1
0
Fork 0
forked from forgejo/forgejo

DBContext is just a Context (#17100)

* DBContext is just a Context

This PR removes some of the specialness from the DBContext and makes it context
This allows us to simplify the GetEngine code to wrap around any context in future
and means that we can change our loadRepo(e Engine) functions to simply take contexts.

Signed-off-by: Andrew Thornton <art27@cantab.net>

* fix unit tests

Signed-off-by: Andrew Thornton <art27@cantab.net>

* another place that needs to set the initial context

Signed-off-by: Andrew Thornton <art27@cantab.net>

* avoid race

Signed-off-by: Andrew Thornton <art27@cantab.net>

* change attachment error

Signed-off-by: Andrew Thornton <art27@cantab.net>
This commit is contained in:
zeripath 2021-09-23 16:45:36 +01:00 committed by GitHub
parent b22be7f594
commit 9302eba971
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
129 changed files with 1112 additions and 1022 deletions

View file

@ -208,7 +208,7 @@ func (source *LoginSource) SkipVerify() bool {
// CreateLoginSource inserts a LoginSource in the DB if not already
// existing with the given name.
func CreateLoginSource(source *LoginSource) error {
has, err := db.DefaultContext().Engine().Where("name=?", source.Name).Exist(new(LoginSource))
has, err := db.GetEngine(db.DefaultContext).Where("name=?", source.Name).Exist(new(LoginSource))
if err != nil {
return err
} else if has {
@ -219,7 +219,7 @@ func CreateLoginSource(source *LoginSource) error {
source.IsSyncEnabled = false
}
_, err = db.DefaultContext().Engine().Insert(source)
_, err = db.GetEngine(db.DefaultContext).Insert(source)
if err != nil {
return err
}
@ -240,7 +240,7 @@ func CreateLoginSource(source *LoginSource) error {
err = registerableSource.RegisterSource()
if err != nil {
// remove the LoginSource in case of errors while registering configuration
if _, err := db.DefaultContext().Engine().Delete(source); err != nil {
if _, err := db.GetEngine(db.DefaultContext).Delete(source); err != nil {
log.Error("CreateLoginSource: Error while wrapOpenIDConnectInitializeError: %v", err)
}
}
@ -250,13 +250,13 @@ func CreateLoginSource(source *LoginSource) error {
// LoginSources returns a slice of all login sources found in DB.
func LoginSources() ([]*LoginSource, error) {
auths := make([]*LoginSource, 0, 6)
return auths, db.DefaultContext().Engine().Find(&auths)
return auths, db.GetEngine(db.DefaultContext).Find(&auths)
}
// LoginSourcesByType returns all sources of the specified type
func LoginSourcesByType(loginType LoginType) ([]*LoginSource, error) {
sources := make([]*LoginSource, 0, 1)
if err := db.DefaultContext().Engine().Where("type = ?", loginType).Find(&sources); err != nil {
if err := db.GetEngine(db.DefaultContext).Where("type = ?", loginType).Find(&sources); err != nil {
return nil, err
}
return sources, nil
@ -265,7 +265,7 @@ func LoginSourcesByType(loginType LoginType) ([]*LoginSource, error) {
// AllActiveLoginSources returns all active sources
func AllActiveLoginSources() ([]*LoginSource, error) {
sources := make([]*LoginSource, 0, 5)
if err := db.DefaultContext().Engine().Where("is_active = ?", true).Find(&sources); err != nil {
if err := db.GetEngine(db.DefaultContext).Where("is_active = ?", true).Find(&sources); err != nil {
return nil, err
}
return sources, nil
@ -274,7 +274,7 @@ func AllActiveLoginSources() ([]*LoginSource, error) {
// ActiveLoginSources returns all active sources of the specified type
func ActiveLoginSources(loginType LoginType) ([]*LoginSource, error) {
sources := make([]*LoginSource, 0, 1)
if err := db.DefaultContext().Engine().Where("is_active = ? and type = ?", true, loginType).Find(&sources); err != nil {
if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, loginType).Find(&sources); err != nil {
return nil, err
}
return sources, nil
@ -305,7 +305,7 @@ func GetLoginSourceByID(id int64) (*LoginSource, error) {
return source, nil
}
has, err := db.DefaultContext().Engine().ID(id).Get(source)
has, err := db.GetEngine(db.DefaultContext).ID(id).Get(source)
if err != nil {
return nil, err
} else if !has {
@ -325,7 +325,7 @@ func UpdateSource(source *LoginSource) error {
}
}
_, err := db.DefaultContext().Engine().ID(source.ID).AllCols().Update(source)
_, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(source)
if err != nil {
return err
}
@ -346,7 +346,7 @@ func UpdateSource(source *LoginSource) error {
err = registerableSource.RegisterSource()
if err != nil {
// restore original values since we cannot update the provider it self
if _, err := db.DefaultContext().Engine().ID(source.ID).AllCols().Update(originalLoginSource); err != nil {
if _, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(originalLoginSource); err != nil {
log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
}
}
@ -355,14 +355,14 @@ func UpdateSource(source *LoginSource) error {
// DeleteSource deletes a LoginSource record in DB.
func DeleteSource(source *LoginSource) error {
count, err := db.DefaultContext().Engine().Count(&User{LoginSource: source.ID})
count, err := db.GetEngine(db.DefaultContext).Count(&User{LoginSource: source.ID})
if err != nil {
return err
} else if count > 0 {
return ErrLoginSourceInUse{source.ID}
}
count, err = db.DefaultContext().Engine().Count(&ExternalLoginUser{LoginSourceID: source.ID})
count, err = db.GetEngine(db.DefaultContext).Count(&ExternalLoginUser{LoginSourceID: source.ID})
if err != nil {
return err
} else if count > 0 {
@ -375,12 +375,12 @@ func DeleteSource(source *LoginSource) error {
}
}
_, err = db.DefaultContext().Engine().ID(source.ID).Delete(new(LoginSource))
_, err = db.GetEngine(db.DefaultContext).ID(source.ID).Delete(new(LoginSource))
return err
}
// CountLoginSources returns number of login sources.
func CountLoginSources() int64 {
count, _ := db.DefaultContext().Engine().Count(new(LoginSource))
count, _ := db.GetEngine(db.DefaultContext).Count(new(LoginSource))
return count
}