Remove extra column from users table (HSTORE field)

Migrated key/value pairs to specific columns.
This commit is contained in:
Frédéric Guillot 2020-12-21 21:14:10 -08:00 committed by fguillot
parent ae74f94655
commit 83f3ccab0e
19 changed files with 256 additions and 141 deletions

View file

@ -18,17 +18,22 @@ const (
// User represents a user in the system.
type User struct {
ID int64 `json:"id"`
Username string `json:"username"`
Password string `json:"password,omitempty"`
IsAdmin bool `json:"is_admin"`
Theme string `json:"theme"`
Language string `json:"language"`
Timezone string `json:"timezone"`
EntryDirection string `json:"entry_sorting_direction"`
EntriesPerPage int `json:"entries_per_page"`
LastLoginAt *time.Time `json:"last_login_at"`
Extra map[string]string `json:"extra"`
ID int64 `json:"id"`
Username string `json:"username"`
Password string `json:"password,omitempty"`
IsAdmin bool `json:"is_admin"`
Theme string `json:"theme"`
Language string `json:"language"`
Timezone string `json:"timezone"`
EntryDirection string `json:"entry_sorting_direction"`
Stylesheet string `json:"stylesheet"`
GoogleID string `json:"google_id"`
OpenIDConnectID string `json:"openid_connect_id"`
EntriesPerPage int `json:"entries_per_page"`
KeyboardShortcuts bool `json:"keyboard_shortcuts"`
ShowReadingTime bool `json:"show_reading_time"`
EntrySwipe bool `json:"entry_swipe"`
LastLoginAt *time.Time `json:"last_login_at"`
}
func (u User) String() string {

View file

@ -4,7 +4,9 @@
package database // import "miniflux.app/database"
import "database/sql"
import (
"database/sql"
)
var schemaVersion = len(migrations)
@ -427,4 +429,71 @@ var migrations = []func(tx *sql.Tx) error{
_, err = tx.Exec(sql)
return err
},
func(tx *sql.Tx) (err error) {
_, err = tx.Exec(`
ALTER TABLE users
ADD column stylesheet text not null default '',
ADD column google_id text not null default '',
ADD column openid_connect_id text not null default ''
`)
if err != nil {
return err
}
_, err = tx.Exec(`
DECLARE my_cursor CURSOR FOR
SELECT
id,
COALESCE(extra->'custom_css', '') as custom_css,
COALESCE(extra->'google_id', '') as google_id,
COALESCE(extra->'oidc_id', '') as oidc_id
FROM users
FOR UPDATE
`)
if err != nil {
return err
}
defer tx.Exec("CLOSE my_cursor")
for {
var (
userID int64
customStylesheet string
googleID string
oidcID string
)
if err := tx.QueryRow(`FETCH NEXT FROM my_cursor`).Scan(&userID, &customStylesheet, &googleID, &oidcID); err != nil {
if err == sql.ErrNoRows {
break
}
return err
}
_, err := tx.Exec(
`UPDATE
users
SET
stylesheet=$2,
google_id=$3,
openid_connect_id=$4
WHERE
id=$1
`,
userID, customStylesheet, googleID, oidcID)
if err != nil {
return err
}
}
return err
},
func(tx *sql.Tx) (err error) {
_, err = tx.Exec(`
ALTER TABLE users DROP COLUMN extra;
CREATE UNIQUE INDEX users_google_id_idx ON users(google_id) WHERE google_id <> '';
CREATE UNIQUE INDEX users_openid_connect_id_idx ON users(openid_connect_id) WHERE openid_connect_id <> '';
`)
return err
},
}

View file

@ -213,7 +213,7 @@ List of networks allowed to access the metrics endpoint (comma-separated values)
Default is 127.0.0.1/8\&.
.TP
.B OAUTH2_PROVIDER
OAuth2 provider to use\&. Only google is supported\&.
Possible values are "google" or "oidc"\&.
.TP
.B OAUTH2_CLIENT_ID
OAuth2 client ID\&.

View file

@ -13,25 +13,27 @@ import (
// User represents a user in the system.
type User struct {
ID int64 `json:"id"`
Username string `json:"username"`
Password string `json:"password,omitempty"`
IsAdmin bool `json:"is_admin"`
Theme string `json:"theme"`
Language string `json:"language"`
Timezone string `json:"timezone"`
EntryDirection string `json:"entry_sorting_direction"`
EntriesPerPage int `json:"entries_per_page"`
KeyboardShortcuts bool `json:"keyboard_shortcuts"`
ShowReadingTime bool `json:"show_reading_time"`
LastLoginAt *time.Time `json:"last_login_at,omitempty"`
Extra map[string]string `json:"extra"`
EntrySwipe bool `json:"entry_swipe"`
ID int64 `json:"id"`
Username string `json:"username"`
Password string `json:"password,omitempty"`
IsAdmin bool `json:"is_admin"`
Theme string `json:"theme"`
Language string `json:"language"`
Timezone string `json:"timezone"`
EntryDirection string `json:"entry_sorting_direction"`
Stylesheet string `json:"stylesheet"`
GoogleID string `json:"google_id"`
OpenIDConnectID string `json:"openid_connect_id"`
EntriesPerPage int `json:"entries_per_page"`
KeyboardShortcuts bool `json:"keyboard_shortcuts"`
ShowReadingTime bool `json:"show_reading_time"`
EntrySwipe bool `json:"entry_swipe"`
LastLoginAt *time.Time `json:"last_login_at,omitempty"`
}
// NewUser returns a new User.
func NewUser() *User {
return &User{Extra: make(map[string]string)}
return &User{}
}
// ValidateUserCreation validates new user.

View file

@ -9,6 +9,8 @@ import (
"encoding/json"
"fmt"
"miniflux.app/model"
"golang.org/x/oauth2"
)
@ -23,15 +25,15 @@ type googleProvider struct {
redirectURL string
}
func (g googleProvider) GetUserExtraKey() string {
func (g *googleProvider) GetUserExtraKey() string {
return "google_id"
}
func (g googleProvider) GetRedirectURL(state string) string {
func (g *googleProvider) GetRedirectURL(state string) string {
return g.config().AuthCodeURL(state)
}
func (g googleProvider) GetProfile(ctx context.Context, code string) (*Profile, error) {
func (g *googleProvider) GetProfile(ctx context.Context, code string) (*Profile, error) {
conf := g.config()
token, err := conf.Exchange(ctx, code)
if err != nil {
@ -48,14 +50,22 @@ func (g googleProvider) GetProfile(ctx context.Context, code string) (*Profile,
var user googleProfile
decoder := json.NewDecoder(resp.Body)
if err := decoder.Decode(&user); err != nil {
return nil, fmt.Errorf("unable to unserialize google profile: %v", err)
return nil, fmt.Errorf("oauth2: unable to unserialize google profile: %v", err)
}
profile := &Profile{Key: g.GetUserExtraKey(), ID: user.Sub, Username: user.Email}
return profile, nil
}
func (g googleProvider) config() *oauth2.Config {
func (g *googleProvider) PopulateUserWithProfileID(user *model.User, profile *Profile) {
user.GoogleID = profile.ID
}
func (g *googleProvider) UnsetUserProfileID(user *model.User) {
user.GoogleID = ""
}
func (g *googleProvider) config() *oauth2.Config {
return &oauth2.Config{
RedirectURL: g.redirectURL,
ClientID: g.clientID,

View file

@ -7,6 +7,7 @@ package oauth2 // import "miniflux.app/oauth2"
import (
"context"
"errors"
"miniflux.app/logger"
)
@ -15,8 +16,8 @@ type Manager struct {
providers map[string]Provider
}
// Provider returns the given provider.
func (m *Manager) Provider(name string) (Provider, error) {
// FindProvider returns the given provider.
func (m *Manager) FindProvider(name string) (Provider, error) {
if provider, found := m.providers[name]; found {
return provider, nil
}

View file

@ -6,6 +6,9 @@ package oauth2 // import "miniflux.app/oauth2"
import (
"context"
"miniflux.app/model"
"github.com/coreos/go-oidc"
"golang.org/x/oauth2"
)
@ -17,15 +20,15 @@ type oidcProvider struct {
provider *oidc.Provider
}
func (o oidcProvider) GetUserExtraKey() string {
return "oidc_id" // FIXME? add extra options key to allow multiple OIDC providers each with their own extra key?
func (o *oidcProvider) GetUserExtraKey() string {
return "openid_connect_id"
}
func (o oidcProvider) GetRedirectURL(state string) string {
func (o *oidcProvider) GetRedirectURL(state string) string {
return o.config().AuthCodeURL(state)
}
func (o oidcProvider) GetProfile(ctx context.Context, code string) (*Profile, error) {
func (o *oidcProvider) GetProfile(ctx context.Context, code string) (*Profile, error) {
conf := o.config()
token, err := conf.Exchange(ctx, code)
if err != nil {
@ -41,7 +44,15 @@ func (o oidcProvider) GetProfile(ctx context.Context, code string) (*Profile, er
return profile, nil
}
func (o oidcProvider) config() *oauth2.Config {
func (o *oidcProvider) PopulateUserWithProfileID(user *model.User, profile *Profile) {
user.OpenIDConnectID = profile.ID
}
func (o *oidcProvider) UnsetUserProfileID(user *model.User) {
user.OpenIDConnectID = ""
}
func (o *oidcProvider) config() *oauth2.Config {
return &oauth2.Config{
RedirectURL: o.redirectURL,
ClientID: o.clientID,

View file

@ -3,11 +3,18 @@
// license that can be found in the LICENSE file.
package oauth2 // import "miniflux.app/oauth2"
import "context"
import (
"context"
"miniflux.app/model"
)
// Provider is an interface for OAuth2 providers.
type Provider interface {
GetUserExtraKey() string
GetRedirectURL(state string) string
GetProfile(ctx context.Context, code string) (*Profile, error)
PopulateUserWithProfileID(user *model.User, profile *Profile)
UnsetUserProfileID(user *model.User)
}

View file

@ -12,7 +12,7 @@ import (
"miniflux.app/logger"
"miniflux.app/model"
"github.com/lib/pq/hstore"
"github.com/lib/pq"
"golang.org/x/crypto/bcrypt"
)
@ -54,32 +54,37 @@ func (s *Storage) AnotherUserExists(userID int64, username string) bool {
// CreateUser creates a new user.
func (s *Storage) CreateUser(user *model.User) (err error) {
password := ""
extra := hstore.Hstore{Map: make(map[string]sql.NullString)}
hashedPassword := ""
if user.Password != "" {
password, err = hashPassword(user.Password)
hashedPassword, err = hashPassword(user.Password)
if err != nil {
return err
}
}
if len(user.Extra) > 0 {
for key, value := range user.Extra {
extra.Map[key] = sql.NullString{String: value, Valid: true}
}
}
query := `
INSERT INTO users
(username, password, is_admin, extra)
(username, password, is_admin, google_id, openid_connect_id)
VALUES
(LOWER($1), $2, $3, $4)
(LOWER($1), $2, $3, $4, $5)
RETURNING
id, username, is_admin, language, theme, timezone, entry_direction, entries_per_page, keyboard_shortcuts, show_reading_time, entry_swipe
id,
username,
is_admin,
language,
theme,
timezone,
entry_direction,
entries_per_page,
keyboard_shortcuts,
show_reading_time,
entry_swipe,
stylesheet,
google_id,
openid_connect_id
`
err = s.db.QueryRow(query, user.Username, password, user.IsAdmin, extra).Scan(
err = s.db.QueryRow(query, user.Username, hashedPassword, user.IsAdmin, user.GoogleID, user.OpenIDConnectID).Scan(
&user.ID,
&user.Username,
&user.IsAdmin,
@ -91,6 +96,9 @@ func (s *Storage) CreateUser(user *model.User) (err error) {
&user.KeyboardShortcuts,
&user.ShowReadingTime,
&user.EntrySwipe,
&user.Stylesheet,
&user.GoogleID,
&user.OpenIDConnectID,
)
if err != nil {
return fmt.Errorf(`store: unable to create user: %v`, err)
@ -101,26 +109,6 @@ func (s *Storage) CreateUser(user *model.User) (err error) {
return nil
}
// UpdateExtraField updates an extra field of the given user.
func (s *Storage) UpdateExtraField(userID int64, field, value string) error {
query := fmt.Sprintf(`UPDATE users SET extra = extra || hstore('%s', $1) WHERE id=$2`, field)
_, err := s.db.Exec(query, value, userID)
if err != nil {
return fmt.Errorf(`store: unable to update user extra field: %v`, err)
}
return nil
}
// RemoveExtraField deletes an extra field for the given user.
func (s *Storage) RemoveExtraField(userID int64, field string) error {
query := `UPDATE users SET extra = delete(extra, $1) WHERE id=$2`
_, err := s.db.Exec(query, field, userID)
if err != nil {
return fmt.Errorf(`store: unable to remove user extra field: %v`, err)
}
return nil
}
// UpdateUser updates a user.
func (s *Storage) UpdateUser(user *model.User) error {
if user.Password != "" {
@ -141,9 +129,12 @@ func (s *Storage) UpdateUser(user *model.User) error {
entries_per_page=$8,
keyboard_shortcuts=$9,
show_reading_time=$10,
entry_swipe=$11
entry_swipe=$11,
stylesheet=$12,
google_id=$13,
openid_connect_id=$14
WHERE
id=$12
id=$15
`
_, err = s.db.Exec(
@ -159,6 +150,9 @@ func (s *Storage) UpdateUser(user *model.User) error {
user.KeyboardShortcuts,
user.ShowReadingTime,
user.EntrySwipe,
user.Stylesheet,
user.GoogleID,
user.OpenIDConnectID,
user.ID,
)
if err != nil {
@ -176,9 +170,12 @@ func (s *Storage) UpdateUser(user *model.User) error {
entries_per_page=$7,
keyboard_shortcuts=$8,
show_reading_time=$9,
entry_swipe=$10
entry_swipe=$10,
stylesheet=$11,
google_id=$12,
openid_connect_id=$13
WHERE
id=$11
id=$14
`
_, err := s.db.Exec(
@ -193,6 +190,9 @@ func (s *Storage) UpdateUser(user *model.User) error {
user.KeyboardShortcuts,
user.ShowReadingTime,
user.EntrySwipe,
user.Stylesheet,
user.GoogleID,
user.OpenIDConnectID,
user.ID,
)
@ -201,10 +201,6 @@ func (s *Storage) UpdateUser(user *model.User) error {
}
}
if err := s.UpdateExtraField(user.ID, "custom_css", user.Extra["custom_css"]); err != nil {
return fmt.Errorf(`store: unable to update user custom css: %v`, err)
}
return nil
}
@ -234,7 +230,9 @@ func (s *Storage) UserByID(userID int64) (*model.User, error) {
show_reading_time,
entry_swipe,
last_login_at,
extra
stylesheet,
google_id,
openid_connect_id
FROM
users
WHERE
@ -259,7 +257,9 @@ func (s *Storage) UserByUsername(username string) (*model.User, error) {
show_reading_time,
entry_swipe,
last_login_at,
extra
stylesheet,
google_id,
openid_connect_id
FROM
users
WHERE
@ -268,8 +268,8 @@ func (s *Storage) UserByUsername(username string) (*model.User, error) {
return s.fetchUser(query, username)
}
// UserByExtraField finds a user by an extra field value.
func (s *Storage) UserByExtraField(field, value string) (*model.User, error) {
// UserByField finds a user by a field value.
func (s *Storage) UserByField(field, value string) (*model.User, error) {
query := `
SELECT
id,
@ -284,13 +284,22 @@ func (s *Storage) UserByExtraField(field, value string) (*model.User, error) {
show_reading_time,
entry_swipe,
last_login_at,
extra
stylesheet,
google_id,
openid_connect_id
FROM
users
WHERE
extra->$1=$2
%s=$1
`
return s.fetchUser(query, field, value)
return s.fetchUser(fmt.Sprintf(query, pq.QuoteIdentifier(field)), value)
}
// AnotherUserWithFieldExists returns true if a user has the value set for the given field.
func (s *Storage) AnotherUserWithFieldExists(userID int64, field, value string) bool {
var result bool
s.db.QueryRow(fmt.Sprintf(`SELECT true FROM users WHERE id <> $1 AND %s=$2`, pq.QuoteIdentifier(field)), userID, value).Scan(&result)
return result
}
// UserByAPIKey returns a User from an API Key.
@ -309,7 +318,9 @@ func (s *Storage) UserByAPIKey(token string) (*model.User, error) {
u.show_reading_time,
u.entry_swipe,
u.last_login_at,
u.extra
u.stylesheet,
u.google_id,
u.openid_connect_id
FROM
users u
LEFT JOIN
@ -321,8 +332,6 @@ func (s *Storage) UserByAPIKey(token string) (*model.User, error) {
}
func (s *Storage) fetchUser(query string, args ...interface{}) (*model.User, error) {
var extra hstore.Hstore
user := model.NewUser()
err := s.db.QueryRow(query, args...).Scan(
&user.ID,
@ -337,7 +346,9 @@ func (s *Storage) fetchUser(query string, args ...interface{}) (*model.User, err
&user.ShowReadingTime,
&user.EntrySwipe,
&user.LastLoginAt,
&extra,
&user.Stylesheet,
&user.GoogleID,
&user.OpenIDConnectID,
)
if err == sql.ErrNoRows {
@ -346,12 +357,6 @@ func (s *Storage) fetchUser(query string, args ...interface{}) (*model.User, err
return nil, fmt.Errorf(`store: unable to fetch user: %v`, err)
}
for key, value := range extra.Map {
if value.Valid {
user.Extra[key] = value.String
}
}
return user, nil
}
@ -404,7 +409,9 @@ func (s *Storage) Users() (model.Users, error) {
show_reading_time,
entry_swipe,
last_login_at,
extra
stylesheet,
google_id,
openid_connect_id
FROM
users
ORDER BY username ASC
@ -417,7 +424,6 @@ func (s *Storage) Users() (model.Users, error) {
var users model.Users
for rows.Next() {
var extra hstore.Hstore
user := model.NewUser()
err := rows.Scan(
&user.ID,
@ -432,19 +438,15 @@ func (s *Storage) Users() (model.Users, error) {
&user.ShowReadingTime,
&user.EntrySwipe,
&user.LastLoginAt,
&extra,
&user.Stylesheet,
&user.GoogleID,
&user.OpenIDConnectID,
)
if err != nil {
return nil, fmt.Errorf(`store: unable to fetch users row: %v`, err)
}
for key, value := range extra.Map {
if value.Valid {
user.Extra[key] = value.String
}
}
users = append(users, user)
}

View file

@ -346,9 +346,9 @@ SOFTWARE.
<meta name="theme-color" content="{{ theme_color .theme }}">
<link rel="stylesheet" type="text/css" href="{{ route "stylesheet" "name" .theme }}?{{ .theme_checksum }}">
{{ if .user }} {{ if ne (index .user.Extra "custom_css") ("") }}
{{ if and .user .user.Stylesheet }}
<link rel="stylesheet" type="text/css" href="{{ route "stylesheet" "name" "custom_css" }}">
{{ end }}{{ end }}
{{ end }}
<script type="text/javascript" src="{{ route "javascript" "name" "app" }}?{{ .app_js_checksum }}" defer></script>
<script type="text/javascript" src="{{ route "javascript" "name" "service-worker" }}?{{ .sw_js_checksum }}" defer id="service-worker-script"></script>
@ -524,7 +524,7 @@ var templateCommonMapChecksums = map[string]string{
"feed_menu": "318d8662dda5ca9dfc75b909c8461e79c86fb5082df1428f67aaf856f19f4b50",
"icons": "9a41753778072f286216085d8712495e2ccca20c7a24f5c982775436a3d38579",
"item_meta": "56ab09d7dd46eeb2e2ee11ddcec0c157a5832c896dbd2887d9e2b013680b2af6",
"layout": "65767e7dbebe1f7ed42895ecd5a737b0693e4a2ec35e84e3e391f462beb11977",
"layout": "c4b8c65c0d85ed1aff0550f58b9dbf0768c74f2df6232952c0fe299d4c73d674",
"pagination": "7b61288e86283c4cf0dc83bcbf8bf1c00c7cb29e60201c8c0b633b2450d2911f",
"settings_menu": "e2b777630c0efdbc529800303c01d6744ed3af80ec505ac5a5b3f99c9b989156",
}

View file

@ -35,9 +35,9 @@
<meta name="theme-color" content="{{ theme_color .theme }}">
<link rel="stylesheet" type="text/css" href="{{ route "stylesheet" "name" .theme }}?{{ .theme_checksum }}">
{{ if .user }} {{ if ne (index .user.Extra "custom_css") ("") }}
{{ if and .user .user.Stylesheet }}
<link rel="stylesheet" type="text/css" href="{{ route "stylesheet" "name" "custom_css" }}">
{{ end }}{{ end }}
{{ end }}
<script type="text/javascript" src="{{ route "javascript" "name" "app" }}?{{ .app_js_checksum }}" defer></script>
<script type="text/javascript" src="{{ route "javascript" "name" "service-worker" }}?{{ .sw_js_checksum }}" defer id="service-worker-script"></script>

View file

@ -66,7 +66,7 @@
{{ if hasOAuth2Provider "google" }}
<div class="panel">
{{ if hasKey .user.Extra "google_id" }}
{{ if .user.GoogleID }}
<a href="{{ route "oauth2Unlink" "provider" "google" }}">{{ t "page.settings.unlink_google_account" }}</a>
{{ else }}
<a href="{{ route "oauth2Redirect" "provider" "google" }}">{{ t "page.settings.link_google_account" }}</a>
@ -74,7 +74,7 @@
</div>
{{ else if hasOAuth2Provider "oidc" }}
<div class="panel">
{{ if hasKey .user.Extra "oidc_id" }}
{{ if .user.OpenIDConnectID }}
<a href="{{ route "oauth2Unlink" "provider" "oidc" }}">{{ t "page.settings.unlink_oidc_account" }}</a>
{{ else }}
<a href="{{ route "oauth2Redirect" "provider" "oidc" }}">{{ t "page.settings.link_oidc_account" }}</a>

View file

@ -1387,7 +1387,7 @@ var templateViewsMap = map[string]string{
{{ if hasOAuth2Provider "google" }}
<div class="panel">
{{ if hasKey .user.Extra "google_id" }}
{{ if .user.GoogleID }}
<a href="{{ route "oauth2Unlink" "provider" "google" }}">{{ t "page.settings.unlink_google_account" }}</a>
{{ else }}
<a href="{{ route "oauth2Redirect" "provider" "google" }}">{{ t "page.settings.link_google_account" }}</a>
@ -1395,7 +1395,7 @@ var templateViewsMap = map[string]string{
</div>
{{ else if hasOAuth2Provider "oidc" }}
<div class="panel">
{{ if hasKey .user.Extra "oidc_id" }}
{{ if .user.OpenIDConnectID }}
<a href="{{ route "oauth2Unlink" "provider" "oidc" }}">{{ t "page.settings.unlink_oidc_account" }}</a>
{{ else }}
<a href="{{ route "oauth2Redirect" "provider" "oidc" }}">{{ t "page.settings.link_oidc_account" }}</a>
@ -1624,7 +1624,7 @@ var templateViewsMapChecksums = map[string]string{
"login": "9165434b2405e9332de4bebbb54a93dc5692276ea72e7c5e07f655a002dfd290",
"search_entries": "6a3e5876cb7541a2f08f56e30ab46a2d7d64894ec5e170f627b2dd674d8aeefe",
"sessions": "5d5c677bddbd027e0b0c9f7a0dd95b66d9d95b4e130959f31fb955b926c2201c",
"settings": "6f77f9431beb9aa2e28840a60a6463d1a9eb0e92c1929b204584c85c71d0c7a3",
"settings": "ef2155983f362ef001e0d9b27536a3bea9e04869ee600e8285fff894c39ba1c1",
"shared_entries": "f87a42bf44dc3606c5a44b185263c1b9a612a8ae194f75061253d4dde7b095a2",
"unread_entries": "21c584da7ca8192655c62f16a7ac92dbbfdf1307588ffe51eb4a8bbf3f9f7526",
"users": "d7ff52efc582bbad10504f4a04fa3adcc12d15890e45dff51cac281e0c446e45",

View file

@ -38,8 +38,8 @@ func (s *SettingsForm) Merge(user *model.User) *model.User {
user.EntriesPerPage = s.EntriesPerPage
user.KeyboardShortcuts = s.KeyboardShortcuts
user.ShowReadingTime = s.ShowReadingTime
user.Extra["custom_css"] = s.CustomCSS
user.EntrySwipe = s.EntrySwipe;
user.Stylesheet = s.CustomCSS
user.EntrySwipe = s.EntrySwipe
if s.Password != "" {
user.Password = s.Password

View file

@ -5,6 +5,7 @@
package ui // import "miniflux.app/ui"
import (
"errors"
"net/http"
"miniflux.app/config"
@ -44,7 +45,7 @@ func (h *handler) oauth2Callback(w http.ResponseWriter, r *http.Request) {
return
}
authProvider, err := getOAuth2Manager(r.Context()).Provider(provider)
authProvider, err := getOAuth2Manager(r.Context()).FindProvider(provider)
if err != nil {
logger.Error("[OAuth2] %v", err)
html.Redirect(w, r, route.Path(h.router, "login"))
@ -61,20 +62,21 @@ func (h *handler) oauth2Callback(w http.ResponseWriter, r *http.Request) {
logger.Info("[OAuth2] [ClientIP=%s] Successful auth for %s", clientIP, profile)
if request.IsAuthenticated(r) {
user, err := h.store.UserByExtraField(profile.Key, profile.ID)
loggedUser, err := h.store.UserByID(request.UserID(r))
if err != nil {
html.ServerError(w, r, err)
return
}
if user != nil {
logger.Error("[OAuth2] User #%d cannot be associated because %s is already associated", request.UserID(r), user.Username)
if h.store.AnotherUserWithFieldExists(loggedUser.ID, profile.Key, profile.ID) {
logger.Error("[OAuth2] User #%d cannot be associated because it is already associated with another user", loggedUser.ID)
sess.NewFlashErrorMessage(printer.Printf("error.duplicate_linked_account"))
html.Redirect(w, r, route.Path(h.router, "settings"))
return
}
if err := h.store.UpdateExtraField(request.UserID(r), profile.Key, profile.ID); err != nil {
authProvider.PopulateUserWithProfileID(loggedUser, profile)
if err := h.store.UpdateUser(loggedUser); err != nil {
html.ServerError(w, r, err)
return
}
@ -84,7 +86,7 @@ func (h *handler) oauth2Callback(w http.ResponseWriter, r *http.Request) {
return
}
user, err := h.store.UserByExtraField(profile.Key, profile.ID)
user, err := h.store.UserByField(profile.Key, profile.ID)
if err != nil {
html.ServerError(w, r, err)
return
@ -96,10 +98,15 @@ func (h *handler) oauth2Callback(w http.ResponseWriter, r *http.Request) {
return
}
if h.store.UserExists(profile.Username) {
html.BadRequest(w, r, errors.New(printer.Printf("error.user_already_exists")))
return
}
user = model.NewUser()
user.Username = profile.Username
user.IsAdmin = false
user.Extra[profile.Key] = profile.ID
authProvider.PopulateUserWithProfileID(user, profile)
if err := h.store.CreateUser(user); err != nil {
html.ServerError(w, r, err)

View file

@ -24,7 +24,7 @@ func (h *handler) oauth2Redirect(w http.ResponseWriter, r *http.Request) {
return
}
authProvider, err := getOAuth2Manager(r.Context()).Provider(provider)
authProvider, err := getOAuth2Manager(r.Context()).FindProvider(provider)
if err != nil {
logger.Error("[OAuth2] %v", err)
html.Redirect(w, r, route.Path(h.router, "login"))

View file

@ -24,7 +24,7 @@ func (h *handler) oauth2Unlink(w http.ResponseWriter, r *http.Request) {
return
}
authProvider, err := getOAuth2Manager(r.Context()).Provider(provider)
authProvider, err := getOAuth2Manager(r.Context()).FindProvider(provider)
if err != nil {
logger.Error("[OAuth2] %v", err)
html.Redirect(w, r, route.Path(h.router, "settings"))
@ -32,6 +32,11 @@ func (h *handler) oauth2Unlink(w http.ResponseWriter, r *http.Request) {
}
sess := session.New(h.store, request.SessionID(r))
user, err := h.store.UserByID(request.UserID(r))
if err != nil {
html.ServerError(w, r, err)
return
}
hasPassword, err := h.store.HasPassword(request.UserID(r))
if err != nil {
@ -45,7 +50,8 @@ func (h *handler) oauth2Unlink(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.store.RemoveExtraField(request.UserID(r), authProvider.GetUserExtraKey()); err != nil {
authProvider.UnsetUserProfileID(user)
if err := h.store.UpdateUser(user); err != nil {
html.ServerError(w, r, err)
return
}

View file

@ -35,7 +35,7 @@ func (h *handler) showSettingsPage(w http.ResponseWriter, r *http.Request) {
EntriesPerPage: user.EntriesPerPage,
KeyboardShortcuts: user.KeyboardShortcuts,
ShowReadingTime: user.ShowReadingTime,
CustomCSS: user.Extra["custom_css"],
CustomCSS: user.Stylesheet,
EntrySwipe: user.EntrySwipe,
}

View file

@ -18,22 +18,17 @@ func (h *handler) showStylesheet(w http.ResponseWriter, r *http.Request) {
filename := request.RouteStringParam(r, "name")
if filename == "custom_css" {
user, err := h.store.UserByID(request.UserID(r))
if err != nil {
if err != nil || user == nil {
html.NotFound(w, r)
return
}
b := response.New(w, r)
if user == nil {
b.WithHeader("Content-Type", "text/css; charset=utf-8")
b.WithBody("")
b.Write()
return
}
b.WithHeader("Content-Type", "text/css; charset=utf-8")
b.WithBody(user.Extra["custom_css"])
b.WithBody(user.Stylesheet)
b.Write()
return
}
etag, found := static.StylesheetsChecksums[filename]
if !found {
html.NotFound(w, r)