Implement rate limiting

- 60 requests per minute for requests without a valid API token
- 2000 requests per minute per user for requests with a valid API token
This commit is contained in:
Howl 2016-07-06 16:33:58 +02:00
parent 0a870ee742
commit faf948b037
4 changed files with 97 additions and 2 deletions

View File

@ -19,11 +19,12 @@ func Method(f func(md common.MethodData) common.CodeMessager, privilegesNeeded .
} }
func initialCaretaker(c *gin.Context, f func(md common.MethodData) common.CodeMessager, privilegesNeeded ...int) { func initialCaretaker(c *gin.Context, f func(md common.MethodData) common.CodeMessager, privilegesNeeded ...int) {
rateLimiter()
data, err := ioutil.ReadAll(c.Request.Body) data, err := ioutil.ReadAll(c.Request.Body)
if err != nil { if err != nil {
c.Error(err) c.Error(err)
} }
c.Request.Body.Close()
token := "" token := ""
switch { switch {
@ -50,6 +51,8 @@ func initialCaretaker(c *gin.Context, f func(md common.MethodData) common.CodeMe
} }
} }
perUserRequestLimiter(md.ID(), c.Request.Header.Get("X-Real-IP"))
missingPrivileges := 0 missingPrivileges := 0
for _, privilege := range privilegesNeeded { for _, privilege := range privilegesNeeded {
if int(md.User.Privileges)&privilege == 0 { if int(md.User.Privileges)&privilege == 0 {
@ -96,7 +99,7 @@ func mkjson(c *gin.Context, data interface{}) {
exported, err := json.MarshalIndent(data, "", "\t") exported, err := json.MarshalIndent(data, "", "\t")
if err != nil { if err != nil {
c.Error(err) c.Error(err)
exported = []byte(`{ "code": 500, "message": "something has gone really really really really really really wrong.", "data": null }`) exported = []byte(`{ "code": 500, "message": "something has gone really really really really really really wrong." }`)
} }
cb := c.Query("callback") cb := c.Query("callback")
willcb := cb != "" && willcb := cb != "" &&

View File

@ -9,6 +9,9 @@ import (
// PeppyMethod generates a method for the peppyapi // PeppyMethod generates a method for the peppyapi
func PeppyMethod(a func(c *gin.Context, db *sql.DB)) gin.HandlerFunc { func PeppyMethod(a func(c *gin.Context, db *sql.DB)) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
rateLimiter()
perUserRequestLimiter(0, c.Request.Header.Get("X-Real-IP"))
// I have no idea how, but I manged to accidentally string the first 4 // I have no idea how, but I manged to accidentally string the first 4
// letters of the alphabet into a single function call. // letters of the alphabet into a single function call.
a(c, db) a(c, db)

86
app/rate_limiter.go Normal file
View File

@ -0,0 +1,86 @@
package app
import (
"strconv"
"sync"
"time"
)
const reqsPerSecond = 5000
const sleepTime = time.Second / reqsPerSecond
var limiter = make(chan struct{}, reqsPerSecond)
func setUpLimiter() {
for i := 0; i < reqsPerSecond; i++ {
limiter <- struct{}{}
}
go func() {
for {
limiter <- struct{}{}
time.Sleep(sleepTime)
}
}()
}
func rateLimiter() {
<-limiter
}
func perUserRequestLimiter(uid int, ip string) {
if uid == 0 {
defaultLimiter.Request("ip:"+ip, 60)
} else {
defaultLimiter.Request("user:"+strconv.Itoa(uid), 2000)
}
}
var defaultLimiter = &specificRateLimiter{
Map: make(map[string]chan struct{}),
Mutex: &sync.RWMutex{},
}
type specificRateLimiter struct {
Map map[string]chan struct{}
Mutex *sync.RWMutex
}
func (s *specificRateLimiter) Request(u string, perMinute int) {
s.Mutex.RLock()
c, exists := s.Map[u]
s.Mutex.RUnlock()
if !exists {
c = makePrefilledChan(perMinute)
s.Mutex.Lock()
s.Map[u] = c
s.Mutex.Unlock()
<-c
go s.filler(u, perMinute)
}
<-c
}
func (s *specificRateLimiter) filler(el string, perMinute int) {
s.Mutex.RLock()
c := s.Map[el]
s.Mutex.RUnlock()
for {
select {
case c <- struct{}{}:
time.Sleep(time.Minute / time.Duration(perMinute))
default: // c is full
s.Mutex.Lock()
close(c)
delete(s.Map, el)
s.Mutex.Unlock()
return
}
}
}
func makePrefilledChan(l int) chan struct{} {
c := make(chan struct{}, l)
for i := 0; i < l; i++ {
c <- struct{}{}
}
return c
}

View File

@ -18,6 +18,9 @@ var db *sql.DB
// Start begins taking HTTP connections. // Start begins taking HTTP connections.
func Start(conf common.Conf, dbO *sql.DB) *gin.Engine { func Start(conf common.Conf, dbO *sql.DB) *gin.Engine {
db = dbO db = dbO
setUpLimiter()
r := gin.Default() r := gin.Default()
r.Use(gzip.Gzip(gzip.DefaultCompression)) r.Use(gzip.Gzip(gzip.DefaultCompression))