Improve OAuth2 integration

This commit is contained in:
Frédéric Guillot 2017-11-24 16:09:10 -08:00
parent cc6d272eb7
commit 747da03e4c
16 changed files with 120 additions and 15 deletions

View file

@ -26,7 +26,7 @@ func Hash(value string) string {
func GenerateRandomBytes(size int) []byte { func GenerateRandomBytes(size int) []byte {
b := make([]byte, size) b := make([]byte, size)
if _, err := rand.Read(b); err != nil { if _, err := rand.Read(b); err != nil {
panic(fmt.Errorf("Unable to generate random string: %v", err)) panic(err)
} }
return b return b

View file

@ -1,5 +1,5 @@
// Code generated by go generate; DO NOT EDIT. // Code generated by go generate; DO NOT EDIT.
// 2017-11-22 22:11:44.610818223 -0800 PST m=+0.024503556 // 2017-11-24 16:04:49.318661623 -0800 PST m=+0.006828741
package locale package locale
@ -146,12 +146,14 @@ var Translations = map[string]string{
"This special link allows you to subscribe to a website directly by using a bookmark in your web browser.": "Ce lien spécial vous permet de vous abonner à un site web directement en utilisant un marque page dans votre navigateur web.", "This special link allows you to subscribe to a website directly by using a bookmark in your web browser.": "Ce lien spécial vous permet de vous abonner à un site web directement en utilisant un marque page dans votre navigateur web.",
"Add to Miniflux": "Ajouter à Miniflux", "Add to Miniflux": "Ajouter à Miniflux",
"Refresh all feeds in the background": "Actualiser tous les abonnements en arrière-plan", "Refresh all feeds in the background": "Actualiser tous les abonnements en arrière-plan",
"Sign in with Google": "Se connecter avec Google" "Sign in with Google": "Se connecter avec Google",
"Unlink my Google account": "Dissocier mon compte Google",
"Link my Google account": "Associer mon compte Google"
} }
`, `,
} }
var TranslationsChecksums = map[string]string{ var TranslationsChecksums = map[string]string{
"en_US": "6fe95384260941e8a5a3c695a655a932e0a8a6a572c1e45cb2b1ae8baa01b897", "en_US": "6fe95384260941e8a5a3c695a655a932e0a8a6a572c1e45cb2b1ae8baa01b897",
"fr_FR": "f413b0bc103b2ab689d52da2e17c5e718a91f5dc4138dc601beaae4ec0cfc1af", "fr_FR": "f438ed9116ecc7b71412581255dd9b1332cacd9e2876615b03ec65e4b500bf02",
} }

View file

@ -130,5 +130,7 @@
"This special link allows you to subscribe to a website directly by using a bookmark in your web browser.": "Ce lien spécial vous permet de vous abonner à un site web directement en utilisant un marque page dans votre navigateur web.", "This special link allows you to subscribe to a website directly by using a bookmark in your web browser.": "Ce lien spécial vous permet de vous abonner à un site web directement en utilisant un marque page dans votre navigateur web.",
"Add to Miniflux": "Ajouter à Miniflux", "Add to Miniflux": "Ajouter à Miniflux",
"Refresh all feeds in the background": "Actualiser tous les abonnements en arrière-plan", "Refresh all feeds in the background": "Actualiser tous les abonnements en arrière-plan",
"Sign in with Google": "Se connecter avec Google" "Sign in with Google": "Se connecter avec Google",
"Unlink my Google account": "Dissocier mon compte Google",
"Link my Google account": "Associer mon compte Google"
} }

View file

@ -23,6 +23,10 @@ type googleProvider struct {
redirectURL string redirectURL string
} }
func (g googleProvider) GetUserExtraKey() string {
return "google_id"
}
func (g googleProvider) GetRedirectURL(state string) string { func (g googleProvider) GetRedirectURL(state string) string {
return g.config().AuthCodeURL(state) return g.config().AuthCodeURL(state)
} }
@ -48,7 +52,7 @@ func (g googleProvider) GetProfile(code string) (*Profile, error) {
return nil, fmt.Errorf("unable to unserialize google profile: %v", err) return nil, fmt.Errorf("unable to unserialize google profile: %v", err)
} }
profile := &Profile{Key: "google_id", ID: user.Sub, Username: user.Email} profile := &Profile{Key: g.GetUserExtraKey(), ID: user.Sub, Username: user.Email}
return profile, nil return profile, nil
} }

View file

@ -6,6 +6,7 @@ package oauth2
// Provider is an interface for OAuth2 providers. // Provider is an interface for OAuth2 providers.
type Provider interface { type Provider interface {
GetUserExtraKey() string
GetRedirectURL(state string) string GetRedirectURL(state string) string
GetProfile(code string) (*Profile, error) GetProfile(code string) (*Profile, error)
} }

View file

@ -124,6 +124,7 @@ func getRoutes(cfg *config.Config, store *storage.Storage, feedHandler *feed.Han
router.Handle("/import", uiHandler.Use(uiController.Import)).Name("import").Methods("GET") router.Handle("/import", uiHandler.Use(uiController.Import)).Name("import").Methods("GET")
router.Handle("/upload", uiHandler.Use(uiController.UploadOPML)).Name("uploadOPML").Methods("POST") router.Handle("/upload", uiHandler.Use(uiController.UploadOPML)).Name("uploadOPML").Methods("POST")
router.Handle("/oauth2/{provider}/unlink", uiHandler.Use(uiController.OAuth2Unlink)).Name("oauth2Unlink").Methods("GET")
router.Handle("/oauth2/{provider}/redirect", uiHandler.Use(uiController.OAuth2Redirect)).Name("oauth2Redirect").Methods("GET") router.Handle("/oauth2/{provider}/redirect", uiHandler.Use(uiController.OAuth2Redirect)).Name("oauth2Redirect").Methods("GET")
router.Handle("/oauth2/{provider}/callback", uiHandler.Use(uiController.OAuth2Callback)).Name("oauth2Callback").Methods("GET") router.Handle("/oauth2/{provider}/callback", uiHandler.Use(uiController.OAuth2Callback)).Name("oauth2Callback").Methods("GET")

View file

@ -1,5 +1,5 @@
// Code generated by go generate; DO NOT EDIT. // Code generated by go generate; DO NOT EDIT.
// 2017-11-22 22:11:44.595540312 -0800 PST m=+0.009225645 // 2017-11-24 16:04:49.314940117 -0800 PST m=+0.003107235
package static package static

View file

@ -1,5 +1,5 @@
// Code generated by go generate; DO NOT EDIT. // Code generated by go generate; DO NOT EDIT.
// 2017-11-22 22:11:44.596955262 -0800 PST m=+0.010640595 // 2017-11-24 16:04:49.315340301 -0800 PST m=+0.003507419
package static package static

View file

@ -1,5 +1,5 @@
// Code generated by go generate; DO NOT EDIT. // Code generated by go generate; DO NOT EDIT.
// 2017-11-22 22:11:44.598697812 -0800 PST m=+0.012383145 // 2017-11-24 16:04:49.316027642 -0800 PST m=+0.004194760
package static package static

View file

@ -1,5 +1,5 @@
// Code generated by go generate; DO NOT EDIT. // Code generated by go generate; DO NOT EDIT.
// 2017-11-22 22:11:44.609659332 -0800 PST m=+0.023344665 // 2017-11-24 16:04:49.318279667 -0800 PST m=+0.006446785
package template package template

View file

@ -63,4 +63,14 @@
</div> </div>
</form> </form>
{{ if hasOAuth2Provider "google" }}
<div class="panel">
{{ if hasKey .user.Extra "google_id" }}
<a href="{{ route "oauth2Unlink" "provider" "google" }}">{{ t "Unlink my Google account" }}</a>
{{ else }}
<a href="{{ route "oauth2Redirect" "provider" "google" }}">{{ t "Link my Google account" }}</a>
{{ end }}
</div>
{{ end }}
{{ end }} {{ end }}

View file

@ -40,6 +40,13 @@ func (e *Engine) parseAll() {
"hasOAuth2Provider": func(provider string) bool { "hasOAuth2Provider": func(provider string) bool {
return e.cfg.Get("OAUTH2_PROVIDER", "") == provider return e.cfg.Get("OAUTH2_PROVIDER", "") == provider
}, },
"hasKey": func(dict map[string]string, key string) bool {
log.Println(dict)
if value, found := dict[key]; found {
return value != ""
}
return false
},
"route": func(name string, args ...interface{}) string { "route": func(name string, args ...interface{}) string {
return route.GetRoute(e.router, name, args...) return route.GetRoute(e.router, name, args...)
}, },

View file

@ -1,5 +1,5 @@
// Code generated by go generate; DO NOT EDIT. // Code generated by go generate; DO NOT EDIT.
// 2017-11-22 22:11:44.601583424 -0800 PST m=+0.015268757 // 2017-11-24 16:04:49.316644027 -0800 PST m=+0.004811145
package template package template
@ -921,6 +921,16 @@ var templateViewsMap = map[string]string{
</div> </div>
</form> </form>
{{ if hasOAuth2Provider "google" }}
<div class="panel">
{{ if hasKey .user.Extra "google_id" }}
<a href="{{ route "oauth2Unlink" "provider" "google" }}">{{ t "Unlink my Google account" }}</a>
{{ else }}
<a href="{{ route "oauth2Redirect" "provider" "google" }}">{{ t "Link my Google account" }}</a>
{{ end }}
</div>
{{ end }}
{{ end }} {{ end }}
`, `,
"unread": `{{ define "title"}}{{ t "Unread Items" }} {{ if gt .countUnread 0 }}({{ .countUnread }}){{ end }} {{ end }} "unread": `{{ define "title"}}{{ t "Unread Items" }} {{ if gt .countUnread 0 }}({{ .countUnread }}){{ end }} {{ end }}
@ -1052,7 +1062,7 @@ var templateViewsMapChecksums = map[string]string{
"integrations": "c485d6d9ed996635e55e73320610e6bcb01a41c1153e8e739ae2294b0b14b243", "integrations": "c485d6d9ed996635e55e73320610e6bcb01a41c1153e8e739ae2294b0b14b243",
"login": "04f3ce79bfa5753f69e0d956c2a8999c0da549c7925634a3e8134975da0b0e0f", "login": "04f3ce79bfa5753f69e0d956c2a8999c0da549c7925634a3e8134975da0b0e0f",
"sessions": "878dbe8f8ea783b44130c495814179519fa5c3aa2666ac87508f94d58dd008bf", "sessions": "878dbe8f8ea783b44130c495814179519fa5c3aa2666ac87508f94d58dd008bf",
"settings": "a972fb5767fd32522648149880e40607ed8bbed7a389038bbab6b08539ac2893", "settings": "1e2df11f5436eb2d05ae1fae30dd6f1362613011edbfcc79ae8b23854fa348b4",
"unread": "b6f9be1a72188947c75a6fdcac6ff7878db7745f9efa46318e0433102892a722", "unread": "b6f9be1a72188947c75a6fdcac6ff7878db7745f9efa46318e0433102892a722",
"users": "44677e28bb5347799ed0020c90ec785aadec4b1454446d92411cfdaf6e32110b", "users": "44677e28bb5347799ed0020c90ec785aadec4b1454446d92411cfdaf6e32110b",
} }

View file

@ -71,6 +71,17 @@ func (c *Controller) OAuth2Callback(ctx *core.Context, request *core.Request, re
return return
} }
if ctx.IsAuthenticated() {
user := ctx.LoggedUser()
if err := c.store.UpdateExtraField(user.ID, profile.Key, profile.ID); err != nil {
response.HTML().ServerError(err)
return
}
response.Redirect(ctx.Route("settings"))
return
}
user, err := c.store.GetUserByExtraField(profile.Key, profile.ID) user, err := c.store.GetUserByExtraField(profile.Key, profile.ID)
if err != nil { if err != nil {
response.HTML().ServerError(err) response.HTML().ServerError(err)
@ -78,6 +89,11 @@ func (c *Controller) OAuth2Callback(ctx *core.Context, request *core.Request, re
} }
if user == nil { if user == nil {
if c.cfg.GetInt("OAUTH2_USER_CREATION", 0) == 0 {
response.HTML().Forbidden()
return
}
user = model.NewUser() user = model.NewUser()
user.Username = profile.Username user.Username = profile.Username
user.IsAdmin = false user.IsAdmin = false
@ -114,6 +130,32 @@ func (c *Controller) OAuth2Callback(ctx *core.Context, request *core.Request, re
response.Redirect(ctx.Route("unread")) response.Redirect(ctx.Route("unread"))
} }
// OAuth2Unlink unlink an account from the external provider.
func (c *Controller) OAuth2Unlink(ctx *core.Context, request *core.Request, response *core.Response) {
provider := request.StringParam("provider", "")
if provider == "" {
log.Println("[OAuth2] Invalid or missing provider")
response.Redirect(ctx.Route("login"))
return
}
authProvider, err := getOAuth2Manager(c.cfg).Provider(provider)
if err != nil {
log.Println("[OAuth2]", err)
response.Redirect(ctx.Route("settings"))
return
}
user := ctx.LoggedUser()
if err := c.store.RemoveExtraField(user.ID, authProvider.GetUserExtraKey()); err != nil {
response.HTML().ServerError(err)
return
}
response.Redirect(ctx.Route("settings"))
return
}
func getOAuth2Manager(cfg *config.Config) *oauth2.Manager { func getOAuth2Manager(cfg *config.Config) *oauth2.Manager {
return oauth2.NewManager( return oauth2.NewManager(
cfg.Get("OAUTH2_CLIENT_ID", ""), cfg.Get("OAUTH2_CLIENT_ID", ""),

View file

@ -1,5 +1,5 @@
// Code generated by go generate; DO NOT EDIT. // Code generated by go generate; DO NOT EDIT.
// 2017-11-22 22:11:44.590706207 -0800 PST m=+0.004391540 // 2017-11-24 16:04:49.314265268 -0800 PST m=+0.002432386
package sql package sql

View file

@ -74,6 +74,24 @@ func (s *Storage) CreateUser(user *model.User) (err error) {
return nil return nil
} }
func (s *Storage) UpdateExtraField(userID int64, field, value string) error {
query := fmt.Sprintf(`UPDATE users SET extra = hstore('%s', $1) WHERE id=$2`, field)
_, err := s.db.Exec(query, value, userID)
if err != nil {
return fmt.Errorf("unable to update user extra field: %v", err)
}
return nil
}
func (s *Storage) RemoveExtraField(userID int64, field string) error {
query := `UPDATE users SET extra = delete(extra, $1) WHERE id=$2`
_, err := s.db.Exec(query, field, userID)
if err != nil {
return fmt.Errorf("unable to remove user extra field: %v", err)
}
return nil
}
func (s *Storage) UpdateUser(user *model.User) error { func (s *Storage) UpdateUser(user *model.User) error {
defer helper.ExecutionTime(time.Now(), fmt.Sprintf("[Storage:UpdateUser] username=%s", user.Username)) defer helper.ExecutionTime(time.Now(), fmt.Sprintf("[Storage:UpdateUser] username=%s", user.Username))
user.Username = strings.ToLower(user.Username) user.Username = strings.ToLower(user.Username)
@ -104,14 +122,22 @@ func (s *Storage) GetUserById(userID int64) (*model.User, error) {
defer helper.ExecutionTime(time.Now(), fmt.Sprintf("[Storage:GetUserById] userID=%d", userID)) defer helper.ExecutionTime(time.Now(), fmt.Sprintf("[Storage:GetUserById] userID=%d", userID))
var user model.User var user model.User
row := s.db.QueryRow("SELECT id, username, is_admin, theme, language, timezone FROM users WHERE id = $1", userID) var extra hstore.Hstore
err := row.Scan(&user.ID, &user.Username, &user.IsAdmin, &user.Theme, &user.Language, &user.Timezone) row := s.db.QueryRow("SELECT id, username, is_admin, theme, language, timezone, extra FROM users WHERE id = $1", userID)
err := row.Scan(&user.ID, &user.Username, &user.IsAdmin, &user.Theme, &user.Language, &user.Timezone, &extra)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} else if err != nil { } else if err != nil {
return nil, fmt.Errorf("unable to fetch user: %v", err) return nil, fmt.Errorf("unable to fetch user: %v", err)
} }
user.Extra = make(map[string]string)
for key, value := range extra.Map {
if value.Valid {
user.Extra[key] = value.String
}
}
return &user, nil return &user, nil
} }