def test_rate_limiting_custom_backend(self):
        """Test rate limiter with a custom backend."""
        # Intentionally set a bad host
        self.app.config.setdefault('RATELIMITER_BACKEND_OPTIONS',
                                   {'host': 'invalid-host'})
        # Generate a good Redis client and attached backend
        redis = StrictRedis(host=environ.get('REDIS_HOST', 'localhost'))
        SimpleRedisBackend = RateLimiter.get_backend('SimpleRedisBackend')
        redis_backend = SimpleRedisBackend(cache=redis)
        RateLimiter(self.app, backend=redis_backend)

        @self.app.route('/custom-backend-limit')
        @ratelimit(1, 10)
        def test_custom_backend_limit():
            return 'custom-backend-limit'

        with self.app.test_client() as c:
            res = c.get('/custom-backend-limit')
            self.assertEqual(request.endpoint, 'test_custom_backend_limit')
            self.assertEqual(g._rate_limit_info.remaining, 0)
            self.assertEqual(g._rate_limit_info.limit_exceeded, False)
            self.assertEqual(res.status_code, 200)

            res = c.get('/custom-backend-limit')
            self.assertEqual(g._rate_limit_info.remaining, 0)
            self.assertEqual(g._rate_limit_info.limit_exceeded, True)
            self.assertEqual(res.status_code, 429)
Exemple #2
0
    def test_backend_with_app(self):
        """Test simple redis backend with app."""
        cache = Cache(self.app,
                      config={
                          'CACHE_TYPE':
                          'redis',
                          'CACHE_REDIS_HOST':
                          environ.get('REDIS_HOST', 'localhost')
                      })

        self.app.config.setdefault('RATELIMITER_BACKEND',
                                   'FlaskCacheRedisBackend')
        self.app.config.setdefault(
            'RATELIMITER_BACKEND_OPTIONS', {
                'cache': cache,
                'host': environ.get('REDIS_HOST', 'localhost')
            })
        r = RateLimiter(self.app)

        limit_exceeded, remaining, reset = r.backend.update(
            'flask_cache_backend', 3, 5)
        self.assertEqual(limit_exceeded, False)
        self.assertEqual(remaining, 2)

        limit_exceeded, remaining, reset = r.backend.update(
            'flask_cache_backend', 3, 5)
        self.assertEqual(limit_exceeded, False)
        self.assertEqual(remaining, 1)

        limit_exceeded, remaining, reset = r.backend.update(
            'flask_cache_backend', 3, 5)
        self.assertEqual(limit_exceeded, True)
        self.assertEqual(remaining, 0)
    def test_flask_cache_prefix(self):
        """Test rate limiter with Flask-Cache Redis backend."""
        cache = Cache(self.app, config={'CACHE_TYPE': 'redis'})
        prefix = 'flask_cache_prefix'

        self.app.config.setdefault('CACHE_KEY_PREFIX', prefix)
        self.app.config.setdefault('RATELIMITER_BACKEND',
                                   'FlaskCacheRedisBackend')
        self.app.config.setdefault('RATELIMITER_BACKEND_OPTIONS',
                                   {'cache': cache})
        RateLimiter(self.app)
        self.assertEqual(self.app.config['RATELIMITER_KEY_PREFIX'], prefix)
    def test_ratelimit_headers(self):
        """Test rate limiter backend options."""
        self.app.config.setdefault('RATELIMITER_BACKEND_OPTIONS',
                                   {'host': environ.get('REDIS_HOST',
                                                        'localhost')})
        RateLimiter(self.app)

        @self.app.route('/limit3')
        @ratelimit(2, 10)
        def test_limit3():
            return 'limit'

        with self.app.test_client() as c:
            res = c.get('/limit3')
            self.assertEqual(request.endpoint, 'test_limit3')
            self.assertEqual(int(res.headers.get('X-RateLimit-Limit')), 2)
            self.assertEqual(int(res.headers.get('X-RateLimit-Remaining')), 1)
            self.assertIsNot(res.headers.get('X-RateLimit-Reset', None), None)
    def test_rate_limiting_simple_redis_backend(self):
        """Test rate limiter with simple Redis backend."""
        self.app.config.setdefault('RATELIMITER_BACKEND_OPTIONS',
                                   {'host': environ.get('REDIS_HOST',
                                                        'localhost')})
        RateLimiter(self.app)

        @self.app.route('/limit')
        @ratelimit(2, 10)
        def test_limit():
            return 'limit'

        with self.app.test_client() as c:
            res = c.get('/limit')
            self.assertEqual(request.endpoint, 'test_limit')
            self.assertEqual(g._rate_limit_info.limit, 2)
            self.assertEqual(g._rate_limit_info.remaining, 1)
            self.assertEqual(g._rate_limit_info.limit_exceeded, False)
            self.assertEqual(g._rate_limit_info.per, 10)
            self.assertEqual(g._rate_limit_info.send_x_headers, True)
            self.assertEqual(res.status_code, 200)

            res = c.get('/limit')
            self.assertEqual(g._rate_limit_info.limit, 2)
            self.assertEqual(g._rate_limit_info.remaining, 0)
            self.assertEqual(g._rate_limit_info.limit_exceeded, False)
            self.assertEqual(g._rate_limit_info.per, 10)
            self.assertEqual(g._rate_limit_info.send_x_headers, True)
            self.assertEqual(res.status_code, 200)

            res = c.get('/limit')
            self.assertEqual(g._rate_limit_info.limit, 2)
            self.assertEqual(g._rate_limit_info.remaining, 0)
            self.assertEqual(g._rate_limit_info.limit_exceeded, True)
            self.assertEqual(g._rate_limit_info.per, 10)
            self.assertEqual(g._rate_limit_info.send_x_headers, True)
            self.assertIn(six.b('Rate limit was exceeded'), res.data)
            self.assertEqual(res.status_code, 429)

            res = c.get('/limit')
            self.assertEqual(g._rate_limit_info.limit, 2)
            self.assertEqual(g._rate_limit_info.remaining, 0)
            self.assertEqual(g._rate_limit_info.limit_exceeded, True)
            self.assertEqual(res.status_code, 429)
    def test_dynamic_ratelimit(self):
        """Test that passing callable limit to @ratelimit behaves correctly"""
        self.app.config.setdefault('RATELIMITER_BACKEND_OPTIONS',
                                   {'host': environ.get('REDIS_HOST',
                                                        'localhost')})
        RateLimiter(self.app)

        def dynamic_limit(default, **kwargs):
            """
            A function that returns a rate limit based on some incoming
            information
            The function may be passed key_func() and scope_func() as kwargs to
            help to do that, but in this test we directly use the http header
            to test functionality
            """
            self.assertIn('key', kwargs)
            self.assertIn('scope', kwargs)
            factor = request.headers.get('multiply_ratelimit')
            if factor is None:
                return default
            return default * int(factor)

        @self.app.route('/dynamic-limit')
        @ratelimit(lambda **kwargs: dynamic_limit(default=2, **kwargs), 10)
        def test_dynamic_limit():
            return 'limit'

        with self.app.test_client() as c:
            res = c.get('/dynamic-limit')
            self.assertEqual(g._rate_limit_info.limit, 2)
            self.assertEqual(g._rate_limit_info.remaining, 1)
            self.assertEqual(g._rate_limit_info.limit_exceeded, False)
            self.assertEqual(g._rate_limit_info.per, 10)
            self.assertEqual(g._rate_limit_info.send_x_headers, True)
            self.assertEqual(res.status_code, 200)

            res = c.get('/dynamic-limit', headers={'multiply_ratelimit': 10})
            self.assertEqual(g._rate_limit_info.limit, 20)
            self.assertEqual(g._rate_limit_info.remaining, 18)
            self.assertEqual(g._rate_limit_info.limit_exceeded, False)
            self.assertEqual(g._rate_limit_info.per, 10)
            self.assertEqual(g._rate_limit_info.send_x_headers, True)
            self.assertEqual(res.status_code, 200)
    def test_ratelimit_no_headers_sent(self):
        """Test rate limiter header sending."""
        self.app.config.setdefault('RATELIMITER_INJECT_X_HEADERS', False)
        self.app.config.setdefault(
            'RATELIMITER_BACKEND_OPTIONS',
            {'host': environ.get('REDIS_HOST', 'localhost')})
        RateLimiter(self.app)

        @self.app.route('/limit4')
        @ratelimit(2, 10)
        def test_limit4():
            return 'limit'

        with self.app.test_client() as c:
            res = c.get('/limit4')
            self.assertEqual(request.endpoint, 'test_limit4')
            self.assertEqual(res.headers.get('X-RateLimit-Limit', None), None)
            self.assertEqual(res.headers.get('X-RateLimit-Remaining', None),
                             None)
 def test_get_incorrect_backend(self):
     """Test get_backend function with incorrect input."""
     self.assertEqual(RateLimiter.get_backend('CrazyBackendX').__name__,
                      'SimpleRedisBackend')
 def test_get_correct_backend(self):
     """Test get_backend function with correct input."""
     backend = RateLimiter.get_backend('SimpleRedisBackend')
     self.assertEqual(backend.__name__, 'SimpleRedisBackend')