1
0
Fork 0
forked from forgejo/forgejo

[Vendor] Update directly used dependencys (#15593)

* update github.com/blevesearch/bleve v2.0.2 -> v2.0.3

* github.com/denisenkom/go-mssqldb v0.9.0 -> v0.10.0

* github.com/editorconfig/editorconfig-core-go v2.4.1 -> v2.4.2

* github.com/go-chi/cors v1.1.1 -> v1.2.0

* github.com/go-git/go-billy v5.0.0 -> v5.1.0

* github.com/go-git/go-git v5.2.0 -> v5.3.0

* github.com/go-ldap/ldap v3.2.4 -> v3.3.0

* github.com/go-redis/redis v8.6.0 -> v8.8.2

* github.com/go-sql-driver/mysql v1.5.0 -> v1.6.0

* github.com/go-swagger/go-swagger v0.26.1 -> v0.27.0

* github.com/lib/pq v1.9.0 -> v1.10.1

* github.com/mattn/go-sqlite3 v1.14.6 -> v1.14.7

* github.com/go-testfixtures/testfixtures v3.5.0 -> v3.6.0

* github.com/issue9/identicon v1.0.1 -> v1.2.0

* github.com/klauspost/compress v1.11.8 -> v1.12.1

* github.com/mgechev/revive v1.0.3 -> v1.0.6

* github.com/microcosm-cc/bluemonday v1.0.7 -> v1.0.8

* github.com/niklasfasching/go-org v1.4.0 -> v1.5.0

* github.com/olivere/elastic v7.0.22 -> v7.0.24

* github.com/pelletier/go-toml v1.8.1 -> v1.9.0

* github.com/prometheus/client_golang v1.9.0 -> v1.10.0

* github.com/xanzy/go-gitlab v0.44.0 -> v0.48.0

* github.com/yuin/goldmark v1.3.3 -> v1.3.5

* github.com/6543/go-version v1.2.4 -> v1.3.1

* do github.com/lib/pq v1.10.0 -> v1.10.1 again ...
This commit is contained in:
6543 2021-04-23 02:08:53 +02:00 committed by GitHub
parent 834fc74873
commit 792b4dba2c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
558 changed files with 32080 additions and 24669 deletions

8
vendor/github.com/denisenkom/go-mssqldb/.gitignore generated vendored Normal file
View file

@ -0,0 +1,8 @@
/.idea
/.connstr
.vscode
.terraform
*.tfstate*
*.log
*.swp
*~

10
vendor/github.com/denisenkom/go-mssqldb/.golangci.yml generated vendored Normal file
View file

@ -0,0 +1,10 @@
linters:
enable:
# basic go linters
- gofmt
- golint
- govet
# sql related linters
- rowserrcheck
- sqlclosecheck

View file

@ -6,19 +6,8 @@ import (
"context"
"database/sql/driver"
"errors"
"fmt"
)
var _ driver.Connector = &accessTokenConnector{}
// accessTokenConnector wraps Connector and injects a
// fresh access token when connecting to the database
type accessTokenConnector struct {
Connector
accessTokenProvider func() (string, error)
}
// NewAccessTokenConnector creates a new connector from a DSN and a token provider.
// The token provider func will be called when a new connection is requested and should return a valid access token.
// The returned connector may be used with sql.OpenDB.
@ -32,20 +21,10 @@ func NewAccessTokenConnector(dsn string, tokenProvider func() (string, error)) (
return nil, err
}
c := &accessTokenConnector{
Connector: *conn,
accessTokenProvider: tokenProvider,
}
return c, nil
}
// Connect returns a new database connection
func (c *accessTokenConnector) Connect(ctx context.Context) (driver.Conn, error) {
var err error
c.Connector.params.fedAuthAccessToken, err = c.accessTokenProvider()
if err != nil {
return nil, fmt.Errorf("mssql: error retrieving access token: %+v", err)
conn.params.fedAuthLibrary = fedAuthLibrarySecurityToken
conn.securityTokenProvider = func(ctx context.Context) (string, error) {
return tokenProvider()
}
return c.Connector.Connect(ctx)
return conn, nil
}

View file

@ -39,6 +39,9 @@ environment:
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 115
SQLINSTANCE: SQL2017
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 116
SQLINSTANCE: SQL2017
install:
- set GOROOT=c:\go%GOVERSION%

View file

@ -48,8 +48,8 @@ type tdsBuffer struct {
func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer {
return &tdsBuffer{
packetSize: int(bufsize),
wbuf: make([]byte, 1<<16),
rbuf: make([]byte, 1<<16),
wbuf: make([]byte, bufsize),
rbuf: make([]byte, bufsize),
rpos: 8,
transport: transport,
}
@ -137,19 +137,28 @@ func (w *tdsBuffer) FinishPacket() error {
var headerSize = binary.Size(header{})
func (r *tdsBuffer) readNextPacket() error {
h := header{}
var err error
err = binary.Read(r.transport, binary.BigEndian, &h)
buf := r.rbuf[:headerSize]
_, err := io.ReadFull(r.transport, buf)
if err != nil {
return err
}
h := header{
PacketType: packetType(buf[0]),
Status: buf[1],
Size: binary.BigEndian.Uint16(buf[2:4]),
Spid: binary.BigEndian.Uint16(buf[4:6]),
PacketNo: buf[6],
Pad: buf[7],
}
if int(h.Size) > r.packetSize {
return errors.New("Invalid packet size, it is longer than buffer size")
return errors.New("invalid packet size, it is longer than buffer size")
}
if headerSize > int(h.Size) {
return errors.New("Invalid packet size, it is shorter than header size")
return errors.New("invalid packet size, it is shorter than header size")
}
_, err = io.ReadFull(r.transport, r.rbuf[headerSize:h.Size])
//s := base64.StdEncoding.EncodeToString(r.rbuf[headerSize:h.Size])
//fmt.Print(s)
if err != nil {
return err
}

View file

@ -44,8 +44,9 @@ type BulkOptions struct {
type DataValue interface{}
const (
sqlDateFormat = "2006-01-02"
sqlTimeFormat = "2006-01-02 15:04:05.999999999Z07:00"
sqlDateFormat = "2006-01-02"
sqlDateTimeFormat = "2006-01-02 15:04:05.999999999Z07:00"
sqlTimeFormat = "15:04:05.9999999"
)
func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) {
@ -86,7 +87,7 @@ func (b *Bulk) sendBulkCommand(ctx context.Context) (err error) {
b.bulkColumns = append(b.bulkColumns, *bulkCol)
b.dlogf("Adding column %s %s %#x", colname, bulkCol.ColName, bulkCol.ti.TypeId)
} else {
return fmt.Errorf("Column %s does not exist in destination table %s", colname, b.tablename)
return fmt.Errorf("column %s does not exist in destination table %s", colname, b.tablename)
}
}
@ -166,7 +167,7 @@ func (b *Bulk) AddRow(row []interface{}) (err error) {
}
if len(row) != len(b.bulkColumns) {
return fmt.Errorf("Row does not have the same number of columns than the destination table %d %d",
return fmt.Errorf("row does not have the same number of columns than the destination table %d %d",
len(row), len(b.bulkColumns))
}
@ -215,7 +216,7 @@ func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) {
}
func (b *Bulk) Done() (rowcount int64, err error) {
if b.headerSent == false {
if !b.headerSent {
//no rows had been sent
return 0, nil
}
@ -233,24 +234,13 @@ func (b *Bulk) Done() (rowcount int64, err error) {
buf.FinishPacket()
tokchan := make(chan tokenStruct, 5)
go processResponse(b.ctx, b.cn.sess, tokchan, nil)
var rowCount int64
for token := range tokchan {
switch token := token.(type) {
case doneStruct:
if token.Status&doneCount != 0 {
rowCount = int64(token.RowCount)
}
if token.isError() {
return 0, token.getError()
}
case error:
return 0, b.cn.checkBadConn(token)
}
reader := startReading(b.cn.sess, b.ctx, nil)
err = reader.iterateResponse()
if err != nil {
return 0, b.cn.checkBadConn(err)
}
return rowCount, nil
return reader.rowCount, nil
}
func (b *Bulk) createColMetadata() []byte {
@ -421,7 +411,7 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
res.ti.Size = len(res.buffer)
case string:
var t time.Time
if t, err = time.Parse(sqlTimeFormat, val); err != nil {
if t, err = time.Parse(sqlDateTimeFormat, val); err != nil {
return res, fmt.Errorf("bulk: unable to convert string to date: %v", err)
}
res.buffer = encodeDateTime2(t, int(col.ti.Scale))
@ -437,7 +427,7 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
res.ti.Size = len(res.buffer)
case string:
var t time.Time
if t, err = time.Parse(sqlTimeFormat, val); err != nil {
if t, err = time.Parse(sqlDateTimeFormat, val); err != nil {
return res, fmt.Errorf("bulk: unable to convert string to date: %v", err)
}
res.buffer = encodeDateTimeOffset(t, int(col.ti.Scale))
@ -468,7 +458,7 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
case time.Time:
t = val
case string:
if t, err = time.Parse(sqlTimeFormat, val); err != nil {
if t, err = time.Parse(sqlDateTimeFormat, val); err != nil {
return res, fmt.Errorf("bulk: unable to convert string to date: %v", err)
}
default:
@ -485,7 +475,22 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
} else {
err = fmt.Errorf("mssql: invalid size of column %d", col.ti.Size)
}
case typeTimeN:
var t time.Time
switch val := val.(type) {
case time.Time:
res.buffer = encodeTime(val.Hour(), val.Minute(), val.Second(), val.Nanosecond(), int(col.ti.Scale))
res.ti.Size = len(res.buffer)
case string:
if t, err = time.Parse(sqlTimeFormat, val); err != nil {
return res, fmt.Errorf("bulk: unable to convert string to time: %v", err)
}
res.buffer = encodeTime(t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), int(col.ti.Scale))
res.ti.Size = len(res.buffer)
default:
err = fmt.Errorf("mssql: invalid type for time column: %T %s", val, val)
return
}
// case typeMoney, typeMoney4, typeMoneyN:
case typeDecimal, typeDecimalN, typeNumeric, typeNumericN:
prec := col.ti.Prec

View file

@ -37,11 +37,17 @@ type connectParams struct {
failOverPartner string
failOverPort uint64
packetSize uint16
fedAuthAccessToken string
fedAuthLibrary int
fedAuthADALWorkflow byte
}
// default packet size for TDS buffer
const defaultPacketSize = 4096
func parseConnectParams(dsn string) (connectParams, error) {
var p connectParams
p := connectParams{
fedAuthLibrary: fedAuthLibraryReserved,
}
var params map[string]string
if strings.HasPrefix(dsn, "odbc:") {
@ -65,7 +71,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
var err error
p.logFlags, err = strconv.ParseUint(strlog, 10, 64)
if err != nil {
return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error())
return p, fmt.Errorf("invalid log parameter '%s': %s", strlog, err.Error())
}
}
server := params["server"]
@ -87,20 +93,19 @@ func parseConnectParams(dsn string) (connectParams, error) {
var err error
p.port, err = strconv.ParseUint(strport, 10, 16)
if err != nil {
f := "Invalid tcp port '%v': %v"
f := "invalid tcp port '%v': %v"
return p, fmt.Errorf(f, strport, err.Error())
}
}
// https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
// Default packet size remains at 4096 bytes
p.packetSize = 4096
p.packetSize = defaultPacketSize
strpsize, ok := params["packet size"]
if ok {
var err error
psize, err := strconv.ParseUint(strpsize, 0, 16)
if err != nil {
f := "Invalid packet size '%v': %v"
f := "invalid packet size '%v': %v"
return p, fmt.Errorf(f, strpsize, err.Error())
}
@ -123,7 +128,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
if strconntimeout, ok := params["connection timeout"]; ok {
timeout, err := strconv.ParseUint(strconntimeout, 10, 64)
if err != nil {
f := "Invalid connection timeout '%v': %v"
f := "invalid connection timeout '%v': %v"
return p, fmt.Errorf(f, strconntimeout, err.Error())
}
p.conn_timeout = time.Duration(timeout) * time.Second
@ -132,7 +137,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
if strdialtimeout, ok := params["dial timeout"]; ok {
timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
if err != nil {
f := "Invalid dial timeout '%v': %v"
f := "invalid dial timeout '%v': %v"
return p, fmt.Errorf(f, strdialtimeout, err.Error())
}
p.dial_timeout = time.Duration(timeout) * time.Second
@ -144,7 +149,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
if keepAlive, ok := params["keepalive"]; ok {
timeout, err := strconv.ParseUint(keepAlive, 10, 64)
if err != nil {
f := "Invalid keepAlive value '%s': %s"
f := "invalid keepAlive value '%s': %s"
return p, fmt.Errorf(f, keepAlive, err.Error())
}
p.keepAlive = time.Duration(timeout) * time.Second
@ -157,7 +162,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
var err error
p.encrypt, err = strconv.ParseBool(encrypt)
if err != nil {
f := "Invalid encrypt '%s': %s"
f := "invalid encrypt '%s': %s"
return p, fmt.Errorf(f, encrypt, err.Error())
}
}
@ -169,7 +174,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
var err error
p.trustServerCertificate, err = strconv.ParseBool(trust)
if err != nil {
f := "Invalid trust server certificate '%s': %s"
f := "invalid trust server certificate '%s': %s"
return p, fmt.Errorf(f, trust, err.Error())
}
}
@ -209,7 +214,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
if ok {
if appintent == "ReadOnly" {
if p.database == "" {
return p, fmt.Errorf("Database must be specified when ApplicationIntent is ReadOnly")
return p, fmt.Errorf("database must be specified when ApplicationIntent is ReadOnly")
}
p.typeFlags |= fReadOnlyIntent
}
@ -225,7 +230,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
var err error
p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
if err != nil {
f := "Invalid tcp port '%v': %v"
f := "invalid tcp port '%v': %v"
return p, fmt.Errorf(f, failOverPort, err.Error())
}
}
@ -233,6 +238,30 @@ func parseConnectParams(dsn string) (connectParams, error) {
return p, nil
}
// convert connectionParams to url style connection string
// used mostly for testing
func (p connectParams) toUrl() *url.URL {
q := url.Values{}
if p.database != "" {
q.Add("database", p.database)
}
if p.logFlags != 0 {
q.Add("log", strconv.FormatUint(p.logFlags, 10))
}
res := url.URL{
Scheme: "sqlserver",
Host: p.host,
User: url.UserPassword(p.user, p.password),
}
if p.instance != "" {
res.Path = p.instance
}
if len(q) > 0 {
res.RawQuery = q.Encode()
}
return &res
}
func splitConnectionString(dsn string) (res map[string]string) {
res = map[string]string{}
parts := strings.Split(dsn, ";")
@ -340,7 +369,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
case parserStateBeforeKey:
switch {
case c == '=':
return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i)
return res, fmt.Errorf("unexpected character = at index %d. Expected start of key or semi-colon or whitespace", i)
case !unicode.IsSpace(c) && c != ';':
state = parserStateKey
key += string(c)
@ -419,7 +448,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
case unicode.IsSpace(c):
// Ignore whitespace
default:
return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
return res, fmt.Errorf("unexpected character %c at index %d. Expected semi-colon or whitespace", c, i)
}
case parserStateEndValue:
@ -429,7 +458,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
case unicode.IsSpace(c):
// Ignore whitespace
default:
return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
return res, fmt.Errorf("unexpected character %c at index %d. Expected semi-colon or whitespace", c, i)
}
}
}
@ -444,7 +473,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
case parserStateBareValue:
res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
case parserStateBracedValue:
return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn))
return res, fmt.Errorf("unexpected end of braced value at index %d", len(dsn))
case parserStateBracedValueClosingBrace: // End of braced value
res[key] = value
case parserStateEndValue: // Okay

82
vendor/github.com/denisenkom/go-mssqldb/fedauth.go generated vendored Normal file
View file

@ -0,0 +1,82 @@
package mssql
import (
"context"
"errors"
)
// Federated authentication library affects the login data structure and message sequence.
const (
// fedAuthLibraryLiveIDCompactToken specifies the Microsoft Live ID Compact Token authentication scheme
fedAuthLibraryLiveIDCompactToken = 0x00
// fedAuthLibrarySecurityToken specifies a token-based authentication where the token is available
// without additional information provided during the login sequence.
fedAuthLibrarySecurityToken = 0x01
// fedAuthLibraryADAL specifies a token-based authentication where a token is obtained during the
// login sequence using the server SPN and STS URL provided by the server during login.
fedAuthLibraryADAL = 0x02
// fedAuthLibraryReserved is used to indicate that no federated authentication scheme applies.
fedAuthLibraryReserved = 0x7F
)
// Federated authentication ADAL workflow affects the mechanism used to authenticate.
const (
// fedAuthADALWorkflowPassword uses a username/password to obtain a token from Active Directory
fedAuthADALWorkflowPassword = 0x01
// fedAuthADALWorkflowPassword uses the Windows identity to obtain a token from Active Directory
fedAuthADALWorkflowIntegrated = 0x02
// fedAuthADALWorkflowMSI uses the managed identity service to obtain a token
fedAuthADALWorkflowMSI = 0x03
)
// newSecurityTokenConnector creates a new connector from a DSN and a token provider.
// When invoked, token provider implementations should contact the security token
// service specified and obtain the appropriate token, or return an error
// to indicate why a token is not available.
// The returned connector may be used with sql.OpenDB.
func newSecurityTokenConnector(dsn string, tokenProvider func(ctx context.Context) (string, error)) (*Connector, error) {
if tokenProvider == nil {
return nil, errors.New("mssql: tokenProvider cannot be nil")
}
conn, err := NewConnector(dsn)
if err != nil {
return nil, err
}
conn.params.fedAuthLibrary = fedAuthLibrarySecurityToken
conn.securityTokenProvider = tokenProvider
return conn, nil
}
// newADALTokenConnector creates a new connector from a DSN and a Active Directory token provider.
// Token provider implementations are called during federated
// authentication login sequences where the server provides a service
// principal name and security token service endpoint that should be used
// to obtain the token. Implementations should contact the security token
// service specified and obtain the appropriate token, or return an error
// to indicate why a token is not available.
//
// The returned connector may be used with sql.OpenDB.
func newActiveDirectoryTokenConnector(dsn string, adalWorkflow byte, tokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error)) (*Connector, error) {
if tokenProvider == nil {
return nil, errors.New("mssql: tokenProvider cannot be nil")
}
conn, err := NewConnector(dsn)
if err != nil {
return nil, err
}
conn.params.fedAuthLibrary = fedAuthLibraryADAL
conn.params.fedAuthADALWorkflow = adalWorkflow
conn.adalTokenProvider = tokenProvider
return conn, nil
}

View file

@ -58,6 +58,7 @@ func (d *Driver) OpenConnector(dsn string) (*Connector, error) {
if err != nil {
return nil, err
}
return &Connector{
params: params,
driver: d,
@ -100,6 +101,12 @@ type Connector struct {
params connectParams
driver *Driver
// callback that can provide a security token during login
securityTokenProvider func(ctx context.Context) (string, error)
// callback that can provide a security token during ADAL login
adalTokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error)
// SessionInitSQL is executed after marking a given session to be reset.
// When not present, the next query will still reset the session to the
// database defaults.
@ -148,15 +155,7 @@ type Conn struct {
processQueryText bool
connectionGood bool
outs map[string]interface{}
returnStatus *ReturnStatus
}
func (c *Conn) setReturnStatus(s ReturnStatus) {
if c.returnStatus == nil {
return
}
*c.returnStatus = s
outs map[string]interface{}
}
func (c *Conn) checkBadConn(err error) error {
@ -201,20 +200,15 @@ func (c *Conn) clearOuts() {
}
func (c *Conn) simpleProcessResp(ctx context.Context) error {
tokchan := make(chan tokenStruct, 5)
go processResponse(ctx, c.sess, tokchan, c.outs)
reader := startReading(c.sess, ctx, c.outs)
c.clearOuts()
for tok := range tokchan {
switch token := tok.(type) {
case doneStruct:
if token.isError() {
return c.checkBadConn(token.getError())
}
case error:
return c.checkBadConn(token)
}
var resultError error
err := reader.iterateResponse()
if err != nil {
return c.checkBadConn(err)
}
return nil
return resultError
}
func (c *Conn) Commit() error {
@ -239,7 +233,7 @@ func (c *Conn) sendCommitRequest() error {
c.sess.log.Printf("Failed to send CommitXact with %v", err)
}
c.connectionGood = false
return fmt.Errorf("Faild to send CommitXact: %v", err)
return fmt.Errorf("faild to send CommitXact: %v", err)
}
return nil
}
@ -266,7 +260,7 @@ func (c *Conn) sendRollbackRequest() error {
c.sess.log.Printf("Failed to send RollbackXact with %v", err)
}
c.connectionGood = false
return fmt.Errorf("Failed to send RollbackXact: %v", err)
return fmt.Errorf("failed to send RollbackXact: %v", err)
}
return nil
}
@ -303,7 +297,7 @@ func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) erro
c.sess.log.Printf("Failed to send BeginXact with %v", err)
}
c.connectionGood = false
return fmt.Errorf("Failed to send BeginXact: %v", err)
return fmt.Errorf("failed to send BeginXact: %v", err)
}
return nil
}
@ -478,7 +472,7 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) {
conn.sess.log.Printf("Failed to send Rpc with %v", err)
}
conn.connectionGood = false
return fmt.Errorf("Failed to send RPC: %v", err)
return fmt.Errorf("failed to send RPC: %v", err)
}
}
return
@ -595,38 +589,46 @@ func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver
}
func (s *Stmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
tokchan := make(chan tokenStruct, 5)
ctx, cancel := context.WithCancel(ctx)
go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
reader := startReading(s.c.sess, ctx, s.c.outs)
s.c.clearOuts()
// process metadata
var cols []columnStruct
loop:
for tok := range tokchan {
switch token := tok.(type) {
// By ignoring DONE token we effectively
// skip empty result-sets.
// This improves results in queries like that:
// set nocount on; select 1
// see TestIgnoreEmptyResults test
//case doneStruct:
//break loop
case []columnStruct:
cols = token
break loop
case doneStruct:
if token.isError() {
cancel()
return nil, s.c.checkBadConn(token.getError())
for {
tok, err := reader.nextToken()
if err == nil {
if tok == nil {
break
} else {
switch token := tok.(type) {
// By ignoring DONE token we effectively
// skip empty result-sets.
// This improves results in queries like that:
// set nocount on; select 1
// see TestIgnoreEmptyResults test
//case doneStruct:
//break loop
case []columnStruct:
cols = token
break loop
case doneStruct:
if token.isError() {
// need to cleanup cancellable context
cancel()
return nil, s.c.checkBadConn(token.getError())
}
case ReturnStatus:
s.c.sess.setReturnStatus(token)
}
}
case ReturnStatus:
s.c.setReturnStatus(token)
case error:
} else {
// need to cleanup cancellable context
cancel()
return nil, s.c.checkBadConn(token)
return nil, s.c.checkBadConn(err)
}
}
res = &Rows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel}
res = &Rows{stmt: s, reader: reader, cols: cols, cancel: cancel}
return
}
@ -648,48 +650,46 @@ func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result,
}
func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) {
tokchan := make(chan tokenStruct, 5)
go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
reader := startReading(s.c.sess, ctx, s.c.outs)
s.c.clearOuts()
var rowCount int64
for token := range tokchan {
switch token := token.(type) {
case doneInProcStruct:
if token.Status&doneCount != 0 {
rowCount += int64(token.RowCount)
}
case doneStruct:
if token.Status&doneCount != 0 {
rowCount += int64(token.RowCount)
}
if token.isError() {
return nil, token.getError()
}
case ReturnStatus:
s.c.setReturnStatus(token)
case error:
return nil, token
}
err = reader.iterateResponse()
if err != nil {
return nil, s.c.checkBadConn(err)
}
return &Result{s.c, rowCount}, nil
return &Result{s.c, reader.rowCount}, nil
}
type Rows struct {
stmt *Stmt
cols []columnStruct
tokchan chan tokenStruct
stmt *Stmt
cols []columnStruct
reader *tokenProcessor
nextCols []columnStruct
cancel func()
}
func (rc *Rows) Close() error {
// need to add a test which returns lots of rows
// and check closing after reading only few rows
rc.cancel()
for _ = range rc.tokchan {
for {
tok, err := rc.reader.nextToken()
if err == nil {
if tok == nil {
return nil
} else {
// continue consuming tokens
continue
}
} else {
if err == rc.reader.ctx.Err() {
return nil
} else {
return err
}
}
}
rc.tokchan = nil
return nil
}
func (rc *Rows) Columns() (res []string) {
@ -707,27 +707,34 @@ func (rc *Rows) Next(dest []driver.Value) error {
if rc.nextCols != nil {
return io.EOF
}
for tok := range rc.tokchan {
switch tokdata := tok.(type) {
case []columnStruct:
rc.nextCols = tokdata
return io.EOF
case []interface{}:
for i := range dest {
dest[i] = tokdata[i]
for {
tok, err := rc.reader.nextToken()
if err == nil {
if tok == nil {
return io.EOF
} else {
switch tokdata := tok.(type) {
case []columnStruct:
rc.nextCols = tokdata
return io.EOF
case []interface{}:
for i := range dest {
dest[i] = tokdata[i]
}
return nil
case doneStruct:
if tokdata.isError() {
return rc.stmt.c.checkBadConn(tokdata.getError())
}
case ReturnStatus:
rc.stmt.c.sess.setReturnStatus(tokdata)
}
}
return nil
case doneStruct:
if tokdata.isError() {
return rc.stmt.c.checkBadConn(tokdata.getError())
}
case ReturnStatus:
rc.stmt.c.setReturnStatus(tokdata)
case error:
return rc.stmt.c.checkBadConn(tokdata)
} else {
return rc.stmt.c.checkBadConn(err)
}
}
return io.EOF
}
func (rc *Rows) HasNextResultSet() bool {
@ -895,35 +902,41 @@ func (c *Conn) Ping(ctx context.Context) error {
var _ driver.ConnBeginTx = &Conn{}
func convertIsolationLevel(level sql.IsolationLevel) (isoLevel, error) {
switch level {
case sql.LevelDefault:
return isolationUseCurrent, nil
case sql.LevelReadUncommitted:
return isolationReadUncommited, nil
case sql.LevelReadCommitted:
return isolationReadCommited, nil
case sql.LevelWriteCommitted:
return isolationUseCurrent, errors.New("LevelWriteCommitted isolation level is not supported")
case sql.LevelRepeatableRead:
return isolationRepeatableRead, nil
case sql.LevelSnapshot:
return isolationSnapshot, nil
case sql.LevelSerializable:
return isolationSerializable, nil
case sql.LevelLinearizable:
return isolationUseCurrent, errors.New("LevelLinearizable isolation level is not supported")
default:
return isolationUseCurrent, errors.New("isolation level is not supported or unknown")
}
}
// BeginTx satisfies ConnBeginTx.
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if !c.connectionGood {
return nil, driver.ErrBadConn
}
if opts.ReadOnly {
return nil, errors.New("Read-only transactions are not supported")
return nil, errors.New("read-only transactions are not supported")
}
var tdsIsolation isoLevel
switch sql.IsolationLevel(opts.Isolation) {
case sql.LevelDefault:
tdsIsolation = isolationUseCurrent
case sql.LevelReadUncommitted:
tdsIsolation = isolationReadUncommited
case sql.LevelReadCommitted:
tdsIsolation = isolationReadCommited
case sql.LevelWriteCommitted:
return nil, errors.New("LevelWriteCommitted isolation level is not supported")
case sql.LevelRepeatableRead:
tdsIsolation = isolationRepeatableRead
case sql.LevelSnapshot:
tdsIsolation = isolationSnapshot
case sql.LevelSerializable:
tdsIsolation = isolationSerializable
case sql.LevelLinearizable:
return nil, errors.New("LevelLinearizable isolation level is not supported")
default:
return nil, errors.New("Isolation level is not supported or unknown")
tdsIsolation, err := convertIsolationLevel(sql.IsolationLevel(opts.Isolation))
if err != nil {
return nil, err
}
return c.begin(ctx, tdsIsolation)
}

View file

@ -48,5 +48,5 @@ func (c *Connector) Driver() driver.Driver {
}
func (r *Result) LastInsertId() (int64, error) {
return -1, errors.New("LastInsertId is not supported. Please use the OUTPUT clause or add `select ID = convert(bigint, SCOPE_IDENTITY())` to the end of your query.")
return -1, errors.New("LastInsertId is not supported. Please use the OUTPUT clause or add `select ID = convert(bigint, SCOPE_IDENTITY())` to the end of your query")
}

View file

@ -110,7 +110,7 @@ func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error {
return nil
case *ReturnStatus:
*v = 0 // By default the return value should be zero.
c.returnStatus = v
c.sess.returnStatus = v
return driver.ErrRemoveArgument
case TVP:
return nil

View file

@ -7,8 +7,8 @@ import (
)
type timeoutConn struct {
c net.Conn
timeout time.Duration
c net.Conn
timeout time.Duration
}
func newTimeoutConn(conn net.Conn, timeout time.Duration) *timeoutConn {
@ -51,21 +51,21 @@ func (c timeoutConn) RemoteAddr() net.Addr {
}
func (c timeoutConn) SetDeadline(t time.Time) error {
panic("Not implemented")
return c.c.SetDeadline(t)
}
func (c timeoutConn) SetReadDeadline(t time.Time) error {
panic("Not implemented")
return c.c.SetReadDeadline(t)
}
func (c timeoutConn) SetWriteDeadline(t time.Time) error {
panic("Not implemented")
return c.c.SetWriteDeadline(t)
}
// this connection is used during TLS Handshake
// TDS protocol requires TLS handshake messages to be sent inside TDS packets
type tlsHandshakeConn struct {
buf *tdsBuffer
buf *tdsBuffer
packetPending bool
continueRead bool
}
@ -75,7 +75,7 @@ func (c *tlsHandshakeConn) Read(b []byte) (n int, err error) {
c.packetPending = false
err = c.buf.FinishPacket()
if err != nil {
err = fmt.Errorf("Cannot send handshake packet: %s", err.Error())
err = fmt.Errorf("cannot send handshake packet: %s", err.Error())
return
}
c.continueRead = false
@ -84,7 +84,7 @@ func (c *tlsHandshakeConn) Read(b []byte) (n int, err error) {
var packet packetType
packet, err = c.buf.BeginRead()
if err != nil {
err = fmt.Errorf("Cannot read handshake packet: %s", err.Error())
err = fmt.Errorf("cannot read handshake packet: %s", err.Error())
return
}
if packet != packPrelogin {
@ -105,27 +105,27 @@ func (c *tlsHandshakeConn) Write(b []byte) (n int, err error) {
}
func (c *tlsHandshakeConn) Close() error {
panic("Not implemented")
return c.buf.transport.Close()
}
func (c *tlsHandshakeConn) LocalAddr() net.Addr {
panic("Not implemented")
return nil
}
func (c *tlsHandshakeConn) RemoteAddr() net.Addr {
panic("Not implemented")
return nil
}
func (c *tlsHandshakeConn) SetDeadline(t time.Time) error {
panic("Not implemented")
func (c *tlsHandshakeConn) SetDeadline(_ time.Time) error {
return nil
}
func (c *tlsHandshakeConn) SetReadDeadline(t time.Time) error {
panic("Not implemented")
func (c *tlsHandshakeConn) SetReadDeadline(_ time.Time) error {
return nil
}
func (c *tlsHandshakeConn) SetWriteDeadline(t time.Time) error {
panic("Not implemented")
func (c *tlsHandshakeConn) SetWriteDeadline(_ time.Time) error {
return nil
}
// this connection just delegates all methods to it's wrapped connection
@ -148,21 +148,21 @@ func (c passthroughConn) Close() error {
}
func (c passthroughConn) LocalAddr() net.Addr {
panic("Not implemented")
return c.c.LocalAddr()
}
func (c passthroughConn) RemoteAddr() net.Addr {
panic("Not implemented")
return c.c.RemoteAddr()
}
func (c passthroughConn) SetDeadline(t time.Time) error {
panic("Not implemented")
return c.c.SetDeadline(t)
}
func (c passthroughConn) SetReadDeadline(t time.Time) error {
panic("Not implemented")
return c.c.SetReadDeadline(t)
}
func (c passthroughConn) SetWriteDeadline(t time.Time) error {
panic("Not implemented")
return c.c.SetWriteDeadline(t)
}

View file

@ -14,6 +14,7 @@ import (
"time"
"unicode/utf16"
//lint:ignore SA1019 MD4 is used by legacy NTLM
"golang.org/x/crypto/md4"
)
@ -126,18 +127,6 @@ func createDesKey(bytes, material []byte) {
material[7] = (byte)(bytes[6] << 1)
}
func oddParity(bytes []byte) {
for i := 0; i < len(bytes); i++ {
b := bytes[i]
needsParity := (((b >> 7) ^ (b >> 6) ^ (b >> 5) ^ (b >> 4) ^ (b >> 3) ^ (b >> 2) ^ (b >> 1)) & 0x01) == 0
if needsParity {
bytes[i] = bytes[i] | byte(0x01)
} else {
bytes[i] = bytes[i] & byte(0xfe)
}
}
}
func encryptDes(key []byte, cleartext []byte, ciphertext []byte) {
var desKey [8]byte
createDesKey(key, desKey[:])

View file

@ -22,12 +22,6 @@ type param struct {
buffer []byte
}
const (
fWithRecomp = 1
fNoMetaData = 2
fReuseMetaData = 4
)
var (
sp_Cursor = procId{1, ""}
sp_CursorOpen = procId{2, ""}

View file

@ -82,19 +82,20 @@ const (
// https://msdn.microsoft.com/en-us/library/dd304214.aspx
const (
packSQLBatch packetType = 1
packRPCRequest = 3
packReply = 4
packRPCRequest packetType = 3
packReply packetType = 4
// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
packAttention = 6
packAttention packetType = 6
packBulkLoadBCP = 7
packTransMgrReq = 14
packNormal = 15
packLogin7 = 16
packSSPIMessage = 17
packPrelogin = 18
packBulkLoadBCP packetType = 7
packFedAuthToken packetType = 8
packTransMgrReq packetType = 14
packNormal packetType = 15
packLogin7 packetType = 16
packSSPIMessage packetType = 17
packPrelogin packetType = 18
)
// prelogin fields
@ -118,6 +119,17 @@ const (
encryptReq = 3 // Encryption is required.
)
const (
featExtSESSIONRECOVERY byte = 0x01
featExtFEDAUTH byte = 0x02
featExtCOLUMNENCRYPTION byte = 0x04
featExtGLOBALTRANSACTIONS byte = 0x05
featExtAZURESQLSUPPORT byte = 0x08
featExtDATACLASSIFICATION byte = 0x09
featExtUTF8SUPPORT byte = 0x0A
featExtTERMINATOR byte = 0xFF
)
type tdsSession struct {
buf *tdsBuffer
loginAck loginAckStruct
@ -129,6 +141,7 @@ type tdsSession struct {
log optionalLogger
routedServer string
routedPort uint16
returnStatus *ReturnStatus
}
const (
@ -155,13 +168,13 @@ func (p keySlice) Less(i, j int) bool { return p[i] < p[j] }
func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// http://msdn.microsoft.com/en-us/library/dd357559.aspx
func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error {
func writePrelogin(packetType packetType, w *tdsBuffer, fields map[uint8][]byte) error {
var err error
w.BeginPacket(packPrelogin, false)
w.BeginPacket(packetType, false)
offset := uint16(5*len(fields) + 1)
keys := make(keySlice, 0, len(fields))
for k, _ := range fields {
for k := range fields {
keys = append(keys, k)
}
sort.Sort(keys)
@ -210,12 +223,15 @@ func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) {
if err != nil {
return nil, err
}
if packet_type != 4 {
return nil, errors.New("Invalid respones, expected packet type 4, PRELOGIN RESPONSE")
if packet_type != packReply {
return nil, errors.New("invalid respones, expected packet type 4, PRELOGIN RESPONSE")
}
if len(struct_buf) == 0 {
return nil, errors.New("invalid empty PRELOGIN response, it must contain at least one byte")
}
offset := 0
results := map[uint8][]byte{}
for true {
for {
rec_type := struct_buf[offset]
if rec_type == preloginTERMINATOR {
break
@ -240,6 +256,16 @@ const (
fIntSecurity = 0x80
)
// OptionFlags3
// http://msdn.microsoft.com/en-us/library/dd304019.aspx
const (
fChangePassword = 1
fSendYukonBinaryXML = 2
fUserInstance = 4
fUnknownCollationHandling = 8
fExtension = 0x10
)
// TypeFlags
const (
// 4 bits for fSQLType
@ -247,12 +273,6 @@ const (
fReadOnlyIntent = 32
)
// OptionFlags3
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac
const (
fExtension = 0x10
)
type login struct {
TDSVersion uint32
PacketSize uint32
@ -295,7 +315,7 @@ func (e *featureExts) Add(f featureExt) error {
}
id := f.featureID()
if _, exists := e.features[id]; exists {
f := "Login error: Feature with ID '%v' is already present in FeatureExt block."
f := "login error: Feature with ID '%v' is already present in FeatureExt block"
return fmt.Errorf(f, id)
}
if e.features == nil {
@ -326,37 +346,63 @@ func (e featureExts) toBytes() []byte {
return d
}
type featureExtFedAuthSTS struct {
FedAuthEcho bool
// featureExtFedAuth tracks federated authentication state before and during login
type featureExtFedAuth struct {
// FedAuthLibrary is populated by the federated authentication provider.
FedAuthLibrary int
// ADALWorkflow is populated by the federated authentication provider.
ADALWorkflow byte
// FedAuthEcho is populated from the prelogin response
FedAuthEcho bool
// FedAuthToken is populated during login with the value from the provider.
FedAuthToken string
Nonce []byte
// Nonce is populated during login with the value from the provider.
Nonce []byte
// Signature is populated during login with the value from the server.
Signature []byte
}
func (e *featureExtFedAuthSTS) featureID() byte {
return 0x02
func (e *featureExtFedAuth) featureID() byte {
return featExtFEDAUTH
}
func (e *featureExtFedAuthSTS) toBytes() []byte {
func (e *featureExtFedAuth) toBytes() []byte {
if e == nil {
return nil
}
options := byte(0x01) << 1 // 0x01 => STS bFedAuthLibrary 7BIT
options := byte(e.FedAuthLibrary) << 1
if e.FedAuthEcho {
options |= 1 // fFedAuthEcho
}
d := make([]byte, 5)
d[0] = options
// Feature extension format depends on the federated auth library.
// Options are described at
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac
var d []byte
// looks like string in
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/f88b63bb-b479-49e1-a87b-deda521da508
tokenBytes := str2ucs2(e.FedAuthToken)
binary.LittleEndian.PutUint32(d[1:], uint32(len(tokenBytes))) // Should be a signed int32, but since the length is relatively small, this should work
d = append(d, tokenBytes...)
switch e.FedAuthLibrary {
case fedAuthLibrarySecurityToken:
d = make([]byte, 5)
d[0] = options
if len(e.Nonce) == 32 {
d = append(d, e.Nonce...)
// looks like string in
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/f88b63bb-b479-49e1-a87b-deda521da508
tokenBytes := str2ucs2(e.FedAuthToken)
binary.LittleEndian.PutUint32(d[1:], uint32(len(tokenBytes))) // Should be a signed int32, but since the length is relatively small, this should work
d = append(d, tokenBytes...)
if len(e.Nonce) == 32 {
d = append(d, e.Nonce...)
}
case fedAuthLibraryADAL:
d = []byte{options, e.ADALWorkflow}
}
return d
@ -418,7 +464,7 @@ func str2ucs2(s string) []byte {
func ucs22str(s []byte) (string, error) {
if len(s)%2 != 0 {
return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s))
return "", fmt.Errorf("illegal UCS2 string length: %d", len(s))
}
buf := make([]uint16, len(s)/2)
for i := 0; i < len(s); i += 2 {
@ -436,7 +482,7 @@ func manglePassword(password string) []byte {
}
// http://msdn.microsoft.com/en-us/library/dd304019.aspx
func sendLogin(w *tdsBuffer, login login) error {
func sendLogin(w *tdsBuffer, login *login) error {
w.BeginPacket(packLogin7, false)
hostname := str2ucs2(login.HostName)
username := str2ucs2(login.UserName)
@ -572,6 +618,36 @@ func sendLogin(w *tdsBuffer, login login) error {
return w.FinishPacket()
}
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/827d9632-2957-4d54-b9ea-384530ae79d0
func sendFedAuthInfo(w *tdsBuffer, fedAuth *featureExtFedAuth) (err error) {
fedauthtoken := str2ucs2(fedAuth.FedAuthToken)
tokenlen := len(fedauthtoken)
datalen := 4 + tokenlen + len(fedAuth.Nonce)
w.BeginPacket(packFedAuthToken, false)
err = binary.Write(w, binary.LittleEndian, uint32(datalen))
if err != nil {
return
}
err = binary.Write(w, binary.LittleEndian, uint32(tokenlen))
if err != nil {
return
}
_, err = w.Write(fedauthtoken)
if err != nil {
return
}
_, err = w.Write(fedAuth.Nonce)
if err != nil {
return
}
return w.FinishPacket()
}
func readUcs2(r io.Reader, numchars int) (res string, err error) {
buf := make([]byte, numchars*2)
_, err = io.ReadFull(r, buf)
@ -770,12 +846,13 @@ type auth interface {
// use the first one that allows a connection.
func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn net.Conn, err error) {
var ips []net.IP
ips, err = net.LookupIP(p.host)
if err != nil {
ip := net.ParseIP(p.host)
if ip == nil {
return nil, err
ip := net.ParseIP(p.host)
if ip == nil {
ips, err = net.LookupIP(p.host)
if err != nil {
return
}
} else {
ips = []net.IP{ip}
}
if len(ips) == 1 {
@ -802,7 +879,7 @@ func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn ne
}
// Wait for either the *first* successful connection, or all the errors
wait_loop:
for i, _ := range ips {
for i := range ips {
select {
case conn = <-connChan:
// Got a connection to use, close any others
@ -824,12 +901,123 @@ func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn ne
}
// Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection
if conn == nil {
f := "Unable to open tcp connection with host '%v:%v': %v"
f := "unable to open tcp connection with host '%v:%v': %v"
return nil, fmt.Errorf(f, p.host, resolveServerPort(p.port), err.Error())
}
return conn, err
}
func preparePreloginFields(p connectParams, fe *featureExtFedAuth) map[uint8][]byte {
instance_buf := []byte(p.instance)
instance_buf = append(instance_buf, 0) // zero terminate instance name
var encrypt byte
if p.disableEncryption {
encrypt = encryptNotSup
} else if p.encrypt {
encrypt = encryptOn
} else {
encrypt = encryptOff
}
fields := map[uint8][]byte{
preloginVERSION: {0, 0, 0, 0, 0, 0},
preloginENCRYPTION: {encrypt},
preloginINSTOPT: instance_buf,
preloginTHREADID: {0, 0, 0, 0},
preloginMARS: {0}, // MARS disabled
}
if fe.FedAuthLibrary != fedAuthLibraryReserved {
fields[preloginFEDAUTHREQUIRED] = []byte{1}
}
return fields
}
func interpretPreloginResponse(p connectParams, fe *featureExtFedAuth, fields map[uint8][]byte) (encrypt byte, err error) {
// If the server returns the preloginFEDAUTHREQUIRED field, then federated authentication
// is supported. The actual value may be 0 or 1, where 0 means either SSPI or federated
// authentication is allowed, while 1 means only federated authentication is allowed.
if fedAuthSupport, ok := fields[preloginFEDAUTHREQUIRED]; ok {
if len(fedAuthSupport) != 1 {
return 0, fmt.Errorf("Federated authentication flag length should be 1: is %d", len(fedAuthSupport))
}
// We need to be able to echo the value back to the server
fe.FedAuthEcho = fedAuthSupport[0] != 0
} else if fe.FedAuthLibrary != fedAuthLibraryReserved {
return 0, fmt.Errorf("Federated authentication is not supported by the server")
}
encryptBytes, ok := fields[preloginENCRYPTION]
if !ok {
return 0, fmt.Errorf("encrypt negotiation failed")
}
encrypt = encryptBytes[0]
if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) {
return 0, fmt.Errorf("server does not support encryption")
}
return
}
func prepareLogin(ctx context.Context, c *Connector, p connectParams, log optionalLogger, auth auth, fe *featureExtFedAuth, packetSize uint32) (l *login, err error) {
l = &login{
TDSVersion: verTDS74,
PacketSize: packetSize,
Database: p.database,
OptionFlags2: fODBC, // to get unlimited TEXTSIZE
HostName: p.workstation,
ServerName: p.host,
AppName: p.appname,
TypeFlags: p.typeFlags,
}
switch {
case fe.FedAuthLibrary == fedAuthLibrarySecurityToken:
if p.logFlags&logDebug != 0 {
log.Println("Starting federated authentication using security token")
}
fe.FedAuthToken, err = c.securityTokenProvider(ctx)
if err != nil {
if p.logFlags&logDebug != 0 {
log.Printf("Failed to retrieve service principal token for federated authentication security token library: %v", err)
}
return nil, err
}
l.FeatureExt.Add(fe)
case fe.FedAuthLibrary == fedAuthLibraryADAL:
if p.logFlags&logDebug != 0 {
log.Println("Starting federated authentication using ADAL")
}
l.FeatureExt.Add(fe)
case auth != nil:
if p.logFlags&logDebug != 0 {
log.Println("Starting SSPI login")
}
l.SSPI, err = auth.InitialBytes()
if err != nil {
return nil, err
}
l.OptionFlags2 |= fIntSecurity
return l, nil
default:
// Default to SQL server authentication with user and password
l.UserName = p.user
l.Password = p.password
}
return l, nil
}
func connect(ctx context.Context, c *Connector, log optionalLogger, p connectParams) (res *tdsSession, err error) {
dialCtx := ctx
if p.dial_timeout > 0 {
@ -842,24 +1030,24 @@ func connect(ctx context.Context, c *Connector, log optionalLogger, p connectPar
// both instance name and port specified
// when port is specified instance name is not used
// you should not provide instance name when you provide port
log.Println("WARN: You specified both instance name and port in the connection string, port will be used and instance name will be ignored");
log.Println("WARN: You specified both instance name and port in the connection string, port will be used and instance name will be ignored")
}
if p.instance != "" && p.port == 0 {
p.instance = strings.ToUpper(p.instance)
d := c.getDialer(&p)
instances, err := getInstances(dialCtx, d, p.host)
if err != nil {
f := "Unable to get instances from Sql Server Browser on host %v: %v"
f := "unable to get instances from Sql Server Browser on host %v: %v"
return nil, fmt.Errorf(f, p.host, err.Error())
}
strport, ok := instances[p.instance]["tcp"]
if !ok {
f := "No instance matching '%v' returned from host '%v'"
f := "no instance matching '%v' returned from host '%v'"
return nil, fmt.Errorf(f, p.instance, p.host)
}
port, err := strconv.ParseUint(strport, 0, 16)
if err != nil {
f := "Invalid tcp port returned from Sql Server Browser '%v': %v"
f := "invalid tcp port returned from Sql Server Browser '%v': %v"
return nil, fmt.Errorf(f, strport, err.Error())
}
p.port = port
@ -880,25 +1068,14 @@ initiate_connection:
logFlags: p.logFlags,
}
instance_buf := []byte(p.instance)
instance_buf = append(instance_buf, 0) // zero terminate instance name
var encrypt byte
if p.disableEncryption {
encrypt = encryptNotSup
} else if p.encrypt {
encrypt = encryptOn
} else {
encrypt = encryptOff
}
fields := map[uint8][]byte{
preloginVERSION: {0, 0, 0, 0, 0, 0},
preloginENCRYPTION: {encrypt},
preloginINSTOPT: instance_buf,
preloginTHREADID: {0, 0, 0, 0},
preloginMARS: {0}, // MARS disabled
fedAuth := &featureExtFedAuth{
FedAuthLibrary: p.fedAuthLibrary,
ADALWorkflow: p.fedAuthADALWorkflow,
}
err = writePrelogin(outbuf, fields)
fields := preparePreloginFields(p, fedAuth)
err = writePrelogin(packPrelogin, outbuf, fields)
if err != nil {
return nil, err
}
@ -908,13 +1085,9 @@ initiate_connection:
return nil, err
}
encryptBytes, ok := fields[preloginENCRYPTION]
if !ok {
return nil, fmt.Errorf("Encrypt negotiation failed")
}
encrypt = encryptBytes[0]
if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) {
return nil, fmt.Errorf("Server does not support encryption")
encrypt, err := interpretPreloginResponse(p, fedAuth, fields)
if err != nil {
return nil, err
}
if encrypt != encryptNotSup {
@ -922,7 +1095,7 @@ initiate_connection:
if p.certificate != "" {
pem, err := ioutil.ReadFile(p.certificate)
if err != nil {
return nil, fmt.Errorf("Cannot read certificate %q: %v", p.certificate, err)
return nil, fmt.Errorf("cannot read certificate %q: %v", p.certificate, err)
}
certs := x509.NewCertPool()
certs.AppendCertsFromPEM(pem)
@ -954,54 +1127,46 @@ initiate_connection:
}
}
login := login{
TDSVersion: verTDS74,
PacketSize: uint32(outbuf.PackageSize()),
Database: p.database,
OptionFlags2: fODBC, // to get unlimited TEXTSIZE
HostName: p.workstation,
ServerName: p.host,
AppName: p.appname,
TypeFlags: p.typeFlags,
}
auth, authOk := getAuth(p.user, p.password, p.serverSPN, p.workstation)
switch {
case p.fedAuthAccessToken != "": // accesstoken ignores user/password
featurext := &featureExtFedAuthSTS{
FedAuthEcho: len(fields[preloginFEDAUTHREQUIRED]) > 0 && fields[preloginFEDAUTHREQUIRED][0] == 1,
FedAuthToken: p.fedAuthAccessToken,
Nonce: fields[preloginNONCEOPT],
}
login.FeatureExt.Add(featurext)
case authOk:
login.SSPI, err = auth.InitialBytes()
if err != nil {
return nil, err
}
login.OptionFlags2 |= fIntSecurity
if authOk {
defer auth.Free()
default:
login.UserName = p.user
login.Password = p.password
} else {
auth = nil
}
login, err := prepareLogin(ctx, c, p, log, auth, fedAuth, uint32(outbuf.PackageSize()))
if err != nil {
return nil, err
}
err = sendLogin(outbuf, login)
if err != nil {
return nil, err
}
// processing login response
success := false
for {
tokchan := make(chan tokenStruct, 5)
go processResponse(context.Background(), &sess, tokchan, nil)
for tok := range tokchan {
// Loop until a packet containing a login acknowledgement is received.
// SSPI and federated authentication scenarios may require multiple
// packet exchanges to complete the login sequence.
for loginAck := false; !loginAck; {
reader := startReading(&sess, ctx, nil)
for {
tok, err := reader.nextToken()
if err != nil {
return nil, err
}
if tok == nil {
break
}
switch token := tok.(type) {
case sspiMsg:
sspi_msg, err := auth.NextBytes(token)
if err != nil {
return nil, err
}
if sspi_msg != nil && len(sspi_msg) > 0 {
if len(sspi_msg) > 0 {
outbuf.BeginPacket(packSSPIMessage, false)
_, err = outbuf.Write(sspi_msg)
if err != nil {
@ -1013,23 +1178,41 @@ initiate_connection:
}
sspi_msg = nil
}
// TODO: for Live ID authentication it may be necessary to
// compare fedAuth.Nonce == token.Nonce and keep track of signature
//case fedAuthAckStruct:
//fedAuth.Signature = token.Signature
case fedAuthInfoStruct:
// For ADAL workflows this contains the STS URL and server SPN.
// If received outside of an ADAL workflow, ignore.
if c == nil || c.adalTokenProvider == nil {
continue
}
// Request the AD token given the server SPN and STS URL
fedAuth.FedAuthToken, err = c.adalTokenProvider(ctx, token.ServerSPN, token.STSURL)
if err != nil {
return nil, err
}
// Now need to send the token as a FEDINFO packet
err = sendFedAuthInfo(outbuf, fedAuth)
if err != nil {
return nil, err
}
case loginAckStruct:
success = true
sess.loginAck = token
case error:
return nil, fmt.Errorf("Login error: %s", token.Error())
loginAck = true
case doneStruct:
if token.isError() {
return nil, fmt.Errorf("Login error: %s", token.getError())
return nil, fmt.Errorf("login error: %s", token.getError())
}
goto loginEnd
case error:
return nil, fmt.Errorf("login error: %s", token.Error())
}
}
}
loginEnd:
if !success {
return nil, fmt.Errorf("Login failed")
}
if sess.routedServer != "" {
toconn.Close()
p.host = sess.routedServer
@ -1041,3 +1224,9 @@ loginEnd:
}
return &sess, nil
}
func (sess *tdsSession) setReturnStatus(status ReturnStatus) {
if sess.returnStatus != nil {
*sess.returnStatus = status
}
}

View file

@ -6,12 +6,11 @@ import (
"errors"
"fmt"
"io"
"net"
"io/ioutil"
"strconv"
"strings"
)
//go:generate stringer -type token
//go:generate go run golang.org/x/tools/cmd/stringer -type token
type token byte
@ -29,6 +28,7 @@ const (
tokenNbcRow token = 210 // 0xd2
tokenEnvChange token = 227 // 0xE3
tokenSSPI token = 237 // 0xED
tokenFedAuthInfo token = 238 // 0xEE
tokenDone token = 253 // 0xFD
tokenDoneProc token = 254
tokenDoneInProc token = 255
@ -70,6 +70,11 @@ const (
envRouting = 20
)
const (
fedAuthInfoSTSURL = 0x01
fedAuthInfoSPN = 0x02
)
// COLMETADATA flags
// https://msdn.microsoft.com/en-us/library/dd357363.aspx
const (
@ -105,26 +110,6 @@ func (d doneStruct) getError() Error {
type doneInProcStruct doneStruct
var doneFlags2str = map[uint16]string{
doneFinal: "final",
doneMore: "more",
doneError: "error",
doneInxact: "inxact",
doneCount: "count",
doneAttn: "attn",
doneSrvError: "srverror",
}
func doneFlags2Str(flags uint16) string {
strs := make([]string, 0, len(doneFlags2str))
for flag, tag := range doneFlags2str {
if flags&flag != 0 {
strs = append(strs, tag)
}
}
return strings.Join(strs, "|")
}
// ENVCHANGE stream
// http://msdn.microsoft.com/en-us/library/dd303449.aspx
func processEnvChg(sess *tdsSession) {
@ -380,9 +365,8 @@ func processEnvChg(sess *tdsSession) {
default:
// ignore rest of records because we don't know how to skip those
sess.log.Printf("WARN: Unknown ENVCHANGE record detected with type id = %d\n", envtype)
break
return
}
}
}
@ -425,6 +409,78 @@ func parseSSPIMsg(r *tdsBuffer) sspiMsg {
return sspiMsg(buf)
}
type fedAuthInfoStruct struct {
STSURL string
ServerSPN string
}
type fedAuthInfoOpt struct {
fedAuthInfoID byte
dataLength, dataOffset uint32
}
func parseFedAuthInfo(r *tdsBuffer) fedAuthInfoStruct {
size := r.uint32()
var STSURL, SPN string
var err error
// Each fedAuthInfoOpt is one byte to indicate the info ID,
// then a four byte offset and a four byte length.
count := r.uint32()
offset := uint32(4)
opts := make([]fedAuthInfoOpt, count)
for i := uint32(0); i < count; i++ {
fedAuthInfoID := r.byte()
dataLength := r.uint32()
dataOffset := r.uint32()
offset += 1 + 4 + 4
opts[i] = fedAuthInfoOpt{
fedAuthInfoID: fedAuthInfoID,
dataLength: dataLength,
dataOffset: dataOffset,
}
}
data := make([]byte, size-offset)
r.ReadFull(data)
for i := uint32(0); i < count; i++ {
if opts[i].dataOffset < offset {
badStreamPanicf("Fed auth info opt stated data offset %d is before data begins in packet at %d",
opts[i].dataOffset, offset)
// returns via panic
}
if opts[i].dataOffset+opts[i].dataLength > size {
badStreamPanicf("Fed auth info opt stated data length %d added to stated offset exceeds size of packet %d",
opts[i].dataOffset+opts[i].dataLength, size)
// returns via panic
}
optData := data[opts[i].dataOffset-offset : opts[i].dataOffset-offset+opts[i].dataLength]
switch opts[i].fedAuthInfoID {
case fedAuthInfoSTSURL:
STSURL, err = ucs22str(optData)
case fedAuthInfoSPN:
SPN, err = ucs22str(optData)
default:
err = fmt.Errorf("Unexpected fed auth info opt ID %d", int(opts[i].fedAuthInfoID))
}
if err != nil {
badStreamPanic(err)
}
}
return fedAuthInfoStruct{
STSURL: STSURL,
ServerSPN: SPN,
}
}
type loginAckStruct struct {
Interface uint8
TDSVersion uint32
@ -449,19 +505,43 @@ func parseLoginAck(r *tdsBuffer) loginAckStruct {
}
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/2eb82f8e-11f0-46dc-b42d-27302fa4701a
func parseFeatureExtAck(r *tdsBuffer) {
// at most 1 featureAck per feature in featureExt
// go-mssqldb will add at most 1 feature, the spec defines 7 different features
for i := 0; i < 8; i++ {
featureID := r.byte() // FeatureID
if featureID == 0xff {
return
type fedAuthAckStruct struct {
Nonce []byte
Signature []byte
}
func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} {
ack := map[byte]interface{}{}
for feature := r.byte(); feature != featExtTERMINATOR; feature = r.byte() {
length := r.uint32()
switch feature {
case featExtFEDAUTH:
// In theory we need to know the federated authentication library to
// know how to parse, but the alternatives provide compatible structures.
fedAuthAck := fedAuthAckStruct{}
if length >= 32 {
fedAuthAck.Nonce = make([]byte, 32)
r.ReadFull(fedAuthAck.Nonce)
length -= 32
}
if length >= 32 {
fedAuthAck.Signature = make([]byte, 32)
r.ReadFull(fedAuthAck.Signature)
length -= 32
}
ack[feature] = fedAuthAck
}
// Skip unprocessed bytes
if length > 0 {
io.CopyN(ioutil.Discard, r, int64(length))
}
size := r.uint32() // FeatureAckDataLen
d := make([]byte, size)
r.ReadFull(d)
}
panic("parsed more than 7 featureAck's, protocol implementation error?")
return ack
}
// http://msdn.microsoft.com/en-us/library/dd357363.aspx
@ -579,7 +659,7 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin
}
var columns []columnStruct
errs := make([]Error, 0, 5)
for {
for tokens := 0; ; tokens += 1 {
token := token(sess.buf.byte())
if sess.logFlags&logDebug != 0 {
sess.log.Printf("got token %v", token)
@ -588,6 +668,9 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin
case tokenSSPI:
ch <- parseSSPIMsg(sess.buf)
return
case tokenFedAuthInfo:
ch <- parseFedAuthInfo(sess.buf)
return
case tokenReturnStatus:
returnStatus := parseReturnStatus(sess.buf)
ch <- returnStatus
@ -595,7 +678,8 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin
loginAck := parseLoginAck(sess.buf)
ch <- loginAck
case tokenFeatureExtAck:
parseFeatureExtAck(sess.buf)
featureExtAck := parseFeatureExtAck(sess.buf)
ch <- featureExtAck
case tokenOrder:
order := parseOrder(sess.buf)
ch <- order
@ -670,158 +754,137 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin
}
}
type parseRespIter byte
const (
parseRespIterContinue parseRespIter = iota // Continue parsing current token.
parseRespIterNext // Fetch the next token.
parseRespIterDone // Done with parsing the response.
)
type parseRespState byte
const (
parseRespStateNormal parseRespState = iota // Normal response state.
parseRespStateCancel // Query is canceled, wait for server to confirm.
parseRespStateClosing // Waiting for tokens to come through.
)
type parseResp struct {
sess *tdsSession
ctxDone <-chan struct{}
state parseRespState
cancelError error
type tokenProcessor struct {
tokChan chan tokenStruct
ctx context.Context
sess *tdsSession
outs map[string]interface{}
lastRow []interface{}
rowCount int64
firstError error
}
func (ts *parseResp) sendAttention(ch chan tokenStruct) parseRespIter {
if err := sendAttention(ts.sess.buf); err != nil {
ts.dlogf("failed to send attention signal %v", err)
ch <- err
return parseRespIterDone
}
ts.state = parseRespStateCancel
return parseRespIterContinue
}
func (ts *parseResp) dlog(msg string) {
// logging from goroutine is disabled to prevent
// data race detection from firing
// The race is probably happening when
// test logger changes between tests.
/*if ts.sess.logFlags&logDebug != 0 {
ts.sess.log.Println(msg)
}*/
}
func (ts *parseResp) dlogf(f string, v ...interface{}) {
/*if ts.sess.logFlags&logDebug != 0 {
ts.sess.log.Printf(f, v...)
}*/
}
func (ts *parseResp) iter(ctx context.Context, ch chan tokenStruct, tokChan chan tokenStruct) parseRespIter {
switch ts.state {
default:
panic("unknown state")
case parseRespStateNormal:
select {
case tok, ok := <-tokChan:
if !ok {
ts.dlog("response finished")
return parseRespIterDone
}
if err, ok := tok.(net.Error); ok && err.Timeout() {
ts.cancelError = err
ts.dlog("got timeout error, sending attention signal to server")
return ts.sendAttention(ch)
}
// Pass the token along.
ch <- tok
return parseRespIterContinue
case <-ts.ctxDone:
ts.ctxDone = nil
ts.dlog("got cancel message, sending attention signal to server")
return ts.sendAttention(ch)
}
case parseRespStateCancel: // Read all responses until a DONE or error is received.Auth
select {
case tok, ok := <-tokChan:
if !ok {
ts.dlog("response finished but waiting for attention ack")
return parseRespIterNext
}
switch tok := tok.(type) {
default:
// Ignore all other tokens while waiting.
// The TDS spec says other tokens may arrive after an attention
// signal is sent. Ignore these tokens and continue looking for
// a DONE with attention confirm mark.
case doneStruct:
if tok.Status&doneAttn != 0 {
ts.dlog("got cancellation confirmation from server")
if ts.cancelError != nil {
ch <- ts.cancelError
ts.cancelError = nil
} else {
ch <- ctx.Err()
}
return parseRespIterDone
}
// If an error happens during cancel, pass it along and just stop.
// We are uncertain to receive more tokens.
case error:
ch <- tok
ts.state = parseRespStateClosing
}
return parseRespIterContinue
case <-ts.ctxDone:
ts.ctxDone = nil
ts.state = parseRespStateClosing
return parseRespIterContinue
}
case parseRespStateClosing: // Wait for current token chan to close.
if _, ok := <-tokChan; !ok {
ts.dlog("response finished")
return parseRespIterDone
}
return parseRespIterContinue
}
}
func processResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) {
ts := &parseResp{
func startReading(sess *tdsSession, ctx context.Context, outs map[string]interface{}) *tokenProcessor {
tokChan := make(chan tokenStruct, 5)
go processSingleResponse(sess, tokChan, outs)
return &tokenProcessor{
tokChan: tokChan,
ctx: ctx,
sess: sess,
ctxDone: ctx.Done(),
outs: outs,
}
defer func() {
// Ensure any remaining error is piped through
// or the query may look like it executed when it actually failed.
if ts.cancelError != nil {
ch <- ts.cancelError
ts.cancelError = nil
}
close(ch)
}()
}
// Loop over multiple responses.
func (t *tokenProcessor) iterateResponse() error {
for {
ts.dlog("initiating response reading")
tokChan := make(chan tokenStruct)
go processSingleResponse(sess, tokChan, outs)
// Loop over multiple tokens in response.
tokensLoop:
for {
switch ts.iter(ctx, ch, tokChan) {
case parseRespIterContinue:
// Nothing, continue to next token.
case parseRespIterNext:
break tokensLoop
case parseRespIterDone:
return
tok, err := t.nextToken()
if err == nil {
if tok == nil {
return t.firstError
} else {
switch token := tok.(type) {
case []columnStruct:
t.sess.columns = token
case []interface{}:
t.lastRow = token
case doneInProcStruct:
if token.Status&doneCount != 0 {
t.rowCount += int64(token.RowCount)
}
case doneStruct:
if token.Status&doneCount != 0 {
t.rowCount += int64(token.RowCount)
}
if token.isError() && t.firstError == nil {
t.firstError = token.getError()
}
case ReturnStatus:
t.sess.setReturnStatus(token)
/*case error:
if resultError == nil {
resultError = token
}*/
}
}
} else {
return err
}
}
}
func (t tokenProcessor) nextToken() (tokenStruct, error) {
// we do this separate non-blocking check on token channel to
// prioritize it over cancellation channel
select {
case tok, more := <-t.tokChan:
err, more := tok.(error)
if more {
// this is an error and not a token
return nil, err
} else {
return tok, nil
}
default:
// there are no tokens on the channel, will need to wait
}
select {
case tok, more := <-t.tokChan:
if more {
err, ok := tok.(error)
if ok {
// this is an error and not a token
return nil, err
} else {
return tok, nil
}
} else {
// completed reading response
return nil, nil
}
case <-t.ctx.Done():
if err := sendAttention(t.sess.buf); err != nil {
// unable to send attention, current connection is bad
// notify caller and close channel
return nil, err
}
// now the server should send cancellation confirmation
// it is possible that we already received full response
// just before we sent cancellation request
// in this case current response would not contain confirmation
// and we would need to read one more response
// first lets finish reading current response and look
// for confirmation in it
if readCancelConfirmation(t.tokChan) {
// we got confirmation in current response
return nil, t.ctx.Err()
}
// we did not get cancellation confirmation in the current response
// read one more response, it must be there
t.tokChan = make(chan tokenStruct, 5)
go processSingleResponse(t.sess, t.tokChan, t.outs)
if readCancelConfirmation(t.tokChan) {
return nil, t.ctx.Err()
}
// we did not get cancellation confirmation, something is not
// right, this connection is not usable anymore
return nil, errors.New("did not get cancellation confirmation from the server")
}
}
func readCancelConfirmation(tokChan chan tokenStruct) bool {
for tok := range tokChan {
switch tok := tok.(type) {
default:
// just skip token
case doneStruct:
if tok.Status&doneAttn != 0 {
// got cancellation confirmation, exit
return true
}
}
}
return false
}

View file

@ -1,29 +1,24 @@
// Code generated by "stringer -type token"; DO NOT EDIT
// Code generated by "stringer -type token"; DO NOT EDIT.
package mssql
import "fmt"
import "strconv"
const (
_token_name_0 = "tokenReturnStatus"
_token_name_1 = "tokenColMetadata"
_token_name_2 = "tokenOrdertokenErrortokenInfo"
_token_name_3 = "tokenLoginAck"
_token_name_4 = "tokenRowtokenNbcRow"
_token_name_5 = "tokenEnvChange"
_token_name_6 = "tokenSSPI"
_token_name_7 = "tokenDonetokenDoneProctokenDoneInProc"
_token_name_2 = "tokenOrdertokenErrortokenInfotokenReturnValuetokenLoginAcktokenFeatureExtAck"
_token_name_3 = "tokenRowtokenNbcRow"
_token_name_4 = "tokenEnvChange"
_token_name_5 = "tokenSSPItokenFedAuthInfo"
_token_name_6 = "tokenDonetokenDoneProctokenDoneInProc"
)
var (
_token_index_0 = [...]uint8{0, 17}
_token_index_1 = [...]uint8{0, 16}
_token_index_2 = [...]uint8{0, 10, 20, 29}
_token_index_3 = [...]uint8{0, 13}
_token_index_4 = [...]uint8{0, 8, 19}
_token_index_5 = [...]uint8{0, 14}
_token_index_6 = [...]uint8{0, 9}
_token_index_7 = [...]uint8{0, 9, 22, 37}
_token_index_2 = [...]uint8{0, 10, 20, 29, 45, 58, 76}
_token_index_3 = [...]uint8{0, 8, 19}
_token_index_5 = [...]uint8{0, 9, 25}
_token_index_6 = [...]uint8{0, 9, 22, 37}
)
func (i token) String() string {
@ -32,22 +27,21 @@ func (i token) String() string {
return _token_name_0
case i == 129:
return _token_name_1
case 169 <= i && i <= 171:
case 169 <= i && i <= 174:
i -= 169
return _token_name_2[_token_index_2[i]:_token_index_2[i+1]]
case i == 173:
return _token_name_3
case 209 <= i && i <= 210:
i -= 209
return _token_name_4[_token_index_4[i]:_token_index_4[i+1]]
return _token_name_3[_token_index_3[i]:_token_index_3[i+1]]
case i == 227:
return _token_name_5
case i == 237:
return _token_name_6
return _token_name_4
case 237 <= i && i <= 238:
i -= 237
return _token_name_5[_token_index_5[i]:_token_index_5[i+1]]
case 253 <= i && i <= 255:
i -= 253
return _token_name_7[_token_index_7[i]:_token_index_7[i+1]]
return _token_name_6[_token_index_6[i]:_token_index_6[i+1]]
default:
return fmt.Sprintf("token(%d)", i)
return "token(" + strconv.FormatInt(int64(i), 10) + ")"
}
}

View file

@ -21,11 +21,11 @@ type isoLevel uint8
const (
isolationUseCurrent isoLevel = 0
isolationReadUncommited = 1
isolationReadCommited = 2
isolationRepeatableRead = 3
isolationSerializable = 4
isolationSnapshot = 5
isolationReadUncommited isoLevel = 1
isolationReadCommited isoLevel = 2
isolationRepeatableRead isoLevel = 3
isolationSerializable isoLevel = 4
isolationSnapshot isoLevel = 5
)
func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation isoLevel, name string, resetSession bool) (err error) {

View file

@ -4,6 +4,7 @@ package mssql
import (
"bytes"
"database/sql"
"encoding/binary"
"errors"
"fmt"
@ -97,6 +98,9 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
for columnStrIdx, fieldIdx := range tvpFieldIndexes {
field := refStr.Field(fieldIdx)
tvpVal := field.Interface()
if tvp.verifyStandardTypeOnNull(buf, tvpVal) {
continue
}
valOf := reflect.ValueOf(tvpVal)
elemKind := field.Kind()
if elemKind == reflect.Ptr && valOf.IsNil() {
@ -155,7 +159,7 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {
defaultValues = append(defaultValues, v.Interface())
continue
}
defaultValues = append(defaultValues, reflect.Zero(field.Type).Interface())
defaultValues = append(defaultValues, tvp.createZeroType(reflect.Zero(field.Type).Interface()))
}
if columnCount-len(tvpFieldIndexes) == columnCount {
@ -209,19 +213,23 @@ func getSchemeAndName(tvpName string) (string, string, error) {
}
splitVal := strings.Split(tvpName, ".")
if len(splitVal) > 2 {
return "", "", errors.New("wrong tvp name")
return "", "", ErrorObjectName
}
const (
openSquareBrackets = "["
closeSquareBrackets = "]"
)
if len(splitVal) == 2 {
res := make([]string, 2)
for key, value := range splitVal {
tmp := strings.Replace(value, "[", "", -1)
tmp = strings.Replace(tmp, "]", "", -1)
tmp := strings.Replace(value, openSquareBrackets, "", -1)
tmp = strings.Replace(tmp, closeSquareBrackets, "", -1)
res[key] = tmp
}
return res[0], res[1], nil
}
tmp := strings.Replace(splitVal[0], "[", "", -1)
tmp = strings.Replace(tmp, "]", "", -1)
tmp := strings.Replace(splitVal[0], openSquareBrackets, "", -1)
tmp = strings.Replace(tmp, closeSquareBrackets, "", -1)
return "", tmp, nil
}
@ -229,3 +237,56 @@ func getSchemeAndName(tvpName string) (string, string, error) {
func getCountSQLSeparators(str string) int {
return strings.Count(str, sqlSeparator)
}
// verify types https://golang.org/pkg/database/sql/
func (tvp TVP) createZeroType(fieldVal interface{}) interface{} {
const (
defaultBool = false
defaultFloat64 = float64(0)
defaultInt64 = int64(0)
defaultString = ""
)
switch fieldVal.(type) {
case sql.NullBool:
return defaultBool
case sql.NullFloat64:
return defaultFloat64
case sql.NullInt64:
return defaultInt64
case sql.NullString:
return defaultString
}
return fieldVal
}
// verify types https://golang.org/pkg/database/sql/
func (tvp TVP) verifyStandardTypeOnNull(buf *bytes.Buffer, tvpVal interface{}) bool {
const (
defaultNull = uint8(0)
)
switch val := tvpVal.(type) {
case sql.NullBool:
if !val.Valid {
binary.Write(buf, binary.LittleEndian, defaultNull)
return true
}
case sql.NullFloat64:
if !val.Valid {
binary.Write(buf, binary.LittleEndian, defaultNull)
return true
}
case sql.NullInt64:
if !val.Valid {
binary.Write(buf, binary.LittleEndian, defaultNull)
return true
}
case sql.NullString:
if !val.Valid {
binary.Write(buf, binary.LittleEndian, uint64(_PLP_NULL))
return true
}
}
return false
}

View file

@ -665,7 +665,7 @@ func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} {
default:
buf = bytes.NewBuffer(make([]byte, 0, size))
}
for true {
for {
chunksize := r.uint32()
if chunksize == 0 {
break
@ -690,6 +690,10 @@ func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} {
}
func writePLPType(w io.Writer, ti typeInfo, buf []byte) (err error) {
if buf == nil {
err = binary.Write(w, binary.LittleEndian, uint64(_PLP_NULL))
return
}
if err = binary.Write(w, binary.LittleEndian, uint64(_UNKNOWN_PLP_LEN)); err != nil {
return
}
@ -807,7 +811,6 @@ func readVarLen(ti *typeInfo, r *tdsBuffer) {
default:
badStreamPanicf("Invalid type %d", ti.TypeId)
}
return
}
func decodeMoney(buf []byte) []byte {
@ -834,8 +837,7 @@ func decodeGuid(buf []byte) []byte {
}
func decodeDecimal(prec uint8, scale uint8, buf []byte) []byte {
var sign uint8
sign = buf[0]
sign := buf[0]
var dec decimal.Decimal
dec.SetPositive(sign != 0)
dec.SetPrec(prec)
@ -1187,7 +1189,7 @@ func makeDecl(ti typeInfo) string {
return fmt.Sprintf("char(%d)", ti.Size)
case typeBigVarChar, typeVarChar:
if ti.Size > 8000 || ti.Size == 0 {
return fmt.Sprintf("varchar(max)")
return "varchar(max)"
} else {
return fmt.Sprintf("varchar(%d)", ti.Size)
}