563 lines
14 KiB
Go
563 lines
14 KiB
Go
|
package osin
|
||
|
|
||
|
import (
|
||
|
"crypto/sha256"
|
||
|
"encoding/base64"
|
||
|
"errors"
|
||
|
"net/http"
|
||
|
"strings"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
// AccessRequestType is the type for OAuth param `grant_type`
|
||
|
type AccessRequestType string
|
||
|
|
||
|
const (
|
||
|
AUTHORIZATION_CODE AccessRequestType = "authorization_code"
|
||
|
REFRESH_TOKEN AccessRequestType = "refresh_token"
|
||
|
PASSWORD AccessRequestType = "password"
|
||
|
CLIENT_CREDENTIALS AccessRequestType = "client_credentials"
|
||
|
ASSERTION AccessRequestType = "assertion"
|
||
|
IMPLICIT AccessRequestType = "__implicit"
|
||
|
)
|
||
|
|
||
|
// AccessRequest is a request for access tokens
|
||
|
type AccessRequest struct {
|
||
|
Type AccessRequestType
|
||
|
Code string
|
||
|
Client Client
|
||
|
AuthorizeData *AuthorizeData
|
||
|
AccessData *AccessData
|
||
|
|
||
|
// Force finish to use this access data, to allow access data reuse
|
||
|
ForceAccessData *AccessData
|
||
|
RedirectUri string
|
||
|
Scope string
|
||
|
Username string
|
||
|
Password string
|
||
|
AssertionType string
|
||
|
Assertion string
|
||
|
|
||
|
// Set if request is authorized
|
||
|
Authorized bool
|
||
|
|
||
|
// Token expiration in seconds. Change if different from default
|
||
|
Expiration int32
|
||
|
|
||
|
// Set if a refresh token should be generated
|
||
|
GenerateRefresh bool
|
||
|
|
||
|
// Data to be passed to storage. Not used by the library.
|
||
|
UserData interface{}
|
||
|
|
||
|
// HttpRequest *http.Request for special use
|
||
|
HttpRequest *http.Request
|
||
|
|
||
|
// Optional code_verifier as described in rfc7636
|
||
|
CodeVerifier string
|
||
|
}
|
||
|
|
||
|
// AccessData represents an access grant (tokens, expiration, client, etc)
|
||
|
type AccessData struct {
|
||
|
// Client information
|
||
|
Client Client
|
||
|
|
||
|
// Authorize data, for authorization code
|
||
|
AuthorizeData *AuthorizeData
|
||
|
|
||
|
// Previous access data, for refresh token
|
||
|
AccessData *AccessData
|
||
|
|
||
|
// Access token
|
||
|
AccessToken string
|
||
|
|
||
|
// Refresh Token. Can be blank
|
||
|
RefreshToken string
|
||
|
|
||
|
// Token expiration in seconds
|
||
|
ExpiresIn int32
|
||
|
|
||
|
// Requested scope
|
||
|
Scope string
|
||
|
|
||
|
// Redirect Uri from request
|
||
|
RedirectUri string
|
||
|
|
||
|
// Date created
|
||
|
CreatedAt time.Time
|
||
|
|
||
|
// Data to be passed to storage. Not used by the library.
|
||
|
UserData interface{}
|
||
|
}
|
||
|
|
||
|
// IsExpired returns true if access expired
|
||
|
func (d *AccessData) IsExpired() bool {
|
||
|
return d.IsExpiredAt(time.Now())
|
||
|
}
|
||
|
|
||
|
// IsExpiredAt returns true if access expires at time 't'
|
||
|
func (d *AccessData) IsExpiredAt(t time.Time) bool {
|
||
|
return d.ExpireAt().Before(t)
|
||
|
}
|
||
|
|
||
|
// ExpireAt returns the expiration date
|
||
|
func (d *AccessData) ExpireAt() time.Time {
|
||
|
return d.CreatedAt.Add(time.Duration(d.ExpiresIn) * time.Second)
|
||
|
}
|
||
|
|
||
|
// AccessTokenGen generates access tokens
|
||
|
type AccessTokenGen interface {
|
||
|
GenerateAccessToken(data *AccessData, generaterefresh bool) (accesstoken string, refreshtoken string, err error)
|
||
|
}
|
||
|
|
||
|
// HandleAccessRequest is the http.HandlerFunc for handling access token requests
|
||
|
func (s *Server) HandleAccessRequest(w *Response, r *http.Request) *AccessRequest {
|
||
|
// Only allow GET or POST
|
||
|
if r.Method == "GET" {
|
||
|
if !s.Config.AllowGetAccessRequest {
|
||
|
w.SetError(E_INVALID_REQUEST, "")
|
||
|
w.InternalError = errors.New("Request must be POST")
|
||
|
return nil
|
||
|
}
|
||
|
} else if r.Method != "POST" {
|
||
|
w.SetError(E_INVALID_REQUEST, "")
|
||
|
w.InternalError = errors.New("Request must be POST")
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
err := r.ParseForm()
|
||
|
if err != nil {
|
||
|
w.SetError(E_INVALID_REQUEST, "")
|
||
|
w.InternalError = err
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
grantType := AccessRequestType(r.Form.Get("grant_type"))
|
||
|
if s.Config.AllowedAccessTypes.Exists(grantType) {
|
||
|
switch grantType {
|
||
|
case AUTHORIZATION_CODE:
|
||
|
return s.handleAuthorizationCodeRequest(w, r)
|
||
|
case REFRESH_TOKEN:
|
||
|
return s.handleRefreshTokenRequest(w, r)
|
||
|
case PASSWORD:
|
||
|
return s.handlePasswordRequest(w, r)
|
||
|
case CLIENT_CREDENTIALS:
|
||
|
return s.handleClientCredentialsRequest(w, r)
|
||
|
case ASSERTION:
|
||
|
return s.handleAssertionRequest(w, r)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
w.SetError(E_UNSUPPORTED_GRANT_TYPE, "")
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *AccessRequest {
|
||
|
// get client authentication
|
||
|
auth := getClientAuth(w, r, s.Config.AllowClientSecretInParams)
|
||
|
if auth == nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// generate access token
|
||
|
ret := &AccessRequest{
|
||
|
Type: AUTHORIZATION_CODE,
|
||
|
Code: r.Form.Get("code"),
|
||
|
CodeVerifier: r.Form.Get("code_verifier"),
|
||
|
RedirectUri: r.Form.Get("redirect_uri"),
|
||
|
GenerateRefresh: true,
|
||
|
Expiration: s.Config.AccessExpiration,
|
||
|
HttpRequest: r,
|
||
|
}
|
||
|
|
||
|
// "code" is required
|
||
|
if ret.Code == "" {
|
||
|
w.SetError(E_INVALID_GRANT, "")
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// must have a valid client
|
||
|
if ret.Client = getClient(auth, w.Storage, w); ret.Client == nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// must be a valid authorization code
|
||
|
var err error
|
||
|
ret.AuthorizeData, err = w.Storage.LoadAuthorize(ret.Code)
|
||
|
if err != nil {
|
||
|
w.SetError(E_INVALID_GRANT, "")
|
||
|
w.InternalError = err
|
||
|
return nil
|
||
|
}
|
||
|
if ret.AuthorizeData == nil {
|
||
|
w.SetError(E_UNAUTHORIZED_CLIENT, "")
|
||
|
return nil
|
||
|
}
|
||
|
if ret.AuthorizeData.Client == nil {
|
||
|
w.SetError(E_UNAUTHORIZED_CLIENT, "")
|
||
|
return nil
|
||
|
}
|
||
|
if ret.AuthorizeData.Client.GetRedirectUri() == "" {
|
||
|
w.SetError(E_UNAUTHORIZED_CLIENT, "")
|
||
|
return nil
|
||
|
}
|
||
|
if ret.AuthorizeData.IsExpiredAt(s.Now()) {
|
||
|
w.SetError(E_INVALID_GRANT, "")
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// code must be from the client
|
||
|
if ret.AuthorizeData.Client.GetId() != ret.Client.GetId() {
|
||
|
w.SetError(E_INVALID_GRANT, "")
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// check redirect uri
|
||
|
if ret.RedirectUri == "" {
|
||
|
ret.RedirectUri = FirstUri(ret.Client.GetRedirectUri(), s.Config.RedirectUriSeparator)
|
||
|
}
|
||
|
if err = ValidateUriList(ret.Client.GetRedirectUri(), ret.RedirectUri, s.Config.RedirectUriSeparator); err != nil {
|
||
|
w.SetError(E_INVALID_REQUEST, "")
|
||
|
w.InternalError = err
|
||
|
return nil
|
||
|
}
|
||
|
if ret.AuthorizeData.RedirectUri != ret.RedirectUri {
|
||
|
w.SetError(E_INVALID_REQUEST, "")
|
||
|
w.InternalError = errors.New("Redirect uri is different")
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Verify PKCE, if present in the authorization data
|
||
|
if len(ret.AuthorizeData.CodeChallenge) > 0 {
|
||
|
// https://tools.ietf.org/html/rfc7636#section-4.1
|
||
|
if matched := pkceMatcher.MatchString(ret.CodeVerifier); !matched {
|
||
|
w.SetError(E_INVALID_REQUEST, "code_verifier invalid (rfc7636)")
|
||
|
w.InternalError = errors.New("code_verifier has invalid format")
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// https: //tools.ietf.org/html/rfc7636#section-4.6
|
||
|
codeVerifier := ""
|
||
|
switch ret.AuthorizeData.CodeChallengeMethod {
|
||
|
case "", PKCE_PLAIN:
|
||
|
codeVerifier = ret.CodeVerifier
|
||
|
case PKCE_S256:
|
||
|
hash := sha256.Sum256([]byte(ret.CodeVerifier))
|
||
|
codeVerifier = base64.RawURLEncoding.EncodeToString(hash[:])
|
||
|
default:
|
||
|
w.SetError(E_INVALID_REQUEST, "code_challenge_method transform algorithm not supported (rfc7636)")
|
||
|
return nil
|
||
|
}
|
||
|
if codeVerifier != ret.AuthorizeData.CodeChallenge {
|
||
|
w.SetError(E_INVALID_GRANT, "code_verifier invalid (rfc7636)")
|
||
|
w.InternalError = errors.New("code_verifier failed comparison with code_challenge")
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// set rest of data
|
||
|
ret.Scope = ret.AuthorizeData.Scope
|
||
|
ret.UserData = ret.AuthorizeData.UserData
|
||
|
|
||
|
return ret
|
||
|
}
|
||
|
|
||
|
func extraScopes(access_scopes, refresh_scopes string) bool {
|
||
|
access_scopes_list := strings.Split(access_scopes, " ")
|
||
|
refresh_scopes_list := strings.Split(refresh_scopes, " ")
|
||
|
|
||
|
access_map := make(map[string]int)
|
||
|
|
||
|
for _, scope := range access_scopes_list {
|
||
|
if scope == "" {
|
||
|
continue
|
||
|
}
|
||
|
access_map[scope] = 1
|
||
|
}
|
||
|
|
||
|
for _, scope := range refresh_scopes_list {
|
||
|
if scope == "" {
|
||
|
continue
|
||
|
}
|
||
|
if _, ok := access_map[scope]; !ok {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
func (s *Server) handleRefreshTokenRequest(w *Response, r *http.Request) *AccessRequest {
|
||
|
// get client authentication
|
||
|
auth := getClientAuth(w, r, s.Config.AllowClientSecretInParams)
|
||
|
if auth == nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// generate access token
|
||
|
ret := &AccessRequest{
|
||
|
Type: REFRESH_TOKEN,
|
||
|
Code: r.Form.Get("refresh_token"),
|
||
|
Scope: r.Form.Get("scope"),
|
||
|
GenerateRefresh: true,
|
||
|
Expiration: s.Config.AccessExpiration,
|
||
|
HttpRequest: r,
|
||
|
}
|
||
|
|
||
|
// "refresh_token" is required
|
||
|
if ret.Code == "" {
|
||
|
w.SetError(E_INVALID_GRANT, "")
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// must have a valid client
|
||
|
if ret.Client = getClient(auth, w.Storage, w); ret.Client == nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// must be a valid refresh code
|
||
|
var err error
|
||
|
ret.AccessData, err = w.Storage.LoadRefresh(ret.Code)
|
||
|
if err != nil {
|
||
|
w.SetError(E_INVALID_GRANT, "")
|
||
|
w.InternalError = err
|
||
|
return nil
|
||
|
}
|
||
|
if ret.AccessData == nil {
|
||
|
w.SetError(E_UNAUTHORIZED_CLIENT, "")
|
||
|
return nil
|
||
|
}
|
||
|
if ret.AccessData.Client == nil {
|
||
|
w.SetError(E_UNAUTHORIZED_CLIENT, "")
|
||
|
return nil
|
||
|
}
|
||
|
if ret.AccessData.Client.GetRedirectUri() == "" {
|
||
|
w.SetError(E_UNAUTHORIZED_CLIENT, "")
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// client must be the same as the previous token
|
||
|
if ret.AccessData.Client.GetId() != ret.Client.GetId() {
|
||
|
w.SetError(E_INVALID_CLIENT, "")
|
||
|
w.InternalError = errors.New("Client id must be the same from previous token")
|
||
|
return nil
|
||
|
|
||
|
}
|
||
|
|
||
|
// set rest of data
|
||
|
ret.RedirectUri = ret.AccessData.RedirectUri
|
||
|
ret.UserData = ret.AccessData.UserData
|
||
|
if ret.Scope == "" {
|
||
|
ret.Scope = ret.AccessData.Scope
|
||
|
}
|
||
|
|
||
|
if extraScopes(ret.AccessData.Scope, ret.Scope) {
|
||
|
w.SetError(E_ACCESS_DENIED, "")
|
||
|
w.InternalError = errors.New("the requested scope must not include any scope not originally granted by the resource owner")
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
return ret
|
||
|
}
|
||
|
|
||
|
func (s *Server) handlePasswordRequest(w *Response, r *http.Request) *AccessRequest {
|
||
|
// get client authentication
|
||
|
auth := getClientAuth(w, r, s.Config.AllowClientSecretInParams)
|
||
|
if auth == nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// generate access token
|
||
|
ret := &AccessRequest{
|
||
|
Type: PASSWORD,
|
||
|
Username: r.Form.Get("username"),
|
||
|
Password: r.Form.Get("password"),
|
||
|
Scope: r.Form.Get("scope"),
|
||
|
GenerateRefresh: true,
|
||
|
Expiration: s.Config.AccessExpiration,
|
||
|
HttpRequest: r,
|
||
|
}
|
||
|
|
||
|
// "username" and "password" is required
|
||
|
if ret.Username == "" || ret.Password == "" {
|
||
|
w.SetError(E_INVALID_GRANT, "")
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// must have a valid client
|
||
|
if ret.Client = getClient(auth, w.Storage, w); ret.Client == nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// set redirect uri
|
||
|
ret.RedirectUri = FirstUri(ret.Client.GetRedirectUri(), s.Config.RedirectUriSeparator)
|
||
|
|
||
|
return ret
|
||
|
}
|
||
|
|
||
|
func (s *Server) handleClientCredentialsRequest(w *Response, r *http.Request) *AccessRequest {
|
||
|
// get client authentication
|
||
|
auth := getClientAuth(w, r, s.Config.AllowClientSecretInParams)
|
||
|
if auth == nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// generate access token
|
||
|
ret := &AccessRequest{
|
||
|
Type: CLIENT_CREDENTIALS,
|
||
|
Scope: r.Form.Get("scope"),
|
||
|
GenerateRefresh: false,
|
||
|
Expiration: s.Config.AccessExpiration,
|
||
|
HttpRequest: r,
|
||
|
}
|
||
|
|
||
|
// must have a valid client
|
||
|
if ret.Client = getClient(auth, w.Storage, w); ret.Client == nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// set redirect uri
|
||
|
ret.RedirectUri = FirstUri(ret.Client.GetRedirectUri(), s.Config.RedirectUriSeparator)
|
||
|
|
||
|
return ret
|
||
|
}
|
||
|
|
||
|
func (s *Server) handleAssertionRequest(w *Response, r *http.Request) *AccessRequest {
|
||
|
// get client authentication
|
||
|
auth := getClientAuth(w, r, s.Config.AllowClientSecretInParams)
|
||
|
if auth == nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// generate access token
|
||
|
ret := &AccessRequest{
|
||
|
Type: ASSERTION,
|
||
|
Scope: r.Form.Get("scope"),
|
||
|
AssertionType: r.Form.Get("assertion_type"),
|
||
|
Assertion: r.Form.Get("assertion"),
|
||
|
GenerateRefresh: false, // assertion should NOT generate a refresh token, per the RFC
|
||
|
Expiration: s.Config.AccessExpiration,
|
||
|
HttpRequest: r,
|
||
|
}
|
||
|
|
||
|
// "assertion_type" and "assertion" is required
|
||
|
if ret.AssertionType == "" || ret.Assertion == "" {
|
||
|
w.SetError(E_INVALID_GRANT, "")
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// must have a valid client
|
||
|
if ret.Client = getClient(auth, w.Storage, w); ret.Client == nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// set redirect uri
|
||
|
ret.RedirectUri = FirstUri(ret.Client.GetRedirectUri(), s.Config.RedirectUriSeparator)
|
||
|
|
||
|
return ret
|
||
|
}
|
||
|
|
||
|
func (s *Server) FinishAccessRequest(w *Response, r *http.Request, ar *AccessRequest) {
|
||
|
// don't process if is already an error
|
||
|
if w.IsError {
|
||
|
return
|
||
|
}
|
||
|
redirectUri := r.Form.Get("redirect_uri")
|
||
|
// Get redirect uri from AccessRequest if it's there (e.g., refresh token request)
|
||
|
if ar.RedirectUri != "" {
|
||
|
redirectUri = ar.RedirectUri
|
||
|
}
|
||
|
if ar.Authorized {
|
||
|
var ret *AccessData
|
||
|
var err error
|
||
|
|
||
|
if ar.ForceAccessData == nil {
|
||
|
// generate access token
|
||
|
ret = &AccessData{
|
||
|
Client: ar.Client,
|
||
|
AuthorizeData: ar.AuthorizeData,
|
||
|
AccessData: ar.AccessData,
|
||
|
RedirectUri: redirectUri,
|
||
|
CreatedAt: s.Now(),
|
||
|
ExpiresIn: ar.Expiration,
|
||
|
UserData: ar.UserData,
|
||
|
Scope: ar.Scope,
|
||
|
}
|
||
|
|
||
|
// generate access token
|
||
|
ret.AccessToken, ret.RefreshToken, err = s.AccessTokenGen.GenerateAccessToken(ret, ar.GenerateRefresh)
|
||
|
if err != nil {
|
||
|
w.SetError(E_SERVER_ERROR, "")
|
||
|
w.InternalError = err
|
||
|
return
|
||
|
}
|
||
|
} else {
|
||
|
ret = ar.ForceAccessData
|
||
|
}
|
||
|
|
||
|
// save access token
|
||
|
if err = w.Storage.SaveAccess(ret); err != nil {
|
||
|
w.SetError(E_SERVER_ERROR, "")
|
||
|
w.InternalError = err
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// remove authorization token
|
||
|
if ret.AuthorizeData != nil {
|
||
|
w.Storage.RemoveAuthorize(ret.AuthorizeData.Code)
|
||
|
}
|
||
|
|
||
|
// remove previous access token
|
||
|
if ret.AccessData != nil && !s.Config.RetainTokenAfterRefresh {
|
||
|
if ret.AccessData.RefreshToken != "" {
|
||
|
w.Storage.RemoveRefresh(ret.AccessData.RefreshToken)
|
||
|
}
|
||
|
w.Storage.RemoveAccess(ret.AccessData.AccessToken)
|
||
|
}
|
||
|
|
||
|
// output data
|
||
|
w.Output["access_token"] = ret.AccessToken
|
||
|
w.Output["token_type"] = s.Config.TokenType
|
||
|
w.Output["expires_in"] = ret.ExpiresIn
|
||
|
if ret.RefreshToken != "" {
|
||
|
w.Output["refresh_token"] = ret.RefreshToken
|
||
|
}
|
||
|
if ret.Scope != "" {
|
||
|
w.Output["scope"] = ret.Scope
|
||
|
}
|
||
|
} else {
|
||
|
w.SetError(E_ACCESS_DENIED, "")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Helper Functions
|
||
|
|
||
|
// getClient looks up and authenticates the basic auth using the given
|
||
|
// storage. Sets an error on the response if auth fails or a server error occurs.
|
||
|
func getClient(auth *BasicAuth, storage Storage, w *Response) Client {
|
||
|
client, err := storage.GetClient(auth.Username)
|
||
|
if err == ErrNotFound {
|
||
|
w.SetError(E_UNAUTHORIZED_CLIENT, "")
|
||
|
return nil
|
||
|
}
|
||
|
if err != nil {
|
||
|
w.SetError(E_SERVER_ERROR, "")
|
||
|
w.InternalError = err
|
||
|
return nil
|
||
|
}
|
||
|
if client == nil {
|
||
|
w.SetError(E_UNAUTHORIZED_CLIENT, "")
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
if !CheckClientSecret(client, auth.Password) {
|
||
|
w.SetError(E_UNAUTHORIZED_CLIENT, "")
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
if client.GetRedirectUri() == "" {
|
||
|
w.SetError(E_UNAUTHORIZED_CLIENT, "")
|
||
|
return nil
|
||
|
}
|
||
|
return client
|
||
|
}
|