1
0
Fork 0
forked from forgejo/forgejo

Upgrade xorm to v1.2.2 (#16663)

* Upgrade xorm to v1.2.2

* Change the Engine interface to match xorm v1.2.2
This commit is contained in:
Lunny Xiao 2021-08-13 07:11:42 +08:00 committed by GitHub
parent 5fbccad906
commit 7224cfc578
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
134 changed files with 42889 additions and 5428 deletions

View file

@ -6,6 +6,7 @@ package dialects
import (
"context"
"database/sql"
"errors"
"fmt"
"net/url"
@ -777,12 +778,24 @@ var (
var (
// DefaultPostgresSchema default postgres schema
DefaultPostgresSchema = "public"
postgresColAliases = map[string]string{
"numeric": "decimal",
}
)
type postgres struct {
Base
}
// Alias returns a alias of column
func (db *postgres) Alias(col string) string {
v, ok := postgresColAliases[strings.ToLower(col)]
if ok {
return v
}
return col
}
func (db *postgres) Init(uri *URI) error {
db.quoter = postgresQuoter
return db.Base.Init(db, uri)
@ -797,7 +810,10 @@ func (db *postgres) Version(ctx context.Context, queryer core.Queryer) (*schemas
var version string
if !rows.Next() {
return nil, errors.New("Unknow version")
if rows.Err() != nil {
return nil, rows.Err()
}
return nil, errors.New("unknow version")
}
if err := rows.Scan(&version); err != nil {
@ -860,21 +876,16 @@ func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) {
}
}
// FormatBytes formats bytes
func (db *postgres) FormatBytes(bs []byte) string {
return fmt.Sprintf("E'\\x%x'", bs)
}
func (db *postgres) SQLType(c *schemas.Column) string {
var res string
switch t := c.SQLType.Name; t {
case schemas.TinyInt:
case schemas.TinyInt, schemas.UnsignedTinyInt:
res = schemas.SmallInt
return res
case schemas.Bit:
res = schemas.Boolean
return res
case schemas.MediumInt, schemas.Int, schemas.Integer:
case schemas.MediumInt, schemas.Int, schemas.Integer, schemas.UnsignedMediumInt, schemas.UnsignedSmallInt:
if c.IsAutoIncrement {
return schemas.Serial
}
@ -930,6 +941,21 @@ func (db *postgres) SQLType(c *schemas.Column) string {
return res
}
func (db *postgres) ColumnTypeKind(t string) int {
switch strings.ToUpper(t) {
case "DATETIME", "TIMESTAMP":
return schemas.TIME_TYPE
case "VARCHAR", "TEXT":
return schemas.TEXT_TYPE
case "BIGINT", "BIGSERIAL", "SMALLINT", "INT", "INT8", "INT4", "INTEGER", "SERIAL", "FLOAT", "FLOAT4", "REAL", "DOUBLE PRECISION":
return schemas.NUMERIC_TYPE
case "BOOL":
return schemas.BOOL_TYPE
default:
return schemas.UNKNOW_TYPE
}
}
func (db *postgres) IsReserved(name string) bool {
_, ok := postgresReservedWords[strings.ToUpper(name)]
return ok
@ -1039,7 +1065,10 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab
}
defer rows.Close()
return rows.Next(), nil
if rows.Next() {
return true, nil
}
return false, rows.Err()
}
func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
@ -1169,7 +1198,7 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
}
}
if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok {
return nil, nil, fmt.Errorf("Unknown colType: %s - %s", dataType, col.SQLType.Name)
return nil, nil, fmt.Errorf("unknown colType: %s - %s", dataType, col.SQLType.Name)
}
col.Length = maxLen
@ -1177,19 +1206,22 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
if !col.DefaultIsEmpty {
if col.SQLType.IsText() {
if strings.HasSuffix(col.Default, "::character varying") {
col.Default = strings.TrimRight(col.Default, "::character varying")
col.Default = strings.TrimSuffix(col.Default, "::character varying")
} else if !strings.HasPrefix(col.Default, "'") {
col.Default = "'" + col.Default + "'"
}
} else if col.SQLType.IsTime() {
if strings.HasSuffix(col.Default, "::timestamp without time zone") {
col.Default = strings.TrimRight(col.Default, "::timestamp without time zone")
col.Default = strings.TrimSuffix(col.Default, "::timestamp without time zone")
}
}
}
cols[col.Name] = col
colSeq = append(colSeq, col.Name)
}
if rows.Err() != nil {
return nil, nil, rows.Err()
}
return colSeq, cols, nil
}
@ -1220,6 +1252,9 @@ func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*sch
table.Name = name
tables = append(tables, table)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return tables, nil
}
@ -1236,7 +1271,7 @@ func getIndexColName(indexdef string) []string {
func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{tableName}
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
s := "SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1"
if len(db.getSchema()) != 0 {
args = append(args, db.getSchema())
s = s + " AND schemaname=$2"
@ -1248,7 +1283,7 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN
}
defer rows.Close()
indexes := make(map[string]*schemas.Index, 0)
indexes := make(map[string]*schemas.Index)
for rows.Next() {
var indexType int
var indexName, indexdef string
@ -1290,6 +1325,9 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN
index.IsRegular = isRegular
indexes[index.Name] = index
}
if rows.Err() != nil {
return nil, rows.Err()
}
return indexes, nil
}
@ -1298,18 +1336,11 @@ func (db *postgres) Filters() []Filter {
}
type pqDriver struct {
baseDriver
}
type values map[string]string
func (vs values) Set(k, v string) {
vs[k] = v
}
func (vs values) Get(k string) (v string) {
return vs[k]
}
func parseURL(connstr string) (string, error) {
u, err := url.Parse(connstr)
if err != nil {
@ -1329,30 +1360,94 @@ func parseURL(connstr string) (string, error) {
return "", nil
}
func parseOpts(name string, o values) error {
if len(name) == 0 {
return fmt.Errorf("invalid options: %s", name)
func parseOpts(urlStr string, o values) error {
if len(urlStr) == 0 {
return fmt.Errorf("invalid options: %s", urlStr)
}
name = strings.TrimSpace(name)
urlStr = strings.TrimSpace(urlStr)
ps := strings.Split(name, " ")
for _, p := range ps {
kv := strings.Split(p, "=")
if len(kv) < 2 {
return fmt.Errorf("invalid option: %q", p)
var (
inQuote bool
state int // 0 key, 1 space, 2 value, 3 equal
start int
key string
)
for i, c := range urlStr {
switch c {
case ' ':
if !inQuote {
if state == 2 {
state = 1
v := urlStr[start:i]
if strings.HasPrefix(v, "'") && strings.HasSuffix(v, "'") {
v = v[1 : len(v)-1]
} else if strings.HasPrefix(v, "'") || strings.HasSuffix(v, "'") {
return fmt.Errorf("wrong single quote in %d of %s", i, urlStr)
}
o[key] = v
} else if state != 1 {
return fmt.Errorf("wrong format: %v", urlStr)
}
}
case '\'':
if state == 3 {
state = 2
start = i
} else if state != 2 {
return fmt.Errorf("wrong format: %v", urlStr)
}
inQuote = !inQuote
case '=':
if !inQuote {
if state != 0 {
return fmt.Errorf("wrong format: %v", urlStr)
}
key = urlStr[start:i]
state = 3
}
default:
if state == 3 {
state = 2
start = i
} else if state == 1 {
state = 0
start = i
}
}
if i == len(urlStr)-1 {
if state != 2 {
return errors.New("no value matched key")
}
v := urlStr[start : i+1]
if strings.HasPrefix(v, "'") && strings.HasSuffix(v, "'") {
v = v[1 : len(v)-1]
} else if strings.HasPrefix(v, "'") || strings.HasSuffix(v, "'") {
return fmt.Errorf("wrong single quote in %d of %s", i, urlStr)
}
o[key] = v
}
o.Set(kv[0], kv[1])
}
return nil
}
func (p *pqDriver) Features() *DriverFeatures {
return &DriverFeatures{
SupportReturnInsertedID: false,
}
}
func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
db := &URI{DBType: schemas.POSTGRES}
var err error
if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") {
var err error
if strings.Contains(dataSourceName, "://") {
if !strings.HasPrefix(dataSourceName, "postgresql://") && !strings.HasPrefix(dataSourceName, "postgres://") {
return nil, fmt.Errorf("unsupported protocol %v", dataSourceName)
}
db.DBName, err = parseURL(dataSourceName)
if err != nil {
return nil, err
@ -1364,7 +1459,7 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
return nil, err
}
db.DBName = o.Get("dbname")
db.DBName = o["dbname"]
}
if db.DBName == "" {
@ -1374,6 +1469,32 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
return db, nil
}
func (p *pqDriver) GenScanResult(colType string) (interface{}, error) {
switch colType {
case "VARCHAR", "TEXT":
var s sql.NullString
return &s, nil
case "BIGINT", "BIGSERIAL":
var s sql.NullInt64
return &s, nil
case "SMALLINT", "INT", "INT8", "INT4", "INTEGER", "SERIAL":
var s sql.NullInt32
return &s, nil
case "FLOAT", "FLOAT4", "REAL", "DOUBLE PRECISION":
var s sql.NullFloat64
return &s, nil
case "DATETIME", "TIMESTAMP":
var s sql.NullTime
return &s, nil
case "BOOL":
var s sql.NullBool
return &s, nil
default:
var r sql.RawBytes
return &r, nil
}
}
type pqDriverPgx struct {
pqDriver
}
@ -1401,6 +1522,9 @@ func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (stri
parts := strings.Split(defaultSchema, ",")
return strings.TrimSpace(parts[len(parts)-1]), nil
}
if rows.Err() != nil {
return "", rows.Err()
}
return "", errors.New("No default schema")
return "", errors.New("no default schema")
}