Upgrade xorm to v1.0.0 (#10646)

* Upgrade xorm to v1.0.0

* small nit

* Fix tests

* Update xorm

* Update xorm

* fix go.sum

* fix test

* Fix bug when dump

* Fix bug

* update xorm to latest

* Fix migration test

* update xorm to latest

* Fix import order

* Use xorm tag
This commit is contained in:
Lunny Xiao 2020-03-22 23:12:55 +08:00 committed by GitHub
parent dcaa5643d7
commit c61b902538
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
154 changed files with 7195 additions and 5962 deletions

178
vendor/xorm.io/xorm/session_schema.go generated vendored
View file

@ -5,11 +5,15 @@
package xorm
import (
"bufio"
"database/sql"
"fmt"
"io"
"os"
"strings"
"xorm.io/core"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
// Ping test if database is ok
@ -32,13 +36,18 @@ func (session *Session) CreateTable(bean interface{}) error {
}
func (session *Session) createTable(bean interface{}) error {
if err := session.statement.setRefBean(bean); err != nil {
if err := session.statement.SetRefBean(bean); err != nil {
return err
}
sqlStr := session.statement.genCreateTableSQL()
_, err := session.exec(sqlStr)
return err
sqlStrs := session.statement.GenCreateTableSQL()
for _, s := range sqlStrs {
_, err := session.exec(s)
if err != nil {
return err
}
}
return nil
}
// CreateIndexes create indexes
@ -51,11 +60,11 @@ func (session *Session) CreateIndexes(bean interface{}) error {
}
func (session *Session) createIndexes(bean interface{}) error {
if err := session.statement.setRefBean(bean); err != nil {
if err := session.statement.SetRefBean(bean); err != nil {
return err
}
sqls := session.statement.genIndexSQL()
sqls := session.statement.GenIndexSQL()
for _, sqlStr := range sqls {
_, err := session.exec(sqlStr)
if err != nil {
@ -74,11 +83,11 @@ func (session *Session) CreateUniques(bean interface{}) error {
}
func (session *Session) createUniques(bean interface{}) error {
if err := session.statement.setRefBean(bean); err != nil {
if err := session.statement.SetRefBean(bean); err != nil {
return err
}
sqls := session.statement.genUniqueSQL()
sqls := session.statement.GenUniqueSQL()
for _, sqlStr := range sqls {
_, err := session.exec(sqlStr)
if err != nil {
@ -98,11 +107,11 @@ func (session *Session) DropIndexes(bean interface{}) error {
}
func (session *Session) dropIndexes(bean interface{}) error {
if err := session.statement.setRefBean(bean); err != nil {
if err := session.statement.SetRefBean(bean); err != nil {
return err
}
sqls := session.statement.genDelIndexSQL()
sqls := session.statement.GenDelIndexSQL()
for _, sqlStr := range sqls {
_, err := session.exec(sqlStr)
if err != nil {
@ -123,18 +132,16 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
func (session *Session) dropTable(beanOrTableName interface{}) error {
tableName := session.engine.TableName(beanOrTableName)
var needDrop = true
if !session.engine.dialect.SupportDropIfExists() {
sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
results, err := session.queryBytes(sqlStr, args...)
sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true))
if !checkIfExist {
exist, err := session.engine.dialect.IsTableExist(session.ctx, tableName)
if err != nil {
return err
}
needDrop = len(results) > 0
checkIfExist = exist
}
if needDrop {
sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableName(tableName, true))
if checkIfExist {
_, err := session.exec(sqlStr)
return err
}
@ -153,9 +160,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
}
func (session *Session) isTableExist(tableName string) (bool, error) {
sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
results, err := session.queryBytes(sqlStr, args...)
return len(results) > 0, err
return session.engine.dialect.IsTableExist(session.ctx, tableName)
}
// IsTableEmpty if table have any records
@ -182,17 +187,17 @@ func (session *Session) isTableEmpty(tableName string) (bool, error) {
// find if index is exist according cols
func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) {
indexes, err := session.engine.dialect.GetIndexes(tableName)
indexes, err := session.engine.dialect.GetIndexes(session.ctx, tableName)
if err != nil {
return false, err
}
for _, index := range indexes {
if sliceEq(index.Cols, cols) {
if utils.SliceEq(index.Cols, cols) {
if unique {
return index.Type == core.UniqueType, nil
return index.Type == schemas.UniqueType, nil
}
return index.Type == core.IndexType, nil
return index.Type == schemas.IndexType, nil
}
}
return false, nil
@ -200,21 +205,21 @@ func (session *Session) isIndexExist2(tableName string, cols []string, unique bo
func (session *Session) addColumn(colName string) error {
col := session.statement.RefTable.GetColumn(colName)
sql, args := session.statement.genAddColumnStr(col)
_, err := session.exec(sql, args...)
sql := session.engine.dialect.AddColumnSQL(session.statement.TableName(), col)
_, err := session.exec(sql)
return err
}
func (session *Session) addIndex(tableName, idxName string) error {
index := session.statement.RefTable.Indexes[idxName]
sqlStr := session.engine.dialect.CreateIndexSql(tableName, index)
sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index)
_, err := session.exec(sqlStr)
return err
}
func (session *Session) addUnique(tableName, uqeName string) error {
index := session.statement.RefTable.Indexes[uqeName]
sqlStr := session.engine.dialect.CreateIndexSql(tableName, index)
sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index)
_, err := session.exec(sqlStr)
return err
}
@ -228,7 +233,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
defer session.Close()
}
tables, err := engine.dialect.GetTables()
tables, err := engine.dialect.GetTables(session.ctx)
if err != nil {
return err
}
@ -240,8 +245,8 @@ func (session *Session) Sync2(beans ...interface{}) error {
}()
for _, bean := range beans {
v := rValue(bean)
table, err := engine.mapType(v)
v := utils.ReflectValue(bean)
table, err := engine.tagParser.ParseWithCache(v)
if err != nil {
return err
}
@ -253,7 +258,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
}
tbNameWithSchema := engine.tbNameWithSchema(tbName)
var oriTable *core.Table
var oriTable *schemas.Table
for _, tb := range tables {
if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) {
oriTable = tb
@ -287,7 +292,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
// check columns
for _, col := range table.Columns() {
var oriCol *core.Column
var oriCol *schemas.Column
for _, col2 := range oriTable.Columns() {
if strings.EqualFold(col.Name, col2.Name) {
oriCol = col2
@ -298,7 +303,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
// column is not exist on table
if oriCol == nil {
session.statement.RefTable = table
session.statement.tableName = tbNameWithSchema
session.statement.SetTableName(tbNameWithSchema)
if err = session.addColumn(col.Name); err != nil {
return err
}
@ -306,27 +311,27 @@ func (session *Session) Sync2(beans ...interface{}) error {
}
err = nil
expectedType := engine.dialect.SqlType(col)
curType := engine.dialect.SqlType(oriCol)
expectedType := engine.dialect.SQLType(col)
curType := engine.dialect.SQLType(oriCol)
if expectedType != curType {
if expectedType == core.Text &&
strings.HasPrefix(curType, core.Varchar) {
if expectedType == schemas.Text &&
strings.HasPrefix(curType, schemas.Varchar) {
// currently only support mysql & postgres
if engine.dialect.DBType() == core.MYSQL ||
engine.dialect.DBType() == core.POSTGRES {
if engine.dialect.URI().DBType == schemas.MYSQL ||
engine.dialect.URI().DBType == schemas.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbNameWithSchema, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
} else {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
tbNameWithSchema, col.Name, curType, expectedType)
}
} else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) {
if engine.dialect.DBType() == core.MYSQL {
} else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) {
if engine.dialect.URI().DBType == schemas.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
}
}
} else {
@ -335,21 +340,23 @@ func (session *Session) Sync2(beans ...interface{}) error {
tbNameWithSchema, col.Name, curType, expectedType)
}
}
} else if expectedType == core.Varchar {
if engine.dialect.DBType() == core.MYSQL {
} else if expectedType == schemas.Varchar {
if engine.dialect.URI().DBType == schemas.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
}
}
}
if col.Default != oriCol.Default {
if (col.SQLType.Name == core.Bool || col.SQLType.Name == core.Boolean) &&
switch {
case col.IsAutoIncrement: // For autoincrement column, don't check default
case (col.SQLType.Name == schemas.Bool || col.SQLType.Name == schemas.Boolean) &&
((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") ||
(strings.EqualFold(col.Default, "false") && oriCol.Default == "0")) {
} else {
(strings.EqualFold(col.Default, "false") && oriCol.Default == "0")):
default:
engine.logger.Warnf("Table %s Column %s db default is %s, struct default is %s",
tbName, col.Name, oriCol.Default, col.Default)
}
@ -365,10 +372,10 @@ func (session *Session) Sync2(beans ...interface{}) error {
}
var foundIndexNames = make(map[string]bool)
var addedNames = make(map[string]*core.Index)
var addedNames = make(map[string]*schemas.Index)
for name, index := range table.Indexes {
var oriIndex *core.Index
var oriIndex *schemas.Index
for name2, index2 := range oriTable.Indexes {
if index.Equal(index2) {
oriIndex = index2
@ -379,7 +386,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
if oriIndex != nil {
if oriIndex.Type != index.Type {
sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex)
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex)
_, err = session.exec(sql)
if err != nil {
return err
@ -395,7 +402,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name2, index2 := range oriTable.Indexes {
if _, ok := foundIndexNames[name2]; !ok {
sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2)
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2)
_, err = session.exec(sql)
if err != nil {
return err
@ -404,13 +411,13 @@ func (session *Session) Sync2(beans ...interface{}) error {
}
for name, index := range addedNames {
if index.Type == core.UniqueType {
if index.Type == schemas.UniqueType {
session.statement.RefTable = table
session.statement.tableName = tbNameWithSchema
session.statement.SetTableName(tbNameWithSchema)
err = session.addUnique(tbNameWithSchema, name)
} else if index.Type == core.IndexType {
} else if index.Type == schemas.IndexType {
session.statement.RefTable = table
session.statement.tableName = tbNameWithSchema
session.statement.SetTableName(tbNameWithSchema)
err = session.addIndex(tbNameWithSchema, name)
}
if err != nil {
@ -428,3 +435,56 @@ func (session *Session) Sync2(beans ...interface{}) error {
return nil
}
// ImportFile SQL DDL file
func (session *Session) ImportFile(ddlPath string) ([]sql.Result, error) {
file, err := os.Open(ddlPath)
if err != nil {
return nil, err
}
defer file.Close()
return session.Import(file)
}
// Import SQL DDL from io.Reader
func (session *Session) Import(r io.Reader) ([]sql.Result, error) {
var results []sql.Result
var lastError error
scanner := bufio.NewScanner(r)
var inSingleQuote bool
semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
for i, b := range data {
if b == '\'' {
inSingleQuote = !inSingleQuote
}
if !inSingleQuote && b == ';' {
return i + 1, data[0:i], nil
}
}
// If we're at EOF, we have a final, non-terminated line. Return it.
if atEOF {
return len(data), data, nil
}
// Request more data.
return 0, nil, nil
}
scanner.Split(semiColSpliter)
for scanner.Scan() {
query := strings.Trim(scanner.Text(), " \t\n\r")
if len(query) > 0 {
result, err := session.Exec(query)
results = append(results, result)
if err != nil {
return nil, err
}
}
}
return results, lastError
}