diff --git a/go.mod b/go.mod index a13db44..e643d2b 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,7 @@ require ( github.com/rs/zerolog v1.33.0 github.com/xhit/go-simple-mail/v2 v2.16.0 gitlab.com/mstarongitlab/goap v1.1.0 - gitlab.com/mstarongitlab/goutils v1.3.0 + gitlab.com/mstarongitlab/goutils v1.4.1 golang.org/x/image v0.20.0 gorm.io/driver/postgres v1.5.7 gorm.io/gorm v1.25.10 diff --git a/go.sum b/go.sum index f43807b..eb7bec7 100644 --- a/go.sum +++ b/go.sum @@ -304,6 +304,8 @@ gitlab.com/mstarongitlab/goap v1.1.0 h1:uN05RP+Tq2NR2IuPq6XQa5oLpfailpoEvxo1Sfeh gitlab.com/mstarongitlab/goap v1.1.0/go.mod h1:rt9IYvJBPh1z6t+vvzifmxDtGjGlr8683tSPfa5dbXI= gitlab.com/mstarongitlab/goutils v1.3.0 h1:uuxPHjIU36lyJ8/z4T2xI32zOyh53Xj0Au8K12qkaJ4= gitlab.com/mstarongitlab/goutils v1.3.0/go.mod h1:SvqfzFxgashuZPqR9kPwQ9gFA7I1yskZjhmGmY2pAow= +gitlab.com/mstarongitlab/goutils v1.4.1 h1:g6bLX1gGqQeoRwmFC0Aw9lvxyw22cWjyG7RXsi+JmlI= +gitlab.com/mstarongitlab/goutils v1.4.1/go.mod h1:SvqfzFxgashuZPqR9kPwQ9gFA7I1yskZjhmGmY2pAow= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= diff --git a/main.go b/main.go index e452460..084d8c5 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,6 @@ package main import ( "embed" - "fmt" "io" "os" "strings" @@ -77,8 +76,6 @@ func main() { log.Fatal().Err(err).Msg("Failed to setup passkey support") } - fmt.Println(nojsFS.ReadDir("frontend-noscript")) - server := server.NewServer( store, pkey, diff --git a/server/constants.go b/server/constants.go new file mode 100644 index 0000000..346977b --- /dev/null +++ b/server/constants.go @@ -0,0 +1,16 @@ +package server + +const ContextKeyPasskeyUsername = "context-passkey-username" + +type ContextKey string + +const ( + ContextKeyStorage ContextKey = "Context key for storage" + ContextKeyActorId ContextKey = "Context key for actor id" +) + +const ( + HttpErrIdPlaceholder = iota + HttpErrIdMissingContextValue + HttpErrIdDbFailure +) diff --git a/server/contextUtils.go b/server/contextUtils.go new file mode 100644 index 0000000..f815c6f --- /dev/null +++ b/server/contextUtils.go @@ -0,0 +1,22 @@ +package server + +import ( + "net/http" + + "gitlab.com/mstarongitlab/goutils/other" + "gitlab.com/mstarongitlab/linstrom/storage" +) + +func StorageFromRequest(w http.ResponseWriter, r *http.Request) *storage.Storage { + store, ok := r.Context().Value(ContextKeyStorage).(*storage.Storage) + if !ok { + other.HttpErr( + w, + HttpErrIdMissingContextValue, + "Missing storage reference", + http.StatusInternalServerError, + ) + return nil + } + return store +} diff --git a/server/frontend.go b/server/frontend.go new file mode 100644 index 0000000..990822a --- /dev/null +++ b/server/frontend.go @@ -0,0 +1,14 @@ +package server + +import ( + "io/fs" + "net/http" +) + +func setupFrontendRouter(interactiveFs, noscriptFs fs.FS) http.Handler { + router := http.NewServeMux() + router.Handle("/noscript/", http.StripPrefix("/noscript", http.FileServerFS(noscriptFs))) + router.Handle("/", http.FileServerFS(interactiveFs)) + + return router +} diff --git a/server/middlewares.go b/server/middlewares.go index 4ed9806..798214f 100644 --- a/server/middlewares.go +++ b/server/middlewares.go @@ -8,6 +8,7 @@ import ( "github.com/rs/zerolog/hlog" "github.com/rs/zerolog/log" + "gitlab.com/mstarongitlab/goutils/other" ) type HandlerBuilder func(http.Handler) http.Handler @@ -51,3 +52,36 @@ func LoggingMiddleware(handler http.Handler) http.Handler { hlog.RequestIDHandler("req_id", "Request-Id"), ) } + +func passkeyIdToAccountIdTransformerMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s := StorageFromRequest(w, r) + if s == nil { + return + } + log := hlog.FromRequest(r) + passkeyId, ok := r.Context().Value(ContextKeyPasskeyUsername).(string) + if !ok { + other.HttpErr( + w, + HttpErrIdMissingContextValue, + "Actor name missing", + http.StatusInternalServerError, + ) + return + } + log.Debug().Bytes("passkey-bytes", []byte(passkeyId)).Msg("Id from passkey auth") + acc, err := s.FindAccountByPasskeyId([]byte(passkeyId)) + if err != nil { + other.HttpErr( + w, + HttpErrIdDbFailure, + "Failed to get account from storage", + http.StatusInternalServerError, + ) + return + } + r = r.WithContext(context.WithValue(r.Context(), ContextKeyActorId, acc.ID)) + handler.ServeHTTP(w, r) + }) +} diff --git a/server/server.go b/server/server.go index 7d33e67..ab33d20 100644 --- a/server/server.go +++ b/server/server.go @@ -4,6 +4,7 @@ import ( "fmt" "io/fs" "net/http" + "net/url" "github.com/mstarongithub/passkey" "github.com/rs/zerolog/log" @@ -17,7 +18,9 @@ type Server struct { func NewServer(store *storage.Storage, pkey *passkey.Passkey, reactiveFS, staticFS fs.FS) *Server { handler := buildRootHandler(pkey, reactiveFS, staticFS) - handler = ChainMiddlewares(handler, LoggingMiddleware, ContextValsMiddleware(map[any]any{})) + handler = ChainMiddlewares(handler, LoggingMiddleware, ContextValsMiddleware(map[any]any{ + ContextKeyStorage: store, + })) return &Server{ store: store, router: handler, @@ -27,15 +30,22 @@ func NewServer(store *storage.Storage, pkey *passkey.Passkey, reactiveFS, static func buildRootHandler(pkey *passkey.Passkey, reactiveFS, staticFS fs.FS) http.Handler { mux := http.NewServeMux() pkey.MountRoutes(mux, "/webauthn/") - mux.Handle("/", http.FileServerFS(reactiveFS)) - mux.Handle("/nojs/", http.StripPrefix("/nojs", http.FileServerFS(staticFS))) + mux.Handle("/", setupFrontendRouter(reactiveFS, staticFS)) mux.Handle("/pk/", http.StripPrefix("/pk", http.FileServer(http.Dir("pk-auth")))) mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, true) }) + mux.Handle( + "/authonly/", + pkey.Auth( + ContextKeyPasskeyUsername, + nil, + passkey.RedirectUnauthorized(url.URL{Path: "/"}), + )(ChainMiddlewares(setupTestEndpoints(), passkeyIdToAccountIdTransformerMiddleware)), + ) return mux } -func (s *Server) Start(addr string) { +func (s *Server) Start(addr string) error { log.Info().Str("addr", addr).Msg("Starting server") - http.ListenAndServe(addr, s.router) + return http.ListenAndServe(addr, s.router) } diff --git a/server/testingEndpoints.go b/server/testingEndpoints.go new file mode 100644 index 0000000..67be6f4 --- /dev/null +++ b/server/testingEndpoints.go @@ -0,0 +1,16 @@ +package server + +import ( + "fmt" + "net/http" +) + +func setupTestEndpoints() http.Handler { + router := http.NewServeMux() + router.HandleFunc( + "/", + func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "test root") }, + ) + + return router +} diff --git a/storage/cache.go b/storage/cache.go index 789d7fc..f61a0b9 100644 --- a/storage/cache.go +++ b/storage/cache.go @@ -9,15 +9,17 @@ import ( // various prefixes for accessing items in the cache (since it's a simple key-value store) const ( - cacheUserHandleToIdPrefix = "acc-name-to-id:" - cacheUserIdToAccPrefix = "acc-id-to-data:" - cacheNoteIdToNotePrefix = "note-id-to-data:" + cacheUserHandleToIdPrefix = "acc-name-to-id:" + cacheLocalUsernameToIdPrefix = "acc-local-name-to-id:" + cachePasskeyIdToAccIdPrefix = "acc-pkey-id-to-id:" + cacheUserIdToAccPrefix = "acc-id-to-data:" + cacheNoteIdToNotePrefix = "note-id-to-data:" ) // An error describing the case where some value was just not found in the cache var errCacheNotFound = errors.New("not found in cache") -// Find an account id in cache using a given user handle +// Find an account id in cache using a given user handle ("@bob@example.com" or "bob@example.com") // accId contains the Id of the account if found // err contains an error describing why an account's id couldn't be found // The most common one should be errCacheNotFound @@ -38,6 +40,44 @@ func (s *Storage) cacheHandleToAccUid(handle string) (accId *string, err error) return &target, nil } +// Find a local account's id in cache using a given username ("bob") +// accId containst the Id of the account if found +// err contains an error describing why an account's id couldn't be found +// The most common one should be errCacheNotFound +func (s *Storage) cacheLocalUsernameToAccUid(username string) (accId *string, err error) { + // Where to put the data (in case it's found) + var target string + found, err := s.cache.Get(cacheLocalUsernameToIdPrefix+username, &target) + // If nothing was found, check error + if !found { + // Case error is set and NOT redis' error for nothing found: Return that error + if err != nil && !errors.Is(err, redis.Nil) { + return nil, err + } else { + // Else return errCacheNotFound + return nil, errCacheNotFound + } + } + return &target, nil +} + +func (s *Storage) cachePkeyIdToAccId(pkeyId []byte) (accId *string, err error) { + // Where to put the data (in case it's found) + var target string + found, err := s.cache.Get(cachePasskeyIdToAccIdPrefix+string(pkeyId), &target) + // If nothing was found, check error + if !found { + // Case error is set and NOT redis' error for nothing found: Return that error + if err != nil && !errors.Is(err, redis.Nil) { + return nil, err + } else { + // Else return errCacheNotFound + return nil, errCacheNotFound + } + } + return &target, nil +} + // Find an account's data in cache using a given account id // acc contains the full account as stored last time if found // err contains an error describing why an account couldn't be found diff --git a/storage/user.go b/storage/user.go index 0634183..4e28242 100644 --- a/storage/user.go +++ b/storage/user.go @@ -1,6 +1,7 @@ package storage import ( + "crypto/ed25519" "crypto/rand" "errors" "strings" @@ -20,10 +21,10 @@ import ( // If remote, this is used for caching the account type Account struct { ID string `gorm:"primarykey"` // ID is a uuid for this account - // Handle of the user (eg "max" if the full username is @max@example.com) + // Username of the user (eg "max" if the full username is @max@example.com) // Assume unchangable (once set by a user) to be kind to other implementations // Would be an easy avenue to fuck with them though - Handle string + Username string CreatedAt time.Time // When this entry was created. Automatically set by gorm // When this account was last updated. Will also be used for refreshing remote accounts. Automatically set by gorm UpdatedAt time.Time @@ -44,7 +45,7 @@ type Account struct { Background string // ID of a media file used as background image Banner string // ID of a media file used as banner Indexable bool // Whether this account can be found by crawlers - PublicKeyPem *string // The public key of the account + PublicKey []byte // The public key of the account // Whether this account restricts following // If true, the owner must approve of a follow request first RestrictedFollow bool @@ -61,8 +62,8 @@ type Account struct { // --- And internal account stuff --- // Still public fields since they wouldn't be able to be stored in the db otherwise - PrivateKeyPem *string // The private key of the account. Nil if remote user - WebAuthnId []byte // The unique and random ID of this account used for passkey authentication + PrivateKey []byte // The private key of the account. Nil if remote user + WebAuthnId []byte // The unique and random ID of this account used for passkey authentication // Whether the account got verified and is allowed to be active // For local accounts being active means being allowed to login and perform interactions // For remote users, if an account is not verified, any interactions it sends are discarded @@ -176,6 +177,114 @@ func (s *Storage) FindAccountById(id string) (*Account, error) { return acc, nil } +func (s *Storage) FindLocalAccountByUsername(username string) (*Account, error) { + log.Trace().Caller().Send() + log.Debug().Str("account-username", username).Msg("Looking for local account") + log.Debug().Str("account-username", username).Msg("Checking cache first") + + // Try and find the account in cache first + cacheAccId, err := s.cacheLocalUsernameToAccUid(username) + if err == nil { + log.Info().Str("account-username", username).Msg("Hit account handle in cache") + // Then always load via id since unique key access should be faster than string matching + return s.FindAccountById(*cacheAccId) + } else { + if !errors.Is(err, errCacheNotFound) { + log.Error().Err(err).Str("account-username", username).Msg("Problem while checking cache for account") + return nil, err + } + } + + // Failed to find in cache, go the slow route of hitting the db + log.Debug().Str("account-username", username).Msg("Didn't hit account in cache, going to db") + if err != nil { + log.Warn(). + Err(err). + Str("account-username", username). + Msg("Failed to split up account username") + return nil, err + } + + acc := Account{} + res := s.db.Where("username = ?", username). + Where("server = ?", config.GlobalConfig.General.GetFullDomain()). + First(&acc) + if res.Error != nil { + if errors.Is(res.Error, gorm.ErrRecordNotFound) { + log.Info(). + Str("account-username", username). + Msg("Local account with username not found") + } else { + log.Error().Err(err).Str("account-username", username).Msg("Failed to get local account with username") + } + return nil, res.Error + } + log.Info().Str("account-username", username).Msg("Found account, also inserting into cache") + if err = s.cache.Set(cacheUserIdToAccPrefix+acc.ID, &acc); err != nil { + log.Warn(). + Err(err). + Str("account-username", username). + Msg("Found account but failed to insert into cache") + } + if err = s.cache.Set(cacheLocalUsernameToIdPrefix+username, acc.ID); err != nil { + log.Warn(). + Err(err). + Str("account-username", username). + Msg("Failed to store local username to id in cache") + } + return &acc, nil +} + +func (s *Storage) FindAccountByPasskeyId(pkeyId []byte) (*Account, error) { + log.Trace().Caller().Send() + log.Debug().Bytes("account-passkey-id", pkeyId).Msg("Looking for account") + log.Debug().Bytes("account-passkey-id", pkeyId).Msg("Checking cache first") + + // Try and find the account in cache first + cacheAccId, err := s.cachePkeyIdToAccId(pkeyId) + if err == nil { + log.Info().Bytes("account-passkey-id", pkeyId).Msg("Hit passkey id in cache") + // Then always load via id since unique key access should be faster than string matching + return s.FindAccountById(*cacheAccId) + } else { + if err != errCacheNotFound { + log.Error().Err(err).Bytes("account-passkey-id", pkeyId).Msg("Problem while checking cache for account") + return nil, err + } + } + + // Failed to find in cache, go the slow route of hitting the db + log.Debug().Bytes("account-passkey-id", pkeyId).Msg("Didn't hit account in cache, going to db") + + acc := Account{} + res := s.db.Where("web_authn_id = ?", pkeyId). + First(&acc) + if res.Error != nil { + if res.Error == gorm.ErrRecordNotFound { + log.Info(). + Bytes("account-passkey-id", pkeyId). + Msg("Local account with passkey id not found") + } else { + log.Error().Err(res.Error).Bytes("account-passkey-id", pkeyId).Msg("Failed to get local account with passkey id") + } + return nil, res.Error + } + log.Info().Bytes("account-passkey-id", pkeyId).Msg("Found account, also inserting into cache") + // if err = s.cache.Set(cacheUserIdToAccPrefix+acc.ID, &acc); err != nil { + // log.Warn(). + // Err(err). + // Bytes("account-passkey-id", pkeyId). + // Msg("Found account but failed to insert into cache") + // } + // if err = s.cache.Set(cachePasskeyIdToAccIdPrefix+string(pkeyId), acc.ID); err != nil { + // log.Warn(). + // Err(err). + // Bytes("account-passkey-id", pkeyId). + // Msg("Failed to store local username to id in cache") + // } + return &acc, nil +} + // Update a given account in storage and cache func (s *Storage) UpdateAccount(acc *Account) error { // If the account is nil or doesn't have an id, error out @@ -236,11 +345,19 @@ func (s *Storage) NewLocalAccount(handle string) (*Account, error) { log.Error().Err(err).Msg("Failed to create empty account for use") return nil, err } - acc.Handle = handle + acc.Username = handle acc.Server = config.GlobalConfig.General.GetFullDomain() acc.Remote = false acc.DisplayName = handle + publicKey, privateKey, err := ed25519.GenerateKey(nil) + if err != nil { + log.Error().Err(err).Msg("Failed to generate key pair for new local account") + return nil, err + } + acc.PrivateKey = privateKey + acc.PublicKey = publicKey + log.Debug(). Str("account-handle", handle). Str("account-id", acc.ID). @@ -267,7 +384,7 @@ func (a *Account) WebAuthnID() []byte { func (u *Account) WebAuthnName() string { log.Trace().Caller().Send() - return u.Handle + return u.Username } func (u *Account) WebAuthnDisplayName() string { @@ -302,7 +419,7 @@ func (s *Storage) GetOrCreateUser(userID string) passkey.User { Str("account-handle", userID). Msg("Looking for or creating account for passkey stuff") acc := &Account{} - res := s.db.Where(Account{Handle: userID, Server: config.GlobalConfig.General.GetFullDomain()}). + res := s.db.Where(Account{Username: userID, Server: config.GlobalConfig.General.GetFullDomain()}). First(acc) if errors.Is(res.Error, gorm.ErrRecordNotFound) { log.Debug().Str("account-handle", userID)