364 lines
11 KiB
Go
364 lines
11 KiB
Go
|
// Package mysql is a osin storage implementation for mysql.
|
||
|
package mysql
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
"fmt"
|
||
|
"log"
|
||
|
"strings"
|
||
|
"time"
|
||
|
|
||
|
"github.com/RangelReale/osin"
|
||
|
"github.com/ansel1/merry"
|
||
|
"github.com/felipeweb/gopher-utils"
|
||
|
// driver for mysql db
|
||
|
_ "github.com/go-sql-driver/mysql"
|
||
|
)
|
||
|
|
||
|
var schemas = []string{`CREATE TABLE IF NOT EXISTS {prefix}client (
|
||
|
id varchar(255) BINARY NOT NULL PRIMARY KEY,
|
||
|
secret varchar(255) NOT NULL,
|
||
|
extra varchar(255) NOT NULL,
|
||
|
redirect_uri varchar(255) NOT NULL
|
||
|
)`, `CREATE TABLE IF NOT EXISTS {prefix}authorize (
|
||
|
client varchar(255) BINARY NOT NULL,
|
||
|
code varchar(255) BINARY NOT NULL PRIMARY KEY,
|
||
|
expires_in int(10) NOT NULL,
|
||
|
scope varchar(255) NOT NULL,
|
||
|
redirect_uri varchar(255) NOT NULL,
|
||
|
state varchar(255) NOT NULL,
|
||
|
extra varchar(255) NOT NULL,
|
||
|
created_at timestamp NOT NULL
|
||
|
)`, `CREATE TABLE IF NOT EXISTS {prefix}access (
|
||
|
client varchar(255) BINARY NOT NULL,
|
||
|
authorize varchar(255) BINARY NOT NULL,
|
||
|
previous varchar(255) BINARY NOT NULL,
|
||
|
access_token varchar(255) BINARY NOT NULL PRIMARY KEY,
|
||
|
refresh_token varchar(255) BINARY NOT NULL,
|
||
|
expires_in int(10) NOT NULL,
|
||
|
scope varchar(255) NOT NULL,
|
||
|
redirect_uri varchar(255) NOT NULL,
|
||
|
extra varchar(255) NOT NULL,
|
||
|
created_at timestamp NOT NULL
|
||
|
)`, `CREATE TABLE IF NOT EXISTS {prefix}refresh (
|
||
|
token varchar(255) BINARY NOT NULL PRIMARY KEY,
|
||
|
access varchar(255) BINARY NOT NULL
|
||
|
)`, `CREATE TABLE IF NOT EXISTS {prefix}expires (
|
||
|
id int(11) NOT NULL PRIMARY KEY AUTO_INCREMENT,
|
||
|
token varchar(255) BINARY NOT NULL,
|
||
|
expires_at timestamp NOT NULL,
|
||
|
INDEX expires_index (expires_at),
|
||
|
INDEX token_expires_index (token)
|
||
|
)`,
|
||
|
}
|
||
|
|
||
|
// Storage implements interface "github.com/RangelReale/osin".Storage and interface "github.com/felipeweb/osin-mysql/storage".Storage
|
||
|
type Storage struct {
|
||
|
db *sql.DB
|
||
|
tablePrefix string
|
||
|
}
|
||
|
|
||
|
// New returns a new mysql storage instance.
|
||
|
func New(db *sql.DB, tablePrefix string) *Storage {
|
||
|
return &Storage{db, tablePrefix}
|
||
|
}
|
||
|
|
||
|
// CreateSchemas creates the schemata, if they do not exist yet in the database. Returns an error if something went wrong.
|
||
|
func (s *Storage) CreateSchemas() error {
|
||
|
for k, schema := range schemas {
|
||
|
schema := strings.Replace(schema, "{prefix}", s.tablePrefix, 4)
|
||
|
if _, err := s.db.Exec(schema); err != nil {
|
||
|
log.Printf("Error creating schema %d: %s", k, schema)
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Clone the storage if needed. For example, using mgo, you can clone the session with session.Clone
|
||
|
// to avoid concurrent access problems.
|
||
|
// This is to avoid cloning the connection at each method access.
|
||
|
// Can return itself if not a problem.
|
||
|
func (s *Storage) Clone() osin.Storage {
|
||
|
return s
|
||
|
}
|
||
|
|
||
|
// Close the resources the Storage potentially holds (using Clone for example)
|
||
|
func (s *Storage) Close() {
|
||
|
}
|
||
|
|
||
|
// GetClient loads the client by id
|
||
|
func (s *Storage) GetClient(id string) (osin.Client, error) {
|
||
|
row := s.db.QueryRow(fmt.Sprintf("SELECT id, secret, redirect_uri, extra FROM %sclient WHERE id=?", s.tablePrefix), id)
|
||
|
var c osin.DefaultClient
|
||
|
var extra string
|
||
|
|
||
|
if err := row.Scan(&c.Id, &c.Secret, &c.RedirectUri, &extra); err == sql.ErrNoRows {
|
||
|
return nil, osin.ErrNotFound
|
||
|
} else if err != nil {
|
||
|
return nil, merry.Wrap(err)
|
||
|
}
|
||
|
c.UserData = extra
|
||
|
return &c, nil
|
||
|
}
|
||
|
|
||
|
// UpdateClient updates the client (identified by it's id) and replaces the values with the values of client.
|
||
|
func (s *Storage) UpdateClient(c osin.Client) error {
|
||
|
data := gopher_utils.ToStr(c.GetUserData())
|
||
|
|
||
|
if _, err := s.db.Exec(fmt.Sprintf("UPDATE %sclient SET secret=?, redirect_uri=?, extra=? WHERE id=?", s.tablePrefix), c.GetSecret(), c.GetRedirectUri(), data, c.GetId()); err != nil {
|
||
|
return merry.Wrap(err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// CreateClient stores the client in the database and returns an error, if something went wrong.
|
||
|
func (s *Storage) CreateClient(c osin.Client) error {
|
||
|
data := gopher_utils.ToStr(c.GetUserData())
|
||
|
|
||
|
if _, err := s.db.Exec(fmt.Sprintf("INSERT INTO %sclient (id, secret, redirect_uri, extra) VALUES (?, ?, ?, ?)", s.tablePrefix), c.GetId(), c.GetSecret(), c.GetRedirectUri(), data); err != nil {
|
||
|
return merry.Wrap(err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// RemoveClient removes a client (identified by id) from the database. Returns an error if something went wrong.
|
||
|
func (s *Storage) RemoveClient(id string) (err error) {
|
||
|
if _, err = s.db.Exec(fmt.Sprintf("DELETE FROM %sclient WHERE id=?", s.tablePrefix), id); err != nil {
|
||
|
return merry.Wrap(err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// SaveAuthorize saves authorize data.
|
||
|
func (s *Storage) SaveAuthorize(data *osin.AuthorizeData) (err error) {
|
||
|
extra := gopher_utils.ToStr(data.UserData)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if _, err = s.db.Exec(
|
||
|
fmt.Sprintf("INSERT INTO %sauthorize (client, code, expires_in, scope, redirect_uri, state, created_at, extra) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", s.tablePrefix),
|
||
|
data.Client.GetId(),
|
||
|
data.Code,
|
||
|
data.ExpiresIn,
|
||
|
data.Scope,
|
||
|
data.RedirectUri,
|
||
|
data.State,
|
||
|
data.CreatedAt,
|
||
|
extra,
|
||
|
); err != nil {
|
||
|
return merry.Wrap(err)
|
||
|
}
|
||
|
if err = s.AddExpireAtData(data.Code, data.ExpireAt()); err != nil {
|
||
|
return merry.Wrap(err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// LoadAuthorize looks up AuthorizeData by a code.
|
||
|
// Client information MUST be loaded together.
|
||
|
// Optionally can return error if expired.
|
||
|
func (s *Storage) LoadAuthorize(code string) (*osin.AuthorizeData, error) {
|
||
|
var data osin.AuthorizeData
|
||
|
var extra string
|
||
|
var cid string
|
||
|
if err := s.db.QueryRow(fmt.Sprintf("SELECT client, code, expires_in, scope, redirect_uri, state, created_at, extra FROM %sauthorize WHERE code=? LIMIT 1", s.tablePrefix), code).Scan(&cid, &data.Code, &data.ExpiresIn, &data.Scope, &data.RedirectUri, &data.State, &data.CreatedAt, &extra); err == sql.ErrNoRows {
|
||
|
return nil, osin.ErrNotFound
|
||
|
} else if err != nil {
|
||
|
return nil, merry.Wrap(err)
|
||
|
}
|
||
|
data.UserData = extra
|
||
|
|
||
|
c, err := s.GetClient(cid)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if data.ExpireAt().Before(time.Now()) {
|
||
|
return nil, merry.Errorf("Token expired at %s.", data.ExpireAt().String())
|
||
|
}
|
||
|
|
||
|
data.Client = c
|
||
|
return &data, nil
|
||
|
}
|
||
|
|
||
|
// RemoveAuthorize revokes or deletes the authorization code.
|
||
|
func (s *Storage) RemoveAuthorize(code string) (err error) {
|
||
|
if _, err = s.db.Exec(fmt.Sprintf("DELETE FROM %sauthorize WHERE code=?", s.tablePrefix), code); err != nil {
|
||
|
return merry.Wrap(err)
|
||
|
}
|
||
|
if err = s.RemoveExpireAtData(code); err != nil {
|
||
|
return merry.Wrap(err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// SaveAccess writes AccessData.
|
||
|
// If RefreshToken is not blank, it must save in a way that can be loaded using LoadRefresh.
|
||
|
func (s *Storage) SaveAccess(data *osin.AccessData) (err error) {
|
||
|
prev := ""
|
||
|
authorizeData := &osin.AuthorizeData{}
|
||
|
|
||
|
if data.AccessData != nil {
|
||
|
prev = data.AccessData.AccessToken
|
||
|
}
|
||
|
|
||
|
if data.AuthorizeData != nil {
|
||
|
authorizeData = data.AuthorizeData
|
||
|
}
|
||
|
|
||
|
extra := gopher_utils.ToStr(data.UserData)
|
||
|
|
||
|
tx, err := s.db.Begin()
|
||
|
if err != nil {
|
||
|
return merry.Wrap(err)
|
||
|
}
|
||
|
|
||
|
if data.RefreshToken != "" {
|
||
|
if err := s.saveRefresh(tx, data.RefreshToken, data.AccessToken); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if data.Client == nil {
|
||
|
return merry.New("data.Client must not be nil")
|
||
|
}
|
||
|
|
||
|
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %saccess (client, authorize, previous, access_token, refresh_token, expires_in, scope, redirect_uri, created_at, extra) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", s.tablePrefix), data.Client.GetId(), authorizeData.Code, prev, data.AccessToken, data.RefreshToken, data.ExpiresIn, data.Scope, data.RedirectUri, data.CreatedAt, extra)
|
||
|
if err != nil {
|
||
|
if rbe := tx.Rollback(); rbe != nil {
|
||
|
return merry.Wrap(rbe)
|
||
|
}
|
||
|
return merry.Wrap(err)
|
||
|
}
|
||
|
|
||
|
if err = s.AddExpireAtData(data.AccessToken, data.ExpireAt()); err != nil {
|
||
|
return merry.Wrap(err)
|
||
|
}
|
||
|
if err = tx.Commit(); err != nil {
|
||
|
return merry.Wrap(err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// LoadAccess retrieves access data by token. Client information MUST be loaded together.
|
||
|
// AuthorizeData and AccessData DON'T NEED to be loaded if not easily available.
|
||
|
// Optionally can return error if expired.
|
||
|
func (s *Storage) LoadAccess(code string) (*osin.AccessData, error) {
|
||
|
var extra, cid, prevAccessToken, authorizeCode string
|
||
|
var result osin.AccessData
|
||
|
|
||
|
if err := s.db.QueryRow(
|
||
|
fmt.Sprintf("SELECT client, authorize, previous, access_token, refresh_token, expires_in, scope, redirect_uri, created_at, extra FROM %saccess WHERE access_token=? LIMIT 1", s.tablePrefix),
|
||
|
code,
|
||
|
).Scan(
|
||
|
&cid,
|
||
|
&authorizeCode,
|
||
|
&prevAccessToken,
|
||
|
&result.AccessToken,
|
||
|
&result.RefreshToken,
|
||
|
&result.ExpiresIn,
|
||
|
&result.Scope,
|
||
|
&result.RedirectUri,
|
||
|
&result.CreatedAt,
|
||
|
&extra,
|
||
|
); err == sql.ErrNoRows {
|
||
|
return nil, osin.ErrNotFound
|
||
|
} else if err != nil {
|
||
|
return nil, merry.Wrap(err)
|
||
|
}
|
||
|
|
||
|
result.UserData = extra
|
||
|
client, err := s.GetClient(cid)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
result.Client = client
|
||
|
result.AuthorizeData, _ = s.LoadAuthorize(authorizeCode)
|
||
|
prevAccess, _ := s.LoadAccess(prevAccessToken)
|
||
|
result.AccessData = prevAccess
|
||
|
return &result, nil
|
||
|
}
|
||
|
|
||
|
// RemoveAccess revokes or deletes an AccessData.
|
||
|
func (s *Storage) RemoveAccess(code string) (err error) {
|
||
|
if _, err = s.db.Exec(fmt.Sprintf("DELETE FROM %saccess WHERE access_token=?", s.tablePrefix), code); err != nil {
|
||
|
return merry.Wrap(err)
|
||
|
}
|
||
|
if err = s.RemoveExpireAtData(code); err != nil {
|
||
|
return merry.Wrap(err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// LoadRefresh retrieves refresh AccessData. Client information MUST be loaded together.
|
||
|
// AuthorizeData and AccessData DON'T NEED to be loaded if not easily available.
|
||
|
// Optionally can return error if expired.
|
||
|
func (s *Storage) LoadRefresh(code string) (*osin.AccessData, error) {
|
||
|
row := s.db.QueryRow(fmt.Sprintf("SELECT access FROM %srefresh WHERE token=? LIMIT 1", s.tablePrefix), code)
|
||
|
var access string
|
||
|
if err := row.Scan(&access); err == sql.ErrNoRows {
|
||
|
return nil, osin.ErrNotFound
|
||
|
} else if err != nil {
|
||
|
return nil, merry.Wrap(err)
|
||
|
}
|
||
|
return s.LoadAccess(access)
|
||
|
}
|
||
|
|
||
|
// RemoveRefresh revokes or deletes refresh AccessData.
|
||
|
func (s *Storage) RemoveRefresh(code string) error {
|
||
|
_, err := s.db.Exec(fmt.Sprintf("DELETE FROM %srefresh WHERE token=?", s.tablePrefix), code)
|
||
|
if err != nil {
|
||
|
return merry.Wrap(err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// CreateClientWithInformation Makes easy to create a osin.DefaultClient
|
||
|
func (s *Storage) CreateClientWithInformation(id string, secret string, redirectURI string, userData interface{}) osin.Client {
|
||
|
return &osin.DefaultClient{
|
||
|
Id: id,
|
||
|
Secret: secret,
|
||
|
RedirectUri: redirectURI,
|
||
|
UserData: userData,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *Storage) saveRefresh(tx *sql.Tx, refresh, access string) (err error) {
|
||
|
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %srefresh (token, access) VALUES (?, ?)", s.tablePrefix), refresh, access)
|
||
|
if err != nil {
|
||
|
if rbe := tx.Rollback(); rbe != nil {
|
||
|
return merry.Wrap(rbe)
|
||
|
}
|
||
|
return merry.Wrap(err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// AddExpireAtData add info in expires table
|
||
|
func (s *Storage) AddExpireAtData(code string, expireAt time.Time) error {
|
||
|
if _, err := s.db.Exec(
|
||
|
fmt.Sprintf("INSERT INTO %sexpires(token, expires_at) VALUES(?, ?)", s.tablePrefix),
|
||
|
code,
|
||
|
expireAt,
|
||
|
); err != nil {
|
||
|
return merry.Wrap(err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// RemoveExpireAtData remove info in expires table
|
||
|
func (s *Storage) RemoveExpireAtData(code string) error {
|
||
|
if _, err := s.db.Exec(
|
||
|
fmt.Sprintf("DELETE FROM %sexpires WHERE token=?", s.tablePrefix),
|
||
|
code,
|
||
|
); err != nil {
|
||
|
return merry.Wrap(err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|