diff --git a/server/ratelimiter.py b/server/ratelimiter.py index c62f07887e55748a35ef366b3369178395784d51..a916f874f7b9db0b6bb8eab9bf173ddbd8c40228 100644 --- a/server/ratelimiter.py +++ b/server/ratelimiter.py @@ -1,51 +1,70 @@ -from functools import wraps import time -from sanic.response import json -class RateLimiter: +class TokenBucket: + ''' + Implementation of TokenBucket + ''' + def __init__(self,tokens:int,time_for_token:int): + ''' + tokens:int - maximum amount of tokens + time_for_token: - time in which 1 token is added + ''' + self.token_koef = tokens/time_for_token + self.tokens = tokens + self.last_check = time.time() + + def handle(self) -> bool: + current_time = time.time() + time_delta = current_time - self.last_check + self.last_check = current_time + + self.tokens += time_delta*self.token_koef + + if self.tokens > self.max_tokens: + self.tokens = self.max_tokens + + if self.tokens < 1: + return False + + self.tokens -= 1 + return True + +class RateLimited(Exception): def __init__(self): - self.storage = {} + super().__init__("Call was rate limited") - - async def limit(self, calls, per_second, func, request, *args, **kwargs): +GLOBAL_BUCKETS = {} - current_time = time.time() - cell = self.storage.get(request.ip) +def rate_limit(tokens:int,time_for_token:int): + def wrapper(function): + def wrapped(*args,**kwargs): + fn_name = function.__name__ + try: + bucket = GLOBAL_BUCKETS[fn_name] + except: + bucket = TokenBucket(tokens,time_for_token) + GLOBAL_BUCKETS[fn_name] = bucket - if not cell: - cell = [calls-1, current_time] - self.storage[request.ip] = cell - return await func(request,*args,**kwargs) + if not bucket.handle(): + raise RateLimited() - time_delta = current_time - cell[-1] - to_add = int(time_delta*(calls/per_second)) - cell[0] += to_add + return function(*args,**kwargs) + return wrapped + return wrapper - if cell[0] > calls: - cell[0] = calls - - if cell[0] <= 0: - return json({"success": False, "ratelimit": True}) - self.storage[request.ip][0] -= 1 - self.storage[request.ip][1] = current_time - return await func(request, *args, **kwargs) + +if __name__ == "__main__": + @rate_limit(3,50) + def func1(a): + print("Function1",a) + + for i in range(10): + try: + a = func1(555) + except RateLimited: + print("RT") + + time.sleep(2) -class EndpointLimiter: - def __init__(self): - self.funcs = {} - - - def limit(self, calls, per_second): - def decorator(func): - @wraps(func) - async def wrapper(request, *args, **kwargs): - try: - return await self.funcs[func.__name__].limit(calls, per_second, func, request, *args, **kwargs) - except KeyError: - rate_limiter = RateLimiter() - self.funcs[func.__name__] = rate_limiter - return await self.funcs[func.__name__].limit(calls, per_second, func, request, *args, **kwargs) - return wrapper - return decorator