diff --git a/client/core.go b/client/core.go index 29daeff5..9a7818f2 100644 --- a/client/core.go +++ b/client/core.go @@ -18,17 +18,22 @@ const ( // User represents a user in the system. type User struct { - ID int64 `json:"id"` - Username string `json:"username"` - Password string `json:"password,omitempty"` - IsAdmin bool `json:"is_admin"` - Theme string `json:"theme"` - Language string `json:"language"` - Timezone string `json:"timezone"` - EntryDirection string `json:"entry_sorting_direction"` - EntriesPerPage int `json:"entries_per_page"` - LastLoginAt *time.Time `json:"last_login_at"` - Extra map[string]string `json:"extra"` + ID int64 `json:"id"` + Username string `json:"username"` + Password string `json:"password,omitempty"` + IsAdmin bool `json:"is_admin"` + Theme string `json:"theme"` + Language string `json:"language"` + Timezone string `json:"timezone"` + EntryDirection string `json:"entry_sorting_direction"` + Stylesheet string `json:"stylesheet"` + GoogleID string `json:"google_id"` + OpenIDConnectID string `json:"openid_connect_id"` + EntriesPerPage int `json:"entries_per_page"` + KeyboardShortcuts bool `json:"keyboard_shortcuts"` + ShowReadingTime bool `json:"show_reading_time"` + EntrySwipe bool `json:"entry_swipe"` + LastLoginAt *time.Time `json:"last_login_at"` } func (u User) String() string { diff --git a/database/migrations.go b/database/migrations.go index 01209c65..f5177d06 100644 --- a/database/migrations.go +++ b/database/migrations.go @@ -4,7 +4,9 @@ package database // import "miniflux.app/database" -import "database/sql" +import ( + "database/sql" +) var schemaVersion = len(migrations) @@ -427,4 +429,71 @@ var migrations = []func(tx *sql.Tx) error{ _, err = tx.Exec(sql) return err }, + func(tx *sql.Tx) (err error) { + _, err = tx.Exec(` + ALTER TABLE users + ADD column stylesheet text not null default '', + ADD column google_id text not null default '', + ADD column openid_connect_id text not null default '' + `) + if err != nil { + return err + } + + _, err = tx.Exec(` + DECLARE my_cursor CURSOR FOR + SELECT + id, + COALESCE(extra->'custom_css', '') as custom_css, + COALESCE(extra->'google_id', '') as google_id, + COALESCE(extra->'oidc_id', '') as oidc_id + FROM users + FOR UPDATE + `) + if err != nil { + return err + } + defer tx.Exec("CLOSE my_cursor") + + for { + var ( + userID int64 + customStylesheet string + googleID string + oidcID string + ) + + if err := tx.QueryRow(`FETCH NEXT FROM my_cursor`).Scan(&userID, &customStylesheet, &googleID, &oidcID); err != nil { + if err == sql.ErrNoRows { + break + } + return err + } + + _, err := tx.Exec( + `UPDATE + users + SET + stylesheet=$2, + google_id=$3, + openid_connect_id=$4 + WHERE + id=$1 + `, + userID, customStylesheet, googleID, oidcID) + if err != nil { + return err + } + } + + return err + }, + func(tx *sql.Tx) (err error) { + _, err = tx.Exec(` + ALTER TABLE users DROP COLUMN extra; + CREATE UNIQUE INDEX users_google_id_idx ON users(google_id) WHERE google_id <> ''; + CREATE UNIQUE INDEX users_openid_connect_id_idx ON users(openid_connect_id) WHERE openid_connect_id <> ''; + `) + return err + }, } diff --git a/miniflux.1 b/miniflux.1 index 682a26f7..4e001d5b 100644 --- a/miniflux.1 +++ b/miniflux.1 @@ -213,7 +213,7 @@ List of networks allowed to access the metrics endpoint (comma-separated values) Default is 127.0.0.1/8\&. .TP .B OAUTH2_PROVIDER -OAuth2 provider to use\&. Only google is supported\&. +Possible values are "google" or "oidc"\&. .TP .B OAUTH2_CLIENT_ID OAuth2 client ID\&. diff --git a/model/user.go b/model/user.go index a2c49abb..b51b0d72 100644 --- a/model/user.go +++ b/model/user.go @@ -13,25 +13,27 @@ import ( // User represents a user in the system. type User struct { - ID int64 `json:"id"` - Username string `json:"username"` - Password string `json:"password,omitempty"` - IsAdmin bool `json:"is_admin"` - Theme string `json:"theme"` - Language string `json:"language"` - Timezone string `json:"timezone"` - EntryDirection string `json:"entry_sorting_direction"` - EntriesPerPage int `json:"entries_per_page"` - KeyboardShortcuts bool `json:"keyboard_shortcuts"` - ShowReadingTime bool `json:"show_reading_time"` - LastLoginAt *time.Time `json:"last_login_at,omitempty"` - Extra map[string]string `json:"extra"` - EntrySwipe bool `json:"entry_swipe"` + ID int64 `json:"id"` + Username string `json:"username"` + Password string `json:"password,omitempty"` + IsAdmin bool `json:"is_admin"` + Theme string `json:"theme"` + Language string `json:"language"` + Timezone string `json:"timezone"` + EntryDirection string `json:"entry_sorting_direction"` + Stylesheet string `json:"stylesheet"` + GoogleID string `json:"google_id"` + OpenIDConnectID string `json:"openid_connect_id"` + EntriesPerPage int `json:"entries_per_page"` + KeyboardShortcuts bool `json:"keyboard_shortcuts"` + ShowReadingTime bool `json:"show_reading_time"` + EntrySwipe bool `json:"entry_swipe"` + LastLoginAt *time.Time `json:"last_login_at,omitempty"` } // NewUser returns a new User. func NewUser() *User { - return &User{Extra: make(map[string]string)} + return &User{} } // ValidateUserCreation validates new user. diff --git a/oauth2/google.go b/oauth2/google.go index 87173764..9af11a23 100644 --- a/oauth2/google.go +++ b/oauth2/google.go @@ -9,6 +9,8 @@ import ( "encoding/json" "fmt" + "miniflux.app/model" + "golang.org/x/oauth2" ) @@ -23,15 +25,15 @@ type googleProvider struct { redirectURL string } -func (g googleProvider) GetUserExtraKey() string { +func (g *googleProvider) GetUserExtraKey() string { return "google_id" } -func (g googleProvider) GetRedirectURL(state string) string { +func (g *googleProvider) GetRedirectURL(state string) string { return g.config().AuthCodeURL(state) } -func (g googleProvider) GetProfile(ctx context.Context, code string) (*Profile, error) { +func (g *googleProvider) GetProfile(ctx context.Context, code string) (*Profile, error) { conf := g.config() token, err := conf.Exchange(ctx, code) if err != nil { @@ -48,14 +50,22 @@ func (g googleProvider) GetProfile(ctx context.Context, code string) (*Profile, var user googleProfile decoder := json.NewDecoder(resp.Body) if err := decoder.Decode(&user); err != nil { - return nil, fmt.Errorf("unable to unserialize google profile: %v", err) + return nil, fmt.Errorf("oauth2: unable to unserialize google profile: %v", err) } profile := &Profile{Key: g.GetUserExtraKey(), ID: user.Sub, Username: user.Email} return profile, nil } -func (g googleProvider) config() *oauth2.Config { +func (g *googleProvider) PopulateUserWithProfileID(user *model.User, profile *Profile) { + user.GoogleID = profile.ID +} + +func (g *googleProvider) UnsetUserProfileID(user *model.User) { + user.GoogleID = "" +} + +func (g *googleProvider) config() *oauth2.Config { return &oauth2.Config{ RedirectURL: g.redirectURL, ClientID: g.clientID, diff --git a/oauth2/manager.go b/oauth2/manager.go index ea076217..52b2fc8b 100644 --- a/oauth2/manager.go +++ b/oauth2/manager.go @@ -7,6 +7,7 @@ package oauth2 // import "miniflux.app/oauth2" import ( "context" "errors" + "miniflux.app/logger" ) @@ -15,8 +16,8 @@ type Manager struct { providers map[string]Provider } -// Provider returns the given provider. -func (m *Manager) Provider(name string) (Provider, error) { +// FindProvider returns the given provider. +func (m *Manager) FindProvider(name string) (Provider, error) { if provider, found := m.providers[name]; found { return provider, nil } diff --git a/oauth2/oidc.go b/oauth2/oidc.go index c1e664dd..8701084f 100644 --- a/oauth2/oidc.go +++ b/oauth2/oidc.go @@ -6,6 +6,9 @@ package oauth2 // import "miniflux.app/oauth2" import ( "context" + + "miniflux.app/model" + "github.com/coreos/go-oidc" "golang.org/x/oauth2" ) @@ -17,15 +20,15 @@ type oidcProvider struct { provider *oidc.Provider } -func (o oidcProvider) GetUserExtraKey() string { - return "oidc_id" // FIXME? add extra options key to allow multiple OIDC providers each with their own extra key? +func (o *oidcProvider) GetUserExtraKey() string { + return "openid_connect_id" } -func (o oidcProvider) GetRedirectURL(state string) string { +func (o *oidcProvider) GetRedirectURL(state string) string { return o.config().AuthCodeURL(state) } -func (o oidcProvider) GetProfile(ctx context.Context, code string) (*Profile, error) { +func (o *oidcProvider) GetProfile(ctx context.Context, code string) (*Profile, error) { conf := o.config() token, err := conf.Exchange(ctx, code) if err != nil { @@ -41,7 +44,15 @@ func (o oidcProvider) GetProfile(ctx context.Context, code string) (*Profile, er return profile, nil } -func (o oidcProvider) config() *oauth2.Config { +func (o *oidcProvider) PopulateUserWithProfileID(user *model.User, profile *Profile) { + user.OpenIDConnectID = profile.ID +} + +func (o *oidcProvider) UnsetUserProfileID(user *model.User) { + user.OpenIDConnectID = "" +} + +func (o *oidcProvider) config() *oauth2.Config { return &oauth2.Config{ RedirectURL: o.redirectURL, ClientID: o.clientID, diff --git a/oauth2/provider.go b/oauth2/provider.go index 4c2e6f02..d767129f 100644 --- a/oauth2/provider.go +++ b/oauth2/provider.go @@ -3,11 +3,18 @@ // license that can be found in the LICENSE file. package oauth2 // import "miniflux.app/oauth2" -import "context" + +import ( + "context" + + "miniflux.app/model" +) // Provider is an interface for OAuth2 providers. type Provider interface { GetUserExtraKey() string GetRedirectURL(state string) string GetProfile(ctx context.Context, code string) (*Profile, error) + PopulateUserWithProfileID(user *model.User, profile *Profile) + UnsetUserProfileID(user *model.User) } diff --git a/storage/user.go b/storage/user.go index bb164ab4..757a1b1e 100644 --- a/storage/user.go +++ b/storage/user.go @@ -12,7 +12,7 @@ import ( "miniflux.app/logger" "miniflux.app/model" - "github.com/lib/pq/hstore" + "github.com/lib/pq" "golang.org/x/crypto/bcrypt" ) @@ -54,32 +54,37 @@ func (s *Storage) AnotherUserExists(userID int64, username string) bool { // CreateUser creates a new user. func (s *Storage) CreateUser(user *model.User) (err error) { - password := "" - extra := hstore.Hstore{Map: make(map[string]sql.NullString)} - + hashedPassword := "" if user.Password != "" { - password, err = hashPassword(user.Password) + hashedPassword, err = hashPassword(user.Password) if err != nil { return err } } - if len(user.Extra) > 0 { - for key, value := range user.Extra { - extra.Map[key] = sql.NullString{String: value, Valid: true} - } - } - query := ` INSERT INTO users - (username, password, is_admin, extra) + (username, password, is_admin, google_id, openid_connect_id) VALUES - (LOWER($1), $2, $3, $4) + (LOWER($1), $2, $3, $4, $5) RETURNING - id, username, is_admin, language, theme, timezone, entry_direction, entries_per_page, keyboard_shortcuts, show_reading_time, entry_swipe + id, + username, + is_admin, + language, + theme, + timezone, + entry_direction, + entries_per_page, + keyboard_shortcuts, + show_reading_time, + entry_swipe, + stylesheet, + google_id, + openid_connect_id ` - err = s.db.QueryRow(query, user.Username, password, user.IsAdmin, extra).Scan( + err = s.db.QueryRow(query, user.Username, hashedPassword, user.IsAdmin, user.GoogleID, user.OpenIDConnectID).Scan( &user.ID, &user.Username, &user.IsAdmin, @@ -91,6 +96,9 @@ func (s *Storage) CreateUser(user *model.User) (err error) { &user.KeyboardShortcuts, &user.ShowReadingTime, &user.EntrySwipe, + &user.Stylesheet, + &user.GoogleID, + &user.OpenIDConnectID, ) if err != nil { return fmt.Errorf(`store: unable to create user: %v`, err) @@ -101,26 +109,6 @@ func (s *Storage) CreateUser(user *model.User) (err error) { return nil } -// UpdateExtraField updates an extra field of the given user. -func (s *Storage) UpdateExtraField(userID int64, field, value string) error { - query := fmt.Sprintf(`UPDATE users SET extra = extra || hstore('%s', $1) WHERE id=$2`, field) - _, err := s.db.Exec(query, value, userID) - if err != nil { - return fmt.Errorf(`store: unable to update user extra field: %v`, err) - } - return nil -} - -// RemoveExtraField deletes an extra field for the given user. -func (s *Storage) RemoveExtraField(userID int64, field string) error { - query := `UPDATE users SET extra = delete(extra, $1) WHERE id=$2` - _, err := s.db.Exec(query, field, userID) - if err != nil { - return fmt.Errorf(`store: unable to remove user extra field: %v`, err) - } - return nil -} - // UpdateUser updates a user. func (s *Storage) UpdateUser(user *model.User) error { if user.Password != "" { @@ -141,9 +129,12 @@ func (s *Storage) UpdateUser(user *model.User) error { entries_per_page=$8, keyboard_shortcuts=$9, show_reading_time=$10, - entry_swipe=$11 + entry_swipe=$11, + stylesheet=$12, + google_id=$13, + openid_connect_id=$14 WHERE - id=$12 + id=$15 ` _, err = s.db.Exec( @@ -159,6 +150,9 @@ func (s *Storage) UpdateUser(user *model.User) error { user.KeyboardShortcuts, user.ShowReadingTime, user.EntrySwipe, + user.Stylesheet, + user.GoogleID, + user.OpenIDConnectID, user.ID, ) if err != nil { @@ -176,9 +170,12 @@ func (s *Storage) UpdateUser(user *model.User) error { entries_per_page=$7, keyboard_shortcuts=$8, show_reading_time=$9, - entry_swipe=$10 + entry_swipe=$10, + stylesheet=$11, + google_id=$12, + openid_connect_id=$13 WHERE - id=$11 + id=$14 ` _, err := s.db.Exec( @@ -193,6 +190,9 @@ func (s *Storage) UpdateUser(user *model.User) error { user.KeyboardShortcuts, user.ShowReadingTime, user.EntrySwipe, + user.Stylesheet, + user.GoogleID, + user.OpenIDConnectID, user.ID, ) @@ -201,10 +201,6 @@ func (s *Storage) UpdateUser(user *model.User) error { } } - if err := s.UpdateExtraField(user.ID, "custom_css", user.Extra["custom_css"]); err != nil { - return fmt.Errorf(`store: unable to update user custom css: %v`, err) - } - return nil } @@ -234,7 +230,9 @@ func (s *Storage) UserByID(userID int64) (*model.User, error) { show_reading_time, entry_swipe, last_login_at, - extra + stylesheet, + google_id, + openid_connect_id FROM users WHERE @@ -259,7 +257,9 @@ func (s *Storage) UserByUsername(username string) (*model.User, error) { show_reading_time, entry_swipe, last_login_at, - extra + stylesheet, + google_id, + openid_connect_id FROM users WHERE @@ -268,8 +268,8 @@ func (s *Storage) UserByUsername(username string) (*model.User, error) { return s.fetchUser(query, username) } -// UserByExtraField finds a user by an extra field value. -func (s *Storage) UserByExtraField(field, value string) (*model.User, error) { +// UserByField finds a user by a field value. +func (s *Storage) UserByField(field, value string) (*model.User, error) { query := ` SELECT id, @@ -284,13 +284,22 @@ func (s *Storage) UserByExtraField(field, value string) (*model.User, error) { show_reading_time, entry_swipe, last_login_at, - extra + stylesheet, + google_id, + openid_connect_id FROM users WHERE - extra->$1=$2 + %s=$1 ` - return s.fetchUser(query, field, value) + return s.fetchUser(fmt.Sprintf(query, pq.QuoteIdentifier(field)), value) +} + +// AnotherUserWithFieldExists returns true if a user has the value set for the given field. +func (s *Storage) AnotherUserWithFieldExists(userID int64, field, value string) bool { + var result bool + s.db.QueryRow(fmt.Sprintf(`SELECT true FROM users WHERE id <> $1 AND %s=$2`, pq.QuoteIdentifier(field)), userID, value).Scan(&result) + return result } // UserByAPIKey returns a User from an API Key. @@ -309,7 +318,9 @@ func (s *Storage) UserByAPIKey(token string) (*model.User, error) { u.show_reading_time, u.entry_swipe, u.last_login_at, - u.extra + u.stylesheet, + u.google_id, + u.openid_connect_id FROM users u LEFT JOIN @@ -321,8 +332,6 @@ func (s *Storage) UserByAPIKey(token string) (*model.User, error) { } func (s *Storage) fetchUser(query string, args ...interface{}) (*model.User, error) { - var extra hstore.Hstore - user := model.NewUser() err := s.db.QueryRow(query, args...).Scan( &user.ID, @@ -337,7 +346,9 @@ func (s *Storage) fetchUser(query string, args ...interface{}) (*model.User, err &user.ShowReadingTime, &user.EntrySwipe, &user.LastLoginAt, - &extra, + &user.Stylesheet, + &user.GoogleID, + &user.OpenIDConnectID, ) if err == sql.ErrNoRows { @@ -346,12 +357,6 @@ func (s *Storage) fetchUser(query string, args ...interface{}) (*model.User, err return nil, fmt.Errorf(`store: unable to fetch user: %v`, err) } - for key, value := range extra.Map { - if value.Valid { - user.Extra[key] = value.String - } - } - return user, nil } @@ -404,7 +409,9 @@ func (s *Storage) Users() (model.Users, error) { show_reading_time, entry_swipe, last_login_at, - extra + stylesheet, + google_id, + openid_connect_id FROM users ORDER BY username ASC @@ -417,7 +424,6 @@ func (s *Storage) Users() (model.Users, error) { var users model.Users for rows.Next() { - var extra hstore.Hstore user := model.NewUser() err := rows.Scan( &user.ID, @@ -432,19 +438,15 @@ func (s *Storage) Users() (model.Users, error) { &user.ShowReadingTime, &user.EntrySwipe, &user.LastLoginAt, - &extra, + &user.Stylesheet, + &user.GoogleID, + &user.OpenIDConnectID, ) if err != nil { return nil, fmt.Errorf(`store: unable to fetch users row: %v`, err) } - for key, value := range extra.Map { - if value.Valid { - user.Extra[key] = value.String - } - } - users = append(users, user) } diff --git a/template/common.go b/template/common.go index 7b35f46b..79e7b7cf 100644 --- a/template/common.go +++ b/template/common.go @@ -346,9 +346,9 @@ SOFTWARE. - {{ if .user }} {{ if ne (index .user.Extra "custom_css") ("") }} + {{ if and .user .user.Stylesheet }} - {{ end }}{{ end }} + {{ end }} @@ -524,7 +524,7 @@ var templateCommonMapChecksums = map[string]string{ "feed_menu": "318d8662dda5ca9dfc75b909c8461e79c86fb5082df1428f67aaf856f19f4b50", "icons": "9a41753778072f286216085d8712495e2ccca20c7a24f5c982775436a3d38579", "item_meta": "56ab09d7dd46eeb2e2ee11ddcec0c157a5832c896dbd2887d9e2b013680b2af6", - "layout": "65767e7dbebe1f7ed42895ecd5a737b0693e4a2ec35e84e3e391f462beb11977", + "layout": "c4b8c65c0d85ed1aff0550f58b9dbf0768c74f2df6232952c0fe299d4c73d674", "pagination": "7b61288e86283c4cf0dc83bcbf8bf1c00c7cb29e60201c8c0b633b2450d2911f", "settings_menu": "e2b777630c0efdbc529800303c01d6744ed3af80ec505ac5a5b3f99c9b989156", } diff --git a/template/html/common/layout.html b/template/html/common/layout.html index bce2d70b..695f5226 100644 --- a/template/html/common/layout.html +++ b/template/html/common/layout.html @@ -35,9 +35,9 @@ - {{ if .user }} {{ if ne (index .user.Extra "custom_css") ("") }} + {{ if and .user .user.Stylesheet }} - {{ end }}{{ end }} + {{ end }} diff --git a/template/html/settings.html b/template/html/settings.html index 315ebb0e..fb7647d2 100644 --- a/template/html/settings.html +++ b/template/html/settings.html @@ -66,7 +66,7 @@ {{ if hasOAuth2Provider "google" }}