Fix passkey authentication

Also prep for better router layout
This commit is contained in:
Melody Becker 2024-10-15 16:16:18 +02:00
parent e2260e4a0f
commit b9eb4234f4
11 changed files with 289 additions and 21 deletions

2
go.mod
View file

@ -20,7 +20,7 @@ require (
github.com/rs/zerolog v1.33.0 github.com/rs/zerolog v1.33.0
github.com/xhit/go-simple-mail/v2 v2.16.0 github.com/xhit/go-simple-mail/v2 v2.16.0
gitlab.com/mstarongitlab/goap v1.1.0 gitlab.com/mstarongitlab/goap v1.1.0
gitlab.com/mstarongitlab/goutils v1.3.0 gitlab.com/mstarongitlab/goutils v1.4.1
golang.org/x/image v0.20.0 golang.org/x/image v0.20.0
gorm.io/driver/postgres v1.5.7 gorm.io/driver/postgres v1.5.7
gorm.io/gorm v1.25.10 gorm.io/gorm v1.25.10

2
go.sum
View file

@ -304,6 +304,8 @@ gitlab.com/mstarongitlab/goap v1.1.0 h1:uN05RP+Tq2NR2IuPq6XQa5oLpfailpoEvxo1Sfeh
gitlab.com/mstarongitlab/goap v1.1.0/go.mod h1:rt9IYvJBPh1z6t+vvzifmxDtGjGlr8683tSPfa5dbXI= gitlab.com/mstarongitlab/goap v1.1.0/go.mod h1:rt9IYvJBPh1z6t+vvzifmxDtGjGlr8683tSPfa5dbXI=
gitlab.com/mstarongitlab/goutils v1.3.0 h1:uuxPHjIU36lyJ8/z4T2xI32zOyh53Xj0Au8K12qkaJ4= gitlab.com/mstarongitlab/goutils v1.3.0 h1:uuxPHjIU36lyJ8/z4T2xI32zOyh53Xj0Au8K12qkaJ4=
gitlab.com/mstarongitlab/goutils v1.3.0/go.mod h1:SvqfzFxgashuZPqR9kPwQ9gFA7I1yskZjhmGmY2pAow= gitlab.com/mstarongitlab/goutils v1.3.0/go.mod h1:SvqfzFxgashuZPqR9kPwQ9gFA7I1yskZjhmGmY2pAow=
gitlab.com/mstarongitlab/goutils v1.4.1 h1:g6bLX1gGqQeoRwmFC0Aw9lvxyw22cWjyG7RXsi+JmlI=
gitlab.com/mstarongitlab/goutils v1.4.1/go.mod h1:SvqfzFxgashuZPqR9kPwQ9gFA7I1yskZjhmGmY2pAow=
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=

View file

@ -3,7 +3,6 @@ package main
import ( import (
"embed" "embed"
"fmt"
"io" "io"
"os" "os"
"strings" "strings"
@ -77,8 +76,6 @@ func main() {
log.Fatal().Err(err).Msg("Failed to setup passkey support") log.Fatal().Err(err).Msg("Failed to setup passkey support")
} }
fmt.Println(nojsFS.ReadDir("frontend-noscript"))
server := server.NewServer( server := server.NewServer(
store, store,
pkey, pkey,

16
server/constants.go Normal file
View file

@ -0,0 +1,16 @@
package server
const ContextKeyPasskeyUsername = "context-passkey-username"
type ContextKey string
const (
ContextKeyStorage ContextKey = "Context key for storage"
ContextKeyActorId ContextKey = "Context key for actor id"
)
const (
HttpErrIdPlaceholder = iota
HttpErrIdMissingContextValue
HttpErrIdDbFailure
)

22
server/contextUtils.go Normal file
View file

@ -0,0 +1,22 @@
package server
import (
"net/http"
"gitlab.com/mstarongitlab/goutils/other"
"gitlab.com/mstarongitlab/linstrom/storage"
)
func StorageFromRequest(w http.ResponseWriter, r *http.Request) *storage.Storage {
store, ok := r.Context().Value(ContextKeyStorage).(*storage.Storage)
if !ok {
other.HttpErr(
w,
HttpErrIdMissingContextValue,
"Missing storage reference",
http.StatusInternalServerError,
)
return nil
}
return store
}

14
server/frontend.go Normal file
View file

@ -0,0 +1,14 @@
package server
import (
"io/fs"
"net/http"
)
func setupFrontendRouter(interactiveFs, noscriptFs fs.FS) http.Handler {
router := http.NewServeMux()
router.Handle("/noscript/", http.StripPrefix("/noscript", http.FileServerFS(noscriptFs)))
router.Handle("/", http.FileServerFS(interactiveFs))
return router
}

View file

@ -8,6 +8,7 @@ import (
"github.com/rs/zerolog/hlog" "github.com/rs/zerolog/hlog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gitlab.com/mstarongitlab/goutils/other"
) )
type HandlerBuilder func(http.Handler) http.Handler type HandlerBuilder func(http.Handler) http.Handler
@ -51,3 +52,36 @@ func LoggingMiddleware(handler http.Handler) http.Handler {
hlog.RequestIDHandler("req_id", "Request-Id"), hlog.RequestIDHandler("req_id", "Request-Id"),
) )
} }
func passkeyIdToAccountIdTransformerMiddleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s := StorageFromRequest(w, r)
if s == nil {
return
}
log := hlog.FromRequest(r)
passkeyId, ok := r.Context().Value(ContextKeyPasskeyUsername).(string)
if !ok {
other.HttpErr(
w,
HttpErrIdMissingContextValue,
"Actor name missing",
http.StatusInternalServerError,
)
return
}
log.Debug().Bytes("passkey-bytes", []byte(passkeyId)).Msg("Id from passkey auth")
acc, err := s.FindAccountByPasskeyId([]byte(passkeyId))
if err != nil {
other.HttpErr(
w,
HttpErrIdDbFailure,
"Failed to get account from storage",
http.StatusInternalServerError,
)
return
}
r = r.WithContext(context.WithValue(r.Context(), ContextKeyActorId, acc.ID))
handler.ServeHTTP(w, r)
})
}

View file

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"io/fs" "io/fs"
"net/http" "net/http"
"net/url"
"github.com/mstarongithub/passkey" "github.com/mstarongithub/passkey"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -17,7 +18,9 @@ type Server struct {
func NewServer(store *storage.Storage, pkey *passkey.Passkey, reactiveFS, staticFS fs.FS) *Server { func NewServer(store *storage.Storage, pkey *passkey.Passkey, reactiveFS, staticFS fs.FS) *Server {
handler := buildRootHandler(pkey, reactiveFS, staticFS) handler := buildRootHandler(pkey, reactiveFS, staticFS)
handler = ChainMiddlewares(handler, LoggingMiddleware, ContextValsMiddleware(map[any]any{})) handler = ChainMiddlewares(handler, LoggingMiddleware, ContextValsMiddleware(map[any]any{
ContextKeyStorage: store,
}))
return &Server{ return &Server{
store: store, store: store,
router: handler, router: handler,
@ -27,15 +30,22 @@ func NewServer(store *storage.Storage, pkey *passkey.Passkey, reactiveFS, static
func buildRootHandler(pkey *passkey.Passkey, reactiveFS, staticFS fs.FS) http.Handler { func buildRootHandler(pkey *passkey.Passkey, reactiveFS, staticFS fs.FS) http.Handler {
mux := http.NewServeMux() mux := http.NewServeMux()
pkey.MountRoutes(mux, "/webauthn/") pkey.MountRoutes(mux, "/webauthn/")
mux.Handle("/", http.FileServerFS(reactiveFS)) mux.Handle("/", setupFrontendRouter(reactiveFS, staticFS))
mux.Handle("/nojs/", http.StripPrefix("/nojs", http.FileServerFS(staticFS)))
mux.Handle("/pk/", http.StripPrefix("/pk", http.FileServer(http.Dir("pk-auth")))) mux.Handle("/pk/", http.StripPrefix("/pk", http.FileServer(http.Dir("pk-auth"))))
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, true) }) mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, true) })
mux.Handle(
"/authonly/",
pkey.Auth(
ContextKeyPasskeyUsername,
nil,
passkey.RedirectUnauthorized(url.URL{Path: "/"}),
)(ChainMiddlewares(setupTestEndpoints(), passkeyIdToAccountIdTransformerMiddleware)),
)
return mux return mux
} }
func (s *Server) Start(addr string) { func (s *Server) Start(addr string) error {
log.Info().Str("addr", addr).Msg("Starting server") log.Info().Str("addr", addr).Msg("Starting server")
http.ListenAndServe(addr, s.router) return http.ListenAndServe(addr, s.router)
} }

View file

@ -0,0 +1,16 @@
package server
import (
"fmt"
"net/http"
)
func setupTestEndpoints() http.Handler {
router := http.NewServeMux()
router.HandleFunc(
"/",
func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "test root") },
)
return router
}

View file

@ -10,6 +10,8 @@ import (
// various prefixes for accessing items in the cache (since it's a simple key-value store) // various prefixes for accessing items in the cache (since it's a simple key-value store)
const ( const (
cacheUserHandleToIdPrefix = "acc-name-to-id:" cacheUserHandleToIdPrefix = "acc-name-to-id:"
cacheLocalUsernameToIdPrefix = "acc-local-name-to-id:"
cachePasskeyIdToAccIdPrefix = "acc-pkey-id-to-id:"
cacheUserIdToAccPrefix = "acc-id-to-data:" cacheUserIdToAccPrefix = "acc-id-to-data:"
cacheNoteIdToNotePrefix = "note-id-to-data:" cacheNoteIdToNotePrefix = "note-id-to-data:"
) )
@ -17,7 +19,7 @@ const (
// An error describing the case where some value was just not found in the cache // An error describing the case where some value was just not found in the cache
var errCacheNotFound = errors.New("not found in cache") var errCacheNotFound = errors.New("not found in cache")
// Find an account id in cache using a given user handle // Find an account id in cache using a given user handle ("@bob@example.com" or "bob@example.com")
// accId contains the Id of the account if found // accId contains the Id of the account if found
// err contains an error describing why an account's id couldn't be found // err contains an error describing why an account's id couldn't be found
// The most common one should be errCacheNotFound // The most common one should be errCacheNotFound
@ -38,6 +40,44 @@ func (s *Storage) cacheHandleToAccUid(handle string) (accId *string, err error)
return &target, nil return &target, nil
} }
// Find a local account's id in cache using a given username ("bob")
// accId containst the Id of the account if found
// err contains an error describing why an account's id couldn't be found
// The most common one should be errCacheNotFound
func (s *Storage) cacheLocalUsernameToAccUid(username string) (accId *string, err error) {
// Where to put the data (in case it's found)
var target string
found, err := s.cache.Get(cacheLocalUsernameToIdPrefix+username, &target)
// If nothing was found, check error
if !found {
// Case error is set and NOT redis' error for nothing found: Return that error
if err != nil && !errors.Is(err, redis.Nil) {
return nil, err
} else {
// Else return errCacheNotFound
return nil, errCacheNotFound
}
}
return &target, nil
}
func (s *Storage) cachePkeyIdToAccId(pkeyId []byte) (accId *string, err error) {
// Where to put the data (in case it's found)
var target string
found, err := s.cache.Get(cachePasskeyIdToAccIdPrefix+string(pkeyId), &target)
// If nothing was found, check error
if !found {
// Case error is set and NOT redis' error for nothing found: Return that error
if err != nil && !errors.Is(err, redis.Nil) {
return nil, err
} else {
// Else return errCacheNotFound
return nil, errCacheNotFound
}
}
return &target, nil
}
// Find an account's data in cache using a given account id // Find an account's data in cache using a given account id
// acc contains the full account as stored last time if found // acc contains the full account as stored last time if found
// err contains an error describing why an account couldn't be found // err contains an error describing why an account couldn't be found

View file

@ -1,6 +1,7 @@
package storage package storage
import ( import (
"crypto/ed25519"
"crypto/rand" "crypto/rand"
"errors" "errors"
"strings" "strings"
@ -20,10 +21,10 @@ import (
// If remote, this is used for caching the account // If remote, this is used for caching the account
type Account struct { type Account struct {
ID string `gorm:"primarykey"` // ID is a uuid for this account ID string `gorm:"primarykey"` // ID is a uuid for this account
// Handle of the user (eg "max" if the full username is @max@example.com) // Username of the user (eg "max" if the full username is @max@example.com)
// Assume unchangable (once set by a user) to be kind to other implementations // Assume unchangable (once set by a user) to be kind to other implementations
// Would be an easy avenue to fuck with them though // Would be an easy avenue to fuck with them though
Handle string Username string
CreatedAt time.Time // When this entry was created. Automatically set by gorm CreatedAt time.Time // When this entry was created. Automatically set by gorm
// When this account was last updated. Will also be used for refreshing remote accounts. Automatically set by gorm // When this account was last updated. Will also be used for refreshing remote accounts. Automatically set by gorm
UpdatedAt time.Time UpdatedAt time.Time
@ -44,7 +45,7 @@ type Account struct {
Background string // ID of a media file used as background image Background string // ID of a media file used as background image
Banner string // ID of a media file used as banner Banner string // ID of a media file used as banner
Indexable bool // Whether this account can be found by crawlers Indexable bool // Whether this account can be found by crawlers
PublicKeyPem *string // The public key of the account PublicKey []byte // The public key of the account
// Whether this account restricts following // Whether this account restricts following
// If true, the owner must approve of a follow request first // If true, the owner must approve of a follow request first
RestrictedFollow bool RestrictedFollow bool
@ -61,7 +62,7 @@ type Account struct {
// --- And internal account stuff --- // --- And internal account stuff ---
// Still public fields since they wouldn't be able to be stored in the db otherwise // Still public fields since they wouldn't be able to be stored in the db otherwise
PrivateKeyPem *string // The private key of the account. Nil if remote user PrivateKey []byte // The private key of the account. Nil if remote user
WebAuthnId []byte // The unique and random ID of this account used for passkey authentication WebAuthnId []byte // The unique and random ID of this account used for passkey authentication
// Whether the account got verified and is allowed to be active // Whether the account got verified and is allowed to be active
// For local accounts being active means being allowed to login and perform interactions // For local accounts being active means being allowed to login and perform interactions
@ -176,6 +177,114 @@ func (s *Storage) FindAccountById(id string) (*Account, error) {
return acc, nil return acc, nil
} }
func (s *Storage) FindLocalAccountByUsername(username string) (*Account, error) {
log.Trace().Caller().Send()
log.Debug().Str("account-username", username).Msg("Looking for local account")
log.Debug().Str("account-username", username).Msg("Checking cache first")
// Try and find the account in cache first
cacheAccId, err := s.cacheLocalUsernameToAccUid(username)
if err == nil {
log.Info().Str("account-username", username).Msg("Hit account handle in cache")
// Then always load via id since unique key access should be faster than string matching
return s.FindAccountById(*cacheAccId)
} else {
if !errors.Is(err, errCacheNotFound) {
log.Error().Err(err).Str("account-username", username).Msg("Problem while checking cache for account")
return nil, err
}
}
// Failed to find in cache, go the slow route of hitting the db
log.Debug().Str("account-username", username).Msg("Didn't hit account in cache, going to db")
if err != nil {
log.Warn().
Err(err).
Str("account-username", username).
Msg("Failed to split up account username")
return nil, err
}
acc := Account{}
res := s.db.Where("username = ?", username).
Where("server = ?", config.GlobalConfig.General.GetFullDomain()).
First(&acc)
if res.Error != nil {
if errors.Is(res.Error, gorm.ErrRecordNotFound) {
log.Info().
Str("account-username", username).
Msg("Local account with username not found")
} else {
log.Error().Err(err).Str("account-username", username).Msg("Failed to get local account with username")
}
return nil, res.Error
}
log.Info().Str("account-username", username).Msg("Found account, also inserting into cache")
if err = s.cache.Set(cacheUserIdToAccPrefix+acc.ID, &acc); err != nil {
log.Warn().
Err(err).
Str("account-username", username).
Msg("Found account but failed to insert into cache")
}
if err = s.cache.Set(cacheLocalUsernameToIdPrefix+username, acc.ID); err != nil {
log.Warn().
Err(err).
Str("account-username", username).
Msg("Failed to store local username to id in cache")
}
return &acc, nil
}
func (s *Storage) FindAccountByPasskeyId(pkeyId []byte) (*Account, error) {
log.Trace().Caller().Send()
log.Debug().Bytes("account-passkey-id", pkeyId).Msg("Looking for account")
log.Debug().Bytes("account-passkey-id", pkeyId).Msg("Checking cache first")
// Try and find the account in cache first
cacheAccId, err := s.cachePkeyIdToAccId(pkeyId)
if err == nil {
log.Info().Bytes("account-passkey-id", pkeyId).Msg("Hit passkey id in cache")
// Then always load via id since unique key access should be faster than string matching
return s.FindAccountById(*cacheAccId)
} else {
if err != errCacheNotFound {
log.Error().Err(err).Bytes("account-passkey-id", pkeyId).Msg("Problem while checking cache for account")
return nil, err
}
}
// Failed to find in cache, go the slow route of hitting the db
log.Debug().Bytes("account-passkey-id", pkeyId).Msg("Didn't hit account in cache, going to db")
acc := Account{}
res := s.db.Where("web_authn_id = ?", pkeyId).
First(&acc)
if res.Error != nil {
if res.Error == gorm.ErrRecordNotFound {
log.Info().
Bytes("account-passkey-id", pkeyId).
Msg("Local account with passkey id not found")
} else {
log.Error().Err(res.Error).Bytes("account-passkey-id", pkeyId).Msg("Failed to get local account with passkey id")
}
return nil, res.Error
}
log.Info().Bytes("account-passkey-id", pkeyId).Msg("Found account, also inserting into cache")
// if err = s.cache.Set(cacheUserIdToAccPrefix+acc.ID, &acc); err != nil {
// log.Warn().
// Err(err).
// Bytes("account-passkey-id", pkeyId).
// Msg("Found account but failed to insert into cache")
// }
// if err = s.cache.Set(cachePasskeyIdToAccIdPrefix+string(pkeyId), acc.ID); err != nil {
// log.Warn().
// Err(err).
// Bytes("account-passkey-id", pkeyId).
// Msg("Failed to store local username to id in cache")
// }
return &acc, nil
}
// Update a given account in storage and cache // Update a given account in storage and cache
func (s *Storage) UpdateAccount(acc *Account) error { func (s *Storage) UpdateAccount(acc *Account) error {
// If the account is nil or doesn't have an id, error out // If the account is nil or doesn't have an id, error out
@ -236,11 +345,19 @@ func (s *Storage) NewLocalAccount(handle string) (*Account, error) {
log.Error().Err(err).Msg("Failed to create empty account for use") log.Error().Err(err).Msg("Failed to create empty account for use")
return nil, err return nil, err
} }
acc.Handle = handle acc.Username = handle
acc.Server = config.GlobalConfig.General.GetFullDomain() acc.Server = config.GlobalConfig.General.GetFullDomain()
acc.Remote = false acc.Remote = false
acc.DisplayName = handle acc.DisplayName = handle
publicKey, privateKey, err := ed25519.GenerateKey(nil)
if err != nil {
log.Error().Err(err).Msg("Failed to generate key pair for new local account")
return nil, err
}
acc.PrivateKey = privateKey
acc.PublicKey = publicKey
log.Debug(). log.Debug().
Str("account-handle", handle). Str("account-handle", handle).
Str("account-id", acc.ID). Str("account-id", acc.ID).
@ -267,7 +384,7 @@ func (a *Account) WebAuthnID() []byte {
func (u *Account) WebAuthnName() string { func (u *Account) WebAuthnName() string {
log.Trace().Caller().Send() log.Trace().Caller().Send()
return u.Handle return u.Username
} }
func (u *Account) WebAuthnDisplayName() string { func (u *Account) WebAuthnDisplayName() string {
@ -302,7 +419,7 @@ func (s *Storage) GetOrCreateUser(userID string) passkey.User {
Str("account-handle", userID). Str("account-handle", userID).
Msg("Looking for or creating account for passkey stuff") Msg("Looking for or creating account for passkey stuff")
acc := &Account{} acc := &Account{}
res := s.db.Where(Account{Handle: userID, Server: config.GlobalConfig.General.GetFullDomain()}). res := s.db.Where(Account{Username: userID, Server: config.GlobalConfig.General.GetFullDomain()}).
First(acc) First(acc)
if errors.Is(res.Error, gorm.ErrRecordNotFound) { if errors.Is(res.Error, gorm.ErrRecordNotFound) {
log.Debug().Str("account-handle", userID) log.Debug().Str("account-handle", userID)