Move almost all functions' parameter db.Engine to context.Context (#19748)

* Move almost all functions' parameter db.Engine to context.Context
* remove some unnecessary wrap functions
This commit is contained in:
Lunny Xiao 2022-05-20 22:08:52 +08:00 committed by GitHub
parent d81e31ad78
commit fd7d83ace6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
232 changed files with 1463 additions and 2108 deletions

View file

@ -60,11 +60,6 @@ func (a *Attachment) DownloadURL() string {
return setting.AppURL + "attachments/" + url.PathEscape(a.UUID)
}
// GetAttachmentByID returns attachment by given id
func GetAttachmentByID(id int64) (*Attachment, error) {
return getAttachmentByID(db.GetEngine(db.DefaultContext), id)
}
// _____ __ __ .__ __
// / _ \_/ |__/ |______ ____ | |__ _____ ____ _____/ |_
// / /_\ \ __\ __\__ \ _/ ___\| | \ / \_/ __ \ / \ __\
@ -88,9 +83,10 @@ func (err ErrAttachmentNotExist) Error() string {
return fmt.Sprintf("attachment does not exist [id: %d, uuid: %s]", err.ID, err.UUID)
}
func getAttachmentByID(e db.Engine, id int64) (*Attachment, error) {
// GetAttachmentByID returns attachment by given id
func GetAttachmentByID(ctx context.Context, id int64) (*Attachment, error) {
attach := &Attachment{}
if has, err := e.ID(id).Get(attach); err != nil {
if has, err := db.GetEngine(ctx).ID(id).Get(attach); err != nil {
return nil, err
} else if !has {
return nil, ErrAttachmentNotExist{ID: id, UUID: ""}
@ -98,9 +94,10 @@ func getAttachmentByID(e db.Engine, id int64) (*Attachment, error) {
return attach, nil
}
func getAttachmentByUUID(e db.Engine, uuid string) (*Attachment, error) {
// GetAttachmentByUUID returns attachment by given UUID.
func GetAttachmentByUUID(ctx context.Context, uuid string) (*Attachment, error) {
attach := &Attachment{}
has, err := e.Where("uuid=?", uuid).Get(attach)
has, err := db.GetEngine(ctx).Where("uuid=?", uuid).Get(attach)
if err != nil {
return nil, err
} else if !has {
@ -111,22 +108,13 @@ func getAttachmentByUUID(e db.Engine, uuid string) (*Attachment, error) {
// GetAttachmentsByUUIDs returns attachment by given UUID list.
func GetAttachmentsByUUIDs(ctx context.Context, uuids []string) ([]*Attachment, error) {
return getAttachmentsByUUIDs(db.GetEngine(ctx), uuids)
}
func getAttachmentsByUUIDs(e db.Engine, uuids []string) ([]*Attachment, error) {
if len(uuids) == 0 {
return []*Attachment{}, nil
}
// Silently drop invalid uuids.
attachments := make([]*Attachment, 0, len(uuids))
return attachments, e.In("uuid", uuids).Find(&attachments)
}
// GetAttachmentByUUID returns attachment by given UUID.
func GetAttachmentByUUID(uuid string) (*Attachment, error) {
return getAttachmentByUUID(db.GetEngine(db.DefaultContext), uuid)
return attachments, db.GetEngine(ctx).In("uuid", uuids).Find(&attachments)
}
// ExistAttachmentsByUUID returns true if attachment is exist by given UUID
@ -134,37 +122,22 @@ func ExistAttachmentsByUUID(uuid string) (bool, error) {
return db.GetEngine(db.DefaultContext).Where("`uuid`=?", uuid).Exist(new(Attachment))
}
// GetAttachmentByReleaseIDFileName returns attachment by given releaseId and fileName.
func GetAttachmentByReleaseIDFileName(releaseID int64, fileName string) (*Attachment, error) {
return getAttachmentByReleaseIDFileName(db.GetEngine(db.DefaultContext), releaseID, fileName)
}
// GetAttachmentsByIssueIDCtx returns all attachments of an issue.
func GetAttachmentsByIssueIDCtx(ctx context.Context, issueID int64) ([]*Attachment, error) {
// GetAttachmentsByIssueID returns all attachments of an issue.
func GetAttachmentsByIssueID(ctx context.Context, issueID int64) ([]*Attachment, error) {
attachments := make([]*Attachment, 0, 10)
return attachments, db.GetEngine(ctx).Where("issue_id = ? AND comment_id = 0", issueID).Find(&attachments)
}
// GetAttachmentsByIssueID returns all attachments of an issue.
func GetAttachmentsByIssueID(issueID int64) ([]*Attachment, error) {
return GetAttachmentsByIssueIDCtx(db.DefaultContext, issueID)
}
// GetAttachmentsByCommentID returns all attachments if comment by given ID.
func GetAttachmentsByCommentID(commentID int64) ([]*Attachment, error) {
return GetAttachmentsByCommentIDCtx(db.DefaultContext, commentID)
}
// GetAttachmentsByCommentIDCtx returns all attachments if comment by given ID.
func GetAttachmentsByCommentIDCtx(ctx context.Context, commentID int64) ([]*Attachment, error) {
func GetAttachmentsByCommentID(ctx context.Context, commentID int64) ([]*Attachment, error) {
attachments := make([]*Attachment, 0, 10)
return attachments, db.GetEngine(ctx).Where("comment_id=?", commentID).Find(&attachments)
}
// getAttachmentByReleaseIDFileName return a file based on the the following infos:
func getAttachmentByReleaseIDFileName(e db.Engine, releaseID int64, fileName string) (*Attachment, error) {
// GetAttachmentByReleaseIDFileName returns attachment by given releaseId and fileName.
func GetAttachmentByReleaseIDFileName(ctx context.Context, releaseID int64, fileName string) (*Attachment, error) {
attach := &Attachment{ReleaseID: releaseID, Name: fileName}
has, err := e.Get(attach)
has, err := db.GetEngine(ctx).Get(attach)
if err != nil {
return nil, err
} else if !has {
@ -207,7 +180,7 @@ func DeleteAttachments(ctx context.Context, attachments []*Attachment, remove bo
// DeleteAttachmentsByIssue deletes all attachments associated with the given issue.
func DeleteAttachmentsByIssue(issueID int64, remove bool) (int, error) {
attachments, err := GetAttachmentsByIssueID(issueID)
attachments, err := GetAttachmentsByIssueID(db.DefaultContext, issueID)
if err != nil {
return 0, err
}
@ -217,7 +190,7 @@ func DeleteAttachmentsByIssue(issueID int64, remove bool) (int, error) {
// DeleteAttachmentsByComment deletes all attachments associated with the given comment.
func DeleteAttachmentsByComment(commentID int64, remove bool) (int, error) {
attachments, err := GetAttachmentsByCommentID(commentID)
attachments, err := GetAttachmentsByCommentID(db.DefaultContext, commentID)
if err != nil {
return 0, err
}
@ -225,11 +198,6 @@ func DeleteAttachmentsByComment(commentID int64, remove bool) (int, error) {
return DeleteAttachments(db.DefaultContext, attachments, remove)
}
// UpdateAttachment updates the given attachment in database
func UpdateAttachment(atta *Attachment) error {
return UpdateAttachmentCtx(db.DefaultContext, atta)
}
// UpdateAttachmentByUUID Updates attachment via uuid
func UpdateAttachmentByUUID(ctx context.Context, attach *Attachment, cols ...string) error {
if attach.UUID == "" {
@ -239,8 +207,8 @@ func UpdateAttachmentByUUID(ctx context.Context, attach *Attachment, cols ...str
return err
}
// UpdateAttachmentCtx updates the given attachment in database
func UpdateAttachmentCtx(ctx context.Context, atta *Attachment) error {
// UpdateAttachment updates the given attachment in database
func UpdateAttachment(ctx context.Context, atta *Attachment) error {
sess := db.GetEngine(ctx).Cols("name", "issue_id", "release_id", "comment_id", "download_count")
if atta.ID != 0 && atta.UUID == "" {
sess = sess.ID(atta.ID)

View file

@ -16,7 +16,7 @@ import (
func TestIncreaseDownloadCount(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
attachment, err := GetAttachmentByUUID("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")
attachment, err := GetAttachmentByUUID(db.DefaultContext, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")
assert.NoError(t, err)
assert.Equal(t, int64(0), attachment.DownloadCount)
@ -24,7 +24,7 @@ func TestIncreaseDownloadCount(t *testing.T) {
err = attachment.IncreaseDownloadCount()
assert.NoError(t, err)
attachment, err = GetAttachmentByUUID("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")
attachment, err = GetAttachmentByUUID(db.DefaultContext, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")
assert.NoError(t, err)
assert.Equal(t, int64(1), attachment.DownloadCount)
}
@ -33,11 +33,11 @@ func TestGetByCommentOrIssueID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
// count of attachments from issue ID
attachments, err := GetAttachmentsByIssueID(1)
attachments, err := GetAttachmentsByIssueID(db.DefaultContext, 1)
assert.NoError(t, err)
assert.Len(t, attachments, 1)
attachments, err = GetAttachmentsByCommentID(1)
attachments, err = GetAttachmentsByCommentID(db.DefaultContext, 1)
assert.NoError(t, err)
assert.Len(t, attachments, 2)
}
@ -56,7 +56,7 @@ func TestDeleteAttachments(t *testing.T) {
err = DeleteAttachment(&Attachment{ID: 8}, false)
assert.NoError(t, err)
attachment, err := GetAttachmentByUUID("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a18")
attachment, err := GetAttachmentByUUID(db.DefaultContext, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a18")
assert.Error(t, err)
assert.True(t, IsErrAttachmentNotExist(err))
assert.Nil(t, attachment)
@ -65,7 +65,7 @@ func TestDeleteAttachments(t *testing.T) {
func TestGetAttachmentByID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
attach, err := GetAttachmentByID(1)
attach, err := GetAttachmentByID(db.DefaultContext, 1)
assert.NoError(t, err)
assert.Equal(t, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11", attach.UUID)
}
@ -81,12 +81,12 @@ func TestAttachment_DownloadURL(t *testing.T) {
func TestUpdateAttachment(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
attach, err := GetAttachmentByID(1)
attach, err := GetAttachmentByID(db.DefaultContext, 1)
assert.NoError(t, err)
assert.Equal(t, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11", attach.UUID)
attach.Name = "new_name"
assert.NoError(t, UpdateAttachment(attach))
assert.NoError(t, UpdateAttachment(db.DefaultContext, attach))
unittest.AssertExistsAndLoadBean(t, &Attachment{Name: "new_name"})
}

View file

@ -5,6 +5,7 @@
package repo
import (
"context"
"fmt"
"image/png"
"io"
@ -25,11 +26,11 @@ func (repo *Repository) CustomAvatarRelativePath() string {
// RelAvatarLink returns a relative link to the repository's avatar.
func (repo *Repository) RelAvatarLink() string {
return repo.relAvatarLink(db.GetEngine(db.DefaultContext))
return repo.relAvatarLink(db.DefaultContext)
}
// generateRandomAvatar generates a random avatar for repository.
func generateRandomAvatar(e db.Engine, repo *Repository) error {
func generateRandomAvatar(ctx context.Context, repo *Repository) error {
idToString := fmt.Sprintf("%d", repo.ID)
seed := idToString
@ -51,14 +52,14 @@ func generateRandomAvatar(e db.Engine, repo *Repository) error {
log.Info("New random avatar created for repository: %d", repo.ID)
if _, err := e.ID(repo.ID).Cols("avatar").NoAutoTime().Update(repo); err != nil {
if _, err := db.GetEngine(ctx).ID(repo.ID).Cols("avatar").NoAutoTime().Update(repo); err != nil {
return err
}
return nil
}
func (repo *Repository) relAvatarLink(e db.Engine) string {
func (repo *Repository) relAvatarLink(ctx context.Context) string {
// If no avatar - path is empty
avatarPath := repo.CustomAvatarRelativePath()
if len(avatarPath) == 0 {
@ -66,7 +67,7 @@ func (repo *Repository) relAvatarLink(e db.Engine) string {
case "image":
return setting.RepoAvatar.FallbackImage
case "random":
if err := generateRandomAvatar(e, repo); err != nil {
if err := generateRandomAvatar(ctx, repo); err != nil {
log.Error("generateRandomAvatar: %v", err)
}
default:
@ -79,12 +80,12 @@ func (repo *Repository) relAvatarLink(e db.Engine) string {
// AvatarLink returns a link to the repository's avatar.
func (repo *Repository) AvatarLink() string {
return repo.avatarLink(db.GetEngine(db.DefaultContext))
return repo.avatarLink(db.DefaultContext)
}
// avatarLink returns user avatar absolute link.
func (repo *Repository) avatarLink(e db.Engine) string {
link := repo.relAvatarLink(e)
func (repo *Repository) avatarLink(ctx context.Context) string {
link := repo.relAvatarLink(ctx)
// we only prepend our AppURL to our known (relative, internal) avatar link to get an absolute URL
if strings.HasPrefix(link, "/") && !strings.HasPrefix(link, "//") {
return setting.AppURL + strings.TrimPrefix(link, setting.AppSubURL)[1:]

View file

@ -37,15 +37,14 @@ type Collaborator struct {
// GetCollaborators returns the collaborators for a repository
func GetCollaborators(ctx context.Context, repoID int64, listOptions db.ListOptions) ([]*Collaborator, error) {
e := db.GetEngine(ctx)
collaborations, err := getCollaborations(e, repoID, listOptions)
collaborations, err := getCollaborations(ctx, repoID, listOptions)
if err != nil {
return nil, fmt.Errorf("getCollaborations: %v", err)
}
collaborators := make([]*Collaborator, 0, len(collaborations))
for _, c := range collaborations {
user, err := user_model.GetUserByIDEngine(e, c.UserID)
user, err := user_model.GetUserByIDCtx(ctx, c.UserID)
if err != nil {
if user_model.IsErrUserNotExist(err) {
log.Warn("Inconsistent DB: User: %d is listed as collaborator of %-v but does not exist", c.UserID, repoID)
@ -85,12 +84,14 @@ func IsCollaborator(ctx context.Context, repoID, userID int64) (bool, error) {
return db.GetEngine(ctx).Get(&Collaboration{RepoID: repoID, UserID: userID})
}
func getCollaborations(e db.Engine, repoID int64, listOptions db.ListOptions) ([]*Collaboration, error) {
func getCollaborations(ctx context.Context, repoID int64, listOptions db.ListOptions) ([]*Collaboration, error) {
if listOptions.Page == 0 {
collaborations := make([]*Collaboration, 0, 8)
return collaborations, e.Find(&collaborations, &Collaboration{RepoID: repoID})
return collaborations, db.GetEngine(ctx).Find(&collaborations, &Collaboration{RepoID: repoID})
}
e := db.GetEngine(ctx)
e = db.SetEnginePagination(e, &listOptions)
collaborations := make([]*Collaboration, 0, listOptions.PageSize)

View file

@ -10,16 +10,12 @@ import (
"code.gitea.io/gitea/models/db"
)
func getRepositoriesByForkID(e db.Engine, forkID int64) ([]*Repository, error) {
repos := make([]*Repository, 0, 10)
return repos, e.
Where("fork_id=?", forkID).
Find(&repos)
}
// GetRepositoriesByForkID returns all repositories with given fork ID.
func GetRepositoriesByForkID(ctx context.Context, forkID int64) ([]*Repository, error) {
return getRepositoriesByForkID(db.GetEngine(ctx), forkID)
repos := make([]*Repository, 0, 10)
return repos, db.GetEngine(ctx).
Where("fork_id=?", forkID).
Find(&repos)
}
// GetForkedRepo checks if given user has already forked a repository with given ID.

View file

@ -5,6 +5,7 @@
package repo
import (
"context"
"math"
"strings"
@ -66,22 +67,18 @@ func (stats LanguageStatList) getLanguagePercentages() map[string]float32 {
return langPerc
}
func getLanguageStats(e db.Engine, repo *Repository) (LanguageStatList, error) {
// GetLanguageStats returns the language statistics for a repository
func GetLanguageStats(ctx context.Context, repo *Repository) (LanguageStatList, error) {
stats := make(LanguageStatList, 0, 6)
if err := e.Where("`repo_id` = ?", repo.ID).Desc("`size`").Find(&stats); err != nil {
if err := db.GetEngine(ctx).Where("`repo_id` = ?", repo.ID).Desc("`size`").Find(&stats); err != nil {
return nil, err
}
return stats, nil
}
// GetLanguageStats returns the language statistics for a repository
func GetLanguageStats(repo *Repository) (LanguageStatList, error) {
return getLanguageStats(db.GetEngine(db.DefaultContext), repo)
}
// GetTopLanguageStats returns the top language statistics for a repository
func GetTopLanguageStats(repo *Repository, limit int) (LanguageStatList, error) {
stats, err := getLanguageStats(db.GetEngine(db.DefaultContext), repo)
stats, err := GetLanguageStats(db.DefaultContext, repo)
if err != nil {
return nil, err
}
@ -120,7 +117,7 @@ func UpdateLanguageStats(repo *Repository, commitID string, stats map[string]int
defer committer.Close()
sess := db.GetEngine(ctx)
oldstats, err := getLanguageStats(sess, repo)
oldstats, err := GetLanguageStats(ctx, repo)
if err != nil {
return err
}
@ -151,7 +148,7 @@ func UpdateLanguageStats(repo *Repository, commitID string, stats map[string]int
}
// Insert new language
if !upd {
if _, err := sess.Insert(&LanguageStat{
if err := db.Insert(ctx, &LanguageStat{
RepoID: repo.ID,
CommitID: commitID,
IsPrimary: llang == topLang,
@ -176,7 +173,7 @@ func UpdateLanguageStats(repo *Repository, commitID string, stats map[string]int
}
// Update indexer status
if err = updateIndexerStatus(sess, repo, RepoIndexerTypeStats, commitID); err != nil {
if err = UpdateIndexerStatus(ctx, repo, RepoIndexerTypeStats, commitID); err != nil {
return err
}
@ -190,10 +187,9 @@ func CopyLanguageStat(originalRepo, destRepo *Repository) error {
return err
}
defer committer.Close()
sess := db.GetEngine(ctx)
RepoLang := make(LanguageStatList, 0, 6)
if err := sess.Where("`repo_id` = ?", originalRepo.ID).Desc("`size`").Find(&RepoLang); err != nil {
if err := db.GetEngine(ctx).Where("`repo_id` = ?", originalRepo.ID).Desc("`size`").Find(&RepoLang); err != nil {
return err
}
if len(RepoLang) > 0 {
@ -204,10 +200,10 @@ func CopyLanguageStat(originalRepo, destRepo *Repository) error {
}
// update destRepo's indexer status
tmpCommitID := RepoLang[0].CommitID
if err := updateIndexerStatus(sess, destRepo, RepoIndexerTypeStats, tmpCommitID); err != nil {
if err := UpdateIndexerStatus(ctx, destRepo, RepoIndexerTypeStats, tmpCommitID); err != nil {
return err
}
if _, err := sess.Insert(&RepoLang); err != nil {
if err := db.Insert(ctx, &RepoLang); err != nil {
return err
}
}

View file

@ -14,8 +14,6 @@ import (
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/timeutil"
"xorm.io/xorm"
)
// ErrMirrorNotExist mirror does not exist error
@ -56,21 +54,16 @@ func (m *Mirror) BeforeInsert() {
}
}
// AfterLoad is invoked from XORM after setting the values of all fields of this object.
func (m *Mirror) AfterLoad(session *xorm.Session) {
if m == nil {
return
// GetRepository returns the repository.
func (m *Mirror) GetRepository() *Repository {
if m.Repo != nil {
return m.Repo
}
var err error
m.Repo, err = getRepositoryByID(session, m.RepoID)
m.Repo, err = GetRepositoryByIDCtx(db.DefaultContext, m.RepoID)
if err != nil {
log.Error("getRepositoryByID[%d]: %v", m.ID, err)
}
}
// GetRepository returns the repository.
func (m *Mirror) GetRepository() *Repository {
return m.Repo
}
@ -88,9 +81,10 @@ func (m *Mirror) ScheduleNextUpdate() {
}
}
func getMirrorByRepoID(e db.Engine, repoID int64) (*Mirror, error) {
// GetMirrorByRepoID returns mirror information of a repository.
func GetMirrorByRepoID(ctx context.Context, repoID int64) (*Mirror, error) {
m := &Mirror{RepoID: repoID}
has, err := e.Get(m)
has, err := db.GetEngine(ctx).Get(m)
if err != nil {
return nil, err
} else if !has {
@ -99,19 +93,10 @@ func getMirrorByRepoID(e db.Engine, repoID int64) (*Mirror, error) {
return m, nil
}
// GetMirrorByRepoID returns mirror information of a repository.
func GetMirrorByRepoID(repoID int64) (*Mirror, error) {
return getMirrorByRepoID(db.GetEngine(db.DefaultContext), repoID)
}
func updateMirror(e db.Engine, m *Mirror) error {
_, err := e.ID(m.ID).AllCols().Update(m)
return err
}
// UpdateMirror updates the mirror
func UpdateMirror(m *Mirror) error {
return updateMirror(db.GetEngine(db.DefaultContext), m)
func UpdateMirror(ctx context.Context, m *Mirror) error {
_, err := db.GetEngine(ctx).ID(m.ID).AllCols().Update(m)
return err
}
// TouchMirror updates the mirror updatedUnix
@ -146,7 +131,7 @@ func InsertMirror(mirror *Mirror) error {
// MirrorRepositoryList contains the mirror repositories
type MirrorRepositoryList []*Repository
func (repos MirrorRepositoryList) loadAttributes(e db.Engine) error {
func (repos MirrorRepositoryList) loadAttributes(ctx context.Context) error {
if len(repos) == 0 {
return nil
}
@ -161,7 +146,7 @@ func (repos MirrorRepositoryList) loadAttributes(e db.Engine) error {
repoIDs = append(repoIDs, repos[i].ID)
}
mirrors := make([]*Mirror, 0, len(repoIDs))
if err := e.
if err := db.GetEngine(ctx).
Where("id > 0").
In("repo_id", repoIDs).
Find(&mirrors); err != nil {
@ -174,11 +159,12 @@ func (repos MirrorRepositoryList) loadAttributes(e db.Engine) error {
}
for i := range repos {
repos[i].Mirror = set[repos[i].ID]
repos[i].Mirror.Repo = repos[i]
}
return nil
}
// LoadAttributes loads the attributes for the given MirrorRepositoryList
func (repos MirrorRepositoryList) LoadAttributes() error {
return repos.loadAttributes(db.GetEngine(db.DefaultContext))
return repos.loadAttributes(db.DefaultContext)
}

View file

@ -11,8 +11,6 @@ import (
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/timeutil"
"xorm.io/xorm"
)
// ErrPushMirrorNotExist mirror does not exist error
@ -35,21 +33,16 @@ func init() {
db.RegisterModel(new(PushMirror))
}
// AfterLoad is invoked from XORM after setting the values of all fields of this object.
func (m *PushMirror) AfterLoad(session *xorm.Session) {
if m == nil {
return
// GetRepository returns the path of the repository.
func (m *PushMirror) GetRepository() *Repository {
if m.Repo != nil {
return m.Repo
}
var err error
m.Repo, err = getRepositoryByID(session, m.RepoID)
m.Repo, err = GetRepositoryByIDCtx(db.DefaultContext, m.RepoID)
if err != nil {
log.Error("getRepositoryByID[%d]: %v", m.ID, err)
}
}
// GetRepository returns the path of the repository.
func (m *PushMirror) GetRepository() *Repository {
return m.Repo
}

View file

@ -289,7 +289,7 @@ func (repo *Repository) LoadUnits(ctx context.Context) (err error) {
return nil
}
repo.Units, err = getUnitsByRepoID(db.GetEngine(ctx), repo.ID)
repo.Units, err = getUnitsByRepoID(ctx, repo.ID)
if log.IsTrace() {
unitTypeStrings := make([]string, len(repo.Units))
for i, unit := range repo.Units {
@ -383,7 +383,7 @@ func (repo *Repository) GetOwner(ctx context.Context) (err error) {
return nil
}
repo.Owner, err = user_model.GetUserByIDEngine(db.GetEngine(ctx), repo.OwnerID)
repo.Owner, err = user_model.GetUserByIDCtx(ctx, repo.OwnerID)
return err
}
@ -454,15 +454,15 @@ func (repo *Repository) ComposeDocumentMetas() map[string]string {
// returns an error on failure (NOTE: no error is returned for
// non-fork repositories, and BaseRepo will be left untouched)
func (repo *Repository) GetBaseRepo() (err error) {
return repo.getBaseRepo(db.GetEngine(db.DefaultContext))
return repo.getBaseRepo(db.DefaultContext)
}
func (repo *Repository) getBaseRepo(e db.Engine) (err error) {
func (repo *Repository) getBaseRepo(ctx context.Context) (err error) {
if !repo.IsFork {
return nil
}
repo.BaseRepo, err = getRepositoryByID(e, repo.ForkID)
repo.BaseRepo, err = GetRepositoryByIDCtx(ctx, repo.ForkID)
return err
}
@ -481,16 +481,6 @@ func (repo *Repository) RepoPath() string {
return RepoPath(repo.OwnerName, repo.Name)
}
// GitConfigPath returns the path to a repository's git config/ directory
func GitConfigPath(repoPath string) string {
return filepath.Join(repoPath, "config")
}
// GitConfigPath returns the repository git config path
func (repo *Repository) GitConfigPath() string {
return GitConfigPath(repo.RepoPath())
}
// Link returns the repository link
func (repo *Repository) Link() string {
return setting.AppSubURL + "/" + url.PathEscape(repo.OwnerName) + "/" + url.PathEscape(repo.Name)
@ -669,9 +659,10 @@ func GetRepositoryByName(ownerID int64, name string) (*Repository, error) {
return repo, err
}
func getRepositoryByID(e db.Engine, id int64) (*Repository, error) {
// GetRepositoryByIDCtx returns the repository by given id if exists.
func GetRepositoryByIDCtx(ctx context.Context, id int64) (*Repository, error) {
repo := new(Repository)
has, err := e.ID(id).Get(repo)
has, err := db.GetEngine(ctx).ID(id).Get(repo)
if err != nil {
return nil, err
} else if !has {
@ -682,12 +673,7 @@ func getRepositoryByID(e db.Engine, id int64) (*Repository, error) {
// GetRepositoryByID returns the repository by given id if exists.
func GetRepositoryByID(id int64) (*Repository, error) {
return getRepositoryByID(db.GetEngine(db.DefaultContext), id)
}
// GetRepositoryByIDCtx returns the repository by given id if exists.
func GetRepositoryByIDCtx(ctx context.Context, id int64) (*Repository, error) {
return getRepositoryByID(db.GetEngine(ctx), id)
return GetRepositoryByIDCtx(db.DefaultContext, id)
}
// GetRepositoriesMapByIDs returns the repositories by given id slice.
@ -696,8 +682,8 @@ func GetRepositoriesMapByIDs(ids []int64) (map[int64]*Repository, error) {
return repos, db.GetEngine(db.DefaultContext).In("id", ids).Find(&repos)
}
// IsRepositoryExistCtx returns true if the repository with given name under user has already existed.
func IsRepositoryExistCtx(ctx context.Context, u *user_model.User, repoName string) (bool, error) {
// IsRepositoryExist returns true if the repository with given name under user has already existed.
func IsRepositoryExist(ctx context.Context, u *user_model.User, repoName string) (bool, error) {
has, err := db.GetEngine(ctx).Get(&Repository{
OwnerID: u.ID,
LowerName: strings.ToLower(repoName),
@ -709,29 +695,20 @@ func IsRepositoryExistCtx(ctx context.Context, u *user_model.User, repoName stri
return has && isDir, err
}
// IsRepositoryExist returns true if the repository with given name under user has already existed.
func IsRepositoryExist(u *user_model.User, repoName string) (bool, error) {
return IsRepositoryExistCtx(db.DefaultContext, u, repoName)
}
// GetTemplateRepo populates repo.TemplateRepo for a generated repository and
// returns an error on failure (NOTE: no error is returned for
// non-generated repositories, and TemplateRepo will be left untouched)
func GetTemplateRepo(repo *Repository) (*Repository, error) {
return getTemplateRepo(db.GetEngine(db.DefaultContext), repo)
}
func getTemplateRepo(e db.Engine, repo *Repository) (*Repository, error) {
func GetTemplateRepo(ctx context.Context, repo *Repository) (*Repository, error) {
if !repo.IsGenerated() {
return nil, nil
}
return getRepositoryByID(e, repo.TemplateID)
return GetRepositoryByIDCtx(ctx, repo.TemplateID)
}
// TemplateRepo returns the repository, which is template of this repository
func (repo *Repository) TemplateRepo() *Repository {
repo, err := GetTemplateRepo(repo)
repo, err := GetTemplateRepo(db.DefaultContext, repo)
if err != nil {
log.Error("TemplateRepo: %v", err)
return nil
@ -739,26 +716,27 @@ func (repo *Repository) TemplateRepo() *Repository {
return repo
}
func countRepositories(userID int64, private bool) int64 {
sess := db.GetEngine(db.DefaultContext).Where("id > 0")
if userID > 0 {
sess.And("owner_id = ?", userID)
}
if !private {
sess.And("is_private=?", false)
}
count, err := sess.Count(new(Repository))
if err != nil {
log.Error("countRepositories: %v", err)
}
return count
type CountRepositoryOptions struct {
OwnerID int64
Private util.OptionalBool
}
// CountRepositories returns number of repositories.
// Argument private only takes effect when it is false,
// set it true to count all repositories.
func CountRepositories(private bool) int64 {
return countRepositories(-1, private)
func CountRepositories(ctx context.Context, opts CountRepositoryOptions) (int64, error) {
sess := db.GetEngine(ctx).Where("id > 0")
if opts.OwnerID > 0 {
sess.And("owner_id = ?", opts.OwnerID)
}
if !opts.Private.IsNone() {
sess.And("is_private=?", opts.Private.IsTrue())
}
count, err := sess.Count(new(Repository))
if err != nil {
return 0, fmt.Errorf("countRepositories: %v", err)
}
return count, nil
}

View file

@ -5,6 +5,7 @@
package repo
import (
"context"
"fmt"
"code.gitea.io/gitea/models/db"
@ -62,8 +63,8 @@ func GetUnindexedRepos(indexerType RepoIndexerType, maxRepoID int64, page, pageS
return ids, err
}
// getIndexerStatus loads repo codes indxer status
func getIndexerStatus(e db.Engine, repo *Repository, indexerType RepoIndexerType) (*RepoIndexerStatus, error) {
// GetIndexerStatus loads repo codes indxer status
func GetIndexerStatus(ctx context.Context, repo *Repository, indexerType RepoIndexerType) (*RepoIndexerStatus, error) {
switch indexerType {
case RepoIndexerTypeCode:
if repo.CodeIndexerStatus != nil {
@ -75,7 +76,7 @@ func getIndexerStatus(e db.Engine, repo *Repository, indexerType RepoIndexerType
}
}
status := &RepoIndexerStatus{RepoID: repo.ID}
if has, err := e.Where("`indexer_type` = ?", indexerType).Get(status); err != nil {
if has, err := db.GetEngine(ctx).Where("`indexer_type` = ?", indexerType).Get(status); err != nil {
return nil, err
} else if !has {
status.IndexerType = indexerType
@ -90,36 +91,25 @@ func getIndexerStatus(e db.Engine, repo *Repository, indexerType RepoIndexerType
return status, nil
}
// GetIndexerStatus loads repo codes indxer status
func GetIndexerStatus(repo *Repository, indexerType RepoIndexerType) (*RepoIndexerStatus, error) {
return getIndexerStatus(db.GetEngine(db.DefaultContext), repo, indexerType)
}
// updateIndexerStatus updates indexer status
func updateIndexerStatus(e db.Engine, repo *Repository, indexerType RepoIndexerType, sha string) error {
status, err := getIndexerStatus(e, repo, indexerType)
// UpdateIndexerStatus updates indexer status
func UpdateIndexerStatus(ctx context.Context, repo *Repository, indexerType RepoIndexerType, sha string) error {
status, err := GetIndexerStatus(ctx, repo, indexerType)
if err != nil {
return fmt.Errorf("UpdateIndexerStatus: Unable to getIndexerStatus for repo: %s Error: %v", repo.FullName(), err)
}
if len(status.CommitSha) == 0 {
status.CommitSha = sha
_, err := e.Insert(status)
if err != nil {
if err := db.Insert(ctx, status); err != nil {
return fmt.Errorf("UpdateIndexerStatus: Unable to insert repoIndexerStatus for repo: %s Sha: %s Error: %v", repo.FullName(), sha, err)
}
return nil
}
status.CommitSha = sha
_, err = e.ID(status.ID).Cols("commit_sha").
_, err = db.GetEngine(ctx).ID(status.ID).Cols("commit_sha").
Update(status)
if err != nil {
return fmt.Errorf("UpdateIndexerStatus: Unable to update repoIndexerStatus for repo: %s Sha: %s Error: %v", repo.FullName(), sha, err)
}
return nil
}
// UpdateIndexerStatus updates indexer status
func UpdateIndexerStatus(repo *Repository, indexerType RepoIndexerType, sha string) error {
return updateIndexerStatus(db.GetEngine(db.DefaultContext), repo, indexerType, sha)
}

View file

@ -22,9 +22,10 @@ func GetUserMirrorRepositories(userID int64) ([]*Repository, error) {
func IterateRepository(f func(repo *Repository) error) error {
var start int
batchSize := setting.Database.IterateBufferSize
sess := db.GetEngine(db.DefaultContext)
for {
repos := make([]*Repository, 0, batchSize)
if err := db.GetEngine(db.DefaultContext).Limit(batchSize, start).Find(&repos); err != nil {
if err := sess.Limit(batchSize, start).Find(&repos); err != nil {
return err
}
if len(repos) == 0 {

View file

@ -9,17 +9,24 @@ import (
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/util"
"github.com/stretchr/testify/assert"
)
var (
countRepospts = CountRepositoryOptions{OwnerID: 10}
countReposptsPublic = CountRepositoryOptions{OwnerID: 10, Private: util.OptionalBoolFalse}
countReposptsPrivate = CountRepositoryOptions{OwnerID: 10, Private: util.OptionalBoolTrue}
)
func TestGetRepositoryCount(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
count, err1 := GetRepositoryCount(db.DefaultContext, 10)
privateCount, err2 := GetPrivateRepositoryCount(&user_model.User{ID: int64(10)})
publicCount, err3 := GetPublicRepositoryCount(&user_model.User{ID: int64(10)})
ctx := db.DefaultContext
count, err1 := CountRepositories(ctx, countRepospts)
privateCount, err2 := CountRepositories(ctx, countReposptsPrivate)
publicCount, err3 := CountRepositories(ctx, countReposptsPublic)
assert.NoError(t, err1)
assert.NoError(t, err2)
assert.NoError(t, err3)
@ -30,7 +37,7 @@ func TestGetRepositoryCount(t *testing.T) {
func TestGetPublicRepositoryCount(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
count, err := GetPublicRepositoryCount(&user_model.User{ID: int64(10)})
count, err := CountRepositories(db.DefaultContext, countReposptsPublic)
assert.NoError(t, err)
assert.Equal(t, int64(1), count)
}
@ -38,7 +45,7 @@ func TestGetPublicRepositoryCount(t *testing.T) {
func TestGetPrivateRepositoryCount(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
count, err := GetPrivateRepositoryCount(&user_model.User{ID: int64(10)})
count, err := CountRepositories(db.DefaultContext, countReposptsPrivate)
assert.NoError(t, err)
assert.Equal(t, int64(2), count)
}

View file

@ -5,6 +5,7 @@
package repo
import (
"context"
"fmt"
"code.gitea.io/gitea/models/db"
@ -206,9 +207,9 @@ func (r *RepoUnit) ExternalTrackerConfig() *ExternalTrackerConfig {
return r.Config.(*ExternalTrackerConfig)
}
func getUnitsByRepoID(e db.Engine, repoID int64) (units []*RepoUnit, err error) {
func getUnitsByRepoID(ctx context.Context, repoID int64) (units []*RepoUnit, err error) {
var tmpUnits []*RepoUnit
if err := e.Where("repo_id = ?", repoID).Find(&tmpUnits); err != nil {
if err := db.GetEngine(ctx).Where("repo_id = ?", repoID).Find(&tmpUnits); err != nil {
return nil, err
}

View file

@ -5,6 +5,8 @@
package repo
import (
"context"
"code.gitea.io/gitea/models/db"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/timeutil"
@ -29,7 +31,7 @@ func StarRepo(userID, repoID int64, star bool) error {
return err
}
defer committer.Close()
staring := isStaring(db.GetEngine(ctx), userID, repoID)
staring := IsStaring(ctx, userID, repoID)
if star {
if staring {
@ -65,12 +67,8 @@ func StarRepo(userID, repoID int64, star bool) error {
}
// IsStaring checks if user has starred given repository.
func IsStaring(userID, repoID int64) bool {
return isStaring(db.GetEngine(db.DefaultContext), userID, repoID)
}
func isStaring(e db.Engine, userID, repoID int64) bool {
has, _ := e.Get(&Star{UID: userID, RepoID: repoID})
func IsStaring(ctx context.Context, userID, repoID int64) bool {
has, _ := db.GetEngine(ctx).Get(&Star{UID: userID, RepoID: repoID})
return has
}

View file

@ -28,8 +28,8 @@ func TestStarRepo(t *testing.T) {
func TestIsStaring(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
assert.True(t, IsStaring(2, 4))
assert.False(t, IsStaring(3, 4))
assert.True(t, IsStaring(db.DefaultContext, 2, 4))
assert.False(t, IsStaring(db.DefaultContext, 3, 4))
}
func TestRepository_GetStargazers(t *testing.T) {

View file

@ -99,8 +99,9 @@ func GetTopicByName(name string) (*Topic, error) {
// addTopicByNameToRepo adds a topic name to a repo and increments the topic count.
// Returns topic after the addition
func addTopicByNameToRepo(e db.Engine, repoID int64, topicName string) (*Topic, error) {
func addTopicByNameToRepo(ctx context.Context, repoID int64, topicName string) (*Topic, error) {
var topic Topic
e := db.GetEngine(ctx)
has, err := e.Where("name = ?", topicName).Get(&topic)
if err != nil {
return nil, err
@ -108,7 +109,7 @@ func addTopicByNameToRepo(e db.Engine, repoID int64, topicName string) (*Topic,
if !has {
topic.Name = topicName
topic.RepoCount = 1
if _, err := e.Insert(&topic); err != nil {
if err := db.Insert(ctx, &topic); err != nil {
return nil, err
}
} else {
@ -118,7 +119,7 @@ func addTopicByNameToRepo(e db.Engine, repoID int64, topicName string) (*Topic,
}
}
if _, err := e.Insert(&RepoTopic{
if err := db.Insert(ctx, &RepoTopic{
RepoID: repoID,
TopicID: topic.ID,
}); err != nil {
@ -129,8 +130,9 @@ func addTopicByNameToRepo(e db.Engine, repoID int64, topicName string) (*Topic,
}
// removeTopicFromRepo remove a topic from a repo and decrements the topic repo count
func removeTopicFromRepo(e db.Engine, repoID int64, topic *Topic) error {
func removeTopicFromRepo(ctx context.Context, repoID int64, topic *Topic) error {
topic.RepoCount--
e := db.GetEngine(ctx)
if _, err := e.ID(topic.ID).Cols("repo_count").Update(topic); err != nil {
return err
}
@ -208,17 +210,13 @@ func CountTopics(opts *FindTopicOptions) (int64, error) {
}
// GetRepoTopicByName retrieves topic from name for a repo if it exist
func GetRepoTopicByName(repoID int64, topicName string) (*Topic, error) {
return getRepoTopicByName(db.GetEngine(db.DefaultContext), repoID, topicName)
}
func getRepoTopicByName(e db.Engine, repoID int64, topicName string) (*Topic, error) {
func GetRepoTopicByName(ctx context.Context, repoID int64, topicName string) (*Topic, error) {
cond := builder.NewCond()
var topic Topic
cond = cond.And(builder.Eq{"repo_topic.repo_id": repoID}).And(builder.Eq{"topic.name": topicName})
sess := e.Table("topic").Where(cond)
sess := db.GetEngine(ctx).Table("topic").Where(cond)
sess.Join("INNER", "repo_topic", "repo_topic.topic_id = topic.id")
has, err := sess.Get(&topic)
has, err := sess.Select("topic.*").Get(&topic)
if has {
return &topic, err
}
@ -234,7 +232,7 @@ func AddTopic(repoID int64, topicName string) (*Topic, error) {
defer committer.Close()
sess := db.GetEngine(ctx)
topic, err := getRepoTopicByName(sess, repoID, topicName)
topic, err := GetRepoTopicByName(ctx, repoID, topicName)
if err != nil {
return nil, err
}
@ -243,7 +241,7 @@ func AddTopic(repoID int64, topicName string) (*Topic, error) {
return topic, nil
}
topic, err = addTopicByNameToRepo(sess, repoID, topicName)
topic, err = addTopicByNameToRepo(ctx, repoID, topicName)
if err != nil {
return nil, err
}
@ -266,7 +264,7 @@ func AddTopic(repoID int64, topicName string) (*Topic, error) {
// DeleteTopic removes a topic name from a repository (if it has it)
func DeleteTopic(repoID int64, topicName string) (*Topic, error) {
topic, err := GetRepoTopicByName(repoID, topicName)
topic, err := GetRepoTopicByName(db.DefaultContext, repoID, topicName)
if err != nil {
return nil, err
}
@ -275,7 +273,7 @@ func DeleteTopic(repoID int64, topicName string) (*Topic, error) {
return nil, nil
}
err = removeTopicFromRepo(db.GetEngine(db.DefaultContext), repoID, topic)
err = removeTopicFromRepo(db.DefaultContext, repoID, topic)
return topic, err
}
@ -329,14 +327,14 @@ func SaveTopics(repoID int64, topicNames ...string) error {
}
for _, topicName := range addedTopicNames {
_, err := addTopicByNameToRepo(sess, repoID, topicName)
_, err := addTopicByNameToRepo(ctx, repoID, topicName)
if err != nil {
return err
}
}
for _, topic := range removeTopics {
err := removeTopicFromRepo(sess, repoID, topic)
err := removeTopicFromRepo(ctx, repoID, topic)
if err != nil {
return err
}
@ -361,7 +359,7 @@ func SaveTopics(repoID int64, topicNames ...string) error {
// GenerateTopics generates topics from a template repository
func GenerateTopics(ctx context.Context, templateRepo, generateRepo *Repository) error {
for _, topic := range templateRepo.Topics {
if _, err := addTopicByNameToRepo(db.GetEngine(ctx), generateRepo.ID, topic); err != nil {
if _, err := addTopicByNameToRepo(ctx, generateRepo.ID, topic); err != nil {
return err
}
}

View file

@ -42,17 +42,12 @@ func UpdateRepositoryUpdatedTime(repoID int64, updateTime time.Time) error {
return err
}
// UpdateRepositoryColsCtx updates repository's columns
func UpdateRepositoryColsCtx(ctx context.Context, repo *Repository, cols ...string) error {
// UpdateRepositoryCols updates repository's columns
func UpdateRepositoryCols(ctx context.Context, repo *Repository, cols ...string) error {
_, err := db.GetEngine(ctx).ID(repo.ID).Cols(cols...).Update(repo)
return err
}
// UpdateRepositoryCols updates repository's columns
func UpdateRepositoryCols(repo *Repository, cols ...string) error {
return UpdateRepositoryColsCtx(db.DefaultContext, repo, cols...)
}
// ErrReachLimitOfRepo represents a "ReachLimitOfRepo" kind of error.
type ErrReachLimitOfRepo struct {
Limit int
@ -110,7 +105,7 @@ func CheckCreateRepository(doer, u *user_model.User, name string, overwriteOrAdo
return err
}
has, err := IsRepositoryExist(u, name)
has, err := IsRepositoryExist(db.DefaultContext, u, name)
if err != nil {
return fmt.Errorf("IsRepositoryExist: %v", err)
} else if has {
@ -141,7 +136,7 @@ func ChangeRepositoryName(doer *user_model.User, repo *Repository, newRepoName s
return err
}
has, err := IsRepositoryExist(repo.Owner, newRepoName)
has, err := IsRepositoryExist(db.DefaultContext, repo.Owner, newRepoName)
if err != nil {
return fmt.Errorf("IsRepositoryExist: %v", err)
} else if has {

View file

@ -5,10 +5,7 @@
package repo
import (
"context"
"code.gitea.io/gitea/models/db"
user_model "code.gitea.io/gitea/models/user"
)
// GetStarredRepos returns the repos starred by a particular user
@ -51,37 +48,3 @@ func GetWatchedRepos(userID int64, private bool, listOptions db.ListOptions) ([]
total, err := sess.FindAndCount(&repos)
return repos, total, err
}
// CountUserRepositories returns number of repositories user owns.
// Argument private only takes effect when it is false,
// set it true to count all repositories.
func CountUserRepositories(userID int64, private bool) int64 {
return countRepositories(userID, private)
}
func getRepositoryCount(e db.Engine, ownerID int64) (int64, error) {
return e.Count(&Repository{OwnerID: ownerID})
}
func getPublicRepositoryCount(e db.Engine, u *user_model.User) (int64, error) {
return e.Where("is_private = ?", false).Count(&Repository{OwnerID: u.ID})
}
func getPrivateRepositoryCount(e db.Engine, u *user_model.User) (int64, error) {
return e.Where("is_private = ?", true).Count(&Repository{OwnerID: u.ID})
}
// GetRepositoryCount returns the total number of repositories of user.
func GetRepositoryCount(ctx context.Context, ownerID int64) (int64, error) {
return getRepositoryCount(db.GetEngine(ctx), ownerID)
}
// GetPublicRepositoryCount returns the total number of public repositories of user.
func GetPublicRepositoryCount(u *user_model.User) (int64, error) {
return getPublicRepositoryCount(db.GetEngine(db.DefaultContext), u)
}
// GetPrivateRepositoryCount returns the total number of private repositories of user.
func GetPrivateRepositoryCount(u *user_model.User) (int64, error) {
return getPrivateRepositoryCount(db.GetEngine(db.DefaultContext), u)
}

View file

@ -116,8 +116,8 @@ func WatchRepoMode(userID, repoID int64, mode WatchMode) (err error) {
return watchRepoMode(db.DefaultContext, watch, mode)
}
// WatchRepoCtx watch or unwatch repository.
func WatchRepoCtx(ctx context.Context, userID, repoID int64, doWatch bool) (err error) {
// WatchRepo watch or unwatch repository.
func WatchRepo(ctx context.Context, userID, repoID int64, doWatch bool) (err error) {
var watch Watch
if watch, err = GetWatch(ctx, userID, repoID); err != nil {
return err
@ -132,11 +132,6 @@ func WatchRepoCtx(ctx context.Context, userID, repoID int64, doWatch bool) (err
return err
}
// WatchRepo watch or unwatch repository.
func WatchRepo(userID, repoID int64, watch bool) (err error) {
return WatchRepoCtx(db.DefaultContext, userID, repoID, watch)
}
// GetWatchers returns all watchers of given repository.
func GetWatchers(ctx context.Context, repoID int64) ([]*Watch, error) {
watches := make([]*Watch, 0, 10)
@ -176,7 +171,8 @@ func GetRepoWatchers(repoID int64, opts db.ListOptions) ([]*user_model.User, err
return users, sess.Find(&users)
}
func watchIfAuto(ctx context.Context, userID, repoID int64, isWrite bool) error {
// WatchIfAuto subscribes to repo if AutoWatchOnChanges is set
func WatchIfAuto(ctx context.Context, userID, repoID int64, isWrite bool) error {
if !isWrite || !setting.Service.AutoWatchOnChanges {
return nil
}
@ -189,8 +185,3 @@ func watchIfAuto(ctx context.Context, userID, repoID int64, isWrite bool) error
}
return watchRepoMode(ctx, watch, WatchModeAuto)
}
// WatchIfAuto subscribes to repo if AutoWatchOnChanges is set
func WatchIfAuto(userID, repoID int64, isWrite bool) error {
return watchIfAuto(db.DefaultContext, userID, repoID, isWrite)
}

View file

@ -73,13 +73,13 @@ func TestWatchIfAuto(t *testing.T) {
prevCount := repo.NumWatches
// Must not add watch
assert.NoError(t, WatchIfAuto(8, 1, true))
assert.NoError(t, WatchIfAuto(db.DefaultContext, 8, 1, true))
watchers, err = GetRepoWatchers(repo.ID, db.ListOptions{Page: 1})
assert.NoError(t, err)
assert.Len(t, watchers, prevCount)
// Should not add watch
assert.NoError(t, WatchIfAuto(10, 1, true))
assert.NoError(t, WatchIfAuto(db.DefaultContext, 10, 1, true))
watchers, err = GetRepoWatchers(repo.ID, db.ListOptions{Page: 1})
assert.NoError(t, err)
assert.Len(t, watchers, prevCount)
@ -87,31 +87,31 @@ func TestWatchIfAuto(t *testing.T) {
setting.Service.AutoWatchOnChanges = true
// Must not add watch
assert.NoError(t, WatchIfAuto(8, 1, true))
assert.NoError(t, WatchIfAuto(db.DefaultContext, 8, 1, true))
watchers, err = GetRepoWatchers(repo.ID, db.ListOptions{Page: 1})
assert.NoError(t, err)
assert.Len(t, watchers, prevCount)
// Should not add watch
assert.NoError(t, WatchIfAuto(12, 1, false))
assert.NoError(t, WatchIfAuto(db.DefaultContext, 12, 1, false))
watchers, err = GetRepoWatchers(repo.ID, db.ListOptions{Page: 1})
assert.NoError(t, err)
assert.Len(t, watchers, prevCount)
// Should add watch
assert.NoError(t, WatchIfAuto(12, 1, true))
assert.NoError(t, WatchIfAuto(db.DefaultContext, 12, 1, true))
watchers, err = GetRepoWatchers(repo.ID, db.ListOptions{Page: 1})
assert.NoError(t, err)
assert.Len(t, watchers, prevCount+1)
// Should remove watch, inhibit from adding auto
assert.NoError(t, WatchRepo(12, 1, false))
assert.NoError(t, WatchRepo(db.DefaultContext, 12, 1, false))
watchers, err = GetRepoWatchers(repo.ID, db.ListOptions{Page: 1})
assert.NoError(t, err)
assert.Len(t, watchers, prevCount)
// Must not add watch
assert.NoError(t, WatchIfAuto(12, 1, true))
assert.NoError(t, WatchIfAuto(db.DefaultContext, 12, 1, true))
watchers, err = GetRepoWatchers(repo.ID, db.ListOptions{Page: 1})
assert.NoError(t, err)
assert.Len(t, watchers, prevCount)