1
0
Fork 0
forked from forgejo/forgejo

Integrate public as bindata optionally (#293)

* Dropped unused codekit config

* Integrated dynamic and static bindata for public

* Ignore public bindata

* Add a general generate make task

* Integrated flexible public assets into web command

* Updated vendoring, added all missiong govendor deps

* Made the linter happy with the bindata and dynamic code

* Moved public bindata definition to modules directory

* Ignoring the new bindata path now

* Updated to the new public modules import path

* Updated public bindata command and drop the new prefix
This commit is contained in:
Thomas Boerger 2016-11-29 17:26:36 +01:00 committed by Lunny Xiao
parent 4680c349dd
commit b6a95a8cb3
691 changed files with 305318 additions and 1272 deletions

65
vendor/github.com/pingcap/tidb/optimizer/logic.go generated vendored Normal file
View file

@ -0,0 +1,65 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package optimizer
import (
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/evaluator"
)
// logicOptimize does logic optimization works on AST.
func logicOptimize(ctx context.Context, node ast.Node) error {
return preEvaluate(ctx, node)
}
// preEvaluate evaluates preEvaluable expression and rewrites constant expression to value expression.
func preEvaluate(ctx context.Context, node ast.Node) error {
pe := preEvaluator{ctx: ctx}
node.Accept(&pe)
return pe.err
}
type preEvaluator struct {
ctx context.Context
err error
}
func (r *preEvaluator) Enter(in ast.Node) (ast.Node, bool) {
return in, false
}
func (r *preEvaluator) Leave(in ast.Node) (ast.Node, bool) {
if expr, ok := in.(ast.ExprNode); ok {
if _, ok = expr.(*ast.ValueExpr); ok {
return in, true
} else if ast.IsPreEvaluable(expr) {
val, err := evaluator.Eval(r.ctx, expr)
if err != nil {
r.err = err
return in, false
}
if ast.IsConstant(expr) {
// The expression is constant, rewrite the expression to value expression.
valExpr := &ast.ValueExpr{}
valExpr.SetText(expr.Text())
valExpr.SetType(expr.GetType())
valExpr.SetValue(val)
return valExpr, true
}
expr.SetValue(val)
}
}
return in, true
}

90
vendor/github.com/pingcap/tidb/optimizer/optimizer.go generated vendored Normal file
View file

@ -0,0 +1,90 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package optimizer
import (
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/optimizer/plan"
"github.com/pingcap/tidb/terror"
)
// Optimize does optimization and creates a Plan.
// The node must be prepared first.
func Optimize(ctx context.Context, node ast.Node, sb plan.SubQueryBuilder) (plan.Plan, error) {
// We have to infer type again because after parameter is set, the expression type may change.
if err := InferType(node); err != nil {
return nil, errors.Trace(err)
}
if err := logicOptimize(ctx, node); err != nil {
return nil, errors.Trace(err)
}
p, err := plan.BuildPlan(node, sb)
if err != nil {
return nil, errors.Trace(err)
}
err = plan.Refine(p)
if err != nil {
return nil, errors.Trace(err)
}
return p, nil
}
// Prepare prepares a raw statement parsed from parser.
// The statement must be prepared before it can be passed to optimize function.
// We pass InfoSchema instead of getting from Context in case it is changed after resolving name.
func Prepare(is infoschema.InfoSchema, ctx context.Context, node ast.Node) error {
ast.SetFlag(node)
if err := Preprocess(node, is, ctx); err != nil {
return errors.Trace(err)
}
if err := Validate(node, true); err != nil {
return errors.Trace(err)
}
return nil
}
// Optimizer error codes.
const (
CodeOneColumn terror.ErrCode = 1
CodeSameColumns terror.ErrCode = 2
CodeMultiWildCard terror.ErrCode = 3
CodeUnsupported terror.ErrCode = 4
CodeInvalidGroupFuncUse terror.ErrCode = 5
CodeIllegalReference terror.ErrCode = 6
)
// Optimizer base errors.
var (
ErrOneColumn = terror.ClassOptimizer.New(CodeOneColumn, "Operand should contain 1 column(s)")
ErrSameColumns = terror.ClassOptimizer.New(CodeSameColumns, "Operands should contain same columns")
ErrMultiWildCard = terror.ClassOptimizer.New(CodeMultiWildCard, "wildcard field exist more than once")
ErrUnSupported = terror.ClassOptimizer.New(CodeUnsupported, "unsupported")
ErrInvalidGroupFuncUse = terror.ClassOptimizer.New(CodeInvalidGroupFuncUse, "Invalid use of group function")
ErrIllegalReference = terror.ClassOptimizer.New(CodeIllegalReference, "Illegal reference")
)
func init() {
mySQLErrCodes := map[terror.ErrCode]uint16{
CodeOneColumn: mysql.ErrOperandColumns,
CodeSameColumns: mysql.ErrOperandColumns,
CodeMultiWildCard: mysql.ErrParse,
CodeInvalidGroupFuncUse: mysql.ErrInvalidGroupFuncUse,
CodeIllegalReference: mysql.ErrIllegalReference,
}
terror.ErrClassToMySQLCodes[terror.ClassOptimizer] = mySQLErrCodes
}

109
vendor/github.com/pingcap/tidb/optimizer/plan/cost.go generated vendored Normal file
View file

@ -0,0 +1,109 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package plan
import (
"math"
)
// Pre-defined cost factors.
const (
FullRangeCount = 10000
HalfRangeCount = 4000
MiddleRangeCount = 100
RowCost = 1.0
IndexCost = 2.0
SortCost = 2.0
FilterRate = 0.5
)
// CostEstimator estimates the cost of a plan.
type costEstimator struct {
}
// Enter implements Visitor Enter interface.
func (c *costEstimator) Enter(p Plan) (Plan, bool) {
return p, false
}
// Leave implements Visitor Leave interface.
func (c *costEstimator) Leave(p Plan) (Plan, bool) {
switch v := p.(type) {
case *IndexScan:
c.indexScan(v)
case *Limit:
v.rowCount = v.Src().RowCount()
v.startupCost = v.Src().StartupCost()
v.totalCost = v.Src().TotalCost()
case *SelectFields:
if v.Src() != nil {
v.startupCost = v.Src().StartupCost()
v.rowCount = v.Src().RowCount()
v.totalCost = v.Src().TotalCost()
}
case *SelectLock:
v.startupCost = v.Src().StartupCost()
v.rowCount = v.Src().RowCount()
v.totalCost = v.Src().TotalCost()
case *Sort:
// Sort plan must retrieve all the rows before returns the first row.
v.startupCost = v.Src().TotalCost() + v.Src().RowCount()*SortCost
if v.limit == 0 {
v.rowCount = v.Src().RowCount()
} else {
v.rowCount = math.Min(v.Src().RowCount(), v.limit)
}
v.totalCost = v.startupCost + v.rowCount*RowCost
case *TableScan:
c.tableScan(v)
}
return p, true
}
func (c *costEstimator) tableScan(v *TableScan) {
var rowCount float64 = FullRangeCount
for _, con := range v.AccessConditions {
rowCount *= guesstimateFilterRate(con)
}
v.startupCost = 0
if v.limit == 0 {
// limit is zero means no limit.
v.rowCount = rowCount
} else {
v.rowCount = math.Min(rowCount, v.limit)
}
v.totalCost = v.rowCount * RowCost
}
func (c *costEstimator) indexScan(v *IndexScan) {
var rowCount float64 = FullRangeCount
for _, con := range v.AccessConditions {
rowCount *= guesstimateFilterRate(con)
}
v.startupCost = 0
if v.limit == 0 {
// limit is zero means no limit.
v.rowCount = rowCount
} else {
v.rowCount = math.Min(rowCount, v.limit)
}
v.totalCost = v.rowCount * RowCost
}
// EstimateCost estimates the cost of the plan.
func EstimateCost(p Plan) float64 {
var estimator costEstimator
p.Accept(&estimator)
return p.TotalCost()
}

View file

@ -0,0 +1,115 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package plan
import (
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/parser/opcode"
)
const (
rateFull float64 = 1
rateEqual float64 = 0.01
rateNotEqual float64 = 0.99
rateBetween float64 = 0.1
rateGreaterOrLess float64 = 0.33
rateIsFalse float64 = 0.1
rateIsNull float64 = 0.1
rateLike float64 = 0.1
)
// guesstimateFilterRate guesstimates the filter rate for an expression.
// For example: a table has 100 rows, after filter expression 'a between 0 and 9',
// 10 rows returned, then the filter rate is '0.1'.
// It only depends on the expression type, not the expression value.
// The expr parameter should contain only one column name.
func guesstimateFilterRate(expr ast.ExprNode) float64 {
switch x := expr.(type) {
case *ast.BetweenExpr:
return rateBetween
case *ast.BinaryOperationExpr:
return guesstimateBinop(x)
case *ast.ColumnNameExpr:
return rateFull
case *ast.IsNullExpr:
return guesstimateIsNull(x)
case *ast.IsTruthExpr:
return guesstimateIsTrue(x)
case *ast.ParenthesesExpr:
return guesstimateFilterRate(x.Expr)
case *ast.PatternInExpr:
return guesstimatePatternIn(x)
case *ast.PatternLikeExpr:
return guesstimatePatternLike(x)
}
return rateFull
}
func guesstimateBinop(expr *ast.BinaryOperationExpr) float64 {
switch expr.Op {
case opcode.AndAnd:
// P(A and B) = P(A) * P(B)
return guesstimateFilterRate(expr.L) * guesstimateFilterRate(expr.R)
case opcode.OrOr:
// P(A or B) = P(A) + P(B) P(A and B)
rateL := guesstimateFilterRate(expr.L)
rateR := guesstimateFilterRate(expr.R)
return rateL + rateR - rateL*rateR
case opcode.EQ:
return rateEqual
case opcode.GT, opcode.GE, opcode.LT, opcode.LE:
return rateGreaterOrLess
case opcode.NE:
return rateNotEqual
}
return rateFull
}
func guesstimateIsNull(expr *ast.IsNullExpr) float64 {
if expr.Not {
return rateFull - rateIsNull
}
return rateIsNull
}
func guesstimateIsTrue(expr *ast.IsTruthExpr) float64 {
if expr.True == 0 {
if expr.Not {
return rateFull - rateIsFalse
}
return rateIsFalse
}
if expr.Not {
return rateIsFalse + rateIsNull
}
return rateFull - rateIsFalse - rateIsNull
}
func guesstimatePatternIn(expr *ast.PatternInExpr) float64 {
if len(expr.List) > 0 {
rate := rateEqual * float64(len(expr.List))
if expr.Not {
return rateFull - rate
}
return rate
}
return rateFull
}
func guesstimatePatternLike(expr *ast.PatternLikeExpr) float64 {
if expr.Not {
return rateFull - rateLike
}
return rateLike
}

127
vendor/github.com/pingcap/tidb/optimizer/plan/plan.go generated vendored Normal file
View file

@ -0,0 +1,127 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package plan
import (
"math"
"github.com/pingcap/tidb/ast"
)
// Plan is a description of an execution flow.
// It is created from ast.Node first, then optimized by optimizer,
// then used by executor to create a Cursor which executes the statement.
type Plan interface {
// Accept a visitor, implementation should call Visitor.Enter first,
// then call children Accept methods, finally call Visitor.Leave.
Accept(v Visitor) (out Plan, ok bool)
// Fields returns the result fields of the plan.
Fields() []*ast.ResultField
// SetFields sets the results fields of the plan.
SetFields(fields []*ast.ResultField)
// The cost before returning fhe first row.
StartupCost() float64
// The cost after returning all the rows.
TotalCost() float64
// The expected row count.
RowCount() float64
// SetLimit is used to push limit to upstream to estimate the cost.
SetLimit(limit float64)
}
// WithSrcPlan is a Plan has a source Plan.
type WithSrcPlan interface {
Plan
Src() Plan
SetSrc(src Plan)
}
// Visitor visits a Plan.
type Visitor interface {
// Enter is called before visit children.
// The out plan should be of exactly the same type as the in plan.
// if skipChildren is true, the children should not be visited.
Enter(in Plan) (out Plan, skipChildren bool)
// Leave is called after children has been visited, the out Plan can
// be another type, this is different than ast.Visitor Leave, because
// Plans only contain children plans as Plan interface type, so it is safe
// to return a different type of plan.
Leave(in Plan) (out Plan, ok bool)
}
// basePlan implements base Plan interface.
// Should be used as embedded struct in Plan implementations.
type basePlan struct {
fields []*ast.ResultField
startupCost float64
totalCost float64
rowCount float64
limit float64
}
// StartupCost implements Plan StartupCost interface.
func (p *basePlan) StartupCost() float64 {
return p.startupCost
}
// TotalCost implements Plan TotalCost interface.
func (p *basePlan) TotalCost() float64 {
return p.totalCost
}
// RowCount implements Plan RowCount interface.
func (p *basePlan) RowCount() float64 {
if p.limit == 0 {
return p.rowCount
}
return math.Min(p.rowCount, p.limit)
}
// SetLimit implements Plan SetLimit interface.
func (p *basePlan) SetLimit(limit float64) {
p.limit = limit
}
// Fields implements Plan Fields interface.
func (p *basePlan) Fields() []*ast.ResultField {
return p.fields
}
// SetFields implements Plan SetFields interface.
func (p *basePlan) SetFields(fields []*ast.ResultField) {
p.fields = fields
}
// srcPlan implements base PlanWithSrc interface.
type planWithSrc struct {
basePlan
src Plan
}
// Src implements PlanWithSrc interface.
func (p *planWithSrc) Src() Plan {
return p.src
}
// SetSrc implements PlanWithSrc interface.
func (p *planWithSrc) SetSrc(src Plan) {
p.src = src
}
// SetLimit implements Plan interface.
func (p *planWithSrc) SetLimit(limit float64) {
p.limit = limit
p.src.SetLimit(limit)
}

View file

@ -0,0 +1,926 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package plan
import (
"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/types"
)
// Error instances.
var (
ErrUnsupportedType = terror.ClassOptimizerPlan.New(CodeUnsupportedType, "Unsupported type")
)
// Error codes.
const (
CodeUnsupportedType terror.ErrCode = 1
)
// BuildPlan builds a plan from a node.
// It returns ErrUnsupportedType if ast.Node type is not supported yet.
func BuildPlan(node ast.Node, sb SubQueryBuilder) (Plan, error) {
builder := planBuilder{sb: sb}
p := builder.build(node)
return p, builder.err
}
// planBuilder builds Plan from an ast.Node.
// It just builds the ast node straightforwardly.
type planBuilder struct {
err error
hasAgg bool
sb SubQueryBuilder
obj interface{}
}
func (b *planBuilder) build(node ast.Node) Plan {
switch x := node.(type) {
case *ast.AdminStmt:
return b.buildAdmin(x)
case *ast.AlterTableStmt:
return b.buildDDL(x)
case *ast.CreateDatabaseStmt:
return b.buildDDL(x)
case *ast.CreateIndexStmt:
return b.buildDDL(x)
case *ast.CreateTableStmt:
return b.buildDDL(x)
case *ast.DeallocateStmt:
return &Deallocate{Name: x.Name}
case *ast.DeleteStmt:
return b.buildDelete(x)
case *ast.DropDatabaseStmt:
return b.buildDDL(x)
case *ast.DropIndexStmt:
return b.buildDDL(x)
case *ast.DropTableStmt:
return b.buildDDL(x)
case *ast.ExecuteStmt:
return &Execute{Name: x.Name, UsingVars: x.UsingVars}
case *ast.ExplainStmt:
return b.buildExplain(x)
case *ast.InsertStmt:
return b.buildInsert(x)
case *ast.PrepareStmt:
return b.buildPrepare(x)
case *ast.SelectStmt:
return b.buildSelect(x)
case *ast.UnionStmt:
return b.buildUnion(x)
case *ast.UpdateStmt:
return b.buildUpdate(x)
case *ast.UseStmt:
return b.buildSimple(x)
case *ast.SetCharsetStmt:
return b.buildSimple(x)
case *ast.SetStmt:
return b.buildSimple(x)
case *ast.ShowStmt:
return b.buildShow(x)
case *ast.DoStmt:
return b.buildSimple(x)
case *ast.BeginStmt:
return b.buildSimple(x)
case *ast.CommitStmt:
return b.buildSimple(x)
case *ast.RollbackStmt:
return b.buildSimple(x)
case *ast.CreateUserStmt:
return b.buildSimple(x)
case *ast.SetPwdStmt:
return b.buildSimple(x)
case *ast.GrantStmt:
return b.buildSimple(x)
case *ast.TruncateTableStmt:
return b.buildDDL(x)
}
b.err = ErrUnsupportedType.Gen("Unsupported type %T", node)
return nil
}
// Detect aggregate function or groupby clause.
func (b *planBuilder) detectSelectAgg(sel *ast.SelectStmt) bool {
if sel.GroupBy != nil {
return true
}
for _, f := range sel.GetResultFields() {
if ast.HasAggFlag(f.Expr) {
return true
}
}
if sel.Having != nil {
if ast.HasAggFlag(sel.Having.Expr) {
return true
}
}
if sel.OrderBy != nil {
for _, item := range sel.OrderBy.Items {
if ast.HasAggFlag(item.Expr) {
return true
}
}
}
return false
}
// extractSelectAgg extracts aggregate functions and converts ColumnNameExpr to aggregate function.
func (b *planBuilder) extractSelectAgg(sel *ast.SelectStmt) []*ast.AggregateFuncExpr {
extractor := &ast.AggregateFuncExtractor{AggFuncs: make([]*ast.AggregateFuncExpr, 0)}
for _, f := range sel.GetResultFields() {
n, ok := f.Expr.Accept(extractor)
if !ok {
b.err = errors.New("Failed to extract agg expr!")
return nil
}
ve, ok := f.Expr.(*ast.ValueExpr)
if ok && len(f.Column.Name.O) > 0 {
agg := &ast.AggregateFuncExpr{
F: ast.AggFuncFirstRow,
Args: []ast.ExprNode{ve},
}
extractor.AggFuncs = append(extractor.AggFuncs, agg)
n = agg
}
f.Expr = n.(ast.ExprNode)
}
// Extract agg funcs from having clause.
if sel.Having != nil {
n, ok := sel.Having.Expr.Accept(extractor)
if !ok {
b.err = errors.New("Failed to extract agg expr from having clause")
return nil
}
sel.Having.Expr = n.(ast.ExprNode)
}
// Extract agg funcs from orderby clause.
if sel.OrderBy != nil {
for _, item := range sel.OrderBy.Items {
n, ok := item.Expr.Accept(extractor)
if !ok {
b.err = errors.New("Failed to extract agg expr from orderby clause")
return nil
}
item.Expr = n.(ast.ExprNode)
// If item is PositionExpr, we need to rebind it.
// For PositionExpr will refer to a ResultField in fieldlist.
// After extract AggExpr from fieldlist, it may be changed (See the code above).
if pe, ok := item.Expr.(*ast.PositionExpr); ok {
pe.Refer = sel.GetResultFields()[pe.N-1]
}
}
}
return extractor.AggFuncs
}
func (b *planBuilder) buildSubquery(n ast.Node) {
sv := &subqueryVisitor{
builder: b,
}
_, ok := n.Accept(sv)
if !ok {
log.Errorf("Extract subquery error")
}
}
func (b *planBuilder) buildSelect(sel *ast.SelectStmt) Plan {
var aggFuncs []*ast.AggregateFuncExpr
hasAgg := b.detectSelectAgg(sel)
if hasAgg {
aggFuncs = b.extractSelectAgg(sel)
}
// Build subquery
// Convert subquery to expr with plan
b.buildSubquery(sel)
var p Plan
if sel.From != nil {
p = b.buildFrom(sel)
if b.err != nil {
return nil
}
if sel.LockTp != ast.SelectLockNone {
p = b.buildSelectLock(p, sel.LockTp)
if b.err != nil {
return nil
}
}
if hasAgg {
p = b.buildAggregate(p, aggFuncs, sel.GroupBy)
}
p = b.buildSelectFields(p, sel.GetResultFields())
if b.err != nil {
return nil
}
} else {
if hasAgg {
p = b.buildAggregate(p, aggFuncs, nil)
}
p = b.buildSelectFields(p, sel.GetResultFields())
if b.err != nil {
return nil
}
}
if sel.Having != nil {
p = b.buildHaving(p, sel.Having)
if b.err != nil {
return nil
}
}
if sel.Distinct {
p = b.buildDistinct(p)
if b.err != nil {
return nil
}
}
if sel.OrderBy != nil && !matchOrder(p, sel.OrderBy.Items) {
p = b.buildSort(p, sel.OrderBy.Items)
if b.err != nil {
return nil
}
}
if sel.Limit != nil {
p = b.buildLimit(p, sel.Limit)
if b.err != nil {
return nil
}
}
return p
}
func (b *planBuilder) buildFrom(sel *ast.SelectStmt) Plan {
from := sel.From.TableRefs
if from.Right == nil {
return b.buildSingleTable(sel)
}
return b.buildJoin(sel)
}
func (b *planBuilder) buildSingleTable(sel *ast.SelectStmt) Plan {
from := sel.From.TableRefs
ts, ok := from.Left.(*ast.TableSource)
if !ok {
b.err = ErrUnsupportedType.Gen("Unsupported type %T", from.Left)
return nil
}
var bestPlan Plan
switch v := ts.Source.(type) {
case *ast.TableName:
case *ast.SelectStmt:
bestPlan = b.buildSelect(v)
}
if bestPlan != nil {
return bestPlan
}
tn, ok := ts.Source.(*ast.TableName)
if !ok {
b.err = ErrUnsupportedType.Gen("Unsupported type %T", ts.Source)
return nil
}
conditions := splitWhere(sel.Where)
path := &joinPath{table: tn, conditions: conditions}
candidates := b.buildAllAccessMethodsPlan(path)
var lowestCost float64
for _, v := range candidates {
cost := EstimateCost(b.buildPseudoSelectPlan(v, sel))
if bestPlan == nil {
bestPlan = v
lowestCost = cost
}
if cost < lowestCost {
bestPlan = v
lowestCost = cost
}
}
return bestPlan
}
func (b *planBuilder) buildAllAccessMethodsPlan(path *joinPath) []Plan {
var candidates []Plan
p := b.buildTableScanPlan(path)
candidates = append(candidates, p)
for _, index := range path.table.TableInfo.Indices {
ip := b.buildIndexScanPlan(index, path)
candidates = append(candidates, ip)
}
return candidates
}
func (b *planBuilder) buildTableScanPlan(path *joinPath) Plan {
tn := path.table
p := &TableScan{
Table: tn.TableInfo,
}
// Equal condition contains a column from previous joined table.
p.RefAccess = len(path.eqConds) > 0
p.SetFields(tn.GetResultFields())
var pkName model.CIStr
if p.Table.PKIsHandle {
for _, colInfo := range p.Table.Columns {
if mysql.HasPriKeyFlag(colInfo.Flag) {
pkName = colInfo.Name
}
}
}
for _, con := range path.conditions {
if pkName.L != "" {
checker := conditionChecker{tableName: tn.TableInfo.Name, pkName: pkName}
if checker.check(con) {
p.AccessConditions = append(p.AccessConditions, con)
} else {
p.FilterConditions = append(p.FilterConditions, con)
}
} else {
p.FilterConditions = append(p.FilterConditions, con)
}
}
return p
}
func (b *planBuilder) buildIndexScanPlan(index *model.IndexInfo, path *joinPath) Plan {
tn := path.table
ip := &IndexScan{Table: tn.TableInfo, Index: index}
ip.RefAccess = len(path.eqConds) > 0
ip.SetFields(tn.GetResultFields())
condMap := map[ast.ExprNode]bool{}
for _, con := range path.conditions {
condMap[con] = true
}
out:
// Build equal access conditions first.
// Starts from the first index column, if equal condition is found, add it to access conditions,
// proceed to the next index column. until we can't find any equal condition for the column.
for ip.AccessEqualCount < len(index.Columns) {
for con := range condMap {
binop, ok := con.(*ast.BinaryOperationExpr)
if !ok || binop.Op != opcode.EQ {
continue
}
if ast.IsPreEvaluable(binop.L) {
binop.L, binop.R = binop.R, binop.L
}
if !ast.IsPreEvaluable(binop.R) {
continue
}
cn, ok2 := binop.L.(*ast.ColumnNameExpr)
if !ok2 || cn.Refer.Column.Name.L != index.Columns[ip.AccessEqualCount].Name.L {
continue
}
ip.AccessConditions = append(ip.AccessConditions, con)
delete(condMap, con)
ip.AccessEqualCount++
continue out
}
break
}
for con := range condMap {
if ip.AccessEqualCount < len(ip.Index.Columns) {
// Try to add non-equal access condition for index column at AccessEqualCount.
checker := conditionChecker{tableName: tn.TableInfo.Name, idx: index, columnOffset: ip.AccessEqualCount}
if checker.check(con) {
ip.AccessConditions = append(ip.AccessConditions, con)
} else {
ip.FilterConditions = append(ip.FilterConditions, con)
}
} else {
ip.FilterConditions = append(ip.FilterConditions, con)
}
}
return ip
}
// buildPseudoSelectPlan pre-builds more complete plans that may affect total cost.
func (b *planBuilder) buildPseudoSelectPlan(p Plan, sel *ast.SelectStmt) Plan {
if sel.OrderBy == nil {
return p
}
if sel.GroupBy != nil {
return p
}
if !matchOrder(p, sel.OrderBy.Items) {
np := &Sort{ByItems: sel.OrderBy.Items}
np.SetSrc(p)
p = np
}
if sel.Limit != nil {
np := &Limit{Offset: sel.Limit.Offset, Count: sel.Limit.Count}
np.SetSrc(p)
np.SetLimit(0)
p = np
}
return p
}
func (b *planBuilder) buildSelectLock(src Plan, lock ast.SelectLockType) *SelectLock {
selectLock := &SelectLock{
Lock: lock,
}
selectLock.SetSrc(src)
selectLock.SetFields(src.Fields())
return selectLock
}
func (b *planBuilder) buildSelectFields(src Plan, fields []*ast.ResultField) Plan {
selectFields := &SelectFields{}
selectFields.SetSrc(src)
selectFields.SetFields(fields)
return selectFields
}
func (b *planBuilder) buildAggregate(src Plan, aggFuncs []*ast.AggregateFuncExpr, groupby *ast.GroupByClause) Plan {
// Add aggregate plan.
aggPlan := &Aggregate{
AggFuncs: aggFuncs,
}
aggPlan.SetSrc(src)
if src != nil {
aggPlan.SetFields(src.Fields())
}
if groupby != nil {
aggPlan.GroupByItems = groupby.Items
}
return aggPlan
}
func (b *planBuilder) buildHaving(src Plan, having *ast.HavingClause) Plan {
p := &Having{
Conditions: splitWhere(having.Expr),
}
p.SetSrc(src)
p.SetFields(src.Fields())
return p
}
func (b *planBuilder) buildSort(src Plan, byItems []*ast.ByItem) Plan {
sort := &Sort{
ByItems: byItems,
}
sort.SetSrc(src)
sort.SetFields(src.Fields())
return sort
}
func (b *planBuilder) buildLimit(src Plan, limit *ast.Limit) Plan {
li := &Limit{
Offset: limit.Offset,
Count: limit.Count,
}
li.SetSrc(src)
li.SetFields(src.Fields())
return li
}
func (b *planBuilder) buildPrepare(x *ast.PrepareStmt) Plan {
p := &Prepare{
Name: x.Name,
}
if x.SQLVar != nil {
p.SQLText, _ = x.SQLVar.GetValue().(string)
} else {
p.SQLText = x.SQLText
}
return p
}
func (b *planBuilder) buildAdmin(as *ast.AdminStmt) Plan {
var p Plan
switch as.Tp {
case ast.AdminCheckTable:
p = &CheckTable{Tables: as.Tables}
case ast.AdminShowDDL:
p = &ShowDDL{}
p.SetFields(buildShowDDLFields())
default:
b.err = ErrUnsupportedType.Gen("Unsupported type %T", as)
}
return p
}
func buildShowDDLFields() []*ast.ResultField {
rfs := make([]*ast.ResultField, 0, 6)
rfs = append(rfs, buildResultField("", "SCHEMA_VER", mysql.TypeLonglong, 4))
rfs = append(rfs, buildResultField("", "OWNER", mysql.TypeVarchar, 64))
rfs = append(rfs, buildResultField("", "JOB", mysql.TypeVarchar, 128))
rfs = append(rfs, buildResultField("", "BG_SCHEMA_VER", mysql.TypeLonglong, 4))
rfs = append(rfs, buildResultField("", "BG_OWNER", mysql.TypeVarchar, 64))
rfs = append(rfs, buildResultField("", "BG_JOB", mysql.TypeVarchar, 128))
return rfs
}
func buildResultField(tableName, name string, tp byte, size int) *ast.ResultField {
cs := charset.CharsetBin
cl := charset.CharsetBin
flag := mysql.UnsignedFlag
if tp == mysql.TypeVarchar || tp == mysql.TypeBlob {
cs = mysql.DefaultCharset
cl = mysql.DefaultCollationName
flag = 0
}
fieldType := types.FieldType{
Charset: cs,
Collate: cl,
Tp: tp,
Flen: size,
Flag: uint(flag),
}
colInfo := &model.ColumnInfo{
Name: model.NewCIStr(name),
FieldType: fieldType,
}
expr := &ast.ValueExpr{}
expr.SetType(&fieldType)
return &ast.ResultField{
Column: colInfo,
ColumnAsName: colInfo.Name,
TableAsName: model.NewCIStr(tableName),
DBName: model.NewCIStr(infoschema.Name),
Expr: expr,
}
}
// matchOrder checks if the plan has the same ordering as items.
func matchOrder(p Plan, items []*ast.ByItem) bool {
switch x := p.(type) {
case *Aggregate:
return false
case *IndexScan:
if len(items) > len(x.Index.Columns) {
return false
}
for i, item := range items {
if item.Desc {
return false
}
var rf *ast.ResultField
switch y := item.Expr.(type) {
case *ast.ColumnNameExpr:
rf = y.Refer
case *ast.PositionExpr:
rf = y.Refer
default:
return false
}
if rf.Table.Name.L != x.Table.Name.L || rf.Column.Name.L != x.Index.Columns[i].Name.L {
return false
}
}
return true
case *TableScan:
if len(items) != 1 || !x.Table.PKIsHandle {
return false
}
if items[0].Desc {
return false
}
var refer *ast.ResultField
switch x := items[0].Expr.(type) {
case *ast.ColumnNameExpr:
refer = x.Refer
case *ast.PositionExpr:
refer = x.Refer
default:
return false
}
if mysql.HasPriKeyFlag(refer.Column.Flag) {
return true
}
return false
case *JoinOuter:
return false
case *JoinInner:
return false
case *Sort:
// Sort plan should not be checked here as there should only be one sort plan in a plan tree.
return false
case WithSrcPlan:
return matchOrder(x.Src(), items)
}
return true
}
// splitWhere split a where expression to a list of AND conditions.
func splitWhere(where ast.ExprNode) []ast.ExprNode {
var conditions []ast.ExprNode
switch x := where.(type) {
case nil:
case *ast.BinaryOperationExpr:
if x.Op == opcode.AndAnd {
conditions = append(conditions, splitWhere(x.L)...)
conditions = append(conditions, splitWhere(x.R)...)
} else {
conditions = append(conditions, x)
}
case *ast.ParenthesesExpr:
conditions = append(conditions, splitWhere(x.Expr)...)
default:
conditions = append(conditions, where)
}
return conditions
}
// SubQueryBuilder is the interface for building SubQuery executor.
type SubQueryBuilder interface {
Build(p Plan) ast.SubqueryExec
}
// subqueryVisitor visits AST and handles SubqueryExpr.
type subqueryVisitor struct {
builder *planBuilder
}
func (se *subqueryVisitor) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
switch x := in.(type) {
case *ast.SubqueryExpr:
p := se.builder.build(x.Query)
// The expr pointor is copyed into ResultField when running name resolver.
// So we can not just replace the expr node in AST. We need to put SubQuery into the expr.
// See: optimizer.nameResolver.createResultFields()
x.SubqueryExec = se.builder.sb.Build(p)
return in, true
case *ast.Join:
// SubSelect in from clause will be handled in buildJoin().
return in, true
}
return in, false
}
func (se *subqueryVisitor) Leave(in ast.Node) (out ast.Node, ok bool) {
return in, true
}
func (b *planBuilder) buildUnion(union *ast.UnionStmt) Plan {
sels := make([]Plan, len(union.SelectList.Selects))
for i, sel := range union.SelectList.Selects {
sels[i] = b.buildSelect(sel)
}
var p Plan
p = &Union{
Selects: sels,
}
unionFields := union.GetResultFields()
for _, sel := range sels {
for i, f := range sel.Fields() {
if i == len(unionFields) {
b.err = errors.New("The used SELECT statements have a different number of columns")
return nil
}
uField := unionFields[i]
/*
* The lengths of the columns in the UNION result take into account the values retrieved by all of the SELECT statements
* SELECT REPEAT('a',1) UNION SELECT REPEAT('b',10);
* +---------------+
* | REPEAT('a',1) |
* +---------------+
* | a |
* | bbbbbbbbbb |
* +---------------+
*/
if f.Column.Flen > uField.Column.Flen {
uField.Column.Flen = f.Column.Flen
}
// For select nul union select "abc", we should not convert "abc" to nil.
// And the result field type should be VARCHAR.
if uField.Column.Tp == 0 || uField.Column.Tp == mysql.TypeNull {
uField.Column.Tp = f.Column.Tp
}
}
}
for _, v := range unionFields {
v.Expr.SetType(&v.Column.FieldType)
}
p.SetFields(unionFields)
if union.Distinct {
p = b.buildDistinct(p)
}
if union.OrderBy != nil {
p = b.buildSort(p, union.OrderBy.Items)
}
if union.Limit != nil {
p = b.buildLimit(p, union.Limit)
}
return p
}
func (b *planBuilder) buildDistinct(src Plan) Plan {
d := &Distinct{}
d.src = src
d.SetFields(src.Fields())
return d
}
func (b *planBuilder) buildUpdate(update *ast.UpdateStmt) Plan {
sel := &ast.SelectStmt{From: update.TableRefs, Where: update.Where, OrderBy: update.Order, Limit: update.Limit}
p := b.buildFrom(sel)
if sel.OrderBy != nil && !matchOrder(p, sel.OrderBy.Items) {
p = b.buildSort(p, sel.OrderBy.Items)
if b.err != nil {
return nil
}
}
if sel.Limit != nil {
p = b.buildLimit(p, sel.Limit)
if b.err != nil {
return nil
}
}
orderedList := b.buildUpdateLists(update.List, p.Fields())
if b.err != nil {
return nil
}
return &Update{OrderedList: orderedList, SelectPlan: p}
}
func (b *planBuilder) buildUpdateLists(list []*ast.Assignment, fields []*ast.ResultField) []*ast.Assignment {
newList := make([]*ast.Assignment, len(fields))
for _, assign := range list {
offset, err := columnOffsetInFields(assign.Column, fields)
if err != nil {
b.err = errors.Trace(err)
return nil
}
newList[offset] = assign
}
return newList
}
func (b *planBuilder) buildDelete(del *ast.DeleteStmt) Plan {
sel := &ast.SelectStmt{From: del.TableRefs, Where: del.Where, OrderBy: del.Order, Limit: del.Limit}
p := b.buildFrom(sel)
if sel.OrderBy != nil && !matchOrder(p, sel.OrderBy.Items) {
p = b.buildSort(p, sel.OrderBy.Items)
if b.err != nil {
return nil
}
}
if sel.Limit != nil {
p = b.buildLimit(p, sel.Limit)
if b.err != nil {
return nil
}
}
var tables []*ast.TableName
if del.Tables != nil {
tables = del.Tables.Tables
}
return &Delete{
Tables: tables,
IsMultiTable: del.IsMultiTable,
SelectPlan: p,
}
}
func columnOffsetInFields(cn *ast.ColumnName, fields []*ast.ResultField) (int, error) {
offset := -1
tableNameL := cn.Table.L
columnNameL := cn.Name.L
if tableNameL != "" {
for i, f := range fields {
// Check table name.
if f.TableAsName.L != "" {
if tableNameL != f.TableAsName.L {
continue
}
} else {
if tableNameL != f.Table.Name.L {
continue
}
}
// Check column name.
if f.ColumnAsName.L != "" {
if columnNameL != f.ColumnAsName.L {
continue
}
} else {
if columnNameL != f.Column.Name.L {
continue
}
}
offset = i
}
} else {
for i, f := range fields {
matchAsName := f.ColumnAsName.L != "" && f.ColumnAsName.L == columnNameL
matchColumnName := f.ColumnAsName.L == "" && f.Column.Name.L == columnNameL
if matchAsName || matchColumnName {
if offset != -1 {
return -1, errors.Errorf("column %s is ambiguous.", cn.Name.O)
}
offset = i
}
}
}
if offset == -1 {
return -1, errors.Errorf("column %s not found", cn.Name.O)
}
return offset, nil
}
func (b *planBuilder) buildShow(show *ast.ShowStmt) Plan {
var p Plan
p = &Show{
Tp: show.Tp,
DBName: show.DBName,
Table: show.Table,
Column: show.Column,
Flag: show.Flag,
Full: show.Full,
User: show.User,
}
p.SetFields(show.GetResultFields())
var conditions []ast.ExprNode
if show.Pattern != nil {
conditions = append(conditions, show.Pattern)
}
if show.Where != nil {
conditions = append(conditions, show.Where)
}
if len(conditions) != 0 {
filter := &Filter{Conditions: conditions}
filter.SetSrc(p)
p = filter
}
return p
}
func (b *planBuilder) buildSimple(node ast.StmtNode) Plan {
return &Simple{Statement: node}
}
func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan {
insertPlan := &Insert{
Table: insert.Table,
Columns: insert.Columns,
Lists: insert.Lists,
Setlist: insert.Setlist,
OnDuplicate: insert.OnDuplicate,
IsReplace: insert.IsReplace,
Priority: insert.Priority,
}
if insert.Select != nil {
insertPlan.SelectPlan = b.build(insert.Select)
if b.err != nil {
return nil
}
}
return insertPlan
}
func (b *planBuilder) buildDDL(node ast.DDLNode) Plan {
return &DDL{Statement: node}
}
func (b *planBuilder) buildExplain(explain *ast.ExplainStmt) Plan {
if show, ok := explain.Stmt.(*ast.ShowStmt); ok {
return b.buildShow(show)
}
targetPlan := b.build(explain.Stmt)
if b.err != nil {
return nil
}
p := &Explain{StmtPlan: targetPlan}
p.SetFields(buildExplainFields())
return p
}
// See: https://dev.mysql.com/doc/refman/5.7/en/explain-output.html
func buildExplainFields() []*ast.ResultField {
rfs := make([]*ast.ResultField, 0, 10)
rfs = append(rfs, buildResultField("", "id", mysql.TypeLonglong, 4))
rfs = append(rfs, buildResultField("", "select_type", mysql.TypeVarchar, 128))
rfs = append(rfs, buildResultField("", "table", mysql.TypeVarchar, 128))
rfs = append(rfs, buildResultField("", "type", mysql.TypeVarchar, 128))
rfs = append(rfs, buildResultField("", "possible_keys", mysql.TypeVarchar, 128))
rfs = append(rfs, buildResultField("", "key", mysql.TypeVarchar, 128))
rfs = append(rfs, buildResultField("", "key_len", mysql.TypeVarchar, 128))
rfs = append(rfs, buildResultField("", "ref", mysql.TypeVarchar, 128))
rfs = append(rfs, buildResultField("", "rows", mysql.TypeVarchar, 128))
rfs = append(rfs, buildResultField("", "Extra", mysql.TypeVarchar, 128))
return rfs
}

View file

@ -0,0 +1,795 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package plan
import (
"strings"
"github.com/ngaut/log"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
)
// equalCond represents an equivalent join condition, like "t1.c1 = t2.c1".
type equalCond struct {
left *ast.ResultField
leftIdx bool
right *ast.ResultField
rightIdx bool
}
func newEqualCond(left, right *ast.ResultField) *equalCond {
eq := &equalCond{left: left, right: right}
eq.leftIdx = equivHasIndex(eq.left)
eq.rightIdx = equivHasIndex(eq.right)
return eq
}
func equivHasIndex(rf *ast.ResultField) bool {
if rf.Table.PKIsHandle && mysql.HasPriKeyFlag(rf.Column.Flag) {
return true
}
for _, idx := range rf.Table.Indices {
if len(idx.Columns) == 1 && idx.Columns[0].Name.L == rf.Column.Name.L {
return true
}
}
return false
}
// joinPath can be a single table path, inner join or outer join.
type joinPath struct {
// for table path
table *ast.TableName
totalFilterRate float64
// for subquery
subquery ast.Node
asName model.CIStr
neighborCount int // number of neighbor table.
idxDepCount int // number of paths this table depends on.
ordering *ast.ResultField
orderingDesc bool
// for outer join path
outer *joinPath
inner *joinPath
rightJoin bool
// for inner join path
inners []*joinPath
// common
parent *joinPath
filterRate float64
conditions []ast.ExprNode
eqConds []*equalCond
// The joinPaths that this path's index depends on.
idxDeps map[*joinPath]bool
neighbors map[*joinPath]bool
}
// newTablePath creates a new table join path.
func newTablePath(table *ast.TableName) *joinPath {
return &joinPath{
table: table,
filterRate: rateFull,
}
}
// newSubqueryPath creates a new subquery join path.
func newSubqueryPath(node ast.Node, asName model.CIStr) *joinPath {
return &joinPath{
subquery: node,
asName: asName,
filterRate: rateFull,
}
}
// newOuterJoinPath creates a new outer join path and pushes on condition to children paths.
// The returned joinPath slice has one element.
func newOuterJoinPath(isRightJoin bool, leftPath, rightPath *joinPath, on *ast.OnCondition) *joinPath {
outerJoin := &joinPath{rightJoin: isRightJoin, outer: leftPath, inner: rightPath, filterRate: 1}
leftPath.parent = outerJoin
rightPath.parent = outerJoin
if isRightJoin {
outerJoin.outer, outerJoin.inner = outerJoin.inner, outerJoin.outer
}
if on != nil {
conditions := splitWhere(on.Expr)
availablePaths := []*joinPath{outerJoin.outer}
for _, con := range conditions {
if !outerJoin.inner.attachCondition(con, availablePaths) {
log.Errorf("Inner failed to attach ON condition")
}
}
}
return outerJoin
}
// newInnerJoinPath creates inner join path and pushes on condition to children paths.
// If left path or right path is also inner join, it will be merged.
func newInnerJoinPath(leftPath, rightPath *joinPath, on *ast.OnCondition) *joinPath {
var innerJoin *joinPath
if len(leftPath.inners) != 0 {
innerJoin = leftPath
} else {
innerJoin = &joinPath{filterRate: leftPath.filterRate}
innerJoin.inners = append(innerJoin.inners, leftPath)
}
if len(rightPath.inners) != 0 {
innerJoin.inners = append(innerJoin.inners, rightPath.inners...)
innerJoin.conditions = append(innerJoin.conditions, rightPath.conditions...)
} else {
innerJoin.inners = append(innerJoin.inners, rightPath)
}
innerJoin.filterRate *= rightPath.filterRate
for _, in := range innerJoin.inners {
in.parent = innerJoin
}
if on != nil {
conditions := splitWhere(on.Expr)
for _, con := range conditions {
if !innerJoin.attachCondition(con, nil) {
innerJoin.conditions = append(innerJoin.conditions, con)
}
}
}
return innerJoin
}
func (p *joinPath) resultFields() []*ast.ResultField {
if p.table != nil {
return p.table.GetResultFields()
}
if p.outer != nil {
if p.rightJoin {
return append(p.inner.resultFields(), p.outer.resultFields()...)
}
return append(p.outer.resultFields(), p.inner.resultFields()...)
}
var rfs []*ast.ResultField
for _, in := range p.inners {
rfs = append(rfs, in.resultFields()...)
}
return rfs
}
// attachCondition tries to attach a condition as deep as possible.
// availablePaths are paths join before this path.
func (p *joinPath) attachCondition(condition ast.ExprNode, availablePaths []*joinPath) (attached bool) {
filterRate := guesstimateFilterRate(condition)
// table
if p.table != nil || p.subquery != nil {
attacher := conditionAttachChecker{targetPath: p, availablePaths: availablePaths}
condition.Accept(&attacher)
if attacher.invalid {
return false
}
p.conditions = append(p.conditions, condition)
p.filterRate *= filterRate
return true
}
// inner join
if len(p.inners) > 0 {
for _, in := range p.inners {
if in.attachCondition(condition, availablePaths) {
p.filterRate *= filterRate
return true
}
}
attacher := &conditionAttachChecker{targetPath: p, availablePaths: availablePaths}
condition.Accept(attacher)
if attacher.invalid {
return false
}
p.conditions = append(p.conditions, condition)
p.filterRate *= filterRate
return true
}
// outer join
if p.outer.attachCondition(condition, availablePaths) {
p.filterRate *= filterRate
return true
}
if p.inner.attachCondition(condition, append(availablePaths, p.outer)) {
p.filterRate *= filterRate
return true
}
return false
}
func (p *joinPath) containsTable(table *ast.TableName) bool {
if p.table != nil {
return p.table == table
}
if p.subquery != nil {
return p.asName.L == table.Name.L
}
if len(p.inners) != 0 {
for _, in := range p.inners {
if in.containsTable(table) {
return true
}
}
return false
}
return p.outer.containsTable(table) || p.inner.containsTable(table)
}
// attachEqualCond tries to attach an equalCond deep into a table path if applicable.
func (p *joinPath) attachEqualCond(eqCon *equalCond, availablePaths []*joinPath) (attached bool) {
// table
if p.table != nil {
var prevTable *ast.TableName
var needSwap bool
if eqCon.left.TableName == p.table {
prevTable = eqCon.right.TableName
} else if eqCon.right.TableName == p.table {
prevTable = eqCon.left.TableName
needSwap = true
}
if prevTable != nil {
for _, prev := range availablePaths {
if prev.containsTable(prevTable) {
if needSwap {
eqCon.left, eqCon.right = eqCon.right, eqCon.left
eqCon.leftIdx, eqCon.rightIdx = eqCon.rightIdx, eqCon.leftIdx
}
p.eqConds = append(p.eqConds, eqCon)
return true
}
}
}
return false
}
// inner join
if len(p.inners) > 0 {
for _, in := range p.inners {
if in.attachEqualCond(eqCon, availablePaths) {
p.filterRate *= rateEqual
return true
}
}
return false
}
// outer join
if p.outer.attachEqualCond(eqCon, availablePaths) {
p.filterRate *= rateEqual
return true
}
if p.inner.attachEqualCond(eqCon, append(availablePaths, p.outer)) {
p.filterRate *= rateEqual
return true
}
return false
}
func (p *joinPath) extractEqualConditon() {
var equivs []*equalCond
var cons []ast.ExprNode
for _, con := range p.conditions {
eq := equivFromExpr(con)
if eq != nil {
equivs = append(equivs, eq)
if p.table != nil {
if eq.right.TableName == p.table {
eq.left, eq.right = eq.right, eq.left
eq.leftIdx, eq.rightIdx = eq.rightIdx, eq.leftIdx
}
}
} else {
cons = append(cons, con)
}
}
p.eqConds = equivs
p.conditions = cons
for _, in := range p.inners {
in.extractEqualConditon()
}
if p.outer != nil {
p.outer.extractEqualConditon()
p.inner.extractEqualConditon()
}
}
func (p *joinPath) addIndexDependency() {
if p.outer != nil {
p.outer.addIndexDependency()
p.inner.addIndexDependency()
return
}
if p.table != nil {
return
}
for _, eq := range p.eqConds {
if !eq.leftIdx && !eq.rightIdx {
continue
}
pathLeft := p.findInnerContains(eq.left.TableName)
if pathLeft == nil {
continue
}
pathRight := p.findInnerContains(eq.right.TableName)
if pathRight == nil {
continue
}
if eq.leftIdx && eq.rightIdx {
pathLeft.addNeighbor(pathRight)
pathRight.addNeighbor(pathLeft)
} else if eq.leftIdx {
if !pathLeft.hasOuterIdxEqualCond() {
pathLeft.addIndexDep(pathRight)
}
} else if eq.rightIdx {
if !pathRight.hasOuterIdxEqualCond() {
pathRight.addIndexDep(pathLeft)
}
}
}
for _, in := range p.inners {
in.removeIndexDepCycle(in)
in.addIndexDependency()
}
}
func (p *joinPath) hasOuterIdxEqualCond() bool {
if p.table != nil {
for _, eq := range p.eqConds {
if eq.leftIdx {
return true
}
}
return false
}
if p.outer != nil {
return p.outer.hasOuterIdxEqualCond()
}
for _, in := range p.inners {
if in.hasOuterIdxEqualCond() {
return true
}
}
return false
}
func (p *joinPath) findInnerContains(table *ast.TableName) *joinPath {
for _, in := range p.inners {
if in.containsTable(table) {
return in
}
}
return nil
}
func (p *joinPath) addNeighbor(neighbor *joinPath) {
if p.neighbors == nil {
p.neighbors = map[*joinPath]bool{}
}
p.neighbors[neighbor] = true
p.neighborCount++
}
func (p *joinPath) addIndexDep(dep *joinPath) {
if p.idxDeps == nil {
p.idxDeps = map[*joinPath]bool{}
}
p.idxDeps[dep] = true
p.idxDepCount++
}
func (p *joinPath) removeIndexDepCycle(origin *joinPath) {
if p.idxDeps == nil {
return
}
for dep := range p.idxDeps {
if dep == origin {
delete(p.idxDeps, origin)
continue
}
dep.removeIndexDepCycle(origin)
}
}
func (p *joinPath) score() float64 {
return 1 / p.filterRate
}
func (p *joinPath) String() string {
if p.table != nil {
return p.table.TableInfo.Name.L
}
if p.outer != nil {
return "outer{" + p.outer.String() + "," + p.inner.String() + "}"
}
var innerStrs []string
for _, in := range p.inners {
innerStrs = append(innerStrs, in.String())
}
return "inner{" + strings.Join(innerStrs, ",") + "}"
}
func (p *joinPath) optimizeJoinOrder(availablePaths []*joinPath) {
if p.table != nil {
return
}
if p.outer != nil {
p.outer.optimizeJoinOrder(availablePaths)
p.inner.optimizeJoinOrder(append(availablePaths, p.outer))
return
}
var ordered []*joinPath
pathMap := map[*joinPath]bool{}
for _, in := range p.inners {
pathMap[in] = true
}
for len(pathMap) > 0 {
next := p.nextPath(pathMap, availablePaths)
next.optimizeJoinOrder(availablePaths)
ordered = append(ordered, next)
delete(pathMap, next)
availablePaths = append(availablePaths, next)
for path := range pathMap {
if path.idxDeps != nil {
delete(path.idxDeps, next)
}
if path.neighbors != nil {
delete(path.neighbors, next)
}
}
p.reattach(pathMap, availablePaths)
}
p.inners = ordered
}
// reattach is called by inner joinPath to retry attach conditions to inner paths
// after an inner path has been added to available paths.
func (p *joinPath) reattach(pathMap map[*joinPath]bool, availablePaths []*joinPath) {
if len(p.conditions) != 0 {
remainedConds := make([]ast.ExprNode, 0, len(p.conditions))
for _, con := range p.conditions {
var attached bool
for path := range pathMap {
if path.attachCondition(con, availablePaths) {
attached = true
break
}
}
if !attached {
remainedConds = append(remainedConds, con)
}
}
p.conditions = remainedConds
}
if len(p.eqConds) != 0 {
remainedEqConds := make([]*equalCond, 0, len(p.eqConds))
for _, eq := range p.eqConds {
var attached bool
for path := range pathMap {
if path.attachEqualCond(eq, availablePaths) {
attached = true
break
}
}
if !attached {
remainedEqConds = append(remainedEqConds, eq)
}
}
p.eqConds = remainedEqConds
}
}
func (p *joinPath) nextPath(pathMap map[*joinPath]bool, availablePaths []*joinPath) *joinPath {
cans := p.candidates(pathMap)
if len(cans) == 0 {
var v *joinPath
for v = range pathMap {
log.Errorf("index dep %v, prevs %v\n", v.idxDeps, len(availablePaths))
}
return v
}
indexPath := p.nextIndexPath(cans)
if indexPath != nil {
return indexPath
}
return p.pickPath(cans)
}
func (p *joinPath) candidates(pathMap map[*joinPath]bool) []*joinPath {
var cans []*joinPath
for t := range pathMap {
if len(t.idxDeps) > 0 {
continue
}
cans = append(cans, t)
}
return cans
}
func (p *joinPath) nextIndexPath(candidates []*joinPath) *joinPath {
var best *joinPath
for _, can := range candidates {
// Since we may not have equal conditions attached on the path, we
// need to check neighborCount and idxDepCount to see if this path
// can be joined with index.
neighborIsAvailable := len(can.neighbors) < can.neighborCount
idxDepIsAvailable := can.idxDepCount > 0
if can.hasOuterIdxEqualCond() || neighborIsAvailable || idxDepIsAvailable {
if best == nil {
best = can
}
if can.score() > best.score() {
best = can
}
}
}
return best
}
func (p *joinPath) pickPath(candidates []*joinPath) *joinPath {
var best *joinPath
for _, path := range candidates {
if best == nil {
best = path
}
if path.score() > best.score() {
best = path
}
}
return best
}
// conditionAttachChecker checks if an expression is valid to
// attach to a path. attach is valid only if all the referenced tables in the
// expression are available.
type conditionAttachChecker struct {
targetPath *joinPath
availablePaths []*joinPath
invalid bool
}
func (c *conditionAttachChecker) Enter(in ast.Node) (ast.Node, bool) {
switch x := in.(type) {
case *ast.ColumnNameExpr:
table := x.Refer.TableName
if c.targetPath.containsTable(table) {
return in, false
}
c.invalid = true
for _, path := range c.availablePaths {
if path.containsTable(table) {
c.invalid = false
return in, false
}
}
}
return in, false
}
func (c *conditionAttachChecker) Leave(in ast.Node) (ast.Node, bool) {
return in, !c.invalid
}
func (b *planBuilder) buildJoin(sel *ast.SelectStmt) Plan {
nrfinder := &nullRejectFinder{nullRejectTables: map[*ast.TableName]bool{}}
if sel.Where != nil {
sel.Where.Accept(nrfinder)
}
path := b.buildBasicJoinPath(sel.From.TableRefs, nrfinder.nullRejectTables)
rfs := path.resultFields()
whereConditions := splitWhere(sel.Where)
for _, whereCond := range whereConditions {
if !path.attachCondition(whereCond, nil) {
// TODO: Find a better way to handle this condition.
path.conditions = append(path.conditions, whereCond)
log.Errorf("Failed to attach where condtion.")
}
}
path.extractEqualConditon()
path.addIndexDependency()
path.optimizeJoinOrder(nil)
p := b.buildPlanFromJoinPath(path)
p.SetFields(rfs)
return p
}
type nullRejectFinder struct {
nullRejectTables map[*ast.TableName]bool
}
func (n *nullRejectFinder) Enter(in ast.Node) (ast.Node, bool) {
switch x := in.(type) {
case *ast.BinaryOperationExpr:
if x.Op == opcode.NullEQ || x.Op == opcode.OrOr {
return in, true
}
case *ast.IsNullExpr:
if !x.Not {
return in, true
}
case *ast.IsTruthExpr:
if x.Not {
return in, true
}
}
return in, false
}
func (n *nullRejectFinder) Leave(in ast.Node) (ast.Node, bool) {
switch x := in.(type) {
case *ast.ColumnNameExpr:
n.nullRejectTables[x.Refer.TableName] = true
}
return in, true
}
func (b *planBuilder) buildBasicJoinPath(node ast.ResultSetNode, nullRejectTables map[*ast.TableName]bool) *joinPath {
switch x := node.(type) {
case nil:
return nil
case *ast.Join:
leftPath := b.buildBasicJoinPath(x.Left, nullRejectTables)
if x.Right == nil {
return leftPath
}
righPath := b.buildBasicJoinPath(x.Right, nullRejectTables)
isOuter := b.isOuterJoin(x.Tp, leftPath, righPath, nullRejectTables)
if isOuter {
return newOuterJoinPath(x.Tp == ast.RightJoin, leftPath, righPath, x.On)
}
return newInnerJoinPath(leftPath, righPath, x.On)
case *ast.TableSource:
switch v := x.Source.(type) {
case *ast.TableName:
return newTablePath(v)
case *ast.SelectStmt, *ast.UnionStmt:
return newSubqueryPath(v, x.AsName)
default:
b.err = ErrUnsupportedType.Gen("unsupported table source type %T", x)
return nil
}
default:
b.err = ErrUnsupportedType.Gen("unsupported table source type %T", x)
return nil
}
}
func (b *planBuilder) isOuterJoin(tp ast.JoinType, leftPaths, rightPaths *joinPath,
nullRejectTables map[*ast.TableName]bool) bool {
var innerPath *joinPath
switch tp {
case ast.LeftJoin:
innerPath = rightPaths
case ast.RightJoin:
innerPath = leftPaths
default:
return false
}
for table := range nullRejectTables {
if innerPath.containsTable(table) {
return false
}
}
return true
}
func equivFromExpr(expr ast.ExprNode) *equalCond {
binop, ok := expr.(*ast.BinaryOperationExpr)
if !ok || binop.Op != opcode.EQ {
return nil
}
ln, lOK := binop.L.(*ast.ColumnNameExpr)
rn, rOK := binop.R.(*ast.ColumnNameExpr)
if !lOK || !rOK {
return nil
}
if ln.Name.Table.L == "" || rn.Name.Table.L == "" {
return nil
}
if ln.Name.Schema.L == rn.Name.Schema.L && ln.Name.Table.L == rn.Name.Table.L {
return nil
}
return newEqualCond(ln.Refer, rn.Refer)
}
func (b *planBuilder) buildPlanFromJoinPath(path *joinPath) Plan {
if path.table != nil {
return b.buildTablePlanFromJoinPath(path)
}
if path.subquery != nil {
return b.buildSubqueryJoinPath(path)
}
if path.outer != nil {
join := &JoinOuter{
Outer: b.buildPlanFromJoinPath(path.outer),
Inner: b.buildPlanFromJoinPath(path.inner),
}
if path.rightJoin {
join.SetFields(append(join.Inner.Fields(), join.Outer.Fields()...))
} else {
join.SetFields(append(join.Outer.Fields(), join.Inner.Fields()...))
}
return join
}
join := &JoinInner{}
for _, in := range path.inners {
join.Inners = append(join.Inners, b.buildPlanFromJoinPath(in))
join.fields = append(join.fields, in.resultFields()...)
}
join.Conditions = path.conditions
for _, equiv := range path.eqConds {
cond := &ast.BinaryOperationExpr{L: equiv.left.Expr, R: equiv.right.Expr, Op: opcode.EQ}
join.Conditions = append(join.Conditions, cond)
}
return join
}
func (b *planBuilder) buildTablePlanFromJoinPath(path *joinPath) Plan {
for _, equiv := range path.eqConds {
columnNameExpr := &ast.ColumnNameExpr{}
columnNameExpr.Name = &ast.ColumnName{}
columnNameExpr.Name.Name = equiv.left.Column.Name
columnNameExpr.Name.Table = equiv.left.Table.Name
columnNameExpr.Refer = equiv.left
condition := &ast.BinaryOperationExpr{L: columnNameExpr, R: equiv.right.Expr, Op: opcode.EQ}
ast.SetFlag(condition)
path.conditions = append(path.conditions, condition)
}
candidates := b.buildAllAccessMethodsPlan(path)
var p Plan
var lowestCost float64
for _, can := range candidates {
cost := EstimateCost(can)
if p == nil {
p = can
lowestCost = cost
}
if cost < lowestCost {
p = can
lowestCost = cost
}
}
return p
}
// Build subquery join path plan
func (b *planBuilder) buildSubqueryJoinPath(path *joinPath) Plan {
for _, equiv := range path.eqConds {
columnNameExpr := &ast.ColumnNameExpr{}
columnNameExpr.Name = &ast.ColumnName{}
columnNameExpr.Name.Name = equiv.left.Column.Name
columnNameExpr.Name.Table = equiv.left.Table.Name
columnNameExpr.Refer = equiv.left
condition := &ast.BinaryOperationExpr{L: columnNameExpr, R: equiv.right.Expr, Op: opcode.EQ}
ast.SetFlag(condition)
path.conditions = append(path.conditions, condition)
}
p := b.build(path.subquery)
if len(path.conditions) == 0 {
return p
}
filterPlan := &Filter{Conditions: path.conditions}
filterPlan.SetSrc(p)
filterPlan.SetFields(p.Fields())
return filterPlan
}

677
vendor/github.com/pingcap/tidb/optimizer/plan/plans.go generated vendored Normal file
View file

@ -0,0 +1,677 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package plan
import (
"fmt"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/util/types"
)
// TableRange represents a range of row handle.
type TableRange struct {
LowVal int64
HighVal int64
}
// TableScan represents a table scan plan.
type TableScan struct {
basePlan
Table *model.TableInfo
Desc bool
Ranges []TableRange
// RefAccess indicates it references a previous joined table, used in explain.
RefAccess bool
// AccessConditions can be used to build index range.
AccessConditions []ast.ExprNode
// FilterConditions can be used to filter result.
FilterConditions []ast.ExprNode
}
// Accept implements Plan Accept interface.
func (p *TableScan) Accept(v Visitor) (Plan, bool) {
np, _ := v.Enter(p)
return v.Leave(np)
}
// ShowDDL is for showing DDL information.
type ShowDDL struct {
basePlan
}
// Accept implements Plan Accept interface.
func (p *ShowDDL) Accept(v Visitor) (Plan, bool) {
np, _ := v.Enter(p)
return v.Leave(np)
}
// CheckTable is for checking table data.
type CheckTable struct {
basePlan
Tables []*ast.TableName
}
// Accept implements Plan Accept interface.
func (p *CheckTable) Accept(v Visitor) (Plan, bool) {
np, _ := v.Enter(p)
return v.Leave(np)
}
// IndexRange represents an index range to be scanned.
type IndexRange struct {
LowVal []types.Datum
LowExclude bool
HighVal []types.Datum
HighExclude bool
}
// IsPoint returns if the index range is a point.
func (ir *IndexRange) IsPoint() bool {
if len(ir.LowVal) != len(ir.HighVal) {
return false
}
for i := range ir.LowVal {
a := ir.LowVal[i]
b := ir.HighVal[i]
if a.Kind() == types.KindMinNotNull || b.Kind() == types.KindMaxValue {
return false
}
cmp, err := a.CompareDatum(b)
if err != nil {
return false
}
if cmp != 0 {
return false
}
}
return !ir.LowExclude && !ir.HighExclude
}
// IndexScan represents an index scan plan.
type IndexScan struct {
basePlan
// The index used.
Index *model.IndexInfo
// The table to lookup.
Table *model.TableInfo
// Ordered and non-overlapping ranges to be scanned.
Ranges []*IndexRange
// Desc indicates whether the index should be scanned in descending order.
Desc bool
// RefAccess indicates it references a previous joined table, used in explain.
RefAccess bool
// AccessConditions can be used to build index range.
AccessConditions []ast.ExprNode
// Number of leading equal access condition.
// The offset of each equal condition correspond to the offset of index column.
// For example, an index has column (a, b, c), condition is 'a = 0 and b = 0 and c > 0'
// AccessEqualCount would be 2.
AccessEqualCount int
// FilterConditions can be used to filter result.
FilterConditions []ast.ExprNode
}
// Accept implements Plan Accept interface.
func (p *IndexScan) Accept(v Visitor) (Plan, bool) {
np, _ := v.Enter(p)
return v.Leave(np)
}
// JoinOuter represents outer join plan.
type JoinOuter struct {
basePlan
Outer Plan
Inner Plan
}
// Accept implements Plan interface.
func (p *JoinOuter) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(np)
}
p = np.(*JoinOuter)
var ok bool
p.Outer, ok = p.Outer.Accept(v)
if !ok {
return p, false
}
p.Inner, ok = p.Inner.Accept(v)
if !ok {
return p, false
}
return v.Leave(p)
}
// JoinInner represents inner join plan.
type JoinInner struct {
basePlan
Inners []Plan
Conditions []ast.ExprNode
}
func (p *JoinInner) String() string {
return fmt.Sprintf("JoinInner()")
}
// Accept implements Plan interface.
func (p *JoinInner) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(np)
}
p = np.(*JoinInner)
for i, in := range p.Inners {
x, ok := in.Accept(v)
if !ok {
return p, false
}
p.Inners[i] = x
}
return v.Leave(p)
}
// SelectLock represents a select lock plan.
type SelectLock struct {
planWithSrc
Lock ast.SelectLockType
}
// Accept implements Plan Accept interface.
func (p *SelectLock) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(np)
}
p = np.(*SelectLock)
var ok bool
p.src, ok = p.src.Accept(v)
if !ok {
return p, false
}
return v.Leave(p)
}
// SetLimit implements Plan SetLimit interface.
func (p *SelectLock) SetLimit(limit float64) {
p.limit = limit
p.src.SetLimit(p.limit)
}
// SelectFields represents a select fields plan.
type SelectFields struct {
planWithSrc
}
// Accept implements Plan Accept interface.
func (p *SelectFields) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(np)
}
p = np.(*SelectFields)
if p.src != nil {
var ok bool
p.src, ok = p.src.Accept(v)
if !ok {
return p, false
}
}
return v.Leave(p)
}
// SetLimit implements Plan SetLimit interface.
func (p *SelectFields) SetLimit(limit float64) {
p.limit = limit
if p.src != nil {
p.src.SetLimit(limit)
}
}
// Sort represents a sorting plan.
type Sort struct {
planWithSrc
ByItems []*ast.ByItem
}
// Accept implements Plan Accept interface.
func (p *Sort) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(np)
}
p = np.(*Sort)
var ok bool
p.src, ok = p.src.Accept(v)
if !ok {
return p, false
}
return v.Leave(p)
}
// SetLimit implements Plan SetLimit interface.
// It set the Src limit only if it is bypassed.
// Bypass has to be determined before this get called.
func (p *Sort) SetLimit(limit float64) {
p.limit = limit
}
// Limit represents offset and limit plan.
type Limit struct {
planWithSrc
Offset uint64
Count uint64
}
// Accept implements Plan Accept interface.
func (p *Limit) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(np)
}
p = np.(*Limit)
var ok bool
p.src, ok = p.src.Accept(v)
if !ok {
return p, false
}
return v.Leave(p)
}
// SetLimit implements Plan SetLimit interface.
// As Limit itself determine the real limit,
// We just ignore the input, and set the real limit.
func (p *Limit) SetLimit(limit float64) {
p.limit = float64(p.Offset + p.Count)
p.src.SetLimit(p.limit)
}
// Union represents Union plan.
type Union struct {
basePlan
Selects []Plan
}
// Accept implements Plan Accept interface.
func (p *Union) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(p)
}
p = np.(*Union)
for i, sel := range p.Selects {
var ok bool
p.Selects[i], ok = sel.Accept(v)
if !ok {
return p, false
}
}
return v.Leave(p)
}
// Distinct represents Distinct plan.
type Distinct struct {
planWithSrc
}
// Accept implements Plan Accept interface.
func (p *Distinct) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(p)
}
p = np.(*Distinct)
var ok bool
p.src, ok = p.src.Accept(v)
if !ok {
return p, false
}
return v.Leave(p)
}
// SetLimit implements Plan SetLimit interface.
func (p *Distinct) SetLimit(limit float64) {
p.limit = limit
if p.src != nil {
p.src.SetLimit(limit)
}
}
// Prepare represents prepare plan.
type Prepare struct {
basePlan
Name string
SQLText string
}
// Accept implements Plan Accept interface.
func (p *Prepare) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(np)
}
p = np.(*Prepare)
return v.Leave(p)
}
// Execute represents prepare plan.
type Execute struct {
basePlan
Name string
UsingVars []ast.ExprNode
ID uint32
}
// Accept implements Plan Accept interface.
func (p *Execute) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(np)
}
p = np.(*Execute)
return v.Leave(p)
}
// Deallocate represents deallocate plan.
type Deallocate struct {
basePlan
Name string
}
// Accept implements Plan Accept interface.
func (p *Deallocate) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(np)
}
p = np.(*Deallocate)
return v.Leave(p)
}
// Aggregate represents a select fields plan.
type Aggregate struct {
planWithSrc
AggFuncs []*ast.AggregateFuncExpr
GroupByItems []*ast.ByItem
}
// Accept implements Plan Accept interface.
func (p *Aggregate) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(np)
}
p = np.(*Aggregate)
if p.src != nil {
var ok bool
p.src, ok = p.src.Accept(v)
if !ok {
return p, false
}
}
return v.Leave(p)
}
// SetLimit implements Plan SetLimit interface.
func (p *Aggregate) SetLimit(limit float64) {
p.limit = limit
if p.src != nil {
p.src.SetLimit(limit)
}
}
// Having represents a having plan.
// The having plan should after aggregate plan.
type Having struct {
planWithSrc
// Originally the WHERE or ON condition is parsed into a single expression,
// but after we converted to CNF(Conjunctive normal form), it can be
// split into a list of AND conditions.
Conditions []ast.ExprNode
}
// Accept implements Plan Accept interface.
func (p *Having) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(np)
}
p = np.(*Having)
var ok bool
p.src, ok = p.src.Accept(v)
if !ok {
return p, false
}
return v.Leave(p)
}
// SetLimit implements Plan SetLimit interface.
func (p *Having) SetLimit(limit float64) {
p.limit = limit
// We assume 50% of the src row is filtered out.
p.src.SetLimit(limit * 2)
}
// Update represents an update plan.
type Update struct {
basePlan
OrderedList []*ast.Assignment // OrderedList has the same offset as TablePlan's result fields.
SelectPlan Plan
}
// Accept implements Plan Accept interface.
func (p *Update) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(np)
}
p = np.(*Update)
var ok bool
p.SelectPlan, ok = p.SelectPlan.Accept(v)
if !ok {
return p, false
}
return v.Leave(p)
}
// Delete represents a delete plan.
type Delete struct {
basePlan
SelectPlan Plan
Tables []*ast.TableName
IsMultiTable bool
}
// Accept implements Plan Accept interface.
func (p *Delete) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(np)
}
p = np.(*Delete)
var ok bool
p.SelectPlan, ok = p.SelectPlan.Accept(v)
if !ok {
return p, false
}
return v.Leave(p)
}
// Filter represents a plan that filter srcplan result.
type Filter struct {
planWithSrc
// Originally the WHERE or ON condition is parsed into a single expression,
// but after we converted to CNF(Conjunctive normal form), it can be
// split into a list of AND conditions.
Conditions []ast.ExprNode
}
// Accept implements Plan Accept interface.
func (p *Filter) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(np)
}
p = np.(*Filter)
var ok bool
p.src, ok = p.src.Accept(v)
if !ok {
return p, false
}
return v.Leave(p)
}
// SetLimit implements Plan SetLimit interface.
func (p *Filter) SetLimit(limit float64) {
p.limit = limit
// We assume 50% of the src row is filtered out.
p.src.SetLimit(limit * 2)
}
// Show represents a show plan.
type Show struct {
basePlan
Tp ast.ShowStmtType // Databases/Tables/Columns/....
DBName string
Table *ast.TableName // Used for showing columns.
Column *ast.ColumnName // Used for `desc table column`.
Flag int // Some flag parsed from sql, such as FULL.
Full bool
User string // Used for show grants.
// Used by show variables
GlobalScope bool
}
// Accept implements Plan Accept interface.
func (p *Show) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(np)
}
p = np.(*Show)
return v.Leave(p)
}
// Simple represents a simple statement plan which doesn't need any optimization.
type Simple struct {
basePlan
Statement ast.StmtNode
}
// Accept implements Plan Accept interface.
func (p *Simple) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(np)
}
p = np.(*Simple)
return v.Leave(p)
}
// Insert represents an insert plan.
type Insert struct {
basePlan
Table *ast.TableRefsClause
Columns []*ast.ColumnName
Lists [][]ast.ExprNode
Setlist []*ast.Assignment
OnDuplicate []*ast.Assignment
SelectPlan Plan
IsReplace bool
Priority int
}
// Accept implements Plan Accept interface.
func (p *Insert) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(np)
}
p = np.(*Insert)
if p.SelectPlan != nil {
var ok bool
p.SelectPlan, ok = p.SelectPlan.Accept(v)
if !ok {
return p, false
}
}
return v.Leave(p)
}
// DDL represents a DDL statement plan.
type DDL struct {
basePlan
Statement ast.DDLNode
}
// Accept implements Plan Accept interface.
func (p *DDL) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
return v.Leave(np)
}
p = np.(*DDL)
return v.Leave(p)
}
// Explain represents a explain plan.
type Explain struct {
basePlan
StmtPlan Plan
}
// Accept implements Plan Accept interface.
func (p *Explain) Accept(v Visitor) (Plan, bool) {
np, skip := v.Enter(p)
if skip {
v.Leave(np)
}
p = np.(*Explain)
return v.Leave(p)
}

505
vendor/github.com/pingcap/tidb/optimizer/plan/range.go generated vendored Normal file
View file

@ -0,0 +1,505 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package plan
import (
"fmt"
"math"
"sort"
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)
type rangePoint struct {
value types.Datum
excl bool // exclude
start bool
}
func (rp rangePoint) String() string {
val := rp.value.GetValue()
if rp.value.Kind() == types.KindMinNotNull {
val = "-inf"
} else if rp.value.Kind() == types.KindMaxValue {
val = "+inf"
}
if rp.start {
symbol := "["
if rp.excl {
symbol = "("
}
return fmt.Sprintf("%s%v", symbol, val)
}
symbol := "]"
if rp.excl {
symbol = ")"
}
return fmt.Sprintf("%v%s", val, symbol)
}
type rangePointSorter struct {
points []rangePoint
err error
}
func (r *rangePointSorter) Len() int {
return len(r.points)
}
func (r *rangePointSorter) Less(i, j int) bool {
a := r.points[i]
b := r.points[j]
cmp, err := a.value.CompareDatum(b.value)
if err != nil {
r.err = err
return true
}
if cmp == 0 {
return r.equalValueLess(a, b)
}
return cmp < 0
}
func (r *rangePointSorter) equalValueLess(a, b rangePoint) bool {
if a.start && b.start {
return !a.excl && b.excl
} else if a.start {
return !b.excl
} else if b.start {
return a.excl || b.excl
}
return a.excl && !b.excl
}
func (r *rangePointSorter) Swap(i, j int) {
r.points[i], r.points[j] = r.points[j], r.points[i]
}
type rangeBuilder struct {
err error
}
func (r *rangeBuilder) build(expr ast.ExprNode) []rangePoint {
switch x := expr.(type) {
case *ast.BinaryOperationExpr:
return r.buildFromBinop(x)
case *ast.PatternInExpr:
return r.buildFromIn(x)
case *ast.ParenthesesExpr:
return r.build(x.Expr)
case *ast.BetweenExpr:
return r.buildFromBetween(x)
case *ast.IsNullExpr:
return r.buildFromIsNull(x)
case *ast.IsTruthExpr:
return r.buildFromIsTruth(x)
case *ast.PatternLikeExpr:
rans := r.buildFromPatternLike(x)
return rans
case *ast.ColumnNameExpr:
return r.buildFromColumnName(x)
}
return fullRange
}
func (r *rangeBuilder) buildFromBinop(x *ast.BinaryOperationExpr) []rangePoint {
if x.Op == opcode.OrOr {
return r.union(r.build(x.L), r.build(x.R))
} else if x.Op == opcode.AndAnd {
return r.intersection(r.build(x.L), r.build(x.R))
}
// This has been checked that the binary operation is comparison operation, and one of
// the operand is column name expression.
var value types.Datum
var op opcode.Op
if _, ok := x.L.(*ast.ValueExpr); ok {
value = types.NewDatum(x.L.GetValue())
switch x.Op {
case opcode.GE:
op = opcode.LE
case opcode.GT:
op = opcode.LT
case opcode.LT:
op = opcode.GT
case opcode.LE:
op = opcode.GE
default:
op = x.Op
}
} else {
value = types.NewDatum(x.R.GetValue())
op = x.Op
}
if value.Kind() == types.KindNull {
return nil
}
switch op {
case opcode.EQ:
startPoint := rangePoint{value: value, start: true}
endPoint := rangePoint{value: value}
return []rangePoint{startPoint, endPoint}
case opcode.NE:
startPoint1 := rangePoint{value: types.MinNotNullDatum(), start: true}
endPoint1 := rangePoint{value: value, excl: true}
startPoint2 := rangePoint{value: value, start: true, excl: true}
endPoint2 := rangePoint{value: types.MaxValueDatum()}
return []rangePoint{startPoint1, endPoint1, startPoint2, endPoint2}
case opcode.LT:
startPoint := rangePoint{value: types.MinNotNullDatum(), start: true}
endPoint := rangePoint{value: value, excl: true}
return []rangePoint{startPoint, endPoint}
case opcode.LE:
startPoint := rangePoint{value: types.MinNotNullDatum(), start: true}
endPoint := rangePoint{value: value}
return []rangePoint{startPoint, endPoint}
case opcode.GT:
startPoint := rangePoint{value: value, start: true, excl: true}
endPoint := rangePoint{value: types.MaxValueDatum()}
return []rangePoint{startPoint, endPoint}
case opcode.GE:
startPoint := rangePoint{value: value, start: true}
endPoint := rangePoint{value: types.MaxValueDatum()}
return []rangePoint{startPoint, endPoint}
}
return nil
}
func (r *rangeBuilder) buildFromIn(x *ast.PatternInExpr) []rangePoint {
if x.Not {
r.err = ErrUnsupportedType.Gen("NOT IN is not supported")
return fullRange
}
var rangePoints []rangePoint
for _, v := range x.List {
startPoint := rangePoint{value: types.NewDatum(v.GetValue()), start: true}
endPoint := rangePoint{value: types.NewDatum(v.GetValue())}
rangePoints = append(rangePoints, startPoint, endPoint)
}
sorter := rangePointSorter{points: rangePoints}
sort.Sort(&sorter)
if sorter.err != nil {
r.err = sorter.err
}
// check duplicates
hasDuplicate := false
isStart := false
for _, v := range rangePoints {
if isStart == v.start {
hasDuplicate = true
break
}
isStart = v.start
}
if !hasDuplicate {
return rangePoints
}
// remove duplicates
distinctRangePoints := make([]rangePoint, 0, len(rangePoints))
isStart = false
for i := 0; i < len(rangePoints); i++ {
current := rangePoints[i]
if isStart == current.start {
continue
}
distinctRangePoints = append(distinctRangePoints, current)
isStart = current.start
}
return distinctRangePoints
}
func (r *rangeBuilder) buildFromBetween(x *ast.BetweenExpr) []rangePoint {
if x.Not {
binop1 := &ast.BinaryOperationExpr{Op: opcode.LT, L: x.Expr, R: x.Left}
binop2 := &ast.BinaryOperationExpr{Op: opcode.GT, L: x.Expr, R: x.Right}
range1 := r.buildFromBinop(binop1)
range2 := r.buildFromBinop(binop2)
return r.union(range1, range2)
}
binop1 := &ast.BinaryOperationExpr{Op: opcode.GE, L: x.Expr, R: x.Left}
binop2 := &ast.BinaryOperationExpr{Op: opcode.LE, L: x.Expr, R: x.Right}
range1 := r.buildFromBinop(binop1)
range2 := r.buildFromBinop(binop2)
return r.intersection(range1, range2)
}
func (r *rangeBuilder) buildFromIsNull(x *ast.IsNullExpr) []rangePoint {
if x.Not {
startPoint := rangePoint{value: types.MinNotNullDatum(), start: true}
endPoint := rangePoint{value: types.MaxValueDatum()}
return []rangePoint{startPoint, endPoint}
}
startPoint := rangePoint{start: true}
endPoint := rangePoint{}
return []rangePoint{startPoint, endPoint}
}
func (r *rangeBuilder) buildFromIsTruth(x *ast.IsTruthExpr) []rangePoint {
if x.True != 0 {
if x.Not {
// NOT TRUE range is {[null null] [0, 0]}
startPoint1 := rangePoint{start: true}
endPoint1 := rangePoint{}
startPoint2 := rangePoint{start: true}
startPoint2.value.SetInt64(0)
endPoint2 := rangePoint{}
endPoint2.value.SetInt64(0)
return []rangePoint{startPoint1, endPoint1, startPoint2, endPoint2}
}
// TRUE range is {[-inf 0) (0 +inf]}
startPoint1 := rangePoint{value: types.MinNotNullDatum(), start: true}
endPoint1 := rangePoint{excl: true}
endPoint1.value.SetInt64(0)
startPoint2 := rangePoint{excl: true, start: true}
startPoint2.value.SetInt64(0)
endPoint2 := rangePoint{value: types.MaxValueDatum()}
return []rangePoint{startPoint1, endPoint1, startPoint2, endPoint2}
}
if x.Not {
startPoint1 := rangePoint{start: true}
endPoint1 := rangePoint{excl: true}
endPoint1.value.SetInt64(0)
startPoint2 := rangePoint{start: true, excl: true}
startPoint2.value.SetInt64(0)
endPoint2 := rangePoint{value: types.MaxValueDatum()}
return []rangePoint{startPoint1, endPoint1, startPoint2, endPoint2}
}
startPoint := rangePoint{start: true}
startPoint.value.SetInt64(0)
endPoint := rangePoint{}
endPoint.value.SetInt64(0)
return []rangePoint{startPoint, endPoint}
}
func (r *rangeBuilder) buildFromPatternLike(x *ast.PatternLikeExpr) []rangePoint {
if x.Not {
// Pattern not like is not supported.
r.err = ErrUnsupportedType.Gen("NOT LIKE is not supported.")
return fullRange
}
pattern, err := types.ToString(x.Pattern.GetValue())
if err != nil {
r.err = errors.Trace(err)
return fullRange
}
lowValue := make([]byte, 0, len(pattern))
// unscape the pattern
var exclude bool
for i := 0; i < len(pattern); i++ {
if pattern[i] == x.Escape {
i++
if i < len(pattern) {
lowValue = append(lowValue, pattern[i])
} else {
lowValue = append(lowValue, x.Escape)
}
continue
}
if pattern[i] == '%' {
break
} else if pattern[i] == '_' {
exclude = true
break
}
lowValue = append(lowValue, pattern[i])
}
if len(lowValue) == 0 {
return []rangePoint{{value: types.MinNotNullDatum(), start: true}, {value: types.MaxValueDatum()}}
}
startPoint := rangePoint{start: true, excl: exclude}
startPoint.value.SetBytesAsString(lowValue)
highValue := make([]byte, len(lowValue))
copy(highValue, lowValue)
endPoint := rangePoint{excl: true}
for i := len(highValue) - 1; i >= 0; i-- {
highValue[i]++
if highValue[i] != 0 {
endPoint.value.SetBytesAsString(highValue)
break
}
if i == 0 {
endPoint.value = types.MaxValueDatum()
break
}
}
ranges := make([]rangePoint, 2)
ranges[0] = startPoint
ranges[1] = endPoint
return ranges
}
func (r *rangeBuilder) buildFromColumnName(x *ast.ColumnNameExpr) []rangePoint {
// column name expression is equivalent to column name is true.
startPoint1 := rangePoint{value: types.MinNotNullDatum(), start: true}
endPoint1 := rangePoint{excl: true}
endPoint1.value.SetInt64(0)
startPoint2 := rangePoint{excl: true, start: true}
startPoint2.value.SetInt64(0)
endPoint2 := rangePoint{value: types.MaxValueDatum()}
return []rangePoint{startPoint1, endPoint1, startPoint2, endPoint2}
}
func (r *rangeBuilder) intersection(a, b []rangePoint) []rangePoint {
return r.merge(a, b, false)
}
func (r *rangeBuilder) union(a, b []rangePoint) []rangePoint {
return r.merge(a, b, true)
}
func (r *rangeBuilder) merge(a, b []rangePoint, union bool) []rangePoint {
sorter := rangePointSorter{points: append(a, b...)}
sort.Sort(&sorter)
if sorter.err != nil {
r.err = sorter.err
return nil
}
var (
merged []rangePoint
inRangeCount int
requiredInRangeCount int
)
if union {
requiredInRangeCount = 1
} else {
requiredInRangeCount = 2
}
for _, val := range sorter.points {
if val.start {
inRangeCount++
if inRangeCount == requiredInRangeCount {
// just reached the required in range count, a new range started.
merged = append(merged, val)
}
} else {
if inRangeCount == requiredInRangeCount {
// just about to leave the required in range count, the range is ended.
merged = append(merged, val)
}
inRangeCount--
}
}
return merged
}
// buildIndexRanges build index ranges from range points.
// Only the first column in the index is built, extra column ranges will be appended by
// appendIndexRanges.
func (r *rangeBuilder) buildIndexRanges(rangePoints []rangePoint) []*IndexRange {
indexRanges := make([]*IndexRange, 0, len(rangePoints)/2)
for i := 0; i < len(rangePoints); i += 2 {
startPoint := rangePoints[i]
endPoint := rangePoints[i+1]
ir := &IndexRange{
LowVal: []types.Datum{startPoint.value},
LowExclude: startPoint.excl,
HighVal: []types.Datum{endPoint.value},
HighExclude: endPoint.excl,
}
indexRanges = append(indexRanges, ir)
}
return indexRanges
}
// appendIndexRanges appends additional column ranges for multi-column index.
// The additional column ranges can only be appended to point ranges.
// for example we have an index (a, b), if the condition is (a > 1 and b = 2)
// then we can not build a conjunctive ranges for this index.
func (r *rangeBuilder) appendIndexRanges(origin []*IndexRange, rangePoints []rangePoint) []*IndexRange {
var newIndexRanges []*IndexRange
for i := 0; i < len(origin); i++ {
oRange := origin[i]
if !oRange.IsPoint() {
newIndexRanges = append(newIndexRanges, oRange)
} else {
newIndexRanges = append(newIndexRanges, r.appendIndexRange(oRange, rangePoints)...)
}
}
return newIndexRanges
}
func (r *rangeBuilder) appendIndexRange(origin *IndexRange, rangePoints []rangePoint) []*IndexRange {
newRanges := make([]*IndexRange, 0, len(rangePoints)/2)
for i := 0; i < len(rangePoints); i += 2 {
startPoint := rangePoints[i]
lowVal := make([]types.Datum, len(origin.LowVal)+1)
copy(lowVal, origin.LowVal)
lowVal[len(origin.LowVal)] = startPoint.value
endPoint := rangePoints[i+1]
highVal := make([]types.Datum, len(origin.HighVal)+1)
copy(highVal, origin.HighVal)
highVal[len(origin.HighVal)] = endPoint.value
ir := &IndexRange{
LowVal: lowVal,
LowExclude: startPoint.excl,
HighVal: highVal,
HighExclude: endPoint.excl,
}
newRanges = append(newRanges, ir)
}
return newRanges
}
func (r *rangeBuilder) buildTableRanges(rangePoints []rangePoint) []TableRange {
tableRanges := make([]TableRange, 0, len(rangePoints)/2)
for i := 0; i < len(rangePoints); i += 2 {
startPoint := rangePoints[i]
if startPoint.value.Kind() == types.KindNull || startPoint.value.Kind() == types.KindMinNotNull {
startPoint.value.SetInt64(math.MinInt64)
}
startInt, err := types.ToInt64(startPoint.value.GetValue())
if err != nil {
r.err = errors.Trace(err)
return tableRanges
}
startDatum := types.NewDatum(startInt)
cmp, err := startDatum.CompareDatum(startPoint.value)
if err != nil {
r.err = errors.Trace(err)
return tableRanges
}
if cmp < 0 || (cmp == 0 && startPoint.excl) {
startInt++
}
endPoint := rangePoints[i+1]
if endPoint.value.Kind() == types.KindNull {
endPoint.value.SetInt64(math.MinInt64)
} else if endPoint.value.Kind() == types.KindMaxValue {
endPoint.value.SetInt64(math.MaxInt64)
}
endInt, err := types.ToInt64(endPoint.value.GetValue())
if err != nil {
r.err = errors.Trace(err)
return tableRanges
}
endDatum := types.NewDatum(endInt)
cmp, err = endDatum.CompareDatum(endPoint.value)
if err != nil {
r.err = errors.Trace(err)
return tableRanges
}
if cmp > 0 || (cmp == 0 && endPoint.excl) {
endInt--
}
if startInt > endInt {
continue
}
tableRanges = append(tableRanges, TableRange{LowVal: startInt, HighVal: endInt})
}
return tableRanges
}

View file

@ -0,0 +1,193 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package plan
import (
"math"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)
// Refine tries to build index or table range.
func Refine(p Plan) error {
r := refiner{}
p.Accept(&r)
return r.err
}
type refiner struct {
err error
}
func (r *refiner) Enter(in Plan) (Plan, bool) {
return in, false
}
func (r *refiner) Leave(in Plan) (Plan, bool) {
switch x := in.(type) {
case *IndexScan:
r.buildIndexRange(x)
case *Limit:
x.SetLimit(0)
case *TableScan:
r.buildTableRange(x)
}
return in, r.err == nil
}
var fullRange = []rangePoint{
{start: true},
{value: types.MaxValueDatum()},
}
func (r *refiner) buildIndexRange(p *IndexScan) {
rb := rangeBuilder{}
if p.AccessEqualCount > 0 {
// Build ranges for equal access conditions.
point := rb.build(p.AccessConditions[0])
p.Ranges = rb.buildIndexRanges(point)
for i := 1; i < p.AccessEqualCount; i++ {
point = rb.build(p.AccessConditions[i])
p.Ranges = rb.appendIndexRanges(p.Ranges, point)
}
}
rangePoints := fullRange
// Build rangePoints for non-equal access condtions.
for i := p.AccessEqualCount; i < len(p.AccessConditions); i++ {
rangePoints = rb.intersection(rangePoints, rb.build(p.AccessConditions[i]))
}
if p.AccessEqualCount == 0 {
p.Ranges = rb.buildIndexRanges(rangePoints)
} else if p.AccessEqualCount < len(p.AccessConditions) {
p.Ranges = rb.appendIndexRanges(p.Ranges, rangePoints)
}
r.err = rb.err
return
}
func (r *refiner) buildTableRange(p *TableScan) {
if len(p.AccessConditions) == 0 {
p.Ranges = []TableRange{{math.MinInt64, math.MaxInt64}}
return
}
rb := rangeBuilder{}
rangePoints := fullRange
for _, cond := range p.AccessConditions {
rangePoints = rb.intersection(rangePoints, rb.build(cond))
}
p.Ranges = rb.buildTableRanges(rangePoints)
r.err = rb.err
}
// conditionChecker checks if this condition can be pushed to index plan.
type conditionChecker struct {
tableName model.CIStr
idx *model.IndexInfo
// the offset of the indexed column to be checked.
columnOffset int
pkName model.CIStr
}
func (c *conditionChecker) check(condition ast.ExprNode) bool {
switch x := condition.(type) {
case *ast.BinaryOperationExpr:
return c.checkBinaryOperation(x)
case *ast.BetweenExpr:
if ast.IsPreEvaluable(x.Left) && ast.IsPreEvaluable(x.Right) && c.checkColumnExpr(x.Expr) {
return true
}
case *ast.ColumnNameExpr:
return c.checkColumnExpr(x)
case *ast.IsNullExpr:
if c.checkColumnExpr(x.Expr) {
return true
}
case *ast.IsTruthExpr:
if c.checkColumnExpr(x.Expr) {
return true
}
case *ast.ParenthesesExpr:
return c.check(x.Expr)
case *ast.PatternInExpr:
if x.Sel != nil || x.Not {
return false
}
if !c.checkColumnExpr(x.Expr) {
return false
}
for _, val := range x.List {
if !ast.IsPreEvaluable(val) {
return false
}
}
return true
case *ast.PatternLikeExpr:
if x.Not {
return false
}
if !c.checkColumnExpr(x.Expr) {
return false
}
if !ast.IsPreEvaluable(x.Pattern) {
return false
}
patternVal := x.Pattern.GetValue()
if patternVal == nil {
return false
}
patternStr, err := types.ToString(patternVal)
if err != nil {
return false
}
firstChar := patternStr[0]
return firstChar != '%' && firstChar != '.'
}
return false
}
func (c *conditionChecker) checkBinaryOperation(b *ast.BinaryOperationExpr) bool {
switch b.Op {
case opcode.OrOr:
return c.check(b.L) && c.check(b.R)
case opcode.AndAnd:
return c.check(b.L) && c.check(b.R)
case opcode.EQ, opcode.NE, opcode.GE, opcode.GT, opcode.LE, opcode.LT:
if ast.IsPreEvaluable(b.L) {
return c.checkColumnExpr(b.R)
} else if ast.IsPreEvaluable(b.R) {
return c.checkColumnExpr(b.L)
}
}
return false
}
func (c *conditionChecker) checkColumnExpr(expr ast.ExprNode) bool {
cn, ok := expr.(*ast.ColumnNameExpr)
if !ok {
return false
}
if cn.Refer.Table.Name.L != c.tableName.L {
return false
}
if c.pkName.L != "" {
return c.pkName.L == cn.Refer.Column.Name.L
}
if c.idx != nil {
return cn.Refer.Column.Name.L == c.idx.Columns[c.columnOffset].Name.L
}
return true
}

View file

@ -0,0 +1,89 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package plan
import (
"fmt"
"math"
"strings"
)
// ToString explains a Plan, returns description string.
func ToString(p Plan) string {
var e stringer
p.Accept(&e)
return strings.Join(e.strs, "->")
}
type stringer struct {
strs []string
idxs []int
}
func (e *stringer) Enter(in Plan) (Plan, bool) {
switch in.(type) {
case *JoinOuter, *JoinInner:
e.idxs = append(e.idxs, len(e.strs))
}
return in, false
}
func (e *stringer) Leave(in Plan) (Plan, bool) {
var str string
switch x := in.(type) {
case *CheckTable:
str = "CheckTable"
case *IndexScan:
str = fmt.Sprintf("Index(%s.%s)", x.Table.Name.L, x.Index.Name.L)
case *Limit:
str = "Limit"
case *SelectFields:
str = "Fields"
case *SelectLock:
str = "Lock"
case *ShowDDL:
str = "ShowDDL"
case *Sort:
str = "Sort"
case *TableScan:
if len(x.Ranges) > 0 {
ran := x.Ranges[0]
if ran.LowVal != math.MinInt64 || ran.HighVal != math.MaxInt64 {
str = fmt.Sprintf("Range(%s)", x.Table.Name.L)
} else {
str = fmt.Sprintf("Table(%s)", x.Table.Name.L)
}
} else {
str = fmt.Sprintf("Table(%s)", x.Table.Name.L)
}
case *JoinOuter:
last := len(e.idxs) - 1
idx := e.idxs[last]
chilrden := e.strs[idx:]
e.strs = e.strs[:idx]
str = "OuterJoin{" + strings.Join(chilrden, "->") + "}"
e.idxs = e.idxs[:last]
case *JoinInner:
last := len(e.idxs) - 1
idx := e.idxs[last]
chilrden := e.strs[idx:]
e.strs = e.strs[:idx]
str = "InnerJoin{" + strings.Join(chilrden, "->") + "}"
e.idxs = e.idxs[:last]
default:
str = fmt.Sprintf("%T", in)
}
e.strs = append(e.strs, str)
return in, true
}

29
vendor/github.com/pingcap/tidb/optimizer/preprocess.go generated vendored Normal file
View file

@ -0,0 +1,29 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package optimizer
import (
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/infoschema"
)
// Preprocess does preprocess work for optimizer.
func Preprocess(node ast.Node, info infoschema.InfoSchema, ctx context.Context) error {
if err := ResolveName(node, info, ctx); err != nil {
return errors.Trace(err)
}
return nil
}

924
vendor/github.com/pingcap/tidb/optimizer/resolver.go generated vendored Normal file
View file

@ -0,0 +1,924 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package optimizer
import (
"fmt"
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/column"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx/db"
"github.com/pingcap/tidb/util/types"
)
// ResolveName resolves table name and column name.
// It generates ResultFields for ResultSetNode and resolves ColumnNameExpr to a ResultField.
func ResolveName(node ast.Node, info infoschema.InfoSchema, ctx context.Context) error {
defaultSchema := db.GetCurrentSchema(ctx)
resolver := nameResolver{Info: info, Ctx: ctx, DefaultSchema: model.NewCIStr(defaultSchema)}
node.Accept(&resolver)
return errors.Trace(resolver.Err)
}
// nameResolver is the visitor to resolve table name and column name.
// In general, a reference can only refer to information that are available for it.
// So children elements are visited in the order that previous elements make information
// available for following elements.
//
// During visiting, information are collected and stored in resolverContext.
// When we enter a subquery, a new resolverContext is pushed to the contextStack, so subquery
// information can overwrite outer query information. When we look up for a column reference,
// we look up from top to bottom in the contextStack.
type nameResolver struct {
Info infoschema.InfoSchema
Ctx context.Context
DefaultSchema model.CIStr
Err error
useOuterContext bool
contextStack []*resolverContext
}
// resolverContext stores information in a single level of select statement
// that table name and column name can be resolved.
type resolverContext struct {
/* For Select Statement. */
// table map to lookup and check table name conflict.
tableMap map[string]int
// table map to lookup and check derived-table(subselect) name conflict.
derivedTableMap map[string]int
// tableSources collected in from clause.
tables []*ast.TableSource
// result fields collected in select field list.
fieldList []*ast.ResultField
// result fields collected in group by clause.
groupBy []*ast.ResultField
// The join node stack is used by on condition to find out
// available tables to reference. On condition can only
// refer to tables involved in current join.
joinNodeStack []*ast.Join
// When visiting TableRefs, tables in this context are not available
// because it is being collected.
inTableRefs bool
// When visiting on conditon only tables in current join node are available.
inOnCondition bool
// When visiting field list, fieldList in this context are not available.
inFieldList bool
// When visiting group by, groupBy fields are not available.
inGroupBy bool
// When visiting having, only fieldList and groupBy fields are available.
inHaving bool
// When visiting having, checks if the expr is an aggregate function expr.
inHavingAgg bool
// OrderBy clause has different resolving rule than group by.
inOrderBy bool
// When visiting column name in ByItem, we should know if the column name is in an expression.
inByItemExpression bool
// If subquery use outer context.
useOuterContext bool
// When visiting multi-table delete stmt table list.
inDeleteTableList bool
// When visiting create/drop table statement.
inCreateOrDropTable bool
// When visiting show statement.
inShow bool
}
// currentContext gets the current resolverContext.
func (nr *nameResolver) currentContext() *resolverContext {
stackLen := len(nr.contextStack)
if stackLen == 0 {
return nil
}
return nr.contextStack[stackLen-1]
}
// pushContext is called when we enter a statement.
func (nr *nameResolver) pushContext() {
nr.contextStack = append(nr.contextStack, &resolverContext{
tableMap: map[string]int{},
derivedTableMap: map[string]int{},
})
}
// popContext is called when we leave a statement.
func (nr *nameResolver) popContext() {
nr.contextStack = nr.contextStack[:len(nr.contextStack)-1]
}
// pushJoin is called when we enter a join node.
func (nr *nameResolver) pushJoin(j *ast.Join) {
ctx := nr.currentContext()
ctx.joinNodeStack = append(ctx.joinNodeStack, j)
}
// popJoin is called when we leave a join node.
func (nr *nameResolver) popJoin() {
ctx := nr.currentContext()
ctx.joinNodeStack = ctx.joinNodeStack[:len(ctx.joinNodeStack)-1]
}
// Enter implements ast.Visitor interface.
func (nr *nameResolver) Enter(inNode ast.Node) (outNode ast.Node, skipChildren bool) {
switch v := inNode.(type) {
case *ast.AdminStmt:
nr.pushContext()
case *ast.AggregateFuncExpr:
ctx := nr.currentContext()
if ctx.inHaving {
ctx.inHavingAgg = true
}
case *ast.AlterTableStmt:
nr.pushContext()
case *ast.ByItem:
if _, ok := v.Expr.(*ast.ColumnNameExpr); !ok {
// If ByItem is not a single column name expression,
// the resolving rule is different from order by clause.
nr.currentContext().inByItemExpression = true
}
if nr.currentContext().inGroupBy {
// make sure item is not aggregate function
if ast.HasAggFlag(v.Expr) {
nr.Err = ErrInvalidGroupFuncUse
return inNode, true
}
}
case *ast.CreateIndexStmt:
nr.pushContext()
case *ast.CreateTableStmt:
nr.pushContext()
nr.currentContext().inCreateOrDropTable = true
case *ast.DeleteStmt:
nr.pushContext()
case *ast.DeleteTableList:
nr.currentContext().inDeleteTableList = true
case *ast.DoStmt:
nr.pushContext()
case *ast.DropTableStmt:
nr.pushContext()
nr.currentContext().inCreateOrDropTable = true
case *ast.DropIndexStmt:
nr.pushContext()
case *ast.FieldList:
nr.currentContext().inFieldList = true
case *ast.GroupByClause:
nr.currentContext().inGroupBy = true
case *ast.HavingClause:
nr.currentContext().inHaving = true
case *ast.InsertStmt:
nr.pushContext()
case *ast.Join:
nr.pushJoin(v)
case *ast.OnCondition:
nr.currentContext().inOnCondition = true
case *ast.OrderByClause:
nr.currentContext().inOrderBy = true
case *ast.SelectStmt:
nr.pushContext()
case *ast.SetStmt:
for _, assign := range v.Variables {
if cn, ok := assign.Value.(*ast.ColumnNameExpr); ok && cn.Name.Table.L == "" {
// Convert column name expression to string value expression.
assign.Value = ast.NewValueExpr(cn.Name.Name.O)
}
}
nr.pushContext()
case *ast.ShowStmt:
nr.pushContext()
nr.currentContext().inShow = true
nr.fillShowFields(v)
case *ast.TableRefsClause:
nr.currentContext().inTableRefs = true
case *ast.TruncateTableStmt:
nr.pushContext()
case *ast.UnionStmt:
nr.pushContext()
case *ast.UpdateStmt:
nr.pushContext()
}
return inNode, false
}
// Leave implements ast.Visitor interface.
func (nr *nameResolver) Leave(inNode ast.Node) (node ast.Node, ok bool) {
switch v := inNode.(type) {
case *ast.AdminStmt:
nr.popContext()
case *ast.AggregateFuncExpr:
ctx := nr.currentContext()
if ctx.inHaving {
ctx.inHavingAgg = false
}
case *ast.AlterTableStmt:
nr.popContext()
case *ast.TableName:
nr.handleTableName(v)
case *ast.ColumnNameExpr:
nr.handleColumnName(v)
case *ast.CreateIndexStmt:
nr.popContext()
case *ast.CreateTableStmt:
nr.popContext()
case *ast.DeleteTableList:
nr.currentContext().inDeleteTableList = false
case *ast.DoStmt:
nr.popContext()
case *ast.DropIndexStmt:
nr.popContext()
case *ast.DropTableStmt:
nr.popContext()
case *ast.TableSource:
nr.handleTableSource(v)
case *ast.OnCondition:
nr.currentContext().inOnCondition = false
case *ast.Join:
nr.handleJoin(v)
nr.popJoin()
case *ast.TableRefsClause:
nr.currentContext().inTableRefs = false
case *ast.FieldList:
nr.handleFieldList(v)
nr.currentContext().inFieldList = false
case *ast.GroupByClause:
ctx := nr.currentContext()
ctx.inGroupBy = false
for _, item := range v.Items {
switch x := item.Expr.(type) {
case *ast.ColumnNameExpr:
ctx.groupBy = append(ctx.groupBy, x.Refer)
}
}
case *ast.HavingClause:
nr.currentContext().inHaving = false
case *ast.OrderByClause:
nr.currentContext().inOrderBy = false
case *ast.ByItem:
nr.currentContext().inByItemExpression = false
case *ast.PositionExpr:
nr.handlePosition(v)
case *ast.SelectStmt:
ctx := nr.currentContext()
v.SetResultFields(ctx.fieldList)
if ctx.useOuterContext {
nr.useOuterContext = true
}
nr.popContext()
case *ast.SetStmt:
nr.popContext()
case *ast.ShowStmt:
nr.popContext()
case *ast.SubqueryExpr:
if nr.useOuterContext {
// TODO: check this
// If there is a deep nest of subquery, there may be something wrong.
v.UseOuterContext = true
nr.useOuterContext = false
}
case *ast.TruncateTableStmt:
nr.popContext()
case *ast.UnionStmt:
ctx := nr.currentContext()
v.SetResultFields(ctx.fieldList)
if ctx.useOuterContext {
nr.useOuterContext = true
}
nr.popContext()
case *ast.UnionSelectList:
nr.handleUnionSelectList(v)
case *ast.InsertStmt:
nr.popContext()
case *ast.DeleteStmt:
nr.popContext()
case *ast.UpdateStmt:
nr.popContext()
}
return inNode, nr.Err == nil
}
// handleTableName looks up and sets the schema information and result fields for table name.
func (nr *nameResolver) handleTableName(tn *ast.TableName) {
if tn.Schema.L == "" {
tn.Schema = nr.DefaultSchema
}
ctx := nr.currentContext()
if ctx.inCreateOrDropTable {
// The table may not exist in create table or drop table statement.
// Skip resolving the table to avoid error.
return
}
if ctx.inDeleteTableList {
idx, ok := ctx.tableMap[nr.tableUniqueName(tn.Schema, tn.Name)]
if !ok {
nr.Err = errors.Errorf("Unknown table %s", tn.Name.O)
return
}
ts := ctx.tables[idx]
tableName := ts.Source.(*ast.TableName)
tn.DBInfo = tableName.DBInfo
tn.TableInfo = tableName.TableInfo
tn.SetResultFields(tableName.GetResultFields())
return
}
table, err := nr.Info.TableByName(tn.Schema, tn.Name)
if err != nil {
nr.Err = errors.Trace(err)
return
}
tn.TableInfo = table.Meta()
dbInfo, _ := nr.Info.SchemaByName(tn.Schema)
tn.DBInfo = dbInfo
rfs := make([]*ast.ResultField, 0, len(tn.TableInfo.Columns))
for _, v := range tn.TableInfo.Columns {
if v.State != model.StatePublic {
continue
}
expr := &ast.ValueExpr{}
expr.SetType(&v.FieldType)
rf := &ast.ResultField{
Column: v,
Table: tn.TableInfo,
DBName: tn.Schema,
Expr: expr,
TableName: tn,
}
rfs = append(rfs, rf)
}
tn.SetResultFields(rfs)
return
}
// handleTableSources checks name duplication
// and puts the table source in current resolverContext.
// Note:
// "select * from t as a join (select 1) as a;" is not duplicate.
// "select * from t as a join t as a;" is duplicate.
// "select * from (select 1) as a join (select 1) as a;" is duplicate.
func (nr *nameResolver) handleTableSource(ts *ast.TableSource) {
for _, v := range ts.GetResultFields() {
v.TableAsName = ts.AsName
}
ctx := nr.currentContext()
switch ts.Source.(type) {
case *ast.TableName:
var name string
if ts.AsName.L != "" {
name = ts.AsName.L
} else {
tableName := ts.Source.(*ast.TableName)
name = nr.tableUniqueName(tableName.Schema, tableName.Name)
}
if _, ok := ctx.tableMap[name]; ok {
nr.Err = errors.Errorf("duplicated table/alias name %s", name)
return
}
ctx.tableMap[name] = len(ctx.tables)
case *ast.SelectStmt:
name := ts.AsName.L
if _, ok := ctx.derivedTableMap[name]; ok {
nr.Err = errors.Errorf("duplicated table/alias name %s", name)
return
}
ctx.derivedTableMap[name] = len(ctx.tables)
}
dupNames := make(map[string]struct{}, len(ts.GetResultFields()))
for _, f := range ts.GetResultFields() {
// duplicate column name in one table is not allowed.
// "select * from (select 1, 1) as a;" is duplicate.
name := f.ColumnAsName.L
if name == "" {
name = f.Column.Name.L
}
if _, ok := dupNames[name]; ok {
nr.Err = errors.Errorf("Duplicate column name '%s'", name)
return
}
dupNames[name] = struct{}{}
}
ctx.tables = append(ctx.tables, ts)
return
}
// handleJoin sets result fields for join.
func (nr *nameResolver) handleJoin(j *ast.Join) {
if j.Right == nil {
j.SetResultFields(j.Left.GetResultFields())
return
}
leftLen := len(j.Left.GetResultFields())
rightLen := len(j.Right.GetResultFields())
rfs := make([]*ast.ResultField, leftLen+rightLen)
copy(rfs, j.Left.GetResultFields())
copy(rfs[leftLen:], j.Right.GetResultFields())
j.SetResultFields(rfs)
}
// handleColumnName looks up and sets ResultField for
// the column name.
func (nr *nameResolver) handleColumnName(cn *ast.ColumnNameExpr) {
ctx := nr.currentContext()
if ctx.inOnCondition {
// In on condition, only tables within current join is available.
nr.resolveColumnNameInOnCondition(cn)
return
}
// Try to resolve the column name form top to bottom in the context stack.
for i := len(nr.contextStack) - 1; i >= 0; i-- {
if nr.resolveColumnNameInContext(nr.contextStack[i], cn) {
// Column is already resolved or encountered an error.
if i < len(nr.contextStack)-1 {
// If in subselect, the query use outer query.
nr.currentContext().useOuterContext = true
}
return
}
}
nr.Err = errors.Errorf("unknown column %s", cn.Name.Name.L)
}
// resolveColumnNameInContext looks up and sets ResultField for a column with the ctx.
func (nr *nameResolver) resolveColumnNameInContext(ctx *resolverContext, cn *ast.ColumnNameExpr) bool {
if ctx.inTableRefs {
// In TableRefsClause, column reference only in join on condition which is handled before.
return false
}
if ctx.inFieldList {
// only resolve column using tables.
return nr.resolveColumnInTableSources(cn, ctx.tables)
}
if ctx.inGroupBy {
// From tables first, then field list.
// If ctx.InByItemExpression is true, the item is not an identifier.
// Otherwise it is an identifier.
if ctx.inByItemExpression {
// From table first, then field list.
if nr.resolveColumnInTableSources(cn, ctx.tables) {
return true
}
found := nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList)
if nr.Err == nil && found {
// Check if resolved refer is an aggregate function expr.
if _, ok := cn.Refer.Expr.(*ast.AggregateFuncExpr); ok {
nr.Err = ErrIllegalReference.Gen("Reference '%s' not supported (reference to group function)", cn.Name.Name.O)
}
}
return found
}
// Resolve from table first, then from select list.
found := nr.resolveColumnInTableSources(cn, ctx.tables)
if nr.Err != nil {
return found
}
// We should copy the refer here.
// Because if the ByItem is an identifier, we should check if it
// is ambiguous even it is already resolved from table source.
// If the ByItem is not an identifier, we do not need the second check.
r := cn.Refer
if nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList) {
if nr.Err != nil {
return true
}
if r != nil {
// It is not ambiguous and already resolved from table source.
// We should restore its Refer.
cn.Refer = r
}
if _, ok := cn.Refer.Expr.(*ast.AggregateFuncExpr); ok {
nr.Err = ErrIllegalReference.Gen("Reference '%s' not supported (reference to group function)", cn.Name.Name.O)
}
return true
}
return found
}
if ctx.inHaving {
// First group by, then field list.
if nr.resolveColumnInResultFields(ctx, cn, ctx.groupBy) {
return true
}
if ctx.inHavingAgg {
// If cn is in an aggregate function in having clause, check tablesource first.
if nr.resolveColumnInTableSources(cn, ctx.tables) {
return true
}
}
return nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList)
}
if ctx.inOrderBy {
if nr.resolveColumnInResultFields(ctx, cn, ctx.groupBy) {
return true
}
if ctx.inByItemExpression {
// From table first, then field list.
if nr.resolveColumnInTableSources(cn, ctx.tables) {
return true
}
return nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList)
}
// Field list first, then from table.
if nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList) {
return true
}
return nr.resolveColumnInTableSources(cn, ctx.tables)
}
if ctx.inShow {
return nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList)
}
// In where clause.
return nr.resolveColumnInTableSources(cn, ctx.tables)
}
// resolveColumnNameInOnCondition resolves the column name in current join.
func (nr *nameResolver) resolveColumnNameInOnCondition(cn *ast.ColumnNameExpr) {
ctx := nr.currentContext()
join := ctx.joinNodeStack[len(ctx.joinNodeStack)-1]
tableSources := appendTableSources(nil, join)
if !nr.resolveColumnInTableSources(cn, tableSources) {
nr.Err = errors.Errorf("unkown column name %s", cn.Name.Name.O)
}
}
func (nr *nameResolver) resolveColumnInTableSources(cn *ast.ColumnNameExpr, tableSources []*ast.TableSource) (done bool) {
var matchedResultField *ast.ResultField
tableNameL := cn.Name.Table.L
columnNameL := cn.Name.Name.L
if tableNameL != "" {
var matchedTable ast.ResultSetNode
for _, ts := range tableSources {
if tableNameL == ts.AsName.L {
// different table name.
matchedTable = ts
break
} else if ts.AsName.L != "" {
// Table as name shadows table real name.
continue
}
if tn, ok := ts.Source.(*ast.TableName); ok {
if cn.Name.Schema.L != "" && cn.Name.Schema.L != tn.Schema.L {
continue
}
if tableNameL == tn.Name.L {
matchedTable = ts
}
}
}
if matchedTable != nil {
resultFields := matchedTable.GetResultFields()
for _, rf := range resultFields {
if rf.ColumnAsName.L == columnNameL || rf.Column.Name.L == columnNameL {
// resolve column.
matchedResultField = rf
break
}
}
}
} else {
for _, ts := range tableSources {
rfs := ts.GetResultFields()
for _, rf := range rfs {
matchAsName := rf.ColumnAsName.L != "" && rf.ColumnAsName.L == columnNameL
matchColumnName := rf.ColumnAsName.L == "" && rf.Column.Name.L == columnNameL
if matchAsName || matchColumnName {
if matchedResultField != nil {
nr.Err = errors.Errorf("column %s is ambiguous.", cn.Name.Name.O)
return true
}
matchedResultField = rf
}
}
}
}
if matchedResultField != nil {
// Bind column.
cn.Refer = matchedResultField
return true
}
return false
}
func (nr *nameResolver) resolveColumnInResultFields(ctx *resolverContext, cn *ast.ColumnNameExpr, rfs []*ast.ResultField) bool {
var matched *ast.ResultField
for _, rf := range rfs {
if cn.Name.Table.L != "" {
// Check table name
if rf.TableAsName.L != "" {
if cn.Name.Table.L != rf.TableAsName.L {
continue
}
} else if cn.Name.Table.L != rf.Table.Name.L {
continue
}
}
matchAsName := cn.Name.Name.L == rf.ColumnAsName.L
var matchColumnName bool
if ctx.inHaving {
matchColumnName = cn.Name.Name.L == rf.Column.Name.L
} else {
matchColumnName = rf.ColumnAsName.L == "" && cn.Name.Name.L == rf.Column.Name.L
}
if matchAsName || matchColumnName {
if rf.Column.Name.L == "" {
// This is not a real table column, resolve it directly.
cn.Refer = rf
return true
}
if matched == nil {
matched = rf
} else {
sameColumn := matched.TableName == rf.TableName && matched.Column.Name.L == rf.Column.Name.L
if !sameColumn {
nr.Err = errors.Errorf("column %s is ambiguous.", cn.Name.Name.O)
return true
}
}
}
}
if matched != nil {
// If in GroupBy, we clone the ResultField
if ctx.inGroupBy || ctx.inHaving || ctx.inOrderBy {
nf := *matched
expr := matched.Expr
if cexpr, ok := expr.(*ast.ColumnNameExpr); ok {
expr = cexpr.Refer.Expr
}
nf.Expr = expr
matched = &nf
}
// Bind column.
cn.Refer = matched
return true
}
return false
}
// handleFieldList expands wild card field and sets fieldList in current context.
func (nr *nameResolver) handleFieldList(fieldList *ast.FieldList) {
var resultFields []*ast.ResultField
for _, v := range fieldList.Fields {
resultFields = append(resultFields, nr.createResultFields(v)...)
}
nr.currentContext().fieldList = resultFields
}
func getInnerFromParentheses(expr ast.ExprNode) ast.ExprNode {
if pexpr, ok := expr.(*ast.ParenthesesExpr); ok {
return getInnerFromParentheses(pexpr.Expr)
}
return expr
}
// createResultFields creates result field list for a single select field.
func (nr *nameResolver) createResultFields(field *ast.SelectField) (rfs []*ast.ResultField) {
ctx := nr.currentContext()
if field.WildCard != nil {
if len(ctx.tables) == 0 {
nr.Err = errors.New("No table used.")
return
}
tableRfs := []*ast.ResultField{}
if field.WildCard.Table.L == "" {
for _, v := range ctx.tables {
tableRfs = append(tableRfs, v.GetResultFields()...)
}
} else {
name := nr.tableUniqueName(field.WildCard.Schema, field.WildCard.Table)
tableIdx, ok1 := ctx.tableMap[name]
derivedTableIdx, ok2 := ctx.derivedTableMap[name]
if !ok1 && !ok2 {
nr.Err = errors.Errorf("unknown table %s.", field.WildCard.Table.O)
}
if ok1 {
tableRfs = ctx.tables[tableIdx].GetResultFields()
}
if ok2 {
tableRfs = append(tableRfs, ctx.tables[derivedTableIdx].GetResultFields()...)
}
}
for _, trf := range tableRfs {
// Convert it to ColumnNameExpr
cn := &ast.ColumnName{
Schema: trf.DBName,
Table: trf.Table.Name,
Name: trf.ColumnAsName,
}
cnExpr := &ast.ColumnNameExpr{
Name: cn,
Refer: trf,
}
ast.SetFlag(cnExpr)
cnExpr.SetType(trf.Expr.GetType())
rf := *trf
rf.Expr = cnExpr
rfs = append(rfs, &rf)
}
return
}
// The column is visited before so it must has been resolved already.
rf := &ast.ResultField{ColumnAsName: field.AsName}
innerExpr := getInnerFromParentheses(field.Expr)
switch v := innerExpr.(type) {
case *ast.ColumnNameExpr:
rf.Column = v.Refer.Column
rf.Table = v.Refer.Table
rf.DBName = v.Refer.DBName
rf.TableName = v.Refer.TableName
rf.Expr = v
default:
rf.Column = &model.ColumnInfo{} // Empty column info.
rf.Table = &model.TableInfo{} // Empty table info.
rf.Expr = v
}
if field.AsName.L == "" {
switch x := innerExpr.(type) {
case *ast.ColumnNameExpr:
rf.ColumnAsName = model.NewCIStr(x.Name.Name.O)
case *ast.ValueExpr:
if innerExpr.Text() != "" {
rf.ColumnAsName = model.NewCIStr(innerExpr.Text())
} else {
rf.ColumnAsName = model.NewCIStr(field.Text())
}
default:
rf.ColumnAsName = model.NewCIStr(field.Text())
}
}
rfs = append(rfs, rf)
return
}
func appendTableSources(in []*ast.TableSource, resultSetNode ast.ResultSetNode) (out []*ast.TableSource) {
switch v := resultSetNode.(type) {
case *ast.TableSource:
out = append(in, v)
case *ast.Join:
out = appendTableSources(in, v.Left)
if v.Right != nil {
out = appendTableSources(out, v.Right)
}
}
return
}
func (nr *nameResolver) tableUniqueName(schema, table model.CIStr) string {
if schema.L != "" && schema.L != nr.DefaultSchema.L {
return schema.L + "." + table.L
}
return table.L
}
func (nr *nameResolver) handlePosition(pos *ast.PositionExpr) {
ctx := nr.currentContext()
if pos.N < 1 || pos.N > len(ctx.fieldList) {
nr.Err = errors.Errorf("Unknown column '%d'", pos.N)
return
}
matched := ctx.fieldList[pos.N-1]
nf := *matched
expr := matched.Expr
if cexpr, ok := expr.(*ast.ColumnNameExpr); ok {
expr = cexpr.Refer.Expr
}
nf.Expr = expr
pos.Refer = &nf
if nr.currentContext().inGroupBy {
// make sure item is not aggregate function
if ast.HasAggFlag(pos.Refer.Expr) {
nr.Err = errors.New("group by cannot contain aggregate function")
}
}
}
func (nr *nameResolver) handleUnionSelectList(u *ast.UnionSelectList) {
firstSelFields := u.Selects[0].GetResultFields()
unionFields := make([]*ast.ResultField, len(firstSelFields))
// Copy first result fields, because we may change the result field type.
for i, v := range firstSelFields {
rf := *v
col := *v.Column
rf.Column = &col
if rf.Column.Flen == 0 {
rf.Column.Flen = types.UnspecifiedLength
}
rf.Expr = &ast.ValueExpr{}
unionFields[i] = &rf
}
nr.currentContext().fieldList = unionFields
}
func (nr *nameResolver) fillShowFields(s *ast.ShowStmt) {
if s.DBName == "" {
if s.Table != nil && s.Table.Schema.L != "" {
s.DBName = s.Table.Schema.O
} else {
s.DBName = nr.DefaultSchema.O
}
} else if s.Table != nil && s.Table.Schema.L == "" {
s.Table.Schema = model.NewCIStr(s.DBName)
}
var fields []*ast.ResultField
var (
names []string
ftypes []byte
)
switch s.Tp {
case ast.ShowEngines:
names = []string{"Engine", "Support", "Comment", "Transactions", "XA", "Savepoints"}
case ast.ShowDatabases:
names = []string{"Database"}
case ast.ShowTables:
names = []string{fmt.Sprintf("Tables_in_%s", s.DBName)}
if s.Full {
names = append(names, "Table_type")
}
case ast.ShowTableStatus:
names = []string{"Name", "Engine", "Version", "Row_format", "Rows", "Avg_row_length",
"Data_length", "Max_data_length", "Index_length", "Data_free", "Auto_increment",
"Create_time", "Update_time", "Check_time", "Collation", "Checksum",
"Create_options", "Comment"}
ftypes = []byte{mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeLonglong, mysql.TypeVarchar, mysql.TypeLonglong, mysql.TypeLonglong,
mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong,
mysql.TypeDatetime, mysql.TypeDatetime, mysql.TypeDatetime, mysql.TypeVarchar, mysql.TypeVarchar,
mysql.TypeVarchar, mysql.TypeVarchar}
case ast.ShowColumns:
names = column.ColDescFieldNames(s.Full)
case ast.ShowWarnings:
names = []string{"Level", "Code", "Message"}
ftypes = []byte{mysql.TypeVarchar, mysql.TypeLong, mysql.TypeVarchar}
case ast.ShowCharset:
names = []string{"Charset", "Description", "Default collation", "Maxlen"}
ftypes = []byte{mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeLonglong}
case ast.ShowVariables:
names = []string{"Variable_name", "Value"}
case ast.ShowStatus:
names = []string{"Variable_name", "Value"}
case ast.ShowCollation:
names = []string{"Collation", "Charset", "Id", "Default", "Compiled", "Sortlen"}
ftypes = []byte{mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeLonglong,
mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeLonglong}
case ast.ShowCreateTable:
names = []string{"Table", "Create Table"}
case ast.ShowGrants:
names = []string{fmt.Sprintf("Grants for %s", s.User)}
case ast.ShowTriggers:
names = []string{"Trigger", "Event", "Table", "Statement", "Timing", "Created",
"sql_mode", "Definer", "character_set_client", "collation_connection", "Database Collation"}
ftypes = []byte{mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar,
mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar}
case ast.ShowProcedureStatus:
names = []string{}
ftypes = []byte{}
case ast.ShowIndex:
names = []string{"Table", "Non_unique", "Key_name", "Seq_in_index",
"Column_name", "Collation", "Cardinality", "Sub_part", "Packed",
"Null", "Index_type", "Comment", "Index_comment"}
ftypes = []byte{mysql.TypeVarchar, mysql.TypeLonglong, mysql.TypeVarchar, mysql.TypeLonglong,
mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeLonglong, mysql.TypeLonglong,
mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar}
}
for i, name := range names {
f := &ast.ResultField{
ColumnAsName: model.NewCIStr(name),
Column: &model.ColumnInfo{}, // Empty column info.
Table: &model.TableInfo{}, // Empty table info.
}
if ftypes == nil || ftypes[i] == 0 {
// use varchar as the default return column type
f.Column.Tp = mysql.TypeVarchar
} else {
f.Column.Tp = ftypes[i]
}
f.Column.Charset, f.Column.Collate = types.DefaultCharsetForType(f.Column.Tp)
f.Expr = &ast.ValueExpr{}
f.Expr.SetType(&f.Column.FieldType)
fields = append(fields, f)
}
if s.Pattern != nil && s.Pattern.Expr == nil {
rf := fields[0]
s.Pattern.Expr = &ast.ColumnNameExpr{
Name: &ast.ColumnName{Name: rf.ColumnAsName},
}
ast.SetFlag(s.Pattern)
}
s.SetResultFields(fields)
nr.currentContext().fieldList = fields
}

349
vendor/github.com/pingcap/tidb/optimizer/typeinferer.go generated vendored Normal file
View file

@ -0,0 +1,349 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package optimizer
import (
"strings"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/types"
)
// InferType infers result type for ast.ExprNode.
func InferType(node ast.Node) error {
var inferrer typeInferrer
// TODO: get the default charset from ctx
inferrer.defaultCharset = "utf8"
node.Accept(&inferrer)
return inferrer.err
}
type typeInferrer struct {
err error
defaultCharset string
}
func (v *typeInferrer) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
return in, false
}
func (v *typeInferrer) Leave(in ast.Node) (out ast.Node, ok bool) {
switch x := in.(type) {
case *ast.AggregateFuncExpr:
v.aggregateFunc(x)
case *ast.BetweenExpr:
x.SetType(types.NewFieldType(mysql.TypeLonglong))
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
case *ast.BinaryOperationExpr:
v.binaryOperation(x)
case *ast.CaseExpr:
v.handleCaseExpr(x)
case *ast.ColumnNameExpr:
x.SetType(&x.Refer.Column.FieldType)
case *ast.CompareSubqueryExpr:
x.SetType(types.NewFieldType(mysql.TypeLonglong))
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
case *ast.ExistsSubqueryExpr:
x.SetType(types.NewFieldType(mysql.TypeLonglong))
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
case *ast.FuncCallExpr:
v.handleFuncCallExpr(x)
case *ast.FuncCastExpr:
x.SetType(x.Tp)
if len(x.Type.Charset) == 0 {
x.Type.Charset, x.Type.Collate = types.DefaultCharsetForType(x.Type.Tp)
}
case *ast.IsNullExpr:
x.SetType(types.NewFieldType(mysql.TypeLonglong))
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
case *ast.IsTruthExpr:
x.SetType(types.NewFieldType(mysql.TypeLonglong))
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
case *ast.ParamMarkerExpr:
x.SetType(types.DefaultTypeForValue(x.GetValue()))
case *ast.ParenthesesExpr:
x.SetType(x.Expr.GetType())
case *ast.PatternInExpr:
x.SetType(types.NewFieldType(mysql.TypeLonglong))
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
case *ast.PatternLikeExpr:
x.SetType(types.NewFieldType(mysql.TypeLonglong))
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
case *ast.PatternRegexpExpr:
x.SetType(types.NewFieldType(mysql.TypeLonglong))
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
case *ast.SelectStmt:
v.selectStmt(x)
case *ast.UnaryOperationExpr:
v.unaryOperation(x)
case *ast.ValueExpr:
v.handleValueExpr(x)
case *ast.ValuesExpr:
v.handleValuesExpr(x)
case *ast.VariableExpr:
x.SetType(types.NewFieldType(mysql.TypeVarString))
x.Type.Charset = v.defaultCharset
cln, err := charset.GetDefaultCollation(v.defaultCharset)
if err != nil {
v.err = err
}
x.Type.Collate = cln
// TODO: handle all expression types.
}
return in, true
}
func (v *typeInferrer) selectStmt(x *ast.SelectStmt) {
rf := x.GetResultFields()
for _, val := range rf {
// column ID is 0 means it is not a real column from table, but a temporary column,
// so its type is not pre-defined, we need to set it.
if val.Column.ID == 0 && val.Expr.GetType() != nil {
val.Column.FieldType = *(val.Expr.GetType())
}
}
}
func (v *typeInferrer) aggregateFunc(x *ast.AggregateFuncExpr) {
name := strings.ToLower(x.F)
switch name {
case ast.AggFuncCount:
ft := types.NewFieldType(mysql.TypeLonglong)
ft.Flen = 21
ft.Charset = charset.CharsetBin
ft.Collate = charset.CollationBin
x.SetType(ft)
case ast.AggFuncMax, ast.AggFuncMin:
x.SetType(x.Args[0].GetType())
case ast.AggFuncSum, ast.AggFuncAvg:
ft := types.NewFieldType(mysql.TypeNewDecimal)
ft.Charset = charset.CharsetBin
ft.Collate = charset.CollationBin
x.SetType(ft)
case ast.AggFuncGroupConcat:
ft := types.NewFieldType(mysql.TypeVarString)
ft.Charset = v.defaultCharset
cln, err := charset.GetDefaultCollation(v.defaultCharset)
if err != nil {
v.err = err
}
ft.Collate = cln
x.SetType(ft)
}
}
func (v *typeInferrer) binaryOperation(x *ast.BinaryOperationExpr) {
switch x.Op {
case opcode.AndAnd, opcode.OrOr, opcode.LogicXor:
x.Type = types.NewFieldType(mysql.TypeLonglong)
case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ:
x.Type = types.NewFieldType(mysql.TypeLonglong)
case opcode.RightShift, opcode.LeftShift, opcode.And, opcode.Or, opcode.Xor:
x.Type = types.NewFieldType(mysql.TypeLonglong)
x.Type.Flag |= mysql.UnsignedFlag
case opcode.IntDiv:
x.Type = types.NewFieldType(mysql.TypeLonglong)
case opcode.Plus, opcode.Minus, opcode.Mul, opcode.Mod:
if x.L.GetType() != nil && x.R.GetType() != nil {
xTp := mergeArithType(x.L.GetType().Tp, x.R.GetType().Tp)
x.Type = types.NewFieldType(xTp)
leftUnsigned := x.L.GetType().Flag & mysql.UnsignedFlag
rightUnsigned := x.R.GetType().Flag & mysql.UnsignedFlag
// If both operands are unsigned, result is unsigned.
x.Type.Flag |= (leftUnsigned & rightUnsigned)
}
case opcode.Div:
if x.L.GetType() != nil && x.R.GetType() != nil {
xTp := mergeArithType(x.L.GetType().Tp, x.R.GetType().Tp)
if xTp == mysql.TypeLonglong {
xTp = mysql.TypeDecimal
}
x.Type = types.NewFieldType(xTp)
}
}
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
}
func mergeArithType(a, b byte) byte {
switch a {
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat:
return mysql.TypeDouble
}
switch b {
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat:
return mysql.TypeDouble
}
if a == mysql.TypeNewDecimal || b == mysql.TypeNewDecimal {
return mysql.TypeNewDecimal
}
return mysql.TypeLonglong
}
func (v *typeInferrer) unaryOperation(x *ast.UnaryOperationExpr) {
switch x.Op {
case opcode.Not:
x.Type = types.NewFieldType(mysql.TypeLonglong)
case opcode.BitNeg:
x.Type = types.NewFieldType(mysql.TypeLonglong)
x.Type.Flag |= mysql.UnsignedFlag
case opcode.Plus:
x.Type = x.V.GetType()
case opcode.Minus:
x.Type = types.NewFieldType(mysql.TypeLonglong)
if x.V.GetType() != nil {
switch x.V.GetType().Tp {
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat:
x.Type.Tp = mysql.TypeDouble
case mysql.TypeNewDecimal:
x.Type.Tp = mysql.TypeNewDecimal
}
}
}
x.Type.Charset = charset.CharsetBin
x.Type.Collate = charset.CollationBin
}
func (v *typeInferrer) handleValueExpr(x *ast.ValueExpr) {
tp := types.DefaultTypeForValue(x.GetValue())
// Set charset and collation
x.SetType(tp)
}
func (v *typeInferrer) handleValuesExpr(x *ast.ValuesExpr) {
x.SetType(x.Column.GetType())
}
func (v *typeInferrer) getFsp(x *ast.FuncCallExpr) int {
if len(x.Args) == 1 {
a := x.Args[0].GetValue()
fsp, err := types.ToInt64(a)
if err != nil {
v.err = err
}
return int(fsp)
}
return 0
}
func (v *typeInferrer) handleFuncCallExpr(x *ast.FuncCallExpr) {
var (
tp *types.FieldType
chs = charset.CharsetBin
)
switch x.FnName.L {
case "abs", "ifnull", "nullif":
tp = x.Args[0].GetType()
case "pow", "power", "rand":
tp = types.NewFieldType(mysql.TypeDouble)
case "curdate", "current_date", "date":
tp = types.NewFieldType(mysql.TypeDate)
case "curtime", "current_time":
tp = types.NewFieldType(mysql.TypeDuration)
tp.Decimal = v.getFsp(x)
case "current_timestamp", "date_arith":
tp = types.NewFieldType(mysql.TypeDatetime)
case "microsecond", "second", "minute", "hour", "day", "week", "month", "year",
"dayofweek", "dayofmonth", "dayofyear", "weekday", "weekofyear", "yearweek",
"found_rows", "length", "extract", "locate":
tp = types.NewFieldType(mysql.TypeLonglong)
case "now", "sysdate":
tp = types.NewFieldType(mysql.TypeDatetime)
tp.Decimal = v.getFsp(x)
case "dayname", "version", "database", "user", "current_user",
"concat", "concat_ws", "left", "lower", "repeat", "replace", "upper", "convert",
"substring", "substring_index", "trim":
tp = types.NewFieldType(mysql.TypeVarString)
chs = v.defaultCharset
case "strcmp":
tp = types.NewFieldType(mysql.TypeLonglong)
case "connection_id":
tp = types.NewFieldType(mysql.TypeLonglong)
tp.Flag |= mysql.UnsignedFlag
case "if":
// TODO: fix this
// See: https://dev.mysql.com/doc/refman/5.5/en/control-flow-functions.html#function_if
// The default return type of IF() (which may matter when it is stored into a temporary table) is calculated as follows.
// Expression Return Value
// expr2 or expr3 returns a string string
// expr2 or expr3 returns a floating-point value floating-point
// expr2 or expr3 returns an integer integer
tp = x.Args[1].GetType()
default:
tp = types.NewFieldType(mysql.TypeUnspecified)
}
// If charset is unspecified.
if len(tp.Charset) == 0 {
tp.Charset = chs
cln := charset.CollationBin
if chs != charset.CharsetBin {
var err error
cln, err = charset.GetDefaultCollation(chs)
if err != nil {
v.err = err
}
}
tp.Collate = cln
}
x.SetType(tp)
}
// The return type of a CASE expression is the compatible aggregated type of all return values,
// but also depends on the context in which it is used.
// If used in a string context, the result is returned as a string.
// If used in a numeric context, the result is returned as a decimal, real, or integer value.
func (v *typeInferrer) handleCaseExpr(x *ast.CaseExpr) {
var currType *types.FieldType
for _, w := range x.WhenClauses {
t := w.Result.GetType()
if currType == nil {
currType = t
continue
}
mtp := types.MergeFieldType(currType.Tp, t.Tp)
if mtp == t.Tp && mtp != currType.Tp {
currType.Charset = t.Charset
currType.Collate = t.Collate
}
currType.Tp = mtp
}
if x.ElseClause != nil {
t := x.ElseClause.GetType()
if currType == nil {
currType = t
} else {
mtp := types.MergeFieldType(currType.Tp, t.Tp)
if mtp == t.Tp && mtp != currType.Tp {
currType.Charset = t.Charset
currType.Collate = t.Collate
}
currType.Tp = mtp
}
}
x.SetType(currType)
// TODO: We need a better way to set charset/collation
x.Type.Charset, x.Type.Collate = types.DefaultCharsetForType(x.Type.Tp)
}

246
vendor/github.com/pingcap/tidb/optimizer/validator.go generated vendored Normal file
View file

@ -0,0 +1,246 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package optimizer
import (
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser"
"github.com/pingcap/tidb/parser/opcode"
)
// Validate checkes whether the node is valid.
func Validate(node ast.Node, inPrepare bool) error {
v := validator{inPrepare: inPrepare}
node.Accept(&v)
return v.err
}
// validator is an ast.Visitor that validates
// ast Nodes parsed from parser.
type validator struct {
err error
wildCardCount int
inPrepare bool
inAggregate bool
}
func (v *validator) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
switch in.(type) {
case *ast.AggregateFuncExpr:
if v.inAggregate {
// Aggregate function can not contain aggregate function.
v.err = ErrInvalidGroupFuncUse
return in, true
}
v.inAggregate = true
}
return in, false
}
func (v *validator) Leave(in ast.Node) (out ast.Node, ok bool) {
switch x := in.(type) {
case *ast.AggregateFuncExpr:
v.inAggregate = false
case *ast.BetweenExpr:
v.checkAllOneColumn(x.Expr, x.Left, x.Right)
case *ast.BinaryOperationExpr:
v.checkBinaryOperation(x)
case *ast.ByItem:
v.checkAllOneColumn(x.Expr)
case *ast.CreateTableStmt:
v.checkAutoIncrement(x)
case *ast.CompareSubqueryExpr:
v.checkSameColumns(x.L, x.R)
case *ast.FieldList:
v.checkFieldList(x)
case *ast.HavingClause:
v.checkAllOneColumn(x.Expr)
case *ast.IsNullExpr:
v.checkAllOneColumn(x.Expr)
case *ast.IsTruthExpr:
v.checkAllOneColumn(x.Expr)
case *ast.ParamMarkerExpr:
if !v.inPrepare {
v.err = parser.ErrSyntax.Gen("syntax error, unexpected '?'")
}
case *ast.PatternInExpr:
v.checkSameColumns(append(x.List, x.Expr)...)
}
return in, v.err == nil
}
// checkAllOneColumn checks that all expressions have one column.
// Expression may have more than one column when it is a rowExpr or
// a Subquery with more than one result fields.
func (v *validator) checkAllOneColumn(exprs ...ast.ExprNode) {
for _, expr := range exprs {
switch x := expr.(type) {
case *ast.RowExpr:
v.err = ErrOneColumn
case *ast.SubqueryExpr:
if len(x.Query.GetResultFields()) != 1 {
v.err = ErrOneColumn
}
}
}
return
}
func checkAutoIncrementOp(colDef *ast.ColumnDef, num int) (bool, error) {
var hasAutoIncrement bool
if colDef.Options[num].Tp == ast.ColumnOptionAutoIncrement {
hasAutoIncrement = true
if len(colDef.Options) == num+1 {
return hasAutoIncrement, nil
}
for _, op := range colDef.Options[num+1:] {
if op.Tp == ast.ColumnOptionDefaultValue {
return hasAutoIncrement, errors.Errorf("Invalid default value for '%s'", colDef.Name.Name.O)
}
}
}
if colDef.Options[num].Tp == ast.ColumnOptionDefaultValue && len(colDef.Options) != num+1 {
for _, op := range colDef.Options[num+1:] {
if op.Tp == ast.ColumnOptionAutoIncrement {
return hasAutoIncrement, errors.Errorf("Invalid default value for '%s'", colDef.Name.Name.O)
}
}
}
return hasAutoIncrement, nil
}
func isConstraintKeyTp(constraints []*ast.Constraint, colDef *ast.ColumnDef) bool {
for _, c := range constraints {
if len(c.Keys) < 1 {
}
// If the constraint as follows: primary key(c1, c2)
// we only support c1 column can be auto_increment.
if colDef.Name.Name.L != c.Keys[0].Column.Name.L {
continue
}
switch c.Tp {
case ast.ConstraintPrimaryKey, ast.ConstraintKey, ast.ConstraintIndex,
ast.ConstraintUniq, ast.ConstraintUniqIndex, ast.ConstraintUniqKey:
return true
}
}
return false
}
func (v *validator) checkAutoIncrement(stmt *ast.CreateTableStmt) {
var (
isKey bool
count int
autoIncrementCol *ast.ColumnDef
)
for _, colDef := range stmt.Cols {
var hasAutoIncrement bool
for i, op := range colDef.Options {
ok, err := checkAutoIncrementOp(colDef, i)
if err != nil {
v.err = err
return
}
if ok {
hasAutoIncrement = true
}
switch op.Tp {
case ast.ColumnOptionPrimaryKey, ast.ColumnOptionUniqKey, ast.ColumnOptionUniqIndex,
ast.ColumnOptionUniq, ast.ColumnOptionKey, ast.ColumnOptionIndex:
isKey = true
}
}
if hasAutoIncrement {
count++
autoIncrementCol = colDef
}
}
if count < 1 {
return
}
if !isKey {
isKey = isConstraintKeyTp(stmt.Constraints, autoIncrementCol)
}
if !isKey || count > 1 {
v.err = errors.New("Incorrect table definition; there can be only one auto column and it must be defined as a key")
}
switch autoIncrementCol.Tp.Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeLong,
mysql.TypeFloat, mysql.TypeDouble, mysql.TypeLonglong, mysql.TypeInt24:
default:
v.err = errors.Errorf("Incorrect column specifier for column '%s'", autoIncrementCol.Name.Name.O)
}
}
func (v *validator) checkBinaryOperation(x *ast.BinaryOperationExpr) {
// row constructor only supports comparison operation.
switch x.Op {
case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ:
v.checkSameColumns(x.L, x.R)
default:
v.checkAllOneColumn(x.L, x.R)
}
}
func columnCount(ex ast.ExprNode) int {
switch x := ex.(type) {
case *ast.RowExpr:
return len(x.Values)
case *ast.SubqueryExpr:
return len(x.Query.GetResultFields())
default:
return 1
}
}
func (v *validator) checkSameColumns(exprs ...ast.ExprNode) {
if len(exprs) == 0 {
return
}
count := columnCount(exprs[0])
for i := 1; i < len(exprs); i++ {
if columnCount(exprs[i]) != count {
v.err = ErrSameColumns
return
}
}
}
// checkFieldList checks if there is only one '*' and each field has only one column.
func (v *validator) checkFieldList(x *ast.FieldList) {
var hasWildCard bool
for _, val := range x.Fields {
if val.WildCard != nil && val.WildCard.Table.L == "" {
if hasWildCard {
v.err = ErrMultiWildCard
return
}
hasWildCard = true
}
v.checkAllOneColumn(val.Expr)
if v.err != nil {
return
}
}
}