예제 #1
0
    def test_invalid_decorated_static_limit_blueprint(self):
        app = Flask(__name__)
        limiter = Limiter(app,
                          default_limits=["1/second"],
                          key_func=get_remote_address)
        mock_handler = mock.Mock()
        mock_handler.level = logging.INFO
        limiter.logger.addHandler(mock_handler)
        bp = Blueprint("bp1", __name__)

        @bp.route("/t1")
        def t1():
            return "42"

        limiter.limit("2/sec")(bp)
        app.register_blueprint(bp)

        with app.test_client() as cli:
            with hiro.Timeline().freeze() as timeline:
                self.assertEqual(cli.get("/t1").status_code, 200)
                self.assertEqual(cli.get("/t1").status_code, 429)
        self.assertTrue("failed to configure" in
                        mock_handler.handle.call_args_list[0][0][0].msg)
        self.assertTrue("exceeded at endpoint" in
                        mock_handler.handle.call_args_list[1][0][0].msg)
예제 #2
0
    def test_invalid_decorated_dynamic_limits_blueprint(self):
        app = Flask(__name__)
        app.config.setdefault("X", "2 per sec")
        limiter = Limiter(app, global_limits=["1/second"])
        mock_handler = mock.Mock()
        mock_handler.level = logging.INFO
        limiter.logger.addHandler(mock_handler)
        bp = Blueprint("bp1", __name__)

        @bp.route("/t1")
        def t1():
            return "42"

        limiter.limit(lambda: current_app.config.get("X"))(bp)
        app.register_blueprint(bp)

        with app.test_client() as cli:
            with hiro.Timeline().freeze() as timeline:
                self.assertEqual(cli.get("/t1").status_code, 200)
                self.assertEqual(cli.get("/t1").status_code, 429)
        self.assertEqual(mock_handler.handle.call_count, 3)
        self.assertTrue("failed to load ratelimit" in
                        mock_handler.handle.call_args_list[0][0][0].msg)
        self.assertTrue("failed to load ratelimit" in
                        mock_handler.handle.call_args_list[1][0][0].msg)
        self.assertTrue("exceeded at endpoint" in
                        mock_handler.handle.call_args_list[2][0][0].msg)
예제 #3
0
    def test_multiple_apps(self):
        app1 = Flask(__name__)
        app2 = Flask(__name__)

        limiter = Limiter(default_limits=["1/second"],
                          key_func=get_remote_address)
        limiter.init_app(app1)
        limiter.init_app(app2)

        @app1.route("/ping")
        def ping():
            return "PONG"

        @app1.route("/slowping")
        @limiter.limit("1/minute")
        def slow_ping():
            return "PONG"

        @app2.route("/ping")
        @limiter.limit("2/second")
        def ping_2():
            return "PONG"

        @app2.route("/slowping")
        @limiter.limit("2/minute")
        def slow_ping_2():
            return "PONG"

        with hiro.Timeline().freeze() as timeline:
            with app1.test_client() as cli:
                self.assertEqual(cli.get("/ping").status_code, 200)
                self.assertEqual(cli.get("/ping").status_code, 429)
                timeline.forward(1)
                self.assertEqual(cli.get("/ping").status_code, 200)
                self.assertEqual(cli.get("/slowping").status_code, 200)
                timeline.forward(59)
                self.assertEqual(cli.get("/slowping").status_code, 429)
                timeline.forward(1)
                self.assertEqual(cli.get("/slowping").status_code, 200)
            with app2.test_client() as cli:
                self.assertEqual(cli.get("/ping").status_code, 200)
                self.assertEqual(cli.get("/ping").status_code, 200)
                self.assertEqual(cli.get("/ping").status_code, 429)
                timeline.forward(1)
                self.assertEqual(cli.get("/ping").status_code, 200)
                self.assertEqual(cli.get("/slowping").status_code, 200)
                timeline.forward(59)
                self.assertEqual(cli.get("/slowping").status_code, 200)
                self.assertEqual(cli.get("/slowping").status_code, 429)
                timeline.forward(1)
                self.assertEqual(cli.get("/slowping").status_code, 200)
예제 #4
0
    def test_multiple_apps(self):
        app1 = Flask(__name__)
        app2 = Flask(__name__)

        limiter = Limiter(global_limits = ["1/second"])
        limiter.init_app(app1)
        limiter.init_app(app2)

        @app1.route("/ping")
        def ping():
            return "PONG"

        @app1.route("/slowping")
        @limiter.limit("1/minute")
        def slow_ping():
            return "PONG"


        @app2.route("/ping")
        @limiter.limit("2/second")
        def ping_2():
            return "PONG"

        @app2.route("/slowping")
        @limiter.limit("2/minute")
        def slow_ping_2():
            return "PONG"

        with hiro.Timeline().freeze() as timeline:
            with app1.test_client() as cli:
                self.assertEqual(cli.get("/ping").status_code, 200)
                self.assertEqual(cli.get("/ping").status_code, 429)
                timeline.forward(1)
                self.assertEqual(cli.get("/ping").status_code, 200)
                self.assertEqual(cli.get("/slowping").status_code, 200)
                timeline.forward(59)
                self.assertEqual(cli.get("/slowping").status_code, 429)
                timeline.forward(1)
                self.assertEqual(cli.get("/slowping").status_code, 200)
            with app2.test_client() as cli:
                self.assertEqual(cli.get("/ping").status_code, 200)
                self.assertEqual(cli.get("/ping").status_code, 200)
                self.assertEqual(cli.get("/ping").status_code, 429)
                timeline.forward(1)
                self.assertEqual(cli.get("/ping").status_code, 200)
                self.assertEqual(cli.get("/slowping").status_code, 200)
                timeline.forward(59)
                self.assertEqual(cli.get("/slowping").status_code, 200)
                self.assertEqual(cli.get("/slowping").status_code, 429)
                timeline.forward(1)
                self.assertEqual(cli.get("/slowping").status_code, 200)
예제 #5
0
    def test_whitelisting(self):

        app = Flask(__name__)
        limiter = Limiter(app,
                          global_limits=["1/minute"],
                          headers_enabled=True)

        @app.route("/")
        def t():
            return "test"

        @limiter.request_filter
        def w():
            if request.headers.get("internal", None) == "true":
                return True
            return False

        with hiro.Timeline().freeze() as timeline:
            with app.test_client() as cli:
                self.assertEqual(cli.get("/").status_code, 200)
                self.assertEqual(cli.get("/").status_code, 429)
                timeline.forward(60)
                self.assertEqual(cli.get("/").status_code, 200)

                for i in range(0, 10):
                    self.assertEqual(
                        cli.get("/", headers={
                            "internal": "true"
                        }).status_code, 200)
예제 #6
0
    def test_custom_headers_from_config(self):
        app = Flask(__name__)
        app.config.setdefault(C.HEADER_LIMIT, "X-Limit")
        app.config.setdefault(C.HEADER_REMAINING, "X-Remaining")
        app.config.setdefault(C.HEADER_RESET, "X-Reset")
        limiter = Limiter(app,
                          default_limits=["10/minute"],
                          headers_enabled=True,
                          key_func=get_remote_address)

        @app.route("/t1")
        @limiter.limit("2/second; 10 per minute; 20/hour")
        def t():
            return "test"

        with hiro.Timeline().freeze() as timeline:
            with app.test_client() as cli:
                for i in range(11):
                    resp = cli.get("/t1")
                    timeline.forward(1)

                self.assertEqual(resp.headers.get('X-Limit'), '10')
                self.assertEqual(resp.headers.get('X-Remaining'), '0')
                self.assertEqual(resp.headers.get('X-Reset'),
                                 str(int(time.time() + 50)))
예제 #7
0
    def test_custom_headers_from_setter(self):
        app = Flask(__name__)
        limiter = Limiter(app,
                          default_limits=["10/minute"],
                          headers_enabled=True,
                          key_func=get_remote_address,
                          retry_after='http-date')
        limiter._header_mapping[HEADERS.RESET] = 'X-Reset'
        limiter._header_mapping[HEADERS.LIMIT] = 'X-Limit'
        limiter._header_mapping[HEADERS.REMAINING] = 'X-Remaining'

        @app.route("/t1")
        @limiter.limit("2/second; 10 per minute; 20/hour")
        def t():
            return "test"

        with hiro.Timeline().freeze(0) as timeline:
            with app.test_client() as cli:
                for i in range(11):
                    resp = cli.get("/t1")
                    timeline.forward(1)

                self.assertEqual(resp.headers.get('X-Limit'), '10')
                self.assertEqual(resp.headers.get('X-Remaining'), '0')
                self.assertEqual(resp.headers.get('X-Reset'),
                                 str(int(time.time() + 50)))
                self.assertEqual(resp.headers.get('Retry-After'),
                                 'Thu, 01 Jan 1970 00:01:01 GMT')
예제 #8
0
    def test_headers_no_breach(self):
        app = Flask(__name__)
        limiter = Limiter(app,
                          global_limits=["10/minute"],
                          headers_enabled=True)

        @app.route("/t1")
        def t1():
            return "test"

        @app.route("/t2")
        @limiter.limit("2/second; 5 per minute; 10/hour")
        def t2():
            return "test"

        with hiro.Timeline().freeze() as timeline:
            with app.test_client() as cli:
                resp = cli.get("/t1")
                self.assertEqual(resp.headers.get('X-RateLimit-Limit'), '10')
                self.assertEqual(resp.headers.get('X-RateLimit-Remaining'),
                                 '9')
                self.assertEqual(resp.headers.get('X-RateLimit-Reset'),
                                 str(int(time.time() + 60)))
                resp = cli.get("/t2")
                self.assertEqual(resp.headers.get('X-RateLimit-Limit'), '2')
                self.assertEqual(resp.headers.get('X-RateLimit-Remaining'),
                                 '1')
                self.assertEqual(resp.headers.get('X-RateLimit-Reset'),
                                 str(int(time.time() + 1)))
예제 #9
0
    def test_invalid_decorated_dynamic_limits(self):
        app = Flask(__name__)
        app.config.setdefault("X", "2 per sec")
        limiter = Limiter(app,
                          default_limits=["1/second"],
                          key_func=get_remote_address)
        mock_handler = mock.Mock()
        mock_handler.level = logging.INFO
        limiter.logger.addHandler(mock_handler)

        @app.route("/t1")
        @limiter.limit(lambda: current_app.config.get("X"))
        def t1():
            return "42"

        with app.test_client() as cli:
            with hiro.Timeline().freeze() as timeline:
                self.assertEqual(cli.get("/t1").status_code, 200)
                self.assertEqual(cli.get("/t1").status_code, 429)
        # 2 for invalid limit, 1 for warning.
        self.assertEqual(mock_handler.handle.call_count, 3)
        self.assertTrue("failed to load ratelimit" in
                        mock_handler.handle.call_args_list[0][0][0].msg)
        self.assertTrue("failed to load ratelimit" in
                        mock_handler.handle.call_args_list[1][0][0].msg)
        self.assertTrue("exceeded at endpoint" in
                        mock_handler.handle.call_args_list[2][0][0].msg)
예제 #10
0
    def test_headers_breach(self):
        app = Flask(__name__)
        limiter = Limiter(
            app, global_limits=["10/minute"],
            headers_enabled=True, key_func=get_remote_address
        )

        @app.route("/t1")
        @limiter.limit("2/second; 10 per minute; 20/hour")
        def t():
            return "test"

        with hiro.Timeline().freeze() as timeline:
            with app.test_client() as cli:
                for i in range(11):
                    resp = cli.get("/t1")
                    timeline.forward(1)

                self.assertEqual(
                        resp.headers.get('X-RateLimit-Limit'),
                        '10'
                )
                self.assertEqual(
                        resp.headers.get('X-RateLimit-Remaining'),
                        '0'
                )
                self.assertEqual(
                        resp.headers.get('X-RateLimit-Reset'),
                        str(int(time.time() + 49))
                )
                self.assertEqual(
                        resp.headers.get('Retry-After'),
                        str(int(49))
                )
예제 #11
0
    def test_conditional_limits(self):
        """Test that the conditional activation of the limits work."""
        app = Flask(__name__)
        limiter = Limiter(app, key_func=get_remote_address)

        @app.route("/limited")
        @limiter.limit("1 per day")
        def limited_route():
            return "passed"

        @app.route("/unlimited")
        @limiter.limit("1 per day", exempt_when=lambda: True)
        def never_limited_route():
            return "should always pass"

        is_exempt = False

        @app.route("/conditional")
        @limiter.limit("1 per day", exempt_when=lambda: is_exempt)
        def conditionally_limited_route():
            return "conditional"

        with app.test_client() as cli:
            self.assertEqual(cli.get("/limited").status_code, 200)
            self.assertEqual(cli.get("/limited").status_code, 429)

            self.assertEqual(cli.get("/unlimited").status_code, 200)
            self.assertEqual(cli.get("/unlimited").status_code, 200)

            self.assertEqual(cli.get("/conditional").status_code, 200)
            self.assertEqual(cli.get("/conditional").status_code, 429)
            is_exempt = True
            self.assertEqual(cli.get("/conditional").status_code, 200)
            is_exempt = False
            self.assertEqual(cli.get("/conditional").status_code, 429)
예제 #12
0
 def test_constructor_arguments_over_config(self):
     app = Flask(__name__)
     app.config.setdefault(C.STRATEGY, "fixed-window-elastic-expiry")
     limiter = Limiter(strategy='moving-window')
     limiter.init_app(app)
     app.config.setdefault(C.STORAGE_URL, "redis://localhost:6379")
     self.assertEqual(type(limiter.limiter), MovingWindowRateLimiter)
     limiter = Limiter(storage_uri='memcached://localhost:11211')
     limiter.init_app(app)
     self.assertEqual(type(limiter.storage), MemcachedStorage)
예제 #13
0
 def build_app(self, config={}, **limiter_args):
     app = Flask(__name__)
     for k, v in config.items():
         app.config.setdefault(k, v)
     limiter = Limiter(app, **limiter_args)
     mock_handler = mock.Mock()
     mock_handler.level = logging.INFO
     limiter.logger.addHandler(mock_handler)
     return app, limiter
예제 #14
0
    def test_invalid_decorated_static_limit_blueprint(self):
        app = Flask(__name__)
        limiter = Limiter(app, global_limits=["1/second"])
        mock_handler = mock.Mock()
        mock_handler.level = logging.INFO
        limiter.logger.addHandler(mock_handler)
        bp = Blueprint("bp1", __name__)

        @bp.route("/t1")
        def t1():
            return "42"
        limiter.limit("2/sec")(bp)
        app.register_blueprint(bp)

        with app.test_client() as cli:
            with hiro.Timeline().freeze() as timeline:
                self.assertEqual(cli.get("/t1").status_code, 200)
                self.assertEqual(cli.get("/t1").status_code, 429)
        self.assertTrue("failed to configure" in mock_handler.handle.call_args_list[0][0][0].msg)
        self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[1][0][0].msg)
예제 #15
0
    def test_invalid_decorated_dynamic_limits_blueprint(self):
        app = Flask(__name__)
        app.config.setdefault("X", "2 per sec")
        limiter = Limiter(app, global_limits=["1/second"], key_func=get_remote_address)
        mock_handler = mock.Mock()
        mock_handler.level = logging.INFO
        limiter.logger.addHandler(mock_handler)
        bp = Blueprint("bp1", __name__)

        @bp.route("/t1")
        def t1():
            return "42"

        limiter.limit(lambda: current_app.config.get("X"))(bp)
        app.register_blueprint(bp)

        with app.test_client() as cli:
            with hiro.Timeline().freeze() as timeline:
                self.assertEqual(cli.get("/t1").status_code, 200)
                self.assertEqual(cli.get("/t1").status_code, 429)
        self.assertEqual(mock_handler.handle.call_count, 3)
        self.assertTrue("failed to load ratelimit" in mock_handler.handle.call_args_list[0][0][0].msg)
        self.assertTrue("failed to load ratelimit" in mock_handler.handle.call_args_list[1][0][0].msg)
        self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[2][0][0].msg)
예제 #16
0
 def test_constructor_arguments_over_config(self):
     app = Flask(__name__)
     app.config.setdefault(C.STRATEGY, "fixed-window-elastic-expiry")
     limiter = Limiter(strategy='moving-window')
     limiter.init_app(app)
     app.config.setdefault(C.STORAGE_URL, "redis://localhost:6379")
     self.assertEqual(type(limiter.limiter), MovingWindowRateLimiter)
     limiter = Limiter(storage_uri='memcached://localhost:11211')
     limiter.init_app(app)
     self.assertEqual(type(limiter.storage), MemcachedStorage)
예제 #17
0
    def test_logging(self):
        app = Flask(__name__)
        limiter = Limiter(app)
        mock_handler = mock.Mock()
        mock_handler.level = logging.INFO
        limiter.logger.addHandler(mock_handler)

        @app.route("/t1")
        @limiter.limit("1/minute")
        def t1():
            return "test"

        with app.test_client() as cli:
            self.assertEqual(200, cli.get("/t1").status_code)
            self.assertEqual(429, cli.get("/t1").status_code)
        self.assertEqual(mock_handler.handle.call_count, 1)
예제 #18
0
    def test_retry_after(self):
        app = Flask(__name__)
        _ = Limiter(app,
                    default_limits=["1/minute"],
                    headers_enabled=True,
                    key_func=get_remote_address)

        @app.route("/t1")
        def t():
            return "test"

        with hiro.Timeline().freeze() as timeline:
            with app.test_client() as cli:
                resp = cli.get("/t1")
                retry_after = int(resp.headers.get('Retry-After'))
                self.assertTrue(retry_after > 0)
                timeline.forward(retry_after)
                resp = cli.get("/t1")
                self.assertEqual(resp.status_code, 200)
예제 #19
0
    def test_reuse_logging(self):
        app = Flask(__name__)
        app_handler = mock.Mock()
        app_handler.level = logging.INFO
        app.logger.addHandler(app_handler)
        limiter = Limiter(app)
        for handler in app.logger.handlers:
            limiter.logger.addHandler(handler)

        @app.route("/t1")
        @limiter.limit("1/minute")
        def t1():
            return "42"

        with app.test_client() as cli:
            cli.get("/t1")
            cli.get("/t1")

        self.assertEqual(app_handler.handle.call_count, 1)
예제 #20
0
    def test_invalid_decorated_static_limits(self):
        app = Flask(__name__)
        limiter = Limiter(app, global_limits=["1/second"])
        mock_handler = mock.Mock()
        mock_handler.level = logging.INFO
        limiter.logger.addHandler(mock_handler)

        @app.route("/t1")
        @limiter.limit("2/sec")
        def t1():
            return "42"

        with app.test_client() as cli:
            with hiro.Timeline().freeze() as timeline:
                self.assertEqual(cli.get("/t1").status_code, 200)
                self.assertEqual(cli.get("/t1").status_code, 429)
        self.assertTrue("failed to configure" in
                        mock_handler.handle.call_args_list[0][0][0].msg)
        self.assertTrue("exceeded at endpoint" in
                        mock_handler.handle.call_args_list[1][0][0].msg)
예제 #21
0
    def test_conditional_shared_limits(self):
        """Test that conditional shared limits work."""
        app = Flask(__name__)
        limiter = Limiter(app)

        @app.route("/limited")
        @limiter.shared_limit("1 per day", "test_scope")
        def limited_route():
            return "passed"

        @app.route("/unlimited")
        @limiter.shared_limit("1 per day",
                              "test_scope",
                              exempt_when=lambda: True)
        def never_limited_route():
            return "should always pass"

        is_exempt = False

        @app.route("/conditional")
        @limiter.shared_limit("1 per day",
                              "test_scope",
                              exempt_when=lambda: is_exempt)
        def conditionally_limited_route():
            return "conditional"

        with app.test_client() as cli:
            self.assertEqual(cli.get("/unlimited").status_code, 200)
            self.assertEqual(cli.get("/unlimited").status_code, 200)

            self.assertEqual(cli.get("/limited").status_code, 200)
            self.assertEqual(cli.get("/limited").status_code, 429)

            self.assertEqual(cli.get("/conditional").status_code, 429)
            is_exempt = True
            self.assertEqual(cli.get("/conditional").status_code, 200)
            is_exempt = False
            self.assertEqual(cli.get("/conditional").status_code, 429)
예제 #22
0
    def test_custom_headers_from_setter(self):
        app = Flask(__name__)
        limiter = Limiter(app,
                          global_limits=["10/minute"],
                          headers_enabled=True)
        limiter.header_mapping[HEADERS.RESET] = 'X-Reset'
        limiter.header_mapping[HEADERS.LIMIT] = 'X-Limit'
        limiter.header_mapping[HEADERS.REMAINING] = 'X-Remaining'

        @app.route("/t1")
        @limiter.limit("2/second; 10 per minute; 20/hour")
        def t():
            return "test"

        with hiro.Timeline().freeze() as timeline:
            with app.test_client() as cli:
                for i in range(11):
                    resp = cli.get("/t1")
                    timeline.forward(1)

                self.assertEqual(resp.headers.get('X-Limit'), '10')
                self.assertEqual(resp.headers.get('X-Remaining'), '0')
                self.assertEqual(resp.headers.get('X-Reset'),
                                 str(int(time.time() + 49)))