def test_invalid_decorated_static_limit_blueprint(self): app = Flask(__name__) limiter = Limiter(app, default_limits=["1/second"], key_func=get_remote_address) mock_handler = mock.Mock() mock_handler.level = logging.INFO limiter.logger.addHandler(mock_handler) bp = Blueprint("bp1", __name__) @bp.route("/t1") def t1(): return "42" limiter.limit("2/sec")(bp) app.register_blueprint(bp) with app.test_client() as cli: with hiro.Timeline().freeze() as timeline: self.assertEqual(cli.get("/t1").status_code, 200) self.assertEqual(cli.get("/t1").status_code, 429) self.assertTrue("failed to configure" in mock_handler.handle.call_args_list[0][0][0].msg) self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[1][0][0].msg)
def test_invalid_decorated_dynamic_limits_blueprint(self): app = Flask(__name__) app.config.setdefault("X", "2 per sec") limiter = Limiter(app, global_limits=["1/second"]) mock_handler = mock.Mock() mock_handler.level = logging.INFO limiter.logger.addHandler(mock_handler) bp = Blueprint("bp1", __name__) @bp.route("/t1") def t1(): return "42" limiter.limit(lambda: current_app.config.get("X"))(bp) app.register_blueprint(bp) with app.test_client() as cli: with hiro.Timeline().freeze() as timeline: self.assertEqual(cli.get("/t1").status_code, 200) self.assertEqual(cli.get("/t1").status_code, 429) self.assertEqual(mock_handler.handle.call_count, 3) self.assertTrue("failed to load ratelimit" in mock_handler.handle.call_args_list[0][0][0].msg) self.assertTrue("failed to load ratelimit" in mock_handler.handle.call_args_list[1][0][0].msg) self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[2][0][0].msg)
def test_multiple_apps(self): app1 = Flask(__name__) app2 = Flask(__name__) limiter = Limiter(default_limits=["1/second"], key_func=get_remote_address) limiter.init_app(app1) limiter.init_app(app2) @app1.route("/ping") def ping(): return "PONG" @app1.route("/slowping") @limiter.limit("1/minute") def slow_ping(): return "PONG" @app2.route("/ping") @limiter.limit("2/second") def ping_2(): return "PONG" @app2.route("/slowping") @limiter.limit("2/minute") def slow_ping_2(): return "PONG" with hiro.Timeline().freeze() as timeline: with app1.test_client() as cli: self.assertEqual(cli.get("/ping").status_code, 200) self.assertEqual(cli.get("/ping").status_code, 429) timeline.forward(1) self.assertEqual(cli.get("/ping").status_code, 200) self.assertEqual(cli.get("/slowping").status_code, 200) timeline.forward(59) self.assertEqual(cli.get("/slowping").status_code, 429) timeline.forward(1) self.assertEqual(cli.get("/slowping").status_code, 200) with app2.test_client() as cli: self.assertEqual(cli.get("/ping").status_code, 200) self.assertEqual(cli.get("/ping").status_code, 200) self.assertEqual(cli.get("/ping").status_code, 429) timeline.forward(1) self.assertEqual(cli.get("/ping").status_code, 200) self.assertEqual(cli.get("/slowping").status_code, 200) timeline.forward(59) self.assertEqual(cli.get("/slowping").status_code, 200) self.assertEqual(cli.get("/slowping").status_code, 429) timeline.forward(1) self.assertEqual(cli.get("/slowping").status_code, 200)
def test_multiple_apps(self): app1 = Flask(__name__) app2 = Flask(__name__) limiter = Limiter(global_limits = ["1/second"]) limiter.init_app(app1) limiter.init_app(app2) @app1.route("/ping") def ping(): return "PONG" @app1.route("/slowping") @limiter.limit("1/minute") def slow_ping(): return "PONG" @app2.route("/ping") @limiter.limit("2/second") def ping_2(): return "PONG" @app2.route("/slowping") @limiter.limit("2/minute") def slow_ping_2(): return "PONG" with hiro.Timeline().freeze() as timeline: with app1.test_client() as cli: self.assertEqual(cli.get("/ping").status_code, 200) self.assertEqual(cli.get("/ping").status_code, 429) timeline.forward(1) self.assertEqual(cli.get("/ping").status_code, 200) self.assertEqual(cli.get("/slowping").status_code, 200) timeline.forward(59) self.assertEqual(cli.get("/slowping").status_code, 429) timeline.forward(1) self.assertEqual(cli.get("/slowping").status_code, 200) with app2.test_client() as cli: self.assertEqual(cli.get("/ping").status_code, 200) self.assertEqual(cli.get("/ping").status_code, 200) self.assertEqual(cli.get("/ping").status_code, 429) timeline.forward(1) self.assertEqual(cli.get("/ping").status_code, 200) self.assertEqual(cli.get("/slowping").status_code, 200) timeline.forward(59) self.assertEqual(cli.get("/slowping").status_code, 200) self.assertEqual(cli.get("/slowping").status_code, 429) timeline.forward(1) self.assertEqual(cli.get("/slowping").status_code, 200)
def test_whitelisting(self): app = Flask(__name__) limiter = Limiter(app, global_limits=["1/minute"], headers_enabled=True) @app.route("/") def t(): return "test" @limiter.request_filter def w(): if request.headers.get("internal", None) == "true": return True return False with hiro.Timeline().freeze() as timeline: with app.test_client() as cli: self.assertEqual(cli.get("/").status_code, 200) self.assertEqual(cli.get("/").status_code, 429) timeline.forward(60) self.assertEqual(cli.get("/").status_code, 200) for i in range(0, 10): self.assertEqual( cli.get("/", headers={ "internal": "true" }).status_code, 200)
def test_custom_headers_from_config(self): app = Flask(__name__) app.config.setdefault(C.HEADER_LIMIT, "X-Limit") app.config.setdefault(C.HEADER_REMAINING, "X-Remaining") app.config.setdefault(C.HEADER_RESET, "X-Reset") limiter = Limiter(app, default_limits=["10/minute"], headers_enabled=True, key_func=get_remote_address) @app.route("/t1") @limiter.limit("2/second; 10 per minute; 20/hour") def t(): return "test" with hiro.Timeline().freeze() as timeline: with app.test_client() as cli: for i in range(11): resp = cli.get("/t1") timeline.forward(1) self.assertEqual(resp.headers.get('X-Limit'), '10') self.assertEqual(resp.headers.get('X-Remaining'), '0') self.assertEqual(resp.headers.get('X-Reset'), str(int(time.time() + 50)))
def test_custom_headers_from_setter(self): app = Flask(__name__) limiter = Limiter(app, default_limits=["10/minute"], headers_enabled=True, key_func=get_remote_address, retry_after='http-date') limiter._header_mapping[HEADERS.RESET] = 'X-Reset' limiter._header_mapping[HEADERS.LIMIT] = 'X-Limit' limiter._header_mapping[HEADERS.REMAINING] = 'X-Remaining' @app.route("/t1") @limiter.limit("2/second; 10 per minute; 20/hour") def t(): return "test" with hiro.Timeline().freeze(0) as timeline: with app.test_client() as cli: for i in range(11): resp = cli.get("/t1") timeline.forward(1) self.assertEqual(resp.headers.get('X-Limit'), '10') self.assertEqual(resp.headers.get('X-Remaining'), '0') self.assertEqual(resp.headers.get('X-Reset'), str(int(time.time() + 50))) self.assertEqual(resp.headers.get('Retry-After'), 'Thu, 01 Jan 1970 00:01:01 GMT')
def test_headers_no_breach(self): app = Flask(__name__) limiter = Limiter(app, global_limits=["10/minute"], headers_enabled=True) @app.route("/t1") def t1(): return "test" @app.route("/t2") @limiter.limit("2/second; 5 per minute; 10/hour") def t2(): return "test" with hiro.Timeline().freeze() as timeline: with app.test_client() as cli: resp = cli.get("/t1") self.assertEqual(resp.headers.get('X-RateLimit-Limit'), '10') self.assertEqual(resp.headers.get('X-RateLimit-Remaining'), '9') self.assertEqual(resp.headers.get('X-RateLimit-Reset'), str(int(time.time() + 60))) resp = cli.get("/t2") self.assertEqual(resp.headers.get('X-RateLimit-Limit'), '2') self.assertEqual(resp.headers.get('X-RateLimit-Remaining'), '1') self.assertEqual(resp.headers.get('X-RateLimit-Reset'), str(int(time.time() + 1)))
def test_invalid_decorated_dynamic_limits(self): app = Flask(__name__) app.config.setdefault("X", "2 per sec") limiter = Limiter(app, default_limits=["1/second"], key_func=get_remote_address) mock_handler = mock.Mock() mock_handler.level = logging.INFO limiter.logger.addHandler(mock_handler) @app.route("/t1") @limiter.limit(lambda: current_app.config.get("X")) def t1(): return "42" with app.test_client() as cli: with hiro.Timeline().freeze() as timeline: self.assertEqual(cli.get("/t1").status_code, 200) self.assertEqual(cli.get("/t1").status_code, 429) # 2 for invalid limit, 1 for warning. self.assertEqual(mock_handler.handle.call_count, 3) self.assertTrue("failed to load ratelimit" in mock_handler.handle.call_args_list[0][0][0].msg) self.assertTrue("failed to load ratelimit" in mock_handler.handle.call_args_list[1][0][0].msg) self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[2][0][0].msg)
def test_headers_breach(self): app = Flask(__name__) limiter = Limiter( app, global_limits=["10/minute"], headers_enabled=True, key_func=get_remote_address ) @app.route("/t1") @limiter.limit("2/second; 10 per minute; 20/hour") def t(): return "test" with hiro.Timeline().freeze() as timeline: with app.test_client() as cli: for i in range(11): resp = cli.get("/t1") timeline.forward(1) self.assertEqual( resp.headers.get('X-RateLimit-Limit'), '10' ) self.assertEqual( resp.headers.get('X-RateLimit-Remaining'), '0' ) self.assertEqual( resp.headers.get('X-RateLimit-Reset'), str(int(time.time() + 49)) ) self.assertEqual( resp.headers.get('Retry-After'), str(int(49)) )
def test_conditional_limits(self): """Test that the conditional activation of the limits work.""" app = Flask(__name__) limiter = Limiter(app, key_func=get_remote_address) @app.route("/limited") @limiter.limit("1 per day") def limited_route(): return "passed" @app.route("/unlimited") @limiter.limit("1 per day", exempt_when=lambda: True) def never_limited_route(): return "should always pass" is_exempt = False @app.route("/conditional") @limiter.limit("1 per day", exempt_when=lambda: is_exempt) def conditionally_limited_route(): return "conditional" with app.test_client() as cli: self.assertEqual(cli.get("/limited").status_code, 200) self.assertEqual(cli.get("/limited").status_code, 429) self.assertEqual(cli.get("/unlimited").status_code, 200) self.assertEqual(cli.get("/unlimited").status_code, 200) self.assertEqual(cli.get("/conditional").status_code, 200) self.assertEqual(cli.get("/conditional").status_code, 429) is_exempt = True self.assertEqual(cli.get("/conditional").status_code, 200) is_exempt = False self.assertEqual(cli.get("/conditional").status_code, 429)
def test_constructor_arguments_over_config(self): app = Flask(__name__) app.config.setdefault(C.STRATEGY, "fixed-window-elastic-expiry") limiter = Limiter(strategy='moving-window') limiter.init_app(app) app.config.setdefault(C.STORAGE_URL, "redis://localhost:6379") self.assertEqual(type(limiter.limiter), MovingWindowRateLimiter) limiter = Limiter(storage_uri='memcached://localhost:11211') limiter.init_app(app) self.assertEqual(type(limiter.storage), MemcachedStorage)
def build_app(self, config={}, **limiter_args): app = Flask(__name__) for k, v in config.items(): app.config.setdefault(k, v) limiter = Limiter(app, **limiter_args) mock_handler = mock.Mock() mock_handler.level = logging.INFO limiter.logger.addHandler(mock_handler) return app, limiter
def test_invalid_decorated_static_limit_blueprint(self): app = Flask(__name__) limiter = Limiter(app, global_limits=["1/second"]) mock_handler = mock.Mock() mock_handler.level = logging.INFO limiter.logger.addHandler(mock_handler) bp = Blueprint("bp1", __name__) @bp.route("/t1") def t1(): return "42" limiter.limit("2/sec")(bp) app.register_blueprint(bp) with app.test_client() as cli: with hiro.Timeline().freeze() as timeline: self.assertEqual(cli.get("/t1").status_code, 200) self.assertEqual(cli.get("/t1").status_code, 429) self.assertTrue("failed to configure" in mock_handler.handle.call_args_list[0][0][0].msg) self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[1][0][0].msg)
def test_invalid_decorated_dynamic_limits_blueprint(self): app = Flask(__name__) app.config.setdefault("X", "2 per sec") limiter = Limiter(app, global_limits=["1/second"], key_func=get_remote_address) mock_handler = mock.Mock() mock_handler.level = logging.INFO limiter.logger.addHandler(mock_handler) bp = Blueprint("bp1", __name__) @bp.route("/t1") def t1(): return "42" limiter.limit(lambda: current_app.config.get("X"))(bp) app.register_blueprint(bp) with app.test_client() as cli: with hiro.Timeline().freeze() as timeline: self.assertEqual(cli.get("/t1").status_code, 200) self.assertEqual(cli.get("/t1").status_code, 429) self.assertEqual(mock_handler.handle.call_count, 3) self.assertTrue("failed to load ratelimit" in mock_handler.handle.call_args_list[0][0][0].msg) self.assertTrue("failed to load ratelimit" in mock_handler.handle.call_args_list[1][0][0].msg) self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[2][0][0].msg)
def test_logging(self): app = Flask(__name__) limiter = Limiter(app) mock_handler = mock.Mock() mock_handler.level = logging.INFO limiter.logger.addHandler(mock_handler) @app.route("/t1") @limiter.limit("1/minute") def t1(): return "test" with app.test_client() as cli: self.assertEqual(200, cli.get("/t1").status_code) self.assertEqual(429, cli.get("/t1").status_code) self.assertEqual(mock_handler.handle.call_count, 1)
def test_retry_after(self): app = Flask(__name__) _ = Limiter(app, default_limits=["1/minute"], headers_enabled=True, key_func=get_remote_address) @app.route("/t1") def t(): return "test" with hiro.Timeline().freeze() as timeline: with app.test_client() as cli: resp = cli.get("/t1") retry_after = int(resp.headers.get('Retry-After')) self.assertTrue(retry_after > 0) timeline.forward(retry_after) resp = cli.get("/t1") self.assertEqual(resp.status_code, 200)
def test_reuse_logging(self): app = Flask(__name__) app_handler = mock.Mock() app_handler.level = logging.INFO app.logger.addHandler(app_handler) limiter = Limiter(app) for handler in app.logger.handlers: limiter.logger.addHandler(handler) @app.route("/t1") @limiter.limit("1/minute") def t1(): return "42" with app.test_client() as cli: cli.get("/t1") cli.get("/t1") self.assertEqual(app_handler.handle.call_count, 1)
def test_invalid_decorated_static_limits(self): app = Flask(__name__) limiter = Limiter(app, global_limits=["1/second"]) mock_handler = mock.Mock() mock_handler.level = logging.INFO limiter.logger.addHandler(mock_handler) @app.route("/t1") @limiter.limit("2/sec") def t1(): return "42" with app.test_client() as cli: with hiro.Timeline().freeze() as timeline: self.assertEqual(cli.get("/t1").status_code, 200) self.assertEqual(cli.get("/t1").status_code, 429) self.assertTrue("failed to configure" in mock_handler.handle.call_args_list[0][0][0].msg) self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[1][0][0].msg)
def test_conditional_shared_limits(self): """Test that conditional shared limits work.""" app = Flask(__name__) limiter = Limiter(app) @app.route("/limited") @limiter.shared_limit("1 per day", "test_scope") def limited_route(): return "passed" @app.route("/unlimited") @limiter.shared_limit("1 per day", "test_scope", exempt_when=lambda: True) def never_limited_route(): return "should always pass" is_exempt = False @app.route("/conditional") @limiter.shared_limit("1 per day", "test_scope", exempt_when=lambda: is_exempt) def conditionally_limited_route(): return "conditional" with app.test_client() as cli: self.assertEqual(cli.get("/unlimited").status_code, 200) self.assertEqual(cli.get("/unlimited").status_code, 200) self.assertEqual(cli.get("/limited").status_code, 200) self.assertEqual(cli.get("/limited").status_code, 429) self.assertEqual(cli.get("/conditional").status_code, 429) is_exempt = True self.assertEqual(cli.get("/conditional").status_code, 200) is_exempt = False self.assertEqual(cli.get("/conditional").status_code, 429)
def test_custom_headers_from_setter(self): app = Flask(__name__) limiter = Limiter(app, global_limits=["10/minute"], headers_enabled=True) limiter.header_mapping[HEADERS.RESET] = 'X-Reset' limiter.header_mapping[HEADERS.LIMIT] = 'X-Limit' limiter.header_mapping[HEADERS.REMAINING] = 'X-Remaining' @app.route("/t1") @limiter.limit("2/second; 10 per minute; 20/hour") def t(): return "test" with hiro.Timeline().freeze() as timeline: with app.test_client() as cli: for i in range(11): resp = cli.get("/t1") timeline.forward(1) self.assertEqual(resp.headers.get('X-Limit'), '10') self.assertEqual(resp.headers.get('X-Remaining'), '0') self.assertEqual(resp.headers.get('X-Reset'), str(int(time.time() + 49)))