Implement rate limiting
- 60 requests per minute for requests without a valid API token - 2000 requests per minute per user for requests with a valid API token
This commit is contained in:
parent
0a870ee742
commit
faf948b037
@ -19,11 +19,12 @@ func Method(f func(md common.MethodData) common.CodeMessager, privilegesNeeded .
|
||||
}
|
||||
|
||||
func initialCaretaker(c *gin.Context, f func(md common.MethodData) common.CodeMessager, privilegesNeeded ...int) {
|
||||
rateLimiter()
|
||||
|
||||
data, err := ioutil.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.Error(err)
|
||||
}
|
||||
c.Request.Body.Close()
|
||||
|
||||
token := ""
|
||||
switch {
|
||||
@ -50,6 +51,8 @@ func initialCaretaker(c *gin.Context, f func(md common.MethodData) common.CodeMe
|
||||
}
|
||||
}
|
||||
|
||||
perUserRequestLimiter(md.ID(), c.Request.Header.Get("X-Real-IP"))
|
||||
|
||||
missingPrivileges := 0
|
||||
for _, privilege := range privilegesNeeded {
|
||||
if int(md.User.Privileges)&privilege == 0 {
|
||||
@ -96,7 +99,7 @@ func mkjson(c *gin.Context, data interface{}) {
|
||||
exported, err := json.MarshalIndent(data, "", "\t")
|
||||
if err != nil {
|
||||
c.Error(err)
|
||||
exported = []byte(`{ "code": 500, "message": "something has gone really really really really really really wrong.", "data": null }`)
|
||||
exported = []byte(`{ "code": 500, "message": "something has gone really really really really really really wrong." }`)
|
||||
}
|
||||
cb := c.Query("callback")
|
||||
willcb := cb != "" &&
|
||||
|
@ -9,6 +9,9 @@ import (
|
||||
// PeppyMethod generates a method for the peppyapi
|
||||
func PeppyMethod(a func(c *gin.Context, db *sql.DB)) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
rateLimiter()
|
||||
perUserRequestLimiter(0, c.Request.Header.Get("X-Real-IP"))
|
||||
|
||||
// I have no idea how, but I manged to accidentally string the first 4
|
||||
// letters of the alphabet into a single function call.
|
||||
a(c, db)
|
||||
|
86
app/rate_limiter.go
Normal file
86
app/rate_limiter.go
Normal file
@ -0,0 +1,86 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const reqsPerSecond = 5000
|
||||
const sleepTime = time.Second / reqsPerSecond
|
||||
|
||||
var limiter = make(chan struct{}, reqsPerSecond)
|
||||
|
||||
func setUpLimiter() {
|
||||
for i := 0; i < reqsPerSecond; i++ {
|
||||
limiter <- struct{}{}
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
limiter <- struct{}{}
|
||||
time.Sleep(sleepTime)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func rateLimiter() {
|
||||
<-limiter
|
||||
}
|
||||
func perUserRequestLimiter(uid int, ip string) {
|
||||
if uid == 0 {
|
||||
defaultLimiter.Request("ip:"+ip, 60)
|
||||
} else {
|
||||
defaultLimiter.Request("user:"+strconv.Itoa(uid), 2000)
|
||||
}
|
||||
}
|
||||
|
||||
var defaultLimiter = &specificRateLimiter{
|
||||
Map: make(map[string]chan struct{}),
|
||||
Mutex: &sync.RWMutex{},
|
||||
}
|
||||
|
||||
type specificRateLimiter struct {
|
||||
Map map[string]chan struct{}
|
||||
Mutex *sync.RWMutex
|
||||
}
|
||||
|
||||
func (s *specificRateLimiter) Request(u string, perMinute int) {
|
||||
s.Mutex.RLock()
|
||||
c, exists := s.Map[u]
|
||||
s.Mutex.RUnlock()
|
||||
if !exists {
|
||||
c = makePrefilledChan(perMinute)
|
||||
s.Mutex.Lock()
|
||||
s.Map[u] = c
|
||||
s.Mutex.Unlock()
|
||||
<-c
|
||||
go s.filler(u, perMinute)
|
||||
}
|
||||
<-c
|
||||
}
|
||||
|
||||
func (s *specificRateLimiter) filler(el string, perMinute int) {
|
||||
s.Mutex.RLock()
|
||||
c := s.Map[el]
|
||||
s.Mutex.RUnlock()
|
||||
for {
|
||||
select {
|
||||
case c <- struct{}{}:
|
||||
time.Sleep(time.Minute / time.Duration(perMinute))
|
||||
default: // c is full
|
||||
s.Mutex.Lock()
|
||||
close(c)
|
||||
delete(s.Map, el)
|
||||
s.Mutex.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func makePrefilledChan(l int) chan struct{} {
|
||||
c := make(chan struct{}, l)
|
||||
for i := 0; i < l; i++ {
|
||||
c <- struct{}{}
|
||||
}
|
||||
return c
|
||||
}
|
@ -18,6 +18,9 @@ var db *sql.DB
|
||||
// Start begins taking HTTP connections.
|
||||
func Start(conf common.Conf, dbO *sql.DB) *gin.Engine {
|
||||
db = dbO
|
||||
|
||||
setUpLimiter()
|
||||
|
||||
r := gin.Default()
|
||||
r.Use(gzip.Gzip(gzip.DefaultCompression))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user