Implement rate limiting
- 60 requests per minute for requests without a valid API token - 2000 requests per minute per user for requests with a valid API token
This commit is contained in:
		| @@ -19,11 +19,12 @@ func Method(f func(md common.MethodData) common.CodeMessager, privilegesNeeded . | ||||
| } | ||||
|  | ||||
| func initialCaretaker(c *gin.Context, f func(md common.MethodData) common.CodeMessager, privilegesNeeded ...int) { | ||||
| 	rateLimiter() | ||||
|  | ||||
| 	data, err := ioutil.ReadAll(c.Request.Body) | ||||
| 	if err != nil { | ||||
| 		c.Error(err) | ||||
| 	} | ||||
| 	c.Request.Body.Close() | ||||
|  | ||||
| 	token := "" | ||||
| 	switch { | ||||
| @@ -50,6 +51,8 @@ func initialCaretaker(c *gin.Context, f func(md common.MethodData) common.CodeMe | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	perUserRequestLimiter(md.ID(), c.Request.Header.Get("X-Real-IP")) | ||||
|  | ||||
| 	missingPrivileges := 0 | ||||
| 	for _, privilege := range privilegesNeeded { | ||||
| 		if int(md.User.Privileges)&privilege == 0 { | ||||
| @@ -96,7 +99,7 @@ func mkjson(c *gin.Context, data interface{}) { | ||||
| 	exported, err := json.MarshalIndent(data, "", "\t") | ||||
| 	if err != nil { | ||||
| 		c.Error(err) | ||||
| 		exported = []byte(`{ "code": 500, "message": "something has gone really really really really really really wrong.", "data": null }`) | ||||
| 		exported = []byte(`{ "code": 500, "message": "something has gone really really really really really really wrong." }`) | ||||
| 	} | ||||
| 	cb := c.Query("callback") | ||||
| 	willcb := cb != "" && | ||||
|   | ||||
| @@ -9,6 +9,9 @@ import ( | ||||
| // PeppyMethod generates a method for the peppyapi | ||||
| func PeppyMethod(a func(c *gin.Context, db *sql.DB)) gin.HandlerFunc { | ||||
| 	return func(c *gin.Context) { | ||||
| 		rateLimiter() | ||||
| 		perUserRequestLimiter(0, c.Request.Header.Get("X-Real-IP")) | ||||
|  | ||||
| 		// I have no idea how, but I manged to accidentally string the first 4 | ||||
| 		// letters of the alphabet into a single function call. | ||||
| 		a(c, db) | ||||
|   | ||||
							
								
								
									
										86
									
								
								app/rate_limiter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								app/rate_limiter.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,86 @@ | ||||
| package app | ||||
|  | ||||
| import ( | ||||
| 	"strconv" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| const reqsPerSecond = 5000 | ||||
| const sleepTime = time.Second / reqsPerSecond | ||||
|  | ||||
| var limiter = make(chan struct{}, reqsPerSecond) | ||||
|  | ||||
| func setUpLimiter() { | ||||
| 	for i := 0; i < reqsPerSecond; i++ { | ||||
| 		limiter <- struct{}{} | ||||
| 	} | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			limiter <- struct{}{} | ||||
| 			time.Sleep(sleepTime) | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| func rateLimiter() { | ||||
| 	<-limiter | ||||
| } | ||||
| func perUserRequestLimiter(uid int, ip string) { | ||||
| 	if uid == 0 { | ||||
| 		defaultLimiter.Request("ip:"+ip, 60) | ||||
| 	} else { | ||||
| 		defaultLimiter.Request("user:"+strconv.Itoa(uid), 2000) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| var defaultLimiter = &specificRateLimiter{ | ||||
| 	Map:   make(map[string]chan struct{}), | ||||
| 	Mutex: &sync.RWMutex{}, | ||||
| } | ||||
|  | ||||
| type specificRateLimiter struct { | ||||
| 	Map   map[string]chan struct{} | ||||
| 	Mutex *sync.RWMutex | ||||
| } | ||||
|  | ||||
| func (s *specificRateLimiter) Request(u string, perMinute int) { | ||||
| 	s.Mutex.RLock() | ||||
| 	c, exists := s.Map[u] | ||||
| 	s.Mutex.RUnlock() | ||||
| 	if !exists { | ||||
| 		c = makePrefilledChan(perMinute) | ||||
| 		s.Mutex.Lock() | ||||
| 		s.Map[u] = c | ||||
| 		s.Mutex.Unlock() | ||||
| 		<-c | ||||
| 		go s.filler(u, perMinute) | ||||
| 	} | ||||
| 	<-c | ||||
| } | ||||
|  | ||||
| func (s *specificRateLimiter) filler(el string, perMinute int) { | ||||
| 	s.Mutex.RLock() | ||||
| 	c := s.Map[el] | ||||
| 	s.Mutex.RUnlock() | ||||
| 	for { | ||||
| 		select { | ||||
| 		case c <- struct{}{}: | ||||
| 			time.Sleep(time.Minute / time.Duration(perMinute)) | ||||
| 		default: // c is full | ||||
| 			s.Mutex.Lock() | ||||
| 			close(c) | ||||
| 			delete(s.Map, el) | ||||
| 			s.Mutex.Unlock() | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func makePrefilledChan(l int) chan struct{} { | ||||
| 	c := make(chan struct{}, l) | ||||
| 	for i := 0; i < l; i++ { | ||||
| 		c <- struct{}{} | ||||
| 	} | ||||
| 	return c | ||||
| } | ||||
| @@ -18,6 +18,9 @@ var db *sql.DB | ||||
| // Start begins taking HTTP connections. | ||||
| func Start(conf common.Conf, dbO *sql.DB) *gin.Engine { | ||||
| 	db = dbO | ||||
|  | ||||
| 	setUpLimiter() | ||||
|  | ||||
| 	r := gin.Default() | ||||
| 	r.Use(gzip.Gzip(gzip.DefaultCompression)) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user