feat: use XORM EngineGroup instead of single Engine connection (#7212)

Resolves #7207

Add new configuration to make XORM work with a main and replicas database instances. The follow configuration parameters were added:

- `HOST_PRIMARY`
- `HOST_REPLICAS`
- `LOAD_BALANCE_POLICY`. Options:
    - `"WeightRandom"` -> `xorm.WeightRandomPolicy`
    - `"WeightRoundRobin`  -> `WeightRoundRobinPolicy`
    - `"LeastCon"` -> `LeastConnPolicy`
    - `"RoundRobin"` -> `xorm.RoundRobinPolicy()`
    - default: `xorm.RandomPolicy()`
- `LOAD_BALANCE_WEIGHTS`

Co-authored-by: pat-s <patrick.schratz@gmail.com@>
Reviewed-on: https://codeberg.org/forgejo/forgejo/pulls/7212
Reviewed-by: Gusted <gusted@noreply.codeberg.org>
Co-authored-by: pat-s <patrick.schratz@gmail.com>
Co-committed-by: pat-s <patrick.schratz@gmail.com>
This commit is contained in:
pat-s 2025-03-30 11:34:02 +00:00 committed by Gusted
parent a23d0453a3
commit 63a80bf2b9
19 changed files with 463 additions and 129 deletions

View file

@ -20,7 +20,6 @@ import (
"forgejo.org/services/doctor" "forgejo.org/services/doctor"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"xorm.io/xorm"
) )
// CmdDoctor represents the available doctor sub-command. // CmdDoctor represents the available doctor sub-command.
@ -120,7 +119,7 @@ func runRecreateTable(ctx *cli.Context) error {
args := ctx.Args() args := ctx.Args()
names := make([]string, 0, ctx.NArg()) names := make([]string, 0, ctx.NArg())
for i := 0; i < ctx.NArg(); i++ { for i := range ctx.NArg() {
names = append(names, args.Get(i)) names = append(names, args.Get(i))
} }
@ -130,11 +129,17 @@ func runRecreateTable(ctx *cli.Context) error {
} }
recreateTables := migrate_base.RecreateTables(beans...) recreateTables := migrate_base.RecreateTables(beans...)
return db.InitEngineWithMigration(stdCtx, func(x *xorm.Engine) error { return db.InitEngineWithMigration(stdCtx, func(x db.Engine) error {
if err := migrations.EnsureUpToDate(x); err != nil { engine, err := db.GetMasterEngine(x)
if err != nil {
return err return err
} }
return recreateTables(x)
if err := migrations.EnsureUpToDate(engine); err != nil {
return err
}
return recreateTables(engine)
}) })
} }

View file

@ -36,7 +36,13 @@ func runMigrate(ctx *cli.Context) error {
log.Info("Log path: %s", setting.Log.RootPath) log.Info("Log path: %s", setting.Log.RootPath)
log.Info("Configuration file: %s", setting.CustomConf) log.Info("Configuration file: %s", setting.CustomConf)
if err := db.InitEngineWithMigration(context.Background(), migrations.Migrate); err != nil { if err := db.InitEngineWithMigration(context.Background(), func(dbEngine db.Engine) error {
masterEngine, err := db.GetMasterEngine(dbEngine)
if err != nil {
return err
}
return migrations.Migrate(masterEngine)
}); err != nil {
log.Fatal("Failed to initialize ORM engine: %v", err) log.Fatal("Failed to initialize ORM engine: %v", err)
return err return err
} }

View file

@ -23,6 +23,7 @@ import (
"forgejo.org/modules/storage" "forgejo.org/modules/storage"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"xorm.io/xorm"
) )
// CmdMigrateStorage represents the available migrate storage sub-command. // CmdMigrateStorage represents the available migrate storage sub-command.
@ -195,7 +196,9 @@ func runMigrateStorage(ctx *cli.Context) error {
log.Info("Log path: %s", setting.Log.RootPath) log.Info("Log path: %s", setting.Log.RootPath)
log.Info("Configuration file: %s", setting.CustomConf) log.Info("Configuration file: %s", setting.CustomConf)
if err := db.InitEngineWithMigration(context.Background(), migrations.Migrate); err != nil { if err := db.InitEngineWithMigration(context.Background(), func(e db.Engine) error {
return migrations.Migrate(e.(*xorm.Engine))
}); err != nil {
log.Fatal("Failed to initialize ORM engine: %v", err) log.Fatal("Failed to initialize ORM engine: %v", err)
return err return err
} }

View file

@ -95,34 +95,70 @@ func init() {
} }
} }
// newXORMEngine returns a new XORM engine from the configuration // newXORMEngineGroup creates an xorm.EngineGroup (with one master and one or more slaves).
func newXORMEngine() (*xorm.Engine, error) { // It assumes you have separate master and slave DSNs defined via the settings package.
connStr, err := setting.DBConnStr() func newXORMEngineGroup() (Engine, error) {
// Retrieve master DSN from settings.
masterConnStr, err := setting.DBMasterConnStr()
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to determine master DSN: %w", err)
} }
var engine *xorm.Engine var masterEngine *xorm.Engine
// For PostgreSQL: if a schema is provided, we use the special "postgresschema" driver.
if setting.Database.Type.IsPostgreSQL() && len(setting.Database.Schema) > 0 { if setting.Database.Type.IsPostgreSQL() && len(setting.Database.Schema) > 0 {
// OK whilst we sort out our schema issues - create a schema aware postgres
registerPostgresSchemaDriver() registerPostgresSchemaDriver()
engine, err = xorm.NewEngine("postgresschema", connStr) masterEngine, err = xorm.NewEngine("postgresschema", masterConnStr)
} else { } else {
engine, err = xorm.NewEngine(setting.Database.Type.String(), connStr) masterEngine, err = xorm.NewEngine(setting.Database.Type.String(), masterConnStr)
} }
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to create master engine: %w", err)
} }
if setting.Database.Type.IsMySQL() { if setting.Database.Type.IsMySQL() {
engine.Dialect().SetParams(map[string]string{"rowFormat": "DYNAMIC"}) masterEngine.Dialect().SetParams(map[string]string{"rowFormat": "DYNAMIC"})
} }
engine.SetSchema(setting.Database.Schema) masterEngine.SetSchema(setting.Database.Schema)
return engine, nil
slaveConnStrs, err := setting.DBSlaveConnStrs()
if err != nil {
return nil, fmt.Errorf("failed to load slave DSNs: %w", err)
} }
// SyncAllTables sync the schemas of all tables, is required by unit test code var slaveEngines []*xorm.Engine
// Iterate over all slave DSNs and create engines
for _, dsn := range slaveConnStrs {
slaveEngine, err := xorm.NewEngine(setting.Database.Type.String(), dsn)
if err != nil {
return nil, fmt.Errorf("failed to create slave engine for dsn %q: %w", dsn, err)
}
if setting.Database.Type.IsMySQL() {
slaveEngine.Dialect().SetParams(map[string]string{"rowFormat": "DYNAMIC"})
}
slaveEngine.SetSchema(setting.Database.Schema)
slaveEngines = append(slaveEngines, slaveEngine)
}
policy := setting.BuildLoadBalancePolicy(&setting.Database, slaveEngines)
// Create the EngineGroup using the selected policy
group, err := xorm.NewEngineGroup(masterEngine, slaveEngines, policy)
if err != nil {
return nil, fmt.Errorf("failed to create engine group: %w", err)
}
return engineGroupWrapper{group}, nil
}
type engineGroupWrapper struct {
*xorm.EngineGroup
}
func (w engineGroupWrapper) AddHook(hook contexts.Hook) bool {
w.EngineGroup.AddHook(hook)
return true
}
// SyncAllTables sync the schemas of all tables
func SyncAllTables() error { func SyncAllTables() error {
_, err := x.StoreEngine("InnoDB").SyncWithOptions(xorm.SyncOptions{ _, err := x.StoreEngine("InnoDB").SyncWithOptions(xorm.SyncOptions{
WarnIfDatabaseColumnMissed: true, WarnIfDatabaseColumnMissed: true,
@ -130,26 +166,27 @@ func SyncAllTables() error {
return err return err
} }
// InitEngine initializes the xorm.Engine and sets it as db.DefaultContext // InitEngine initializes the xorm EngineGroup and sets it as db.DefaultContext
func InitEngine(ctx context.Context) error { func InitEngine(ctx context.Context) error {
xormEngine, err := newXORMEngine() xormEngine, err := newXORMEngineGroup()
if err != nil { if err != nil {
return fmt.Errorf("failed to connect to database: %w", err) return fmt.Errorf("failed to connect to database: %w", err)
} }
// Try to cast to the concrete type to access diagnostic methods
xormEngine.SetMapper(names.GonicMapper{}) if eng, ok := xormEngine.(engineGroupWrapper); ok {
// WARNING: for serv command, MUST remove the output to os.stdout, eng.SetMapper(names.GonicMapper{})
// so use log file to instead print to stdout. // WARNING: for serv command, MUST remove the output to os.Stdout,
xormEngine.SetLogger(NewXORMLogger(setting.Database.LogSQL)) // so use a log file instead of printing to stdout.
xormEngine.ShowSQL(setting.Database.LogSQL) eng.SetLogger(NewXORMLogger(setting.Database.LogSQL))
xormEngine.SetMaxOpenConns(setting.Database.MaxOpenConns) eng.ShowSQL(setting.Database.LogSQL)
xormEngine.SetMaxIdleConns(setting.Database.MaxIdleConns) eng.SetMaxOpenConns(setting.Database.MaxOpenConns)
xormEngine.SetConnMaxLifetime(setting.Database.ConnMaxLifetime) eng.SetMaxIdleConns(setting.Database.MaxIdleConns)
xormEngine.SetConnMaxIdleTime(setting.Database.ConnMaxIdleTime) eng.SetConnMaxLifetime(setting.Database.ConnMaxLifetime)
xormEngine.SetDefaultContext(ctx) eng.SetConnMaxIdleTime(setting.Database.ConnMaxIdleTime)
eng.SetDefaultContext(ctx)
if setting.Database.SlowQueryThreshold > 0 { if setting.Database.SlowQueryThreshold > 0 {
xormEngine.AddHook(&SlowQueryHook{ eng.AddHook(&SlowQueryHook{
Treshold: setting.Database.SlowQueryThreshold, Treshold: setting.Database.SlowQueryThreshold,
Logger: log.GetLogger("xorm"), Logger: log.GetLogger("xorm"),
}) })
@ -160,22 +197,30 @@ func InitEngine(ctx context.Context) error {
errorLogger = log.GetLogger(log.DEFAULT) errorLogger = log.GetLogger(log.DEFAULT)
} }
xormEngine.AddHook(&ErrorQueryHook{ eng.AddHook(&ErrorQueryHook{
Logger: errorLogger, Logger: errorLogger,
}) })
xormEngine.AddHook(&TracingHook{}) eng.AddHook(&TracingHook{})
SetDefaultEngine(ctx, eng)
} else {
// Fallback: if type assertion fails, set default engine without extended diagnostics
SetDefaultEngine(ctx, xormEngine) SetDefaultEngine(ctx, xormEngine)
}
return nil return nil
} }
// SetDefaultEngine sets the default engine for db // SetDefaultEngine sets the default engine for db.
func SetDefaultEngine(ctx context.Context, eng *xorm.Engine) { func SetDefaultEngine(ctx context.Context, eng Engine) {
x = eng masterEngine, err := GetMasterEngine(eng)
if err == nil {
x = masterEngine
}
DefaultContext = &Context{ DefaultContext = &Context{
Context: ctx, Context: ctx,
e: x, e: eng,
} }
} }
@ -191,12 +236,12 @@ func UnsetDefaultEngine() {
DefaultContext = nil DefaultContext = nil
} }
// InitEngineWithMigration initializes a new xorm.Engine and sets it as the db.DefaultContext // InitEngineWithMigration initializes a new xorm EngineGroup, runs migrations, and sets it as db.DefaultContext
// This function must never call .Sync() if the provided migration function fails. // This function must never call .Sync() if the provided migration function fails.
// When called from the "doctor" command, the migration function is a version check // When called from the "doctor" command, the migration function is a version check
// that prevents the doctor from fixing anything in the database if the migration level // that prevents the doctor from fixing anything in the database if the migration level
// is different from the expected value. // is different from the expected value.
func InitEngineWithMigration(ctx context.Context, migrateFunc func(*xorm.Engine) error) (err error) { func InitEngineWithMigration(ctx context.Context, migrateFunc func(Engine) error) (err error) {
if err = InitEngine(ctx); err != nil { if err = InitEngine(ctx); err != nil {
return err return err
} }
@ -230,14 +275,14 @@ func InitEngineWithMigration(ctx context.Context, migrateFunc func(*xorm.Engine)
return nil return nil
} }
// NamesToBean return a list of beans or an error // NamesToBean returns a list of beans given names
func NamesToBean(names ...string) ([]any, error) { func NamesToBean(names ...string) ([]any, error) {
beans := []any{} beans := []any{}
if len(names) == 0 { if len(names) == 0 {
beans = append(beans, tables...) beans = append(beans, tables...)
return beans, nil return beans, nil
} }
// Need to map provided names to beans... // Map provided names to beans
beanMap := make(map[string]any) beanMap := make(map[string]any)
for _, bean := range tables { for _, bean := range tables {
beanMap[strings.ToLower(reflect.Indirect(reflect.ValueOf(bean)).Type().Name())] = bean beanMap[strings.ToLower(reflect.Indirect(reflect.ValueOf(bean)).Type().Name())] = bean
@ -259,7 +304,7 @@ func NamesToBean(names ...string) ([]any, error) {
return beans, nil return beans, nil
} }
// DumpDatabase dumps all data from database according the special database SQL syntax to file system. // DumpDatabase dumps all data from database using special SQL syntax to the file system.
func DumpDatabase(filePath, dbType string) error { func DumpDatabase(filePath, dbType string) error {
var tbs []*schemas.Table var tbs []*schemas.Table
for _, t := range tables { for _, t := range tables {
@ -295,29 +340,33 @@ func MaxBatchInsertSize(bean any) int {
return 999 / len(t.ColumnsSeq()) return 999 / len(t.ColumnsSeq())
} }
// IsTableNotEmpty returns true if table has at least one record // IsTableNotEmpty returns true if the table has at least one record
func IsTableNotEmpty(beanOrTableName any) (bool, error) { func IsTableNotEmpty(beanOrTableName any) (bool, error) {
return x.Table(beanOrTableName).Exist() return x.Table(beanOrTableName).Exist()
} }
// DeleteAllRecords will delete all the records of this table // DeleteAllRecords deletes all records in the given table.
func DeleteAllRecords(tableName string) error { func DeleteAllRecords(tableName string) error {
_, err := x.Exec(fmt.Sprintf("DELETE FROM %s", tableName)) _, err := x.Exec(fmt.Sprintf("DELETE FROM %s", tableName))
return err return err
} }
// GetMaxID will return max id of the table // GetMaxID returns the maximum id in the table
func GetMaxID(beanOrTableName any) (maxID int64, err error) { func GetMaxID(beanOrTableName any) (maxID int64, err error) {
_, err = x.Select("MAX(id)").Table(beanOrTableName).Get(&maxID) _, err = x.Select("MAX(id)").Table(beanOrTableName).Get(&maxID)
return maxID, err return maxID, err
} }
func SetLogSQL(ctx context.Context, on bool) { func SetLogSQL(ctx context.Context, on bool) {
e := GetEngine(ctx) ctxEngine := GetEngine(ctx)
if x, ok := e.(*xorm.Engine); ok {
x.ShowSQL(on) if sess, ok := ctxEngine.(*xorm.Session); ok {
} else if sess, ok := e.(*xorm.Session); ok {
sess.Engine().ShowSQL(on) sess.Engine().ShowSQL(on)
} else if wrapper, ok := ctxEngine.(engineGroupWrapper); ok {
// Handle engineGroupWrapper directly
wrapper.ShowSQL(on)
} else if masterEngine, err := GetMasterEngine(ctxEngine); err == nil {
masterEngine.ShowSQL(on)
} }
} }
@ -374,3 +423,18 @@ func (h *ErrorQueryHook) AfterProcess(c *contexts.ContextHook) error {
} }
return nil return nil
} }
// GetMasterEngine extracts the master xorm.Engine from the provided xorm.Engine.
// This handles both direct xorm.Engine cases and engines that implement a Master() method.
func GetMasterEngine(x Engine) (*xorm.Engine, error) {
if getter, ok := x.(interface{ Master() *xorm.Engine }); ok {
return getter.Master(), nil
}
engine, ok := x.(*xorm.Engine)
if !ok {
return nil, fmt.Errorf("unsupported engine type: %T", x)
}
return engine, nil
}

View file

@ -33,10 +33,11 @@ func getCurrentResourceIndex(ctx context.Context, tableName string, groupID int6
func TestSyncMaxResourceIndex(t *testing.T) { func TestSyncMaxResourceIndex(t *testing.T) {
require.NoError(t, unittest.PrepareTestDatabase()) require.NoError(t, unittest.PrepareTestDatabase())
xe := unittest.GetXORMEngine() xe, err := unittest.GetXORMEngine()
require.NoError(t, err)
require.NoError(t, xe.Sync(&TestIndex{})) require.NoError(t, xe.Sync(&TestIndex{}))
err := db.SyncMaxResourceIndex(db.DefaultContext, "test_index", 10, 51) err = db.SyncMaxResourceIndex(db.DefaultContext, "test_index", 10, 51)
require.NoError(t, err) require.NoError(t, err)
// sync new max index // sync new max index
@ -88,7 +89,8 @@ func TestSyncMaxResourceIndex(t *testing.T) {
func TestGetNextResourceIndex(t *testing.T) { func TestGetNextResourceIndex(t *testing.T) {
require.NoError(t, unittest.PrepareTestDatabase()) require.NoError(t, unittest.PrepareTestDatabase())
xe := unittest.GetXORMEngine() xe, err := unittest.GetXORMEngine()
require.NoError(t, err)
require.NoError(t, xe.Sync(&TestIndex{})) require.NoError(t, xe.Sync(&TestIndex{}))
// create a new record // create a new record

View file

@ -17,7 +17,8 @@ import (
func TestIterate(t *testing.T) { func TestIterate(t *testing.T) {
require.NoError(t, unittest.PrepareTestDatabase()) require.NoError(t, unittest.PrepareTestDatabase())
xe := unittest.GetXORMEngine() xe, err := unittest.GetXORMEngine()
require.NoError(t, err)
require.NoError(t, xe.Sync(&repo_model.RepoUnit{})) require.NoError(t, xe.Sync(&repo_model.RepoUnit{}))
cnt, err := db.GetEngine(db.DefaultContext).Count(&repo_model.RepoUnit{}) cnt, err := db.GetEngine(db.DefaultContext).Count(&repo_model.RepoUnit{})

View file

@ -29,11 +29,12 @@ func (opts mockListOptions) ToConds() builder.Cond {
func TestFind(t *testing.T) { func TestFind(t *testing.T) {
require.NoError(t, unittest.PrepareTestDatabase()) require.NoError(t, unittest.PrepareTestDatabase())
xe := unittest.GetXORMEngine() xe, err := unittest.GetXORMEngine()
require.NoError(t, err)
require.NoError(t, xe.Sync(&repo_model.RepoUnit{})) require.NoError(t, xe.Sync(&repo_model.RepoUnit{}))
var repoUnitCount int var repoUnitCount int
_, err := db.GetEngine(db.DefaultContext).SQL("SELECT COUNT(*) FROM repo_unit").Get(&repoUnitCount) _, err = db.GetEngine(db.DefaultContext).SQL("SELECT COUNT(*) FROM repo_unit").Get(&repoUnitCount)
require.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, repoUnitCount) assert.NotEmpty(t, repoUnitCount)

View file

@ -8,6 +8,7 @@ import (
"context" "context"
"fmt" "fmt"
"forgejo.org/models/db"
"forgejo.org/models/forgejo_migrations" "forgejo.org/models/forgejo_migrations"
"forgejo.org/models/migrations/v1_10" "forgejo.org/models/migrations/v1_10"
"forgejo.org/models/migrations/v1_11" "forgejo.org/models/migrations/v1_11"
@ -510,3 +511,12 @@ Please try upgrading to a lower version first (suggested v1.6.4), then upgrade t
// Execute Forgejo specific migrations. // Execute Forgejo specific migrations.
return forgejo_migrations.Migrate(x) return forgejo_migrations.Migrate(x)
} }
// WrapperMigrate is a wrapper for Migrate to be called in diagnostics
func WrapperMigrate(e db.Engine) error {
engine, err := db.GetMasterEngine(e)
if err != nil {
return err
}
return Migrate(engine)
}

View file

@ -175,7 +175,10 @@ func newXORMEngine() (*xorm.Engine, error) {
if err := db.InitEngine(context.Background()); err != nil { if err := db.InitEngine(context.Background()); err != nil {
return nil, err return nil, err
} }
x := unittest.GetXORMEngine() x, err := unittest.GetXORMEngine()
if err != nil {
return nil, err
}
return x, nil return x, nil
} }

View file

@ -22,11 +22,11 @@ import (
var fixturesLoader *testfixtures.Loader var fixturesLoader *testfixtures.Loader
// GetXORMEngine gets the XORM engine // GetXORMEngine gets the XORM engine
func GetXORMEngine(engine ...*xorm.Engine) (x *xorm.Engine) { func GetXORMEngine(engine ...*xorm.Engine) (x *xorm.Engine, err error) {
if len(engine) == 1 { if len(engine) == 1 {
return engine[0] return engine[0], nil
} }
return db.DefaultContext.(*db.Context).Engine().(*xorm.Engine) return db.GetMasterEngine(db.DefaultContext.(*db.Context).Engine())
} }
func OverrideFixtures(opts FixturesOptions, engine ...*xorm.Engine) func() { func OverrideFixtures(opts FixturesOptions, engine ...*xorm.Engine) func() {
@ -41,7 +41,10 @@ func OverrideFixtures(opts FixturesOptions, engine ...*xorm.Engine) func() {
// InitFixtures initialize test fixtures for a test database // InitFixtures initialize test fixtures for a test database
func InitFixtures(opts FixturesOptions, engine ...*xorm.Engine) (err error) { func InitFixtures(opts FixturesOptions, engine ...*xorm.Engine) (err error) {
e := GetXORMEngine(engine...) e, err := GetXORMEngine(engine...)
if err != nil {
return err
}
var fixtureOptionFiles func(*testfixtures.Loader) error var fixtureOptionFiles func(*testfixtures.Loader) error
if opts.Dir != "" { if opts.Dir != "" {
fixtureOptionFiles = testfixtures.Directory(opts.Dir) fixtureOptionFiles = testfixtures.Directory(opts.Dir)
@ -93,10 +96,12 @@ func InitFixtures(opts FixturesOptions, engine ...*xorm.Engine) (err error) {
// LoadFixtures load fixtures for a test database // LoadFixtures load fixtures for a test database
func LoadFixtures(engine ...*xorm.Engine) error { func LoadFixtures(engine ...*xorm.Engine) error {
e := GetXORMEngine(engine...) e, err := GetXORMEngine(engine...)
var err error if err != nil {
return err
}
// (doubt) database transaction conflicts could occur and result in ROLLBACK? just try for a few times. // (doubt) database transaction conflicts could occur and result in ROLLBACK? just try for a few times.
for i := 0; i < 5; i++ { for range 5 {
if err = fixturesLoader.Load(); err == nil { if err = fixturesLoader.Load(); err == nil {
break break
} }

View file

@ -10,8 +10,13 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"time" "time"
"forgejo.org/modules/log"
"xorm.io/xorm"
) )
var ( var (
@ -24,9 +29,19 @@ var (
EnableSQLite3 bool EnableSQLite3 bool
// Database holds the database settings // Database holds the database settings
Database = struct { Database = DatabaseSettings{
Timeout: 500,
IterateBufferSize: 50,
}
)
type DatabaseSettings struct {
Type DatabaseType Type DatabaseType
Host string Host string
HostPrimary string
HostReplica string
LoadBalancePolicy string
LoadBalanceWeights string
Name string Name string
User string User string
Passwd string Passwd string
@ -47,11 +62,7 @@ var (
IterateBufferSize int IterateBufferSize int
AutoMigration bool AutoMigration bool
SlowQueryThreshold time.Duration SlowQueryThreshold time.Duration
}{
Timeout: 500,
IterateBufferSize: 50,
} }
)
// LoadDBSetting loads the database settings // LoadDBSetting loads the database settings
func LoadDBSetting() { func LoadDBSetting() {
@ -63,6 +74,10 @@ func loadDBSetting(rootCfg ConfigProvider) {
Database.Type = DatabaseType(sec.Key("DB_TYPE").String()) Database.Type = DatabaseType(sec.Key("DB_TYPE").String())
Database.Host = sec.Key("HOST").String() Database.Host = sec.Key("HOST").String()
Database.HostPrimary = sec.Key("HOST_PRIMARY").String()
Database.HostReplica = sec.Key("HOST_REPLICA").String()
Database.LoadBalancePolicy = sec.Key("LOAD_BALANCE_POLICY").String()
Database.LoadBalanceWeights = sec.Key("LOAD_BALANCE_WEIGHTS").String()
Database.Name = sec.Key("NAME").String() Database.Name = sec.Key("NAME").String()
Database.User = sec.Key("USER").String() Database.User = sec.Key("USER").String()
if len(Database.Passwd) == 0 { if len(Database.Passwd) == 0 {
@ -99,8 +114,93 @@ func loadDBSetting(rootCfg ConfigProvider) {
} }
} }
// DBConnStr returns database connection string // DBMasterConnStr returns the connection string for the master (primary) database.
func DBConnStr() (string, error) { // If a primary host is defined in the configuration, it is used;
// otherwise, it falls back to Database.Host.
// Returns an error if no master host is provided but a slave is defined.
func DBMasterConnStr() (string, error) {
var host string
if Database.HostPrimary != "" {
host = Database.HostPrimary
} else {
host = Database.Host
}
if host == "" && Database.HostReplica != "" {
return "", errors.New("master host is not defined while slave is defined; cannot proceed")
}
// For SQLite, no host is needed
if host == "" && !Database.Type.IsSQLite3() {
return "", errors.New("no database host defined")
}
return dbConnStrWithHost(host)
}
// DBSlaveConnStrs returns one or more connection strings for the replica databases.
// If a replica host is defined (possibly as a comma-separated list) then those DSNs are returned.
// Otherwise, this function falls back to the master DSN (with a warning log).
func DBSlaveConnStrs() ([]string, error) {
var dsns []string
if Database.HostReplica != "" {
// support multiple replica hosts separated by commas
replicas := strings.SplitSeq(Database.HostReplica, ",")
for r := range replicas {
trimmed := strings.TrimSpace(r)
if trimmed == "" {
continue
}
dsn, err := dbConnStrWithHost(trimmed)
if err != nil {
return nil, err
}
dsns = append(dsns, dsn)
}
}
// Fall back to master if no slave DSN was provided.
if len(dsns) == 0 {
master, err := DBMasterConnStr()
if err != nil {
return nil, err
}
log.Debug("Database: No dedicated replica host defined; falling back to primary DSN for replica connections")
dsns = append(dsns, master)
}
return dsns, nil
}
func BuildLoadBalancePolicy(settings *DatabaseSettings, slaveEngines []*xorm.Engine) xorm.GroupPolicy {
var policy xorm.GroupPolicy
switch settings.LoadBalancePolicy { // Use the settings parameter directly
case "WeightRandom":
var weights []int
if settings.LoadBalanceWeights != "" { // Use the settings parameter directly
for part := range strings.SplitSeq(settings.LoadBalanceWeights, ",") {
w, err := strconv.Atoi(strings.TrimSpace(part))
if err != nil {
w = 1 // use a default weight if conversion fails
}
weights = append(weights, w)
}
}
// If no valid weights were provided, default each slave to weight 1
if len(weights) == 0 {
weights = make([]int, len(slaveEngines))
for i := range weights {
weights[i] = 1
}
}
policy = xorm.WeightRandomPolicy(weights)
case "RoundRobin":
policy = xorm.RoundRobinPolicy()
default:
policy = xorm.RandomPolicy()
}
return policy
}
// dbConnStrWithHost constructs the connection string, given a host value.
func dbConnStrWithHost(host string) (string, error) {
var connStr string var connStr string
paramSep := "?" paramSep := "?"
if strings.Contains(Database.Name, paramSep) { if strings.Contains(Database.Name, paramSep) {
@ -109,23 +209,25 @@ func DBConnStr() (string, error) {
switch Database.Type { switch Database.Type {
case "mysql": case "mysql":
connType := "tcp" connType := "tcp"
if len(Database.Host) > 0 && Database.Host[0] == '/' { // looks like a unix socket // if the host starts with '/' it is assumed to be a unix socket path
if len(host) > 0 && host[0] == '/' {
connType = "unix" connType = "unix"
} }
tls := Database.SSLMode tls := Database.SSLMode
if tls == "disable" { // allow (Postgres-inspired) default value to work in MySQL // allow the "disable" value (borrowed from Postgres defaults) to behave as false
if tls == "disable" {
tls = "false" tls = "false"
} }
connStr = fmt.Sprintf("%s:%s@%s(%s)/%s%sparseTime=true&tls=%s", connStr = fmt.Sprintf("%s:%s@%s(%s)/%s%sparseTime=true&tls=%s",
Database.User, Database.Passwd, connType, Database.Host, Database.Name, paramSep, tls) Database.User, Database.Passwd, connType, host, Database.Name, paramSep, tls)
case "postgres": case "postgres":
connStr = getPostgreSQLConnectionString(Database.Host, Database.User, Database.Passwd, Database.Name, Database.SSLMode) connStr = getPostgreSQLConnectionString(host, Database.User, Database.Passwd, Database.Name, Database.SSLMode)
case "sqlite3": case "sqlite3":
if !EnableSQLite3 { if !EnableSQLite3 {
return "", errors.New("this Gitea binary was not built with SQLite3 support") return "", errors.New("this Gitea binary was not built with SQLite3 support")
} }
if err := os.MkdirAll(filepath.Dir(Database.Path), os.ModePerm); err != nil { if err := os.MkdirAll(filepath.Dir(Database.Path), os.ModePerm); err != nil {
return "", fmt.Errorf("Failed to create directories: %w", err) return "", fmt.Errorf("failed to create directories: %w", err)
} }
journalMode := "" journalMode := ""
if Database.SQLiteJournalMode != "" { if Database.SQLiteJournalMode != "" {
@ -136,7 +238,6 @@ func DBConnStr() (string, error) {
default: default:
return "", fmt.Errorf("unknown database type: %s", Database.Type) return "", fmt.Errorf("unknown database type: %s", Database.Type)
} }
return connStr, nil return connStr, nil
} }

View file

@ -4,6 +4,7 @@
package setting package setting
import ( import (
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -107,3 +108,104 @@ func Test_getPostgreSQLConnectionString(t *testing.T) {
assert.Equal(t, test.Output, connStr) assert.Equal(t, test.Output, connStr)
} }
} }
func getPostgreSQLEngineGroupConnectionStrings(primaryHost, replicaHosts, user, passwd, name, sslmode string) (string, []string) {
// Determine the primary connection string.
primary := primaryHost
if strings.TrimSpace(primary) == "" {
primary = "127.0.0.1:5432"
}
primaryConn := getPostgreSQLConnectionString(primary, user, passwd, name, sslmode)
// Build the replica connection strings.
replicaConns := []string{}
if strings.TrimSpace(replicaHosts) != "" {
// Split comma-separated replica host values.
hosts := strings.Split(replicaHosts, ",")
for _, h := range hosts {
trimmed := strings.TrimSpace(h)
if trimmed != "" {
replicaConns = append(replicaConns,
getPostgreSQLConnectionString(trimmed, user, passwd, name, sslmode))
}
}
}
return primaryConn, replicaConns
}
func Test_getPostgreSQLEngineGroupConnectionStrings(t *testing.T) {
tests := []struct {
primaryHost string // primary host setting (e.g. "localhost" or "[::1]:1234")
replicaHosts string // comma-separated replica hosts (e.g. "replica1,replica2:2345")
user string
passwd string
name string
sslmode string
outputPrimary string
outputReplicas []string
}{
{
// No primary override (empty => default) and no replicas.
primaryHost: "",
replicaHosts: "",
user: "",
passwd: "",
name: "",
sslmode: "",
outputPrimary: "postgres://:@127.0.0.1:5432?sslmode=",
outputReplicas: []string{},
},
{
// Primary set and one replica.
primaryHost: "localhost",
replicaHosts: "replicahost",
user: "user",
passwd: "pass",
name: "gitea",
sslmode: "disable",
outputPrimary: "postgres://user:pass@localhost:5432/gitea?sslmode=disable",
outputReplicas: []string{"postgres://user:pass@replicahost:5432/gitea?sslmode=disable"},
},
{
// Primary with explicit port; multiple replicas (one without and one with an explicit port).
primaryHost: "localhost:5433",
replicaHosts: "replica1,replica2:5434",
user: "test",
passwd: "secret",
name: "db",
sslmode: "require",
outputPrimary: "postgres://test:secret@localhost:5433/db?sslmode=require",
outputReplicas: []string{
"postgres://test:secret@replica1:5432/db?sslmode=require",
"postgres://test:secret@replica2:5434/db?sslmode=require",
},
},
{
// IPv6 addresses for primary and replica.
primaryHost: "[::1]:1234",
replicaHosts: "[::2]:2345",
user: "ipv6",
passwd: "ipv6pass",
name: "ipv6db",
sslmode: "disable",
outputPrimary: "postgres://ipv6:ipv6pass@[::1]:1234/ipv6db?sslmode=disable",
outputReplicas: []string{
"postgres://ipv6:ipv6pass@[::2]:2345/ipv6db?sslmode=disable",
},
},
}
for _, test := range tests {
primary, replicas := getPostgreSQLEngineGroupConnectionStrings(
test.primaryHost,
test.replicaHosts,
test.user,
test.passwd,
test.name,
test.sslmode,
)
assert.Equal(t, test.outputPrimary, primary)
assert.Equal(t, test.outputReplicas, replicas)
}
}

View file

@ -364,6 +364,9 @@ var ignoredErrorMessage = []string{
// TestDatabaseCollation // TestDatabaseCollation
`[E] [Error SQL Query] INSERT INTO test_collation_tbl (txt) VALUES ('main') []`, `[E] [Error SQL Query] INSERT INTO test_collation_tbl (txt) VALUES ('main') []`,
// Test_CmdForgejo_Actions
`DB: No dedicated replica host defined; falling back to primary DSN for replica connections`,
// TestDevtestErrorpages // TestDevtestErrorpages
`ErrorPage() [E] Example error: Example error`, `ErrorPage() [E] Example error: Example error`,
} }

View file

@ -28,7 +28,7 @@ func InitDBEngine(ctx context.Context) (err error) {
default: default:
} }
log.Info("ORM engine initialization attempt #%d/%d...", i+1, setting.Database.DBConnectRetries) log.Info("ORM engine initialization attempt #%d/%d...", i+1, setting.Database.DBConnectRetries)
if err = db.InitEngineWithMigration(ctx, migrateWithSetting); err == nil { if err = db.InitEngineWithMigration(ctx, func(eng db.Engine) error { return migrateWithSetting(eng.(*xorm.Engine)) }); err == nil {
break break
} else if i == setting.Database.DBConnectRetries-1 { } else if i == setting.Database.DBConnectRetries-1 {
return err return err

View file

@ -361,7 +361,8 @@ func SubmitInstall(ctx *context.Context) {
} }
// Init the engine with migration // Init the engine with migration
if err = db.InitEngineWithMigration(ctx, migrations.Migrate); err != nil { // Wrap migrations.Migrate into a function of type func(db.Engine) error to fix diagnostics.
if err = db.InitEngineWithMigration(ctx, migrations.WrapperMigrate); err != nil {
db.UnsetDefaultEngine() db.UnsetDefaultEngine()
ctx.Data["Err_DbSetting"] = true ctx.Data["Err_DbSetting"] = true
ctx.RenderWithErr(ctx.Tr("install.invalid_db_setting", err), tplInstall, &form) ctx.RenderWithErr(ctx.Tr("install.invalid_db_setting", err), tplInstall, &form)
@ -587,7 +588,7 @@ func SubmitInstall(ctx *context.Context) {
go func() { go func() {
// Sleep for a while to make sure the user's browser has loaded the post-install page and its assets (images, css, js) // Sleep for a while to make sure the user's browser has loaded the post-install page and its assets (images, css, js)
// What if this duration is not long enough? That's impossible -- if the user can't load the simple page in time, how could they install or use Gitea in the future .... // What if this duration is not long enough? That's impossible -- if the user can't load the simple page in time, how could they install or use Forgejo in the future ....
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
// Now get the http.Server from this request and shut it down // Now get the http.Server from this request and shut it down

View file

@ -78,7 +78,14 @@ func genericOrphanCheck(name, subject, refobject, joincond string) consistencyCh
func checkDBConsistency(ctx context.Context, logger log.Logger, autofix bool) error { func checkDBConsistency(ctx context.Context, logger log.Logger, autofix bool) error {
// make sure DB version is up-to-date // make sure DB version is up-to-date
if err := db.InitEngineWithMigration(ctx, migrations.EnsureUpToDate); err != nil { ensureUpToDateWrapper := func(e db.Engine) error {
engine, err := db.GetMasterEngine(e)
if err != nil {
return err
}
return migrations.EnsureUpToDate(engine)
}
if err := db.InitEngineWithMigration(ctx, ensureUpToDateWrapper); err != nil {
logger.Critical("Model version on the database does not match the current Gitea version. Model consistency will not be checked until the database is upgraded") logger.Critical("Model version on the database does not match the current Gitea version. Model consistency will not be checked until the database is upgraded")
return err return err
} }

View file

@ -9,11 +9,15 @@ import (
"forgejo.org/models/db" "forgejo.org/models/db"
"forgejo.org/models/migrations" "forgejo.org/models/migrations"
"forgejo.org/modules/log" "forgejo.org/modules/log"
"xorm.io/xorm"
) )
func checkDBVersion(ctx context.Context, logger log.Logger, autofix bool) error { func checkDBVersion(ctx context.Context, logger log.Logger, autofix bool) error {
logger.Info("Expected database version: %d", migrations.ExpectedDBVersion()) logger.Info("Expected database version: %d", migrations.ExpectedDBVersion())
if err := db.InitEngineWithMigration(ctx, migrations.EnsureUpToDate); err != nil { if err := db.InitEngineWithMigration(ctx, func(eng db.Engine) error {
return migrations.EnsureUpToDate(eng.(*xorm.Engine))
}); err != nil {
if !autofix { if !autofix {
logger.Critical("Error: %v during ensure up to date", err) logger.Critical("Error: %v during ensure up to date", err)
return err return err
@ -21,7 +25,9 @@ func checkDBVersion(ctx context.Context, logger log.Logger, autofix bool) error
logger.Warn("Got Error: %v during ensure up to date", err) logger.Warn("Got Error: %v during ensure up to date", err)
logger.Warn("Attempting to migrate to the latest DB version to fix this.") logger.Warn("Attempting to migrate to the latest DB version to fix this.")
err = db.InitEngineWithMigration(ctx, migrations.Migrate) err = db.InitEngineWithMigration(ctx, func(eng db.Engine) error {
return migrations.Migrate(eng.(*xorm.Engine))
})
if err != nil { if err != nil {
logger.Critical("Error: %v during migration", err) logger.Critical("Error: %v during migration", err)
} }

View file

@ -16,7 +16,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"xorm.io/xorm"
) )
type TestCollationTbl struct { type TestCollationTbl struct {
@ -48,11 +47,13 @@ func TestDatabaseCollationSelfCheckUI(t *testing.T) {
} }
func TestDatabaseCollation(t *testing.T) { func TestDatabaseCollation(t *testing.T) {
x := db.GetEngine(db.DefaultContext).(*xorm.Engine) engine, err := db.GetMasterEngine(db.GetEngine(db.DefaultContext))
require.NoError(t, err)
x := engine
// all created tables should use case-sensitive collation by default // all created tables should use case-sensitive collation by default
_, _ = x.Exec("DROP TABLE IF EXISTS test_collation_tbl") _, _ = x.Exec("DROP TABLE IF EXISTS test_collation_tbl")
err := x.Sync(&TestCollationTbl{}) err = x.Sync(&TestCollationTbl{})
require.NoError(t, err) require.NoError(t, err)
_, _ = x.Exec("INSERT INTO test_collation_tbl (txt) VALUES ('main')") _, _ = x.Exec("INSERT INTO test_collation_tbl (txt) VALUES ('main')")
_, _ = x.Exec("INSERT INTO test_collation_tbl (txt) VALUES ('Main')") // case-sensitive, so it inserts a new row _, _ = x.Exec("INSERT INTO test_collation_tbl (txt) VALUES ('Main')") // case-sensitive, so it inserts a new row

View file

@ -278,23 +278,36 @@ func doMigrationTest(t *testing.T, version string) {
setting.InitSQLLoggersForCli(log.INFO) setting.InitSQLLoggersForCli(log.INFO)
err := db.InitEngineWithMigration(t.Context(), wrappedMigrate) err := db.InitEngineWithMigration(t.Context(), func(e db.Engine) error {
engine, err := db.GetMasterEngine(e)
if err != nil {
return err
}
currentEngine = engine
return wrappedMigrate(engine)
})
require.NoError(t, err) require.NoError(t, err)
currentEngine.Close() currentEngine.Close()
beans, _ := db.NamesToBean() beans, _ := db.NamesToBean()
err = db.InitEngineWithMigration(t.Context(), func(x *xorm.Engine) error { err = db.InitEngineWithMigration(t.Context(), func(e db.Engine) error {
currentEngine = x currentEngine, err = db.GetMasterEngine(e)
return migrate_base.RecreateTables(beans...)(x) if err != nil {
return err
}
return migrate_base.RecreateTables(beans...)(currentEngine)
}) })
require.NoError(t, err) require.NoError(t, err)
currentEngine.Close() currentEngine.Close()
// We do this a second time to ensure that there is not a problem with retained indices // We do this a second time to ensure that there is not a problem with retained indices
err = db.InitEngineWithMigration(t.Context(), func(x *xorm.Engine) error { err = db.InitEngineWithMigration(t.Context(), func(e db.Engine) error {
currentEngine = x currentEngine, err = db.GetMasterEngine(e)
return migrate_base.RecreateTables(beans...)(x) if err != nil {
return err
}
return migrate_base.RecreateTables(beans...)(currentEngine)
}) })
require.NoError(t, err) require.NoError(t, err)