Compare commits

...

4 commits

Author SHA1 Message Date
582988add2
Add go:generate command and new field to user
Some checks are pending
/ test (push) Waiting to run
2025-03-31 08:08:01 +02:00
8ba5e98c50
Idk why, but testcontainers is borked on Discovery 2025-03-31 08:07:24 +02:00
7d385e48de
More work on auth system 2025-03-31 08:07:16 +02:00
2afb14c4b3
Copy and adapt role helper generator to new system 2025-03-31 08:06:43 +02:00
11 changed files with 650 additions and 72 deletions

View file

@ -17,5 +17,7 @@ type Authenticator struct {
}
func calcAccessExpirationTimestamp() time.Time {
// For now, the default expiration is one month after creation
// though "never" might also be a good option
return time.Now().Add(time.Hour * 24 * 30)
}

View file

@ -9,4 +9,6 @@ var (
ErrUnsupportedAuthMethod = errors.New("authentication method not supported for this user")
ErrInvalidCombination = errors.New("invalid account and token combination")
ErrProcessTimeout = errors.New("authentication process timed out")
// A user may not login, for whatever reason
ErrCantLogin = errors.New("user can't login")
)

View file

@ -5,10 +5,12 @@ import (
"git.mstar.dev/mstar/goutils/other"
"git.mstar.dev/mstar/goutils/sliceutils"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"git.mstar.dev/mstar/linstrom/storage-new"
"git.mstar.dev/mstar/linstrom/storage-new/dbgen"
"git.mstar.dev/mstar/linstrom/storage-new/models"
)
@ -79,8 +81,10 @@ func (a *Authenticator) StartPasswordLogin(
username string,
password string,
) (nextState LoginNextState, token string, err error) {
var acc *models.User
acc, err = dbgen.User.Where(dbgen.User.Username.Eq(username)).First()
if ok, err := a.canUsernameLogin(username); !ok {
return LoginNextFailure, "", other.Error("auth", "user may not login", err)
}
acc, err := dbgen.User.Where(dbgen.User.Username.Eq(username)).First()
switch err {
case nil:
break
@ -143,10 +147,10 @@ func (a *Authenticator) StartPasswordLogin(
loginToken := models.LoginProcessToken{
User: *acc,
UserId: acc.ID,
ExpiresAt: time.Now().Add(time.Minute * 5),
ExpiresAt: calcAccessExpirationTimestamp(),
Token: uuid.NewString(),
}
err = dbgen.LoginProcessToken.Clauses(clause.OnConflict{DoNothing: true}).
Omit(dbgen.LoginProcessToken.Token).
err = dbgen.LoginProcessToken.Clauses(clause.OnConflict{UpdateAll: true}).
Create(&loginToken)
if err != nil {
@ -159,3 +163,24 @@ func (a *Authenticator) StartPasswordLogin(
return nextStates, loginToken.Token, nil
}
func (a *Authenticator) canUsernameLogin(username string) (bool, error) {
acc, err := dbgen.User.Where(dbgen.User.Username.Eq(username)).First()
if err != nil {
return false, err
}
if !acc.FinishedRegistration {
return false, ErrCantLogin
}
// TODO: Check roles too
finalRole := storage.CollapseRolesIntoOne(
sliceutils.Map(acc.Roles, func(t models.UserToRole) models.Role {
return t.Role
})...)
if finalRole.CanLogin != nil && !*finalRole.CanLogin {
return false, ErrCantLogin
}
return true, nil
}

View file

@ -15,6 +15,9 @@ import (
)
func (a *Authenticator) StartPasskeyLogin(username string) (*protocol.CredentialAssertion, error) {
if ok, err := a.canUsernameLogin(username); !ok {
return nil, other.Error("auth", "user may not login", err)
}
acc, err := dbgen.User.Where(dbgen.User.Username.Eq(username)).First()
if err != nil {
return nil, err

View file

@ -0,0 +1,195 @@
/*
Tool for generating helper functions for storage.Role structs inside of the storage package
It generates the following functions:
- CollapseRolesIntoOne: Collapse a list of roles into one singular role. Each value will be set to the
value of the role with the highest priority
- RoleDeepCopy: Copy a role, including all arrays. Every value will be copied too
- CompareRoles: Compare two roles. Returns true only if all fields are equal
(if one of the fields is nil, that field is seen as equal)
*/
package main
import (
"flag"
"fmt"
"io"
"os"
"regexp"
"strings"
"git.mstar.dev/mstar/goutils/sliceutils"
)
var findRoleStructRegex = regexp.MustCompile(`type Role struct \{([\s\S]+)\}\n\n/\*`)
var (
flagInputFile = flag.String("input", "", "Specify the input file. If empty, read from stdin")
flagOutputFile = flag.String(
"output",
"",
"Specify the output file. If empty, writes to stdout",
)
flagOutputModule = flag.String(
"mod",
"",
"Module for the output file. If empty, use mod of input file",
)
)
func main() {
flag.Parse()
var input io.Reader
var output io.Writer
if *flagInputFile == "" {
input = os.Stdin
} else {
file, err := os.Open(*flagInputFile)
if err != nil {
panic(err)
}
defer file.Close()
input = file
}
if *flagOutputFile == "" {
output = os.Stdout
} else {
file, err := os.Create(*flagOutputFile)
if err != nil {
panic(err)
}
defer file.Close()
output = file
}
data, err := io.ReadAll(input)
if err != nil {
panic(err)
}
if !findRoleStructRegex.Match(data) {
panic("Input doesn't contain role struct")
}
content := findRoleStructRegex.FindStringSubmatch(string(data))[1]
lines := strings.Split(content, "\n")
lines = sliceutils.Map(lines, func(t string) string { return strings.TrimSpace(t) })
importantLines := sliceutils.Filter(lines, func(t string) bool {
if strings.HasPrefix(t, "//") {
return false
}
if strings.Contains(t, "gorm.Model") {
return false
}
data := sliceutils.Filter(strings.Split(t, " "), func(t string) bool { return t != "" })
if len(data) < 2 {
return false
}
if !strings.HasPrefix(data[1], "*") && !strings.HasPrefix(data[1], "[]") {
return false
}
return true
})
nameTypeMap := map[string]string{}
for _, line := range importantLines {
parts := sliceutils.Filter(strings.Split(line, " "), func(t string) bool { return t != "" })
nameTypeMap[parts[0]] = parts[1]
}
pkgString, _, _ := strings.Cut(string(data), "\n")
outBuilder := strings.Builder{}
outBuilder.WriteString(`// Code generated by cmd/RolesGenerator DO NOT EDIT.
// If you need to refresh the content, run go generate again
`)
if *flagOutputModule == "" {
outBuilder.WriteString(pkgString + "\n\n")
} else {
outBuilder.WriteString("package " + *flagOutputModule + "\n\n")
}
outBuilder.WriteString(
`import (
"slices"
"git.mstar.dev/mstar/goutils/sliceutils"
"git.mstar.dev/mstar/linstrom/storage-new/models"
)
`,
)
// Build role collapse function
outBuilder.WriteString(
`func CollapseRolesIntoOne(roles ...models.Role) models.Role {
startingRole := RoleDeepCopy(models.DefaultUserRole)
slices.SortFunc(roles, func(a, b models.Role) int { return int(int64(a.Priority)-int64(b.Priority)) })
for _, role := range roles {
`)
// Write all the stupid conditions here
for valName, valType := range nameTypeMap {
if strings.HasPrefix(valType, "[]") {
outBuilder.WriteString(fmt.Sprintf(` if role.%s != nil {
startingRole.%s = append(startingRole.%s, role.%s...)
}
`, valName, valName, valName, valName))
} else {
outBuilder.WriteString(fmt.Sprintf(` if role.%s != nil {
*startingRole.%s = *role.%s
}
`, valName, valName, valName))
}
}
// Then finish up with the end of the function
outBuilder.WriteString(
` }
return startingRole
}
`)
// Then build the deep copy function
outBuilder.WriteString("\nfunc RoleDeepCopy(o models.Role) models.Role {\n")
outBuilder.WriteString(` n := models.Role{}
n.Model = o.Model
n.Name = o.Name
n.Priority = o.Priority
n.IsUserRole = o.IsUserRole
n.IsBuiltIn = o.IsBuiltIn
`)
for valName, valType := range nameTypeMap {
if strings.HasPrefix(valType, "[]") {
outBuilder.WriteString(fmt.Sprintf(" n.%s = slices.Clone(o.%s)\n", valName, valName))
} else {
outBuilder.WriteString(fmt.Sprintf(` if o.%s == nil { n.%s = nil } else {
t := *o.%s
n.%s = &t
}
`, valName, valName, valName, valName))
}
}
outBuilder.WriteString(" return n\n}\n\n")
// Build compare function
outBuilder.WriteString("func CompareRoles(a, b *models.Role) bool {\n")
outBuilder.WriteString(" return ")
lastName, lastType := "", ""
for valName, valType := range nameTypeMap {
lastName = valName
lastType = valType
outBuilder.WriteString(fmt.Sprintf("(a.%s == nil || b.%s == nil || ", valName, valName))
if strings.HasPrefix(valType, "[]") {
outBuilder.WriteString(
fmt.Sprintf("sliceutils.CompareUnordered(a.%s,b.%s)) && ", valName, valName),
)
} else {
outBuilder.WriteString(fmt.Sprintf("a.%s == b.%s) && ", valName, valName))
}
}
outBuilder.WriteString("(a == nil || b == nil || ")
if strings.HasPrefix(lastType, "[]") {
outBuilder.WriteString(
fmt.Sprintf("sliceutils.CompareUnordered(a.%s,b.%s))", lastName, lastName),
)
} else {
outBuilder.WriteString(fmt.Sprintf("a.%s == b.%s)", lastName, lastName))
}
outBuilder.WriteString("\n}")
// And write the entire thing to the output
fmt.Fprint(output, outBuilder.String())
}

View file

@ -1,19 +1,15 @@
package main
import (
"context"
"flag"
"time"
"git.mstar.dev/mstar/goutils/other"
"github.com/rs/zerolog/log"
"github.com/testcontainers/testcontainers-go"
postgresContainer "github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
"gorm.io/driver/postgres"
"gorm.io/gen"
"gorm.io/gorm"
"git.mstar.dev/mstar/linstrom/config"
"git.mstar.dev/mstar/linstrom/shared"
"git.mstar.dev/mstar/linstrom/storage-new/models"
)
@ -36,36 +32,36 @@ func main() {
other.SetupFlags()
flag.Parse()
other.ConfigureLogging(nil)
// config.ReadAndWriteToGlobal(*shared.FlagConfigFile)
config.ReadAndWriteToGlobal(*shared.FlagConfigFile)
// Set up a temporary postgres container for gorm-gen to do its thing
log.Info().Msg("Starting temporary postgres container")
pgContainer, err := postgresContainer.Run(
context.Background(),
"postgres:16.4-alpine",
postgresContainer.WithDatabase(dbName),
postgresContainer.WithUsername(dbUser),
postgresContainer.WithPassword(dbPass),
testcontainers.WithWaitStrategyAndDeadline(
time.Minute,
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(time.Second*5),
),
)
if err != nil {
log.Fatal().Err(err).Msg("Failed to setup temporary postgres container")
}
log.Info().Msg("Temporary postgres container started")
defer func() {
if err := testcontainers.TerminateContainer(pgContainer); err != nil {
log.Fatal().Err(err).Msg("Failed to terminate temporary postgres container")
}
log.Info().Msg("Temporary postgres container stopped")
}()
// log.Info().Msg("Starting temporary postgres container")
// pgContainer, err := postgresContainer.Run(
// context.Background(),
// "postgres:16.4-alpine",
// postgresContainer.WithDatabase(dbName),
// postgresContainer.WithUsername(dbUser),
// postgresContainer.WithPassword(dbPass),
// testcontainers.WithWaitStrategyAndDeadline(
// time.Minute,
// wait.ForLog("database system is ready to accept connections").
// WithOccurrence(2).
// WithStartupTimeout(time.Second*5),
// ),
// )
// if err != nil {
// log.Fatal().Err(err).Msg("Failed to setup temporary postgres container")
// }
// log.Info().Msg("Temporary postgres container started")
// defer func() {
// if err := testcontainers.TerminateContainer(pgContainer); err != nil {
// log.Fatal().Err(err).Msg("Failed to terminate temporary postgres container")
// }
// log.Info().Msg("Temporary postgres container stopped")
// }()
db, err := gorm.Open(
// postgres.Open(config.GlobalConfig.Storage.BuildPostgresDSN()),
postgres.Open(pgContainer.MustConnectionString(context.Background())),
postgres.Open(config.GlobalConfig.Storage.BuildPostgresDSN()),
// postgres.Open(pgContainer.MustConnectionString(context.Background())),
&gorm.Config{
PrepareStmt: false,
Logger: shared.NewGormLogger(log.Logger),

View file

@ -28,6 +28,7 @@ func newLoginProcessToken(db *gorm.DB, opts ...gen.DOOption) loginProcessToken {
_loginProcessToken.ALL = field.NewAsterisk(tableName)
_loginProcessToken.UserId = field.NewString(tableName, "user_id")
_loginProcessToken.Token = field.NewString(tableName, "token")
_loginProcessToken.ExpiresAt = field.NewTime(tableName, "expires_at")
_loginProcessToken.User = loginProcessTokenBelongsToUser{
db: db.Session(&gorm.Session{}),
@ -177,10 +178,11 @@ func newLoginProcessToken(db *gorm.DB, opts ...gen.DOOption) loginProcessToken {
type loginProcessToken struct {
loginProcessTokenDo
ALL field.Asterisk
UserId field.String
Token field.String
User loginProcessTokenBelongsToUser
ALL field.Asterisk
UserId field.String
Token field.String
ExpiresAt field.Time
User loginProcessTokenBelongsToUser
fieldMap map[string]field.Expr
}
@ -199,6 +201,7 @@ func (l *loginProcessToken) updateTableName(table string) *loginProcessToken {
l.ALL = field.NewAsterisk(table)
l.UserId = field.NewString(table, "user_id")
l.Token = field.NewString(table, "token")
l.ExpiresAt = field.NewTime(table, "expires_at")
l.fillFieldMap()
@ -215,9 +218,10 @@ func (l *loginProcessToken) GetFieldByName(fieldName string) (field.OrderExpr, b
}
func (l *loginProcessToken) fillFieldMap() {
l.fieldMap = make(map[string]field.Expr, 3)
l.fieldMap = make(map[string]field.Expr, 4)
l.fieldMap["user_id"] = l.UserId
l.fieldMap["token"] = l.Token
l.fieldMap["expires_at"] = l.ExpiresAt
}

View file

@ -36,14 +36,16 @@ func newUser(db *gorm.DB, opts ...gen.DOOption) user {
_user.Description = field.NewString(tableName, "description")
_user.IsBot = field.NewBool(tableName, "is_bot")
_user.IconId = field.NewString(tableName, "icon_id")
_user.BackgroundId = field.NewString(tableName, "background_id")
_user.BannerId = field.NewString(tableName, "banner_id")
_user.BackgroundId = field.NewField(tableName, "background_id")
_user.BannerId = field.NewField(tableName, "banner_id")
_user.Indexable = field.NewBool(tableName, "indexable")
_user.PublicKey = field.NewBytes(tableName, "public_key")
_user.RestrictedFollow = field.NewBool(tableName, "restricted_follow")
_user.Location = field.NewString(tableName, "location")
_user.Birthday = field.NewTime(tableName, "birthday")
_user.Location = field.NewField(tableName, "location")
_user.Birthday = field.NewField(tableName, "birthday")
_user.Verified = field.NewBool(tableName, "verified")
_user.PasskeyId = field.NewBytes(tableName, "passkey_id")
_user.FinishedRegistration = field.NewBool(tableName, "finished_registration")
_user.RemoteInfo = userHasOneRemoteInfo{
db: db.Session(&gorm.Session{}),
@ -310,26 +312,28 @@ func newUser(db *gorm.DB, opts ...gen.DOOption) user {
type user struct {
userDo
ALL field.Asterisk
ID field.String
Username field.String
CreatedAt field.Time
UpdatedAt field.Time
DeletedAt field.Field
ServerId field.Uint
DisplayName field.String
Description field.String
IsBot field.Bool
IconId field.String
BackgroundId field.String
BannerId field.String
Indexable field.Bool
PublicKey field.Bytes
RestrictedFollow field.Bool
Location field.String
Birthday field.Time
Verified field.Bool
RemoteInfo userHasOneRemoteInfo
ALL field.Asterisk
ID field.String
Username field.String
CreatedAt field.Time
UpdatedAt field.Time
DeletedAt field.Field
ServerId field.Uint
DisplayName field.String
Description field.String
IsBot field.Bool
IconId field.String
BackgroundId field.Field
BannerId field.Field
Indexable field.Bool
PublicKey field.Bytes
RestrictedFollow field.Bool
Location field.Field
Birthday field.Field
Verified field.Bool
PasskeyId field.Bytes
FinishedRegistration field.Bool
RemoteInfo userHasOneRemoteInfo
InfoFields userHasManyInfoFields
@ -376,14 +380,16 @@ func (u *user) updateTableName(table string) *user {
u.Description = field.NewString(table, "description")
u.IsBot = field.NewBool(table, "is_bot")
u.IconId = field.NewString(table, "icon_id")
u.BackgroundId = field.NewString(table, "background_id")
u.BannerId = field.NewString(table, "banner_id")
u.BackgroundId = field.NewField(table, "background_id")
u.BannerId = field.NewField(table, "banner_id")
u.Indexable = field.NewBool(table, "indexable")
u.PublicKey = field.NewBytes(table, "public_key")
u.RestrictedFollow = field.NewBool(table, "restricted_follow")
u.Location = field.NewString(table, "location")
u.Birthday = field.NewTime(table, "birthday")
u.Location = field.NewField(table, "location")
u.Birthday = field.NewField(table, "birthday")
u.Verified = field.NewBool(table, "verified")
u.PasskeyId = field.NewBytes(table, "passkey_id")
u.FinishedRegistration = field.NewBool(table, "finished_registration")
u.fillFieldMap()
@ -400,7 +406,7 @@ func (u *user) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
}
func (u *user) fillFieldMap() {
u.fieldMap = make(map[string]field.Expr, 29)
u.fieldMap = make(map[string]field.Expr, 31)
u.fieldMap["id"] = u.ID
u.fieldMap["username"] = u.Username
u.fieldMap["created_at"] = u.CreatedAt
@ -419,6 +425,8 @@ func (u *user) fillFieldMap() {
u.fieldMap["location"] = u.Location
u.fieldMap["birthday"] = u.Birthday
u.fieldMap["verified"] = u.Verified
u.fieldMap["passkey_id"] = u.PasskeyId
u.fieldMap["finished_registration"] = u.FinishedRegistration
}

View file

@ -70,7 +70,8 @@ type User struct {
// 64 byte unique id for passkeys, because UUIDs are 128 bytes and passkey spec says 64 bytes max
// In theory, could also slash Id in half, but that would be a lot more calculations than the
// saved space is worth
PasskeyId []byte
PasskeyId []byte
FinishedRegistration bool // Whether this account has completed registration yet
// ---- "Remote" linked values
InfoFields []UserInfoField

File diff suppressed because one or more lines are too long

View file

@ -1 +1,3 @@
package storage
//go:generate go run ../cmd/NewRoleHelperGenerator/main.go -input ./models/Role.go -output role_generated.go -mod storage