def test_key_func(extension_factory): app, limiter = extension_factory() @app.route("/t1") @limiter.limit("100 per minute", lambda: "test") def t1(): return "test" with hiro.Timeline().freeze(): with app.test_client() as cli: for i in range(0, 100): assert 200 == \ cli.get( "/t1", headers={ "X_FORWARDED_FOR": "127.0.0.2" } ).status_code assert 429 == cli.get("/t1").status_code
def test_combined_rate_limits(self): app, limiter = self.build_app( {C.GLOBAL_LIMITS: "1 per hour; 10 per day"}) @app.route("/t1") @limiter.limit("100 per hour;10/minute") def t1(): return "t1" @app.route("/t2") def t2(): return "t2" with hiro.Timeline().freeze() as timeline: with app.test_client() as cli: self.assertEqual(200, cli.get("/t1").status_code) self.assertEqual(200, cli.get("/t2").status_code) self.assertEqual(429, cli.get("/t2").status_code)
def test_key_func(self): app, limiter = self.build_app() @app.route("/t1") @limiter.limit("100 per minute", lambda: "test") def t1(): return "test" with hiro.Timeline().freeze() as timeline: with app.test_client() as cli: for i in range(0, 100): self.assertEqual( 200, cli.get("/t1", headers={ "X_FORWARDED_FOR": "127.0.0.2" }).status_code) self.assertEqual(429, cli.get("/t1").status_code)
def test_decorated_limits_with_combined_defaults(extension_factory): app, limiter = extension_factory(default_limits=['2/minute']) @app.route("/") @limiter.limit("1/second", override_defaults=False) def root(): return "root" with hiro.Timeline() as timeline: with app.test_client() as cli: assert 200 == cli.get("/").status_code assert 429 == cli.get("/").status_code timeline.forward(60) assert 200 == cli.get("/").status_code timeline.forward(1) assert 200 == cli.get("/").status_code timeline.forward(1) assert 429 == cli.get("/").status_code
def test_invalid_decorated_static_limits(self): app = Flask(__name__) limiter = Limiter(app, default_limits=["1/second"], key_func=get_remote_address) mock_handler = mock.Mock() mock_handler.level = logging.INFO limiter.logger.addHandler(mock_handler) @app.route("/t1") @limiter.limit("2/sec") def t1(): return "42" with app.test_client() as cli: with hiro.Timeline().freeze() as timeline: self.assertEqual(cli.get("/t1").status_code, 200) self.assertEqual(cli.get("/t1").status_code, 429) self.assertTrue("failed to configure" in mock_handler.handle.call_args_list[0][0][0].msg) self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[1][0][0].msg)
def test_fallback_to_memory_backoff_check(self): app, limiter = self.build_app(config={C.ENABLED: True}, default_limits=["5/minute"], storage_uri="redis://*****:*****@app.route("/t1") def t1(): return "test" with app.test_client() as cli: def raiser(*a): raise Exception("redis dead") with hiro.Timeline() as timeline: with mock.patch( "redis.client.Redis.execute_command") as exec_command: exec_command.side_effect = raiser self.assertEqual(cli.get("/t1").status_code, 200) self.assertEqual(cli.get("/t1").status_code, 429) timeline.forward(1) self.assertEqual(cli.get("/t1").status_code, 429) timeline.forward(2) self.assertEqual(cli.get("/t1").status_code, 429) timeline.forward(4) self.assertEqual(cli.get("/t1").status_code, 429) timeline.forward(8) self.assertEqual(cli.get("/t1").status_code, 429) timeline.forward(16) self.assertEqual(cli.get("/t1").status_code, 429) timeline.forward(32) self.assertEqual(cli.get("/t1").status_code, 200) # redis back to normal, but exponential backoff will only # result in it being marked after pow(2,0) seconds and next # check self.assertEqual(cli.get("/t1").status_code, 429) timeline.forward(2) self.assertEqual(cli.get("/t1").status_code, 200) self.assertEqual(cli.get("/t1").status_code, 200) self.assertEqual(cli.get("/t1").status_code, 200) self.assertEqual(cli.get("/t1").status_code, 200) self.assertEqual(cli.get("/t1").status_code, 200) self.assertEqual(cli.get("/t1").status_code, 429)
def test_custom_error_message(self): app, limiter = self.build_app() @app.errorhandler(429) def ratelimit_handler(e): return make_response( e.description , 429 ) l1 = lambda: "1/second" e1 = lambda: "dos" @limiter.limit("1/second", error_message="uno") @app.route("/t1") def t1(): return "1" @limiter.limit(l1, error_message=e1) @app.route("/t2") def t2(): return "2" s1 = limiter.shared_limit("1/second", scope='error_message', error_message="tres") @app.route("/t3") @s1 def t3(): return "3" with hiro.Timeline().freeze(): with app.test_client() as cli: cli.get("/t1") resp = cli.get("/t1") self.assertEqual(429, resp.status_code) self.assertEqual(resp.data, b'uno') cli.get("/t2") resp = cli.get("/t2") self.assertEqual(429, resp.status_code) self.assertEqual(resp.data, b'dos') cli.get("/t3") resp = cli.get("/t3") self.assertEqual(429, resp.status_code) self.assertEqual(resp.data, b'tres')
def test_flask_restful_resource(self): app, limiter = self.build_app( global_limits=["1/hour"] ) api = restful.Api(app) class Va(Resource): decorators = [limiter.limit("2/second")] def get(self): return request.method.lower() def post(self): return request.method.lower() class Vb(Resource): decorators = [limiter.limit("1/second, 3/minute")] def get(self): return request.method.lower() class Vc(Resource): def get(self): return request.method.lower() api.add_resource(Va, "/a") api.add_resource(Vb, "/b") api.add_resource(Vc, "/c") with hiro.Timeline().freeze() as timeline: with app.test_client() as cli: self.assertEqual(200, cli.get("/a").status_code) self.assertEqual(200, cli.get("/a").status_code) self.assertEqual(429, cli.get("/a").status_code) self.assertEqual(429, cli.post("/a").status_code) self.assertEqual(200, cli.get("/b").status_code) timeline.forward(1) self.assertEqual(200, cli.get("/b").status_code) timeline.forward(1) self.assertEqual(200, cli.get("/b").status_code) timeline.forward(1) self.assertEqual(429, cli.get("/b").status_code) self.assertEqual(200, cli.get("/c").status_code) self.assertEqual(429, cli.get("/c").status_code)
def test_retry_after(self): app = Flask(__name__) _ = Limiter(app, default_limits=["1/minute"], headers_enabled=True, key_func=get_remote_address) @app.route("/t1") def t(): return "test" with hiro.Timeline().freeze() as timeline: with app.test_client() as cli: resp = cli.get("/t1") retry_after = int(resp.headers.get('Retry-After')) self.assertTrue(retry_after > 0) timeline.forward(retry_after) resp = cli.get("/t1") self.assertEqual(resp.status_code, 200)
def test_multiple_decorators(self): app, limiter = self.build_app(key_func=get_ipaddr) @app.route("/t1") @limiter.limit("100 per minute", lambda: "test") # effectively becomes a limit for all users @limiter.limit("50/minute") # per ip as per default key_func def t1(): return "test" with hiro.Timeline().freeze() as timeline: with app.test_client() as cli: for i in range(0, 100): self.assertEqual(200 if i < 50 else 429, cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"}).status_code ) for i in range(50): self.assertEqual(200, cli.get("/t1").status_code) self.assertEqual(429, cli.get("/t1").status_code) self.assertEqual(429, cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.3"}).status_code)
def test_fixed_window_with_elastic_expiry_in_memory(self): storage = MemoryStorage() limiter = FixedWindowElasticExpiryRateLimiter(storage) with hiro.Timeline().freeze() as timeline: start = int(time.time()) limit = RateLimitItemPerSecond(10, 2) self.assertTrue(all([limiter.hit(limit) for _ in range(0, 10)])) timeline.forward(1) self.assertFalse(limiter.hit(limit)) self.assertEqual(limiter.get_window_stats(limit)[1], 0) # three extensions to the expiry self.assertEqual(limiter.get_window_stats(limit)[0], start + 3) timeline.forward(1) self.assertFalse(limiter.hit(limit)) timeline.forward(3) start = int(time.time()) self.assertTrue(limiter.hit(limit)) self.assertEqual(limiter.get_window_stats(limit)[1], 9) self.assertEqual(limiter.get_window_stats(limit)[0], start + 2)
def test_retry_after(self): # FIXME: this test is not actually running! app, limiter = self.build_starlette_app(headers_enabled=True, key_func=get_remote_address) @app.route("/t1") @limiter.limit("1/minute") def t(request: Request): return PlainTextResponse("test") with hiro.Timeline().freeze() as timeline: with TestClient(app) as cli: resp = cli.get("/t1") retry_after = int(resp.headers.get("Retry-After")) assert retry_after > 0 timeline.forward(retry_after) resp = cli.get("/t1") assert resp.status_code == 200
def test_moving_window_in_memory(self): storage = MemoryStorage() limiter = MovingWindowRateLimiter(storage) with hiro.Timeline().freeze() as timeline: limit = RateLimitItemPerMinute(10) for i in range(0, 5): self.assertTrue(limiter.hit(limit)) self.assertTrue(limiter.hit(limit)) self.assertEqual( limiter.get_window_stats(limit)[1], 10 - ((i + 1) * 2)) timeline.forward(10) self.assertEqual(limiter.get_window_stats(limit)[1], 0) self.assertFalse(limiter.hit(limit)) timeline.forward(20) self.assertEqual(limiter.get_window_stats(limit)[1], 2) self.assertEqual( limiter.get_window_stats(limit)[0], int(time.time() + 30)) timeline.forward(31) self.assertEqual(limiter.get_window_stats(limit)[1], 10)
def test_decorated_dynamic_limits(self): app, limiter = self.build_app({"X": "2 per second"}, global_limits=["1/second"]) def request_context_limit(): limits = { "127.0.0.1": "10 per minute", "127.0.0.2": "1 per minute" } remote_addr = (request.access_route and request.access_route[0]) or request.remote_addr or '127.0.0.1' limit = limits.setdefault(remote_addr, '1 per minute') return limit @app.route("/t1") @limiter.limit("20/day") @limiter.limit(lambda: current_app.config.get("X")) @limiter.limit(request_context_limit) def t1(): return "42" @app.route("/t2") @limiter.limit(lambda: current_app.config.get("X")) def t2(): return "42" R1 = {"X_FORWARDED_FOR": "127.0.0.1, 127.0.0.0"} R2 = {"X_FORWARDED_FOR": "127.0.0.2"} with app.test_client() as cli: with hiro.Timeline().freeze() as timeline: for i in range(0, 10): self.assertEqual(cli.get("/t1", headers=R1).status_code, 200) timeline.forward(1) self.assertEqual(cli.get("/t1", headers=R1).status_code, 429) self.assertEqual(cli.get("/t1", headers=R2).status_code, 200) self.assertEqual(cli.get("/t1", headers=R2).status_code, 429) timeline.forward(60) self.assertEqual(cli.get("/t1", headers=R2).status_code, 200) self.assertEqual(cli.get("/t2").status_code, 200) self.assertEqual(cli.get("/t2").status_code, 200) self.assertEqual(cli.get("/t2").status_code, 429) timeline.forward(1) self.assertEqual(cli.get("/t2").status_code, 200)
def test_pluggable_views(self): app, limiter = self.build_app( global_limits=["1/hour"] ) class Va(View): methods = ['GET', 'POST'] decorators = [limiter.limit("2/second")] def dispatch_request(self): return request.method.lower() class Vb(View): methods = ['GET'] decorators = [limiter.limit("1/second, 3/minute")] def dispatch_request(self): return request.method.lower() class Vc(View): methods = ['GET'] def dispatch_request(self): return request.method.lower() app.add_url_rule("/a", view_func=Va.as_view("a")) app.add_url_rule("/b", view_func=Vb.as_view("b")) app.add_url_rule("/c", view_func=Vc.as_view("c")) with hiro.Timeline().freeze() as timeline: with app.test_client() as cli: self.assertEqual(200, cli.get("/a").status_code) self.assertEqual(200, cli.get("/a").status_code) self.assertEqual(429, cli.post("/a").status_code) self.assertEqual(200, cli.get("/b").status_code) timeline.forward(1) self.assertEqual(200, cli.get("/b").status_code) timeline.forward(1) self.assertEqual(200, cli.get("/b").status_code) timeline.forward(1) self.assertEqual(429, cli.get("/b").status_code) self.assertEqual(200, cli.get("/c").status_code) self.assertEqual(429, cli.get("/c").status_code)
def test_headers_breach(self): app, limiter = self.build_starlette_app(headers_enabled=True, key_func=get_remote_address) @app.route("/t1") @limiter.limit("2/second; 10 per minute; 20/hour") def t(request: Request): return PlainTextResponse("test") with hiro.Timeline().freeze() as timeline: with TestClient(app) as cli: for i in range(11): resp = cli.get("/t1") timeline.forward(1) assert resp.headers.get("X-RateLimit-Limit") == "10" assert resp.headers.get("X-RateLimit-Remaining") == "0" assert resp.headers.get("X-RateLimit-Reset") == str( int(time.time() + 50)) assert resp.headers.get("Retry-After") == str(int(50))
def test_fallback_to_memory(extension_factory): app, limiter = extension_factory(config={C.ENABLED: True}, global_limits=["2/minute"], storage_uri="redis://*****:*****@app.route("/t1") def t1(): return "test" @app.route("/t2") @limiter.limit("1 per minute") def t2(): return "test" with app.test_client() as cli: assert cli.get("/t1").status_code == 200 assert cli.get("/t1").status_code == 200 assert cli.get("/t1").status_code == 429 assert cli.get("/t2").status_code == 200 assert cli.get("/t2").status_code == 429 def raiser(*a): raise Exception("redis dead") with patch('limits.storage.RedisStorage.incr') as hit: hit.side_effect = raiser assert cli.get("/t1").status_code == 200 assert cli.get("/t1").status_code == 200 assert cli.get("/t1").status_code == 429 assert cli.get("/t2").status_code == 200 assert cli.get("/t2").status_code == 429 with hiro.Timeline() as timeline: timeline.forward(1) limiter._storage.storage.flushall() assert cli.get("/t2").status_code == 200 assert cli.get("/t2").status_code == 429 with patch('limits.storage.RedisStorage.get') as get: get.side_effect = raiser assert cli.get("/t1").status_code == 200
def test_shared_limit_with_conditional_deduction(extension_factory): app, limiter = extension_factory() bp = Blueprint("main", __name__) limit = limiter.shared_limit( "2/minute", "not_found", deduct_when=lambda response: response.status_code == 400) @app.route("/test/<path:path>") @limit def app_test(path): if path != "1": raise BadRequest() return path @bp.route("/test/<path:path>") def bp_test(path): if path != "1": raise BadRequest() return path limit(bp) app.register_blueprint(bp, url_prefix='/bp') with hiro.Timeline() as timeline: with app.test_client() as cli: assert cli.get("/bp/test/1").status_code == 200 assert cli.get("/bp/test/1").status_code == 200 assert cli.get("/test/1").status_code == 200 assert cli.get("/bp/test/2").status_code == 400 assert cli.get("/test/2").status_code == 400 assert cli.get("/bp/test/2").status_code == 429 assert cli.get("/bp/test/1").status_code == 429 assert cli.get("/test/1").status_code == 429 assert cli.get("/test/2").status_code == 429 timeline.forward(60) assert cli.get("/bp/test/1").status_code == 200 assert cli.get("/test/1").status_code == 200
def test_default_limits_with_per_route_limit(extension_factory): app, limiter = extension_factory(application_limits=['3/minute']) @app.route("/explicit") @limiter.limit("1/minute") def explicit(): return "explicit" @app.route("/default") def default(): return "default" with app.test_client() as cli: with hiro.Timeline().freeze() as timeline: assert 200 == cli.get("/explicit").status_code assert 429 == cli.get("/explicit").status_code assert 200 == cli.get("/default").status_code assert 429 == cli.get("/default").status_code timeline.forward(60) assert 200 == cli.get("/explicit").status_code assert 200 == cli.get("/default").status_code
def test_dynamic_limits(self): app, limiter = self.build_app({ C.STRATEGY: "moving-window", C.HEADERS_ENABLED: True }) def func(*a): return "1/second; 2/minute" @app.route("/t1") @limiter.limit(func) def t1(): return "t1" with hiro.Timeline().freeze() as timeline: with app.test_client() as cli: self.assertEqual(cli.get("/t1").status_code, 200) self.assertEqual(cli.get("/t1").status_code, 429) timeline.forward(2) self.assertEqual(cli.get("/t1").status_code, 200) self.assertEqual(cli.get("/t1").status_code, 429)
def test_default_limits_with_per_route_limit(self): app, limiter = self.build_app(application_limits=['3/minute']) @app.route("/explicit") @limiter.limit("1/minute") def explicit(): return "explicit" @app.route("/default") def default(): return "default" with app.test_client() as cli: with hiro.Timeline().freeze() as timeline: self.assertEqual(200, cli.get("/explicit").status_code) self.assertEqual(429, cli.get("/explicit").status_code) self.assertEqual(200, cli.get("/default").status_code) self.assertEqual(429, cli.get("/default").status_code) timeline.forward(60) self.assertEqual(200, cli.get("/explicit").status_code) self.assertEqual(200, cli.get("/default").status_code)
def test_dynamic_limits(extension_factory): app, limiter = extension_factory({ C.STRATEGY: "moving-window", C.HEADERS_ENABLED: True }) def func(*a): return "1/second; 2/minute" @app.route("/t1") @limiter.limit(func) def t1(): return "t1" with hiro.Timeline().freeze() as timeline: with app.test_client() as cli: assert cli.get("/t1").status_code == 200 assert cli.get("/t1").status_code == 429 timeline.forward(2) assert cli.get("/t1").status_code == 200 assert cli.get("/t1").status_code == 429
def test_pluggable_views(extension_factory): app, limiter = extension_factory(default_limits=["1/hour"]) class Va(View): methods = ['GET', 'POST'] decorators = [limiter.limit("2/second")] def dispatch_request(self): return request.method.lower() class Vb(View): methods = ['GET'] decorators = [limiter.limit("1/second, 3/minute")] def dispatch_request(self): return request.method.lower() class Vc(View): methods = ['GET'] def dispatch_request(self): return request.method.lower() app.add_url_rule("/a", view_func=Va.as_view("a")) app.add_url_rule("/b", view_func=Vb.as_view("b")) app.add_url_rule("/c", view_func=Vc.as_view("c")) with hiro.Timeline().freeze() as timeline: with app.test_client() as cli: assert 200 == cli.get("/a").status_code assert 200 == cli.get("/a").status_code assert 429 == cli.post("/a").status_code assert 200 == cli.get("/b").status_code timeline.forward(1) assert 200 == cli.get("/b").status_code timeline.forward(1) assert 200 == cli.get("/b").status_code timeline.forward(1) assert 429 == cli.get("/b").status_code assert 200 == cli.get("/c").status_code assert 429 == cli.get("/c").status_code
def test_headers_breach(self): app = Flask(__name__) limiter = Limiter(app, global_limits=["10/minute"], headers_enabled=True) @app.route("/t1") @limiter.limit("2/second; 10 per minute; 20/hour") def t(): return "test" with hiro.Timeline().freeze() as timeline: with app.test_client() as cli: for i in range(11): resp = cli.get("/t1") timeline.forward(1) self.assertEqual(resp.headers.get('X-RateLimit-Limit'), '10') self.assertEqual(resp.headers.get('X-RateLimit-Remaining'), '0') self.assertEqual(resp.headers.get('X-RateLimit-Reset'), str(int(time.time() + 49)))
def test_invalid_decorated_dynamic_limits(self): app = Flask(__name__) app.config.setdefault("X", "2 per sec") limiter = Limiter(app, global_limits=["1/second"], key_func=get_remote_address) mock_handler = mock.Mock() mock_handler.level = logging.INFO limiter.logger.addHandler(mock_handler) @app.route("/t1") @limiter.limit(lambda: current_app.config.get("X")) def t1(): return "42" with app.test_client() as cli: with hiro.Timeline().freeze() as timeline: self.assertEqual(cli.get("/t1").status_code, 200) self.assertEqual(cli.get("/t1").status_code, 429) # 2 for invalid limit, 1 for warning. self.assertEqual(mock_handler.handle.call_count, 3) self.assertTrue("failed to load ratelimit" in mock_handler.handle.call_args_list[0][0][0].msg) self.assertTrue("failed to load ratelimit" in mock_handler.handle.call_args_list[1][0][0].msg) self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[2][0][0].msg)
def test_decorated_limit_immediate(extension_factory): app, limiter = extension_factory(default_limits=["1/minute"]) def append_info(fn): @wraps(fn) def __inner(*args, **kwargs): g.rate_limit = "2/minute" return fn(*args, **kwargs) return __inner @app.route("/", methods=["GET", "POST"]) @append_info @limiter.limit(lambda: g.rate_limit, per_method=True) def root(): return "root" with hiro.Timeline().freeze(): with app.test_client() as cli: assert 200 == cli.get("/").status_code assert 200 == cli.get("/").status_code assert 429 == cli.get("/").status_code
def test_shared_dynamic_limits(self): with hiro.Timeline().freeze() as timeline: with self.settings(MIDDLEWARE_CLASSES=settings.MIDDLEWARE_CLASSES + ("djlimiter.Limiter", ), RATELIMIT_GLOBAL="1/second"): self.assertEqual( self.client.get("/shared/one/").status_code, 200) self.assertEqual( self.client.get("/shared/two/").status_code, 200) self.assertEqual( self.client.get("/shared/five/").status_code, 429) self.assertEqual( self.client.get("/shared/four/").status_code, 200) self.assertEqual( self.client.get("/shared/four/").status_code, 200) self.assertEqual( self.client.get("/shared/four/").status_code, 200) timeline.forward(1) self.assertEqual( self.client.get("/shared/four/").status_code, 200) self.assertEqual( self.client.get("/shared/four/").status_code, 429)
def test_invalid_decorated_dynamic_limits(caplog): app = Flask(__name__) app.config.setdefault("X", "2 per sec") limiter = Limiter(app, default_limits=["1/second"], key_func=get_remote_address) @app.route("/t1") @limiter.limit(lambda: current_app.config.get("X")) def t1(): return "42" with app.test_client() as cli: with hiro.Timeline().freeze(): assert cli.get("/t1").status_code == 200 assert cli.get("/t1").status_code == 429 # 2 for invalid limit, 1 for warning. assert len(caplog.records) == 3 assert ("failed to load ratelimit" in caplog.records[0].msg) assert ("failed to load ratelimit" in caplog.records[1].msg) assert ("exceeded at endpoint" in caplog.records[2].msg) assert caplog.records[2].levelname == 'WARNING'
def test_multiple_decorators_not_response(self): app, limiter = self.build_fastapi_app(key_func=get_ipaddr) @app.get("/t1") @limiter.limit("100 per minute", lambda: "test" ) # effectively becomes a limit for all users @limiter.limit("50/minute") # per ip as per default key_func async def t1(request: Request, response: Response): return {"key": "value"} with hiro.Timeline().freeze() as timeline: cli = TestClient(app) for i in range(0, 100): response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"}) assert response.status_code == 200 if i < 50 else 429 for i in range(50): assert cli.get("/t1").status_code == 200 assert cli.get("/t1").status_code == 429 assert (cli.get("/t1", headers={ "X_FORWARDED_FOR": "127.0.0.3" }).status_code == 429)
def test_invalid_decorated_static_limit_blueprint(self): app = Flask(__name__) limiter = Limiter(app, global_limits=["1/second"]) mock_handler = mock.Mock() mock_handler.level = logging.INFO limiter.logger.addHandler(mock_handler) bp = Blueprint("bp1", __name__) @bp.route("/t1") def t1(): return "42" limiter.limit("2/sec")(bp) app.register_blueprint(bp) with app.test_client() as cli: with hiro.Timeline().freeze() as timeline: self.assertEqual(cli.get("/t1").status_code, 200) self.assertEqual(cli.get("/t1").status_code, 429) self.assertTrue("failed to configure" in mock_handler.handle.call_args_list[0][0][0].msg) self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[1][0][0].msg)