hanayo/vendor/github.com/johnniedoe/contrib/cors/config.go
2019-02-23 13:29:15 +00:00

110 lines
2.4 KiB
Go

package cors
import (
"net/http"
"strconv"
"strings"
"time"
)
type settings struct {
allowAllOrigins bool
allowedOriginFunc func(string) bool
allowedOrigins []string
allowedMethods []string
allowedHeaders []string
exposedHeaders []string
normalHeaders http.Header
preflightHeaders http.Header
}
func newSettings(c Config) *settings {
if err := c.Validate(); err != nil {
panic(err.Error())
}
return &settings{
allowedOriginFunc: c.AllowOriginFunc,
allowAllOrigins: c.AllowAllOrigins,
allowedOrigins: c.AllowedOrigins,
allowedMethods: distinct(c.AllowedMethods),
allowedHeaders: distinct(c.AllowedHeaders),
normalHeaders: generateNormalHeaders(c),
preflightHeaders: generatePreflightHeaders(c),
}
}
func (c *settings) validateOrigin(origin string) (string, bool) {
if c.allowAllOrigins {
return "*", true
}
if c.allowedOriginFunc != nil {
return origin, c.allowedOriginFunc(origin)
}
for _, value := range c.allowedOrigins {
if value == origin {
return origin, true
}
}
return "", false
}
func (c *settings) validateMethod(method string) bool {
// TODO!!!
return true
}
func (c *settings) validateHeader(header string) bool {
// TODO!!!
return true
}
func generateNormalHeaders(c Config) http.Header {
headers := make(http.Header)
if c.AllowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
}
if len(c.ExposedHeaders) > 0 {
headers.Set("Access-Control-Expose-Headers", strings.Join(c.ExposedHeaders, ", "))
}
return headers
}
func generatePreflightHeaders(c Config) http.Header {
headers := make(http.Header)
if c.AllowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
}
if len(c.AllowedMethods) > 0 {
headers.Set("Access-Control-Allow-Methods", strings.Join(c.AllowedMethods, ", "))
}
if len(c.AllowedHeaders) > 0 {
headers.Set("Access-Control-Allow-Headers", strings.Join(c.AllowedHeaders, ", "))
}
if c.MaxAge > time.Duration(0) {
headers.Set("Access-Control-Max-Age", strconv.FormatInt(int64(c.MaxAge/time.Second), 10))
}
return headers
}
func distinct(s []string) []string {
m := map[string]bool{}
for _, v := range s {
if _, seen := m[v]; !seen {
s[len(m)] = v
m[v] = true
}
}
return s[:len(m)]
}
func parse(content string) []string {
if len(content) == 0 {
return nil
}
parts := strings.Split(content, ",")
for i := 0; i < len(parts); i++ {
parts[i] = strings.TrimSpace(parts[i])
}
return parts
}