110 lines
2.4 KiB
Go
110 lines
2.4 KiB
Go
package cors
|
|
|
|
import (
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type settings struct {
|
|
allowAllOrigins bool
|
|
allowedOriginFunc func(string) bool
|
|
allowedOrigins []string
|
|
allowedMethods []string
|
|
allowedHeaders []string
|
|
exposedHeaders []string
|
|
normalHeaders http.Header
|
|
preflightHeaders http.Header
|
|
}
|
|
|
|
func newSettings(c Config) *settings {
|
|
if err := c.Validate(); err != nil {
|
|
panic(err.Error())
|
|
}
|
|
return &settings{
|
|
allowedOriginFunc: c.AllowOriginFunc,
|
|
allowAllOrigins: c.AllowAllOrigins,
|
|
allowedOrigins: c.AllowedOrigins,
|
|
allowedMethods: distinct(c.AllowedMethods),
|
|
allowedHeaders: distinct(c.AllowedHeaders),
|
|
normalHeaders: generateNormalHeaders(c),
|
|
preflightHeaders: generatePreflightHeaders(c),
|
|
}
|
|
}
|
|
|
|
func (c *settings) validateOrigin(origin string) (string, bool) {
|
|
if c.allowAllOrigins {
|
|
return "*", true
|
|
}
|
|
if c.allowedOriginFunc != nil {
|
|
return origin, c.allowedOriginFunc(origin)
|
|
}
|
|
for _, value := range c.allowedOrigins {
|
|
if value == origin {
|
|
return origin, true
|
|
}
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
func (c *settings) validateMethod(method string) bool {
|
|
// TODO!!!
|
|
return true
|
|
}
|
|
|
|
func (c *settings) validateHeader(header string) bool {
|
|
// TODO!!!
|
|
return true
|
|
}
|
|
|
|
func generateNormalHeaders(c Config) http.Header {
|
|
headers := make(http.Header)
|
|
if c.AllowCredentials {
|
|
headers.Set("Access-Control-Allow-Credentials", "true")
|
|
}
|
|
if len(c.ExposedHeaders) > 0 {
|
|
headers.Set("Access-Control-Expose-Headers", strings.Join(c.ExposedHeaders, ", "))
|
|
}
|
|
return headers
|
|
}
|
|
|
|
func generatePreflightHeaders(c Config) http.Header {
|
|
headers := make(http.Header)
|
|
if c.AllowCredentials {
|
|
headers.Set("Access-Control-Allow-Credentials", "true")
|
|
}
|
|
if len(c.AllowedMethods) > 0 {
|
|
headers.Set("Access-Control-Allow-Methods", strings.Join(c.AllowedMethods, ", "))
|
|
}
|
|
if len(c.AllowedHeaders) > 0 {
|
|
headers.Set("Access-Control-Allow-Headers", strings.Join(c.AllowedHeaders, ", "))
|
|
}
|
|
if c.MaxAge > time.Duration(0) {
|
|
headers.Set("Access-Control-Max-Age", strconv.FormatInt(int64(c.MaxAge/time.Second), 10))
|
|
}
|
|
return headers
|
|
}
|
|
|
|
func distinct(s []string) []string {
|
|
m := map[string]bool{}
|
|
for _, v := range s {
|
|
if _, seen := m[v]; !seen {
|
|
s[len(m)] = v
|
|
m[v] = true
|
|
}
|
|
}
|
|
return s[:len(m)]
|
|
}
|
|
|
|
func parse(content string) []string {
|
|
if len(content) == 0 {
|
|
return nil
|
|
}
|
|
parts := strings.Split(content, ",")
|
|
for i := 0; i < len(parts); i++ {
|
|
parts[i] = strings.TrimSpace(parts[i])
|
|
}
|
|
return parts
|
|
}
|