hanayo/vendor/github.com/RangelReale/osin/access.go
2019-02-23 13:29:15 +00:00

563 lines
14 KiB
Go

package osin
import (
"crypto/sha256"
"encoding/base64"
"errors"
"net/http"
"strings"
"time"
)
// AccessRequestType is the type for OAuth param `grant_type`
type AccessRequestType string
const (
AUTHORIZATION_CODE AccessRequestType = "authorization_code"
REFRESH_TOKEN AccessRequestType = "refresh_token"
PASSWORD AccessRequestType = "password"
CLIENT_CREDENTIALS AccessRequestType = "client_credentials"
ASSERTION AccessRequestType = "assertion"
IMPLICIT AccessRequestType = "__implicit"
)
// AccessRequest is a request for access tokens
type AccessRequest struct {
Type AccessRequestType
Code string
Client Client
AuthorizeData *AuthorizeData
AccessData *AccessData
// Force finish to use this access data, to allow access data reuse
ForceAccessData *AccessData
RedirectUri string
Scope string
Username string
Password string
AssertionType string
Assertion string
// Set if request is authorized
Authorized bool
// Token expiration in seconds. Change if different from default
Expiration int32
// Set if a refresh token should be generated
GenerateRefresh bool
// Data to be passed to storage. Not used by the library.
UserData interface{}
// HttpRequest *http.Request for special use
HttpRequest *http.Request
// Optional code_verifier as described in rfc7636
CodeVerifier string
}
// AccessData represents an access grant (tokens, expiration, client, etc)
type AccessData struct {
// Client information
Client Client
// Authorize data, for authorization code
AuthorizeData *AuthorizeData
// Previous access data, for refresh token
AccessData *AccessData
// Access token
AccessToken string
// Refresh Token. Can be blank
RefreshToken string
// Token expiration in seconds
ExpiresIn int32
// Requested scope
Scope string
// Redirect Uri from request
RedirectUri string
// Date created
CreatedAt time.Time
// Data to be passed to storage. Not used by the library.
UserData interface{}
}
// IsExpired returns true if access expired
func (d *AccessData) IsExpired() bool {
return d.IsExpiredAt(time.Now())
}
// IsExpiredAt returns true if access expires at time 't'
func (d *AccessData) IsExpiredAt(t time.Time) bool {
return d.ExpireAt().Before(t)
}
// ExpireAt returns the expiration date
func (d *AccessData) ExpireAt() time.Time {
return d.CreatedAt.Add(time.Duration(d.ExpiresIn) * time.Second)
}
// AccessTokenGen generates access tokens
type AccessTokenGen interface {
GenerateAccessToken(data *AccessData, generaterefresh bool) (accesstoken string, refreshtoken string, err error)
}
// HandleAccessRequest is the http.HandlerFunc for handling access token requests
func (s *Server) HandleAccessRequest(w *Response, r *http.Request) *AccessRequest {
// Only allow GET or POST
if r.Method == "GET" {
if !s.Config.AllowGetAccessRequest {
w.SetError(E_INVALID_REQUEST, "")
w.InternalError = errors.New("Request must be POST")
return nil
}
} else if r.Method != "POST" {
w.SetError(E_INVALID_REQUEST, "")
w.InternalError = errors.New("Request must be POST")
return nil
}
err := r.ParseForm()
if err != nil {
w.SetError(E_INVALID_REQUEST, "")
w.InternalError = err
return nil
}
grantType := AccessRequestType(r.Form.Get("grant_type"))
if s.Config.AllowedAccessTypes.Exists(grantType) {
switch grantType {
case AUTHORIZATION_CODE:
return s.handleAuthorizationCodeRequest(w, r)
case REFRESH_TOKEN:
return s.handleRefreshTokenRequest(w, r)
case PASSWORD:
return s.handlePasswordRequest(w, r)
case CLIENT_CREDENTIALS:
return s.handleClientCredentialsRequest(w, r)
case ASSERTION:
return s.handleAssertionRequest(w, r)
}
}
w.SetError(E_UNSUPPORTED_GRANT_TYPE, "")
return nil
}
func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *AccessRequest {
// get client authentication
auth := getClientAuth(w, r, s.Config.AllowClientSecretInParams)
if auth == nil {
return nil
}
// generate access token
ret := &AccessRequest{
Type: AUTHORIZATION_CODE,
Code: r.Form.Get("code"),
CodeVerifier: r.Form.Get("code_verifier"),
RedirectUri: r.Form.Get("redirect_uri"),
GenerateRefresh: true,
Expiration: s.Config.AccessExpiration,
HttpRequest: r,
}
// "code" is required
if ret.Code == "" {
w.SetError(E_INVALID_GRANT, "")
return nil
}
// must have a valid client
if ret.Client = getClient(auth, w.Storage, w); ret.Client == nil {
return nil
}
// must be a valid authorization code
var err error
ret.AuthorizeData, err = w.Storage.LoadAuthorize(ret.Code)
if err != nil {
w.SetError(E_INVALID_GRANT, "")
w.InternalError = err
return nil
}
if ret.AuthorizeData == nil {
w.SetError(E_UNAUTHORIZED_CLIENT, "")
return nil
}
if ret.AuthorizeData.Client == nil {
w.SetError(E_UNAUTHORIZED_CLIENT, "")
return nil
}
if ret.AuthorizeData.Client.GetRedirectUri() == "" {
w.SetError(E_UNAUTHORIZED_CLIENT, "")
return nil
}
if ret.AuthorizeData.IsExpiredAt(s.Now()) {
w.SetError(E_INVALID_GRANT, "")
return nil
}
// code must be from the client
if ret.AuthorizeData.Client.GetId() != ret.Client.GetId() {
w.SetError(E_INVALID_GRANT, "")
return nil
}
// check redirect uri
if ret.RedirectUri == "" {
ret.RedirectUri = FirstUri(ret.Client.GetRedirectUri(), s.Config.RedirectUriSeparator)
}
if err = ValidateUriList(ret.Client.GetRedirectUri(), ret.RedirectUri, s.Config.RedirectUriSeparator); err != nil {
w.SetError(E_INVALID_REQUEST, "")
w.InternalError = err
return nil
}
if ret.AuthorizeData.RedirectUri != ret.RedirectUri {
w.SetError(E_INVALID_REQUEST, "")
w.InternalError = errors.New("Redirect uri is different")
return nil
}
// Verify PKCE, if present in the authorization data
if len(ret.AuthorizeData.CodeChallenge) > 0 {
// https://tools.ietf.org/html/rfc7636#section-4.1
if matched := pkceMatcher.MatchString(ret.CodeVerifier); !matched {
w.SetError(E_INVALID_REQUEST, "code_verifier invalid (rfc7636)")
w.InternalError = errors.New("code_verifier has invalid format")
return nil
}
// https: //tools.ietf.org/html/rfc7636#section-4.6
codeVerifier := ""
switch ret.AuthorizeData.CodeChallengeMethod {
case "", PKCE_PLAIN:
codeVerifier = ret.CodeVerifier
case PKCE_S256:
hash := sha256.Sum256([]byte(ret.CodeVerifier))
codeVerifier = base64.RawURLEncoding.EncodeToString(hash[:])
default:
w.SetError(E_INVALID_REQUEST, "code_challenge_method transform algorithm not supported (rfc7636)")
return nil
}
if codeVerifier != ret.AuthorizeData.CodeChallenge {
w.SetError(E_INVALID_GRANT, "code_verifier invalid (rfc7636)")
w.InternalError = errors.New("code_verifier failed comparison with code_challenge")
return nil
}
}
// set rest of data
ret.Scope = ret.AuthorizeData.Scope
ret.UserData = ret.AuthorizeData.UserData
return ret
}
func extraScopes(access_scopes, refresh_scopes string) bool {
access_scopes_list := strings.Split(access_scopes, " ")
refresh_scopes_list := strings.Split(refresh_scopes, " ")
access_map := make(map[string]int)
for _, scope := range access_scopes_list {
if scope == "" {
continue
}
access_map[scope] = 1
}
for _, scope := range refresh_scopes_list {
if scope == "" {
continue
}
if _, ok := access_map[scope]; !ok {
return true
}
}
return false
}
func (s *Server) handleRefreshTokenRequest(w *Response, r *http.Request) *AccessRequest {
// get client authentication
auth := getClientAuth(w, r, s.Config.AllowClientSecretInParams)
if auth == nil {
return nil
}
// generate access token
ret := &AccessRequest{
Type: REFRESH_TOKEN,
Code: r.Form.Get("refresh_token"),
Scope: r.Form.Get("scope"),
GenerateRefresh: true,
Expiration: s.Config.AccessExpiration,
HttpRequest: r,
}
// "refresh_token" is required
if ret.Code == "" {
w.SetError(E_INVALID_GRANT, "")
return nil
}
// must have a valid client
if ret.Client = getClient(auth, w.Storage, w); ret.Client == nil {
return nil
}
// must be a valid refresh code
var err error
ret.AccessData, err = w.Storage.LoadRefresh(ret.Code)
if err != nil {
w.SetError(E_INVALID_GRANT, "")
w.InternalError = err
return nil
}
if ret.AccessData == nil {
w.SetError(E_UNAUTHORIZED_CLIENT, "")
return nil
}
if ret.AccessData.Client == nil {
w.SetError(E_UNAUTHORIZED_CLIENT, "")
return nil
}
if ret.AccessData.Client.GetRedirectUri() == "" {
w.SetError(E_UNAUTHORIZED_CLIENT, "")
return nil
}
// client must be the same as the previous token
if ret.AccessData.Client.GetId() != ret.Client.GetId() {
w.SetError(E_INVALID_CLIENT, "")
w.InternalError = errors.New("Client id must be the same from previous token")
return nil
}
// set rest of data
ret.RedirectUri = ret.AccessData.RedirectUri
ret.UserData = ret.AccessData.UserData
if ret.Scope == "" {
ret.Scope = ret.AccessData.Scope
}
if extraScopes(ret.AccessData.Scope, ret.Scope) {
w.SetError(E_ACCESS_DENIED, "")
w.InternalError = errors.New("the requested scope must not include any scope not originally granted by the resource owner")
return nil
}
return ret
}
func (s *Server) handlePasswordRequest(w *Response, r *http.Request) *AccessRequest {
// get client authentication
auth := getClientAuth(w, r, s.Config.AllowClientSecretInParams)
if auth == nil {
return nil
}
// generate access token
ret := &AccessRequest{
Type: PASSWORD,
Username: r.Form.Get("username"),
Password: r.Form.Get("password"),
Scope: r.Form.Get("scope"),
GenerateRefresh: true,
Expiration: s.Config.AccessExpiration,
HttpRequest: r,
}
// "username" and "password" is required
if ret.Username == "" || ret.Password == "" {
w.SetError(E_INVALID_GRANT, "")
return nil
}
// must have a valid client
if ret.Client = getClient(auth, w.Storage, w); ret.Client == nil {
return nil
}
// set redirect uri
ret.RedirectUri = FirstUri(ret.Client.GetRedirectUri(), s.Config.RedirectUriSeparator)
return ret
}
func (s *Server) handleClientCredentialsRequest(w *Response, r *http.Request) *AccessRequest {
// get client authentication
auth := getClientAuth(w, r, s.Config.AllowClientSecretInParams)
if auth == nil {
return nil
}
// generate access token
ret := &AccessRequest{
Type: CLIENT_CREDENTIALS,
Scope: r.Form.Get("scope"),
GenerateRefresh: false,
Expiration: s.Config.AccessExpiration,
HttpRequest: r,
}
// must have a valid client
if ret.Client = getClient(auth, w.Storage, w); ret.Client == nil {
return nil
}
// set redirect uri
ret.RedirectUri = FirstUri(ret.Client.GetRedirectUri(), s.Config.RedirectUriSeparator)
return ret
}
func (s *Server) handleAssertionRequest(w *Response, r *http.Request) *AccessRequest {
// get client authentication
auth := getClientAuth(w, r, s.Config.AllowClientSecretInParams)
if auth == nil {
return nil
}
// generate access token
ret := &AccessRequest{
Type: ASSERTION,
Scope: r.Form.Get("scope"),
AssertionType: r.Form.Get("assertion_type"),
Assertion: r.Form.Get("assertion"),
GenerateRefresh: false, // assertion should NOT generate a refresh token, per the RFC
Expiration: s.Config.AccessExpiration,
HttpRequest: r,
}
// "assertion_type" and "assertion" is required
if ret.AssertionType == "" || ret.Assertion == "" {
w.SetError(E_INVALID_GRANT, "")
return nil
}
// must have a valid client
if ret.Client = getClient(auth, w.Storage, w); ret.Client == nil {
return nil
}
// set redirect uri
ret.RedirectUri = FirstUri(ret.Client.GetRedirectUri(), s.Config.RedirectUriSeparator)
return ret
}
func (s *Server) FinishAccessRequest(w *Response, r *http.Request, ar *AccessRequest) {
// don't process if is already an error
if w.IsError {
return
}
redirectUri := r.Form.Get("redirect_uri")
// Get redirect uri from AccessRequest if it's there (e.g., refresh token request)
if ar.RedirectUri != "" {
redirectUri = ar.RedirectUri
}
if ar.Authorized {
var ret *AccessData
var err error
if ar.ForceAccessData == nil {
// generate access token
ret = &AccessData{
Client: ar.Client,
AuthorizeData: ar.AuthorizeData,
AccessData: ar.AccessData,
RedirectUri: redirectUri,
CreatedAt: s.Now(),
ExpiresIn: ar.Expiration,
UserData: ar.UserData,
Scope: ar.Scope,
}
// generate access token
ret.AccessToken, ret.RefreshToken, err = s.AccessTokenGen.GenerateAccessToken(ret, ar.GenerateRefresh)
if err != nil {
w.SetError(E_SERVER_ERROR, "")
w.InternalError = err
return
}
} else {
ret = ar.ForceAccessData
}
// save access token
if err = w.Storage.SaveAccess(ret); err != nil {
w.SetError(E_SERVER_ERROR, "")
w.InternalError = err
return
}
// remove authorization token
if ret.AuthorizeData != nil {
w.Storage.RemoveAuthorize(ret.AuthorizeData.Code)
}
// remove previous access token
if ret.AccessData != nil && !s.Config.RetainTokenAfterRefresh {
if ret.AccessData.RefreshToken != "" {
w.Storage.RemoveRefresh(ret.AccessData.RefreshToken)
}
w.Storage.RemoveAccess(ret.AccessData.AccessToken)
}
// output data
w.Output["access_token"] = ret.AccessToken
w.Output["token_type"] = s.Config.TokenType
w.Output["expires_in"] = ret.ExpiresIn
if ret.RefreshToken != "" {
w.Output["refresh_token"] = ret.RefreshToken
}
if ret.Scope != "" {
w.Output["scope"] = ret.Scope
}
} else {
w.SetError(E_ACCESS_DENIED, "")
}
}
// Helper Functions
// getClient looks up and authenticates the basic auth using the given
// storage. Sets an error on the response if auth fails or a server error occurs.
func getClient(auth *BasicAuth, storage Storage, w *Response) Client {
client, err := storage.GetClient(auth.Username)
if err == ErrNotFound {
w.SetError(E_UNAUTHORIZED_CLIENT, "")
return nil
}
if err != nil {
w.SetError(E_SERVER_ERROR, "")
w.InternalError = err
return nil
}
if client == nil {
w.SetError(E_UNAUTHORIZED_CLIENT, "")
return nil
}
if !CheckClientSecret(client, auth.Password) {
w.SetError(E_UNAUTHORIZED_CLIENT, "")
return nil
}
if client.GetRedirectUri() == "" {
w.SetError(E_UNAUTHORIZED_CLIENT, "")
return nil
}
return client
}