mirror of
https://codeberg.org/forgejo/forgejo.git
synced 2025-05-31 11:52:10 +00:00
Allow detect whether it's in a database transaction for a context.Context (#21756)
Fix #19513 This PR introduce a new db method `InTransaction(context.Context)`, and also builtin check on `db.TxContext` and `db.WithTx`. There is also a new method `db.AutoTx` has been introduced but could be used by other PRs. `WithTx` will always open a new transaction, if a transaction exist in context, return an error. `AutoTx` will try to open a new transaction if no transaction exist in context. That means it will always enter a transaction if there is no error. Co-authored-by: delvh <dev.lh@web.de> Co-authored-by: 6543 <6543@obermui.de>
This commit is contained in:
parent
a0a425a13b
commit
34283a74e8
91 changed files with 252 additions and 176 deletions
|
@ -8,6 +8,7 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
|
||||
"xorm.io/xorm"
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
|
@ -86,7 +87,11 @@ type Committer interface {
|
|||
}
|
||||
|
||||
// TxContext represents a transaction Context
|
||||
func TxContext() (*Context, Committer, error) {
|
||||
func TxContext(parentCtx context.Context) (*Context, Committer, error) {
|
||||
if InTransaction(parentCtx) {
|
||||
return nil, nil, ErrAlreadyInTransaction
|
||||
}
|
||||
|
||||
sess := x.NewSession()
|
||||
if err := sess.Begin(); err != nil {
|
||||
sess.Close()
|
||||
|
@ -97,14 +102,24 @@ func TxContext() (*Context, Committer, error) {
|
|||
}
|
||||
|
||||
// WithTx represents executing database operations on a transaction
|
||||
// you can optionally change the context to a parent one
|
||||
func WithTx(f func(ctx context.Context) error, stdCtx ...context.Context) error {
|
||||
parentCtx := DefaultContext
|
||||
if len(stdCtx) != 0 && stdCtx[0] != nil {
|
||||
// TODO: make sure parent context has no open session
|
||||
parentCtx = stdCtx[0]
|
||||
// This function will always open a new transaction, if a transaction exist in parentCtx return an error.
|
||||
func WithTx(parentCtx context.Context, f func(ctx context.Context) error) error {
|
||||
if InTransaction(parentCtx) {
|
||||
return ErrAlreadyInTransaction
|
||||
}
|
||||
return txWithNoCheck(parentCtx, f)
|
||||
}
|
||||
|
||||
// AutoTx represents executing database operations on a transaction, if the transaction exist,
|
||||
// this function will reuse it otherwise will create a new one and close it when finished.
|
||||
func AutoTx(parentCtx context.Context, f func(ctx context.Context) error) error {
|
||||
if InTransaction(parentCtx) {
|
||||
return f(newContext(parentCtx, GetEngine(parentCtx), true))
|
||||
}
|
||||
return txWithNoCheck(parentCtx, f)
|
||||
}
|
||||
|
||||
func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error) error {
|
||||
sess := x.NewSession()
|
||||
defer sess.Close()
|
||||
if err := sess.Begin(); err != nil {
|
||||
|
@ -180,3 +195,28 @@ func EstimateCount(ctx context.Context, bean interface{}) (int64, error) {
|
|||
}
|
||||
return rows, err
|
||||
}
|
||||
|
||||
// InTransaction returns true if the engine is in a transaction otherwise return false
|
||||
func InTransaction(ctx context.Context) bool {
|
||||
var e Engine
|
||||
if engined, ok := ctx.(Engined); ok {
|
||||
e = engined.Engine()
|
||||
} else {
|
||||
enginedInterface := ctx.Value(enginedContextKey)
|
||||
if enginedInterface != nil {
|
||||
e = enginedInterface.(Engined).Engine()
|
||||
}
|
||||
}
|
||||
if e == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch t := e.(type) {
|
||||
case *xorm.Engine:
|
||||
return false
|
||||
case *xorm.Session:
|
||||
return t.IsInTx()
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
33
models/db/context_test.go
Normal file
33
models/db/context_test.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestInTransaction(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
assert.False(t, db.InTransaction(db.DefaultContext))
|
||||
assert.NoError(t, db.WithTx(db.DefaultContext, func(ctx context.Context) error {
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
return nil
|
||||
}))
|
||||
|
||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
||||
assert.NoError(t, err)
|
||||
defer committer.Close()
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
assert.Error(t, db.WithTx(ctx, func(ctx context.Context) error {
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
return nil
|
||||
}))
|
||||
}
|
|
@ -5,11 +5,14 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"code.gitea.io/gitea/modules/util"
|
||||
)
|
||||
|
||||
var ErrAlreadyInTransaction = errors.New("database connection has already been in a transaction")
|
||||
|
||||
// ErrCancelled represents an error due to context cancellation
|
||||
type ErrCancelled struct {
|
||||
Message string
|
||||
|
|
|
@ -59,7 +59,7 @@ func TestSyncMaxResourceIndex(t *testing.T) {
|
|||
assert.EqualValues(t, 62, maxIndex)
|
||||
|
||||
// commit transaction
|
||||
err = db.WithTx(func(ctx context.Context) error {
|
||||
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
|
||||
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 73)
|
||||
assert.NoError(t, err)
|
||||
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
|
||||
|
@ -73,7 +73,7 @@ func TestSyncMaxResourceIndex(t *testing.T) {
|
|||
assert.EqualValues(t, 73, maxIndex)
|
||||
|
||||
// rollback transaction
|
||||
err = db.WithTx(func(ctx context.Context) error {
|
||||
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
|
||||
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 84)
|
||||
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
|
||||
assert.NoError(t, err)
|
||||
|
@ -102,7 +102,7 @@ func TestGetNextResourceIndex(t *testing.T) {
|
|||
assert.EqualValues(t, 2, maxIndex)
|
||||
|
||||
// commit transaction
|
||||
err = db.WithTx(func(ctx context.Context) error {
|
||||
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
|
||||
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 3, maxIndex)
|
||||
|
@ -114,7 +114,7 @@ func TestGetNextResourceIndex(t *testing.T) {
|
|||
assert.EqualValues(t, 3, maxIndex)
|
||||
|
||||
// rollback transaction
|
||||
err = db.WithTx(func(ctx context.Context) error {
|
||||
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
|
||||
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 4, maxIndex)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue