hanayo/vendor/github.com/felipeweb/osin-mysql/mysql.go
2019-02-23 13:29:15 +00:00

364 lines
11 KiB
Go

// Package mysql is a osin storage implementation for mysql.
package mysql
import (
"database/sql"
"fmt"
"log"
"strings"
"time"
"github.com/RangelReale/osin"
"github.com/ansel1/merry"
"github.com/felipeweb/gopher-utils"
// driver for mysql db
_ "github.com/go-sql-driver/mysql"
)
var schemas = []string{`CREATE TABLE IF NOT EXISTS {prefix}client (
id varchar(255) BINARY NOT NULL PRIMARY KEY,
secret varchar(255) NOT NULL,
extra varchar(255) NOT NULL,
redirect_uri varchar(255) NOT NULL
)`, `CREATE TABLE IF NOT EXISTS {prefix}authorize (
client varchar(255) BINARY NOT NULL,
code varchar(255) BINARY NOT NULL PRIMARY KEY,
expires_in int(10) NOT NULL,
scope varchar(255) NOT NULL,
redirect_uri varchar(255) NOT NULL,
state varchar(255) NOT NULL,
extra varchar(255) NOT NULL,
created_at timestamp NOT NULL
)`, `CREATE TABLE IF NOT EXISTS {prefix}access (
client varchar(255) BINARY NOT NULL,
authorize varchar(255) BINARY NOT NULL,
previous varchar(255) BINARY NOT NULL,
access_token varchar(255) BINARY NOT NULL PRIMARY KEY,
refresh_token varchar(255) BINARY NOT NULL,
expires_in int(10) NOT NULL,
scope varchar(255) NOT NULL,
redirect_uri varchar(255) NOT NULL,
extra varchar(255) NOT NULL,
created_at timestamp NOT NULL
)`, `CREATE TABLE IF NOT EXISTS {prefix}refresh (
token varchar(255) BINARY NOT NULL PRIMARY KEY,
access varchar(255) BINARY NOT NULL
)`, `CREATE TABLE IF NOT EXISTS {prefix}expires (
id int(11) NOT NULL PRIMARY KEY AUTO_INCREMENT,
token varchar(255) BINARY NOT NULL,
expires_at timestamp NOT NULL,
INDEX expires_index (expires_at),
INDEX token_expires_index (token)
)`,
}
// Storage implements interface "github.com/RangelReale/osin".Storage and interface "github.com/felipeweb/osin-mysql/storage".Storage
type Storage struct {
db *sql.DB
tablePrefix string
}
// New returns a new mysql storage instance.
func New(db *sql.DB, tablePrefix string) *Storage {
return &Storage{db, tablePrefix}
}
// CreateSchemas creates the schemata, if they do not exist yet in the database. Returns an error if something went wrong.
func (s *Storage) CreateSchemas() error {
for k, schema := range schemas {
schema := strings.Replace(schema, "{prefix}", s.tablePrefix, 4)
if _, err := s.db.Exec(schema); err != nil {
log.Printf("Error creating schema %d: %s", k, schema)
return err
}
}
return nil
}
// Clone the storage if needed. For example, using mgo, you can clone the session with session.Clone
// to avoid concurrent access problems.
// This is to avoid cloning the connection at each method access.
// Can return itself if not a problem.
func (s *Storage) Clone() osin.Storage {
return s
}
// Close the resources the Storage potentially holds (using Clone for example)
func (s *Storage) Close() {
}
// GetClient loads the client by id
func (s *Storage) GetClient(id string) (osin.Client, error) {
row := s.db.QueryRow(fmt.Sprintf("SELECT id, secret, redirect_uri, extra FROM %sclient WHERE id=?", s.tablePrefix), id)
var c osin.DefaultClient
var extra string
if err := row.Scan(&c.Id, &c.Secret, &c.RedirectUri, &extra); err == sql.ErrNoRows {
return nil, osin.ErrNotFound
} else if err != nil {
return nil, merry.Wrap(err)
}
c.UserData = extra
return &c, nil
}
// UpdateClient updates the client (identified by it's id) and replaces the values with the values of client.
func (s *Storage) UpdateClient(c osin.Client) error {
data := gopher_utils.ToStr(c.GetUserData())
if _, err := s.db.Exec(fmt.Sprintf("UPDATE %sclient SET secret=?, redirect_uri=?, extra=? WHERE id=?", s.tablePrefix), c.GetSecret(), c.GetRedirectUri(), data, c.GetId()); err != nil {
return merry.Wrap(err)
}
return nil
}
// CreateClient stores the client in the database and returns an error, if something went wrong.
func (s *Storage) CreateClient(c osin.Client) error {
data := gopher_utils.ToStr(c.GetUserData())
if _, err := s.db.Exec(fmt.Sprintf("INSERT INTO %sclient (id, secret, redirect_uri, extra) VALUES (?, ?, ?, ?)", s.tablePrefix), c.GetId(), c.GetSecret(), c.GetRedirectUri(), data); err != nil {
return merry.Wrap(err)
}
return nil
}
// RemoveClient removes a client (identified by id) from the database. Returns an error if something went wrong.
func (s *Storage) RemoveClient(id string) (err error) {
if _, err = s.db.Exec(fmt.Sprintf("DELETE FROM %sclient WHERE id=?", s.tablePrefix), id); err != nil {
return merry.Wrap(err)
}
return nil
}
// SaveAuthorize saves authorize data.
func (s *Storage) SaveAuthorize(data *osin.AuthorizeData) (err error) {
extra := gopher_utils.ToStr(data.UserData)
if err != nil {
return err
}
if _, err = s.db.Exec(
fmt.Sprintf("INSERT INTO %sauthorize (client, code, expires_in, scope, redirect_uri, state, created_at, extra) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", s.tablePrefix),
data.Client.GetId(),
data.Code,
data.ExpiresIn,
data.Scope,
data.RedirectUri,
data.State,
data.CreatedAt,
extra,
); err != nil {
return merry.Wrap(err)
}
if err = s.AddExpireAtData(data.Code, data.ExpireAt()); err != nil {
return merry.Wrap(err)
}
return nil
}
// LoadAuthorize looks up AuthorizeData by a code.
// Client information MUST be loaded together.
// Optionally can return error if expired.
func (s *Storage) LoadAuthorize(code string) (*osin.AuthorizeData, error) {
var data osin.AuthorizeData
var extra string
var cid string
if err := s.db.QueryRow(fmt.Sprintf("SELECT client, code, expires_in, scope, redirect_uri, state, created_at, extra FROM %sauthorize WHERE code=? LIMIT 1", s.tablePrefix), code).Scan(&cid, &data.Code, &data.ExpiresIn, &data.Scope, &data.RedirectUri, &data.State, &data.CreatedAt, &extra); err == sql.ErrNoRows {
return nil, osin.ErrNotFound
} else if err != nil {
return nil, merry.Wrap(err)
}
data.UserData = extra
c, err := s.GetClient(cid)
if err != nil {
return nil, err
}
if data.ExpireAt().Before(time.Now()) {
return nil, merry.Errorf("Token expired at %s.", data.ExpireAt().String())
}
data.Client = c
return &data, nil
}
// RemoveAuthorize revokes or deletes the authorization code.
func (s *Storage) RemoveAuthorize(code string) (err error) {
if _, err = s.db.Exec(fmt.Sprintf("DELETE FROM %sauthorize WHERE code=?", s.tablePrefix), code); err != nil {
return merry.Wrap(err)
}
if err = s.RemoveExpireAtData(code); err != nil {
return merry.Wrap(err)
}
return nil
}
// SaveAccess writes AccessData.
// If RefreshToken is not blank, it must save in a way that can be loaded using LoadRefresh.
func (s *Storage) SaveAccess(data *osin.AccessData) (err error) {
prev := ""
authorizeData := &osin.AuthorizeData{}
if data.AccessData != nil {
prev = data.AccessData.AccessToken
}
if data.AuthorizeData != nil {
authorizeData = data.AuthorizeData
}
extra := gopher_utils.ToStr(data.UserData)
tx, err := s.db.Begin()
if err != nil {
return merry.Wrap(err)
}
if data.RefreshToken != "" {
if err := s.saveRefresh(tx, data.RefreshToken, data.AccessToken); err != nil {
return err
}
}
if data.Client == nil {
return merry.New("data.Client must not be nil")
}
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %saccess (client, authorize, previous, access_token, refresh_token, expires_in, scope, redirect_uri, created_at, extra) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", s.tablePrefix), data.Client.GetId(), authorizeData.Code, prev, data.AccessToken, data.RefreshToken, data.ExpiresIn, data.Scope, data.RedirectUri, data.CreatedAt, extra)
if err != nil {
if rbe := tx.Rollback(); rbe != nil {
return merry.Wrap(rbe)
}
return merry.Wrap(err)
}
if err = s.AddExpireAtData(data.AccessToken, data.ExpireAt()); err != nil {
return merry.Wrap(err)
}
if err = tx.Commit(); err != nil {
return merry.Wrap(err)
}
return nil
}
// LoadAccess retrieves access data by token. Client information MUST be loaded together.
// AuthorizeData and AccessData DON'T NEED to be loaded if not easily available.
// Optionally can return error if expired.
func (s *Storage) LoadAccess(code string) (*osin.AccessData, error) {
var extra, cid, prevAccessToken, authorizeCode string
var result osin.AccessData
if err := s.db.QueryRow(
fmt.Sprintf("SELECT client, authorize, previous, access_token, refresh_token, expires_in, scope, redirect_uri, created_at, extra FROM %saccess WHERE access_token=? LIMIT 1", s.tablePrefix),
code,
).Scan(
&cid,
&authorizeCode,
&prevAccessToken,
&result.AccessToken,
&result.RefreshToken,
&result.ExpiresIn,
&result.Scope,
&result.RedirectUri,
&result.CreatedAt,
&extra,
); err == sql.ErrNoRows {
return nil, osin.ErrNotFound
} else if err != nil {
return nil, merry.Wrap(err)
}
result.UserData = extra
client, err := s.GetClient(cid)
if err != nil {
return nil, err
}
result.Client = client
result.AuthorizeData, _ = s.LoadAuthorize(authorizeCode)
prevAccess, _ := s.LoadAccess(prevAccessToken)
result.AccessData = prevAccess
return &result, nil
}
// RemoveAccess revokes or deletes an AccessData.
func (s *Storage) RemoveAccess(code string) (err error) {
if _, err = s.db.Exec(fmt.Sprintf("DELETE FROM %saccess WHERE access_token=?", s.tablePrefix), code); err != nil {
return merry.Wrap(err)
}
if err = s.RemoveExpireAtData(code); err != nil {
return merry.Wrap(err)
}
return nil
}
// LoadRefresh retrieves refresh AccessData. Client information MUST be loaded together.
// AuthorizeData and AccessData DON'T NEED to be loaded if not easily available.
// Optionally can return error if expired.
func (s *Storage) LoadRefresh(code string) (*osin.AccessData, error) {
row := s.db.QueryRow(fmt.Sprintf("SELECT access FROM %srefresh WHERE token=? LIMIT 1", s.tablePrefix), code)
var access string
if err := row.Scan(&access); err == sql.ErrNoRows {
return nil, osin.ErrNotFound
} else if err != nil {
return nil, merry.Wrap(err)
}
return s.LoadAccess(access)
}
// RemoveRefresh revokes or deletes refresh AccessData.
func (s *Storage) RemoveRefresh(code string) error {
_, err := s.db.Exec(fmt.Sprintf("DELETE FROM %srefresh WHERE token=?", s.tablePrefix), code)
if err != nil {
return merry.Wrap(err)
}
return nil
}
// CreateClientWithInformation Makes easy to create a osin.DefaultClient
func (s *Storage) CreateClientWithInformation(id string, secret string, redirectURI string, userData interface{}) osin.Client {
return &osin.DefaultClient{
Id: id,
Secret: secret,
RedirectUri: redirectURI,
UserData: userData,
}
}
func (s *Storage) saveRefresh(tx *sql.Tx, refresh, access string) (err error) {
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %srefresh (token, access) VALUES (?, ?)", s.tablePrefix), refresh, access)
if err != nil {
if rbe := tx.Rollback(); rbe != nil {
return merry.Wrap(rbe)
}
return merry.Wrap(err)
}
return nil
}
// AddExpireAtData add info in expires table
func (s *Storage) AddExpireAtData(code string, expireAt time.Time) error {
if _, err := s.db.Exec(
fmt.Sprintf("INSERT INTO %sexpires(token, expires_at) VALUES(?, ?)", s.tablePrefix),
code,
expireAt,
); err != nil {
return merry.Wrap(err)
}
return nil
}
// RemoveExpireAtData remove info in expires table
func (s *Storage) RemoveExpireAtData(code string) error {
if _, err := s.db.Exec(
fmt.Sprintf("DELETE FROM %sexpires WHERE token=?", s.tablePrefix),
code,
); err != nil {
return merry.Wrap(err)
}
return nil
}