Implement rate limiting
- 60 requests per minute for requests without a valid API token - 2000 requests per minute per user for requests with a valid API token
This commit is contained in:
		@@ -19,11 +19,12 @@ func Method(f func(md common.MethodData) common.CodeMessager, privilegesNeeded .
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func initialCaretaker(c *gin.Context, f func(md common.MethodData) common.CodeMessager, privilegesNeeded ...int) {
 | 
			
		||||
	rateLimiter()
 | 
			
		||||
 | 
			
		||||
	data, err := ioutil.ReadAll(c.Request.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.Error(err)
 | 
			
		||||
	}
 | 
			
		||||
	c.Request.Body.Close()
 | 
			
		||||
 | 
			
		||||
	token := ""
 | 
			
		||||
	switch {
 | 
			
		||||
@@ -50,6 +51,8 @@ func initialCaretaker(c *gin.Context, f func(md common.MethodData) common.CodeMe
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	perUserRequestLimiter(md.ID(), c.Request.Header.Get("X-Real-IP"))
 | 
			
		||||
 | 
			
		||||
	missingPrivileges := 0
 | 
			
		||||
	for _, privilege := range privilegesNeeded {
 | 
			
		||||
		if int(md.User.Privileges)&privilege == 0 {
 | 
			
		||||
@@ -96,7 +99,7 @@ func mkjson(c *gin.Context, data interface{}) {
 | 
			
		||||
	exported, err := json.MarshalIndent(data, "", "\t")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.Error(err)
 | 
			
		||||
		exported = []byte(`{ "code": 500, "message": "something has gone really really really really really really wrong.", "data": null }`)
 | 
			
		||||
		exported = []byte(`{ "code": 500, "message": "something has gone really really really really really really wrong." }`)
 | 
			
		||||
	}
 | 
			
		||||
	cb := c.Query("callback")
 | 
			
		||||
	willcb := cb != "" &&
 | 
			
		||||
 
 | 
			
		||||
@@ -9,6 +9,9 @@ import (
 | 
			
		||||
// PeppyMethod generates a method for the peppyapi
 | 
			
		||||
func PeppyMethod(a func(c *gin.Context, db *sql.DB)) gin.HandlerFunc {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		rateLimiter()
 | 
			
		||||
		perUserRequestLimiter(0, c.Request.Header.Get("X-Real-IP"))
 | 
			
		||||
 | 
			
		||||
		// I have no idea how, but I manged to accidentally string the first 4
 | 
			
		||||
		// letters of the alphabet into a single function call.
 | 
			
		||||
		a(c, db)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										86
									
								
								app/rate_limiter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								app/rate_limiter.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,86 @@
 | 
			
		||||
package app
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const reqsPerSecond = 5000
 | 
			
		||||
const sleepTime = time.Second / reqsPerSecond
 | 
			
		||||
 | 
			
		||||
var limiter = make(chan struct{}, reqsPerSecond)
 | 
			
		||||
 | 
			
		||||
func setUpLimiter() {
 | 
			
		||||
	for i := 0; i < reqsPerSecond; i++ {
 | 
			
		||||
		limiter <- struct{}{}
 | 
			
		||||
	}
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			limiter <- struct{}{}
 | 
			
		||||
			time.Sleep(sleepTime)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func rateLimiter() {
 | 
			
		||||
	<-limiter
 | 
			
		||||
}
 | 
			
		||||
func perUserRequestLimiter(uid int, ip string) {
 | 
			
		||||
	if uid == 0 {
 | 
			
		||||
		defaultLimiter.Request("ip:"+ip, 60)
 | 
			
		||||
	} else {
 | 
			
		||||
		defaultLimiter.Request("user:"+strconv.Itoa(uid), 2000)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var defaultLimiter = &specificRateLimiter{
 | 
			
		||||
	Map:   make(map[string]chan struct{}),
 | 
			
		||||
	Mutex: &sync.RWMutex{},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type specificRateLimiter struct {
 | 
			
		||||
	Map   map[string]chan struct{}
 | 
			
		||||
	Mutex *sync.RWMutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *specificRateLimiter) Request(u string, perMinute int) {
 | 
			
		||||
	s.Mutex.RLock()
 | 
			
		||||
	c, exists := s.Map[u]
 | 
			
		||||
	s.Mutex.RUnlock()
 | 
			
		||||
	if !exists {
 | 
			
		||||
		c = makePrefilledChan(perMinute)
 | 
			
		||||
		s.Mutex.Lock()
 | 
			
		||||
		s.Map[u] = c
 | 
			
		||||
		s.Mutex.Unlock()
 | 
			
		||||
		<-c
 | 
			
		||||
		go s.filler(u, perMinute)
 | 
			
		||||
	}
 | 
			
		||||
	<-c
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *specificRateLimiter) filler(el string, perMinute int) {
 | 
			
		||||
	s.Mutex.RLock()
 | 
			
		||||
	c := s.Map[el]
 | 
			
		||||
	s.Mutex.RUnlock()
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case c <- struct{}{}:
 | 
			
		||||
			time.Sleep(time.Minute / time.Duration(perMinute))
 | 
			
		||||
		default: // c is full
 | 
			
		||||
			s.Mutex.Lock()
 | 
			
		||||
			close(c)
 | 
			
		||||
			delete(s.Map, el)
 | 
			
		||||
			s.Mutex.Unlock()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func makePrefilledChan(l int) chan struct{} {
 | 
			
		||||
	c := make(chan struct{}, l)
 | 
			
		||||
	for i := 0; i < l; i++ {
 | 
			
		||||
		c <- struct{}{}
 | 
			
		||||
	}
 | 
			
		||||
	return c
 | 
			
		||||
}
 | 
			
		||||
@@ -18,6 +18,9 @@ var db *sql.DB
 | 
			
		||||
// Start begins taking HTTP connections.
 | 
			
		||||
func Start(conf common.Conf, dbO *sql.DB) *gin.Engine {
 | 
			
		||||
	db = dbO
 | 
			
		||||
 | 
			
		||||
	setUpLimiter()
 | 
			
		||||
 | 
			
		||||
	r := gin.Default()
 | 
			
		||||
	r.Use(gzip.Gzip(gzip.DefaultCompression))
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user