Fix passkey authentication
Also prep for better router layout
This commit is contained in:
parent
e2260e4a0f
commit
b9eb4234f4
11 changed files with 289 additions and 21 deletions
2
go.mod
2
go.mod
|
@ -20,7 +20,7 @@ require (
|
||||||
github.com/rs/zerolog v1.33.0
|
github.com/rs/zerolog v1.33.0
|
||||||
github.com/xhit/go-simple-mail/v2 v2.16.0
|
github.com/xhit/go-simple-mail/v2 v2.16.0
|
||||||
gitlab.com/mstarongitlab/goap v1.1.0
|
gitlab.com/mstarongitlab/goap v1.1.0
|
||||||
gitlab.com/mstarongitlab/goutils v1.3.0
|
gitlab.com/mstarongitlab/goutils v1.4.1
|
||||||
golang.org/x/image v0.20.0
|
golang.org/x/image v0.20.0
|
||||||
gorm.io/driver/postgres v1.5.7
|
gorm.io/driver/postgres v1.5.7
|
||||||
gorm.io/gorm v1.25.10
|
gorm.io/gorm v1.25.10
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -304,6 +304,8 @@ gitlab.com/mstarongitlab/goap v1.1.0 h1:uN05RP+Tq2NR2IuPq6XQa5oLpfailpoEvxo1Sfeh
|
||||||
gitlab.com/mstarongitlab/goap v1.1.0/go.mod h1:rt9IYvJBPh1z6t+vvzifmxDtGjGlr8683tSPfa5dbXI=
|
gitlab.com/mstarongitlab/goap v1.1.0/go.mod h1:rt9IYvJBPh1z6t+vvzifmxDtGjGlr8683tSPfa5dbXI=
|
||||||
gitlab.com/mstarongitlab/goutils v1.3.0 h1:uuxPHjIU36lyJ8/z4T2xI32zOyh53Xj0Au8K12qkaJ4=
|
gitlab.com/mstarongitlab/goutils v1.3.0 h1:uuxPHjIU36lyJ8/z4T2xI32zOyh53Xj0Au8K12qkaJ4=
|
||||||
gitlab.com/mstarongitlab/goutils v1.3.0/go.mod h1:SvqfzFxgashuZPqR9kPwQ9gFA7I1yskZjhmGmY2pAow=
|
gitlab.com/mstarongitlab/goutils v1.3.0/go.mod h1:SvqfzFxgashuZPqR9kPwQ9gFA7I1yskZjhmGmY2pAow=
|
||||||
|
gitlab.com/mstarongitlab/goutils v1.4.1 h1:g6bLX1gGqQeoRwmFC0Aw9lvxyw22cWjyG7RXsi+JmlI=
|
||||||
|
gitlab.com/mstarongitlab/goutils v1.4.1/go.mod h1:SvqfzFxgashuZPqR9kPwQ9gFA7I1yskZjhmGmY2pAow=
|
||||||
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
|
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
|
||||||
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
||||||
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||||
|
|
3
main.go
3
main.go
|
@ -3,7 +3,6 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"embed"
|
"embed"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -77,8 +76,6 @@ func main() {
|
||||||
log.Fatal().Err(err).Msg("Failed to setup passkey support")
|
log.Fatal().Err(err).Msg("Failed to setup passkey support")
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println(nojsFS.ReadDir("frontend-noscript"))
|
|
||||||
|
|
||||||
server := server.NewServer(
|
server := server.NewServer(
|
||||||
store,
|
store,
|
||||||
pkey,
|
pkey,
|
||||||
|
|
16
server/constants.go
Normal file
16
server/constants.go
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
const ContextKeyPasskeyUsername = "context-passkey-username"
|
||||||
|
|
||||||
|
type ContextKey string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ContextKeyStorage ContextKey = "Context key for storage"
|
||||||
|
ContextKeyActorId ContextKey = "Context key for actor id"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
HttpErrIdPlaceholder = iota
|
||||||
|
HttpErrIdMissingContextValue
|
||||||
|
HttpErrIdDbFailure
|
||||||
|
)
|
22
server/contextUtils.go
Normal file
22
server/contextUtils.go
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"gitlab.com/mstarongitlab/goutils/other"
|
||||||
|
"gitlab.com/mstarongitlab/linstrom/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
func StorageFromRequest(w http.ResponseWriter, r *http.Request) *storage.Storage {
|
||||||
|
store, ok := r.Context().Value(ContextKeyStorage).(*storage.Storage)
|
||||||
|
if !ok {
|
||||||
|
other.HttpErr(
|
||||||
|
w,
|
||||||
|
HttpErrIdMissingContextValue,
|
||||||
|
"Missing storage reference",
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return store
|
||||||
|
}
|
14
server/frontend.go
Normal file
14
server/frontend.go
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/fs"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupFrontendRouter(interactiveFs, noscriptFs fs.FS) http.Handler {
|
||||||
|
router := http.NewServeMux()
|
||||||
|
router.Handle("/noscript/", http.StripPrefix("/noscript", http.FileServerFS(noscriptFs)))
|
||||||
|
router.Handle("/", http.FileServerFS(interactiveFs))
|
||||||
|
|
||||||
|
return router
|
||||||
|
}
|
|
@ -8,6 +8,7 @@ import (
|
||||||
|
|
||||||
"github.com/rs/zerolog/hlog"
|
"github.com/rs/zerolog/hlog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
"gitlab.com/mstarongitlab/goutils/other"
|
||||||
)
|
)
|
||||||
|
|
||||||
type HandlerBuilder func(http.Handler) http.Handler
|
type HandlerBuilder func(http.Handler) http.Handler
|
||||||
|
@ -51,3 +52,36 @@ func LoggingMiddleware(handler http.Handler) http.Handler {
|
||||||
hlog.RequestIDHandler("req_id", "Request-Id"),
|
hlog.RequestIDHandler("req_id", "Request-Id"),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func passkeyIdToAccountIdTransformerMiddleware(handler http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s := StorageFromRequest(w, r)
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log := hlog.FromRequest(r)
|
||||||
|
passkeyId, ok := r.Context().Value(ContextKeyPasskeyUsername).(string)
|
||||||
|
if !ok {
|
||||||
|
other.HttpErr(
|
||||||
|
w,
|
||||||
|
HttpErrIdMissingContextValue,
|
||||||
|
"Actor name missing",
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debug().Bytes("passkey-bytes", []byte(passkeyId)).Msg("Id from passkey auth")
|
||||||
|
acc, err := s.FindAccountByPasskeyId([]byte(passkeyId))
|
||||||
|
if err != nil {
|
||||||
|
other.HttpErr(
|
||||||
|
w,
|
||||||
|
HttpErrIdDbFailure,
|
||||||
|
"Failed to get account from storage",
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r = r.WithContext(context.WithValue(r.Context(), ContextKeyActorId, acc.ID))
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
"github.com/mstarongithub/passkey"
|
"github.com/mstarongithub/passkey"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
@ -17,7 +18,9 @@ type Server struct {
|
||||||
|
|
||||||
func NewServer(store *storage.Storage, pkey *passkey.Passkey, reactiveFS, staticFS fs.FS) *Server {
|
func NewServer(store *storage.Storage, pkey *passkey.Passkey, reactiveFS, staticFS fs.FS) *Server {
|
||||||
handler := buildRootHandler(pkey, reactiveFS, staticFS)
|
handler := buildRootHandler(pkey, reactiveFS, staticFS)
|
||||||
handler = ChainMiddlewares(handler, LoggingMiddleware, ContextValsMiddleware(map[any]any{}))
|
handler = ChainMiddlewares(handler, LoggingMiddleware, ContextValsMiddleware(map[any]any{
|
||||||
|
ContextKeyStorage: store,
|
||||||
|
}))
|
||||||
return &Server{
|
return &Server{
|
||||||
store: store,
|
store: store,
|
||||||
router: handler,
|
router: handler,
|
||||||
|
@ -27,15 +30,22 @@ func NewServer(store *storage.Storage, pkey *passkey.Passkey, reactiveFS, static
|
||||||
func buildRootHandler(pkey *passkey.Passkey, reactiveFS, staticFS fs.FS) http.Handler {
|
func buildRootHandler(pkey *passkey.Passkey, reactiveFS, staticFS fs.FS) http.Handler {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
pkey.MountRoutes(mux, "/webauthn/")
|
pkey.MountRoutes(mux, "/webauthn/")
|
||||||
mux.Handle("/", http.FileServerFS(reactiveFS))
|
mux.Handle("/", setupFrontendRouter(reactiveFS, staticFS))
|
||||||
mux.Handle("/nojs/", http.StripPrefix("/nojs", http.FileServerFS(staticFS)))
|
|
||||||
mux.Handle("/pk/", http.StripPrefix("/pk", http.FileServer(http.Dir("pk-auth"))))
|
mux.Handle("/pk/", http.StripPrefix("/pk", http.FileServer(http.Dir("pk-auth"))))
|
||||||
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, true) })
|
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, true) })
|
||||||
|
mux.Handle(
|
||||||
|
"/authonly/",
|
||||||
|
pkey.Auth(
|
||||||
|
ContextKeyPasskeyUsername,
|
||||||
|
nil,
|
||||||
|
passkey.RedirectUnauthorized(url.URL{Path: "/"}),
|
||||||
|
)(ChainMiddlewares(setupTestEndpoints(), passkeyIdToAccountIdTransformerMiddleware)),
|
||||||
|
)
|
||||||
|
|
||||||
return mux
|
return mux
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Start(addr string) {
|
func (s *Server) Start(addr string) error {
|
||||||
log.Info().Str("addr", addr).Msg("Starting server")
|
log.Info().Str("addr", addr).Msg("Starting server")
|
||||||
http.ListenAndServe(addr, s.router)
|
return http.ListenAndServe(addr, s.router)
|
||||||
}
|
}
|
||||||
|
|
16
server/testingEndpoints.go
Normal file
16
server/testingEndpoints.go
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupTestEndpoints() http.Handler {
|
||||||
|
router := http.NewServeMux()
|
||||||
|
router.HandleFunc(
|
||||||
|
"/",
|
||||||
|
func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "test root") },
|
||||||
|
)
|
||||||
|
|
||||||
|
return router
|
||||||
|
}
|
|
@ -9,15 +9,17 @@ import (
|
||||||
|
|
||||||
// various prefixes for accessing items in the cache (since it's a simple key-value store)
|
// various prefixes for accessing items in the cache (since it's a simple key-value store)
|
||||||
const (
|
const (
|
||||||
cacheUserHandleToIdPrefix = "acc-name-to-id:"
|
cacheUserHandleToIdPrefix = "acc-name-to-id:"
|
||||||
cacheUserIdToAccPrefix = "acc-id-to-data:"
|
cacheLocalUsernameToIdPrefix = "acc-local-name-to-id:"
|
||||||
cacheNoteIdToNotePrefix = "note-id-to-data:"
|
cachePasskeyIdToAccIdPrefix = "acc-pkey-id-to-id:"
|
||||||
|
cacheUserIdToAccPrefix = "acc-id-to-data:"
|
||||||
|
cacheNoteIdToNotePrefix = "note-id-to-data:"
|
||||||
)
|
)
|
||||||
|
|
||||||
// An error describing the case where some value was just not found in the cache
|
// An error describing the case where some value was just not found in the cache
|
||||||
var errCacheNotFound = errors.New("not found in cache")
|
var errCacheNotFound = errors.New("not found in cache")
|
||||||
|
|
||||||
// Find an account id in cache using a given user handle
|
// Find an account id in cache using a given user handle ("@bob@example.com" or "bob@example.com")
|
||||||
// accId contains the Id of the account if found
|
// accId contains the Id of the account if found
|
||||||
// err contains an error describing why an account's id couldn't be found
|
// err contains an error describing why an account's id couldn't be found
|
||||||
// The most common one should be errCacheNotFound
|
// The most common one should be errCacheNotFound
|
||||||
|
@ -38,6 +40,44 @@ func (s *Storage) cacheHandleToAccUid(handle string) (accId *string, err error)
|
||||||
return &target, nil
|
return &target, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Find a local account's id in cache using a given username ("bob")
|
||||||
|
// accId containst the Id of the account if found
|
||||||
|
// err contains an error describing why an account's id couldn't be found
|
||||||
|
// The most common one should be errCacheNotFound
|
||||||
|
func (s *Storage) cacheLocalUsernameToAccUid(username string) (accId *string, err error) {
|
||||||
|
// Where to put the data (in case it's found)
|
||||||
|
var target string
|
||||||
|
found, err := s.cache.Get(cacheLocalUsernameToIdPrefix+username, &target)
|
||||||
|
// If nothing was found, check error
|
||||||
|
if !found {
|
||||||
|
// Case error is set and NOT redis' error for nothing found: Return that error
|
||||||
|
if err != nil && !errors.Is(err, redis.Nil) {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
// Else return errCacheNotFound
|
||||||
|
return nil, errCacheNotFound
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &target, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Storage) cachePkeyIdToAccId(pkeyId []byte) (accId *string, err error) {
|
||||||
|
// Where to put the data (in case it's found)
|
||||||
|
var target string
|
||||||
|
found, err := s.cache.Get(cachePasskeyIdToAccIdPrefix+string(pkeyId), &target)
|
||||||
|
// If nothing was found, check error
|
||||||
|
if !found {
|
||||||
|
// Case error is set and NOT redis' error for nothing found: Return that error
|
||||||
|
if err != nil && !errors.Is(err, redis.Nil) {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
// Else return errCacheNotFound
|
||||||
|
return nil, errCacheNotFound
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &target, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Find an account's data in cache using a given account id
|
// Find an account's data in cache using a given account id
|
||||||
// acc contains the full account as stored last time if found
|
// acc contains the full account as stored last time if found
|
||||||
// err contains an error describing why an account couldn't be found
|
// err contains an error describing why an account couldn't be found
|
||||||
|
|
133
storage/user.go
133
storage/user.go
|
@ -1,6 +1,7 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"errors"
|
"errors"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -20,10 +21,10 @@ import (
|
||||||
// If remote, this is used for caching the account
|
// If remote, this is used for caching the account
|
||||||
type Account struct {
|
type Account struct {
|
||||||
ID string `gorm:"primarykey"` // ID is a uuid for this account
|
ID string `gorm:"primarykey"` // ID is a uuid for this account
|
||||||
// Handle of the user (eg "max" if the full username is @max@example.com)
|
// Username of the user (eg "max" if the full username is @max@example.com)
|
||||||
// Assume unchangable (once set by a user) to be kind to other implementations
|
// Assume unchangable (once set by a user) to be kind to other implementations
|
||||||
// Would be an easy avenue to fuck with them though
|
// Would be an easy avenue to fuck with them though
|
||||||
Handle string
|
Username string
|
||||||
CreatedAt time.Time // When this entry was created. Automatically set by gorm
|
CreatedAt time.Time // When this entry was created. Automatically set by gorm
|
||||||
// When this account was last updated. Will also be used for refreshing remote accounts. Automatically set by gorm
|
// When this account was last updated. Will also be used for refreshing remote accounts. Automatically set by gorm
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
|
@ -44,7 +45,7 @@ type Account struct {
|
||||||
Background string // ID of a media file used as background image
|
Background string // ID of a media file used as background image
|
||||||
Banner string // ID of a media file used as banner
|
Banner string // ID of a media file used as banner
|
||||||
Indexable bool // Whether this account can be found by crawlers
|
Indexable bool // Whether this account can be found by crawlers
|
||||||
PublicKeyPem *string // The public key of the account
|
PublicKey []byte // The public key of the account
|
||||||
// Whether this account restricts following
|
// Whether this account restricts following
|
||||||
// If true, the owner must approve of a follow request first
|
// If true, the owner must approve of a follow request first
|
||||||
RestrictedFollow bool
|
RestrictedFollow bool
|
||||||
|
@ -61,8 +62,8 @@ type Account struct {
|
||||||
|
|
||||||
// --- And internal account stuff ---
|
// --- And internal account stuff ---
|
||||||
// Still public fields since they wouldn't be able to be stored in the db otherwise
|
// Still public fields since they wouldn't be able to be stored in the db otherwise
|
||||||
PrivateKeyPem *string // The private key of the account. Nil if remote user
|
PrivateKey []byte // The private key of the account. Nil if remote user
|
||||||
WebAuthnId []byte // The unique and random ID of this account used for passkey authentication
|
WebAuthnId []byte // The unique and random ID of this account used for passkey authentication
|
||||||
// Whether the account got verified and is allowed to be active
|
// Whether the account got verified and is allowed to be active
|
||||||
// For local accounts being active means being allowed to login and perform interactions
|
// For local accounts being active means being allowed to login and perform interactions
|
||||||
// For remote users, if an account is not verified, any interactions it sends are discarded
|
// For remote users, if an account is not verified, any interactions it sends are discarded
|
||||||
|
@ -176,6 +177,114 @@ func (s *Storage) FindAccountById(id string) (*Account, error) {
|
||||||
return acc, nil
|
return acc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Storage) FindLocalAccountByUsername(username string) (*Account, error) {
|
||||||
|
log.Trace().Caller().Send()
|
||||||
|
log.Debug().Str("account-username", username).Msg("Looking for local account")
|
||||||
|
log.Debug().Str("account-username", username).Msg("Checking cache first")
|
||||||
|
|
||||||
|
// Try and find the account in cache first
|
||||||
|
cacheAccId, err := s.cacheLocalUsernameToAccUid(username)
|
||||||
|
if err == nil {
|
||||||
|
log.Info().Str("account-username", username).Msg("Hit account handle in cache")
|
||||||
|
// Then always load via id since unique key access should be faster than string matching
|
||||||
|
return s.FindAccountById(*cacheAccId)
|
||||||
|
} else {
|
||||||
|
if !errors.Is(err, errCacheNotFound) {
|
||||||
|
log.Error().Err(err).Str("account-username", username).Msg("Problem while checking cache for account")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Failed to find in cache, go the slow route of hitting the db
|
||||||
|
log.Debug().Str("account-username", username).Msg("Didn't hit account in cache, going to db")
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().
|
||||||
|
Err(err).
|
||||||
|
Str("account-username", username).
|
||||||
|
Msg("Failed to split up account username")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
acc := Account{}
|
||||||
|
res := s.db.Where("username = ?", username).
|
||||||
|
Where("server = ?", config.GlobalConfig.General.GetFullDomain()).
|
||||||
|
First(&acc)
|
||||||
|
if res.Error != nil {
|
||||||
|
if errors.Is(res.Error, gorm.ErrRecordNotFound) {
|
||||||
|
log.Info().
|
||||||
|
Str("account-username", username).
|
||||||
|
Msg("Local account with username not found")
|
||||||
|
} else {
|
||||||
|
log.Error().Err(err).Str("account-username", username).Msg("Failed to get local account with username")
|
||||||
|
}
|
||||||
|
return nil, res.Error
|
||||||
|
}
|
||||||
|
log.Info().Str("account-username", username).Msg("Found account, also inserting into cache")
|
||||||
|
if err = s.cache.Set(cacheUserIdToAccPrefix+acc.ID, &acc); err != nil {
|
||||||
|
log.Warn().
|
||||||
|
Err(err).
|
||||||
|
Str("account-username", username).
|
||||||
|
Msg("Found account but failed to insert into cache")
|
||||||
|
}
|
||||||
|
if err = s.cache.Set(cacheLocalUsernameToIdPrefix+username, acc.ID); err != nil {
|
||||||
|
log.Warn().
|
||||||
|
Err(err).
|
||||||
|
Str("account-username", username).
|
||||||
|
Msg("Failed to store local username to id in cache")
|
||||||
|
}
|
||||||
|
return &acc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Storage) FindAccountByPasskeyId(pkeyId []byte) (*Account, error) {
|
||||||
|
log.Trace().Caller().Send()
|
||||||
|
log.Debug().Bytes("account-passkey-id", pkeyId).Msg("Looking for account")
|
||||||
|
log.Debug().Bytes("account-passkey-id", pkeyId).Msg("Checking cache first")
|
||||||
|
|
||||||
|
// Try and find the account in cache first
|
||||||
|
cacheAccId, err := s.cachePkeyIdToAccId(pkeyId)
|
||||||
|
if err == nil {
|
||||||
|
log.Info().Bytes("account-passkey-id", pkeyId).Msg("Hit passkey id in cache")
|
||||||
|
// Then always load via id since unique key access should be faster than string matching
|
||||||
|
return s.FindAccountById(*cacheAccId)
|
||||||
|
} else {
|
||||||
|
if err != errCacheNotFound {
|
||||||
|
log.Error().Err(err).Bytes("account-passkey-id", pkeyId).Msg("Problem while checking cache for account")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Failed to find in cache, go the slow route of hitting the db
|
||||||
|
log.Debug().Bytes("account-passkey-id", pkeyId).Msg("Didn't hit account in cache, going to db")
|
||||||
|
|
||||||
|
acc := Account{}
|
||||||
|
res := s.db.Where("web_authn_id = ?", pkeyId).
|
||||||
|
First(&acc)
|
||||||
|
if res.Error != nil {
|
||||||
|
if res.Error == gorm.ErrRecordNotFound {
|
||||||
|
log.Info().
|
||||||
|
Bytes("account-passkey-id", pkeyId).
|
||||||
|
Msg("Local account with passkey id not found")
|
||||||
|
} else {
|
||||||
|
log.Error().Err(res.Error).Bytes("account-passkey-id", pkeyId).Msg("Failed to get local account with passkey id")
|
||||||
|
}
|
||||||
|
return nil, res.Error
|
||||||
|
}
|
||||||
|
log.Info().Bytes("account-passkey-id", pkeyId).Msg("Found account, also inserting into cache")
|
||||||
|
// if err = s.cache.Set(cacheUserIdToAccPrefix+acc.ID, &acc); err != nil {
|
||||||
|
// log.Warn().
|
||||||
|
// Err(err).
|
||||||
|
// Bytes("account-passkey-id", pkeyId).
|
||||||
|
// Msg("Found account but failed to insert into cache")
|
||||||
|
// }
|
||||||
|
// if err = s.cache.Set(cachePasskeyIdToAccIdPrefix+string(pkeyId), acc.ID); err != nil {
|
||||||
|
// log.Warn().
|
||||||
|
// Err(err).
|
||||||
|
// Bytes("account-passkey-id", pkeyId).
|
||||||
|
// Msg("Failed to store local username to id in cache")
|
||||||
|
// }
|
||||||
|
return &acc, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Update a given account in storage and cache
|
// Update a given account in storage and cache
|
||||||
func (s *Storage) UpdateAccount(acc *Account) error {
|
func (s *Storage) UpdateAccount(acc *Account) error {
|
||||||
// If the account is nil or doesn't have an id, error out
|
// If the account is nil or doesn't have an id, error out
|
||||||
|
@ -236,11 +345,19 @@ func (s *Storage) NewLocalAccount(handle string) (*Account, error) {
|
||||||
log.Error().Err(err).Msg("Failed to create empty account for use")
|
log.Error().Err(err).Msg("Failed to create empty account for use")
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
acc.Handle = handle
|
acc.Username = handle
|
||||||
acc.Server = config.GlobalConfig.General.GetFullDomain()
|
acc.Server = config.GlobalConfig.General.GetFullDomain()
|
||||||
acc.Remote = false
|
acc.Remote = false
|
||||||
acc.DisplayName = handle
|
acc.DisplayName = handle
|
||||||
|
|
||||||
|
publicKey, privateKey, err := ed25519.GenerateKey(nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("Failed to generate key pair for new local account")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
acc.PrivateKey = privateKey
|
||||||
|
acc.PublicKey = publicKey
|
||||||
|
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("account-handle", handle).
|
Str("account-handle", handle).
|
||||||
Str("account-id", acc.ID).
|
Str("account-id", acc.ID).
|
||||||
|
@ -267,7 +384,7 @@ func (a *Account) WebAuthnID() []byte {
|
||||||
|
|
||||||
func (u *Account) WebAuthnName() string {
|
func (u *Account) WebAuthnName() string {
|
||||||
log.Trace().Caller().Send()
|
log.Trace().Caller().Send()
|
||||||
return u.Handle
|
return u.Username
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Account) WebAuthnDisplayName() string {
|
func (u *Account) WebAuthnDisplayName() string {
|
||||||
|
@ -302,7 +419,7 @@ func (s *Storage) GetOrCreateUser(userID string) passkey.User {
|
||||||
Str("account-handle", userID).
|
Str("account-handle", userID).
|
||||||
Msg("Looking for or creating account for passkey stuff")
|
Msg("Looking for or creating account for passkey stuff")
|
||||||
acc := &Account{}
|
acc := &Account{}
|
||||||
res := s.db.Where(Account{Handle: userID, Server: config.GlobalConfig.General.GetFullDomain()}).
|
res := s.db.Where(Account{Username: userID, Server: config.GlobalConfig.General.GetFullDomain()}).
|
||||||
First(acc)
|
First(acc)
|
||||||
if errors.Is(res.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(res.Error, gorm.ErrRecordNotFound) {
|
||||||
log.Debug().Str("account-handle", userID)
|
log.Debug().Str("account-handle", userID)
|
||||||
|
|
Loading…
Reference in a new issue