package server import ( "context" "net/http" "slices" "strings" "time" "github.com/rs/zerolog/hlog" "github.com/rs/zerolog/log" "gitlab.com/mstarongitlab/goutils/other" "gitlab.com/mstarongitlab/linstrom/config" "gitlab.com/mstarongitlab/linstrom/storage" ) type HandlerBuilder func(http.Handler) http.Handler func ChainMiddlewares(base http.Handler, links ...HandlerBuilder) http.Handler { slices.Reverse(links) for _, f := range links { base = f(base) } return base } func ContextValsMiddleware(pairs map[any]any) HandlerBuilder { return func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() for key, val := range pairs { ctx = context.WithValue(ctx, key, val) } newRequest := r.WithContext(ctx) h.ServeHTTP(w, newRequest) }) } } func LoggingMiddleware(handler http.Handler) http.Handler { return ChainMiddlewares(handler, hlog.NewHandler(log.Logger), hlog.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) { if strings.HasPrefix(r.URL.Path, "/assets") { return } hlog.FromRequest(r).Info(). Str("method", r.Method). Stringer("url", r.URL). Int("status", status). Int("size", size). Dur("duration", duration). Send() }), hlog.RemoteAddrHandler("ip"), hlog.UserAgentHandler("user_agent"), hlog.RefererHandler("referer"), hlog.RequestIDHandler("req_id", "Request-Id"), ) } func passkeyIdToAccountIdTransformerMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s := StorageFromRequest(r) if s == nil { return } log := hlog.FromRequest(r) passkeyId, ok := r.Context().Value(ContextKeyPasskeyUsername).(string) if !ok { other.HttpErr( w, HttpErrIdMissingContextValue, "Actor name missing", http.StatusInternalServerError, ) return } log.Debug().Bytes("passkey-bytes", []byte(passkeyId)).Msg("Id from passkey auth") acc, err := s.FindAccountByPasskeyId([]byte(passkeyId)) if err != nil { other.HttpErr( w, HttpErrIdDbFailure, "Failed to get account from storage", http.StatusInternalServerError, ) return } r = r.WithContext(context.WithValue(r.Context(), ContextKeyActorId, acc.ID)) handler.ServeHTTP(w, r) }) } func profilingAuthenticationMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.FormValue("password") != config.GlobalConfig.Admin.ProfilingPassword { other.HttpErr(w, HttpErrIdNotAuthenticated, "Bad password", http.StatusUnauthorized) return } handler.ServeHTTP(w, r) }) } // Middleware for inserting a logged in account's id into the request context if a session exists // Does not cancel requests ever. If an error occurs, it's treated as if no session is set func checkSessionMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("sid") log := hlog.FromRequest(r) if err != nil { // No cookie is ok, this function is only for inserting account id into the context // if one exists, not for checking permissions log.Debug().Msg("No session cookie, passing along") handler.ServeHTTP(w, r) return } store := StorageFromRequest(r) session, ok := store.GetSession(cookie.Value) if !ok { // Failed to get session from cookie id. Log, then move on as if no session is set log.Warn(). Str("session-id", cookie.Value). Msg("Cookie with session id found, but session doesn't exist") handler.ServeHTTP(w, r) return } if session.Expires.Before(time.Now()) { // Session expired. Move on as if no session was set store.DeleteSession(cookie.Value) handler.ServeHTTP(w, r) return } acc, err := store.FindAccountByPasskeyId(session.UserID) if err != nil { // Failed to get account for passkey id. Log, then move on as if no session is set log.Error(). Err(err). Bytes("passkey-id", session.UserID). Msg("Failed to get account with passkey id while checking session. Ignoring session") handler.ServeHTTP(w, r) return } handler.ServeHTTP( w, r.WithContext( context.WithValue( r.Context(), ContextKeyActorId, acc.ID, ), ), ) }) } func requireValidSessionMiddleware( h func(http.ResponseWriter, *http.Request), ) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { _, ok := r.Context().Value(ContextKeyActorId).(string) if !ok { other.HttpErr( w, HttpErrIdNotAuthenticated, "Not authenticated", http.StatusUnauthorized, ) return } h(w, r) } } func buildRequirePermissionsMiddleware(permissionRole *storage.Role) HandlerBuilder { return func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { accId, ok := r.Context().Value(ContextKeyActorId).(string) if !ok { other.HttpErr( w, HttpErrIdNotAuthenticated, "Not authenticated", http.StatusUnauthorized, ) return } store := StorageFromRequest(r) log := hlog.FromRequest(r) acc, err := store.FindAccountById(accId) // Assumption: If this handler is hit, the middleware for checking if a session exists at all has already passed // and thus a valid account id must exist in the context if err != nil { log.Error(). Err(err). Str("account-id", accId). Msg("Error while getting account from session") other.HttpErr( w, HttpErrIdDbFailure, "Error while getting account from session", http.StatusInternalServerError, ) return } roles, err := store.FindRolesByNames(acc.Roles) // Assumption: There will always be at least two roles per user, default user and user-specific one if err != nil { other.HttpErr( w, HttpErrIdDbFailure, "Failed to get roles for account", http.StatusInternalServerError, ) return } collapsedRole := storage.CollapseRolesIntoOne(roles...) if !storage.CompareRoles(&collapsedRole, permissionRole) { other.HttpErr( w, HttpErrIdNotAuthenticated, "Insufficient permisions", http.StatusForbidden, ) return } h.ServeHTTP(w, r) }) } }