diff --git a/README.md b/README.md index 308efcb4..807042dd 100644 --- a/README.md +++ b/README.md @@ -27,12 +27,11 @@ TODO - [ ] Custom entries sorting - [ ] Webpage scraper (Readability) - [X] Bookmarklet -- [ ] External integrations (Pinboard, Wallabag...) +- [ ] External integrations (Pinboard, Instapaper, Pocket?) - [ ] Gzip compression - [X] Integration tests - [X] Flush history - [X] OAuth2 -- [ ] Bookmarks - [ ] Touch events - [ ] Fever API? diff --git a/model/token.go b/model/token.go new file mode 100644 index 00000000..5626a77d --- /dev/null +++ b/model/token.go @@ -0,0 +1,11 @@ +// Copyright 2017 Frédéric Guillot. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package model + +// Token represents a CSRF token in the system. +type Token struct { + ID string + Value string +} diff --git a/server/core/context.go b/server/core/context.go index 217e4d46..393060d3 100644 --- a/server/core/context.go +++ b/server/core/context.go @@ -82,7 +82,7 @@ func (c *Context) UserLanguage() string { // CsrfToken returns the current CSRF token. func (c *Context) CsrfToken() string { - if v := c.request.Context().Value(middleware.CsrfContextKey); v != nil { + if v := c.request.Context().Value(middleware.TokenContextKey); v != nil { return v.(string) } diff --git a/server/middleware/context_keys.go b/server/middleware/context_keys.go index c011fbb6..3099322a 100644 --- a/server/middleware/context_keys.go +++ b/server/middleware/context_keys.go @@ -21,6 +21,6 @@ var ( // IsAuthenticatedContextKey is the context key used to store the authentication flag. IsAuthenticatedContextKey = &contextKey{"IsAuthenticated"} - // CsrfContextKey is the context key used to store CSRF token. - CsrfContextKey = &contextKey{"CSRF"} + // TokenContextKey is the context key used to store CSRF token. + TokenContextKey = &contextKey{"CSRF"} ) diff --git a/server/middleware/csrf.go b/server/middleware/csrf.go deleted file mode 100644 index 0c07e428..00000000 --- a/server/middleware/csrf.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2017 Frédéric Guillot. All rights reserved. -// Use of this source code is governed by the Apache 2.0 -// license that can be found in the LICENSE file. - -package middleware - -import ( - "context" - "log" - "net/http" - - "github.com/miniflux/miniflux2/helper" -) - -// Csrf is a middleware that handle CSRF tokens. -func Csrf(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var csrfToken string - - csrfCookie, err := r.Cookie("csrfToken") - if err == http.ErrNoCookie || csrfCookie.Value == "" { - csrfToken = helper.GenerateRandomString(64) - cookie := &http.Cookie{ - Name: "csrfToken", - Value: csrfToken, - Path: "/", - Secure: r.URL.Scheme == "https", - HttpOnly: true, - } - - http.SetCookie(w, cookie) - } else { - csrfToken = csrfCookie.Value - } - - ctx := r.Context() - ctx = context.WithValue(ctx, CsrfContextKey, csrfToken) - - w.Header().Add("Vary", "Cookie") - isTokenValid := csrfToken == r.FormValue("csrf") || csrfToken == r.Header.Get("X-Csrf-Token") - - if r.Method == "POST" && !isTokenValid { - log.Println("[Middleware:CSRF] Invalid or missing CSRF token!") - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("Invalid or missing CSRF token!")) - } else { - next.ServeHTTP(w, r.WithContext(ctx)) - } - }) -} diff --git a/server/middleware/token.go b/server/middleware/token.go new file mode 100644 index 00000000..e250633f --- /dev/null +++ b/server/middleware/token.go @@ -0,0 +1,81 @@ +// Copyright 2017 Frédéric Guillot. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package middleware + +import ( + "context" + "log" + "net/http" + + "github.com/miniflux/miniflux2/model" + "github.com/miniflux/miniflux2/storage" +) + +// TokenMiddleware represents a token middleware. +type TokenMiddleware struct { + store *storage.Storage +} + +// Handler execute the middleware. +func (t *TokenMiddleware) Handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var err error + token := t.getTokenValueFromCookie(r) + + if token == nil { + log.Println("[Middleware:Token] Token not found") + token, err = t.store.CreateToken() + if err != nil { + log.Println(err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + + cookie := &http.Cookie{ + Name: "tokenID", + Value: token.ID, + Path: "/", + Secure: r.URL.Scheme == "https", + HttpOnly: true, + } + + http.SetCookie(w, cookie) + } else { + log.Println("[Middleware:Token]", token) + } + + isTokenValid := token.Value == r.FormValue("csrf") || token.Value == r.Header.Get("X-Csrf-Token") + + if r.Method == "POST" && !isTokenValid { + log.Println("[Middleware:CSRF] Invalid or missing CSRF token!") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Invalid or missing CSRF token!")) + } else { + ctx := r.Context() + ctx = context.WithValue(ctx, TokenContextKey, token.Value) + next.ServeHTTP(w, r.WithContext(ctx)) + } + }) +} + +func (t *TokenMiddleware) getTokenValueFromCookie(r *http.Request) *model.Token { + tokenCookie, err := r.Cookie("tokenID") + if err == http.ErrNoCookie { + return nil + } + + token, err := t.store.Token(tokenCookie.Value) + if err != nil { + log.Println(err) + return nil + } + + return token +} + +// NewTokenMiddleware returns a new TokenMiddleware. +func NewTokenMiddleware(s *storage.Storage) *TokenMiddleware { + return &TokenMiddleware{store: s} +} diff --git a/server/routes.go b/server/routes.go index 8c584fa6..903f24c4 100644 --- a/server/routes.go +++ b/server/routes.go @@ -37,7 +37,7 @@ func getRoutes(cfg *config.Config, store *storage.Storage, feedHandler *feed.Han uiHandler := core.NewHandler(store, router, templateEngine, translator, middleware.NewChain( middleware.NewSessionMiddleware(store, router).Handler, - middleware.Csrf, + middleware.NewTokenMiddleware(store).Handler, )) router.Handle("/v1/users", apiHandler.Use(apiController.CreateUser)).Methods("POST") diff --git a/sql/schema_version_3.sql b/sql/schema_version_3.sql new file mode 100644 index 00000000..d58e35d9 --- /dev/null +++ b/sql/schema_version_3.sql @@ -0,0 +1,6 @@ +create table tokens ( + id text not null, + value text not null, + created_at timestamp with time zone not null default now(), + primary key(id, value) +); \ No newline at end of file diff --git a/sql/sql.go b/sql/sql.go index 14457c08..1f6b597b 100644 --- a/sql/sql.go +++ b/sql/sql.go @@ -1,5 +1,5 @@ // Code generated by go generate; DO NOT EDIT. -// 2017-11-27 21:07:53.208711992 -0800 PST m=+0.002898220 +// 2017-12-01 21:46:13.639273113 -0800 PST m=+0.002204900 package sql @@ -112,9 +112,16 @@ create table feed_icons ( alter table users add column extra hstore; create index users_extra_idx on users using gin(extra); `, + "schema_version_3": `create table tokens ( + id text not null, + value text not null, + created_at timestamp with time zone not null default now(), + primary key(id, value) +);`, } var SqlMapChecksums = map[string]string{ "schema_version_1": "cb85ca7dd97a6e1348e00b65ea004253a7165bed9a772746613276e47ef93213", "schema_version_2": "e8e9ff32478df04fcddad10a34cba2e8bb1e67e7977b5bd6cdc4c31ec94282b4", + "schema_version_3": "a54745dbc1c51c000f74d4e5068f1e2f43e83309f023415b1749a47d5c1e0f12", } diff --git a/storage/migration.go b/storage/migration.go index 994e2dd0..5060a34b 100644 --- a/storage/migration.go +++ b/storage/migration.go @@ -12,7 +12,7 @@ import ( "github.com/miniflux/miniflux2/sql" ) -const schemaVersion = 2 +const schemaVersion = 3 // Migrate run database migrations. func (s *Storage) Migrate() { diff --git a/storage/token.go b/storage/token.go new file mode 100644 index 00000000..dd14704d --- /dev/null +++ b/storage/token.go @@ -0,0 +1,48 @@ +// Copyright 2017 Frédéric Guillot. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package storage + +import ( + "database/sql" + "fmt" + + "github.com/miniflux/miniflux2/helper" + "github.com/miniflux/miniflux2/model" +) + +// CreateToken creates a new token. +func (s *Storage) CreateToken() (*model.Token, error) { + token := model.Token{ + ID: helper.GenerateRandomString(32), + Value: helper.GenerateRandomString(64), + } + + query := "INSERT INTO tokens (id, value) VALUES ($1, $2)" + _, err := s.db.Exec(query, token.ID, token.Value) + if err != nil { + return nil, fmt.Errorf("unable to create token: %v", err) + } + + return &token, nil +} + +// Token returns a Token. +func (s *Storage) Token(id string) (*model.Token, error) { + var token model.Token + + query := "SELECT id, value FROM tokens WHERE id=$1" + err := s.db.QueryRow(query, id).Scan( + &token.ID, + &token.Value, + ) + + if err == sql.ErrNoRows { + return nil, fmt.Errorf("token not found: %s", id) + } else if err != nil { + return nil, fmt.Errorf("unable to fetch token: %v", err) + } + + return &token, nil +}