More refactoring of db.DefaultContext (#27083)

Next step of #27065
This commit is contained in:
JakobDev 2023-09-15 08:13:19 +02:00 committed by GitHub
parent f8a1094406
commit c548dde205
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
83 changed files with 336 additions and 320 deletions

View file

@ -5,6 +5,7 @@
package auth
import (
"context"
"crypto/subtle"
"encoding/hex"
"fmt"
@ -95,7 +96,7 @@ func init() {
}
// NewAccessToken creates new access token.
func NewAccessToken(t *AccessToken) error {
func NewAccessToken(ctx context.Context, t *AccessToken) error {
salt, err := util.CryptoRandomString(10)
if err != nil {
return err
@ -108,7 +109,7 @@ func NewAccessToken(t *AccessToken) error {
t.Token = hex.EncodeToString(token)
t.TokenHash = HashToken(t.Token, t.TokenSalt)
t.TokenLastEight = t.Token[len(t.Token)-8:]
_, err = db.GetEngine(db.DefaultContext).Insert(t)
_, err = db.GetEngine(ctx).Insert(t)
return err
}
@ -137,7 +138,7 @@ func getAccessTokenIDFromCache(token string) int64 {
}
// GetAccessTokenBySHA returns access token by given token value
func GetAccessTokenBySHA(token string) (*AccessToken, error) {
func GetAccessTokenBySHA(ctx context.Context, token string) (*AccessToken, error) {
if token == "" {
return nil, ErrAccessTokenEmpty{}
}
@ -158,7 +159,7 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) {
TokenLastEight: lastEight,
}
// Re-get the token from the db in case it has been deleted in the intervening period
has, err := db.GetEngine(db.DefaultContext).ID(id).Get(accessToken)
has, err := db.GetEngine(ctx).ID(id).Get(accessToken)
if err != nil {
return nil, err
}
@ -169,7 +170,7 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) {
}
var tokens []AccessToken
err := db.GetEngine(db.DefaultContext).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens)
err := db.GetEngine(ctx).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens)
if err != nil {
return nil, err
} else if len(tokens) == 0 {
@ -189,8 +190,8 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) {
}
// AccessTokenByNameExists checks if a token name has been used already by a user.
func AccessTokenByNameExists(token *AccessToken) (bool, error) {
return db.GetEngine(db.DefaultContext).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist()
func AccessTokenByNameExists(ctx context.Context, token *AccessToken) (bool, error) {
return db.GetEngine(ctx).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist()
}
// ListAccessTokensOptions contain filter options
@ -201,8 +202,8 @@ type ListAccessTokensOptions struct {
}
// ListAccessTokens returns a list of access tokens belongs to given user.
func ListAccessTokens(opts ListAccessTokensOptions) ([]*AccessToken, error) {
sess := db.GetEngine(db.DefaultContext).Where("uid=?", opts.UserID)
func ListAccessTokens(ctx context.Context, opts ListAccessTokensOptions) ([]*AccessToken, error) {
sess := db.GetEngine(ctx).Where("uid=?", opts.UserID)
if len(opts.Name) != 0 {
sess = sess.Where("name=?", opts.Name)
@ -222,14 +223,14 @@ func ListAccessTokens(opts ListAccessTokensOptions) ([]*AccessToken, error) {
}
// UpdateAccessToken updates information of access token.
func UpdateAccessToken(t *AccessToken) error {
_, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t)
func UpdateAccessToken(ctx context.Context, t *AccessToken) error {
_, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
return err
}
// CountAccessTokens count access tokens belongs to given user by options
func CountAccessTokens(opts ListAccessTokensOptions) (int64, error) {
sess := db.GetEngine(db.DefaultContext).Where("uid=?", opts.UserID)
func CountAccessTokens(ctx context.Context, opts ListAccessTokensOptions) (int64, error) {
sess := db.GetEngine(ctx).Where("uid=?", opts.UserID)
if len(opts.Name) != 0 {
sess = sess.Where("name=?", opts.Name)
}
@ -237,8 +238,8 @@ func CountAccessTokens(opts ListAccessTokensOptions) (int64, error) {
}
// DeleteAccessTokenByID deletes access token by given ID.
func DeleteAccessTokenByID(id, userID int64) error {
cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&AccessToken{
func DeleteAccessTokenByID(ctx context.Context, id, userID int64) error {
cnt, err := db.GetEngine(ctx).ID(id).Delete(&AccessToken{
UID: userID,
})
if err != nil {

View file

@ -7,6 +7,7 @@ import (
"testing"
auth_model "code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
"github.com/stretchr/testify/assert"
@ -18,7 +19,7 @@ func TestNewAccessToken(t *testing.T) {
UID: 3,
Name: "Token C",
}
assert.NoError(t, auth_model.NewAccessToken(token))
assert.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token))
unittest.AssertExistsAndLoadBean(t, token)
invalidToken := &auth_model.AccessToken{
@ -26,7 +27,7 @@ func TestNewAccessToken(t *testing.T) {
UID: 2,
Name: "Token F",
}
assert.Error(t, auth_model.NewAccessToken(invalidToken))
assert.Error(t, auth_model.NewAccessToken(db.DefaultContext, invalidToken))
}
func TestAccessTokenByNameExists(t *testing.T) {
@ -39,16 +40,16 @@ func TestAccessTokenByNameExists(t *testing.T) {
}
// Check to make sure it doesn't exists already
exist, err := auth_model.AccessTokenByNameExists(token)
exist, err := auth_model.AccessTokenByNameExists(db.DefaultContext, token)
assert.NoError(t, err)
assert.False(t, exist)
// Save it to the database
assert.NoError(t, auth_model.NewAccessToken(token))
assert.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token))
unittest.AssertExistsAndLoadBean(t, token)
// This token must be found by name in the DB now
exist, err = auth_model.AccessTokenByNameExists(token)
exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, token)
assert.NoError(t, err)
assert.True(t, exist)
@ -59,32 +60,32 @@ func TestAccessTokenByNameExists(t *testing.T) {
// Name matches but different user ID, this shouldn't exists in the
// database
exist, err = auth_model.AccessTokenByNameExists(user4Token)
exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, user4Token)
assert.NoError(t, err)
assert.False(t, exist)
}
func TestGetAccessTokenBySHA(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
token, err := auth_model.GetAccessTokenBySHA("d2c6c1ba3890b309189a8e618c72a162e4efbf36")
token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "d2c6c1ba3890b309189a8e618c72a162e4efbf36")
assert.NoError(t, err)
assert.Equal(t, int64(1), token.UID)
assert.Equal(t, "Token A", token.Name)
assert.Equal(t, "2b3668e11cb82d3af8c6e4524fc7841297668f5008d1626f0ad3417e9fa39af84c268248b78c481daa7e5dc437784003494f", token.TokenHash)
assert.Equal(t, "e4efbf36", token.TokenLastEight)
_, err = auth_model.GetAccessTokenBySHA("notahash")
_, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "notahash")
assert.Error(t, err)
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
_, err = auth_model.GetAccessTokenBySHA("")
_, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "")
assert.Error(t, err)
assert.True(t, auth_model.IsErrAccessTokenEmpty(err))
}
func TestListAccessTokens(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
tokens, err := auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 1})
tokens, err := auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 1})
assert.NoError(t, err)
if assert.Len(t, tokens, 2) {
assert.Equal(t, int64(1), tokens[0].UID)
@ -93,39 +94,39 @@ func TestListAccessTokens(t *testing.T) {
assert.Contains(t, []string{tokens[0].Name, tokens[1].Name}, "Token B")
}
tokens, err = auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 2})
tokens, err = auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 2})
assert.NoError(t, err)
if assert.Len(t, tokens, 1) {
assert.Equal(t, int64(2), tokens[0].UID)
assert.Equal(t, "Token A", tokens[0].Name)
}
tokens, err = auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 100})
tokens, err = auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 100})
assert.NoError(t, err)
assert.Empty(t, tokens)
}
func TestUpdateAccessToken(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
token, err := auth_model.GetAccessTokenBySHA("4c6f36e6cf498e2a448662f915d932c09c5a146c")
token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c")
assert.NoError(t, err)
token.Name = "Token Z"
assert.NoError(t, auth_model.UpdateAccessToken(token))
assert.NoError(t, auth_model.UpdateAccessToken(db.DefaultContext, token))
unittest.AssertExistsAndLoadBean(t, token)
}
func TestDeleteAccessTokenByID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
token, err := auth_model.GetAccessTokenBySHA("4c6f36e6cf498e2a448662f915d932c09c5a146c")
token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c")
assert.NoError(t, err)
assert.Equal(t, int64(1), token.UID)
assert.NoError(t, auth_model.DeleteAccessTokenByID(token.ID, 1))
assert.NoError(t, auth_model.DeleteAccessTokenByID(db.DefaultContext, token.ID, 1))
unittest.AssertNotExistsBean(t, token)
err = auth_model.DeleteAccessTokenByID(100, 100)
err = auth_model.DeleteAccessTokenByID(db.DefaultContext, 100, 100)
assert.Error(t, err)
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
}

View file

@ -4,6 +4,7 @@
package auth
import (
"context"
"crypto/md5"
"crypto/subtle"
"encoding/base32"
@ -121,22 +122,22 @@ func (t *TwoFactor) ValidateTOTP(passcode string) (bool, error) {
}
// NewTwoFactor creates a new two-factor authentication token.
func NewTwoFactor(t *TwoFactor) error {
_, err := db.GetEngine(db.DefaultContext).Insert(t)
func NewTwoFactor(ctx context.Context, t *TwoFactor) error {
_, err := db.GetEngine(ctx).Insert(t)
return err
}
// UpdateTwoFactor updates a two-factor authentication token.
func UpdateTwoFactor(t *TwoFactor) error {
_, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t)
func UpdateTwoFactor(ctx context.Context, t *TwoFactor) error {
_, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
return err
}
// GetTwoFactorByUID returns the two-factor authentication token associated with
// the user, if any.
func GetTwoFactorByUID(uid int64) (*TwoFactor, error) {
func GetTwoFactorByUID(ctx context.Context, uid int64) (*TwoFactor, error) {
twofa := &TwoFactor{}
has, err := db.GetEngine(db.DefaultContext).Where("uid=?", uid).Get(twofa)
has, err := db.GetEngine(ctx).Where("uid=?", uid).Get(twofa)
if err != nil {
return nil, err
} else if !has {
@ -147,13 +148,13 @@ func GetTwoFactorByUID(uid int64) (*TwoFactor, error) {
// HasTwoFactorByUID returns the two-factor authentication token associated with
// the user, if any.
func HasTwoFactorByUID(uid int64) (bool, error) {
return db.GetEngine(db.DefaultContext).Where("uid=?", uid).Exist(&TwoFactor{})
func HasTwoFactorByUID(ctx context.Context, uid int64) (bool, error) {
return db.GetEngine(ctx).Where("uid=?", uid).Exist(&TwoFactor{})
}
// DeleteTwoFactorByID deletes two-factor authentication token by given ID.
func DeleteTwoFactorByID(id, userID int64) error {
cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&TwoFactor{
func DeleteTwoFactorByID(ctx context.Context, id, userID int64) error {
cnt, err := db.GetEngine(ctx).ID(id).Delete(&TwoFactor{
UID: userID,
})
if err != nil {