def test_invalid_config(self): config = RateLimitConfig( group="default", limit_overrides={"GET": { "invalid": "invalid" }}) # type: ignore assert config.get_rate_limit( "bloop", "badcategory") == get_default_rate_limits_for_group( "default", RateLimitCategory.ORGANIZATION)
def test_override(self): config = RateLimitConfig( group="default", limit_overrides={"GET": { RateLimitCategory.IP: RateLimit(1, 1) }}) assert config.get_rate_limit("GET", RateLimitCategory.IP) == RateLimit(1, 1) assert config.get_rate_limit( "POST", RateLimitCategory.IP) == get_default_rate_limits_for_group( "default", RateLimitCategory.IP) assert config.get_rate_limit("GET", RateLimitCategory.ORGANIZATION ) == get_default_rate_limits_for_group( "default", RateLimitCategory.ORGANIZATION)
def get_rate_limit_config(endpoint: Type[object]) -> RateLimitConfig | None: """Read the rate limit config from the view function to be used for the rate limit check. If there is no rate limit defined on the endpoint, use the rate limit defined for the group or the default across the board """ rate_limit_config = getattr(endpoint, "rate_limits", DEFAULT_RATE_LIMIT_CONFIG) return RateLimitConfig.from_rate_limit_override_dict(rate_limit_config)
class RateLimitedEndpoint(Endpoint): permission_classes = (AllowAny,) enforce_rate_limit = True rate_limits = RateLimitConfig( group="foo", limit_overrides={ "GET": { RateLimitCategory.IP: RateLimit(0, 1), RateLimitCategory.USER: RateLimit(0, 1), RateLimitCategory.ORGANIZATION: RateLimit(0, 1), }, }, ) def get(self, request): return Response({"ok": True})
class ConcurrentRateLimitedEndpoint(Endpoint): permission_classes = (AllowAny, ) enforce_rate_limit = True rate_limits = RateLimitConfig( group="foo", limit_overrides={ "GET": { RateLimitCategory.IP: RateLimit(20, 1, CONCURRENT_RATE_LIMIT), RateLimitCategory.USER: RateLimit(20, 1, CONCURRENT_RATE_LIMIT), RateLimitCategory.ORGANIZATION: RateLimit(20, 1, CONCURRENT_RATE_LIMIT), }, }, ) def get(self, request): sleep(CONCURRENT_ENDPOINT_DURATION) return Response({"ok": True})
class ChildEndpoint(ParentEndpoint): rate_limits = RateLimitConfig(group="foo", limit_overrides={"GET": {}})
class ParentEndpoint(Endpoint): rate_limits = RateLimitConfig( group="foo", limit_overrides={"GET": {RateLimitCategory.IP: RateLimit(100, 5)}} )
def test_backwards_compatibility(self): override_dict = {"GET": {RateLimitCategory.IP: RateLimit(1, 1)}} assert RateLimitConfig.from_rate_limit_override_dict( override_dict) == RateLimitConfig(group="default", limit_overrides=override_dict)
def test_defaults(self): config = RateLimitConfig() for c in RateLimitCategory: for method in ("POST", "GET", "PUT", "DELETE"): assert isinstance(config.get_rate_limit(method, c), RateLimit)
def test_grouping(self, *m): config = RateLimitConfig(group="blz") assert config.get_rate_limit( "GET", RateLimitCategory.ORGANIZATION) == RateLimit(420, 69)