1
0
Fork 0
forked from forgejo/forgejo

DBContext is just a Context (#17100)

* DBContext is just a Context

This PR removes some of the specialness from the DBContext and makes it context
This allows us to simplify the GetEngine code to wrap around any context in future
and means that we can change our loadRepo(e Engine) functions to simply take contexts.

Signed-off-by: Andrew Thornton <art27@cantab.net>

* fix unit tests

Signed-off-by: Andrew Thornton <art27@cantab.net>

* another place that needs to set the initial context

Signed-off-by: Andrew Thornton <art27@cantab.net>

* avoid race

Signed-off-by: Andrew Thornton <art27@cantab.net>

* change attachment error

Signed-off-by: Andrew Thornton <art27@cantab.net>
This commit is contained in:
zeripath 2021-09-23 16:45:36 +01:00 committed by GitHub
parent b22be7f594
commit 9302eba971
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
129 changed files with 1112 additions and 1022 deletions

View file

@ -122,7 +122,7 @@ func (pr *PullRequest) loadAttributes(e db.Engine) (err error) {
// LoadAttributes loads pull request attributes from database
func (pr *PullRequest) LoadAttributes() error {
return pr.loadAttributes(db.DefaultContext().Engine())
return pr.loadAttributes(db.GetEngine(db.DefaultContext))
}
func (pr *PullRequest) loadHeadRepo(e db.Engine) (err error) {
@ -148,12 +148,12 @@ func (pr *PullRequest) loadHeadRepo(e db.Engine) (err error) {
// LoadHeadRepo loads the head repository
func (pr *PullRequest) LoadHeadRepo() error {
return pr.loadHeadRepo(db.DefaultContext().Engine())
return pr.loadHeadRepo(db.GetEngine(db.DefaultContext))
}
// LoadBaseRepo loads the target repository
func (pr *PullRequest) LoadBaseRepo() error {
return pr.loadBaseRepo(db.DefaultContext().Engine())
return pr.loadBaseRepo(db.GetEngine(db.DefaultContext))
}
func (pr *PullRequest) loadBaseRepo(e db.Engine) (err error) {
@ -180,7 +180,7 @@ func (pr *PullRequest) loadBaseRepo(e db.Engine) (err error) {
// LoadIssue loads issue information from database
func (pr *PullRequest) LoadIssue() (err error) {
return pr.loadIssue(db.DefaultContext().Engine())
return pr.loadIssue(db.GetEngine(db.DefaultContext))
}
func (pr *PullRequest) loadIssue(e db.Engine) (err error) {
@ -197,7 +197,7 @@ func (pr *PullRequest) loadIssue(e db.Engine) (err error) {
// LoadProtectedBranch loads the protected branch of the base branch
func (pr *PullRequest) LoadProtectedBranch() (err error) {
return pr.loadProtectedBranch(db.DefaultContext().Engine())
return pr.loadProtectedBranch(db.GetEngine(db.DefaultContext))
}
func (pr *PullRequest) loadProtectedBranch(e db.Engine) (err error) {
@ -257,7 +257,7 @@ type ReviewCount struct {
// GetApprovalCounts returns the approval counts by type
// FIXME: Only returns official counts due to double counting of non-official counts
func (pr *PullRequest) GetApprovalCounts() ([]*ReviewCount, error) {
return pr.getApprovalCounts(db.DefaultContext().Engine())
return pr.getApprovalCounts(db.GetEngine(db.DefaultContext))
}
func (pr *PullRequest) getApprovalCounts(e db.Engine) ([]*ReviewCount, error) {
@ -284,7 +284,7 @@ func (pr *PullRequest) getReviewedByLines(writer io.Writer) error {
return nil
}
sess := db.DefaultContext().NewSession()
sess := db.NewSession(db.DefaultContext)
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@ -393,7 +393,7 @@ func (pr *PullRequest) SetMerged() (bool, error) {
pr.HasMerged = true
sess := db.DefaultContext().NewSession()
sess := db.NewSession(db.DefaultContext)
defer sess.Close()
if err := sess.Begin(); err != nil {
return false, err
@ -455,7 +455,7 @@ func NewPullRequest(repo *Repository, issue *Issue, labelIDs []int64, uuids []st
issue.Index = idx
sess := db.DefaultContext().NewSession()
sess := db.NewSession(db.DefaultContext)
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@ -492,7 +492,7 @@ func NewPullRequest(repo *Repository, issue *Issue, labelIDs []int64, uuids []st
// by given head/base and repo/branch.
func GetUnmergedPullRequest(headRepoID, baseRepoID int64, headBranch, baseBranch string, flow PullRequestFlow) (*PullRequest, error) {
pr := new(PullRequest)
has, err := db.DefaultContext().Engine().
has, err := db.GetEngine(db.DefaultContext).
Where("head_repo_id=? AND head_branch=? AND base_repo_id=? AND base_branch=? AND has_merged=? AND flow = ? AND issue.is_closed=?",
headRepoID, headBranch, baseRepoID, baseBranch, false, flow, false).
Join("INNER", "issue", "issue.id=pull_request.issue_id").
@ -510,7 +510,7 @@ func GetUnmergedPullRequest(headRepoID, baseRepoID int64, headBranch, baseBranch
// by given head information (repo and branch).
func GetLatestPullRequestByHeadInfo(repoID int64, branch string) (*PullRequest, error) {
pr := new(PullRequest)
has, err := db.DefaultContext().Engine().
has, err := db.GetEngine(db.DefaultContext).
Where("head_repo_id = ? AND head_branch = ? AND flow = ?", repoID, branch, PullRequestFlowGithub).
OrderBy("id DESC").
Get(pr)
@ -527,7 +527,7 @@ func GetPullRequestByIndex(repoID, index int64) (*PullRequest, error) {
Index: index,
}
has, err := db.DefaultContext().Engine().Get(pr)
has, err := db.GetEngine(db.DefaultContext).Get(pr)
if err != nil {
return nil, err
} else if !has {
@ -557,13 +557,13 @@ func getPullRequestByID(e db.Engine, id int64) (*PullRequest, error) {
// GetPullRequestByID returns a pull request by given ID.
func GetPullRequestByID(id int64) (*PullRequest, error) {
return getPullRequestByID(db.DefaultContext().Engine(), id)
return getPullRequestByID(db.GetEngine(db.DefaultContext), id)
}
// GetPullRequestByIssueIDWithNoAttributes returns pull request with no attributes loaded by given issue ID.
func GetPullRequestByIssueIDWithNoAttributes(issueID int64) (*PullRequest, error) {
var pr PullRequest
has, err := db.DefaultContext().Engine().Where("issue_id = ?", issueID).Get(&pr)
has, err := db.GetEngine(db.DefaultContext).Where("issue_id = ?", issueID).Get(&pr)
if err != nil {
return nil, err
}
@ -591,7 +591,7 @@ func getPullRequestByIssueID(e db.Engine, issueID int64) (*PullRequest, error) {
func GetAllUnmergedAgitPullRequestByPoster(uid int64) ([]*PullRequest, error) {
pulls := make([]*PullRequest, 0, 10)
err := db.DefaultContext().Engine().
err := db.GetEngine(db.DefaultContext).
Where("has_merged=? AND flow = ? AND issue.is_closed=? AND issue.poster_id=?",
false, PullRequestFlowAGit, false, uid).
Join("INNER", "issue", "issue.id=pull_request.issue_id").
@ -602,24 +602,24 @@ func GetAllUnmergedAgitPullRequestByPoster(uid int64) ([]*PullRequest, error) {
// GetPullRequestByIssueID returns pull request by given issue ID.
func GetPullRequestByIssueID(issueID int64) (*PullRequest, error) {
return getPullRequestByIssueID(db.DefaultContext().Engine(), issueID)
return getPullRequestByIssueID(db.GetEngine(db.DefaultContext), issueID)
}
// Update updates all fields of pull request.
func (pr *PullRequest) Update() error {
_, err := db.DefaultContext().Engine().ID(pr.ID).AllCols().Update(pr)
_, err := db.GetEngine(db.DefaultContext).ID(pr.ID).AllCols().Update(pr)
return err
}
// UpdateCols updates specific fields of pull request.
func (pr *PullRequest) UpdateCols(cols ...string) error {
_, err := db.DefaultContext().Engine().ID(pr.ID).Cols(cols...).Update(pr)
_, err := db.GetEngine(db.DefaultContext).ID(pr.ID).Cols(cols...).Update(pr)
return err
}
// UpdateColsIfNotMerged updates specific fields of a pull request if it has not been merged
func (pr *PullRequest) UpdateColsIfNotMerged(cols ...string) error {
_, err := db.DefaultContext().Engine().Where("id = ? AND has_merged = ?", pr.ID, false).Cols(cols...).Update(pr)
_, err := db.GetEngine(db.DefaultContext).Where("id = ? AND has_merged = ?", pr.ID, false).Cols(cols...).Update(pr)
return err
}
@ -665,7 +665,7 @@ func (pr *PullRequest) GetWorkInProgressPrefix() string {
// UpdateCommitDivergence update Divergence of a pull request
func (pr *PullRequest) UpdateCommitDivergence(ahead, behind int) error {
return pr.updateCommitDivergence(db.DefaultContext().Engine(), ahead, behind)
return pr.updateCommitDivergence(db.GetEngine(db.DefaultContext), ahead, behind)
}
func (pr *PullRequest) updateCommitDivergence(e db.Engine, ahead, behind int) error {