From faf948b0374259ad2f576175c6e0cbed12fe9eb6 Mon Sep 17 00:00:00 2001 From: Howl Date: Wed, 6 Jul 2016 16:33:58 +0200 Subject: [PATCH] Implement rate limiting - 60 requests per minute for requests without a valid API token - 2000 requests per minute per user for requests with a valid API token --- app/method.go | 7 ++-- app/peppy_method.go | 3 ++ app/rate_limiter.go | 86 +++++++++++++++++++++++++++++++++++++++++++++ app/start.go | 3 ++ 4 files changed, 97 insertions(+), 2 deletions(-) create mode 100644 app/rate_limiter.go diff --git a/app/method.go b/app/method.go index 74c25f3..d4d4379 100644 --- a/app/method.go +++ b/app/method.go @@ -19,11 +19,12 @@ func Method(f func(md common.MethodData) common.CodeMessager, privilegesNeeded . } func initialCaretaker(c *gin.Context, f func(md common.MethodData) common.CodeMessager, privilegesNeeded ...int) { + rateLimiter() + data, err := ioutil.ReadAll(c.Request.Body) if err != nil { c.Error(err) } - c.Request.Body.Close() token := "" switch { @@ -50,6 +51,8 @@ func initialCaretaker(c *gin.Context, f func(md common.MethodData) common.CodeMe } } + perUserRequestLimiter(md.ID(), c.Request.Header.Get("X-Real-IP")) + missingPrivileges := 0 for _, privilege := range privilegesNeeded { if int(md.User.Privileges)&privilege == 0 { @@ -96,7 +99,7 @@ func mkjson(c *gin.Context, data interface{}) { exported, err := json.MarshalIndent(data, "", "\t") if err != nil { c.Error(err) - exported = []byte(`{ "code": 500, "message": "something has gone really really really really really really wrong.", "data": null }`) + exported = []byte(`{ "code": 500, "message": "something has gone really really really really really really wrong." }`) } cb := c.Query("callback") willcb := cb != "" && diff --git a/app/peppy_method.go b/app/peppy_method.go index 608ef84..e6c5780 100644 --- a/app/peppy_method.go +++ b/app/peppy_method.go @@ -9,6 +9,9 @@ import ( // PeppyMethod generates a method for the peppyapi func PeppyMethod(a func(c *gin.Context, db *sql.DB)) gin.HandlerFunc { return func(c *gin.Context) { + rateLimiter() + perUserRequestLimiter(0, c.Request.Header.Get("X-Real-IP")) + // I have no idea how, but I manged to accidentally string the first 4 // letters of the alphabet into a single function call. a(c, db) diff --git a/app/rate_limiter.go b/app/rate_limiter.go new file mode 100644 index 0000000..e162c92 --- /dev/null +++ b/app/rate_limiter.go @@ -0,0 +1,86 @@ +package app + +import ( + "strconv" + "sync" + "time" +) + +const reqsPerSecond = 5000 +const sleepTime = time.Second / reqsPerSecond + +var limiter = make(chan struct{}, reqsPerSecond) + +func setUpLimiter() { + for i := 0; i < reqsPerSecond; i++ { + limiter <- struct{}{} + } + go func() { + for { + limiter <- struct{}{} + time.Sleep(sleepTime) + } + }() +} + +func rateLimiter() { + <-limiter +} +func perUserRequestLimiter(uid int, ip string) { + if uid == 0 { + defaultLimiter.Request("ip:"+ip, 60) + } else { + defaultLimiter.Request("user:"+strconv.Itoa(uid), 2000) + } +} + +var defaultLimiter = &specificRateLimiter{ + Map: make(map[string]chan struct{}), + Mutex: &sync.RWMutex{}, +} + +type specificRateLimiter struct { + Map map[string]chan struct{} + Mutex *sync.RWMutex +} + +func (s *specificRateLimiter) Request(u string, perMinute int) { + s.Mutex.RLock() + c, exists := s.Map[u] + s.Mutex.RUnlock() + if !exists { + c = makePrefilledChan(perMinute) + s.Mutex.Lock() + s.Map[u] = c + s.Mutex.Unlock() + <-c + go s.filler(u, perMinute) + } + <-c +} + +func (s *specificRateLimiter) filler(el string, perMinute int) { + s.Mutex.RLock() + c := s.Map[el] + s.Mutex.RUnlock() + for { + select { + case c <- struct{}{}: + time.Sleep(time.Minute / time.Duration(perMinute)) + default: // c is full + s.Mutex.Lock() + close(c) + delete(s.Map, el) + s.Mutex.Unlock() + return + } + } +} + +func makePrefilledChan(l int) chan struct{} { + c := make(chan struct{}, l) + for i := 0; i < l; i++ { + c <- struct{}{} + } + return c +} diff --git a/app/start.go b/app/start.go index 14862f0..bc0850b 100644 --- a/app/start.go +++ b/app/start.go @@ -18,6 +18,9 @@ var db *sql.DB // Start begins taking HTTP connections. func Start(conf common.Conf, dbO *sql.DB) *gin.Engine { db = dbO + + setUpLimiter() + r := gin.Default() r.Use(gzip.Gzip(gzip.DefaultCompression))