forked from forgejo/forgejo
update mssql drive to last working version 20180314172330-6a30f4e59a44 (#7306)
This commit is contained in:
parent
aeb8f7aad8
commit
1e46eedce7
46 changed files with 3158 additions and 491 deletions
358
vendor/github.com/denisenkom/go-mssqldb/tds.go
generated
vendored
358
vendor/github.com/denisenkom/go-mssqldb/tds.go
generated
vendored
|
@ -1,6 +1,7 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/binary"
|
||||
|
@ -9,11 +10,13 @@ import (
|
|||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
"unicode/utf16"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
@ -47,8 +50,11 @@ func parseInstances(msg []byte) map[string]map[string]string {
|
|||
return results
|
||||
}
|
||||
|
||||
func getInstances(address string) (map[string]map[string]string, error) {
|
||||
conn, err := net.DialTimeout("udp", address+":1434", 5*time.Second)
|
||||
func getInstances(ctx context.Context, address string) (map[string]map[string]string, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
conn, err := dialer.DialContext(ctx, "udp", address+":1434")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -79,11 +85,16 @@ const (
|
|||
)
|
||||
|
||||
// packet types
|
||||
// https://msdn.microsoft.com/en-us/library/dd304214.aspx
|
||||
const (
|
||||
packSQLBatch = 1
|
||||
packRPCRequest = 3
|
||||
packReply = 4
|
||||
packCancel = 6
|
||||
packSQLBatch packetType = 1
|
||||
packRPCRequest = 3
|
||||
packReply = 4
|
||||
|
||||
// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
|
||||
// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
|
||||
packAttention = 6
|
||||
|
||||
packBulkLoadBCP = 7
|
||||
packTransMgrReq = 14
|
||||
packNormal = 15
|
||||
|
@ -119,7 +130,7 @@ type tdsSession struct {
|
|||
columns []columnStruct
|
||||
tranid uint64
|
||||
logFlags uint64
|
||||
log *Logger
|
||||
log optionalLogger
|
||||
routedServer string
|
||||
routedPort uint16
|
||||
}
|
||||
|
@ -131,6 +142,7 @@ const (
|
|||
logSQL = 8
|
||||
logParams = 16
|
||||
logTransaction = 32
|
||||
logDebug = 64
|
||||
)
|
||||
|
||||
type columnStruct struct {
|
||||
|
@ -490,6 +502,11 @@ func readBVarChar(r io.Reader) (res string, err error) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// A zero length could be returned, return an empty string
|
||||
if numchars == 0 {
|
||||
return "", nil
|
||||
}
|
||||
return readUcs2(r, int(numchars))
|
||||
}
|
||||
|
||||
|
@ -588,7 +605,7 @@ func (hdr transDescrHdr) pack() (res []byte) {
|
|||
}
|
||||
|
||||
func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
|
||||
// calculatint total length
|
||||
// Calculating total length.
|
||||
var totallen uint32 = 4
|
||||
for _, hdr := range headers {
|
||||
totallen += 4 + 2 + uint32(len(hdr.data))
|
||||
|
@ -616,9 +633,7 @@ func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
func sendSqlBatch72(buf *tdsBuffer,
|
||||
sqltext string,
|
||||
headers []headerStruct) (err error) {
|
||||
func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct) (err error) {
|
||||
buf.BeginPacket(packSQLBatch)
|
||||
|
||||
if err = writeAllHeaders(buf, headers); err != nil {
|
||||
|
@ -632,6 +647,13 @@ func sendSqlBatch72(buf *tdsBuffer,
|
|||
return buf.FinishPacket()
|
||||
}
|
||||
|
||||
// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
|
||||
// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
|
||||
func sendAttention(buf *tdsBuffer) error {
|
||||
buf.BeginPacket(packAttention)
|
||||
return buf.FinishPacket()
|
||||
}
|
||||
|
||||
type connectParams struct {
|
||||
logFlags uint64
|
||||
port uint64
|
||||
|
@ -654,6 +676,7 @@ type connectParams struct {
|
|||
typeFlags uint8
|
||||
failOverPartner string
|
||||
failOverPort uint64
|
||||
packetSize uint16
|
||||
}
|
||||
|
||||
func splitConnectionString(dsn string) (res map[string]string) {
|
||||
|
@ -677,9 +700,241 @@ func splitConnectionString(dsn string) (res map[string]string) {
|
|||
return res
|
||||
}
|
||||
|
||||
// Splits a URL in the ODBC format
|
||||
func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
|
||||
res := map[string]string{}
|
||||
|
||||
type parserState int
|
||||
const (
|
||||
// Before the start of a key
|
||||
parserStateBeforeKey parserState = iota
|
||||
|
||||
// Inside a key
|
||||
parserStateKey
|
||||
|
||||
// Beginning of a value. May be bare or braced
|
||||
parserStateBeginValue
|
||||
|
||||
// Inside a bare value
|
||||
parserStateBareValue
|
||||
|
||||
// Inside a braced value
|
||||
parserStateBracedValue
|
||||
|
||||
// A closing brace inside a braced value.
|
||||
// May be the end of the value or an escaped closing brace, depending on the next character
|
||||
parserStateBracedValueClosingBrace
|
||||
|
||||
// After a value. Next character should be a semicolon or whitespace.
|
||||
parserStateEndValue
|
||||
)
|
||||
|
||||
var state = parserStateBeforeKey
|
||||
|
||||
var key string
|
||||
var value string
|
||||
|
||||
for i, c := range dsn {
|
||||
switch state {
|
||||
case parserStateBeforeKey:
|
||||
switch {
|
||||
case c == '=':
|
||||
return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i)
|
||||
case !unicode.IsSpace(c) && c != ';':
|
||||
state = parserStateKey
|
||||
key += string(c)
|
||||
}
|
||||
|
||||
case parserStateKey:
|
||||
switch c {
|
||||
case '=':
|
||||
key = normalizeOdbcKey(key)
|
||||
if len(key) == 0 {
|
||||
return res, fmt.Errorf("Unexpected end of key at index %d.", i)
|
||||
}
|
||||
|
||||
state = parserStateBeginValue
|
||||
|
||||
case ';':
|
||||
// Key without value
|
||||
key = normalizeOdbcKey(key)
|
||||
if len(key) == 0 {
|
||||
return res, fmt.Errorf("Unexpected end of key at index %d.", i)
|
||||
}
|
||||
|
||||
res[key] = value
|
||||
key = ""
|
||||
value = ""
|
||||
state = parserStateBeforeKey
|
||||
|
||||
default:
|
||||
key += string(c)
|
||||
}
|
||||
|
||||
case parserStateBeginValue:
|
||||
switch {
|
||||
case c == '{':
|
||||
state = parserStateBracedValue
|
||||
case c == ';':
|
||||
// Empty value
|
||||
res[key] = value
|
||||
key = ""
|
||||
state = parserStateBeforeKey
|
||||
case unicode.IsSpace(c):
|
||||
// Ignore whitespace
|
||||
default:
|
||||
state = parserStateBareValue
|
||||
value += string(c)
|
||||
}
|
||||
|
||||
case parserStateBareValue:
|
||||
if c == ';' {
|
||||
res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
|
||||
key = ""
|
||||
value = ""
|
||||
state = parserStateBeforeKey
|
||||
} else {
|
||||
value += string(c)
|
||||
}
|
||||
|
||||
case parserStateBracedValue:
|
||||
if c == '}' {
|
||||
state = parserStateBracedValueClosingBrace
|
||||
} else {
|
||||
value += string(c)
|
||||
}
|
||||
|
||||
case parserStateBracedValueClosingBrace:
|
||||
if c == '}' {
|
||||
// Escaped closing brace
|
||||
value += string(c)
|
||||
state = parserStateBracedValue
|
||||
continue
|
||||
}
|
||||
|
||||
// End of braced value
|
||||
res[key] = value
|
||||
key = ""
|
||||
value = ""
|
||||
|
||||
// This character is the first character past the end,
|
||||
// so it needs to be parsed like the parserStateEndValue state.
|
||||
state = parserStateEndValue
|
||||
switch {
|
||||
case c == ';':
|
||||
state = parserStateBeforeKey
|
||||
case unicode.IsSpace(c):
|
||||
// Ignore whitespace
|
||||
default:
|
||||
return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
|
||||
}
|
||||
|
||||
case parserStateEndValue:
|
||||
switch {
|
||||
case c == ';':
|
||||
state = parserStateBeforeKey
|
||||
case unicode.IsSpace(c):
|
||||
// Ignore whitespace
|
||||
default:
|
||||
return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch state {
|
||||
case parserStateBeforeKey: // Okay
|
||||
case parserStateKey: // Unfinished key. Treat as key without value.
|
||||
key = normalizeOdbcKey(key)
|
||||
if len(key) == 0 {
|
||||
return res, fmt.Errorf("Unexpected end of key at index %d.", len(dsn))
|
||||
}
|
||||
res[key] = value
|
||||
case parserStateBeginValue: // Empty value
|
||||
res[key] = value
|
||||
case parserStateBareValue:
|
||||
res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
|
||||
case parserStateBracedValue:
|
||||
return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn))
|
||||
case parserStateBracedValueClosingBrace: // End of braced value
|
||||
res[key] = value
|
||||
case parserStateEndValue: // Okay
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Normalizes the given string as an ODBC-format key
|
||||
func normalizeOdbcKey(s string) string {
|
||||
return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace))
|
||||
}
|
||||
|
||||
// Splits a URL of the form sqlserver://username:password@host/instance?param1=value¶m2=value
|
||||
func splitConnectionStringURL(dsn string) (map[string]string, error) {
|
||||
res := map[string]string{}
|
||||
|
||||
u, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
if u.Scheme != "sqlserver" {
|
||||
return res, fmt.Errorf("scheme %s is not recognized", u.Scheme)
|
||||
}
|
||||
|
||||
if u.User != nil {
|
||||
res["user id"] = u.User.Username()
|
||||
p, exists := u.User.Password()
|
||||
if exists {
|
||||
res["password"] = p
|
||||
}
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(u.Host)
|
||||
if err != nil {
|
||||
host = u.Host
|
||||
}
|
||||
|
||||
if len(u.Path) > 0 {
|
||||
res["server"] = host + "\\" + u.Path[1:]
|
||||
} else {
|
||||
res["server"] = host
|
||||
}
|
||||
|
||||
if len(port) > 0 {
|
||||
res["port"] = port
|
||||
}
|
||||
|
||||
query := u.Query()
|
||||
for k, v := range query {
|
||||
if len(v) > 1 {
|
||||
return res, fmt.Errorf("key %s provided more than once", k)
|
||||
}
|
||||
res[strings.ToLower(k)] = v[0]
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func parseConnectParams(dsn string) (connectParams, error) {
|
||||
params := splitConnectionString(dsn)
|
||||
var p connectParams
|
||||
|
||||
var params map[string]string
|
||||
if strings.HasPrefix(dsn, "odbc:") {
|
||||
parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):])
|
||||
if err != nil {
|
||||
return p, err
|
||||
}
|
||||
params = parameters
|
||||
} else if strings.HasPrefix(dsn, "sqlserver://") {
|
||||
parameters, err := splitConnectionStringURL(dsn)
|
||||
if err != nil {
|
||||
return p, err
|
||||
}
|
||||
params = parameters
|
||||
} else {
|
||||
params = splitConnectionString(dsn)
|
||||
}
|
||||
|
||||
strlog, ok := params["log"]
|
||||
if ok {
|
||||
var err error
|
||||
|
@ -712,7 +967,32 @@ func parseConnectParams(dsn string) (connectParams, error) {
|
|||
}
|
||||
}
|
||||
|
||||
p.dial_timeout = 5 * time.Second
|
||||
// https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
|
||||
// Default packet size remains at 4096 bytes
|
||||
p.packetSize = 4096
|
||||
strpsize, ok := params["packet size"]
|
||||
if ok {
|
||||
var err error
|
||||
psize, err := strconv.ParseUint(strpsize, 0, 16)
|
||||
if err != nil {
|
||||
f := "Invalid packet size '%v': %v"
|
||||
return p, fmt.Errorf(f, strpsize, err.Error())
|
||||
}
|
||||
|
||||
// Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
|
||||
// NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request
|
||||
// a higher packet size, the server will respond with an ENVCHANGE request to
|
||||
// alter the packet size to 16383 bytes.
|
||||
p.packetSize = uint16(psize)
|
||||
if p.packetSize < 512 {
|
||||
p.packetSize = 512
|
||||
} else if p.packetSize > 32767 {
|
||||
p.packetSize = 32767
|
||||
}
|
||||
}
|
||||
|
||||
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
|
||||
p.dial_timeout = 15 * time.Second
|
||||
p.conn_timeout = 30 * time.Second
|
||||
strconntimeout, ok := params["connection timeout"]
|
||||
if ok {
|
||||
|
@ -732,8 +1012,12 @@ func parseConnectParams(dsn string) (connectParams, error) {
|
|||
}
|
||||
p.dial_timeout = time.Duration(timeout) * time.Second
|
||||
}
|
||||
keepAlive, ok := params["keepalive"]
|
||||
if ok {
|
||||
|
||||
// default keep alive should be 30 seconds according to spec:
|
||||
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
|
||||
p.keepAlive = 30 * time.Second
|
||||
|
||||
if keepAlive, ok := params["keepalive"]; ok {
|
||||
timeout, err := strconv.ParseUint(keepAlive, 0, 16)
|
||||
if err != nil {
|
||||
f := "Invalid keepAlive value '%s': %s"
|
||||
|
@ -743,7 +1027,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
|
|||
}
|
||||
encrypt, ok := params["encrypt"]
|
||||
if ok {
|
||||
if strings.ToUpper(encrypt) == "DISABLE" {
|
||||
if strings.EqualFold(encrypt, "DISABLE") {
|
||||
p.disableEncryption = true
|
||||
} else {
|
||||
var err error
|
||||
|
@ -819,7 +1103,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
|
|||
return p, nil
|
||||
}
|
||||
|
||||
type Auth interface {
|
||||
type auth interface {
|
||||
InitialBytes() ([]byte, error)
|
||||
NextBytes([]byte) ([]byte, error)
|
||||
Free()
|
||||
|
@ -828,7 +1112,7 @@ type Auth interface {
|
|||
// SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
|
||||
// list of IP addresses. So if there is more than one, try them all and
|
||||
// use the first one that allows a connection.
|
||||
func dialConnection(p connectParams) (conn net.Conn, err error) {
|
||||
func dialConnection(ctx context.Context, p connectParams) (conn net.Conn, err error) {
|
||||
var ips []net.IP
|
||||
ips, err = net.LookupIP(p.host)
|
||||
if err != nil {
|
||||
|
@ -839,9 +1123,9 @@ func dialConnection(p connectParams) (conn net.Conn, err error) {
|
|||
ips = []net.IP{ip}
|
||||
}
|
||||
if len(ips) == 1 {
|
||||
d := createDialer(p)
|
||||
d := createDialer(&p)
|
||||
addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port)))
|
||||
conn, err = d.Dial("tcp", addr)
|
||||
conn, err = d.Dial(ctx, addr)
|
||||
|
||||
} else {
|
||||
//Try Dials in parallel to avoid waiting for timeouts.
|
||||
|
@ -850,9 +1134,9 @@ func dialConnection(p connectParams) (conn net.Conn, err error) {
|
|||
portStr := strconv.Itoa(int(p.port))
|
||||
for _, ip := range ips {
|
||||
go func(ip net.IP) {
|
||||
d := createDialer(p)
|
||||
d := createDialer(&p)
|
||||
addr := net.JoinHostPort(ip.String(), portStr)
|
||||
conn, err := d.Dial("tcp", addr)
|
||||
conn, err := d.Dial(ctx, addr)
|
||||
if err == nil {
|
||||
connChan <- conn
|
||||
} else {
|
||||
|
@ -887,16 +1171,15 @@ func dialConnection(p connectParams) (conn net.Conn, err error) {
|
|||
f := "Unable to open tcp connection with host '%v:%v': %v"
|
||||
return nil, fmt.Errorf(f, p.host, p.port, err.Error())
|
||||
}
|
||||
|
||||
return conn, err
|
||||
}
|
||||
|
||||
func connect(p connectParams) (res *tdsSession, err error) {
|
||||
func connect(ctx context.Context, log optionalLogger, p connectParams) (res *tdsSession, err error) {
|
||||
res = nil
|
||||
// if instance is specified use instance resolution service
|
||||
if p.instance != "" {
|
||||
p.instance = strings.ToUpper(p.instance)
|
||||
instances, err := getInstances(p.host)
|
||||
instances, err := getInstances(ctx, p.host)
|
||||
if err != nil {
|
||||
f := "Unable to get instances from Sql Server Browser on host %v: %v"
|
||||
return nil, fmt.Errorf(f, p.host, err.Error())
|
||||
|
@ -914,16 +1197,17 @@ func connect(p connectParams) (res *tdsSession, err error) {
|
|||
}
|
||||
|
||||
initiate_connection:
|
||||
conn, err := dialConnection(p)
|
||||
conn, err := dialConnection(ctx, p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toconn := NewTimeoutConn(conn, p.conn_timeout)
|
||||
|
||||
outbuf := newTdsBuffer(4096, toconn)
|
||||
outbuf := newTdsBuffer(p.packetSize, toconn)
|
||||
sess := tdsSession{
|
||||
buf: outbuf,
|
||||
log: log,
|
||||
logFlags: p.logFlags,
|
||||
}
|
||||
|
||||
|
@ -969,8 +1253,7 @@ initiate_connection:
|
|||
if p.certificate != "" {
|
||||
pem, err := ioutil.ReadFile(p.certificate)
|
||||
if err != nil {
|
||||
f := "Cannot read certificate '%s': %s"
|
||||
return nil, fmt.Errorf(f, p.certificate, err.Error())
|
||||
return nil, fmt.Errorf("Cannot read certificate %q: %v", p.certificate, err)
|
||||
}
|
||||
certs := x509.NewCertPool()
|
||||
certs.AppendCertsFromPEM(pem)
|
||||
|
@ -980,15 +1263,20 @@ initiate_connection:
|
|||
config.InsecureSkipVerify = true
|
||||
}
|
||||
config.ServerName = p.hostInCertificate
|
||||
// fix for https://github.com/denisenkom/go-mssqldb/issues/166
|
||||
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
|
||||
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
|
||||
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
|
||||
config.DynamicRecordSizingDisabled = true
|
||||
outbuf.transport = conn
|
||||
toconn.buf = outbuf
|
||||
tlsConn := tls.Client(toconn, &config)
|
||||
err = tlsConn.Handshake()
|
||||
|
||||
toconn.buf = nil
|
||||
outbuf.transport = tlsConn
|
||||
if err != nil {
|
||||
f := "TLS Handshake failed: %s"
|
||||
return nil, fmt.Errorf(f, err.Error())
|
||||
return nil, fmt.Errorf("TLS Handshake failed: %v", err)
|
||||
}
|
||||
if encrypt == encryptOff {
|
||||
outbuf.afterFirst = func() {
|
||||
|
@ -999,7 +1287,7 @@ initiate_connection:
|
|||
|
||||
login := login{
|
||||
TDSVersion: verTDS74,
|
||||
PacketSize: uint32(len(outbuf.buf)),
|
||||
PacketSize: uint32(outbuf.PackageSize()),
|
||||
Database: p.database,
|
||||
OptionFlags2: fODBC, // to get unlimited TEXTSIZE
|
||||
HostName: p.workstation,
|
||||
|
@ -1028,7 +1316,7 @@ initiate_connection:
|
|||
var sspi_msg []byte
|
||||
continue_login:
|
||||
tokchan := make(chan tokenStruct, 5)
|
||||
go processResponse(&sess, tokchan)
|
||||
go processResponse(context.Background(), &sess, tokchan, nil)
|
||||
success := false
|
||||
for tok := range tokchan {
|
||||
switch token := tok.(type) {
|
||||
|
@ -1042,6 +1330,10 @@ continue_login:
|
|||
sess.loginAck = token
|
||||
case error:
|
||||
return nil, fmt.Errorf("Login error: %s", token.Error())
|
||||
case doneStruct:
|
||||
if token.isError() {
|
||||
return nil, fmt.Errorf("Login error: %s", token.getError())
|
||||
}
|
||||
}
|
||||
}
|
||||
if sspi_msg != nil {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue