Create `limit' package, remove need of login_attempts.go
This commit is contained in:
		| @@ -1,10 +1,10 @@ | ||||
| package app | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strconv" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"git.zxq.co/ripple/rippleapi/limit" | ||||
| ) | ||||
|  | ||||
| const reqsPerSecond = 5000 | ||||
| @@ -29,74 +29,8 @@ func rateLimiter() { | ||||
| } | ||||
| func perUserRequestLimiter(uid int, ip string) { | ||||
| 	if uid == 0 { | ||||
| 		defaultLimiter.Request("ip:"+ip, 60) | ||||
| 		limit.Request("ip:"+ip, 60) | ||||
| 	} else { | ||||
| 		defaultLimiter.Request("user:"+strconv.Itoa(uid), 2000) | ||||
| 		limit.Request("user:"+strconv.Itoa(uid), 2000) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| var defaultLimiter = &specificRateLimiter{ | ||||
| 	Map:   make(map[string]chan struct{}), | ||||
| 	Mutex: &sync.RWMutex{}, | ||||
| } | ||||
|  | ||||
| type specificRateLimiter struct { | ||||
| 	Map   map[string]chan struct{} | ||||
| 	Mutex *sync.RWMutex | ||||
| } | ||||
|  | ||||
| func (s *specificRateLimiter) Request(u string, perMinute int) { | ||||
| 	s.Mutex.RLock() | ||||
| 	c, exists := s.Map[u] | ||||
| 	s.Mutex.RUnlock() | ||||
| 	if !exists { | ||||
| 		c = makePrefilledChan(perMinute) | ||||
| 		s.Mutex.Lock() | ||||
| 		// Now that we have exclusive read and write-access, we want to | ||||
| 		// make sure we don't overwrite an existing channel. Otherwise, | ||||
| 		// race conditions and panic happen. | ||||
| 		if cNew, exists := s.Map[u]; exists { | ||||
| 			c = cNew | ||||
| 			s.Mutex.Unlock() | ||||
| 		} else { | ||||
| 			s.Map[u] = c | ||||
| 			s.Mutex.Unlock() | ||||
| 			<-c | ||||
| 			go s.filler(u, perMinute) | ||||
| 		} | ||||
| 	} | ||||
| 	<-c | ||||
| } | ||||
|  | ||||
| func (s *specificRateLimiter) filler(el string, perMinute int) { | ||||
| 	defer func() { | ||||
| 		r := recover() | ||||
| 		if r != nil { | ||||
| 			fmt.Println(r) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	s.Mutex.RLock() | ||||
| 	c := s.Map[el] | ||||
| 	s.Mutex.RUnlock() | ||||
| 	for { | ||||
| 		select { | ||||
| 		case c <- struct{}{}: | ||||
| 			time.Sleep(time.Minute / time.Duration(perMinute)) | ||||
| 		default: // c is full | ||||
| 			s.Mutex.Lock() | ||||
| 			close(c) | ||||
| 			delete(s.Map, el) | ||||
| 			s.Mutex.Unlock() | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func makePrefilledChan(l int) chan struct{} { | ||||
| 	c := make(chan struct{}, l) | ||||
| 	for i := 0; i < l; i++ { | ||||
| 		c <- struct{}{} | ||||
| 	} | ||||
| 	return c | ||||
| } | ||||
|   | ||||
| @@ -1,61 +0,0 @@ | ||||
| package v1 | ||||
|  | ||||
| import ( | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type failedAttempt struct { | ||||
| 	attempt time.Time | ||||
| 	ID      int | ||||
| } | ||||
|  | ||||
| var failedAttempts []failedAttempt | ||||
| var failedAttemptsMutex = new(sync.RWMutex) | ||||
|  | ||||
| // removeUseless removes the expired attempts in failedAttempts | ||||
| func removeUseless() { | ||||
| 	for { | ||||
| 		failedAttemptsMutex.RLock() | ||||
| 		var localCopy = make([]failedAttempt, len(failedAttempts)) | ||||
| 		copy(localCopy, failedAttempts) | ||||
| 		failedAttemptsMutex.RUnlock() | ||||
| 		var newStartFrom int | ||||
| 		for k, v := range localCopy { | ||||
| 			if time.Since(v.attempt) > time.Minute*10 { | ||||
| 				newStartFrom = k + 1 | ||||
| 			} else { | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 		copySl := localCopy[newStartFrom:] | ||||
| 		failedAttemptsMutex.Lock() | ||||
| 		failedAttempts = make([]failedAttempt, len(copySl)) | ||||
| 		for i, v := range copySl { | ||||
| 			failedAttempts[i] = v | ||||
| 		} | ||||
| 		failedAttemptsMutex.Unlock() | ||||
| 		time.Sleep(time.Minute * 10) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func addFailedAttempt(uid int) { | ||||
| 	failedAttemptsMutex.Lock() | ||||
| 	failedAttempts = append(failedAttempts, failedAttempt{ | ||||
| 		attempt: time.Now(), | ||||
| 		ID:      uid, | ||||
| 	}) | ||||
| 	failedAttemptsMutex.Unlock() | ||||
| } | ||||
|  | ||||
| func nFailedAttempts(uid int) int { | ||||
| 	var count int | ||||
| 	failedAttemptsMutex.RLock() | ||||
| 	for _, i := range failedAttempts { | ||||
| 		if i.ID == uid && time.Since(i.attempt) < time.Minute*10 { | ||||
| 			count++ | ||||
| 		} | ||||
| 	} | ||||
| 	failedAttemptsMutex.RUnlock() | ||||
| 	return count | ||||
| } | ||||
| @@ -8,6 +8,7 @@ import ( | ||||
| 	"github.com/jmoiron/sqlx" | ||||
|  | ||||
| 	"git.zxq.co/ripple/rippleapi/common" | ||||
| 	"git.zxq.co/ripple/rippleapi/limit" | ||||
| 	"git.zxq.co/ripple/schiavolib" | ||||
| 	"golang.org/x/crypto/bcrypt" | ||||
| ) | ||||
| @@ -76,7 +77,7 @@ func TokenNewPOST(md common.MethodData) common.CodeMessager { | ||||
| 	} | ||||
| 	privileges := common.UserPrivileges(privilegesRaw) | ||||
|  | ||||
| 	if nFailedAttempts(r.ID) > 20 { | ||||
| 	if !limit.NonBlockingRequest(fmt.Sprintf("loginattempt:%d:%s", r.ID, md.C.ClientIP()), 5) { | ||||
| 		return common.SimpleResponse(429, "You've made too many login attempts. Try again later.") | ||||
| 	} | ||||
|  | ||||
| @@ -85,7 +86,6 @@ func TokenNewPOST(md common.MethodData) common.CodeMessager { | ||||
| 	} | ||||
| 	if err := bcrypt.CompareHashAndPassword([]byte(pw), []byte(fmt.Sprintf("%x", md5.Sum([]byte(data.Password))))); err != nil { | ||||
| 		if err == bcrypt.ErrMismatchedHashAndPassword { | ||||
| 			go addFailedAttempt(r.ID) | ||||
| 			return common.SimpleResponse(403, "That password doesn't match!") | ||||
| 		} | ||||
| 		md.Err(err) | ||||
|   | ||||
							
								
								
									
										122
									
								
								limit/limit.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										122
									
								
								limit/limit.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,122 @@ | ||||
| package limit | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // Request is a Request with DefaultLimiter. | ||||
| func Request(u string, perMinute int) { DefaultLimiter.Request(u, perMinute) } | ||||
|  | ||||
| // NonBlockingRequest is a NonBlockingRequest with DefaultLimiter. | ||||
| func NonBlockingRequest(u string, perMinute int) bool { | ||||
| 	return DefaultLimiter.NonBlockingRequest(u, perMinute) | ||||
| } | ||||
|  | ||||
| // DefaultLimiter is the RateLimiter used by the package-level | ||||
| // Request and NonBlockingRequest. | ||||
| var DefaultLimiter = &RateLimiter{ | ||||
| 	Map:   make(map[string]chan struct{}), | ||||
| 	Mutex: &sync.RWMutex{}, | ||||
| } | ||||
|  | ||||
| // RateLimiter is a simple rate limiter. | ||||
| type RateLimiter struct { | ||||
| 	Map   map[string]chan struct{} | ||||
| 	Mutex *sync.RWMutex | ||||
| } | ||||
|  | ||||
| // Request is a simple request. Blocks until it can make the request. | ||||
| func (s *RateLimiter) Request(u string, perMinute int) { | ||||
| 	s.request(u, perMinute, true) | ||||
| } | ||||
|  | ||||
| // NonBlockingRequest checks if it can do a request. If it can't, it returns | ||||
| // false, else it returns true if the request succeded. | ||||
| func (s *RateLimiter) NonBlockingRequest(u string, perMinute int) bool { | ||||
| 	return s.request(u, perMinute, false) | ||||
| } | ||||
|  | ||||
| func (s *RateLimiter) request(u string, perMinute int, blocking bool) bool { | ||||
| 	s.check() | ||||
| 	s.Mutex.RLock() | ||||
| 	c, exists := s.Map[u] | ||||
| 	s.Mutex.RUnlock() | ||||
| 	if !exists { | ||||
| 		c = makePrefilledChan(perMinute) | ||||
| 		s.Mutex.Lock() | ||||
| 		// Now that we have exclusive read and write-access, we want to | ||||
| 		// make sure we don't overwrite an existing channel. Otherwise, | ||||
| 		// race conditions and panic happen. | ||||
| 		if cNew, exists := s.Map[u]; exists { | ||||
| 			c = cNew | ||||
| 			s.Mutex.Unlock() | ||||
| 		} else { | ||||
| 			s.Map[u] = c | ||||
| 			s.Mutex.Unlock() | ||||
| 			<-c | ||||
| 			go s.filler(u, perMinute) | ||||
| 		} | ||||
| 	} | ||||
| 	return rcv(c, blocking) | ||||
| } | ||||
|  | ||||
| // rcv receives from a channel, but if blocking is true it waits til something | ||||
| // is received and always returns true, otherwise if it can't receive it | ||||
| // returns false. | ||||
| func rcv(c chan struct{}, blocking bool) bool { | ||||
| 	if blocking { | ||||
| 		<-c | ||||
| 		return true | ||||
| 	} | ||||
| 	select { | ||||
| 	case <-c: | ||||
| 		return true | ||||
| 	default: | ||||
| 		return false | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *RateLimiter) filler(el string, perMinute int) { | ||||
| 	defer func() { | ||||
| 		r := recover() | ||||
| 		if r != nil { | ||||
| 			fmt.Println(r) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	s.Mutex.RLock() | ||||
| 	c := s.Map[el] | ||||
| 	s.Mutex.RUnlock() | ||||
| 	for { | ||||
| 		select { | ||||
| 		case c <- struct{}{}: | ||||
| 			time.Sleep(time.Minute / time.Duration(perMinute)) | ||||
| 		default: // c is full | ||||
| 			s.Mutex.Lock() | ||||
| 			close(c) | ||||
| 			delete(s.Map, el) | ||||
| 			s.Mutex.Unlock() | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // check makes sure the map and the mutex are properly initialised. | ||||
| func (s *RateLimiter) check() { | ||||
| 	if s.Map == nil { | ||||
| 		s.Map = make(map[string]chan struct{}) | ||||
| 	} | ||||
| 	if s.Mutex { | ||||
| 		s.Mutex = new(sync.RWMutex) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func makePrefilledChan(l int) chan struct{} { | ||||
| 	c := make(chan struct{}, l) | ||||
| 	for i := 0; i < l; i++ { | ||||
| 		c <- struct{}{} | ||||
| 	} | ||||
| 	return c | ||||
| } | ||||
		Reference in New Issue
	
	Block a user