1
0
Fork 0
forked from forgejo/forgejo

Next round of db.DefaultContext refactor (#27089)

Part of #27065
This commit is contained in:
JakobDev 2023-09-16 16:39:12 +02:00 committed by GitHub
parent a1b2a11812
commit f91dbbba98
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
90 changed files with 434 additions and 464 deletions

View file

@ -4,6 +4,7 @@
package auth
import (
"context"
"fmt"
"code.gitea.io/gitea/models/db"
@ -22,8 +23,8 @@ func init() {
}
// UpdateSession updates the session with provided id
func UpdateSession(key string, data []byte) error {
_, err := db.GetEngine(db.DefaultContext).ID(key).Update(&Session{
func UpdateSession(ctx context.Context, key string, data []byte) error {
_, err := db.GetEngine(ctx).ID(key).Update(&Session{
Data: data,
Expiry: timeutil.TimeStampNow(),
})
@ -31,12 +32,12 @@ func UpdateSession(key string, data []byte) error {
}
// ReadSession reads the data for the provided session
func ReadSession(key string) (*Session, error) {
func ReadSession(ctx context.Context, key string) (*Session, error) {
session := Session{
Key: key,
}
ctx, committer, err := db.TxContext(db.DefaultContext)
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, err
}
@ -55,24 +56,24 @@ func ReadSession(key string) (*Session, error) {
}
// ExistSession checks if a session exists
func ExistSession(key string) (bool, error) {
func ExistSession(ctx context.Context, key string) (bool, error) {
session := Session{
Key: key,
}
return db.GetEngine(db.DefaultContext).Get(&session)
return db.GetEngine(ctx).Get(&session)
}
// DestroySession destroys a session
func DestroySession(key string) error {
_, err := db.GetEngine(db.DefaultContext).Delete(&Session{
func DestroySession(ctx context.Context, key string) error {
_, err := db.GetEngine(ctx).Delete(&Session{
Key: key,
})
return err
}
// RegenerateSession regenerates a session from the old id
func RegenerateSession(oldKey, newKey string) (*Session, error) {
ctx, committer, err := db.TxContext(db.DefaultContext)
func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, error) {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, err
}
@ -114,12 +115,12 @@ func RegenerateSession(oldKey, newKey string) (*Session, error) {
}
// CountSessions returns the number of sessions
func CountSessions() (int64, error) {
return db.GetEngine(db.DefaultContext).Count(&Session{})
func CountSessions(ctx context.Context) (int64, error) {
return db.GetEngine(ctx).Count(&Session{})
}
// CleanupSessions cleans up expired sessions
func CleanupSessions(maxLifetime int64) error {
_, err := db.GetEngine(db.DefaultContext).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{})
func CleanupSessions(ctx context.Context, maxLifetime int64) error {
_, err := db.GetEngine(ctx).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{})
return err
}

View file

@ -67,11 +67,7 @@ func (cred WebAuthnCredential) TableName() string {
}
// UpdateSignCount will update the database value of SignCount
func (cred *WebAuthnCredential) UpdateSignCount() error {
return cred.updateSignCount(db.DefaultContext)
}
func (cred *WebAuthnCredential) updateSignCount(ctx context.Context) error {
func (cred *WebAuthnCredential) UpdateSignCount(ctx context.Context) error {
_, err := db.GetEngine(ctx).ID(cred.ID).Cols("sign_count").Update(cred)
return err
}
@ -113,30 +109,18 @@ func (list WebAuthnCredentialList) ToCredentials() []webauthn.Credential {
}
// GetWebAuthnCredentialsByUID returns all WebAuthn credentials of the given user
func GetWebAuthnCredentialsByUID(uid int64) (WebAuthnCredentialList, error) {
return getWebAuthnCredentialsByUID(db.DefaultContext, uid)
}
func getWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) {
func GetWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) {
creds := make(WebAuthnCredentialList, 0)
return creds, db.GetEngine(ctx).Where("user_id = ?", uid).Find(&creds)
}
// ExistsWebAuthnCredentialsForUID returns if the given user has credentials
func ExistsWebAuthnCredentialsForUID(uid int64) (bool, error) {
return existsWebAuthnCredentialsByUID(db.DefaultContext, uid)
}
func existsWebAuthnCredentialsByUID(ctx context.Context, uid int64) (bool, error) {
func ExistsWebAuthnCredentialsForUID(ctx context.Context, uid int64) (bool, error) {
return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
}
// GetWebAuthnCredentialByName returns WebAuthn credential by id
func GetWebAuthnCredentialByName(uid int64, name string) (*WebAuthnCredential, error) {
return getWebAuthnCredentialByName(db.DefaultContext, uid, name)
}
func getWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) {
func GetWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) {
cred := new(WebAuthnCredential)
if found, err := db.GetEngine(ctx).Where("user_id = ? AND lower_name = ?", uid, strings.ToLower(name)).Get(cred); err != nil {
return nil, err
@ -147,11 +131,7 @@ func getWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*
}
// GetWebAuthnCredentialByID returns WebAuthn credential by id
func GetWebAuthnCredentialByID(id int64) (*WebAuthnCredential, error) {
return getWebAuthnCredentialByID(db.DefaultContext, id)
}
func getWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) {
func GetWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) {
cred := new(WebAuthnCredential)
if found, err := db.GetEngine(ctx).ID(id).Get(cred); err != nil {
return nil, err
@ -162,16 +142,12 @@ func getWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredenti
}
// HasWebAuthnRegistrationsByUID returns whether a given user has WebAuthn registrations
func HasWebAuthnRegistrationsByUID(uid int64) (bool, error) {
return db.GetEngine(db.DefaultContext).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
func HasWebAuthnRegistrationsByUID(ctx context.Context, uid int64) (bool, error) {
return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
}
// GetWebAuthnCredentialByCredID returns WebAuthn credential by credential ID
func GetWebAuthnCredentialByCredID(userID int64, credID []byte) (*WebAuthnCredential, error) {
return getWebAuthnCredentialByCredID(db.DefaultContext, userID, credID)
}
func getWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) {
func GetWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) {
cred := new(WebAuthnCredential)
if found, err := db.GetEngine(ctx).Where("user_id = ? AND credential_id = ?", userID, credID).Get(cred); err != nil {
return nil, err
@ -182,11 +158,7 @@ func getWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []b
}
// CreateCredential will create a new WebAuthnCredential from the given Credential
func CreateCredential(userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
return createCredential(db.DefaultContext, userID, name, cred)
}
func createCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
func CreateCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
c := &WebAuthnCredential{
UserID: userID,
Name: name,
@ -205,18 +177,14 @@ func createCredential(ctx context.Context, userID int64, name string, cred *weba
}
// DeleteCredential will delete WebAuthnCredential
func DeleteCredential(id, userID int64) (bool, error) {
return deleteCredential(db.DefaultContext, id, userID)
}
func deleteCredential(ctx context.Context, id, userID int64) (bool, error) {
func DeleteCredential(ctx context.Context, id, userID int64) (bool, error) {
had, err := db.GetEngine(ctx).ID(id).Where("user_id = ?", userID).Delete(&WebAuthnCredential{})
return had > 0, err
}
// WebAuthnCredentials implementns the webauthn.User interface
func WebAuthnCredentials(userID int64) ([]webauthn.Credential, error) {
dbCreds, err := GetWebAuthnCredentialsByUID(userID)
func WebAuthnCredentials(ctx context.Context, userID int64) ([]webauthn.Credential, error) {
dbCreds, err := GetWebAuthnCredentialsByUID(ctx, userID)
if err != nil {
return nil, err
}

View file

@ -7,6 +7,7 @@ import (
"testing"
auth_model "code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
"github.com/go-webauthn/webauthn/webauthn"
@ -16,11 +17,11 @@ import (
func TestGetWebAuthnCredentialByID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
res, err := auth_model.GetWebAuthnCredentialByID(1)
res, err := auth_model.GetWebAuthnCredentialByID(db.DefaultContext, 1)
assert.NoError(t, err)
assert.Equal(t, "WebAuthn credential", res.Name)
_, err = auth_model.GetWebAuthnCredentialByID(342432)
_, err = auth_model.GetWebAuthnCredentialByID(db.DefaultContext, 342432)
assert.Error(t, err)
assert.True(t, auth_model.IsErrWebAuthnCredentialNotExist(err))
}
@ -28,7 +29,7 @@ func TestGetWebAuthnCredentialByID(t *testing.T) {
func TestGetWebAuthnCredentialsByUID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
res, err := auth_model.GetWebAuthnCredentialsByUID(32)
res, err := auth_model.GetWebAuthnCredentialsByUID(db.DefaultContext, 32)
assert.NoError(t, err)
assert.Len(t, res, 1)
assert.Equal(t, "WebAuthn credential", res[0].Name)
@ -42,7 +43,7 @@ func TestWebAuthnCredential_UpdateSignCount(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1})
cred.SignCount = 1
assert.NoError(t, cred.UpdateSignCount())
assert.NoError(t, cred.UpdateSignCount(db.DefaultContext))
unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 1})
}
@ -50,14 +51,14 @@ func TestWebAuthnCredential_UpdateLargeCounter(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1})
cred.SignCount = 0xffffffff
assert.NoError(t, cred.UpdateSignCount())
assert.NoError(t, cred.UpdateSignCount(db.DefaultContext))
unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 0xffffffff})
}
func TestCreateCredential(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
res, err := auth_model.CreateCredential(1, "WebAuthn Created Credential", &webauthn.Credential{ID: []byte("Test")})
res, err := auth_model.CreateCredential(db.DefaultContext, 1, "WebAuthn Created Credential", &webauthn.Credential{ID: []byte("Test")})
assert.NoError(t, err)
assert.Equal(t, "WebAuthn Created Credential", res.Name)
assert.Equal(t, []byte("Test"), res.CredentialID)