Compare commits

..

4 commits

Author SHA1 Message Date
582988add2
Add go:generate command and new field to user
Some checks are pending
/ test (push) Waiting to run
2025-03-31 08:08:01 +02:00
8ba5e98c50
Idk why, but testcontainers is borked on Discovery 2025-03-31 08:07:24 +02:00
7d385e48de
More work on auth system 2025-03-31 08:07:16 +02:00
2afb14c4b3
Copy and adapt role helper generator to new system 2025-03-31 08:06:43 +02:00
11 changed files with 650 additions and 72 deletions

View file

@ -17,5 +17,7 @@ type Authenticator struct {
} }
func calcAccessExpirationTimestamp() time.Time { func calcAccessExpirationTimestamp() time.Time {
// For now, the default expiration is one month after creation
// though "never" might also be a good option
return time.Now().Add(time.Hour * 24 * 30) return time.Now().Add(time.Hour * 24 * 30)
} }

View file

@ -9,4 +9,6 @@ var (
ErrUnsupportedAuthMethod = errors.New("authentication method not supported for this user") ErrUnsupportedAuthMethod = errors.New("authentication method not supported for this user")
ErrInvalidCombination = errors.New("invalid account and token combination") ErrInvalidCombination = errors.New("invalid account and token combination")
ErrProcessTimeout = errors.New("authentication process timed out") ErrProcessTimeout = errors.New("authentication process timed out")
// A user may not login, for whatever reason
ErrCantLogin = errors.New("user can't login")
) )

View file

@ -5,10 +5,12 @@ import (
"git.mstar.dev/mstar/goutils/other" "git.mstar.dev/mstar/goutils/other"
"git.mstar.dev/mstar/goutils/sliceutils" "git.mstar.dev/mstar/goutils/sliceutils"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"git.mstar.dev/mstar/linstrom/storage-new"
"git.mstar.dev/mstar/linstrom/storage-new/dbgen" "git.mstar.dev/mstar/linstrom/storage-new/dbgen"
"git.mstar.dev/mstar/linstrom/storage-new/models" "git.mstar.dev/mstar/linstrom/storage-new/models"
) )
@ -79,8 +81,10 @@ func (a *Authenticator) StartPasswordLogin(
username string, username string,
password string, password string,
) (nextState LoginNextState, token string, err error) { ) (nextState LoginNextState, token string, err error) {
var acc *models.User if ok, err := a.canUsernameLogin(username); !ok {
acc, err = dbgen.User.Where(dbgen.User.Username.Eq(username)).First() return LoginNextFailure, "", other.Error("auth", "user may not login", err)
}
acc, err := dbgen.User.Where(dbgen.User.Username.Eq(username)).First()
switch err { switch err {
case nil: case nil:
break break
@ -143,10 +147,10 @@ func (a *Authenticator) StartPasswordLogin(
loginToken := models.LoginProcessToken{ loginToken := models.LoginProcessToken{
User: *acc, User: *acc,
UserId: acc.ID, UserId: acc.ID,
ExpiresAt: time.Now().Add(time.Minute * 5), ExpiresAt: calcAccessExpirationTimestamp(),
Token: uuid.NewString(),
} }
err = dbgen.LoginProcessToken.Clauses(clause.OnConflict{DoNothing: true}). err = dbgen.LoginProcessToken.Clauses(clause.OnConflict{UpdateAll: true}).
Omit(dbgen.LoginProcessToken.Token).
Create(&loginToken) Create(&loginToken)
if err != nil { if err != nil {
@ -159,3 +163,24 @@ func (a *Authenticator) StartPasswordLogin(
return nextStates, loginToken.Token, nil return nextStates, loginToken.Token, nil
} }
func (a *Authenticator) canUsernameLogin(username string) (bool, error) {
acc, err := dbgen.User.Where(dbgen.User.Username.Eq(username)).First()
if err != nil {
return false, err
}
if !acc.FinishedRegistration {
return false, ErrCantLogin
}
// TODO: Check roles too
finalRole := storage.CollapseRolesIntoOne(
sliceutils.Map(acc.Roles, func(t models.UserToRole) models.Role {
return t.Role
})...)
if finalRole.CanLogin != nil && !*finalRole.CanLogin {
return false, ErrCantLogin
}
return true, nil
}

View file

@ -15,6 +15,9 @@ import (
) )
func (a *Authenticator) StartPasskeyLogin(username string) (*protocol.CredentialAssertion, error) { func (a *Authenticator) StartPasskeyLogin(username string) (*protocol.CredentialAssertion, error) {
if ok, err := a.canUsernameLogin(username); !ok {
return nil, other.Error("auth", "user may not login", err)
}
acc, err := dbgen.User.Where(dbgen.User.Username.Eq(username)).First() acc, err := dbgen.User.Where(dbgen.User.Username.Eq(username)).First()
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -0,0 +1,195 @@
/*
Tool for generating helper functions for storage.Role structs inside of the storage package
It generates the following functions:
- CollapseRolesIntoOne: Collapse a list of roles into one singular role. Each value will be set to the
value of the role with the highest priority
- RoleDeepCopy: Copy a role, including all arrays. Every value will be copied too
- CompareRoles: Compare two roles. Returns true only if all fields are equal
(if one of the fields is nil, that field is seen as equal)
*/
package main
import (
"flag"
"fmt"
"io"
"os"
"regexp"
"strings"
"git.mstar.dev/mstar/goutils/sliceutils"
)
var findRoleStructRegex = regexp.MustCompile(`type Role struct \{([\s\S]+)\}\n\n/\*`)
var (
flagInputFile = flag.String("input", "", "Specify the input file. If empty, read from stdin")
flagOutputFile = flag.String(
"output",
"",
"Specify the output file. If empty, writes to stdout",
)
flagOutputModule = flag.String(
"mod",
"",
"Module for the output file. If empty, use mod of input file",
)
)
func main() {
flag.Parse()
var input io.Reader
var output io.Writer
if *flagInputFile == "" {
input = os.Stdin
} else {
file, err := os.Open(*flagInputFile)
if err != nil {
panic(err)
}
defer file.Close()
input = file
}
if *flagOutputFile == "" {
output = os.Stdout
} else {
file, err := os.Create(*flagOutputFile)
if err != nil {
panic(err)
}
defer file.Close()
output = file
}
data, err := io.ReadAll(input)
if err != nil {
panic(err)
}
if !findRoleStructRegex.Match(data) {
panic("Input doesn't contain role struct")
}
content := findRoleStructRegex.FindStringSubmatch(string(data))[1]
lines := strings.Split(content, "\n")
lines = sliceutils.Map(lines, func(t string) string { return strings.TrimSpace(t) })
importantLines := sliceutils.Filter(lines, func(t string) bool {
if strings.HasPrefix(t, "//") {
return false
}
if strings.Contains(t, "gorm.Model") {
return false
}
data := sliceutils.Filter(strings.Split(t, " "), func(t string) bool { return t != "" })
if len(data) < 2 {
return false
}
if !strings.HasPrefix(data[1], "*") && !strings.HasPrefix(data[1], "[]") {
return false
}
return true
})
nameTypeMap := map[string]string{}
for _, line := range importantLines {
parts := sliceutils.Filter(strings.Split(line, " "), func(t string) bool { return t != "" })
nameTypeMap[parts[0]] = parts[1]
}
pkgString, _, _ := strings.Cut(string(data), "\n")
outBuilder := strings.Builder{}
outBuilder.WriteString(`// Code generated by cmd/RolesGenerator DO NOT EDIT.
// If you need to refresh the content, run go generate again
`)
if *flagOutputModule == "" {
outBuilder.WriteString(pkgString + "\n\n")
} else {
outBuilder.WriteString("package " + *flagOutputModule + "\n\n")
}
outBuilder.WriteString(
`import (
"slices"
"git.mstar.dev/mstar/goutils/sliceutils"
"git.mstar.dev/mstar/linstrom/storage-new/models"
)
`,
)
// Build role collapse function
outBuilder.WriteString(
`func CollapseRolesIntoOne(roles ...models.Role) models.Role {
startingRole := RoleDeepCopy(models.DefaultUserRole)
slices.SortFunc(roles, func(a, b models.Role) int { return int(int64(a.Priority)-int64(b.Priority)) })
for _, role := range roles {
`)
// Write all the stupid conditions here
for valName, valType := range nameTypeMap {
if strings.HasPrefix(valType, "[]") {
outBuilder.WriteString(fmt.Sprintf(` if role.%s != nil {
startingRole.%s = append(startingRole.%s, role.%s...)
}
`, valName, valName, valName, valName))
} else {
outBuilder.WriteString(fmt.Sprintf(` if role.%s != nil {
*startingRole.%s = *role.%s
}
`, valName, valName, valName))
}
}
// Then finish up with the end of the function
outBuilder.WriteString(
` }
return startingRole
}
`)
// Then build the deep copy function
outBuilder.WriteString("\nfunc RoleDeepCopy(o models.Role) models.Role {\n")
outBuilder.WriteString(` n := models.Role{}
n.Model = o.Model
n.Name = o.Name
n.Priority = o.Priority
n.IsUserRole = o.IsUserRole
n.IsBuiltIn = o.IsBuiltIn
`)
for valName, valType := range nameTypeMap {
if strings.HasPrefix(valType, "[]") {
outBuilder.WriteString(fmt.Sprintf(" n.%s = slices.Clone(o.%s)\n", valName, valName))
} else {
outBuilder.WriteString(fmt.Sprintf(` if o.%s == nil { n.%s = nil } else {
t := *o.%s
n.%s = &t
}
`, valName, valName, valName, valName))
}
}
outBuilder.WriteString(" return n\n}\n\n")
// Build compare function
outBuilder.WriteString("func CompareRoles(a, b *models.Role) bool {\n")
outBuilder.WriteString(" return ")
lastName, lastType := "", ""
for valName, valType := range nameTypeMap {
lastName = valName
lastType = valType
outBuilder.WriteString(fmt.Sprintf("(a.%s == nil || b.%s == nil || ", valName, valName))
if strings.HasPrefix(valType, "[]") {
outBuilder.WriteString(
fmt.Sprintf("sliceutils.CompareUnordered(a.%s,b.%s)) && ", valName, valName),
)
} else {
outBuilder.WriteString(fmt.Sprintf("a.%s == b.%s) && ", valName, valName))
}
}
outBuilder.WriteString("(a == nil || b == nil || ")
if strings.HasPrefix(lastType, "[]") {
outBuilder.WriteString(
fmt.Sprintf("sliceutils.CompareUnordered(a.%s,b.%s))", lastName, lastName),
)
} else {
outBuilder.WriteString(fmt.Sprintf("a.%s == b.%s)", lastName, lastName))
}
outBuilder.WriteString("\n}")
// And write the entire thing to the output
fmt.Fprint(output, outBuilder.String())
}

View file

@ -1,19 +1,15 @@
package main package main
import ( import (
"context"
"flag" "flag"
"time"
"git.mstar.dev/mstar/goutils/other" "git.mstar.dev/mstar/goutils/other"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/testcontainers/testcontainers-go"
postgresContainer "github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/gen" "gorm.io/gen"
"gorm.io/gorm" "gorm.io/gorm"
"git.mstar.dev/mstar/linstrom/config"
"git.mstar.dev/mstar/linstrom/shared" "git.mstar.dev/mstar/linstrom/shared"
"git.mstar.dev/mstar/linstrom/storage-new/models" "git.mstar.dev/mstar/linstrom/storage-new/models"
) )
@ -36,36 +32,36 @@ func main() {
other.SetupFlags() other.SetupFlags()
flag.Parse() flag.Parse()
other.ConfigureLogging(nil) other.ConfigureLogging(nil)
// config.ReadAndWriteToGlobal(*shared.FlagConfigFile) config.ReadAndWriteToGlobal(*shared.FlagConfigFile)
// Set up a temporary postgres container for gorm-gen to do its thing // Set up a temporary postgres container for gorm-gen to do its thing
log.Info().Msg("Starting temporary postgres container") // log.Info().Msg("Starting temporary postgres container")
pgContainer, err := postgresContainer.Run( // pgContainer, err := postgresContainer.Run(
context.Background(), // context.Background(),
"postgres:16.4-alpine", // "postgres:16.4-alpine",
postgresContainer.WithDatabase(dbName), // postgresContainer.WithDatabase(dbName),
postgresContainer.WithUsername(dbUser), // postgresContainer.WithUsername(dbUser),
postgresContainer.WithPassword(dbPass), // postgresContainer.WithPassword(dbPass),
testcontainers.WithWaitStrategyAndDeadline( // testcontainers.WithWaitStrategyAndDeadline(
time.Minute, // time.Minute,
wait.ForLog("database system is ready to accept connections"). // wait.ForLog("database system is ready to accept connections").
WithOccurrence(2). // WithOccurrence(2).
WithStartupTimeout(time.Second*5), // WithStartupTimeout(time.Second*5),
), // ),
) // )
if err != nil { // if err != nil {
log.Fatal().Err(err).Msg("Failed to setup temporary postgres container") // log.Fatal().Err(err).Msg("Failed to setup temporary postgres container")
} // }
log.Info().Msg("Temporary postgres container started") // log.Info().Msg("Temporary postgres container started")
defer func() { // defer func() {
if err := testcontainers.TerminateContainer(pgContainer); err != nil { // if err := testcontainers.TerminateContainer(pgContainer); err != nil {
log.Fatal().Err(err).Msg("Failed to terminate temporary postgres container") // log.Fatal().Err(err).Msg("Failed to terminate temporary postgres container")
} // }
log.Info().Msg("Temporary postgres container stopped") // log.Info().Msg("Temporary postgres container stopped")
}() // }()
db, err := gorm.Open( db, err := gorm.Open(
// postgres.Open(config.GlobalConfig.Storage.BuildPostgresDSN()), postgres.Open(config.GlobalConfig.Storage.BuildPostgresDSN()),
postgres.Open(pgContainer.MustConnectionString(context.Background())), // postgres.Open(pgContainer.MustConnectionString(context.Background())),
&gorm.Config{ &gorm.Config{
PrepareStmt: false, PrepareStmt: false,
Logger: shared.NewGormLogger(log.Logger), Logger: shared.NewGormLogger(log.Logger),

View file

@ -28,6 +28,7 @@ func newLoginProcessToken(db *gorm.DB, opts ...gen.DOOption) loginProcessToken {
_loginProcessToken.ALL = field.NewAsterisk(tableName) _loginProcessToken.ALL = field.NewAsterisk(tableName)
_loginProcessToken.UserId = field.NewString(tableName, "user_id") _loginProcessToken.UserId = field.NewString(tableName, "user_id")
_loginProcessToken.Token = field.NewString(tableName, "token") _loginProcessToken.Token = field.NewString(tableName, "token")
_loginProcessToken.ExpiresAt = field.NewTime(tableName, "expires_at")
_loginProcessToken.User = loginProcessTokenBelongsToUser{ _loginProcessToken.User = loginProcessTokenBelongsToUser{
db: db.Session(&gorm.Session{}), db: db.Session(&gorm.Session{}),
@ -180,6 +181,7 @@ type loginProcessToken struct {
ALL field.Asterisk ALL field.Asterisk
UserId field.String UserId field.String
Token field.String Token field.String
ExpiresAt field.Time
User loginProcessTokenBelongsToUser User loginProcessTokenBelongsToUser
fieldMap map[string]field.Expr fieldMap map[string]field.Expr
@ -199,6 +201,7 @@ func (l *loginProcessToken) updateTableName(table string) *loginProcessToken {
l.ALL = field.NewAsterisk(table) l.ALL = field.NewAsterisk(table)
l.UserId = field.NewString(table, "user_id") l.UserId = field.NewString(table, "user_id")
l.Token = field.NewString(table, "token") l.Token = field.NewString(table, "token")
l.ExpiresAt = field.NewTime(table, "expires_at")
l.fillFieldMap() l.fillFieldMap()
@ -215,9 +218,10 @@ func (l *loginProcessToken) GetFieldByName(fieldName string) (field.OrderExpr, b
} }
func (l *loginProcessToken) fillFieldMap() { func (l *loginProcessToken) fillFieldMap() {
l.fieldMap = make(map[string]field.Expr, 3) l.fieldMap = make(map[string]field.Expr, 4)
l.fieldMap["user_id"] = l.UserId l.fieldMap["user_id"] = l.UserId
l.fieldMap["token"] = l.Token l.fieldMap["token"] = l.Token
l.fieldMap["expires_at"] = l.ExpiresAt
} }

View file

@ -36,14 +36,16 @@ func newUser(db *gorm.DB, opts ...gen.DOOption) user {
_user.Description = field.NewString(tableName, "description") _user.Description = field.NewString(tableName, "description")
_user.IsBot = field.NewBool(tableName, "is_bot") _user.IsBot = field.NewBool(tableName, "is_bot")
_user.IconId = field.NewString(tableName, "icon_id") _user.IconId = field.NewString(tableName, "icon_id")
_user.BackgroundId = field.NewString(tableName, "background_id") _user.BackgroundId = field.NewField(tableName, "background_id")
_user.BannerId = field.NewString(tableName, "banner_id") _user.BannerId = field.NewField(tableName, "banner_id")
_user.Indexable = field.NewBool(tableName, "indexable") _user.Indexable = field.NewBool(tableName, "indexable")
_user.PublicKey = field.NewBytes(tableName, "public_key") _user.PublicKey = field.NewBytes(tableName, "public_key")
_user.RestrictedFollow = field.NewBool(tableName, "restricted_follow") _user.RestrictedFollow = field.NewBool(tableName, "restricted_follow")
_user.Location = field.NewString(tableName, "location") _user.Location = field.NewField(tableName, "location")
_user.Birthday = field.NewTime(tableName, "birthday") _user.Birthday = field.NewField(tableName, "birthday")
_user.Verified = field.NewBool(tableName, "verified") _user.Verified = field.NewBool(tableName, "verified")
_user.PasskeyId = field.NewBytes(tableName, "passkey_id")
_user.FinishedRegistration = field.NewBool(tableName, "finished_registration")
_user.RemoteInfo = userHasOneRemoteInfo{ _user.RemoteInfo = userHasOneRemoteInfo{
db: db.Session(&gorm.Session{}), db: db.Session(&gorm.Session{}),
@ -321,14 +323,16 @@ type user struct {
Description field.String Description field.String
IsBot field.Bool IsBot field.Bool
IconId field.String IconId field.String
BackgroundId field.String BackgroundId field.Field
BannerId field.String BannerId field.Field
Indexable field.Bool Indexable field.Bool
PublicKey field.Bytes PublicKey field.Bytes
RestrictedFollow field.Bool RestrictedFollow field.Bool
Location field.String Location field.Field
Birthday field.Time Birthday field.Field
Verified field.Bool Verified field.Bool
PasskeyId field.Bytes
FinishedRegistration field.Bool
RemoteInfo userHasOneRemoteInfo RemoteInfo userHasOneRemoteInfo
InfoFields userHasManyInfoFields InfoFields userHasManyInfoFields
@ -376,14 +380,16 @@ func (u *user) updateTableName(table string) *user {
u.Description = field.NewString(table, "description") u.Description = field.NewString(table, "description")
u.IsBot = field.NewBool(table, "is_bot") u.IsBot = field.NewBool(table, "is_bot")
u.IconId = field.NewString(table, "icon_id") u.IconId = field.NewString(table, "icon_id")
u.BackgroundId = field.NewString(table, "background_id") u.BackgroundId = field.NewField(table, "background_id")
u.BannerId = field.NewString(table, "banner_id") u.BannerId = field.NewField(table, "banner_id")
u.Indexable = field.NewBool(table, "indexable") u.Indexable = field.NewBool(table, "indexable")
u.PublicKey = field.NewBytes(table, "public_key") u.PublicKey = field.NewBytes(table, "public_key")
u.RestrictedFollow = field.NewBool(table, "restricted_follow") u.RestrictedFollow = field.NewBool(table, "restricted_follow")
u.Location = field.NewString(table, "location") u.Location = field.NewField(table, "location")
u.Birthday = field.NewTime(table, "birthday") u.Birthday = field.NewField(table, "birthday")
u.Verified = field.NewBool(table, "verified") u.Verified = field.NewBool(table, "verified")
u.PasskeyId = field.NewBytes(table, "passkey_id")
u.FinishedRegistration = field.NewBool(table, "finished_registration")
u.fillFieldMap() u.fillFieldMap()
@ -400,7 +406,7 @@ func (u *user) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
} }
func (u *user) fillFieldMap() { func (u *user) fillFieldMap() {
u.fieldMap = make(map[string]field.Expr, 29) u.fieldMap = make(map[string]field.Expr, 31)
u.fieldMap["id"] = u.ID u.fieldMap["id"] = u.ID
u.fieldMap["username"] = u.Username u.fieldMap["username"] = u.Username
u.fieldMap["created_at"] = u.CreatedAt u.fieldMap["created_at"] = u.CreatedAt
@ -419,6 +425,8 @@ func (u *user) fillFieldMap() {
u.fieldMap["location"] = u.Location u.fieldMap["location"] = u.Location
u.fieldMap["birthday"] = u.Birthday u.fieldMap["birthday"] = u.Birthday
u.fieldMap["verified"] = u.Verified u.fieldMap["verified"] = u.Verified
u.fieldMap["passkey_id"] = u.PasskeyId
u.fieldMap["finished_registration"] = u.FinishedRegistration
} }

View file

@ -71,6 +71,7 @@ type User struct {
// In theory, could also slash Id in half, but that would be a lot more calculations than the // In theory, could also slash Id in half, but that would be a lot more calculations than the
// saved space is worth // saved space is worth
PasskeyId []byte PasskeyId []byte
FinishedRegistration bool // Whether this account has completed registration yet
// ---- "Remote" linked values // ---- "Remote" linked values
InfoFields []UserInfoField InfoFields []UserInfoField

File diff suppressed because one or more lines are too long

View file

@ -1 +1,3 @@
package storage package storage
//go:generate go run ../cmd/NewRoleHelperGenerator/main.go -input ./models/Role.go -output role_generated.go -mod storage