1
0
Fork 0
forked from forgejo/forgejo

Integrate public as bindata optionally (#293)

* Dropped unused codekit config

* Integrated dynamic and static bindata for public

* Ignore public bindata

* Add a general generate make task

* Integrated flexible public assets into web command

* Updated vendoring, added all missiong govendor deps

* Made the linter happy with the bindata and dynamic code

* Moved public bindata definition to modules directory

* Ignoring the new bindata path now

* Updated to the new public modules import path

* Updated public bindata command and drop the new prefix
This commit is contained in:
Thomas Boerger 2016-11-29 17:26:36 +01:00 committed by Lunny Xiao
parent 4680c349dd
commit b6a95a8cb3
691 changed files with 305318 additions and 1272 deletions

130
vendor/github.com/pingcap/tidb/evaluator/builtin.go generated vendored Normal file
View file

@ -0,0 +1,130 @@
// Copyright 2013 The ql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSES/QL-LICENSE file.
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package evaluator
import (
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/util/types"
)
// OldFunc is for a old builtin function.
type OldFunc struct {
// F is the specific calling function.
F func([]interface{}, context.Context) (interface{}, error)
// MinArgs is the minimal arguments needed,
MinArgs int
// MaxArgs is the maximal arguments needed, -1 for infinity.
MaxArgs int
// IsStatic shows whether this function can be called statically.
IsStatic bool
// IsAggregate represents whether this function is an aggregate function or not.
IsAggregate bool
}
// Func is for a builtin function.
type Func struct {
// F is the specific calling function.
F func([]types.Datum, context.Context) (types.Datum, error)
// MinArgs is the minimal arguments needed,
MinArgs int
// MaxArgs is the maximal arguments needed, -1 for infinity.
MaxArgs int
}
// OldFuncs holds all has old registered builtin functions.
var OldFuncs = map[string]OldFunc{
// control functions
"if": {builtinIf, 3, 3, true, false},
"ifnull": {builtinIfNull, 2, 2, true, false},
"nullif": {builtinNullIf, 2, 2, true, false},
// string functions
"replace": {builtinReplace, 3, 3, true, false},
"strcmp": {builtinStrcmp, 2, 2, true, false},
"convert": {builtinConvert, 2, 2, true, false},
"substring": {builtinSubstring, 2, 3, true, false},
"substring_index": {builtinSubstringIndex, 3, 3, true, false},
"locate": {builtinLocate, 2, 3, true, false},
"trim": {builtinTrim, 1, 3, true, false},
// information functions
"current_user": {builtinCurrentUser, 0, 0, false, false},
"database": {builtinDatabase, 0, 0, false, false},
"found_rows": {builtinFoundRows, 0, 0, false, false},
"user": {builtinUser, 0, 0, false, false},
"connection_id": {builtinConnectionID, 0, 0, true, false},
"version": {builtinVersion, 0, 0, true, false},
}
// Funcs holds all registered builtin functions.
var Funcs = map[string]Func{
// common functions
"coalesce": {builtinCoalesce, 1, -1},
// math functions
"abs": {builtinAbs, 1, 1},
"pow": {builtinPow, 2, 2},
"power": {builtinPow, 2, 2},
"rand": {builtinRand, 0, 1},
// time functions
"curdate": {builtinCurrentDate, 0, 0},
"current_date": {builtinCurrentDate, 0, 0},
"current_time": {builtinCurrentTime, 0, 1},
"current_timestamp": {builtinNow, 0, 1},
"curtime": {builtinCurrentTime, 0, 1},
"date": {builtinDate, 1, 1},
"day": {builtinDay, 1, 1},
"dayname": {builtinDayName, 1, 1},
"dayofmonth": {builtinDayOfMonth, 1, 1},
"dayofweek": {builtinDayOfWeek, 1, 1},
"dayofyear": {builtinDayOfYear, 1, 1},
"hour": {builtinHour, 1, 1},
"microsecond": {builtinMicroSecond, 1, 1},
"minute": {builtinMinute, 1, 1},
"month": {builtinMonth, 1, 1},
"now": {builtinNow, 0, 1},
"second": {builtinSecond, 1, 1},
"sysdate": {builtinSysDate, 0, 1},
"week": {builtinWeek, 1, 2},
"weekday": {builtinWeekDay, 1, 1},
"weekofyear": {builtinWeekOfYear, 1, 1},
"year": {builtinYear, 1, 1},
"yearweek": {builtinYearWeek, 1, 2},
"extract": {builtinExtract, 2, 2},
"date_arith": {builtinDateArith, 3, 3},
// string functions
"concat": {builtinConcat, 1, -1},
"concat_ws": {builtinConcatWS, 2, -1},
"left": {builtinLeft, 2, 2},
"length": {builtinLength, 1, 1},
"lower": {builtinLower, 1, 1},
"repeat": {builtinRepeat, 2, 2},
"upper": {builtinUpper, 1, 1},
}
// See: http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_coalesce
func builtinCoalesce(args []types.Datum, ctx context.Context) (d types.Datum, err error) {
for _, d = range args {
if d.Kind() != types.KindNull {
return d, nil
}
}
return d, nil
}

View file

@ -0,0 +1,76 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package evaluator
import (
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/util/types"
)
// See https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#function_if
func builtinIf(args []interface{}, _ context.Context) (interface{}, error) {
// if(expr1, expr2, expr3)
// if expr1 is true, return expr2, otherwise, return expr3
v1 := args[0]
v2 := args[1]
v3 := args[2]
if v1 == nil {
return v3, nil
}
b, err := types.ToBool(v1)
if err != nil {
return nil, err
}
// TODO: check return type, must be numeric or string
if b == 1 {
return v2, nil
}
return v3, nil
}
// See https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#function_ifnull
func builtinIfNull(args []interface{}, _ context.Context) (interface{}, error) {
// ifnull(expr1, expr2)
// if expr1 is not null, return expr1, otherwise, return expr2
v1 := args[0]
v2 := args[1]
if v1 != nil {
return v1, nil
}
return v2, nil
}
// See https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#function_nullif
func builtinNullIf(args []interface{}, _ context.Context) (interface{}, error) {
// nullif(expr1, expr2)
// returns null if expr1 = expr2 is true, otherwise returns expr1
v1 := args[0]
v2 := args[1]
if v1 == nil || v2 == nil {
return v1, nil
}
if n, err := types.Compare(v1, v2); err != nil || n == 0 {
return nil, err
}
return v1, nil
}

View file

@ -0,0 +1,78 @@
// Copyright 2013 The ql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSES/QL-LICENSE file.
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package evaluator
import (
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx/db"
"github.com/pingcap/tidb/sessionctx/variable"
)
// See: https://dev.mysql.com/doc/refman/5.7/en/information-functions.html
func builtinDatabase(args []interface{}, ctx context.Context) (v interface{}, err error) {
d := db.GetCurrentSchema(ctx)
if d == "" {
return nil, nil
}
return d, nil
}
func builtinFoundRows(arg []interface{}, ctx context.Context) (interface{}, error) {
data := variable.GetSessionVars(ctx)
if data == nil {
return nil, errors.Errorf("Missing session variable when evalue builtin")
}
return data.FoundRows, nil
}
// See: https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_current-user
// TODO: The value of CURRENT_USER() can differ from the value of USER(). We will finish this after we support grant tables.
func builtinCurrentUser(args []interface{}, ctx context.Context) (v interface{}, err error) {
data := variable.GetSessionVars(ctx)
if data == nil {
return nil, errors.Errorf("Missing session variable when evalue builtin")
}
return data.User, nil
}
func builtinUser(args []interface{}, ctx context.Context) (v interface{}, err error) {
data := variable.GetSessionVars(ctx)
if data == nil {
return nil, errors.Errorf("Missing session variable when evalue builtin")
}
return data.User, nil
}
func builtinConnectionID(args []interface{}, ctx context.Context) (v interface{}, err error) {
data := variable.GetSessionVars(ctx)
if data == nil {
return nil, errors.Errorf("Missing session variable when evalue builtin")
}
return data.ConnectionID, nil
}
func builtinVersion(args []interface{}, ctx context.Context) (v interface{}, err error) {
return mysql.ServerVersion, nil
}

View file

@ -0,0 +1,83 @@
// Copyright 2013 The ql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSES/QL-LICENSE file.
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package evaluator
import (
"math"
"math/rand"
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/util/types"
)
// see https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html
func builtinAbs(args []types.Datum, _ context.Context) (d types.Datum, err error) {
d = args[0]
switch d.Kind() {
case types.KindNull:
return d, nil
case types.KindUint64:
return d, nil
case types.KindInt64:
iv := d.GetInt64()
if iv >= 0 {
d.SetInt64(iv)
return d, nil
}
d.SetInt64(-iv)
return d, nil
default:
// we will try to convert other types to float
// TODO: if time has no precision, it will be a integer
f, err := d.ToFloat64()
d.SetFloat64(math.Abs(f))
return d, errors.Trace(err)
}
}
func builtinRand(args []types.Datum, _ context.Context) (d types.Datum, err error) {
if len(args) == 1 && args[0].Kind() != types.KindNull {
seed, err := args[0].ToInt64()
if err != nil {
d.SetNull()
return d, errors.Trace(err)
}
rand.Seed(seed)
}
d.SetFloat64(rand.Float64())
return d, nil
}
func builtinPow(args []types.Datum, _ context.Context) (d types.Datum, err error) {
x, err := args[0].ToFloat64()
if err != nil {
d.SetNull()
return d, errors.Trace(err)
}
y, err := args[1].ToFloat64()
if err != nil {
d.SetNull()
return d, errors.Trace(err)
}
d.SetFloat64(math.Pow(x, y))
return d, nil
}

View file

@ -0,0 +1,476 @@
// Copyright 2013 The ql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSES/QL-LICENSE file.
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package evaluator
import (
"fmt"
"strings"
"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/types"
"golang.org/x/text/transform"
)
// https://dev.mysql.com/doc/refman/5.7/en/string-functions.html
func builtinLength(args []types.Datum, _ context.Context) (d types.Datum, err error) {
switch args[0].Kind() {
case types.KindNull:
d.SetNull()
return d, nil
default:
s, err := args[0].ToString()
if err != nil {
d.SetNull()
return d, errors.Trace(err)
}
d.SetInt64(int64(len(s)))
return d, nil
}
}
// See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_concat
func builtinConcat(args []types.Datum, _ context.Context) (d types.Datum, err error) {
var s []byte
for _, a := range args {
if a.Kind() == types.KindNull {
d.SetNull()
return d, nil
}
var ss string
ss, err = a.ToString()
if err != nil {
d.SetNull()
return d, errors.Trace(err)
}
s = append(s, []byte(ss)...)
}
d.SetBytesAsString(s)
return d, nil
}
// See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_concat-ws
func builtinConcatWS(args []types.Datum, _ context.Context) (d types.Datum, err error) {
var sep string
s := make([]string, 0, len(args))
for i, a := range args {
if a.Kind() == types.KindNull {
if i == 0 {
d.SetNull()
return d, nil
}
continue
}
ss, err := a.ToString()
if err != nil {
d.SetNull()
return d, errors.Trace(err)
}
if i == 0 {
sep = ss
continue
}
s = append(s, ss)
}
d.SetString(strings.Join(s, sep))
return d, nil
}
// See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_left
func builtinLeft(args []types.Datum, _ context.Context) (d types.Datum, err error) {
str, err := args[0].ToString()
if err != nil {
d.SetNull()
return d, errors.Trace(err)
}
length, err := args[1].ToInt64()
if err != nil {
d.SetNull()
return d, errors.Trace(err)
}
l := int(length)
if l < 0 {
l = 0
} else if l > len(str) {
l = len(str)
}
d.SetString(str[:l])
return d, nil
}
// See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_repeat
func builtinRepeat(args []types.Datum, _ context.Context) (d types.Datum, err error) {
str, err := args[0].ToString()
if err != nil {
d.SetNull()
return d, err
}
ch := fmt.Sprintf("%v", str)
num := 0
x := args[1]
switch x.Kind() {
case types.KindInt64:
num = int(x.GetInt64())
case types.KindUint64:
num = int(x.GetUint64())
}
if num < 1 {
d.SetString("")
return d, nil
}
d.SetString(strings.Repeat(ch, num))
return d, nil
}
// See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_lower
func builtinLower(args []types.Datum, _ context.Context) (d types.Datum, err error) {
x := args[0]
switch x.Kind() {
case types.KindNull:
d.SetNull()
return d, nil
default:
s, err := x.ToString()
if err != nil {
d.SetNull()
return d, errors.Trace(err)
}
d.SetString(strings.ToLower(s))
return d, nil
}
}
// See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_upper
func builtinUpper(args []types.Datum, _ context.Context) (d types.Datum, err error) {
x := args[0]
switch x.Kind() {
case types.KindNull:
d.SetNull()
return d, nil
default:
s, err := x.ToString()
if err != nil {
d.SetNull()
return d, errors.Trace(err)
}
d.SetString(strings.ToUpper(s))
return d, nil
}
}
// See: https://dev.mysql.com/doc/refman/5.7/en/string-comparison-functions.html
func builtinStrcmp(args []interface{}, _ context.Context) (interface{}, error) {
if args[0] == nil || args[1] == nil {
return nil, nil
}
left, err := types.ToString(args[0])
if err != nil {
return nil, errors.Trace(err)
}
right, err := types.ToString(args[1])
if err != nil {
return nil, errors.Trace(err)
}
res := types.CompareString(left, right)
return res, nil
}
// See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_replace
func builtinReplace(args []interface{}, _ context.Context) (interface{}, error) {
for _, arg := range args {
if arg == nil {
return nil, nil
}
}
str, err := types.ToString(args[0])
if err != nil {
return nil, errors.Trace(err)
}
oldStr, err := types.ToString(args[1])
if err != nil {
return nil, errors.Trace(err)
}
newStr, err := types.ToString(args[2])
if err != nil {
return nil, errors.Trace(err)
}
return strings.Replace(str, oldStr, newStr, -1), nil
}
// See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#function_convert
func builtinConvert(args []interface{}, _ context.Context) (interface{}, error) {
value := args[0]
Charset := args[1].(string)
// Casting nil to any type returns nil
if value == nil {
return nil, nil
}
str, ok := value.(string)
if !ok {
return nil, nil
}
if strings.ToLower(Charset) == "ascii" {
return value, nil
} else if strings.ToLower(Charset) == "utf8mb4" {
return value, nil
}
encoding, _ := charset.Lookup(Charset)
if encoding == nil {
return nil, errors.Errorf("unknown encoding: %s", Charset)
}
target, _, err := transform.String(encoding.NewDecoder(), str)
if err != nil {
log.Errorf("Convert %s to %s with error: %v", str, Charset, err)
return nil, errors.Trace(err)
}
return target, nil
}
func builtinSubstring(args []interface{}, _ context.Context) (interface{}, error) {
// The meaning of the elements of args.
// arg[0] -> StrExpr
// arg[1] -> Pos
// arg[2] -> Len (Optional)
str, err := types.ToString(args[0])
if err != nil {
return nil, errors.Errorf("Substring invalid args, need string but get %T", args[0])
}
t := args[1]
p, ok := t.(int64)
if !ok {
return nil, errors.Errorf("Substring invalid pos args, need int but get %T", t)
}
pos := int(p)
length := -1
if len(args) == 3 {
t = args[2]
p, ok = t.(int64)
if !ok {
return nil, errors.Errorf("Substring invalid pos args, need int but get %T", t)
}
length = int(p)
}
// The forms without a len argument return a substring from string str starting at position pos.
// The forms with a len argument return a substring len characters long from string str, starting at position pos.
// The forms that use FROM are standard SQL syntax. It is also possible to use a negative value for pos.
// In this case, the beginning of the substring is pos characters from the end of the string, rather than the beginning.
// A negative value may be used for pos in any of the forms of this function.
if pos < 0 {
pos = len(str) + pos
} else {
pos--
}
if pos > len(str) || pos <= 0 {
pos = len(str)
}
end := len(str)
if length != -1 {
end = pos + length
}
if end > len(str) {
end = len(str)
}
return str[pos:end], nil
}
// See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_substring-index
func builtinSubstringIndex(args []interface{}, _ context.Context) (interface{}, error) {
// The meaning of the elements of args.
// args[0] -> StrExpr
// args[1] -> Delim
// args[2] -> Count
fs := args[0]
str, err := types.ToString(fs)
if err != nil {
return nil, errors.Errorf("Substring_Index invalid args, need string but get %T", fs)
}
t := args[1]
delim, err := types.ToString(t)
if err != nil {
return nil, errors.Errorf("Substring_Index invalid delim, need string but get %T", t)
}
if len(delim) == 0 {
return "", nil
}
t = args[2]
c, err := types.ToInt64(t)
if err != nil {
return nil, errors.Trace(err)
}
count := int(c)
strs := strings.Split(str, delim)
var (
start = 0
end = len(strs)
)
if count > 0 {
// If count is positive, everything to the left of the final delimiter (counting from the left) is returned.
if count < end {
end = count
}
} else {
// If count is negative, everything to the right of the final delimiter (counting from the right) is returned.
count = -count
if count < end {
start = end - count
}
}
substrs := strs[start:end]
return strings.Join(substrs, delim), nil
}
// See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_locate
func builtinLocate(args []interface{}, _ context.Context) (interface{}, error) {
// The meaning of the elements of args.
// args[0] -> SubStr
// args[1] -> Str
// args[2] -> Pos
// eval str
fs := args[1]
if fs == nil {
return nil, nil
}
str, err := types.ToString(fs)
if err != nil {
return nil, errors.Trace(err)
}
// eval substr
fs = args[0]
if fs == nil {
return nil, nil
}
subStr, err := types.ToString(fs)
if err != nil {
return nil, errors.Trace(err)
}
// eval pos
pos := int64(0)
if len(args) == 3 {
t := args[2]
p, err := types.ToInt64(t)
if err != nil {
return nil, errors.Trace(err)
}
pos = p - 1
if pos < 0 || pos > int64(len(str)) {
return 0, nil
}
if pos > int64(len(str)-len(subStr)) {
return 0, nil
}
}
if len(subStr) == 0 {
return pos + 1, nil
}
i := strings.Index(str[pos:], subStr)
if i == -1 {
return 0, nil
}
return int64(i) + pos + 1, nil
}
const spaceChars = "\n\t\r "
// See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_trim
func builtinTrim(args []interface{}, _ context.Context) (interface{}, error) {
// args[0] -> Str
// args[1] -> RemStr
// args[2] -> Direction
// eval str
fs := args[0]
if fs == nil {
return nil, nil
}
str, err := types.ToString(fs)
if err != nil {
return nil, errors.Trace(err)
}
remstr := ""
// eval remstr
if len(args) > 1 {
fs = args[1]
if fs != nil {
remstr, err = types.ToString(fs)
if err != nil {
return nil, errors.Trace(err)
}
}
}
// do trim
var result string
var direction ast.TrimDirectionType
if len(args) > 2 {
direction = args[2].(ast.TrimDirectionType)
} else {
direction = ast.TrimBothDefault
}
if direction == ast.TrimLeading {
if len(remstr) > 0 {
result = trimLeft(str, remstr)
} else {
result = strings.TrimLeft(str, spaceChars)
}
} else if direction == ast.TrimTrailing {
if len(remstr) > 0 {
result = trimRight(str, remstr)
} else {
result = strings.TrimRight(str, spaceChars)
}
} else if len(remstr) > 0 {
x := trimLeft(str, remstr)
result = trimRight(x, remstr)
} else {
result = strings.Trim(str, spaceChars)
}
return result, nil
}
func trimLeft(str, remstr string) string {
for {
x := strings.TrimPrefix(str, remstr)
if len(x) == len(str) {
return x
}
str = x
}
}
func trimRight(str, remstr string) string {
for {
x := strings.TrimSuffix(str, remstr)
if len(x) == len(str) {
return x
}
str = x
}
}

View file

@ -0,0 +1,555 @@
// Copyright 2013 The ql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSES/QL-LICENSE file.
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package evaluator
import (
"fmt"
"regexp"
"strings"
"time"
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/types"
)
func convertToTime(arg types.Datum, tp byte) (d types.Datum, err error) {
f := types.NewFieldType(tp)
f.Decimal = mysql.MaxFsp
d, err = arg.ConvertTo(f)
if err != nil {
d.SetNull()
return d, errors.Trace(err)
}
if d.Kind() == types.KindNull {
return d, nil
}
if d.Kind() != types.KindMysqlTime {
err = errors.Errorf("need time type, but got %T", d.GetValue())
d.SetNull()
return d, err
}
return d, nil
}
func convertToDuration(arg types.Datum, fsp int) (d types.Datum, err error) {
f := types.NewFieldType(mysql.TypeDuration)
f.Decimal = fsp
d, err = arg.ConvertTo(f)
if err != nil {
d.SetNull()
return d, errors.Trace(err)
}
if d.Kind() == types.KindNull {
d.SetNull()
return d, nil
}
if d.Kind() != types.KindMysqlDuration {
err = errors.Errorf("need duration type, but got %T", d.GetValue())
d.SetNull()
return d, err
}
return d, nil
}
func builtinDate(args []types.Datum, _ context.Context) (types.Datum, error) {
return convertToTime(args[0], mysql.TypeDate)
}
// See http://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_day
// day is a synonym for DayOfMonth
func builtinDay(args []types.Datum, ctx context.Context) (types.Datum, error) {
return builtinDayOfMonth(args, ctx)
}
// See http://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_hour
func builtinHour(args []types.Datum, _ context.Context) (types.Datum, error) {
d, err := convertToDuration(args[0], mysql.MaxFsp)
if err != nil || d.Kind() == types.KindNull {
d.SetNull()
return d, errors.Trace(err)
}
// No need to check type here.
h := int64(d.GetMysqlDuration().Hour())
d.SetInt64(h)
return d, nil
}
// See http://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_minute
func builtinMinute(args []types.Datum, _ context.Context) (types.Datum, error) {
d, err := convertToDuration(args[0], mysql.MaxFsp)
if err != nil || d.Kind() == types.KindNull {
d.SetNull()
return d, errors.Trace(err)
}
// No need to check type here.
m := int64(d.GetMysqlDuration().Minute())
d.SetInt64(m)
return d, nil
}
// See http://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_second
func builtinSecond(args []types.Datum, _ context.Context) (types.Datum, error) {
d, err := convertToDuration(args[0], mysql.MaxFsp)
if err != nil || d.Kind() == types.KindNull {
d.SetNull()
return d, errors.Trace(err)
}
// No need to check type here.
s := int64(d.GetMysqlDuration().Second())
d.SetInt64(s)
return d, nil
}
// See http://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_microsecond
func builtinMicroSecond(args []types.Datum, _ context.Context) (types.Datum, error) {
d, err := convertToDuration(args[0], mysql.MaxFsp)
if err != nil || d.Kind() == types.KindNull {
d.SetNull()
return d, errors.Trace(err)
}
// No need to check type here.
m := int64(d.GetMysqlDuration().MicroSecond())
d.SetInt64(m)
return d, nil
}
// See http://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_month
func builtinMonth(args []types.Datum, _ context.Context) (types.Datum, error) {
d, err := convertToTime(args[0], mysql.TypeDate)
if err != nil || d.Kind() == types.KindNull {
d.SetNull()
return d, errors.Trace(err)
}
// No need to check type here.
t := d.GetMysqlTime()
i := int64(0)
if t.IsZero() {
d.SetInt64(i)
return d, nil
}
i = int64(t.Month())
d.SetInt64(i)
return d, nil
}
func builtinNow(args []types.Datum, _ context.Context) (d types.Datum, err error) {
// TODO: if NOW is used in stored function or trigger, NOW will return the beginning time
// of the execution.
fsp := 0
if len(args) == 1 && args[0].Kind() != types.KindNull {
if fsp, err = checkFsp(args[0]); err != nil {
d.SetNull()
return d, errors.Trace(err)
}
}
t := mysql.Time{
Time: time.Now(),
Type: mysql.TypeDatetime,
// set unspecified for later round
Fsp: mysql.UnspecifiedFsp,
}
tr, err := t.RoundFrac(int(fsp))
if err != nil {
d.SetNull()
return d, errors.Trace(err)
}
d.SetMysqlTime(tr)
return d, nil
}
// See http://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_dayname
func builtinDayName(args []types.Datum, ctx context.Context) (types.Datum, error) {
d, err := builtinWeekDay(args, ctx)
if err != nil || d.Kind() == types.KindNull {
d.SetNull()
return d, errors.Trace(err)
}
weekday := d.GetInt64()
if (weekday < 0) || (weekday >= int64(len(mysql.WeekdayNames))) {
d.SetNull()
return d, errors.Errorf("no name for invalid weekday: %d.", weekday)
}
d.SetString(mysql.WeekdayNames[weekday])
return d, nil
}
// See http://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_dayofmonth
func builtinDayOfMonth(args []types.Datum, _ context.Context) (d types.Datum, err error) {
// TODO: some invalid format like 2000-00-00 will return 0 too.
d, err = convertToTime(args[0], mysql.TypeDate)
if err != nil || d.Kind() == types.KindNull {
d.SetNull()
return d, errors.Trace(err)
}
// No need to check type here.
t := d.GetMysqlTime()
if t.IsZero() {
d.SetInt64(int64(0))
return d, nil
}
d.SetInt64(int64(t.Day()))
return d, nil
}
// See http://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_dayofweek
func builtinDayOfWeek(args []types.Datum, _ context.Context) (d types.Datum, err error) {
d, err = convertToTime(args[0], mysql.TypeDate)
if err != nil || d.Kind() == types.KindNull {
d.SetNull()
return d, errors.Trace(err)
}
// No need to check type here.
t := d.GetMysqlTime()
if t.IsZero() {
d.SetNull()
// TODO: log warning or return error?
return d, nil
}
// 1 is Sunday, 2 is Monday, .... 7 is Saturday
d.SetInt64(int64(t.Weekday()) + 1)
return d, nil
}
// See http://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_dayofyear
func builtinDayOfYear(args []types.Datum, _ context.Context) (types.Datum, error) {
d, err := convertToTime(args[0], mysql.TypeDate)
if err != nil || d.Kind() == types.KindNull {
d.SetNull()
return d, errors.Trace(err)
}
t := d.GetMysqlTime()
if t.IsZero() {
// TODO: log warning or return error?
d.SetNull()
return d, nil
}
yd := int64(t.YearDay())
d.SetInt64(yd)
return d, nil
}
// See http://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_week
func builtinWeek(args []types.Datum, _ context.Context) (types.Datum, error) {
d, err := convertToTime(args[0], mysql.TypeDate)
if err != nil || d.Kind() == types.KindNull {
d.SetNull()
return d, errors.Trace(err)
}
// No need to check type here.
t := d.GetMysqlTime()
if t.IsZero() {
// TODO: log warning or return error?
d.SetNull()
return d, nil
}
// TODO: support multi mode for week
_, week := t.ISOWeek()
wi := int64(week)
d.SetInt64(wi)
return d, nil
}
// See http://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_weekday
func builtinWeekDay(args []types.Datum, _ context.Context) (types.Datum, error) {
d, err := convertToTime(args[0], mysql.TypeDate)
if err != nil || d.Kind() == types.KindNull {
d.SetNull()
return d, errors.Trace(err)
}
// No need to check type here.
t := d.GetMysqlTime()
if t.IsZero() {
// TODO: log warning or return error?
d.SetNull()
return d, nil
}
// Monday is 0, ... Sunday = 6 in MySQL
// but in go, Sunday is 0, ... Saturday is 6
// w will do a conversion.
w := (int64(t.Weekday()) + 6) % 7
d.SetInt64(w)
return d, nil
}
// See http://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_weekofyear
func builtinWeekOfYear(args []types.Datum, ctx context.Context) (types.Datum, error) {
// WeekOfYear is equivalent to to Week(date, 3)
d := types.Datum{}
d.SetInt64(3)
return builtinWeek([]types.Datum{args[0], d}, ctx)
}
// See http://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_year
func builtinYear(args []types.Datum, _ context.Context) (types.Datum, error) {
d, err := convertToTime(args[0], mysql.TypeDate)
if err != nil || d.Kind() == types.KindNull {
return d, errors.Trace(err)
}
// No need to check type here.
t := d.GetMysqlTime()
if t.IsZero() {
d.SetInt64(0)
return d, nil
}
d.SetInt64(int64(t.Year()))
return d, nil
}
// See http://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_yearweek
func builtinYearWeek(args []types.Datum, _ context.Context) (types.Datum, error) {
d, err := convertToTime(args[0], mysql.TypeDate)
if err != nil || d.Kind() == types.KindNull {
d.SetNull()
return d, errors.Trace(err)
}
// No need to check type here.
t := d.GetMysqlTime()
if t.IsZero() {
d.SetNull()
// TODO: log warning or return error?
return d, nil
}
// TODO: support multi mode for week
year, week := t.ISOWeek()
d.SetInt64(int64(year*100 + week))
return d, nil
}
func builtinSysDate(args []types.Datum, ctx context.Context) (types.Datum, error) {
// SYSDATE is not the same as NOW if NOW is used in a stored function or trigger.
// But here we can just think they are the same because we don't support stored function
// and trigger now.
return builtinNow(args, ctx)
}
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_curdate
func builtinCurrentDate(args []types.Datum, _ context.Context) (d types.Datum, err error) {
year, month, day := time.Now().Date()
t := mysql.Time{
Time: time.Date(year, month, day, 0, 0, 0, 0, time.Local),
Type: mysql.TypeDate, Fsp: 0}
d.SetMysqlTime(t)
return d, nil
}
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_curtime
func builtinCurrentTime(args []types.Datum, _ context.Context) (d types.Datum, err error) {
fsp := 0
if len(args) == 1 && args[0].Kind() != types.KindNull {
if fsp, err = checkFsp(args[0]); err != nil {
d.SetNull()
return d, errors.Trace(err)
}
}
d.SetString(time.Now().Format("15:04:05.000000"))
return convertToDuration(d, fsp)
}
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_extract
func builtinExtract(args []types.Datum, _ context.Context) (d types.Datum, err error) {
unit := args[0].GetString()
vd := args[1]
if vd.Kind() == types.KindNull {
d.SetNull()
return d, nil
}
f := types.NewFieldType(mysql.TypeDatetime)
f.Decimal = mysql.MaxFsp
val, err := vd.ConvertTo(f)
if err != nil {
d.SetNull()
return d, errors.Trace(err)
}
if val.Kind() == types.KindNull {
d.SetNull()
return d, nil
}
if val.Kind() != types.KindMysqlTime {
err = errors.Errorf("need time type, but got %T", val)
d.SetNull()
return d, err
}
t := val.GetMysqlTime()
n, err1 := mysql.ExtractTimeNum(unit, t)
if err1 != nil {
d.SetNull()
return d, errors.Trace(err1)
}
d.SetInt64(n)
return d, nil
}
func checkFsp(arg types.Datum) (int, error) {
fsp, err := arg.ToInt64()
if err != nil {
return 0, errors.Trace(err)
}
if int(fsp) > mysql.MaxFsp {
return 0, errors.Errorf("Too big precision %d specified. Maximum is 6.", fsp)
} else if fsp < 0 {
return 0, errors.Errorf("Invalid negative %d specified, must in [0, 6].", fsp)
}
return int(fsp), nil
}
func builtinDateArith(args []types.Datum, ctx context.Context) (d types.Datum, err error) {
// Op is used for distinguishing date_add and date_sub.
// args[0] -> Op
// args[1] -> Date
// args[2] -> DateArithInterval
// health check for date and interval
if args[1].Kind() == types.KindNull {
d.SetNull()
return d, nil
}
nodeDate := args[1]
nodeInterval := args[2].GetInterface().(ast.DateArithInterval)
nodeIntervalIntervalDatum := nodeInterval.Interval.GetDatum()
if nodeIntervalIntervalDatum.Kind() == types.KindNull {
d.SetNull()
return d, nil
}
// parse date
fieldType := mysql.TypeDate
var resultField *types.FieldType
switch nodeDate.Kind() {
case types.KindMysqlTime:
x := nodeDate.GetMysqlTime()
if (x.Type == mysql.TypeDatetime) || (x.Type == mysql.TypeTimestamp) {
fieldType = mysql.TypeDatetime
}
case types.KindString:
x := nodeDate.GetString()
if !mysql.IsDateFormat(x) {
fieldType = mysql.TypeDatetime
}
case types.KindInt64:
x := nodeDate.GetInt64()
if t, err1 := mysql.ParseTimeFromInt64(x); err1 == nil {
if (t.Type == mysql.TypeDatetime) || (t.Type == mysql.TypeTimestamp) {
fieldType = mysql.TypeDatetime
}
}
}
if mysql.IsClockUnit(nodeInterval.Unit) {
fieldType = mysql.TypeDatetime
}
resultField = types.NewFieldType(fieldType)
resultField.Decimal = mysql.MaxFsp
value, err := nodeDate.ConvertTo(resultField)
if err != nil {
d.SetNull()
return d, ErrInvalidOperation.Gen("DateArith invalid args, need date but get %T", nodeDate)
}
if value.Kind() == types.KindNull {
d.SetNull()
return d, ErrInvalidOperation.Gen("DateArith invalid args, need date but get %v", value.GetValue())
}
if value.Kind() != types.KindMysqlTime {
d.SetNull()
return d, ErrInvalidOperation.Gen("DateArith need time type, but got %T", value.GetValue())
}
result := value.GetMysqlTime()
// parse interval
var interval string
if strings.ToLower(nodeInterval.Unit) == "day" {
day, err2 := parseDayInterval(*nodeIntervalIntervalDatum)
if err2 != nil {
d.SetNull()
return d, ErrInvalidOperation.Gen("DateArith invalid day interval, need int but got %T", nodeIntervalIntervalDatum.GetString())
}
interval = fmt.Sprintf("%d", day)
} else {
if nodeIntervalIntervalDatum.Kind() == types.KindString {
interval = fmt.Sprintf("%v", nodeIntervalIntervalDatum.GetString())
} else {
ii, err := nodeIntervalIntervalDatum.ToInt64()
if err != nil {
d.SetNull()
return d, errors.Trace(err)
}
interval = fmt.Sprintf("%v", ii)
}
}
year, month, day, duration, err := mysql.ExtractTimeValue(nodeInterval.Unit, interval)
if err != nil {
d.SetNull()
return d, errors.Trace(err)
}
op := args[0].GetInterface().(ast.DateArithType)
if op == ast.DateSub {
year, month, day, duration = -year, -month, -day, -duration
}
result.Time = result.Time.Add(duration)
result.Time = result.Time.AddDate(int(year), int(month), int(day))
if result.Time.Nanosecond() == 0 {
result.Fsp = 0
}
d.SetMysqlTime(result)
return d, nil
}
var reg = regexp.MustCompile(`[\d]+`)
func parseDayInterval(value types.Datum) (int64, error) {
switch value.Kind() {
case types.KindString:
vs := value.GetString()
s := strings.ToLower(vs)
if s == "false" {
return 0, nil
} else if s == "true" {
return 1, nil
}
value.SetString(reg.FindString(vs))
}
return value.ToInt64()
}

717
vendor/github.com/pingcap/tidb/evaluator/evaluator.go generated vendored Normal file
View file

@ -0,0 +1,717 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package evaluator
import (
"strings"
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/util/types"
)
// Error instances.
var (
ErrInvalidOperation = terror.ClassEvaluator.New(CodeInvalidOperation, "invalid operation")
)
// Error codes.
const (
CodeInvalidOperation terror.ErrCode = 1
)
// Eval evaluates an expression to a value.
func Eval(ctx context.Context, expr ast.ExprNode) (interface{}, error) {
e := &Evaluator{ctx: ctx}
expr.Accept(e)
if e.err != nil {
return nil, errors.Trace(e.err)
}
return expr.GetValue(), nil
}
// EvalBool evalueates an expression to a boolean value.
func EvalBool(ctx context.Context, expr ast.ExprNode) (bool, error) {
val, err := Eval(ctx, expr)
if err != nil {
return false, errors.Trace(err)
}
if val == nil {
return false, nil
}
i, err := types.ToBool(val)
if err != nil {
return false, errors.Trace(err)
}
return i != 0, nil
}
func boolToInt64(v bool) int64 {
if v {
return int64(1)
}
return int64(0)
}
// Evaluator is an ast Visitor that evaluates an expression.
type Evaluator struct {
ctx context.Context
err error
multipleRows bool
existRow bool
}
// Enter implements ast.Visitor interface.
func (e *Evaluator) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
switch v := in.(type) {
case *ast.SubqueryExpr:
if v.Evaluated && !v.UseOuterContext {
return in, true
}
case *ast.PatternInExpr, *ast.CompareSubqueryExpr:
e.multipleRows = true
case *ast.ExistsSubqueryExpr:
e.existRow = true
}
return in, false
}
// Leave implements ast.Visitor interface.
func (e *Evaluator) Leave(in ast.Node) (out ast.Node, ok bool) {
switch v := in.(type) {
case *ast.AggregateFuncExpr:
ok = e.aggregateFunc(v)
case *ast.BetweenExpr:
ok = e.between(v)
case *ast.BinaryOperationExpr:
ok = e.binaryOperation(v)
case *ast.CaseExpr:
ok = e.caseExpr(v)
case *ast.ColumnName:
ok = true
case *ast.ColumnNameExpr:
ok = e.columnName(v)
case *ast.CompareSubqueryExpr:
e.multipleRows = false
ok = e.compareSubquery(v)
case *ast.DefaultExpr:
ok = e.defaultExpr(v)
case *ast.ExistsSubqueryExpr:
e.existRow = false
ok = e.existsSubquery(v)
case *ast.FuncCallExpr:
ok = e.funcCall(v)
case *ast.FuncCastExpr:
ok = e.funcCast(v)
case *ast.IsNullExpr:
ok = e.isNull(v)
case *ast.IsTruthExpr:
ok = e.isTruth(v)
case *ast.ParamMarkerExpr:
ok = e.paramMarker(v)
case *ast.ParenthesesExpr:
ok = e.parentheses(v)
case *ast.PatternInExpr:
e.multipleRows = false
ok = e.patternIn(v)
case *ast.PatternLikeExpr:
ok = e.patternLike(v)
case *ast.PatternRegexpExpr:
ok = e.patternRegexp(v)
case *ast.PositionExpr:
ok = e.position(v)
case *ast.RowExpr:
ok = e.row(v)
case *ast.SubqueryExpr:
ok = e.subqueryExpr(v)
case ast.SubqueryExec:
ok = e.subqueryExec(v)
case *ast.UnaryOperationExpr:
ok = e.unaryOperation(v)
case *ast.ValueExpr:
ok = true
case *ast.ValuesExpr:
ok = e.values(v)
case *ast.VariableExpr:
ok = e.variable(v)
case *ast.WhenClause:
ok = true
}
out = in
return
}
func (e *Evaluator) between(v *ast.BetweenExpr) bool {
var l, r ast.ExprNode
op := opcode.AndAnd
if v.Not {
// v < lv || v > rv
op = opcode.OrOr
l = &ast.BinaryOperationExpr{Op: opcode.LT, L: v.Expr, R: v.Left}
r = &ast.BinaryOperationExpr{Op: opcode.GT, L: v.Expr, R: v.Right}
} else {
// v >= lv && v <= rv
l = &ast.BinaryOperationExpr{Op: opcode.GE, L: v.Expr, R: v.Left}
r = &ast.BinaryOperationExpr{Op: opcode.LE, L: v.Expr, R: v.Right}
}
ret := &ast.BinaryOperationExpr{Op: op, L: l, R: r}
ret.Accept(e)
if e.err != nil {
return false
}
v.SetDatum(*ret.GetDatum())
return true
}
func (e *Evaluator) caseExpr(v *ast.CaseExpr) bool {
tmp := types.NewDatum(boolToInt64(true))
target := &tmp
if v.Value != nil {
target = v.Value.GetDatum()
}
if target.Kind() != types.KindNull {
for _, val := range v.WhenClauses {
cmp, err := target.CompareDatum(*val.Expr.GetDatum())
if err != nil {
e.err = errors.Trace(err)
return false
}
if cmp == 0 {
v.SetDatum(*val.Result.GetDatum())
return true
}
}
}
if v.ElseClause != nil {
v.SetDatum(*v.ElseClause.GetDatum())
} else {
v.SetNull()
}
return true
}
func (e *Evaluator) columnName(v *ast.ColumnNameExpr) bool {
v.SetDatum(*v.Refer.Expr.GetDatum())
return true
}
func (e *Evaluator) defaultExpr(v *ast.DefaultExpr) bool {
return true
}
func (e *Evaluator) compareSubquery(cs *ast.CompareSubqueryExpr) bool {
lvDatum := cs.L.GetDatum()
if lvDatum.Kind() == types.KindNull {
cs.SetNull()
return true
}
lv := lvDatum.GetValue()
x, err := e.checkResult(cs, lv, cs.R.GetValue().([]interface{}))
if err != nil {
e.err = errors.Trace(err)
return false
}
cs.SetValue(x)
return true
}
func (e *Evaluator) checkResult(cs *ast.CompareSubqueryExpr, lv interface{}, result []interface{}) (interface{}, error) {
if cs.All {
return e.checkAllResult(cs, lv, result)
}
return e.checkAnyResult(cs, lv, result)
}
func (e *Evaluator) checkAllResult(cs *ast.CompareSubqueryExpr, lv interface{}, result []interface{}) (interface{}, error) {
hasNull := false
for _, v := range result {
if v == nil {
hasNull = true
continue
}
comRes, err := types.Compare(lv, v)
if err != nil {
return nil, errors.Trace(err)
}
res, err := getCompResult(cs.Op, comRes)
if err != nil {
return nil, errors.Trace(err)
}
if !res {
return false, nil
}
}
if hasNull {
// If no matched but we get null, return null.
// Like `insert t (c) values (1),(2),(null)`, then
// `select 3 > all (select c from t)`, returns null.
return nil, nil
}
return true, nil
}
func (e *Evaluator) checkAnyResult(cs *ast.CompareSubqueryExpr, lv interface{}, result []interface{}) (interface{}, error) {
hasNull := false
for _, v := range result {
if v == nil {
hasNull = true
continue
}
comRes, err := types.Compare(lv, v)
if err != nil {
return nil, errors.Trace(err)
}
res, err := getCompResult(cs.Op, comRes)
if err != nil {
return nil, errors.Trace(err)
}
if res {
return true, nil
}
}
if hasNull {
// If no matched but we get null, return null.
// Like `insert t (c) values (1),(2),(null)`, then
// `select 0 > any (select c from t)`, returns null.
return nil, nil
}
return false, nil
}
func (e *Evaluator) existsSubquery(v *ast.ExistsSubqueryExpr) bool {
datum := v.Sel.GetDatum()
if datum.Kind() == types.KindNull {
v.SetInt64(0)
return true
}
r := datum.GetValue()
rows, _ := r.([]interface{})
if len(rows) > 0 {
v.SetInt64(1)
} else {
v.SetInt64(0)
}
return true
}
// Evaluate SubqueryExpr.
// Get the value from v.SubQuery and set it to v.
func (e *Evaluator) subqueryExpr(v *ast.SubqueryExpr) bool {
if v.SubqueryExec != nil {
v.SetDatum(*v.SubqueryExec.GetDatum())
}
v.Evaluated = true
return true
}
// Do the real work to evaluate subquery.
func (e *Evaluator) subqueryExec(v ast.SubqueryExec) bool {
rowCount := 2
if e.multipleRows {
rowCount = -1
} else if e.existRow {
rowCount = 1
}
rows, err := v.EvalRows(e.ctx, rowCount)
if err != nil {
e.err = errors.Trace(err)
return false
}
if e.multipleRows || e.existRow {
v.SetValue(rows)
return true
}
switch len(rows) {
case 0:
v.GetDatum().SetNull()
case 1:
v.SetValue(rows[0])
default:
e.err = errors.New("Subquery returns more than 1 row")
return false
}
return true
}
func (e *Evaluator) checkInList(not bool, in interface{}, list []interface{}) interface{} {
hasNull := false
for _, v := range list {
if v == nil {
hasNull = true
continue
}
r, err := types.Compare(types.Coerce(in, v))
if err != nil {
e.err = errors.Trace(err)
return nil
}
if r == 0 {
if !not {
return 1
}
return 0
}
}
if hasNull {
// if no matched but we got null in In, return null
// e.g 1 in (null, 2, 3) returns null
return nil
}
if not {
return 1
}
return 0
}
func (e *Evaluator) patternIn(n *ast.PatternInExpr) bool {
lhs := n.Expr.GetDatum()
if lhs.Kind() == types.KindNull {
n.SetNull()
return true
}
if n.Sel == nil {
values := make([]interface{}, 0, len(n.List))
for _, ei := range n.List {
values = append(values, ei.GetValue())
}
x := e.checkInList(n.Not, lhs.GetValue(), values)
if e.err != nil {
return false
}
n.SetValue(x)
return true
}
se := n.Sel.(*ast.SubqueryExpr)
sel := se.SubqueryExec
res := sel.GetValue().([]interface{})
x := e.checkInList(n.Not, lhs.GetValue(), res)
if e.err != nil {
return false
}
n.SetValue(x)
return true
}
func (e *Evaluator) isNull(v *ast.IsNullExpr) bool {
var boolVal bool
if v.Expr.GetDatum().Kind() == types.KindNull {
boolVal = true
}
if v.Not {
boolVal = !boolVal
}
v.SetInt64(boolToInt64(boolVal))
return true
}
func (e *Evaluator) isTruth(v *ast.IsTruthExpr) bool {
var boolVal bool
datum := v.Expr.GetDatum()
if datum.Kind() != types.KindNull {
ival, err := datum.ToBool()
if err != nil {
e.err = errors.Trace(err)
return false
}
if ival == v.True {
boolVal = true
}
}
if v.Not {
boolVal = !boolVal
}
v.GetDatum().SetInt64(boolToInt64(boolVal))
return true
}
func (e *Evaluator) paramMarker(v *ast.ParamMarkerExpr) bool {
return true
}
func (e *Evaluator) parentheses(v *ast.ParenthesesExpr) bool {
v.SetDatum(*v.Expr.GetDatum())
return true
}
func (e *Evaluator) position(v *ast.PositionExpr) bool {
v.SetDatum(*v.Refer.Expr.GetDatum())
return true
}
func (e *Evaluator) row(v *ast.RowExpr) bool {
row := make([]interface{}, 0, len(v.Values))
for _, val := range v.Values {
row = append(row, val.GetValue())
}
v.SetValue(row)
return true
}
func (e *Evaluator) unaryOperation(u *ast.UnaryOperationExpr) bool {
defer func() {
if er := recover(); er != nil {
e.err = errors.Errorf("%v", er)
}
}()
aDatum := u.V.GetDatum()
if aDatum.Kind() == types.KindNull {
u.SetNull()
return true
}
switch op := u.Op; op {
case opcode.Not:
n, err := aDatum.ToBool()
if err != nil {
e.err = errors.Trace(err)
} else if n == 0 {
u.SetInt64(1)
} else {
u.SetInt64(0)
}
case opcode.BitNeg:
// for bit operation, we will use int64 first, then return uint64
n, err := aDatum.ToInt64()
if err != nil {
e.err = errors.Trace(err)
return false
}
u.SetUint64(uint64(^n))
case opcode.Plus:
switch aDatum.Kind() {
case types.KindInt64,
types.KindUint64,
types.KindFloat64,
types.KindFloat32,
types.KindMysqlDuration,
types.KindMysqlTime,
types.KindString,
types.KindMysqlDecimal,
types.KindBytes,
types.KindMysqlHex,
types.KindMysqlBit,
types.KindMysqlEnum,
types.KindMysqlSet:
u.SetDatum(*aDatum)
default:
e.err = ErrInvalidOperation
return false
}
case opcode.Minus:
switch aDatum.Kind() {
case types.KindInt64:
u.SetInt64(-aDatum.GetInt64())
case types.KindUint64:
u.SetInt64(-int64(aDatum.GetUint64()))
case types.KindFloat64:
u.SetFloat64(-aDatum.GetFloat64())
case types.KindFloat32:
u.SetFloat32(-aDatum.GetFloat32())
case types.KindMysqlDuration:
u.SetValue(mysql.ZeroDecimal.Sub(aDatum.GetMysqlDuration().ToNumber()))
case types.KindMysqlTime:
u.SetValue(mysql.ZeroDecimal.Sub(aDatum.GetMysqlTime().ToNumber()))
case types.KindString:
f, err := types.StrToFloat(aDatum.GetString())
e.err = errors.Trace(err)
u.SetFloat64(-f)
case types.KindMysqlDecimal:
f, _ := aDatum.GetMysqlDecimal().Float64()
u.SetValue(mysql.NewDecimalFromFloat(-f))
case types.KindBytes:
f, err := types.StrToFloat(string(aDatum.GetBytes()))
e.err = errors.Trace(err)
u.SetFloat64(-f)
case types.KindMysqlHex:
u.SetFloat64(-aDatum.GetMysqlHex().ToNumber())
case types.KindMysqlBit:
u.SetFloat64(-aDatum.GetMysqlBit().ToNumber())
case types.KindMysqlEnum:
u.SetFloat64(-aDatum.GetMysqlEnum().ToNumber())
case types.KindMysqlSet:
u.SetFloat64(-aDatum.GetMysqlSet().ToNumber())
default:
e.err = ErrInvalidOperation
return false
}
default:
e.err = ErrInvalidOperation
return false
}
return true
}
func (e *Evaluator) values(v *ast.ValuesExpr) bool {
v.SetDatum(*v.Column.GetDatum())
return true
}
func (e *Evaluator) variable(v *ast.VariableExpr) bool {
name := strings.ToLower(v.Name)
sessionVars := variable.GetSessionVars(e.ctx)
globalVars := variable.GetGlobalVarAccessor(e.ctx)
if !v.IsSystem {
// user vars
if value, ok := sessionVars.Users[name]; ok {
v.SetString(value)
return true
}
// select null user vars is permitted.
v.SetNull()
return true
}
_, ok := variable.SysVars[name]
if !ok {
// select null sys vars is not permitted
e.err = variable.UnknownSystemVar.Gen("Unknown system variable '%s'", name)
return false
}
if !v.IsGlobal {
if value, ok := sessionVars.Systems[name]; ok {
v.SetString(value)
return true
}
}
value, err := globalVars.GetGlobalSysVar(e.ctx, name)
if err != nil {
e.err = errors.Trace(err)
return false
}
v.SetString(value)
return true
}
func (e *Evaluator) funcCall(v *ast.FuncCallExpr) bool {
of, ok := OldFuncs[v.FnName.L]
if ok {
if len(v.Args) < of.MinArgs || (of.MaxArgs != -1 && len(v.Args) > of.MaxArgs) {
e.err = ErrInvalidOperation.Gen("number of function arguments must in [%d, %d].", of.MinArgs, of.MaxArgs)
return false
}
a := make([]interface{}, len(v.Args))
for i, arg := range v.Args {
a[i] = arg.GetValue()
}
val, err := of.F(a, e.ctx)
if err != nil {
e.err = errors.Trace(err)
return false
}
v.SetValue(val)
return true
}
f, ok := Funcs[v.FnName.L]
if !ok {
e.err = ErrInvalidOperation.Gen("unknown function %s", v.FnName.O)
return false
}
if len(v.Args) < f.MinArgs || (f.MaxArgs != -1 && len(v.Args) > f.MaxArgs) {
e.err = ErrInvalidOperation.Gen("number of function arguments must in [%d, %d].", f.MinArgs, f.MaxArgs)
return false
}
a := make([]types.Datum, len(v.Args))
for i, arg := range v.Args {
a[i] = *arg.GetDatum()
}
val, err := f.F(a, e.ctx)
if err != nil {
e.err = errors.Trace(err)
return false
}
v.SetDatum(val)
return true
}
func (e *Evaluator) funcCast(v *ast.FuncCastExpr) bool {
value := v.Expr.GetValue()
// Casting nil to any type returns null
if value == nil {
v.SetNull()
return true
}
var err error
value, err = types.Cast(value, v.Tp)
if err != nil {
e.err = errors.Trace(err)
return false
}
v.SetValue(value)
return true
}
func (e *Evaluator) aggregateFunc(v *ast.AggregateFuncExpr) bool {
name := strings.ToLower(v.F)
switch name {
case ast.AggFuncAvg:
e.evalAggAvg(v)
case ast.AggFuncCount:
e.evalAggCount(v)
case ast.AggFuncFirstRow, ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncSum:
e.evalAggSetValue(v)
case ast.AggFuncGroupConcat:
e.evalAggGroupConcat(v)
}
return e.err == nil
}
func (e *Evaluator) evalAggCount(v *ast.AggregateFuncExpr) {
ctx := v.GetContext()
v.SetInt64(ctx.Count)
}
func (e *Evaluator) evalAggSetValue(v *ast.AggregateFuncExpr) {
ctx := v.GetContext()
v.SetValue(ctx.Value)
}
func (e *Evaluator) evalAggAvg(v *ast.AggregateFuncExpr) {
ctx := v.GetContext()
switch x := ctx.Value.(type) {
case float64:
ctx.Value = x / float64(ctx.Count)
case mysql.Decimal:
ctx.Value = x.Div(mysql.NewDecimalFromUint(uint64(ctx.Count), 0))
}
v.SetValue(ctx.Value)
}
func (e *Evaluator) evalAggGroupConcat(v *ast.AggregateFuncExpr) {
ctx := v.GetContext()
if ctx.Buffer != nil {
v.SetValue(ctx.Buffer.String())
} else {
v.SetValue(nil)
}
}

View file

@ -0,0 +1,564 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package evaluator
import (
"math"
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)
const (
zeroI64 int64 = 0
oneI64 int64 = 1
)
func (e *Evaluator) binaryOperation(o *ast.BinaryOperationExpr) bool {
switch o.Op {
case opcode.AndAnd, opcode.OrOr, opcode.LogicXor:
return e.handleLogicOperation(o)
case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ:
return e.handleComparisonOp(o)
case opcode.RightShift, opcode.LeftShift, opcode.And, opcode.Or, opcode.Xor:
return e.handleBitOp(o)
case opcode.Plus, opcode.Minus, opcode.Mod, opcode.Div, opcode.Mul, opcode.IntDiv:
return e.handleArithmeticOp(o)
default:
e.err = ErrInvalidOperation
return false
}
}
func (e *Evaluator) handleLogicOperation(o *ast.BinaryOperationExpr) bool {
switch o.Op {
case opcode.AndAnd:
return e.handleAndAnd(o)
case opcode.OrOr:
return e.handleOrOr(o)
case opcode.LogicXor:
return e.handleXor(o)
default:
e.err = ErrInvalidOperation.Gen("unkown operator %s", o.Op)
return false
}
}
func (e *Evaluator) handleAndAnd(o *ast.BinaryOperationExpr) bool {
leftVal := o.L.GetValue()
righVal := o.R.GetValue()
if leftVal != nil {
x, err := types.ToBool(leftVal)
if err != nil {
e.err = errors.Trace(err)
return false
} else if x == 0 {
// false && any other types is false
o.SetValue(x)
return true
}
}
if righVal != nil {
y, err := types.ToBool(righVal)
if err != nil {
e.err = errors.Trace(err)
return false
} else if y == 0 {
o.SetValue(y)
return true
}
}
if leftVal == nil || righVal == nil {
o.SetValue(nil)
return true
}
o.SetValue(int64(1))
return true
}
func (e *Evaluator) handleOrOr(o *ast.BinaryOperationExpr) bool {
leftVal := o.L.GetValue()
righVal := o.R.GetValue()
if leftVal != nil {
x, err := types.ToBool(leftVal)
if err != nil {
e.err = errors.Trace(err)
return false
} else if x == 1 {
// true || any other types is true.
o.SetValue(x)
return true
}
}
if righVal != nil {
y, err := types.ToBool(righVal)
if err != nil {
e.err = errors.Trace(err)
return false
} else if y == 1 {
o.SetValue(y)
return true
}
}
if leftVal == nil || righVal == nil {
o.SetValue(nil)
return true
}
o.SetValue(int64(0))
return true
}
func (e *Evaluator) handleXor(o *ast.BinaryOperationExpr) bool {
leftVal := o.L.GetValue()
righVal := o.R.GetValue()
if leftVal == nil || righVal == nil {
o.SetValue(nil)
return true
}
x, err := types.ToBool(leftVal)
if err != nil {
e.err = errors.Trace(err)
return false
}
y, err := types.ToBool(righVal)
if err != nil {
e.err = errors.Trace(err)
return false
}
if x == y {
o.SetValue(int64(0))
} else {
o.SetValue(int64(1))
}
return true
}
func (e *Evaluator) handleComparisonOp(o *ast.BinaryOperationExpr) bool {
a, b := types.Coerce(o.L.GetValue(), o.R.GetValue())
if a == nil || b == nil {
// for <=>, if a and b are both nil, return true.
// if a or b is nil, return false.
if o.Op == opcode.NullEQ {
if a == nil && b == nil {
o.SetValue(oneI64)
} else {
o.SetValue(zeroI64)
}
} else {
o.SetValue(nil)
}
return true
}
n, err := types.Compare(a, b)
if err != nil {
e.err = errors.Trace(err)
return false
}
r, err := getCompResult(o.Op, n)
if err != nil {
e.err = errors.Trace(err)
return false
}
if r {
o.SetValue(oneI64)
} else {
o.SetValue(zeroI64)
}
return true
}
func getCompResult(op opcode.Op, value int) (bool, error) {
switch op {
case opcode.LT:
return value < 0, nil
case opcode.LE:
return value <= 0, nil
case opcode.GE:
return value >= 0, nil
case opcode.GT:
return value > 0, nil
case opcode.EQ:
return value == 0, nil
case opcode.NE:
return value != 0, nil
case opcode.NullEQ:
return value == 0, nil
default:
return false, ErrInvalidOperation.Gen("invalid op %v in comparision operation", op)
}
}
func (e *Evaluator) handleBitOp(o *ast.BinaryOperationExpr) bool {
a, b := types.Coerce(o.L.GetValue(), o.R.GetValue())
if a == nil || b == nil {
o.SetValue(nil)
return true
}
x, err := types.ToInt64(a)
if err != nil {
e.err = errors.Trace(err)
return false
}
y, err := types.ToInt64(b)
if err != nil {
e.err = errors.Trace(err)
return false
}
// use a int64 for bit operator, return uint64
switch o.Op {
case opcode.And:
o.SetValue(uint64(x & y))
case opcode.Or:
o.SetValue(uint64(x | y))
case opcode.Xor:
o.SetValue(uint64(x ^ y))
case opcode.RightShift:
o.SetValue(uint64(x) >> uint64(y))
case opcode.LeftShift:
o.SetValue(uint64(x) << uint64(y))
default:
e.err = ErrInvalidOperation.Gen("invalid op %v in bit operation", o.Op)
return false
}
return true
}
func (e *Evaluator) handleArithmeticOp(o *ast.BinaryOperationExpr) bool {
a, err := coerceArithmetic(o.L.GetValue())
if err != nil {
e.err = errors.Trace(err)
return false
}
b, err := coerceArithmetic(o.R.GetValue())
if err != nil {
e.err = errors.Trace(err)
return false
}
a, b = types.Coerce(a, b)
if a == nil || b == nil {
o.SetValue(nil)
return true
}
var result interface{}
switch o.Op {
case opcode.Plus:
result, e.err = computePlus(a, b)
case opcode.Minus:
result, e.err = computeMinus(a, b)
case opcode.Mul:
result, e.err = computeMul(a, b)
case opcode.Div:
result, e.err = computeDiv(a, b)
case opcode.Mod:
result, e.err = computeMod(a, b)
case opcode.IntDiv:
result, e.err = computeIntDiv(a, b)
default:
e.err = ErrInvalidOperation.Gen("invalid op %v in arithmetic operation", o.Op)
return false
}
o.SetValue(result)
return e.err == nil
}
func computePlus(a, b interface{}) (interface{}, error) {
switch x := a.(type) {
case int64:
switch y := b.(type) {
case int64:
return types.AddInt64(x, y)
case uint64:
return types.AddInteger(y, x)
}
case uint64:
switch y := b.(type) {
case int64:
return types.AddInteger(x, y)
case uint64:
return types.AddUint64(x, y)
}
case float64:
switch y := b.(type) {
case float64:
return x + y, nil
}
case mysql.Decimal:
switch y := b.(type) {
case mysql.Decimal:
return x.Add(y), nil
}
}
return types.InvOp2(a, b, opcode.Plus)
}
func computeMinus(a, b interface{}) (interface{}, error) {
switch x := a.(type) {
case int64:
switch y := b.(type) {
case int64:
return types.SubInt64(x, y)
case uint64:
return types.SubIntWithUint(x, y)
}
case uint64:
switch y := b.(type) {
case int64:
return types.SubUintWithInt(x, y)
case uint64:
return types.SubUint64(x, y)
}
case float64:
switch y := b.(type) {
case float64:
return x - y, nil
}
case mysql.Decimal:
switch y := b.(type) {
case mysql.Decimal:
return x.Sub(y), nil
}
}
return types.InvOp2(a, b, opcode.Minus)
}
func computeMul(a, b interface{}) (interface{}, error) {
switch x := a.(type) {
case int64:
switch y := b.(type) {
case int64:
return types.MulInt64(x, y)
case uint64:
return types.MulInteger(y, x)
}
case uint64:
switch y := b.(type) {
case int64:
return types.MulInteger(x, y)
case uint64:
return types.MulUint64(x, y)
}
case float64:
switch y := b.(type) {
case float64:
return x * y, nil
}
case mysql.Decimal:
switch y := b.(type) {
case mysql.Decimal:
return x.Mul(y), nil
}
}
return types.InvOp2(a, b, opcode.Mul)
}
func computeDiv(a, b interface{}) (interface{}, error) {
// MySQL support integer divison Div and division operator /
// we use opcode.Div for division operator and will use another for integer division later.
// for division operator, we will use float64 for calculation.
switch x := a.(type) {
case float64:
y, err := types.ToFloat64(b)
if err != nil {
return nil, errors.Trace(err)
}
if y == 0 {
return nil, nil
}
return x / y, nil
default:
// the scale of the result is the scale of the first operand plus
// the value of the div_precision_increment system variable (which is 4 by default)
// we will use 4 here
xa, err := types.ToDecimal(a)
if err != nil {
return nil, errors.Trace(err)
}
xb, err := types.ToDecimal(b)
if err != nil {
return nil, errors.Trace(err)
}
if f, _ := xb.Float64(); f == 0 {
// division by zero return null
return nil, nil
}
return xa.Div(xb), nil
}
}
func computeMod(a, b interface{}) (interface{}, error) {
switch x := a.(type) {
case int64:
switch y := b.(type) {
case int64:
if y == 0 {
return nil, nil
}
return x % y, nil
case uint64:
if y == 0 {
return nil, nil
} else if x < 0 {
// first is int64, return int64.
return -int64(uint64(-x) % y), nil
}
return int64(uint64(x) % y), nil
}
case uint64:
switch y := b.(type) {
case int64:
if y == 0 {
return nil, nil
} else if y < 0 {
// first is uint64, return uint64.
return uint64(x % uint64(-y)), nil
}
return x % uint64(y), nil
case uint64:
if y == 0 {
return nil, nil
}
return x % y, nil
}
case float64:
switch y := b.(type) {
case float64:
if y == 0 {
return nil, nil
}
return math.Mod(x, y), nil
}
case mysql.Decimal:
switch y := b.(type) {
case mysql.Decimal:
xf, _ := x.Float64()
yf, _ := y.Float64()
if yf == 0 {
return nil, nil
}
return math.Mod(xf, yf), nil
}
}
return types.InvOp2(a, b, opcode.Mod)
}
func computeIntDiv(a, b interface{}) (interface{}, error) {
switch x := a.(type) {
case int64:
switch y := b.(type) {
case int64:
if y == 0 {
return nil, nil
}
return types.DivInt64(x, y)
case uint64:
if y == 0 {
return nil, nil
}
return types.DivIntWithUint(x, y)
}
case uint64:
switch y := b.(type) {
case int64:
if y == 0 {
return nil, nil
}
return types.DivUintWithInt(x, y)
case uint64:
if y == 0 {
return nil, nil
}
return x / y, nil
}
}
// if any is none integer, use decimal to calculate
x, err := types.ToDecimal(a)
if err != nil {
return nil, errors.Trace(err)
}
y, err := types.ToDecimal(b)
if err != nil {
return nil, errors.Trace(err)
}
if f, _ := y.Float64(); f == 0 {
return nil, nil
}
return x.Div(y).IntPart(), nil
}
func coerceArithmetic(a interface{}) (interface{}, error) {
switch x := a.(type) {
case string:
// MySQL will convert string to float for arithmetic operation
f, err := types.StrToFloat(x)
if err != nil {
return nil, errors.Trace(err)
}
return f, errors.Trace(err)
case mysql.Time:
// if time has no precision, return int64
v := x.ToNumber()
if x.Fsp == 0 {
return v.IntPart(), nil
}
return v, nil
case mysql.Duration:
// if duration has no precision, return int64
v := x.ToNumber()
if x.Fsp == 0 {
return v.IntPart(), nil
}
return v, nil
case []byte:
// []byte is the same as string, converted to float for arithmetic operator.
f, err := types.StrToFloat(string(x))
if err != nil {
return nil, errors.Trace(err)
}
return f, errors.Trace(err)
case mysql.Hex:
return x.ToNumber(), nil
case mysql.Bit:
return x.ToNumber(), nil
case mysql.Enum:
return x.ToNumber(), nil
case mysql.Set:
return x.ToNumber(), nil
default:
return x, nil
}
}

View file

@ -0,0 +1,217 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package evaluator
import (
"regexp"
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/util/types"
)
const (
patMatch = iota + 1
patOne
patAny
)
// handle escapes and wild cards convert pattern characters and pattern types,
func compilePattern(pattern string, escape byte) (patChars, patTypes []byte) {
var lastAny bool
patChars = make([]byte, len(pattern))
patTypes = make([]byte, len(pattern))
patLen := 0
for i := 0; i < len(pattern); i++ {
var tp byte
var c = pattern[i]
switch c {
case escape:
lastAny = false
tp = patMatch
if i < len(pattern)-1 {
i++
c = pattern[i]
if c == escape || c == '_' || c == '%' {
// valid escape.
} else {
// invalid escape, fall back to escape byte
// mysql will treat escape character as the origin value even
// the escape sequence is invalid in Go or C.
// e.g, \m is invalid in Go, but in MySQL we will get "m" for select '\m'.
// Following case is correct just for escape \, not for others like +.
// TODO: add more checks for other escapes.
i--
c = escape
}
}
case '_':
lastAny = false
tp = patOne
case '%':
if lastAny {
continue
}
lastAny = true
tp = patAny
default:
lastAny = false
tp = patMatch
}
patChars[patLen] = c
patTypes[patLen] = tp
patLen++
}
for i := 0; i < patLen-1; i++ {
if (patTypes[i] == patAny) && (patTypes[i+1] == patOne) {
patTypes[i] = patOne
patTypes[i+1] = patAny
}
}
patChars = patChars[:patLen]
patTypes = patTypes[:patLen]
return
}
const caseDiff = 'a' - 'A'
func matchByteCI(a, b byte) bool {
if a == b {
return true
}
if a >= 'a' && a <= 'z' && a-caseDiff == b {
return true
}
return a >= 'A' && a <= 'Z' && a+caseDiff == b
}
func doMatch(str string, patChars, patTypes []byte) bool {
var sIdx int
for i := 0; i < len(patChars); i++ {
switch patTypes[i] {
case patMatch:
if sIdx >= len(str) || !matchByteCI(str[sIdx], patChars[i]) {
return false
}
sIdx++
case patOne:
sIdx++
if sIdx > len(str) {
return false
}
case patAny:
i++
if i == len(patChars) {
return true
}
for sIdx < len(str) {
if matchByteCI(patChars[i], str[sIdx]) && doMatch(str[sIdx:], patChars[i:], patTypes[i:]) {
return true
}
sIdx++
}
return false
}
}
return sIdx == len(str)
}
func (e *Evaluator) patternLike(p *ast.PatternLikeExpr) bool {
expr := p.Expr.GetValue()
if expr == nil {
p.SetValue(nil)
return true
}
sexpr, err := types.ToString(expr)
if err != nil {
e.err = errors.Trace(err)
return false
}
// We need to compile pattern if it has not been compiled or it is not static.
var needCompile = len(p.PatChars) == 0 || !ast.IsConstant(p.Pattern)
if needCompile {
pattern := p.Pattern.GetValue()
if pattern == nil {
p.SetValue(nil)
return true
}
spattern, err := types.ToString(pattern)
if err != nil {
e.err = errors.Trace(err)
return false
}
p.PatChars, p.PatTypes = compilePattern(spattern, p.Escape)
}
match := doMatch(sexpr, p.PatChars, p.PatTypes)
if p.Not {
match = !match
}
p.SetValue(boolToInt64(match))
return true
}
func (e *Evaluator) patternRegexp(p *ast.PatternRegexpExpr) bool {
var sexpr string
if p.Sexpr != nil {
sexpr = *p.Sexpr
} else {
expr := p.Expr.GetValue()
if expr == nil {
p.SetValue(nil)
return true
}
var err error
sexpr, err = types.ToString(expr)
if err != nil {
e.err = errors.Errorf("non-string Expression in LIKE: %v (Value of type %T)", expr, expr)
return false
}
if ast.IsConstant(p.Expr) {
p.Sexpr = new(string)
*p.Sexpr = sexpr
}
}
re := p.Re
if re == nil {
pattern := p.Pattern.GetValue()
if pattern == nil {
p.SetValue(nil)
return true
}
spattern, err := types.ToString(pattern)
if err != nil {
e.err = errors.Errorf("non-string pattern in LIKE: %v (Value of type %T)", pattern, pattern)
return false
}
if re, err = regexp.Compile(spattern); err != nil {
e.err = errors.Trace(err)
return false
}
if ast.IsConstant(p.Pattern) {
p.Re = re
}
}
match := re.MatchString(sexpr)
if p.Not {
match = !match
}
p.SetValue(boolToInt64(match))
return true
}

136
vendor/github.com/pingcap/tidb/evaluator/helper.go generated vendored Normal file
View file

@ -0,0 +1,136 @@
package evaluator
import (
"strconv"
"strings"
"time"
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util/types"
)
var (
// CurrentTimestamp is the keyword getting default value for datetime and timestamp type.
CurrentTimestamp = "CURRENT_TIMESTAMP"
currentTimestampL = "current_timestamp"
// ZeroTimestamp shows the zero datetime and timestamp.
ZeroTimestamp = "0000-00-00 00:00:00"
)
var (
errDefaultValue = errors.New("invalid default value")
)
// GetTimeValue gets the time value with type tp.
func GetTimeValue(ctx context.Context, v interface{}, tp byte, fsp int) (interface{}, error) {
return getTimeValue(ctx, v, tp, fsp)
}
func getTimeValue(ctx context.Context, v interface{}, tp byte, fsp int) (interface{}, error) {
value := mysql.Time{
Type: tp,
Fsp: fsp,
}
defaultTime, err := getSystemTimestamp(ctx)
if err != nil {
return nil, errors.Trace(err)
}
switch x := v.(type) {
case string:
upperX := strings.ToUpper(x)
if upperX == CurrentTimestamp {
value.Time = defaultTime
} else if upperX == ZeroTimestamp {
value, _ = mysql.ParseTimeFromNum(0, tp, fsp)
} else {
value, err = mysql.ParseTime(x, tp, fsp)
if err != nil {
return nil, errors.Trace(err)
}
}
case *ast.ValueExpr:
switch x.Kind() {
case types.KindString:
value, err = mysql.ParseTime(x.GetString(), tp, fsp)
if err != nil {
return nil, errors.Trace(err)
}
case types.KindInt64:
value, err = mysql.ParseTimeFromNum(x.GetInt64(), tp, fsp)
if err != nil {
return nil, errors.Trace(err)
}
case types.KindNull:
return nil, nil
default:
return nil, errors.Trace(errDefaultValue)
}
case *ast.FuncCallExpr:
if x.FnName.L == currentTimestampL {
return CurrentTimestamp, nil
}
return nil, errors.Trace(errDefaultValue)
case *ast.UnaryOperationExpr:
// support some expression, like `-1`
v, err := Eval(ctx, x)
if err != nil {
return nil, errors.Trace(err)
}
ft := types.NewFieldType(mysql.TypeLonglong)
xval, err := types.Convert(v, ft)
if err != nil {
return nil, errors.Trace(err)
}
value, err = mysql.ParseTimeFromNum(xval.(int64), tp, fsp)
if err != nil {
return nil, errors.Trace(err)
}
default:
return nil, nil
}
return value, nil
}
// IsCurrentTimeExpr returns whether e is CurrentTimeExpr.
func IsCurrentTimeExpr(e ast.ExprNode) bool {
x, ok := e.(*ast.FuncCallExpr)
if !ok {
return false
}
return x.FnName.L == currentTimestampL
}
func getSystemTimestamp(ctx context.Context) (time.Time, error) {
value := time.Now()
if ctx == nil {
return value, nil
}
// check whether use timestamp varibale
sessionVars := variable.GetSessionVars(ctx)
if v, ok := sessionVars.Systems["timestamp"]; ok {
if v != "" {
timestamp, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return time.Time{}, errors.Trace(err)
}
if timestamp <= 0 {
return value, nil
}
return time.Unix(timestamp, 0), nil
}
}
return value, nil
}