Add first helper method to user

This commit is contained in:
Melody Becker 2025-04-04 13:46:11 +02:00
parent f3a139b809
commit 11e0059631
Signed by: mstar
SSH key fingerprint: SHA256:9VAo09aaVNTWKzPW7Hq2LW+ox9OdwmTSHRoD4mlz1yI
6 changed files with 16 additions and 7 deletions

View file

@ -177,7 +177,7 @@ func oneStorageAuthToLoginState(in models.AuthenticationMethodType) LoginNextSta
// //
// TODO: Decide whether to include the reason for disallowed login // TODO: Decide whether to include the reason for disallowed login
func (a *Authenticator) canUsernameLogin(username string) (bool, error) { func (a *Authenticator) canUsernameLogin(username string) (bool, error) {
acc, err := dbgen.User.Where(dbgen.User.Username.Eq(username)).First() acc, err := dbgen.User.GetByUsername(username)
if err != nil { if err != nil {
return false, err return false, err
} }

View file

@ -31,7 +31,7 @@ func (a *Authenticator) StartPasskeyLogin(
if ok, err := a.canUsernameLogin(username); !ok { if ok, err := a.canUsernameLogin(username); !ok {
return nil, "", other.Error("auth", "user may not login", err) return nil, "", other.Error("auth", "user may not login", err)
} }
acc, err := dbgen.User.Where(dbgen.User.Username.Eq(username)).First() acc, err := dbgen.User.GetByUsername(username)
if err != nil { if err != nil {
return nil, "", other.Error("auth", "failed to acquire user for login", err) return nil, "", other.Error("auth", "failed to acquire user for login", err)
} }
@ -64,7 +64,7 @@ func (a *Authenticator) CompletePasskeyLogin(
response *http.Request, response *http.Request,
) (accessToken string, err error) { ) (accessToken string, err error) {
// Get user in question // Get user in question
acc, err := dbgen.User.Where(dbgen.User.Username.Eq(username)).First() acc, err := dbgen.User.GetByUsername(username)
if err != nil { if err != nil {
return "", other.Error("auth", "failed to get user for passkey login completion", err) return "", other.Error("auth", "failed to get user for passkey login completion", err)
} }
@ -140,7 +140,7 @@ func (a *Authenticator) StartPasskeyRegistration(
if ok, err := a.canUsernameLogin(username); !ok { if ok, err := a.canUsernameLogin(username); !ok {
return nil, "", other.Error("auth", "user may not login", err) return nil, "", other.Error("auth", "user may not login", err)
} }
acc, err := dbgen.User.Where(dbgen.User.Username.Eq(username)).First() acc, err := dbgen.User.GetByUsername(username)
if err != nil { if err != nil {
return nil, "", other.Error("auth", "failed to acquire user for login", err) return nil, "", other.Error("auth", "failed to acquire user for login", err)
} }

View file

@ -24,7 +24,7 @@ func (a *Authenticator) PerformPasswordLogin(
if ok, err := a.canUsernameLogin(username); !ok { if ok, err := a.canUsernameLogin(username); !ok {
return LoginNextFailure, "", other.Error("auth", "user may not login", err) return LoginNextFailure, "", other.Error("auth", "user may not login", err)
} }
acc, err := dbgen.User.Where(dbgen.User.Username.Eq(username)).First() acc, err := dbgen.User.GetByUsername(username)
switch err { switch err {
case nil: case nil:
break break
@ -110,7 +110,7 @@ func (a *Authenticator) PerformPasswordLogin(
// If there is no password set yet (i.e. during account registration or passkey only so far) // If there is no password set yet (i.e. during account registration or passkey only so far)
// it creates the password link // it creates the password link
func (a *Authenticator) PerformPasswordRegister(username, password string) error { func (a *Authenticator) PerformPasswordRegister(username, password string) error {
acc, err := dbgen.User.Where(dbgen.User.Username.Eq(username)).First() acc, err := dbgen.User.GetByUsername(username)
if err != nil { if err != nil {
return other.Error("auth", "failed to get user to add a password to", err) return other.Error("auth", "failed to get user to add a password to", err)
} }

View file

@ -129,7 +129,7 @@ func (a *Authenticator) StartTotpRegistration(
if ok, err := a.canUsernameLogin(username); !ok { if ok, err := a.canUsernameLogin(username); !ok {
return nil, other.Error("auth", "user may not login", err) return nil, other.Error("auth", "user may not login", err)
} }
acc, err := dbgen.User.Where(dbgen.User.Username.Eq(username)).First() acc, err := dbgen.User.GetByUsername(username)
if err != nil { if err != nil {
return nil, other.Error("auth", "failed to find account", err) return nil, other.Error("auth", "failed to find account", err)
} }

View file

@ -53,6 +53,7 @@ func main() {
log.Info().Msg("Basic operations applied, applying extra features") log.Info().Msg("Basic operations applied, applying extra features")
g.ApplyInterface(func(models.INotification) {}, models.Notification{}) g.ApplyInterface(func(models.INotification) {}, models.Notification{})
g.ApplyInterface(func(models.IUser) {}, models.User{})
log.Info().Msg("Extra features applied, starting generation") log.Info().Msg("Extra features applied, starting generation")
g.Execute() g.Execute()

View file

@ -4,6 +4,7 @@ import (
"database/sql" "database/sql"
"time" "time"
"gorm.io/gen"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -76,3 +77,10 @@ type User struct {
RemoteInfo *UserRemoteLinks RemoteInfo *UserRemoteLinks
AuthMethods []UserAuthMethod AuthMethods []UserAuthMethod
} }
type IUser interface {
// Get a user by a username
//
// SELECT * FROM @@table WHERE username = @username LIMIT 1
GetByUsername(username string) (*gen.T, error)
}