diff --git a/client/core.go b/client/core.go index 29daeff5..9a7818f2 100644 --- a/client/core.go +++ b/client/core.go @@ -18,17 +18,22 @@ const ( // User represents a user in the system. type User struct { - ID int64 `json:"id"` - Username string `json:"username"` - Password string `json:"password,omitempty"` - IsAdmin bool `json:"is_admin"` - Theme string `json:"theme"` - Language string `json:"language"` - Timezone string `json:"timezone"` - EntryDirection string `json:"entry_sorting_direction"` - EntriesPerPage int `json:"entries_per_page"` - LastLoginAt *time.Time `json:"last_login_at"` - Extra map[string]string `json:"extra"` + ID int64 `json:"id"` + Username string `json:"username"` + Password string `json:"password,omitempty"` + IsAdmin bool `json:"is_admin"` + Theme string `json:"theme"` + Language string `json:"language"` + Timezone string `json:"timezone"` + EntryDirection string `json:"entry_sorting_direction"` + Stylesheet string `json:"stylesheet"` + GoogleID string `json:"google_id"` + OpenIDConnectID string `json:"openid_connect_id"` + EntriesPerPage int `json:"entries_per_page"` + KeyboardShortcuts bool `json:"keyboard_shortcuts"` + ShowReadingTime bool `json:"show_reading_time"` + EntrySwipe bool `json:"entry_swipe"` + LastLoginAt *time.Time `json:"last_login_at"` } func (u User) String() string { diff --git a/database/migrations.go b/database/migrations.go index 01209c65..f5177d06 100644 --- a/database/migrations.go +++ b/database/migrations.go @@ -4,7 +4,9 @@ package database // import "miniflux.app/database" -import "database/sql" +import ( + "database/sql" +) var schemaVersion = len(migrations) @@ -427,4 +429,71 @@ var migrations = []func(tx *sql.Tx) error{ _, err = tx.Exec(sql) return err }, + func(tx *sql.Tx) (err error) { + _, err = tx.Exec(` + ALTER TABLE users + ADD column stylesheet text not null default '', + ADD column google_id text not null default '', + ADD column openid_connect_id text not null default '' + `) + if err != nil { + return err + } + + _, err = tx.Exec(` + DECLARE my_cursor CURSOR FOR + SELECT + id, + COALESCE(extra->'custom_css', '') as custom_css, + COALESCE(extra->'google_id', '') as google_id, + COALESCE(extra->'oidc_id', '') as oidc_id + FROM users + FOR UPDATE + `) + if err != nil { + return err + } + defer tx.Exec("CLOSE my_cursor") + + for { + var ( + userID int64 + customStylesheet string + googleID string + oidcID string + ) + + if err := tx.QueryRow(`FETCH NEXT FROM my_cursor`).Scan(&userID, &customStylesheet, &googleID, &oidcID); err != nil { + if err == sql.ErrNoRows { + break + } + return err + } + + _, err := tx.Exec( + `UPDATE + users + SET + stylesheet=$2, + google_id=$3, + openid_connect_id=$4 + WHERE + id=$1 + `, + userID, customStylesheet, googleID, oidcID) + if err != nil { + return err + } + } + + return err + }, + func(tx *sql.Tx) (err error) { + _, err = tx.Exec(` + ALTER TABLE users DROP COLUMN extra; + CREATE UNIQUE INDEX users_google_id_idx ON users(google_id) WHERE google_id <> ''; + CREATE UNIQUE INDEX users_openid_connect_id_idx ON users(openid_connect_id) WHERE openid_connect_id <> ''; + `) + return err + }, } diff --git a/miniflux.1 b/miniflux.1 index 682a26f7..4e001d5b 100644 --- a/miniflux.1 +++ b/miniflux.1 @@ -213,7 +213,7 @@ List of networks allowed to access the metrics endpoint (comma-separated values) Default is 127.0.0.1/8\&. .TP .B OAUTH2_PROVIDER -OAuth2 provider to use\&. Only google is supported\&. +Possible values are "google" or "oidc"\&. .TP .B OAUTH2_CLIENT_ID OAuth2 client ID\&. diff --git a/model/user.go b/model/user.go index a2c49abb..b51b0d72 100644 --- a/model/user.go +++ b/model/user.go @@ -13,25 +13,27 @@ import ( // User represents a user in the system. type User struct { - ID int64 `json:"id"` - Username string `json:"username"` - Password string `json:"password,omitempty"` - IsAdmin bool `json:"is_admin"` - Theme string `json:"theme"` - Language string `json:"language"` - Timezone string `json:"timezone"` - EntryDirection string `json:"entry_sorting_direction"` - EntriesPerPage int `json:"entries_per_page"` - KeyboardShortcuts bool `json:"keyboard_shortcuts"` - ShowReadingTime bool `json:"show_reading_time"` - LastLoginAt *time.Time `json:"last_login_at,omitempty"` - Extra map[string]string `json:"extra"` - EntrySwipe bool `json:"entry_swipe"` + ID int64 `json:"id"` + Username string `json:"username"` + Password string `json:"password,omitempty"` + IsAdmin bool `json:"is_admin"` + Theme string `json:"theme"` + Language string `json:"language"` + Timezone string `json:"timezone"` + EntryDirection string `json:"entry_sorting_direction"` + Stylesheet string `json:"stylesheet"` + GoogleID string `json:"google_id"` + OpenIDConnectID string `json:"openid_connect_id"` + EntriesPerPage int `json:"entries_per_page"` + KeyboardShortcuts bool `json:"keyboard_shortcuts"` + ShowReadingTime bool `json:"show_reading_time"` + EntrySwipe bool `json:"entry_swipe"` + LastLoginAt *time.Time `json:"last_login_at,omitempty"` } // NewUser returns a new User. func NewUser() *User { - return &User{Extra: make(map[string]string)} + return &User{} } // ValidateUserCreation validates new user. diff --git a/oauth2/google.go b/oauth2/google.go index 87173764..9af11a23 100644 --- a/oauth2/google.go +++ b/oauth2/google.go @@ -9,6 +9,8 @@ import ( "encoding/json" "fmt" + "miniflux.app/model" + "golang.org/x/oauth2" ) @@ -23,15 +25,15 @@ type googleProvider struct { redirectURL string } -func (g googleProvider) GetUserExtraKey() string { +func (g *googleProvider) GetUserExtraKey() string { return "google_id" } -func (g googleProvider) GetRedirectURL(state string) string { +func (g *googleProvider) GetRedirectURL(state string) string { return g.config().AuthCodeURL(state) } -func (g googleProvider) GetProfile(ctx context.Context, code string) (*Profile, error) { +func (g *googleProvider) GetProfile(ctx context.Context, code string) (*Profile, error) { conf := g.config() token, err := conf.Exchange(ctx, code) if err != nil { @@ -48,14 +50,22 @@ func (g googleProvider) GetProfile(ctx context.Context, code string) (*Profile, var user googleProfile decoder := json.NewDecoder(resp.Body) if err := decoder.Decode(&user); err != nil { - return nil, fmt.Errorf("unable to unserialize google profile: %v", err) + return nil, fmt.Errorf("oauth2: unable to unserialize google profile: %v", err) } profile := &Profile{Key: g.GetUserExtraKey(), ID: user.Sub, Username: user.Email} return profile, nil } -func (g googleProvider) config() *oauth2.Config { +func (g *googleProvider) PopulateUserWithProfileID(user *model.User, profile *Profile) { + user.GoogleID = profile.ID +} + +func (g *googleProvider) UnsetUserProfileID(user *model.User) { + user.GoogleID = "" +} + +func (g *googleProvider) config() *oauth2.Config { return &oauth2.Config{ RedirectURL: g.redirectURL, ClientID: g.clientID, diff --git a/oauth2/manager.go b/oauth2/manager.go index ea076217..52b2fc8b 100644 --- a/oauth2/manager.go +++ b/oauth2/manager.go @@ -7,6 +7,7 @@ package oauth2 // import "miniflux.app/oauth2" import ( "context" "errors" + "miniflux.app/logger" ) @@ -15,8 +16,8 @@ type Manager struct { providers map[string]Provider } -// Provider returns the given provider. -func (m *Manager) Provider(name string) (Provider, error) { +// FindProvider returns the given provider. +func (m *Manager) FindProvider(name string) (Provider, error) { if provider, found := m.providers[name]; found { return provider, nil } diff --git a/oauth2/oidc.go b/oauth2/oidc.go index c1e664dd..8701084f 100644 --- a/oauth2/oidc.go +++ b/oauth2/oidc.go @@ -6,6 +6,9 @@ package oauth2 // import "miniflux.app/oauth2" import ( "context" + + "miniflux.app/model" + "github.com/coreos/go-oidc" "golang.org/x/oauth2" ) @@ -17,15 +20,15 @@ type oidcProvider struct { provider *oidc.Provider } -func (o oidcProvider) GetUserExtraKey() string { - return "oidc_id" // FIXME? add extra options key to allow multiple OIDC providers each with their own extra key? +func (o *oidcProvider) GetUserExtraKey() string { + return "openid_connect_id" } -func (o oidcProvider) GetRedirectURL(state string) string { +func (o *oidcProvider) GetRedirectURL(state string) string { return o.config().AuthCodeURL(state) } -func (o oidcProvider) GetProfile(ctx context.Context, code string) (*Profile, error) { +func (o *oidcProvider) GetProfile(ctx context.Context, code string) (*Profile, error) { conf := o.config() token, err := conf.Exchange(ctx, code) if err != nil { @@ -41,7 +44,15 @@ func (o oidcProvider) GetProfile(ctx context.Context, code string) (*Profile, er return profile, nil } -func (o oidcProvider) config() *oauth2.Config { +func (o *oidcProvider) PopulateUserWithProfileID(user *model.User, profile *Profile) { + user.OpenIDConnectID = profile.ID +} + +func (o *oidcProvider) UnsetUserProfileID(user *model.User) { + user.OpenIDConnectID = "" +} + +func (o *oidcProvider) config() *oauth2.Config { return &oauth2.Config{ RedirectURL: o.redirectURL, ClientID: o.clientID, diff --git a/oauth2/provider.go b/oauth2/provider.go index 4c2e6f02..d767129f 100644 --- a/oauth2/provider.go +++ b/oauth2/provider.go @@ -3,11 +3,18 @@ // license that can be found in the LICENSE file. package oauth2 // import "miniflux.app/oauth2" -import "context" + +import ( + "context" + + "miniflux.app/model" +) // Provider is an interface for OAuth2 providers. type Provider interface { GetUserExtraKey() string GetRedirectURL(state string) string GetProfile(ctx context.Context, code string) (*Profile, error) + PopulateUserWithProfileID(user *model.User, profile *Profile) + UnsetUserProfileID(user *model.User) } diff --git a/storage/user.go b/storage/user.go index bb164ab4..757a1b1e 100644 --- a/storage/user.go +++ b/storage/user.go @@ -12,7 +12,7 @@ import ( "miniflux.app/logger" "miniflux.app/model" - "github.com/lib/pq/hstore" + "github.com/lib/pq" "golang.org/x/crypto/bcrypt" ) @@ -54,32 +54,37 @@ func (s *Storage) AnotherUserExists(userID int64, username string) bool { // CreateUser creates a new user. func (s *Storage) CreateUser(user *model.User) (err error) { - password := "" - extra := hstore.Hstore{Map: make(map[string]sql.NullString)} - + hashedPassword := "" if user.Password != "" { - password, err = hashPassword(user.Password) + hashedPassword, err = hashPassword(user.Password) if err != nil { return err } } - if len(user.Extra) > 0 { - for key, value := range user.Extra { - extra.Map[key] = sql.NullString{String: value, Valid: true} - } - } - query := ` INSERT INTO users - (username, password, is_admin, extra) + (username, password, is_admin, google_id, openid_connect_id) VALUES - (LOWER($1), $2, $3, $4) + (LOWER($1), $2, $3, $4, $5) RETURNING - id, username, is_admin, language, theme, timezone, entry_direction, entries_per_page, keyboard_shortcuts, show_reading_time, entry_swipe + id, + username, + is_admin, + language, + theme, + timezone, + entry_direction, + entries_per_page, + keyboard_shortcuts, + show_reading_time, + entry_swipe, + stylesheet, + google_id, + openid_connect_id ` - err = s.db.QueryRow(query, user.Username, password, user.IsAdmin, extra).Scan( + err = s.db.QueryRow(query, user.Username, hashedPassword, user.IsAdmin, user.GoogleID, user.OpenIDConnectID).Scan( &user.ID, &user.Username, &user.IsAdmin, @@ -91,6 +96,9 @@ func (s *Storage) CreateUser(user *model.User) (err error) { &user.KeyboardShortcuts, &user.ShowReadingTime, &user.EntrySwipe, + &user.Stylesheet, + &user.GoogleID, + &user.OpenIDConnectID, ) if err != nil { return fmt.Errorf(`store: unable to create user: %v`, err) @@ -101,26 +109,6 @@ func (s *Storage) CreateUser(user *model.User) (err error) { return nil } -// UpdateExtraField updates an extra field of the given user. -func (s *Storage) UpdateExtraField(userID int64, field, value string) error { - query := fmt.Sprintf(`UPDATE users SET extra = extra || hstore('%s', $1) WHERE id=$2`, field) - _, err := s.db.Exec(query, value, userID) - if err != nil { - return fmt.Errorf(`store: unable to update user extra field: %v`, err) - } - return nil -} - -// RemoveExtraField deletes an extra field for the given user. -func (s *Storage) RemoveExtraField(userID int64, field string) error { - query := `UPDATE users SET extra = delete(extra, $1) WHERE id=$2` - _, err := s.db.Exec(query, field, userID) - if err != nil { - return fmt.Errorf(`store: unable to remove user extra field: %v`, err) - } - return nil -} - // UpdateUser updates a user. func (s *Storage) UpdateUser(user *model.User) error { if user.Password != "" { @@ -141,9 +129,12 @@ func (s *Storage) UpdateUser(user *model.User) error { entries_per_page=$8, keyboard_shortcuts=$9, show_reading_time=$10, - entry_swipe=$11 + entry_swipe=$11, + stylesheet=$12, + google_id=$13, + openid_connect_id=$14 WHERE - id=$12 + id=$15 ` _, err = s.db.Exec( @@ -159,6 +150,9 @@ func (s *Storage) UpdateUser(user *model.User) error { user.KeyboardShortcuts, user.ShowReadingTime, user.EntrySwipe, + user.Stylesheet, + user.GoogleID, + user.OpenIDConnectID, user.ID, ) if err != nil { @@ -176,9 +170,12 @@ func (s *Storage) UpdateUser(user *model.User) error { entries_per_page=$7, keyboard_shortcuts=$8, show_reading_time=$9, - entry_swipe=$10 + entry_swipe=$10, + stylesheet=$11, + google_id=$12, + openid_connect_id=$13 WHERE - id=$11 + id=$14 ` _, err := s.db.Exec( @@ -193,6 +190,9 @@ func (s *Storage) UpdateUser(user *model.User) error { user.KeyboardShortcuts, user.ShowReadingTime, user.EntrySwipe, + user.Stylesheet, + user.GoogleID, + user.OpenIDConnectID, user.ID, ) @@ -201,10 +201,6 @@ func (s *Storage) UpdateUser(user *model.User) error { } } - if err := s.UpdateExtraField(user.ID, "custom_css", user.Extra["custom_css"]); err != nil { - return fmt.Errorf(`store: unable to update user custom css: %v`, err) - } - return nil } @@ -234,7 +230,9 @@ func (s *Storage) UserByID(userID int64) (*model.User, error) { show_reading_time, entry_swipe, last_login_at, - extra + stylesheet, + google_id, + openid_connect_id FROM users WHERE @@ -259,7 +257,9 @@ func (s *Storage) UserByUsername(username string) (*model.User, error) { show_reading_time, entry_swipe, last_login_at, - extra + stylesheet, + google_id, + openid_connect_id FROM users WHERE @@ -268,8 +268,8 @@ func (s *Storage) UserByUsername(username string) (*model.User, error) { return s.fetchUser(query, username) } -// UserByExtraField finds a user by an extra field value. -func (s *Storage) UserByExtraField(field, value string) (*model.User, error) { +// UserByField finds a user by a field value. +func (s *Storage) UserByField(field, value string) (*model.User, error) { query := ` SELECT id, @@ -284,13 +284,22 @@ func (s *Storage) UserByExtraField(field, value string) (*model.User, error) { show_reading_time, entry_swipe, last_login_at, - extra + stylesheet, + google_id, + openid_connect_id FROM users WHERE - extra->$1=$2 + %s=$1 ` - return s.fetchUser(query, field, value) + return s.fetchUser(fmt.Sprintf(query, pq.QuoteIdentifier(field)), value) +} + +// AnotherUserWithFieldExists returns true if a user has the value set for the given field. +func (s *Storage) AnotherUserWithFieldExists(userID int64, field, value string) bool { + var result bool + s.db.QueryRow(fmt.Sprintf(`SELECT true FROM users WHERE id <> $1 AND %s=$2`, pq.QuoteIdentifier(field)), userID, value).Scan(&result) + return result } // UserByAPIKey returns a User from an API Key. @@ -309,7 +318,9 @@ func (s *Storage) UserByAPIKey(token string) (*model.User, error) { u.show_reading_time, u.entry_swipe, u.last_login_at, - u.extra + u.stylesheet, + u.google_id, + u.openid_connect_id FROM users u LEFT JOIN @@ -321,8 +332,6 @@ func (s *Storage) UserByAPIKey(token string) (*model.User, error) { } func (s *Storage) fetchUser(query string, args ...interface{}) (*model.User, error) { - var extra hstore.Hstore - user := model.NewUser() err := s.db.QueryRow(query, args...).Scan( &user.ID, @@ -337,7 +346,9 @@ func (s *Storage) fetchUser(query string, args ...interface{}) (*model.User, err &user.ShowReadingTime, &user.EntrySwipe, &user.LastLoginAt, - &extra, + &user.Stylesheet, + &user.GoogleID, + &user.OpenIDConnectID, ) if err == sql.ErrNoRows { @@ -346,12 +357,6 @@ func (s *Storage) fetchUser(query string, args ...interface{}) (*model.User, err return nil, fmt.Errorf(`store: unable to fetch user: %v`, err) } - for key, value := range extra.Map { - if value.Valid { - user.Extra[key] = value.String - } - } - return user, nil } @@ -404,7 +409,9 @@ func (s *Storage) Users() (model.Users, error) { show_reading_time, entry_swipe, last_login_at, - extra + stylesheet, + google_id, + openid_connect_id FROM users ORDER BY username ASC @@ -417,7 +424,6 @@ func (s *Storage) Users() (model.Users, error) { var users model.Users for rows.Next() { - var extra hstore.Hstore user := model.NewUser() err := rows.Scan( &user.ID, @@ -432,19 +438,15 @@ func (s *Storage) Users() (model.Users, error) { &user.ShowReadingTime, &user.EntrySwipe, &user.LastLoginAt, - &extra, + &user.Stylesheet, + &user.GoogleID, + &user.OpenIDConnectID, ) if err != nil { return nil, fmt.Errorf(`store: unable to fetch users row: %v`, err) } - for key, value := range extra.Map { - if value.Valid { - user.Extra[key] = value.String - } - } - users = append(users, user) } diff --git a/template/common.go b/template/common.go index 7b35f46b..79e7b7cf 100644 --- a/template/common.go +++ b/template/common.go @@ -346,9 +346,9 @@ SOFTWARE. - {{ if .user }} {{ if ne (index .user.Extra "custom_css") ("") }} + {{ if and .user .user.Stylesheet }} - {{ end }}{{ end }} + {{ end }} @@ -524,7 +524,7 @@ var templateCommonMapChecksums = map[string]string{ "feed_menu": "318d8662dda5ca9dfc75b909c8461e79c86fb5082df1428f67aaf856f19f4b50", "icons": "9a41753778072f286216085d8712495e2ccca20c7a24f5c982775436a3d38579", "item_meta": "56ab09d7dd46eeb2e2ee11ddcec0c157a5832c896dbd2887d9e2b013680b2af6", - "layout": "65767e7dbebe1f7ed42895ecd5a737b0693e4a2ec35e84e3e391f462beb11977", + "layout": "c4b8c65c0d85ed1aff0550f58b9dbf0768c74f2df6232952c0fe299d4c73d674", "pagination": "7b61288e86283c4cf0dc83bcbf8bf1c00c7cb29e60201c8c0b633b2450d2911f", "settings_menu": "e2b777630c0efdbc529800303c01d6744ed3af80ec505ac5a5b3f99c9b989156", } diff --git a/template/html/common/layout.html b/template/html/common/layout.html index bce2d70b..695f5226 100644 --- a/template/html/common/layout.html +++ b/template/html/common/layout.html @@ -35,9 +35,9 @@ - {{ if .user }} {{ if ne (index .user.Extra "custom_css") ("") }} + {{ if and .user .user.Stylesheet }} - {{ end }}{{ end }} + {{ end }} diff --git a/template/html/settings.html b/template/html/settings.html index 315ebb0e..fb7647d2 100644 --- a/template/html/settings.html +++ b/template/html/settings.html @@ -66,7 +66,7 @@ {{ if hasOAuth2Provider "google" }}
- {{ if hasKey .user.Extra "google_id" }} + {{ if .user.GoogleID }} {{ t "page.settings.unlink_google_account" }} {{ else }} {{ t "page.settings.link_google_account" }} @@ -74,7 +74,7 @@
{{ else if hasOAuth2Provider "oidc" }}
- {{ if hasKey .user.Extra "oidc_id" }} + {{ if .user.OpenIDConnectID }} {{ t "page.settings.unlink_oidc_account" }} {{ else }} {{ t "page.settings.link_oidc_account" }} diff --git a/template/views.go b/template/views.go index 6e4a2b94..1a56e37c 100644 --- a/template/views.go +++ b/template/views.go @@ -1387,7 +1387,7 @@ var templateViewsMap = map[string]string{ {{ if hasOAuth2Provider "google" }}
- {{ if hasKey .user.Extra "google_id" }} + {{ if .user.GoogleID }} {{ t "page.settings.unlink_google_account" }} {{ else }} {{ t "page.settings.link_google_account" }} @@ -1395,7 +1395,7 @@ var templateViewsMap = map[string]string{
{{ else if hasOAuth2Provider "oidc" }}
- {{ if hasKey .user.Extra "oidc_id" }} + {{ if .user.OpenIDConnectID }} {{ t "page.settings.unlink_oidc_account" }} {{ else }} {{ t "page.settings.link_oidc_account" }} @@ -1624,7 +1624,7 @@ var templateViewsMapChecksums = map[string]string{ "login": "9165434b2405e9332de4bebbb54a93dc5692276ea72e7c5e07f655a002dfd290", "search_entries": "6a3e5876cb7541a2f08f56e30ab46a2d7d64894ec5e170f627b2dd674d8aeefe", "sessions": "5d5c677bddbd027e0b0c9f7a0dd95b66d9d95b4e130959f31fb955b926c2201c", - "settings": "6f77f9431beb9aa2e28840a60a6463d1a9eb0e92c1929b204584c85c71d0c7a3", + "settings": "ef2155983f362ef001e0d9b27536a3bea9e04869ee600e8285fff894c39ba1c1", "shared_entries": "f87a42bf44dc3606c5a44b185263c1b9a612a8ae194f75061253d4dde7b095a2", "unread_entries": "21c584da7ca8192655c62f16a7ac92dbbfdf1307588ffe51eb4a8bbf3f9f7526", "users": "d7ff52efc582bbad10504f4a04fa3adcc12d15890e45dff51cac281e0c446e45", diff --git a/ui/form/settings.go b/ui/form/settings.go index b61d1d5a..9eb51f5d 100644 --- a/ui/form/settings.go +++ b/ui/form/settings.go @@ -38,8 +38,8 @@ func (s *SettingsForm) Merge(user *model.User) *model.User { user.EntriesPerPage = s.EntriesPerPage user.KeyboardShortcuts = s.KeyboardShortcuts user.ShowReadingTime = s.ShowReadingTime - user.Extra["custom_css"] = s.CustomCSS - user.EntrySwipe = s.EntrySwipe; + user.Stylesheet = s.CustomCSS + user.EntrySwipe = s.EntrySwipe if s.Password != "" { user.Password = s.Password diff --git a/ui/oauth2_callback.go b/ui/oauth2_callback.go index e0453675..2c70386f 100644 --- a/ui/oauth2_callback.go +++ b/ui/oauth2_callback.go @@ -5,6 +5,7 @@ package ui // import "miniflux.app/ui" import ( + "errors" "net/http" "miniflux.app/config" @@ -44,7 +45,7 @@ func (h *handler) oauth2Callback(w http.ResponseWriter, r *http.Request) { return } - authProvider, err := getOAuth2Manager(r.Context()).Provider(provider) + authProvider, err := getOAuth2Manager(r.Context()).FindProvider(provider) if err != nil { logger.Error("[OAuth2] %v", err) html.Redirect(w, r, route.Path(h.router, "login")) @@ -61,20 +62,21 @@ func (h *handler) oauth2Callback(w http.ResponseWriter, r *http.Request) { logger.Info("[OAuth2] [ClientIP=%s] Successful auth for %s", clientIP, profile) if request.IsAuthenticated(r) { - user, err := h.store.UserByExtraField(profile.Key, profile.ID) + loggedUser, err := h.store.UserByID(request.UserID(r)) if err != nil { html.ServerError(w, r, err) return } - if user != nil { - logger.Error("[OAuth2] User #%d cannot be associated because %s is already associated", request.UserID(r), user.Username) + if h.store.AnotherUserWithFieldExists(loggedUser.ID, profile.Key, profile.ID) { + logger.Error("[OAuth2] User #%d cannot be associated because it is already associated with another user", loggedUser.ID) sess.NewFlashErrorMessage(printer.Printf("error.duplicate_linked_account")) html.Redirect(w, r, route.Path(h.router, "settings")) return } - if err := h.store.UpdateExtraField(request.UserID(r), profile.Key, profile.ID); err != nil { + authProvider.PopulateUserWithProfileID(loggedUser, profile) + if err := h.store.UpdateUser(loggedUser); err != nil { html.ServerError(w, r, err) return } @@ -84,7 +86,7 @@ func (h *handler) oauth2Callback(w http.ResponseWriter, r *http.Request) { return } - user, err := h.store.UserByExtraField(profile.Key, profile.ID) + user, err := h.store.UserByField(profile.Key, profile.ID) if err != nil { html.ServerError(w, r, err) return @@ -96,10 +98,15 @@ func (h *handler) oauth2Callback(w http.ResponseWriter, r *http.Request) { return } + if h.store.UserExists(profile.Username) { + html.BadRequest(w, r, errors.New(printer.Printf("error.user_already_exists"))) + return + } + user = model.NewUser() user.Username = profile.Username user.IsAdmin = false - user.Extra[profile.Key] = profile.ID + authProvider.PopulateUserWithProfileID(user, profile) if err := h.store.CreateUser(user); err != nil { html.ServerError(w, r, err) diff --git a/ui/oauth2_redirect.go b/ui/oauth2_redirect.go index f083c068..b2fe9bc8 100644 --- a/ui/oauth2_redirect.go +++ b/ui/oauth2_redirect.go @@ -24,7 +24,7 @@ func (h *handler) oauth2Redirect(w http.ResponseWriter, r *http.Request) { return } - authProvider, err := getOAuth2Manager(r.Context()).Provider(provider) + authProvider, err := getOAuth2Manager(r.Context()).FindProvider(provider) if err != nil { logger.Error("[OAuth2] %v", err) html.Redirect(w, r, route.Path(h.router, "login")) diff --git a/ui/oauth2_unlink.go b/ui/oauth2_unlink.go index f87acf93..46cc0df6 100644 --- a/ui/oauth2_unlink.go +++ b/ui/oauth2_unlink.go @@ -24,7 +24,7 @@ func (h *handler) oauth2Unlink(w http.ResponseWriter, r *http.Request) { return } - authProvider, err := getOAuth2Manager(r.Context()).Provider(provider) + authProvider, err := getOAuth2Manager(r.Context()).FindProvider(provider) if err != nil { logger.Error("[OAuth2] %v", err) html.Redirect(w, r, route.Path(h.router, "settings")) @@ -32,6 +32,11 @@ func (h *handler) oauth2Unlink(w http.ResponseWriter, r *http.Request) { } sess := session.New(h.store, request.SessionID(r)) + user, err := h.store.UserByID(request.UserID(r)) + if err != nil { + html.ServerError(w, r, err) + return + } hasPassword, err := h.store.HasPassword(request.UserID(r)) if err != nil { @@ -45,7 +50,8 @@ func (h *handler) oauth2Unlink(w http.ResponseWriter, r *http.Request) { return } - if err := h.store.RemoveExtraField(request.UserID(r), authProvider.GetUserExtraKey()); err != nil { + authProvider.UnsetUserProfileID(user) + if err := h.store.UpdateUser(user); err != nil { html.ServerError(w, r, err) return } diff --git a/ui/settings_show.go b/ui/settings_show.go index d877feff..16efcd8e 100644 --- a/ui/settings_show.go +++ b/ui/settings_show.go @@ -35,7 +35,7 @@ func (h *handler) showSettingsPage(w http.ResponseWriter, r *http.Request) { EntriesPerPage: user.EntriesPerPage, KeyboardShortcuts: user.KeyboardShortcuts, ShowReadingTime: user.ShowReadingTime, - CustomCSS: user.Extra["custom_css"], + CustomCSS: user.Stylesheet, EntrySwipe: user.EntrySwipe, } diff --git a/ui/static_stylesheet.go b/ui/static_stylesheet.go index e4300122..dd9b0e3e 100644 --- a/ui/static_stylesheet.go +++ b/ui/static_stylesheet.go @@ -18,22 +18,17 @@ func (h *handler) showStylesheet(w http.ResponseWriter, r *http.Request) { filename := request.RouteStringParam(r, "name") if filename == "custom_css" { user, err := h.store.UserByID(request.UserID(r)) - if err != nil { + if err != nil || user == nil { html.NotFound(w, r) return } b := response.New(w, r) - if user == nil { - b.WithHeader("Content-Type", "text/css; charset=utf-8") - b.WithBody("") - b.Write() - return - } b.WithHeader("Content-Type", "text/css; charset=utf-8") - b.WithBody(user.Extra["custom_css"]) + b.WithBody(user.Stylesheet) b.Write() return } + etag, found := static.StylesheetsChecksums[filename] if !found { html.NotFound(w, r)