1
0
Fork 0
forked from forgejo/forgejo

update mssql drive to last working version 20180314172330-6a30f4e59a44 (#7306)

This commit is contained in:
Antoine GIRARD 2019-06-30 05:28:17 +02:00 committed by Lunny Xiao
parent aeb8f7aad8
commit 1e46eedce7
46 changed files with 3158 additions and 491 deletions

View file

@ -1,6 +1,7 @@
package mssql
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/binary"
@ -9,11 +10,13 @@ import (
"io"
"io/ioutil"
"net"
"net/url"
"os"
"sort"
"strconv"
"strings"
"time"
"unicode"
"unicode/utf16"
"unicode/utf8"
)
@ -47,8 +50,11 @@ func parseInstances(msg []byte) map[string]map[string]string {
return results
}
func getInstances(address string) (map[string]map[string]string, error) {
conn, err := net.DialTimeout("udp", address+":1434", 5*time.Second)
func getInstances(ctx context.Context, address string) (map[string]map[string]string, error) {
dialer := &net.Dialer{
Timeout: 5 * time.Second,
}
conn, err := dialer.DialContext(ctx, "udp", address+":1434")
if err != nil {
return nil, err
}
@ -79,11 +85,16 @@ const (
)
// packet types
// https://msdn.microsoft.com/en-us/library/dd304214.aspx
const (
packSQLBatch = 1
packRPCRequest = 3
packReply = 4
packCancel = 6
packSQLBatch packetType = 1
packRPCRequest = 3
packReply = 4
// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
packAttention = 6
packBulkLoadBCP = 7
packTransMgrReq = 14
packNormal = 15
@ -119,7 +130,7 @@ type tdsSession struct {
columns []columnStruct
tranid uint64
logFlags uint64
log *Logger
log optionalLogger
routedServer string
routedPort uint16
}
@ -131,6 +142,7 @@ const (
logSQL = 8
logParams = 16
logTransaction = 32
logDebug = 64
)
type columnStruct struct {
@ -490,6 +502,11 @@ func readBVarChar(r io.Reader) (res string, err error) {
if err != nil {
return "", err
}
// A zero length could be returned, return an empty string
if numchars == 0 {
return "", nil
}
return readUcs2(r, int(numchars))
}
@ -588,7 +605,7 @@ func (hdr transDescrHdr) pack() (res []byte) {
}
func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
// calculatint total length
// Calculating total length.
var totallen uint32 = 4
for _, hdr := range headers {
totallen += 4 + 2 + uint32(len(hdr.data))
@ -616,9 +633,7 @@ func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
return nil
}
func sendSqlBatch72(buf *tdsBuffer,
sqltext string,
headers []headerStruct) (err error) {
func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct) (err error) {
buf.BeginPacket(packSQLBatch)
if err = writeAllHeaders(buf, headers); err != nil {
@ -632,6 +647,13 @@ func sendSqlBatch72(buf *tdsBuffer,
return buf.FinishPacket()
}
// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
func sendAttention(buf *tdsBuffer) error {
buf.BeginPacket(packAttention)
return buf.FinishPacket()
}
type connectParams struct {
logFlags uint64
port uint64
@ -654,6 +676,7 @@ type connectParams struct {
typeFlags uint8
failOverPartner string
failOverPort uint64
packetSize uint16
}
func splitConnectionString(dsn string) (res map[string]string) {
@ -677,9 +700,241 @@ func splitConnectionString(dsn string) (res map[string]string) {
return res
}
// Splits a URL in the ODBC format
func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
res := map[string]string{}
type parserState int
const (
// Before the start of a key
parserStateBeforeKey parserState = iota
// Inside a key
parserStateKey
// Beginning of a value. May be bare or braced
parserStateBeginValue
// Inside a bare value
parserStateBareValue
// Inside a braced value
parserStateBracedValue
// A closing brace inside a braced value.
// May be the end of the value or an escaped closing brace, depending on the next character
parserStateBracedValueClosingBrace
// After a value. Next character should be a semicolon or whitespace.
parserStateEndValue
)
var state = parserStateBeforeKey
var key string
var value string
for i, c := range dsn {
switch state {
case parserStateBeforeKey:
switch {
case c == '=':
return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i)
case !unicode.IsSpace(c) && c != ';':
state = parserStateKey
key += string(c)
}
case parserStateKey:
switch c {
case '=':
key = normalizeOdbcKey(key)
if len(key) == 0 {
return res, fmt.Errorf("Unexpected end of key at index %d.", i)
}
state = parserStateBeginValue
case ';':
// Key without value
key = normalizeOdbcKey(key)
if len(key) == 0 {
return res, fmt.Errorf("Unexpected end of key at index %d.", i)
}
res[key] = value
key = ""
value = ""
state = parserStateBeforeKey
default:
key += string(c)
}
case parserStateBeginValue:
switch {
case c == '{':
state = parserStateBracedValue
case c == ';':
// Empty value
res[key] = value
key = ""
state = parserStateBeforeKey
case unicode.IsSpace(c):
// Ignore whitespace
default:
state = parserStateBareValue
value += string(c)
}
case parserStateBareValue:
if c == ';' {
res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
key = ""
value = ""
state = parserStateBeforeKey
} else {
value += string(c)
}
case parserStateBracedValue:
if c == '}' {
state = parserStateBracedValueClosingBrace
} else {
value += string(c)
}
case parserStateBracedValueClosingBrace:
if c == '}' {
// Escaped closing brace
value += string(c)
state = parserStateBracedValue
continue
}
// End of braced value
res[key] = value
key = ""
value = ""
// This character is the first character past the end,
// so it needs to be parsed like the parserStateEndValue state.
state = parserStateEndValue
switch {
case c == ';':
state = parserStateBeforeKey
case unicode.IsSpace(c):
// Ignore whitespace
default:
return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
}
case parserStateEndValue:
switch {
case c == ';':
state = parserStateBeforeKey
case unicode.IsSpace(c):
// Ignore whitespace
default:
return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
}
}
}
switch state {
case parserStateBeforeKey: // Okay
case parserStateKey: // Unfinished key. Treat as key without value.
key = normalizeOdbcKey(key)
if len(key) == 0 {
return res, fmt.Errorf("Unexpected end of key at index %d.", len(dsn))
}
res[key] = value
case parserStateBeginValue: // Empty value
res[key] = value
case parserStateBareValue:
res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
case parserStateBracedValue:
return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn))
case parserStateBracedValueClosingBrace: // End of braced value
res[key] = value
case parserStateEndValue: // Okay
}
return res, nil
}
// Normalizes the given string as an ODBC-format key
func normalizeOdbcKey(s string) string {
return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace))
}
// Splits a URL of the form sqlserver://username:password@host/instance?param1=value&param2=value
func splitConnectionStringURL(dsn string) (map[string]string, error) {
res := map[string]string{}
u, err := url.Parse(dsn)
if err != nil {
return res, err
}
if u.Scheme != "sqlserver" {
return res, fmt.Errorf("scheme %s is not recognized", u.Scheme)
}
if u.User != nil {
res["user id"] = u.User.Username()
p, exists := u.User.Password()
if exists {
res["password"] = p
}
}
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
host = u.Host
}
if len(u.Path) > 0 {
res["server"] = host + "\\" + u.Path[1:]
} else {
res["server"] = host
}
if len(port) > 0 {
res["port"] = port
}
query := u.Query()
for k, v := range query {
if len(v) > 1 {
return res, fmt.Errorf("key %s provided more than once", k)
}
res[strings.ToLower(k)] = v[0]
}
return res, nil
}
func parseConnectParams(dsn string) (connectParams, error) {
params := splitConnectionString(dsn)
var p connectParams
var params map[string]string
if strings.HasPrefix(dsn, "odbc:") {
parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):])
if err != nil {
return p, err
}
params = parameters
} else if strings.HasPrefix(dsn, "sqlserver://") {
parameters, err := splitConnectionStringURL(dsn)
if err != nil {
return p, err
}
params = parameters
} else {
params = splitConnectionString(dsn)
}
strlog, ok := params["log"]
if ok {
var err error
@ -712,7 +967,32 @@ func parseConnectParams(dsn string) (connectParams, error) {
}
}
p.dial_timeout = 5 * time.Second
// https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
// Default packet size remains at 4096 bytes
p.packetSize = 4096
strpsize, ok := params["packet size"]
if ok {
var err error
psize, err := strconv.ParseUint(strpsize, 0, 16)
if err != nil {
f := "Invalid packet size '%v': %v"
return p, fmt.Errorf(f, strpsize, err.Error())
}
// Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
// NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request
// a higher packet size, the server will respond with an ENVCHANGE request to
// alter the packet size to 16383 bytes.
p.packetSize = uint16(psize)
if p.packetSize < 512 {
p.packetSize = 512
} else if p.packetSize > 32767 {
p.packetSize = 32767
}
}
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
p.dial_timeout = 15 * time.Second
p.conn_timeout = 30 * time.Second
strconntimeout, ok := params["connection timeout"]
if ok {
@ -732,8 +1012,12 @@ func parseConnectParams(dsn string) (connectParams, error) {
}
p.dial_timeout = time.Duration(timeout) * time.Second
}
keepAlive, ok := params["keepalive"]
if ok {
// default keep alive should be 30 seconds according to spec:
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
p.keepAlive = 30 * time.Second
if keepAlive, ok := params["keepalive"]; ok {
timeout, err := strconv.ParseUint(keepAlive, 0, 16)
if err != nil {
f := "Invalid keepAlive value '%s': %s"
@ -743,7 +1027,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
}
encrypt, ok := params["encrypt"]
if ok {
if strings.ToUpper(encrypt) == "DISABLE" {
if strings.EqualFold(encrypt, "DISABLE") {
p.disableEncryption = true
} else {
var err error
@ -819,7 +1103,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
return p, nil
}
type Auth interface {
type auth interface {
InitialBytes() ([]byte, error)
NextBytes([]byte) ([]byte, error)
Free()
@ -828,7 +1112,7 @@ type Auth interface {
// SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
// list of IP addresses. So if there is more than one, try them all and
// use the first one that allows a connection.
func dialConnection(p connectParams) (conn net.Conn, err error) {
func dialConnection(ctx context.Context, p connectParams) (conn net.Conn, err error) {
var ips []net.IP
ips, err = net.LookupIP(p.host)
if err != nil {
@ -839,9 +1123,9 @@ func dialConnection(p connectParams) (conn net.Conn, err error) {
ips = []net.IP{ip}
}
if len(ips) == 1 {
d := createDialer(p)
d := createDialer(&p)
addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port)))
conn, err = d.Dial("tcp", addr)
conn, err = d.Dial(ctx, addr)
} else {
//Try Dials in parallel to avoid waiting for timeouts.
@ -850,9 +1134,9 @@ func dialConnection(p connectParams) (conn net.Conn, err error) {
portStr := strconv.Itoa(int(p.port))
for _, ip := range ips {
go func(ip net.IP) {
d := createDialer(p)
d := createDialer(&p)
addr := net.JoinHostPort(ip.String(), portStr)
conn, err := d.Dial("tcp", addr)
conn, err := d.Dial(ctx, addr)
if err == nil {
connChan <- conn
} else {
@ -887,16 +1171,15 @@ func dialConnection(p connectParams) (conn net.Conn, err error) {
f := "Unable to open tcp connection with host '%v:%v': %v"
return nil, fmt.Errorf(f, p.host, p.port, err.Error())
}
return conn, err
}
func connect(p connectParams) (res *tdsSession, err error) {
func connect(ctx context.Context, log optionalLogger, p connectParams) (res *tdsSession, err error) {
res = nil
// if instance is specified use instance resolution service
if p.instance != "" {
p.instance = strings.ToUpper(p.instance)
instances, err := getInstances(p.host)
instances, err := getInstances(ctx, p.host)
if err != nil {
f := "Unable to get instances from Sql Server Browser on host %v: %v"
return nil, fmt.Errorf(f, p.host, err.Error())
@ -914,16 +1197,17 @@ func connect(p connectParams) (res *tdsSession, err error) {
}
initiate_connection:
conn, err := dialConnection(p)
conn, err := dialConnection(ctx, p)
if err != nil {
return nil, err
}
toconn := NewTimeoutConn(conn, p.conn_timeout)
outbuf := newTdsBuffer(4096, toconn)
outbuf := newTdsBuffer(p.packetSize, toconn)
sess := tdsSession{
buf: outbuf,
log: log,
logFlags: p.logFlags,
}
@ -969,8 +1253,7 @@ initiate_connection:
if p.certificate != "" {
pem, err := ioutil.ReadFile(p.certificate)
if err != nil {
f := "Cannot read certificate '%s': %s"
return nil, fmt.Errorf(f, p.certificate, err.Error())
return nil, fmt.Errorf("Cannot read certificate %q: %v", p.certificate, err)
}
certs := x509.NewCertPool()
certs.AppendCertsFromPEM(pem)
@ -980,15 +1263,20 @@ initiate_connection:
config.InsecureSkipVerify = true
}
config.ServerName = p.hostInCertificate
// fix for https://github.com/denisenkom/go-mssqldb/issues/166
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
config.DynamicRecordSizingDisabled = true
outbuf.transport = conn
toconn.buf = outbuf
tlsConn := tls.Client(toconn, &config)
err = tlsConn.Handshake()
toconn.buf = nil
outbuf.transport = tlsConn
if err != nil {
f := "TLS Handshake failed: %s"
return nil, fmt.Errorf(f, err.Error())
return nil, fmt.Errorf("TLS Handshake failed: %v", err)
}
if encrypt == encryptOff {
outbuf.afterFirst = func() {
@ -999,7 +1287,7 @@ initiate_connection:
login := login{
TDSVersion: verTDS74,
PacketSize: uint32(len(outbuf.buf)),
PacketSize: uint32(outbuf.PackageSize()),
Database: p.database,
OptionFlags2: fODBC, // to get unlimited TEXTSIZE
HostName: p.workstation,
@ -1028,7 +1316,7 @@ initiate_connection:
var sspi_msg []byte
continue_login:
tokchan := make(chan tokenStruct, 5)
go processResponse(&sess, tokchan)
go processResponse(context.Background(), &sess, tokchan, nil)
success := false
for tok := range tokchan {
switch token := tok.(type) {
@ -1042,6 +1330,10 @@ continue_login:
sess.loginAck = token
case error:
return nil, fmt.Errorf("Login error: %s", token.Error())
case doneStruct:
if token.isError() {
return nil, fmt.Errorf("Login error: %s", token.getError())
}
}
}
if sspi_msg != nil {