forked from forgejo/forgejo
parent
a1b2a11812
commit
f91dbbba98
90 changed files with 434 additions and 464 deletions
|
@ -4,6 +4,7 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"code.gitea.io/gitea/models/db"
|
||||
|
@ -22,8 +23,8 @@ func init() {
|
|||
}
|
||||
|
||||
// UpdateSession updates the session with provided id
|
||||
func UpdateSession(key string, data []byte) error {
|
||||
_, err := db.GetEngine(db.DefaultContext).ID(key).Update(&Session{
|
||||
func UpdateSession(ctx context.Context, key string, data []byte) error {
|
||||
_, err := db.GetEngine(ctx).ID(key).Update(&Session{
|
||||
Data: data,
|
||||
Expiry: timeutil.TimeStampNow(),
|
||||
})
|
||||
|
@ -31,12 +32,12 @@ func UpdateSession(key string, data []byte) error {
|
|||
}
|
||||
|
||||
// ReadSession reads the data for the provided session
|
||||
func ReadSession(key string) (*Session, error) {
|
||||
func ReadSession(ctx context.Context, key string) (*Session, error) {
|
||||
session := Session{
|
||||
Key: key,
|
||||
}
|
||||
|
||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
||||
ctx, committer, err := db.TxContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -55,24 +56,24 @@ func ReadSession(key string) (*Session, error) {
|
|||
}
|
||||
|
||||
// ExistSession checks if a session exists
|
||||
func ExistSession(key string) (bool, error) {
|
||||
func ExistSession(ctx context.Context, key string) (bool, error) {
|
||||
session := Session{
|
||||
Key: key,
|
||||
}
|
||||
return db.GetEngine(db.DefaultContext).Get(&session)
|
||||
return db.GetEngine(ctx).Get(&session)
|
||||
}
|
||||
|
||||
// DestroySession destroys a session
|
||||
func DestroySession(key string) error {
|
||||
_, err := db.GetEngine(db.DefaultContext).Delete(&Session{
|
||||
func DestroySession(ctx context.Context, key string) error {
|
||||
_, err := db.GetEngine(ctx).Delete(&Session{
|
||||
Key: key,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// RegenerateSession regenerates a session from the old id
|
||||
func RegenerateSession(oldKey, newKey string) (*Session, error) {
|
||||
ctx, committer, err := db.TxContext(db.DefaultContext)
|
||||
func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, error) {
|
||||
ctx, committer, err := db.TxContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -114,12 +115,12 @@ func RegenerateSession(oldKey, newKey string) (*Session, error) {
|
|||
}
|
||||
|
||||
// CountSessions returns the number of sessions
|
||||
func CountSessions() (int64, error) {
|
||||
return db.GetEngine(db.DefaultContext).Count(&Session{})
|
||||
func CountSessions(ctx context.Context) (int64, error) {
|
||||
return db.GetEngine(ctx).Count(&Session{})
|
||||
}
|
||||
|
||||
// CleanupSessions cleans up expired sessions
|
||||
func CleanupSessions(maxLifetime int64) error {
|
||||
_, err := db.GetEngine(db.DefaultContext).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{})
|
||||
func CleanupSessions(ctx context.Context, maxLifetime int64) error {
|
||||
_, err := db.GetEngine(ctx).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{})
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -67,11 +67,7 @@ func (cred WebAuthnCredential) TableName() string {
|
|||
}
|
||||
|
||||
// UpdateSignCount will update the database value of SignCount
|
||||
func (cred *WebAuthnCredential) UpdateSignCount() error {
|
||||
return cred.updateSignCount(db.DefaultContext)
|
||||
}
|
||||
|
||||
func (cred *WebAuthnCredential) updateSignCount(ctx context.Context) error {
|
||||
func (cred *WebAuthnCredential) UpdateSignCount(ctx context.Context) error {
|
||||
_, err := db.GetEngine(ctx).ID(cred.ID).Cols("sign_count").Update(cred)
|
||||
return err
|
||||
}
|
||||
|
@ -113,30 +109,18 @@ func (list WebAuthnCredentialList) ToCredentials() []webauthn.Credential {
|
|||
}
|
||||
|
||||
// GetWebAuthnCredentialsByUID returns all WebAuthn credentials of the given user
|
||||
func GetWebAuthnCredentialsByUID(uid int64) (WebAuthnCredentialList, error) {
|
||||
return getWebAuthnCredentialsByUID(db.DefaultContext, uid)
|
||||
}
|
||||
|
||||
func getWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) {
|
||||
func GetWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) {
|
||||
creds := make(WebAuthnCredentialList, 0)
|
||||
return creds, db.GetEngine(ctx).Where("user_id = ?", uid).Find(&creds)
|
||||
}
|
||||
|
||||
// ExistsWebAuthnCredentialsForUID returns if the given user has credentials
|
||||
func ExistsWebAuthnCredentialsForUID(uid int64) (bool, error) {
|
||||
return existsWebAuthnCredentialsByUID(db.DefaultContext, uid)
|
||||
}
|
||||
|
||||
func existsWebAuthnCredentialsByUID(ctx context.Context, uid int64) (bool, error) {
|
||||
func ExistsWebAuthnCredentialsForUID(ctx context.Context, uid int64) (bool, error) {
|
||||
return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
|
||||
}
|
||||
|
||||
// GetWebAuthnCredentialByName returns WebAuthn credential by id
|
||||
func GetWebAuthnCredentialByName(uid int64, name string) (*WebAuthnCredential, error) {
|
||||
return getWebAuthnCredentialByName(db.DefaultContext, uid, name)
|
||||
}
|
||||
|
||||
func getWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) {
|
||||
func GetWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) {
|
||||
cred := new(WebAuthnCredential)
|
||||
if found, err := db.GetEngine(ctx).Where("user_id = ? AND lower_name = ?", uid, strings.ToLower(name)).Get(cred); err != nil {
|
||||
return nil, err
|
||||
|
@ -147,11 +131,7 @@ func getWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*
|
|||
}
|
||||
|
||||
// GetWebAuthnCredentialByID returns WebAuthn credential by id
|
||||
func GetWebAuthnCredentialByID(id int64) (*WebAuthnCredential, error) {
|
||||
return getWebAuthnCredentialByID(db.DefaultContext, id)
|
||||
}
|
||||
|
||||
func getWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) {
|
||||
func GetWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) {
|
||||
cred := new(WebAuthnCredential)
|
||||
if found, err := db.GetEngine(ctx).ID(id).Get(cred); err != nil {
|
||||
return nil, err
|
||||
|
@ -162,16 +142,12 @@ func getWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredenti
|
|||
}
|
||||
|
||||
// HasWebAuthnRegistrationsByUID returns whether a given user has WebAuthn registrations
|
||||
func HasWebAuthnRegistrationsByUID(uid int64) (bool, error) {
|
||||
return db.GetEngine(db.DefaultContext).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
|
||||
func HasWebAuthnRegistrationsByUID(ctx context.Context, uid int64) (bool, error) {
|
||||
return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
|
||||
}
|
||||
|
||||
// GetWebAuthnCredentialByCredID returns WebAuthn credential by credential ID
|
||||
func GetWebAuthnCredentialByCredID(userID int64, credID []byte) (*WebAuthnCredential, error) {
|
||||
return getWebAuthnCredentialByCredID(db.DefaultContext, userID, credID)
|
||||
}
|
||||
|
||||
func getWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) {
|
||||
func GetWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) {
|
||||
cred := new(WebAuthnCredential)
|
||||
if found, err := db.GetEngine(ctx).Where("user_id = ? AND credential_id = ?", userID, credID).Get(cred); err != nil {
|
||||
return nil, err
|
||||
|
@ -182,11 +158,7 @@ func getWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []b
|
|||
}
|
||||
|
||||
// CreateCredential will create a new WebAuthnCredential from the given Credential
|
||||
func CreateCredential(userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
|
||||
return createCredential(db.DefaultContext, userID, name, cred)
|
||||
}
|
||||
|
||||
func createCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
|
||||
func CreateCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
|
||||
c := &WebAuthnCredential{
|
||||
UserID: userID,
|
||||
Name: name,
|
||||
|
@ -205,18 +177,14 @@ func createCredential(ctx context.Context, userID int64, name string, cred *weba
|
|||
}
|
||||
|
||||
// DeleteCredential will delete WebAuthnCredential
|
||||
func DeleteCredential(id, userID int64) (bool, error) {
|
||||
return deleteCredential(db.DefaultContext, id, userID)
|
||||
}
|
||||
|
||||
func deleteCredential(ctx context.Context, id, userID int64) (bool, error) {
|
||||
func DeleteCredential(ctx context.Context, id, userID int64) (bool, error) {
|
||||
had, err := db.GetEngine(ctx).ID(id).Where("user_id = ?", userID).Delete(&WebAuthnCredential{})
|
||||
return had > 0, err
|
||||
}
|
||||
|
||||
// WebAuthnCredentials implementns the webauthn.User interface
|
||||
func WebAuthnCredentials(userID int64) ([]webauthn.Credential, error) {
|
||||
dbCreds, err := GetWebAuthnCredentialsByUID(userID)
|
||||
func WebAuthnCredentials(ctx context.Context, userID int64) ([]webauthn.Credential, error) {
|
||||
dbCreds, err := GetWebAuthnCredentialsByUID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"testing"
|
||||
|
||||
auth_model "code.gitea.io/gitea/models/auth"
|
||||
"code.gitea.io/gitea/models/db"
|
||||
"code.gitea.io/gitea/models/unittest"
|
||||
|
||||
"github.com/go-webauthn/webauthn/webauthn"
|
||||
|
@ -16,11 +17,11 @@ import (
|
|||
func TestGetWebAuthnCredentialByID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
res, err := auth_model.GetWebAuthnCredentialByID(1)
|
||||
res, err := auth_model.GetWebAuthnCredentialByID(db.DefaultContext, 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "WebAuthn credential", res.Name)
|
||||
|
||||
_, err = auth_model.GetWebAuthnCredentialByID(342432)
|
||||
_, err = auth_model.GetWebAuthnCredentialByID(db.DefaultContext, 342432)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, auth_model.IsErrWebAuthnCredentialNotExist(err))
|
||||
}
|
||||
|
@ -28,7 +29,7 @@ func TestGetWebAuthnCredentialByID(t *testing.T) {
|
|||
func TestGetWebAuthnCredentialsByUID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
res, err := auth_model.GetWebAuthnCredentialsByUID(32)
|
||||
res, err := auth_model.GetWebAuthnCredentialsByUID(db.DefaultContext, 32)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, res, 1)
|
||||
assert.Equal(t, "WebAuthn credential", res[0].Name)
|
||||
|
@ -42,7 +43,7 @@ func TestWebAuthnCredential_UpdateSignCount(t *testing.T) {
|
|||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1})
|
||||
cred.SignCount = 1
|
||||
assert.NoError(t, cred.UpdateSignCount())
|
||||
assert.NoError(t, cred.UpdateSignCount(db.DefaultContext))
|
||||
unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 1})
|
||||
}
|
||||
|
||||
|
@ -50,14 +51,14 @@ func TestWebAuthnCredential_UpdateLargeCounter(t *testing.T) {
|
|||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1})
|
||||
cred.SignCount = 0xffffffff
|
||||
assert.NoError(t, cred.UpdateSignCount())
|
||||
assert.NoError(t, cred.UpdateSignCount(db.DefaultContext))
|
||||
unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 0xffffffff})
|
||||
}
|
||||
|
||||
func TestCreateCredential(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
res, err := auth_model.CreateCredential(1, "WebAuthn Created Credential", &webauthn.Credential{ID: []byte("Test")})
|
||||
res, err := auth_model.CreateCredential(db.DefaultContext, 1, "WebAuthn Created Credential", &webauthn.Credential{ID: []byte("Test")})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "WebAuthn Created Credential", res.Name)
|
||||
assert.Equal(t, []byte("Test"), res.CredentialID)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue