Fix passkey authentication
Also prep for better router layout
This commit is contained in:
parent
e2260e4a0f
commit
b9eb4234f4
11 changed files with 289 additions and 21 deletions
2
go.mod
2
go.mod
|
@ -20,7 +20,7 @@ require (
|
|||
github.com/rs/zerolog v1.33.0
|
||||
github.com/xhit/go-simple-mail/v2 v2.16.0
|
||||
gitlab.com/mstarongitlab/goap v1.1.0
|
||||
gitlab.com/mstarongitlab/goutils v1.3.0
|
||||
gitlab.com/mstarongitlab/goutils v1.4.1
|
||||
golang.org/x/image v0.20.0
|
||||
gorm.io/driver/postgres v1.5.7
|
||||
gorm.io/gorm v1.25.10
|
||||
|
|
2
go.sum
2
go.sum
|
@ -304,6 +304,8 @@ gitlab.com/mstarongitlab/goap v1.1.0 h1:uN05RP+Tq2NR2IuPq6XQa5oLpfailpoEvxo1Sfeh
|
|||
gitlab.com/mstarongitlab/goap v1.1.0/go.mod h1:rt9IYvJBPh1z6t+vvzifmxDtGjGlr8683tSPfa5dbXI=
|
||||
gitlab.com/mstarongitlab/goutils v1.3.0 h1:uuxPHjIU36lyJ8/z4T2xI32zOyh53Xj0Au8K12qkaJ4=
|
||||
gitlab.com/mstarongitlab/goutils v1.3.0/go.mod h1:SvqfzFxgashuZPqR9kPwQ9gFA7I1yskZjhmGmY2pAow=
|
||||
gitlab.com/mstarongitlab/goutils v1.4.1 h1:g6bLX1gGqQeoRwmFC0Aw9lvxyw22cWjyG7RXsi+JmlI=
|
||||
gitlab.com/mstarongitlab/goutils v1.4.1/go.mod h1:SvqfzFxgashuZPqR9kPwQ9gFA7I1yskZjhmGmY2pAow=
|
||||
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
|
||||
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
||||
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
|
|
3
main.go
3
main.go
|
@ -3,7 +3,6 @@ package main
|
|||
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
@ -77,8 +76,6 @@ func main() {
|
|||
log.Fatal().Err(err).Msg("Failed to setup passkey support")
|
||||
}
|
||||
|
||||
fmt.Println(nojsFS.ReadDir("frontend-noscript"))
|
||||
|
||||
server := server.NewServer(
|
||||
store,
|
||||
pkey,
|
||||
|
|
16
server/constants.go
Normal file
16
server/constants.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
package server
|
||||
|
||||
const ContextKeyPasskeyUsername = "context-passkey-username"
|
||||
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
ContextKeyStorage ContextKey = "Context key for storage"
|
||||
ContextKeyActorId ContextKey = "Context key for actor id"
|
||||
)
|
||||
|
||||
const (
|
||||
HttpErrIdPlaceholder = iota
|
||||
HttpErrIdMissingContextValue
|
||||
HttpErrIdDbFailure
|
||||
)
|
22
server/contextUtils.go
Normal file
22
server/contextUtils.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"gitlab.com/mstarongitlab/goutils/other"
|
||||
"gitlab.com/mstarongitlab/linstrom/storage"
|
||||
)
|
||||
|
||||
func StorageFromRequest(w http.ResponseWriter, r *http.Request) *storage.Storage {
|
||||
store, ok := r.Context().Value(ContextKeyStorage).(*storage.Storage)
|
||||
if !ok {
|
||||
other.HttpErr(
|
||||
w,
|
||||
HttpErrIdMissingContextValue,
|
||||
"Missing storage reference",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
return store
|
||||
}
|
14
server/frontend.go
Normal file
14
server/frontend.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func setupFrontendRouter(interactiveFs, noscriptFs fs.FS) http.Handler {
|
||||
router := http.NewServeMux()
|
||||
router.Handle("/noscript/", http.StripPrefix("/noscript", http.FileServerFS(noscriptFs)))
|
||||
router.Handle("/", http.FileServerFS(interactiveFs))
|
||||
|
||||
return router
|
||||
}
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/rs/zerolog/hlog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gitlab.com/mstarongitlab/goutils/other"
|
||||
)
|
||||
|
||||
type HandlerBuilder func(http.Handler) http.Handler
|
||||
|
@ -51,3 +52,36 @@ func LoggingMiddleware(handler http.Handler) http.Handler {
|
|||
hlog.RequestIDHandler("req_id", "Request-Id"),
|
||||
)
|
||||
}
|
||||
|
||||
func passkeyIdToAccountIdTransformerMiddleware(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s := StorageFromRequest(w, r)
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
log := hlog.FromRequest(r)
|
||||
passkeyId, ok := r.Context().Value(ContextKeyPasskeyUsername).(string)
|
||||
if !ok {
|
||||
other.HttpErr(
|
||||
w,
|
||||
HttpErrIdMissingContextValue,
|
||||
"Actor name missing",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
return
|
||||
}
|
||||
log.Debug().Bytes("passkey-bytes", []byte(passkeyId)).Msg("Id from passkey auth")
|
||||
acc, err := s.FindAccountByPasskeyId([]byte(passkeyId))
|
||||
if err != nil {
|
||||
other.HttpErr(
|
||||
w,
|
||||
HttpErrIdDbFailure,
|
||||
"Failed to get account from storage",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
return
|
||||
}
|
||||
r = r.WithContext(context.WithValue(r.Context(), ContextKeyActorId, acc.ID))
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/mstarongithub/passkey"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -17,7 +18,9 @@ type Server struct {
|
|||
|
||||
func NewServer(store *storage.Storage, pkey *passkey.Passkey, reactiveFS, staticFS fs.FS) *Server {
|
||||
handler := buildRootHandler(pkey, reactiveFS, staticFS)
|
||||
handler = ChainMiddlewares(handler, LoggingMiddleware, ContextValsMiddleware(map[any]any{}))
|
||||
handler = ChainMiddlewares(handler, LoggingMiddleware, ContextValsMiddleware(map[any]any{
|
||||
ContextKeyStorage: store,
|
||||
}))
|
||||
return &Server{
|
||||
store: store,
|
||||
router: handler,
|
||||
|
@ -27,15 +30,22 @@ func NewServer(store *storage.Storage, pkey *passkey.Passkey, reactiveFS, static
|
|||
func buildRootHandler(pkey *passkey.Passkey, reactiveFS, staticFS fs.FS) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
pkey.MountRoutes(mux, "/webauthn/")
|
||||
mux.Handle("/", http.FileServerFS(reactiveFS))
|
||||
mux.Handle("/nojs/", http.StripPrefix("/nojs", http.FileServerFS(staticFS)))
|
||||
mux.Handle("/", setupFrontendRouter(reactiveFS, staticFS))
|
||||
mux.Handle("/pk/", http.StripPrefix("/pk", http.FileServer(http.Dir("pk-auth"))))
|
||||
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, true) })
|
||||
mux.Handle(
|
||||
"/authonly/",
|
||||
pkey.Auth(
|
||||
ContextKeyPasskeyUsername,
|
||||
nil,
|
||||
passkey.RedirectUnauthorized(url.URL{Path: "/"}),
|
||||
)(ChainMiddlewares(setupTestEndpoints(), passkeyIdToAccountIdTransformerMiddleware)),
|
||||
)
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
func (s *Server) Start(addr string) {
|
||||
func (s *Server) Start(addr string) error {
|
||||
log.Info().Str("addr", addr).Msg("Starting server")
|
||||
http.ListenAndServe(addr, s.router)
|
||||
return http.ListenAndServe(addr, s.router)
|
||||
}
|
||||
|
|
16
server/testingEndpoints.go
Normal file
16
server/testingEndpoints.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func setupTestEndpoints() http.Handler {
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc(
|
||||
"/",
|
||||
func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "test root") },
|
||||
)
|
||||
|
||||
return router
|
||||
}
|
|
@ -10,6 +10,8 @@ import (
|
|||
// various prefixes for accessing items in the cache (since it's a simple key-value store)
|
||||
const (
|
||||
cacheUserHandleToIdPrefix = "acc-name-to-id:"
|
||||
cacheLocalUsernameToIdPrefix = "acc-local-name-to-id:"
|
||||
cachePasskeyIdToAccIdPrefix = "acc-pkey-id-to-id:"
|
||||
cacheUserIdToAccPrefix = "acc-id-to-data:"
|
||||
cacheNoteIdToNotePrefix = "note-id-to-data:"
|
||||
)
|
||||
|
@ -17,7 +19,7 @@ const (
|
|||
// An error describing the case where some value was just not found in the cache
|
||||
var errCacheNotFound = errors.New("not found in cache")
|
||||
|
||||
// Find an account id in cache using a given user handle
|
||||
// Find an account id in cache using a given user handle ("@bob@example.com" or "bob@example.com")
|
||||
// accId contains the Id of the account if found
|
||||
// err contains an error describing why an account's id couldn't be found
|
||||
// The most common one should be errCacheNotFound
|
||||
|
@ -38,6 +40,44 @@ func (s *Storage) cacheHandleToAccUid(handle string) (accId *string, err error)
|
|||
return &target, nil
|
||||
}
|
||||
|
||||
// Find a local account's id in cache using a given username ("bob")
|
||||
// accId containst the Id of the account if found
|
||||
// err contains an error describing why an account's id couldn't be found
|
||||
// The most common one should be errCacheNotFound
|
||||
func (s *Storage) cacheLocalUsernameToAccUid(username string) (accId *string, err error) {
|
||||
// Where to put the data (in case it's found)
|
||||
var target string
|
||||
found, err := s.cache.Get(cacheLocalUsernameToIdPrefix+username, &target)
|
||||
// If nothing was found, check error
|
||||
if !found {
|
||||
// Case error is set and NOT redis' error for nothing found: Return that error
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return nil, err
|
||||
} else {
|
||||
// Else return errCacheNotFound
|
||||
return nil, errCacheNotFound
|
||||
}
|
||||
}
|
||||
return &target, nil
|
||||
}
|
||||
|
||||
func (s *Storage) cachePkeyIdToAccId(pkeyId []byte) (accId *string, err error) {
|
||||
// Where to put the data (in case it's found)
|
||||
var target string
|
||||
found, err := s.cache.Get(cachePasskeyIdToAccIdPrefix+string(pkeyId), &target)
|
||||
// If nothing was found, check error
|
||||
if !found {
|
||||
// Case error is set and NOT redis' error for nothing found: Return that error
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return nil, err
|
||||
} else {
|
||||
// Else return errCacheNotFound
|
||||
return nil, errCacheNotFound
|
||||
}
|
||||
}
|
||||
return &target, nil
|
||||
}
|
||||
|
||||
// Find an account's data in cache using a given account id
|
||||
// acc contains the full account as stored last time if found
|
||||
// err contains an error describing why an account couldn't be found
|
||||
|
|
131
storage/user.go
131
storage/user.go
|
@ -1,6 +1,7 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"strings"
|
||||
|
@ -20,10 +21,10 @@ import (
|
|||
// If remote, this is used for caching the account
|
||||
type Account struct {
|
||||
ID string `gorm:"primarykey"` // ID is a uuid for this account
|
||||
// Handle of the user (eg "max" if the full username is @max@example.com)
|
||||
// Username of the user (eg "max" if the full username is @max@example.com)
|
||||
// Assume unchangable (once set by a user) to be kind to other implementations
|
||||
// Would be an easy avenue to fuck with them though
|
||||
Handle string
|
||||
Username string
|
||||
CreatedAt time.Time // When this entry was created. Automatically set by gorm
|
||||
// When this account was last updated. Will also be used for refreshing remote accounts. Automatically set by gorm
|
||||
UpdatedAt time.Time
|
||||
|
@ -44,7 +45,7 @@ type Account struct {
|
|||
Background string // ID of a media file used as background image
|
||||
Banner string // ID of a media file used as banner
|
||||
Indexable bool // Whether this account can be found by crawlers
|
||||
PublicKeyPem *string // The public key of the account
|
||||
PublicKey []byte // The public key of the account
|
||||
// Whether this account restricts following
|
||||
// If true, the owner must approve of a follow request first
|
||||
RestrictedFollow bool
|
||||
|
@ -61,7 +62,7 @@ type Account struct {
|
|||
|
||||
// --- And internal account stuff ---
|
||||
// Still public fields since they wouldn't be able to be stored in the db otherwise
|
||||
PrivateKeyPem *string // The private key of the account. Nil if remote user
|
||||
PrivateKey []byte // The private key of the account. Nil if remote user
|
||||
WebAuthnId []byte // The unique and random ID of this account used for passkey authentication
|
||||
// Whether the account got verified and is allowed to be active
|
||||
// For local accounts being active means being allowed to login and perform interactions
|
||||
|
@ -176,6 +177,114 @@ func (s *Storage) FindAccountById(id string) (*Account, error) {
|
|||
return acc, nil
|
||||
}
|
||||
|
||||
func (s *Storage) FindLocalAccountByUsername(username string) (*Account, error) {
|
||||
log.Trace().Caller().Send()
|
||||
log.Debug().Str("account-username", username).Msg("Looking for local account")
|
||||
log.Debug().Str("account-username", username).Msg("Checking cache first")
|
||||
|
||||
// Try and find the account in cache first
|
||||
cacheAccId, err := s.cacheLocalUsernameToAccUid(username)
|
||||
if err == nil {
|
||||
log.Info().Str("account-username", username).Msg("Hit account handle in cache")
|
||||
// Then always load via id since unique key access should be faster than string matching
|
||||
return s.FindAccountById(*cacheAccId)
|
||||
} else {
|
||||
if !errors.Is(err, errCacheNotFound) {
|
||||
log.Error().Err(err).Str("account-username", username).Msg("Problem while checking cache for account")
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Failed to find in cache, go the slow route of hitting the db
|
||||
log.Debug().Str("account-username", username).Msg("Didn't hit account in cache, going to db")
|
||||
if err != nil {
|
||||
log.Warn().
|
||||
Err(err).
|
||||
Str("account-username", username).
|
||||
Msg("Failed to split up account username")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
acc := Account{}
|
||||
res := s.db.Where("username = ?", username).
|
||||
Where("server = ?", config.GlobalConfig.General.GetFullDomain()).
|
||||
First(&acc)
|
||||
if res.Error != nil {
|
||||
if errors.Is(res.Error, gorm.ErrRecordNotFound) {
|
||||
log.Info().
|
||||
Str("account-username", username).
|
||||
Msg("Local account with username not found")
|
||||
} else {
|
||||
log.Error().Err(err).Str("account-username", username).Msg("Failed to get local account with username")
|
||||
}
|
||||
return nil, res.Error
|
||||
}
|
||||
log.Info().Str("account-username", username).Msg("Found account, also inserting into cache")
|
||||
if err = s.cache.Set(cacheUserIdToAccPrefix+acc.ID, &acc); err != nil {
|
||||
log.Warn().
|
||||
Err(err).
|
||||
Str("account-username", username).
|
||||
Msg("Found account but failed to insert into cache")
|
||||
}
|
||||
if err = s.cache.Set(cacheLocalUsernameToIdPrefix+username, acc.ID); err != nil {
|
||||
log.Warn().
|
||||
Err(err).
|
||||
Str("account-username", username).
|
||||
Msg("Failed to store local username to id in cache")
|
||||
}
|
||||
return &acc, nil
|
||||
}
|
||||
|
||||
func (s *Storage) FindAccountByPasskeyId(pkeyId []byte) (*Account, error) {
|
||||
log.Trace().Caller().Send()
|
||||
log.Debug().Bytes("account-passkey-id", pkeyId).Msg("Looking for account")
|
||||
log.Debug().Bytes("account-passkey-id", pkeyId).Msg("Checking cache first")
|
||||
|
||||
// Try and find the account in cache first
|
||||
cacheAccId, err := s.cachePkeyIdToAccId(pkeyId)
|
||||
if err == nil {
|
||||
log.Info().Bytes("account-passkey-id", pkeyId).Msg("Hit passkey id in cache")
|
||||
// Then always load via id since unique key access should be faster than string matching
|
||||
return s.FindAccountById(*cacheAccId)
|
||||
} else {
|
||||
if err != errCacheNotFound {
|
||||
log.Error().Err(err).Bytes("account-passkey-id", pkeyId).Msg("Problem while checking cache for account")
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Failed to find in cache, go the slow route of hitting the db
|
||||
log.Debug().Bytes("account-passkey-id", pkeyId).Msg("Didn't hit account in cache, going to db")
|
||||
|
||||
acc := Account{}
|
||||
res := s.db.Where("web_authn_id = ?", pkeyId).
|
||||
First(&acc)
|
||||
if res.Error != nil {
|
||||
if res.Error == gorm.ErrRecordNotFound {
|
||||
log.Info().
|
||||
Bytes("account-passkey-id", pkeyId).
|
||||
Msg("Local account with passkey id not found")
|
||||
} else {
|
||||
log.Error().Err(res.Error).Bytes("account-passkey-id", pkeyId).Msg("Failed to get local account with passkey id")
|
||||
}
|
||||
return nil, res.Error
|
||||
}
|
||||
log.Info().Bytes("account-passkey-id", pkeyId).Msg("Found account, also inserting into cache")
|
||||
// if err = s.cache.Set(cacheUserIdToAccPrefix+acc.ID, &acc); err != nil {
|
||||
// log.Warn().
|
||||
// Err(err).
|
||||
// Bytes("account-passkey-id", pkeyId).
|
||||
// Msg("Found account but failed to insert into cache")
|
||||
// }
|
||||
// if err = s.cache.Set(cachePasskeyIdToAccIdPrefix+string(pkeyId), acc.ID); err != nil {
|
||||
// log.Warn().
|
||||
// Err(err).
|
||||
// Bytes("account-passkey-id", pkeyId).
|
||||
// Msg("Failed to store local username to id in cache")
|
||||
// }
|
||||
return &acc, nil
|
||||
}
|
||||
|
||||
// Update a given account in storage and cache
|
||||
func (s *Storage) UpdateAccount(acc *Account) error {
|
||||
// If the account is nil or doesn't have an id, error out
|
||||
|
@ -236,11 +345,19 @@ func (s *Storage) NewLocalAccount(handle string) (*Account, error) {
|
|||
log.Error().Err(err).Msg("Failed to create empty account for use")
|
||||
return nil, err
|
||||
}
|
||||
acc.Handle = handle
|
||||
acc.Username = handle
|
||||
acc.Server = config.GlobalConfig.General.GetFullDomain()
|
||||
acc.Remote = false
|
||||
acc.DisplayName = handle
|
||||
|
||||
publicKey, privateKey, err := ed25519.GenerateKey(nil)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to generate key pair for new local account")
|
||||
return nil, err
|
||||
}
|
||||
acc.PrivateKey = privateKey
|
||||
acc.PublicKey = publicKey
|
||||
|
||||
log.Debug().
|
||||
Str("account-handle", handle).
|
||||
Str("account-id", acc.ID).
|
||||
|
@ -267,7 +384,7 @@ func (a *Account) WebAuthnID() []byte {
|
|||
|
||||
func (u *Account) WebAuthnName() string {
|
||||
log.Trace().Caller().Send()
|
||||
return u.Handle
|
||||
return u.Username
|
||||
}
|
||||
|
||||
func (u *Account) WebAuthnDisplayName() string {
|
||||
|
@ -302,7 +419,7 @@ func (s *Storage) GetOrCreateUser(userID string) passkey.User {
|
|||
Str("account-handle", userID).
|
||||
Msg("Looking for or creating account for passkey stuff")
|
||||
acc := &Account{}
|
||||
res := s.db.Where(Account{Handle: userID, Server: config.GlobalConfig.General.GetFullDomain()}).
|
||||
res := s.db.Where(Account{Username: userID, Server: config.GlobalConfig.General.GetFullDomain()}).
|
||||
First(acc)
|
||||
if errors.Is(res.Error, gorm.ErrRecordNotFound) {
|
||||
log.Debug().Str("account-handle", userID)
|
||||
|
|
Loading…
Reference in a new issue