Compare commits
3 commits
8f53e8a967
...
4ef9e19fbc
Author | SHA1 | Date | |
---|---|---|---|
4ef9e19fbc | |||
6f16289b41 | |||
6a2b213787 |
7 changed files with 152 additions and 7 deletions
29
auth-new/accessTokens.go
Normal file
29
auth-new/accessTokens.go
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.mstar.dev/mstar/goutils/other"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"git.mstar.dev/mstar/linstrom/storage-new/dbgen"
|
||||||
|
"git.mstar.dev/mstar/linstrom/storage-new/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check whether a given access token is valid (exists and hasn't expired).
|
||||||
|
// If it is, returns the user it belongs to
|
||||||
|
func (a *Authenticator) IsValidAccessToken(token string) (*models.User, error) {
|
||||||
|
dbToken, err := dbgen.AccessToken.GetTokenIfValid(token)
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
if dbToken.ExpiresAt.Before(time.Now()) {
|
||||||
|
return nil, ErrTokenExpired
|
||||||
|
} else {
|
||||||
|
return &dbToken.User, nil
|
||||||
|
}
|
||||||
|
case gorm.ErrRecordNotFound:
|
||||||
|
return nil, ErrTokenNotFound
|
||||||
|
default:
|
||||||
|
return nil, other.Error("auth", "failed to check for token", err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -26,6 +26,10 @@ var (
|
||||||
ErrInvalidPasskeyRegistrationData = errors.New(
|
ErrInvalidPasskeyRegistrationData = errors.New(
|
||||||
"stored passkey registration data was formatted badly",
|
"stored passkey registration data was formatted badly",
|
||||||
)
|
)
|
||||||
|
// The given token has expired
|
||||||
|
ErrTokenExpired = errors.New("token expired")
|
||||||
|
// The given token doesn't exist
|
||||||
|
ErrTokenNotFound = errors.New("token not found")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Helper error type to combine two errors into one
|
// Helper error type to combine two errors into one
|
||||||
|
|
|
@ -18,12 +18,6 @@ import (
|
||||||
"git.mstar.dev/mstar/linstrom/storage-new/models"
|
"git.mstar.dev/mstar/linstrom/storage-new/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
dbName = "linstrom"
|
|
||||||
dbUser = "linstrom"
|
|
||||||
dbPass = "linstrom"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
other.SetupFlags()
|
other.SetupFlags()
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
@ -32,7 +26,6 @@ func main() {
|
||||||
|
|
||||||
db, err := gorm.Open(
|
db, err := gorm.Open(
|
||||||
postgres.Open(config.GlobalConfig.Storage.BuildPostgresDSN()),
|
postgres.Open(config.GlobalConfig.Storage.BuildPostgresDSN()),
|
||||||
// postgres.Open(pgContainer.MustConnectionString(context.Background())),
|
|
||||||
&gorm.Config{
|
&gorm.Config{
|
||||||
PrepareStmt: false,
|
PrepareStmt: false,
|
||||||
Logger: shared.NewGormLogger(log.Logger),
|
Logger: shared.NewGormLogger(log.Logger),
|
||||||
|
@ -54,6 +47,7 @@ func main() {
|
||||||
log.Info().Msg("Basic operations applied, applying extra features")
|
log.Info().Msg("Basic operations applied, applying extra features")
|
||||||
g.ApplyInterface(func(models.INotification) {}, models.Notification{})
|
g.ApplyInterface(func(models.INotification) {}, models.Notification{})
|
||||||
g.ApplyInterface(func(models.IUser) {}, models.User{})
|
g.ApplyInterface(func(models.IUser) {}, models.User{})
|
||||||
|
g.ApplyInterface(func(models.IAccessToken) {}, models.AccessToken{})
|
||||||
|
|
||||||
log.Info().Msg("Extra features applied, starting generation")
|
log.Info().Msg("Extra features applied, starting generation")
|
||||||
g.Execute()
|
g.Execute()
|
||||||
|
|
19
storage-new/cleaners/ExpireAccessTokens.go
Normal file
19
storage-new/cleaners/ExpireAccessTokens.go
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
package cleaners
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.mstar.dev/mstar/linstrom/storage-new/dbgen"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
cleanerBuilders = append(cleanerBuilders, buildExpireAccessTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func tickExpireAccessTokens(now time.Time) {
|
||||||
|
dbgen.AccessToken.Where(dbgen.AccessToken.ExpiresAt.Lt(time.Now())).Delete()
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildExpireAccessTokens() (onTick func(time.Time), name string, tickSpeed time.Duration) {
|
||||||
|
return tickExpireAccessTokens, "expire-access-tokens", time.Hour
|
||||||
|
}
|
70
storage-new/cleaners/manager.go
Normal file
70
storage-new/cleaners/manager.go
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
package cleaners
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CleanerManager struct {
|
||||||
|
activeCleaners map[string]bool
|
||||||
|
activeCleanerLock sync.Mutex
|
||||||
|
exitChans []chan any
|
||||||
|
}
|
||||||
|
|
||||||
|
var cleanerBuilders = []func() (onTick func(time.Time), name string, tickSpeed time.Duration){}
|
||||||
|
|
||||||
|
func NewManager() *CleanerManager {
|
||||||
|
activeCleaners := make(map[string]bool)
|
||||||
|
exitChans := []chan any{}
|
||||||
|
cm := &CleanerManager{
|
||||||
|
activeCleaners: activeCleaners,
|
||||||
|
exitChans: exitChans,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Launch all cleaner tickers in a new goroutine each
|
||||||
|
for _, builder := range cleanerBuilders {
|
||||||
|
exitChan := make(chan any, 1)
|
||||||
|
onTick, name, tickSpeed := builder()
|
||||||
|
cm.exitChans = append(cm.exitChans, exitChan)
|
||||||
|
go cm.tickOrExit(tickSpeed, name, exitChan, onTick)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cm
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CleanerManager) Stop() {
|
||||||
|
for _, exitChan := range m.exitChans {
|
||||||
|
exitChan <- 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CleanerManager) tickOrExit(
|
||||||
|
tickSpeed time.Duration,
|
||||||
|
name string,
|
||||||
|
exitChan chan any,
|
||||||
|
onTick func(time.Time),
|
||||||
|
) {
|
||||||
|
ticker := time.Tick(tickSpeed)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case now := <-ticker:
|
||||||
|
go m.wrapOnTick(name, now, onTick)
|
||||||
|
case <-exitChan:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CleanerManager) wrapOnTick(name string, now time.Time, onTick func(time.Time)) {
|
||||||
|
m.activeCleanerLock.Lock()
|
||||||
|
if m.activeCleaners[name] {
|
||||||
|
m.activeCleanerLock.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.activeCleaners[name] = true
|
||||||
|
m.activeCleanerLock.Unlock()
|
||||||
|
onTick(now)
|
||||||
|
m.activeCleanerLock.Lock()
|
||||||
|
m.activeCleaners[name] = false
|
||||||
|
m.activeCleanerLock.Unlock()
|
||||||
|
}
|
|
@ -6,6 +6,7 @@ package dbgen
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"git.mstar.dev/mstar/linstrom/storage-new/models"
|
"git.mstar.dev/mstar/linstrom/storage-new/models"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
@ -435,6 +436,25 @@ type IAccessTokenDo interface {
|
||||||
Returning(value interface{}, columns ...string) IAccessTokenDo
|
Returning(value interface{}, columns ...string) IAccessTokenDo
|
||||||
UnderlyingDB() *gorm.DB
|
UnderlyingDB() *gorm.DB
|
||||||
schema.Tabler
|
schema.Tabler
|
||||||
|
|
||||||
|
GetTokenIfValid(token string) (result *models.AccessToken, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the data for a token if it hasn't expired yet
|
||||||
|
//
|
||||||
|
// SELECT * FROM @@table WHERE token = @token AND expires_at < NOW() LIMIT 1
|
||||||
|
func (a accessTokenDo) GetTokenIfValid(token string) (result *models.AccessToken, err error) {
|
||||||
|
var params []interface{}
|
||||||
|
|
||||||
|
var generateSQL strings.Builder
|
||||||
|
params = append(params, token)
|
||||||
|
generateSQL.WriteString("SELECT * FROM access_tokens WHERE token = ? AND expires_at < NOW() LIMIT 1 ")
|
||||||
|
|
||||||
|
var executeSQL *gorm.DB
|
||||||
|
executeSQL = a.UnderlyingDB().Raw(generateSQL.String(), params...).Take(&result) // ignore_security_alert
|
||||||
|
err = executeSQL.Error
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a accessTokenDo) Debug() IAccessTokenDo {
|
func (a accessTokenDo) Debug() IAccessTokenDo {
|
||||||
|
|
|
@ -2,6 +2,8 @@ package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gen"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AccessToken maps a unique token to one account.
|
// AccessToken maps a unique token to one account.
|
||||||
|
@ -18,3 +20,10 @@ type AccessToken struct {
|
||||||
// at a point in the future this server should never reach
|
// at a point in the future this server should never reach
|
||||||
ExpiresAt time.Time `gorm:"default:TIMESTAMP WITH TIME ZONE '9999-12-30 23:59:59+00'"`
|
ExpiresAt time.Time `gorm:"default:TIMESTAMP WITH TIME ZONE '9999-12-30 23:59:59+00'"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type IAccessToken interface {
|
||||||
|
// Get the data for a token
|
||||||
|
//
|
||||||
|
// SELECT * FROM @@table WHERE token = @token
|
||||||
|
GetTokenIfValid(token string) (*gen.T, error)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue