コード例 #1
0
def test_key_func(extension_factory):
    app, limiter = extension_factory()

    @app.route("/t1")
    @limiter.limit("100 per minute", lambda: "test")
    def t1():
        return "test"

    with hiro.Timeline().freeze():
        with app.test_client() as cli:
            for i in range(0, 100):
                assert 200 == \
                    cli.get(
                        "/t1", headers={
                            "X_FORWARDED_FOR": "127.0.0.2"
                        }
                    ).status_code
            assert 429 == cli.get("/t1").status_code
コード例 #2
0
ファイル: test_flask_ext.py プロジェクト: YoApp/flask-limiter
    def test_combined_rate_limits(self):
        app, limiter = self.build_app(
            {C.GLOBAL_LIMITS: "1 per hour; 10 per day"})

        @app.route("/t1")
        @limiter.limit("100 per hour;10/minute")
        def t1():
            return "t1"

        @app.route("/t2")
        def t2():
            return "t2"

        with hiro.Timeline().freeze() as timeline:
            with app.test_client() as cli:
                self.assertEqual(200, cli.get("/t1").status_code)
                self.assertEqual(200, cli.get("/t2").status_code)
                self.assertEqual(429, cli.get("/t2").status_code)
コード例 #3
0
ファイル: test_flask_ext.py プロジェクト: YoApp/flask-limiter
    def test_key_func(self):
        app, limiter = self.build_app()

        @app.route("/t1")
        @limiter.limit("100 per minute", lambda: "test")
        def t1():
            return "test"

        with hiro.Timeline().freeze() as timeline:
            with app.test_client() as cli:
                for i in range(0, 100):
                    self.assertEqual(
                        200,
                        cli.get("/t1",
                                headers={
                                    "X_FORWARDED_FOR": "127.0.0.2"
                                }).status_code)
                self.assertEqual(429, cli.get("/t1").status_code)
コード例 #4
0
ファイル: test_decorators.py プロジェクト: imfht/flaskapps
def test_decorated_limits_with_combined_defaults(extension_factory):
    app, limiter = extension_factory(default_limits=['2/minute'])

    @app.route("/")
    @limiter.limit("1/second", override_defaults=False)
    def root():
        return "root"

    with hiro.Timeline() as timeline:
        with app.test_client() as cli:
            assert 200 == cli.get("/").status_code
            assert 429 == cli.get("/").status_code
            timeline.forward(60)
            assert 200 == cli.get("/").status_code
            timeline.forward(1)
            assert 200 == cli.get("/").status_code
            timeline.forward(1)
            assert 429 == cli.get("/").status_code
コード例 #5
0
    def test_invalid_decorated_static_limits(self):
        app = Flask(__name__)
        limiter = Limiter(app, default_limits=["1/second"], key_func=get_remote_address)
        mock_handler = mock.Mock()
        mock_handler.level = logging.INFO
        limiter.logger.addHandler(mock_handler)

        @app.route("/t1")
        @limiter.limit("2/sec")
        def t1():
            return "42"

        with app.test_client() as cli:
            with hiro.Timeline().freeze() as timeline:
                self.assertEqual(cli.get("/t1").status_code, 200)
                self.assertEqual(cli.get("/t1").status_code, 429)
        self.assertTrue("failed to configure" in mock_handler.handle.call_args_list[0][0][0].msg)
        self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[1][0][0].msg)
コード例 #6
0
    def test_fallback_to_memory_backoff_check(self):
        app, limiter = self.build_app(config={C.ENABLED: True},
                                      default_limits=["5/minute"],
                                      storage_uri="redis://*****:*****@app.route("/t1")
        def t1():
            return "test"

        with app.test_client() as cli:

            def raiser(*a):
                raise Exception("redis dead")

            with hiro.Timeline() as timeline:
                with mock.patch(
                        "redis.client.Redis.execute_command") as exec_command:
                    exec_command.side_effect = raiser
                    self.assertEqual(cli.get("/t1").status_code, 200)
                    self.assertEqual(cli.get("/t1").status_code, 429)
                    timeline.forward(1)
                    self.assertEqual(cli.get("/t1").status_code, 429)
                    timeline.forward(2)
                    self.assertEqual(cli.get("/t1").status_code, 429)
                    timeline.forward(4)
                    self.assertEqual(cli.get("/t1").status_code, 429)
                    timeline.forward(8)
                    self.assertEqual(cli.get("/t1").status_code, 429)
                    timeline.forward(16)
                    self.assertEqual(cli.get("/t1").status_code, 429)
                    timeline.forward(32)
                    self.assertEqual(cli.get("/t1").status_code, 200)
                # redis back to normal, but exponential backoff will only
                # result in it being marked after pow(2,0) seconds and next
                # check
                self.assertEqual(cli.get("/t1").status_code, 429)
                timeline.forward(2)
                self.assertEqual(cli.get("/t1").status_code, 200)
                self.assertEqual(cli.get("/t1").status_code, 200)
                self.assertEqual(cli.get("/t1").status_code, 200)
                self.assertEqual(cli.get("/t1").status_code, 200)
                self.assertEqual(cli.get("/t1").status_code, 200)
                self.assertEqual(cli.get("/t1").status_code, 429)
コード例 #7
0
    def test_custom_error_message(self):
        app, limiter = self.build_app()

        @app.errorhandler(429)
        def ratelimit_handler(e):
            return make_response(
                    e.description
                    , 429
            )

        l1 = lambda: "1/second"
        e1 = lambda: "dos"

        @limiter.limit("1/second", error_message="uno")
        @app.route("/t1")
        def t1():
            return "1"

        @limiter.limit(l1, error_message=e1)
        @app.route("/t2")
        def t2():
            return "2"

        s1 = limiter.shared_limit("1/second", scope='error_message', error_message="tres")

        @app.route("/t3")
        @s1
        def t3():
            return "3"

        with hiro.Timeline().freeze():
            with app.test_client() as cli:
                cli.get("/t1")
                resp = cli.get("/t1")
                self.assertEqual(429, resp.status_code)
                self.assertEqual(resp.data, b'uno')
                cli.get("/t2")
                resp = cli.get("/t2")
                self.assertEqual(429, resp.status_code)
                self.assertEqual(resp.data, b'dos')
                cli.get("/t3")
                resp = cli.get("/t3")
                self.assertEqual(429, resp.status_code)
                self.assertEqual(resp.data, b'tres')
コード例 #8
0
    def test_flask_restful_resource(self):
        app, limiter = self.build_app(
                global_limits=["1/hour"]
        )
        api = restful.Api(app)

        class Va(Resource):
            decorators = [limiter.limit("2/second")]

            def get(self):
                return request.method.lower()

            def post(self):
                return request.method.lower()

        class Vb(Resource):
            decorators = [limiter.limit("1/second, 3/minute")]

            def get(self):
                return request.method.lower()

        class Vc(Resource):
            def get(self):
                return request.method.lower()

        api.add_resource(Va, "/a")
        api.add_resource(Vb, "/b")
        api.add_resource(Vc, "/c")

        with hiro.Timeline().freeze() as timeline:
            with app.test_client() as cli:
                self.assertEqual(200, cli.get("/a").status_code)
                self.assertEqual(200, cli.get("/a").status_code)
                self.assertEqual(429, cli.get("/a").status_code)
                self.assertEqual(429, cli.post("/a").status_code)
                self.assertEqual(200, cli.get("/b").status_code)
                timeline.forward(1)
                self.assertEqual(200, cli.get("/b").status_code)
                timeline.forward(1)
                self.assertEqual(200, cli.get("/b").status_code)
                timeline.forward(1)
                self.assertEqual(429, cli.get("/b").status_code)
                self.assertEqual(200, cli.get("/c").status_code)
                self.assertEqual(429, cli.get("/c").status_code)
コード例 #9
0
    def test_retry_after(self):
        app = Flask(__name__)
        _ = Limiter(app,
                    default_limits=["1/minute"],
                    headers_enabled=True,
                    key_func=get_remote_address)

        @app.route("/t1")
        def t():
            return "test"

        with hiro.Timeline().freeze() as timeline:
            with app.test_client() as cli:
                resp = cli.get("/t1")
                retry_after = int(resp.headers.get('Retry-After'))
                self.assertTrue(retry_after > 0)
                timeline.forward(retry_after)
                resp = cli.get("/t1")
                self.assertEqual(resp.status_code, 200)
コード例 #10
0
    def test_multiple_decorators(self):
        app, limiter = self.build_app(key_func=get_ipaddr)

        @app.route("/t1")
        @limiter.limit("100 per minute", lambda: "test")  # effectively becomes a limit for all users
        @limiter.limit("50/minute")  # per ip as per default key_func
        def t1():
            return "test"

        with hiro.Timeline().freeze() as timeline:
            with app.test_client() as cli:
                for i in range(0, 100):
                    self.assertEqual(200 if i < 50 else 429,
                                     cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"}).status_code
                                     )
                for i in range(50):
                    self.assertEqual(200, cli.get("/t1").status_code)
                self.assertEqual(429, cli.get("/t1").status_code)
                self.assertEqual(429, cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.3"}).status_code)
コード例 #11
0
 def test_fixed_window_with_elastic_expiry_in_memory(self):
     storage = MemoryStorage()
     limiter = FixedWindowElasticExpiryRateLimiter(storage)
     with hiro.Timeline().freeze() as timeline:
         start = int(time.time())
         limit = RateLimitItemPerSecond(10, 2)
         self.assertTrue(all([limiter.hit(limit) for _ in range(0, 10)]))
         timeline.forward(1)
         self.assertFalse(limiter.hit(limit))
         self.assertEqual(limiter.get_window_stats(limit)[1], 0)
         # three extensions to the expiry
         self.assertEqual(limiter.get_window_stats(limit)[0], start + 3)
         timeline.forward(1)
         self.assertFalse(limiter.hit(limit))
         timeline.forward(3)
         start = int(time.time())
         self.assertTrue(limiter.hit(limit))
         self.assertEqual(limiter.get_window_stats(limit)[1], 9)
         self.assertEqual(limiter.get_window_stats(limit)[0], start + 2)
コード例 #12
0
    def test_retry_after(self):
        # FIXME: this test is not actually running!

        app, limiter = self.build_starlette_app(headers_enabled=True,
                                                key_func=get_remote_address)

        @app.route("/t1")
        @limiter.limit("1/minute")
        def t(request: Request):
            return PlainTextResponse("test")

        with hiro.Timeline().freeze() as timeline:
            with TestClient(app) as cli:
                resp = cli.get("/t1")
                retry_after = int(resp.headers.get("Retry-After"))
                assert retry_after > 0
                timeline.forward(retry_after)
                resp = cli.get("/t1")
                assert resp.status_code == 200
コード例 #13
0
 def test_moving_window_in_memory(self):
     storage = MemoryStorage()
     limiter = MovingWindowRateLimiter(storage)
     with hiro.Timeline().freeze() as timeline:
         limit = RateLimitItemPerMinute(10)
         for i in range(0, 5):
             self.assertTrue(limiter.hit(limit))
             self.assertTrue(limiter.hit(limit))
             self.assertEqual(
                 limiter.get_window_stats(limit)[1], 10 - ((i + 1) * 2))
             timeline.forward(10)
         self.assertEqual(limiter.get_window_stats(limit)[1], 0)
         self.assertFalse(limiter.hit(limit))
         timeline.forward(20)
         self.assertEqual(limiter.get_window_stats(limit)[1], 2)
         self.assertEqual(
             limiter.get_window_stats(limit)[0], int(time.time() + 30))
         timeline.forward(31)
         self.assertEqual(limiter.get_window_stats(limit)[1], 10)
コード例 #14
0
    def test_decorated_dynamic_limits(self):
        app, limiter = self.build_app({"X": "2 per second"}, global_limits=["1/second"])

        def request_context_limit():
            limits = {
                "127.0.0.1": "10 per minute",
                "127.0.0.2": "1 per minute"
            }
            remote_addr = (request.access_route and request.access_route[0]) or request.remote_addr or '127.0.0.1'
            limit = limits.setdefault(remote_addr, '1 per minute')
            return limit

        @app.route("/t1")
        @limiter.limit("20/day")
        @limiter.limit(lambda: current_app.config.get("X"))
        @limiter.limit(request_context_limit)
        def t1():
            return "42"

        @app.route("/t2")
        @limiter.limit(lambda: current_app.config.get("X"))
        def t2():
            return "42"

        R1 = {"X_FORWARDED_FOR": "127.0.0.1, 127.0.0.0"}
        R2 = {"X_FORWARDED_FOR": "127.0.0.2"}

        with app.test_client() as cli:
            with hiro.Timeline().freeze() as timeline:
                for i in range(0, 10):
                    self.assertEqual(cli.get("/t1", headers=R1).status_code, 200)
                    timeline.forward(1)
                self.assertEqual(cli.get("/t1", headers=R1).status_code, 429)
                self.assertEqual(cli.get("/t1", headers=R2).status_code, 200)
                self.assertEqual(cli.get("/t1", headers=R2).status_code, 429)
                timeline.forward(60)
                self.assertEqual(cli.get("/t1", headers=R2).status_code, 200)
                self.assertEqual(cli.get("/t2").status_code, 200)
                self.assertEqual(cli.get("/t2").status_code, 200)
                self.assertEqual(cli.get("/t2").status_code, 429)
                timeline.forward(1)
                self.assertEqual(cli.get("/t2").status_code, 200)
コード例 #15
0
    def test_pluggable_views(self):
        app, limiter = self.build_app(
                global_limits=["1/hour"]
        )

        class Va(View):
            methods = ['GET', 'POST']
            decorators = [limiter.limit("2/second")]

            def dispatch_request(self):
                return request.method.lower()

        class Vb(View):
            methods = ['GET']
            decorators = [limiter.limit("1/second, 3/minute")]

            def dispatch_request(self):
                return request.method.lower()

        class Vc(View):
            methods = ['GET']

            def dispatch_request(self):
                return request.method.lower()

        app.add_url_rule("/a", view_func=Va.as_view("a"))
        app.add_url_rule("/b", view_func=Vb.as_view("b"))
        app.add_url_rule("/c", view_func=Vc.as_view("c"))
        with hiro.Timeline().freeze() as timeline:
            with app.test_client() as cli:
                self.assertEqual(200, cli.get("/a").status_code)
                self.assertEqual(200, cli.get("/a").status_code)
                self.assertEqual(429, cli.post("/a").status_code)
                self.assertEqual(200, cli.get("/b").status_code)
                timeline.forward(1)
                self.assertEqual(200, cli.get("/b").status_code)
                timeline.forward(1)
                self.assertEqual(200, cli.get("/b").status_code)
                timeline.forward(1)
                self.assertEqual(429, cli.get("/b").status_code)
                self.assertEqual(200, cli.get("/c").status_code)
                self.assertEqual(429, cli.get("/c").status_code)
コード例 #16
0
    def test_headers_breach(self):
        app, limiter = self.build_starlette_app(headers_enabled=True,
                                                key_func=get_remote_address)

        @app.route("/t1")
        @limiter.limit("2/second; 10 per minute; 20/hour")
        def t(request: Request):
            return PlainTextResponse("test")

        with hiro.Timeline().freeze() as timeline:
            with TestClient(app) as cli:
                for i in range(11):
                    resp = cli.get("/t1")
                    timeline.forward(1)

                assert resp.headers.get("X-RateLimit-Limit") == "10"
                assert resp.headers.get("X-RateLimit-Remaining") == "0"
                assert resp.headers.get("X-RateLimit-Reset") == str(
                    int(time.time() + 50))
                assert resp.headers.get("Retry-After") == str(int(50))
コード例 #17
0
def test_fallback_to_memory(extension_factory):
    app, limiter = extension_factory(config={C.ENABLED: True},
                                     global_limits=["2/minute"],
                                     storage_uri="redis://*****:*****@app.route("/t1")
    def t1():
        return "test"

    @app.route("/t2")
    @limiter.limit("1 per minute")
    def t2():
        return "test"

    with app.test_client() as cli:
        assert cli.get("/t1").status_code == 200
        assert cli.get("/t1").status_code == 200
        assert cli.get("/t1").status_code == 429
        assert cli.get("/t2").status_code == 200
        assert cli.get("/t2").status_code == 429

        def raiser(*a):
            raise Exception("redis dead")

        with patch('limits.storage.RedisStorage.incr') as hit:
            hit.side_effect = raiser
            assert cli.get("/t1").status_code == 200
            assert cli.get("/t1").status_code == 200
            assert cli.get("/t1").status_code == 429
            assert cli.get("/t2").status_code == 200
            assert cli.get("/t2").status_code == 429
        with hiro.Timeline() as timeline:
            timeline.forward(1)
            limiter._storage.storage.flushall()
            assert cli.get("/t2").status_code == 200
            assert cli.get("/t2").status_code == 429
        with patch('limits.storage.RedisStorage.get') as get:
            get.side_effect = raiser
            assert cli.get("/t1").status_code == 200
コード例 #18
0
def test_shared_limit_with_conditional_deduction(extension_factory):
    app, limiter = extension_factory()

    bp = Blueprint("main", __name__)

    limit = limiter.shared_limit(
        "2/minute",
        "not_found",
        deduct_when=lambda response: response.status_code == 400)

    @app.route("/test/<path:path>")
    @limit
    def app_test(path):
        if path != "1":
            raise BadRequest()
        return path

    @bp.route("/test/<path:path>")
    def bp_test(path):
        if path != "1":
            raise BadRequest()
        return path

    limit(bp)

    app.register_blueprint(bp, url_prefix='/bp')

    with hiro.Timeline() as timeline:
        with app.test_client() as cli:
            assert cli.get("/bp/test/1").status_code == 200
            assert cli.get("/bp/test/1").status_code == 200
            assert cli.get("/test/1").status_code == 200
            assert cli.get("/bp/test/2").status_code == 400
            assert cli.get("/test/2").status_code == 400
            assert cli.get("/bp/test/2").status_code == 429
            assert cli.get("/bp/test/1").status_code == 429
            assert cli.get("/test/1").status_code == 429
            assert cli.get("/test/2").status_code == 429
            timeline.forward(60)
            assert cli.get("/bp/test/1").status_code == 200
            assert cli.get("/test/1").status_code == 200
コード例 #19
0
ファイル: test_regressions.py プロジェクト: imfht/flaskapps
def test_default_limits_with_per_route_limit(extension_factory):
    app, limiter = extension_factory(application_limits=['3/minute'])

    @app.route("/explicit")
    @limiter.limit("1/minute")
    def explicit():
        return "explicit"

    @app.route("/default")
    def default():
        return "default"

    with app.test_client() as cli:
        with hiro.Timeline().freeze() as timeline:
            assert 200 == cli.get("/explicit").status_code
            assert 429 == cli.get("/explicit").status_code
            assert 200 == cli.get("/default").status_code
            assert 429 == cli.get("/default").status_code
            timeline.forward(60)
            assert 200 == cli.get("/explicit").status_code
            assert 200 == cli.get("/default").status_code
コード例 #20
0
    def test_dynamic_limits(self):
        app, limiter = self.build_app({
            C.STRATEGY: "moving-window",
            C.HEADERS_ENABLED: True
        })

        def func(*a):
            return "1/second; 2/minute"

        @app.route("/t1")
        @limiter.limit(func)
        def t1():
            return "t1"

        with hiro.Timeline().freeze() as timeline:
            with app.test_client() as cli:
                self.assertEqual(cli.get("/t1").status_code, 200)
                self.assertEqual(cli.get("/t1").status_code, 429)
                timeline.forward(2)
                self.assertEqual(cli.get("/t1").status_code, 200)
                self.assertEqual(cli.get("/t1").status_code, 429)
コード例 #21
0
    def test_default_limits_with_per_route_limit(self):
        app, limiter = self.build_app(application_limits=['3/minute'])

        @app.route("/explicit")
        @limiter.limit("1/minute")
        def explicit():
            return "explicit"

        @app.route("/default")
        def default():
            return "default"

        with app.test_client() as cli:
            with hiro.Timeline().freeze() as timeline:
                self.assertEqual(200, cli.get("/explicit").status_code)
                self.assertEqual(429, cli.get("/explicit").status_code)
                self.assertEqual(200, cli.get("/default").status_code)
                self.assertEqual(429, cli.get("/default").status_code)
                timeline.forward(60)
                self.assertEqual(200, cli.get("/explicit").status_code)
                self.assertEqual(200, cli.get("/default").status_code)
コード例 #22
0
ファイル: test_regressions.py プロジェクト: imfht/flaskapps
def test_dynamic_limits(extension_factory):
    app, limiter = extension_factory({
        C.STRATEGY: "moving-window",
        C.HEADERS_ENABLED: True
    })

    def func(*a):
        return "1/second; 2/minute"

    @app.route("/t1")
    @limiter.limit(func)
    def t1():
        return "t1"

    with hiro.Timeline().freeze() as timeline:
        with app.test_client() as cli:
            assert cli.get("/t1").status_code == 200
            assert cli.get("/t1").status_code == 429
            timeline.forward(2)
            assert cli.get("/t1").status_code == 200
            assert cli.get("/t1").status_code == 429
コード例 #23
0
ファイル: test_views.py プロジェクト: imfht/flaskapps
def test_pluggable_views(extension_factory):
    app, limiter = extension_factory(default_limits=["1/hour"])

    class Va(View):
        methods = ['GET', 'POST']
        decorators = [limiter.limit("2/second")]

        def dispatch_request(self):
            return request.method.lower()

    class Vb(View):
        methods = ['GET']
        decorators = [limiter.limit("1/second, 3/minute")]

        def dispatch_request(self):
            return request.method.lower()

    class Vc(View):
        methods = ['GET']

        def dispatch_request(self):
            return request.method.lower()

    app.add_url_rule("/a", view_func=Va.as_view("a"))
    app.add_url_rule("/b", view_func=Vb.as_view("b"))
    app.add_url_rule("/c", view_func=Vc.as_view("c"))
    with hiro.Timeline().freeze() as timeline:
        with app.test_client() as cli:
            assert 200 == cli.get("/a").status_code
            assert 200 == cli.get("/a").status_code
            assert 429 == cli.post("/a").status_code
            assert 200 == cli.get("/b").status_code
            timeline.forward(1)
            assert 200 == cli.get("/b").status_code
            timeline.forward(1)
            assert 200 == cli.get("/b").status_code
            timeline.forward(1)
            assert 429 == cli.get("/b").status_code
            assert 200 == cli.get("/c").status_code
            assert 429 == cli.get("/c").status_code
コード例 #24
0
    def test_headers_breach(self):
        app = Flask(__name__)
        limiter = Limiter(app,
                          global_limits=["10/minute"],
                          headers_enabled=True)

        @app.route("/t1")
        @limiter.limit("2/second; 10 per minute; 20/hour")
        def t():
            return "test"

        with hiro.Timeline().freeze() as timeline:
            with app.test_client() as cli:
                for i in range(11):
                    resp = cli.get("/t1")
                    timeline.forward(1)

                self.assertEqual(resp.headers.get('X-RateLimit-Limit'), '10')
                self.assertEqual(resp.headers.get('X-RateLimit-Remaining'),
                                 '0')
                self.assertEqual(resp.headers.get('X-RateLimit-Reset'),
                                 str(int(time.time() + 49)))
コード例 #25
0
    def test_invalid_decorated_dynamic_limits(self):
        app = Flask(__name__)
        app.config.setdefault("X", "2 per sec")
        limiter = Limiter(app, global_limits=["1/second"], key_func=get_remote_address)
        mock_handler = mock.Mock()
        mock_handler.level = logging.INFO
        limiter.logger.addHandler(mock_handler)

        @app.route("/t1")
        @limiter.limit(lambda: current_app.config.get("X"))
        def t1():
            return "42"

        with app.test_client() as cli:
            with hiro.Timeline().freeze() as timeline:
                self.assertEqual(cli.get("/t1").status_code, 200)
                self.assertEqual(cli.get("/t1").status_code, 429)
        # 2 for invalid limit, 1 for warning.
        self.assertEqual(mock_handler.handle.call_count, 3)
        self.assertTrue("failed to load ratelimit" in mock_handler.handle.call_args_list[0][0][0].msg)
        self.assertTrue("failed to load ratelimit" in mock_handler.handle.call_args_list[1][0][0].msg)
        self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[2][0][0].msg)
コード例 #26
0
def test_decorated_limit_immediate(extension_factory):
    app, limiter = extension_factory(default_limits=["1/minute"])

    def append_info(fn):
        @wraps(fn)
        def __inner(*args, **kwargs):
            g.rate_limit = "2/minute"
            return fn(*args, **kwargs)

        return __inner

    @app.route("/", methods=["GET", "POST"])
    @append_info
    @limiter.limit(lambda: g.rate_limit, per_method=True)
    def root():
        return "root"

    with hiro.Timeline().freeze():
        with app.test_client() as cli:
            assert 200 == cli.get("/").status_code
            assert 200 == cli.get("/").status_code
            assert 429 == cli.get("/").status_code
コード例 #27
0
 def test_shared_dynamic_limits(self):
     with hiro.Timeline().freeze() as timeline:
         with self.settings(MIDDLEWARE_CLASSES=settings.MIDDLEWARE_CLASSES +
                            ("djlimiter.Limiter", ),
                            RATELIMIT_GLOBAL="1/second"):
             self.assertEqual(
                 self.client.get("/shared/one/").status_code, 200)
             self.assertEqual(
                 self.client.get("/shared/two/").status_code, 200)
             self.assertEqual(
                 self.client.get("/shared/five/").status_code, 429)
             self.assertEqual(
                 self.client.get("/shared/four/").status_code, 200)
             self.assertEqual(
                 self.client.get("/shared/four/").status_code, 200)
             self.assertEqual(
                 self.client.get("/shared/four/").status_code, 200)
             timeline.forward(1)
             self.assertEqual(
                 self.client.get("/shared/four/").status_code, 200)
             self.assertEqual(
                 self.client.get("/shared/four/").status_code, 429)
コード例 #28
0
def test_invalid_decorated_dynamic_limits(caplog):
    app = Flask(__name__)
    app.config.setdefault("X", "2 per sec")
    limiter = Limiter(app,
                      default_limits=["1/second"],
                      key_func=get_remote_address)

    @app.route("/t1")
    @limiter.limit(lambda: current_app.config.get("X"))
    def t1():
        return "42"

    with app.test_client() as cli:
        with hiro.Timeline().freeze():
            assert cli.get("/t1").status_code == 200
            assert cli.get("/t1").status_code == 429
    # 2 for invalid limit, 1 for warning.
    assert len(caplog.records) == 3
    assert ("failed to load ratelimit" in caplog.records[0].msg)
    assert ("failed to load ratelimit" in caplog.records[1].msg)
    assert ("exceeded at endpoint" in caplog.records[2].msg)
    assert caplog.records[2].levelname == 'WARNING'
コード例 #29
0
    def test_multiple_decorators_not_response(self):
        app, limiter = self.build_fastapi_app(key_func=get_ipaddr)

        @app.get("/t1")
        @limiter.limit("100 per minute", lambda: "test"
                       )  # effectively becomes a limit for all users
        @limiter.limit("50/minute")  # per ip as per default key_func
        async def t1(request: Request, response: Response):
            return {"key": "value"}

        with hiro.Timeline().freeze() as timeline:
            cli = TestClient(app)
            for i in range(0, 100):
                response = cli.get("/t1",
                                   headers={"X_FORWARDED_FOR": "127.0.0.2"})
                assert response.status_code == 200 if i < 50 else 429
            for i in range(50):
                assert cli.get("/t1").status_code == 200

            assert cli.get("/t1").status_code == 429
            assert (cli.get("/t1", headers={
                "X_FORWARDED_FOR": "127.0.0.3"
            }).status_code == 429)
コード例 #30
0
    def test_invalid_decorated_static_limit_blueprint(self):
        app = Flask(__name__)
        limiter = Limiter(app, global_limits=["1/second"])
        mock_handler = mock.Mock()
        mock_handler.level = logging.INFO
        limiter.logger.addHandler(mock_handler)
        bp = Blueprint("bp1", __name__)

        @bp.route("/t1")
        def t1():
            return "42"

        limiter.limit("2/sec")(bp)
        app.register_blueprint(bp)

        with app.test_client() as cli:
            with hiro.Timeline().freeze() as timeline:
                self.assertEqual(cli.get("/t1").status_code, 200)
                self.assertEqual(cli.get("/t1").status_code, 429)
        self.assertTrue("failed to configure" in
                        mock_handler.handle.call_args_list[0][0][0].msg)
        self.assertTrue("exceeded at endpoint" in
                        mock_handler.handle.call_args_list[1][0][0].msg)