def __init__(self, app=None , key_func=get_ipaddr , global_limits=[] , headers_enabled=False , strategy=None , storage_uri=None , storage_options={} , auto_check=True , swallow_errors=False , in_memory_fallback=[] ): self.app = app self.logger = logging.getLogger("flask-limiter") self.enabled = True self._global_limits = [] self._in_memory_fallback = [] self._exempt_routes = set() self._request_filters = [] self._headers_enabled = headers_enabled self._header_mapping = {} self._strategy = strategy self._storage_uri = storage_uri self._storage_options = storage_options self._auto_check = auto_check self._swallow_errors = swallow_errors for limit in global_limits: self._global_limits.extend( [ ExtLimit( limit, key_func, None, False, None, None, None ) for limit in parse_many(limit) ] ) for limit in in_memory_fallback: self._in_memory_fallback.extend( [ ExtLimit( limit, key_func, None, False, None, None, None ) for limit in parse_many(limit) ] ) self._route_limits = {} self._dynamic_route_limits = {} self._blueprint_limits = {} self._blueprint_dynamic_limits = {} self._blueprint_exempt = set() self._storage = self._limiter = None self._key_func = key_func self._storage_dead = False self._fallback_limiter = None self.__check_backend_count = 0 self.__last_check_backend = time.time() class BlackHoleHandler(logging.StreamHandler): def emit(*_): return self.logger.addHandler(BlackHoleHandler()) if app: self.init_app(app)
def init_app(self, app): """ :param app: :class:`flask.Flask` instance to rate limit. """ self.enabled = app.config.setdefault(C.ENABLED, True) self._swallow_errors = app.config.setdefault( C.SWALLOW_ERRORS, self._swallow_errors ) self._headers_enabled = ( self._headers_enabled or app.config.setdefault(C.HEADERS_ENABLED, False) ) self._storage_options.update( app.config.get(C.STORAGE_OPTIONS, {}) ) self._storage = storage_from_string( self._storage_uri or app.config.setdefault(C.STORAGE_URL, 'memory://'), ** self._storage_options ) strategy = ( self._strategy or app.config.setdefault(C.STRATEGY, 'fixed-window') ) if strategy not in STRATEGIES: raise ConfigurationError("Invalid rate limiting strategy %s" % strategy) self._limiter = STRATEGIES[strategy](self._storage) self._header_mapping.update({ HEADERS.RESET : self._header_mapping.get(HEADERS.RESET,None) or app.config.setdefault(C.HEADER_RESET, "X-RateLimit-Reset"), HEADERS.REMAINING : self._header_mapping.get(HEADERS.REMAINING,None) or app.config.setdefault(C.HEADER_REMAINING, "X-RateLimit-Remaining"), HEADERS.LIMIT : self._header_mapping.get(HEADERS.LIMIT,None) or app.config.setdefault(C.HEADER_LIMIT, "X-RateLimit-Limit"), }) conf_limits = app.config.get(C.GLOBAL_LIMITS, None) if not self._global_limits and conf_limits: self._global_limits = [ ExtLimit( limit, self._key_func, None, False, None, None, None ) for limit in parse_many(conf_limits) ] fallback_limits = app.config.get(C.IN_MEMORY_FALLBACK, None) if not self._in_memory_fallback and fallback_limits: self._in_memory_fallback = [ ExtLimit( limit, self._key_func, None, False, None, None, None ) for limit in parse_many(fallback_limits) ] if self._auto_check: app.before_request(self.__check_request_limit) app.after_request(self.__inject_headers) if self._in_memory_fallback: self._fallback_storage = MemoryStorage() self._fallback_limiter = STRATEGIES[strategy](self._fallback_storage) # purely for backward compatibility as stated in flask documentation if not hasattr(app, 'extensions'): app.extensions = {} # pragma: no cover app.extensions['limiter'] = self
def _inner(obj): func = key_func or self._key_func is_bp = True if isinstance(obj, Blueprint) else False name = "{}.{}".format(obj.__module__, obj.__name__) if not is_bp else obj.name dynamic_limit, static_limits = None, [] if callable(limit_value): dynamic_limit = ExtLimit(limit_value, func, _scope, per_method, methods, error_message, exempt_when) else: try: static_limits = [ ExtLimit(limit, func, _scope, per_method, methods, error_message, exempt_when) for limit in parse_many(limit_value) ] except ValueError as e: self.logger.error("failed to configure {} {} ({})".format( "view function", name, e)) if is_bp: if dynamic_limit: self._blueprint_dynamic_limits.setdefault( name, []).append(dynamic_limit) else: self._blueprint_limits.setdefault(name, []).extend(static_limits) else: if dynamic_limit: self._dynamic_route_limits.setdefault( name, []).append(dynamic_limit) else: self._route_limits.setdefault(name, []).extend(static_limits) return obj
def _inner(obj): func = key_func or self._key_func name = "{}.{}".format(obj.__module__, obj.__name__) dynamic_limit, static_limits = None, [] if callable(limit_value): dynamic_limit = ExtLimit(limit_value, func, _scope, per_method, methods, error_message, exempt_when) else: try: static_limits = [ ExtLimit(limit, func, _scope, per_method, methods, error_message, exempt_when) for limit in parse_many(limit_value) ] except ValueError as e: self.logger.error("failed to configure {} {} ({})".format( "view function", name, e)) @wraps(obj) def __inner(*a, **k): return obj(*a, **k) if dynamic_limit: self._dynamic_route_limits.setdefault(name, []).append(dynamic_limit) else: self._route_limits.setdefault(name, []).extend(static_limits) return __inner
def init_app(self, app): """ :param app: :class:`sanic.Sanic` instance to rate limit. """ self.app = app self.enabled = app.config.setdefault(C.ENABLED, True) self._swallow_errors = app.config.setdefault(C.SWALLOW_ERRORS, self._swallow_errors) self._storage_options.update(app.config.get(C.STORAGE_OPTIONS, {})) self._storage = storage_from_string( self._storage_uri or app.config.setdefault(C.STORAGE_URL, 'memory://'), **self._storage_options) strategy = (self._strategy or app.config.setdefault(C.STRATEGY, 'fixed-window')) if strategy not in STRATEGIES: raise ConfigurationError("Invalid rate limiting strategy %s" % strategy) self._limiter = STRATEGIES[strategy](self._storage) conf_limits = app.config.get(C.GLOBAL_LIMITS, None) if not self._global_limits and conf_limits: self._global_limits = [ ExtLimit(limit, self._key_func, None, False, None, None, None) for limit in parse_many(conf_limits) ] app.request_middleware.append(self.__check_request_limit)
def __init__(self, app=None, key_func=get_ipaddr, global_limits=[], headers_enabled=False, strategy=None, storage_uri=None, storage_options={}, auto_check=True, swallow_errors=False, outside_routes=set() ): self.app = app self.enabled = True self.global_limits = [] self.exempt_routes = outside_routes self.request_filters = [] self.request_blockers = [] self.headers_enabled = headers_enabled self.header_mapping = {} self.strategy = strategy self.storage_uri = storage_uri self.storage_options = storage_options self.auto_check = auto_check self.swallow_errors = swallow_errors for limit in global_limits: self.global_limits.extend( [ ExtLimit( limit, key_func, None, False, None, None ) for limit in parse_many(limit) ] ) self.route_limits = {} self.dynamic_route_limits = {} self.blueprint_limits = {} self.blueprint_dynamic_limits = {} self.blueprint_exempt = set() self.storage = self.limiter = None self.key_func = key_func self.logger = logging.getLogger("flask-limiter") class BlackHoleHandler(logging.StreamHandler): def emit(*_): return self.logger.addHandler(BlackHoleHandler()) if app: self.init_app(app)
def __inner(fn): @wraps(fn) def _inner(*args, **kwargs): return fn(*args, **kwargs) if fn in DECORATED: DECORATED.setdefault(_inner, DECORATED.pop(fn)) if callable(limit_value): DECORATED.setdefault(_inner, []).append( LimitWrapper(limit_value, key_function, scope) ) else: DECORATED.setdefault(_inner, []).extend([ LimitWrapper( list(parse_many(limit_value)), key_function, scope ) ]) return _inner
def __init__(self, app=None, key_func=get_ipaddr, global_limits=[], headers_enabled=False, strategy=None, storage_uri=None, storage_options={}, auto_check=True, swallow_errors=False): self.app = app self.enabled = True self.global_limits = [] self.exempt_routes = set() self.request_filters = [] self.headers_enabled = headers_enabled self.header_mapping = {} self.strategy = strategy self.storage_uri = storage_uri self.storage_options = storage_options self.auto_check = auto_check self.swallow_errors = swallow_errors for limit in global_limits: self.global_limits.extend([ ExtLimit(limit, key_func, None, False, None, None) for limit in parse_many(limit) ]) self.route_limits = {} self.dynamic_route_limits = {} self.blueprint_limits = {} self.blueprint_dynamic_limits = {} self.blueprint_exempt = set() self.storage = self.limiter = None self.key_func = key_func self.logger = logging.getLogger("flask-limiter") class BlackHoleHandler(logging.StreamHandler): def emit(*_): return self.logger.addHandler(BlackHoleHandler()) if app: self.init_app(app)
def __init__(self): conf_limits = getattr(settings, C.GLOBAL_LIMITS, "") callback = getattr(settings, C.CALLBACK, self.__raise_exceeded) self.enabled = getattr(settings, C.ENABLED, True) self.headers_enabled = getattr(settings, C.HEADERS_ENABLED, False) self.strategy = getattr(settings, C.STRATEGY, 'fixed-window') if self.strategy not in STRATEGIES: raise ConfigurationError("Invalid rate limiting strategy %s" % self.strategy) self.storage = storage_from_string( getattr(settings, C.STORAGE_URL, "memory://")) self.limiter = STRATEGIES[self.strategy](self.storage) self.key_function = getattr(settings, C.DEFAULT_KEY_FUNCTION, get_ipaddr) self.global_limits = [] if conf_limits: self.global_limits = [ LimitWrapper(list(parse_many(conf_limits)), self.key_function, None, False) ] self.header_mapping = { HEADERS.RESET: getattr(settings, C.HEADER_RESET, "X-RateLimit-Reset"), HEADERS.REMAINING: getattr(settings, C.HEADER_REMAINING, "X-RateLimit-Remaining"), HEADERS.LIMIT: getattr(settings, C.HEADER_LIMIT, "X-RateLimit-Limit"), } self.logger = logging.getLogger("djlimiter") self.logger.addHandler(BlackHoleHandler()) if isinstance(callback, six.string_types): mod, _, name = callback.rpartition(".") try: self.callback = getattr(importlib.import_module(mod), name) except AttributeError: self.logger.error( "Unable to load callback function %s. Rate limiting disabled", callback) self.enabled = False else: self.callback = callback
def _inner(obj): func = key_func or self._key_func is_route = not isinstance(obj, Blueprint) name = "%s.%s" % (obj.__module__, obj.__name__) if is_route else obj.name dynamic_limit, static_limits = None, [] if callable(limit_value): dynamic_limit = ExtLimit(limit_value, func, _scope, per_method, methods, error_message, exempt_when) else: try: static_limits = [ExtLimit( limit, func, _scope, per_method, methods, error_message, exempt_when ) for limit in parse_many(limit_value)] except ValueError as e: self.logger.error( "failed to configure %s %s (%s)", "view function" if is_route else "blueprint", name, e ) if isinstance(obj, Blueprint): if dynamic_limit: self._blueprint_dynamic_limits.setdefault(name, []).append( dynamic_limit ) else: self._blueprint_limits.setdefault(name, []).extend( static_limits ) else: @wraps(obj) def __inner(*a, **k): return obj(*a, **k) if dynamic_limit: self._dynamic_route_limits.setdefault(name, []).append( dynamic_limit ) else: self._route_limits.setdefault(name, []).extend( static_limits ) return __inner
def __init__(self, app=None, key_func=None, global_limits=[], strategy=None, storage_uri=None, storage_options={}, swallow_errors=False): self.app = app self.logger = logging.getLogger("sanic-limiter") self.enabled = True self._global_limits = [] self._exempt_routes = set() self._request_filters = [] self._strategy = strategy self._storage_uri = storage_uri self._storage_options = storage_options self._swallow_errors = swallow_errors self._key_func = key_func or get_remote_address for limit in global_limits: self._global_limits.extend([ ExtLimit(limit, self._key_func, None, False, None, None, None) for limit in parse_many(limit) ]) self._route_limits = {} self._dynamic_route_limits = {} self._blueprint_dynamic_limits = {} self._blueprint_limits = {} self._storage = None self._limiter = None self._storage_dead = False class BlackHoleHandler(logging.StreamHandler): def emit(*_): return self.logger.addHandler(BlackHoleHandler()) if app: self.init_app(app)
def __init__(self): conf_limits = getattr(settings, C.GLOBAL_LIMITS, "") callback = getattr(settings, C.CALLBACK, self.__raise_exceeded ) self.enabled = getattr(settings, C.ENABLED, True) self.headers_enabled = getattr(settings, C.HEADERS_ENABLED, False) self.strategy = getattr(settings, C.STRATEGY, 'fixed-window') if self.strategy not in STRATEGIES: raise ConfigurationError("Invalid rate limiting strategy %s" % self.strategy) self.storage = storage_from_string(getattr(settings, C.STORAGE_URL, "memory://")) self.limiter = STRATEGIES[self.strategy](self.storage) self.key_function = getattr(settings, C.DEFAULT_KEY_FUNCTION, get_ipaddr) self.global_limits = [] if conf_limits: self.global_limits = [ LimitWrapper( list(parse_many(conf_limits)), self.key_function, None, False ) ] self.header_mapping = { HEADERS.RESET : getattr(settings,C.HEADER_RESET, "X-RateLimit-Reset"), HEADERS.REMAINING : getattr(settings,C.HEADER_REMAINING, "X-RateLimit-Remaining"), HEADERS.LIMIT : getattr(settings,C.HEADER_LIMIT, "X-RateLimit-Limit"), } self.logger = logging.getLogger("djlimiter") self.logger.addHandler(BlackHoleHandler()) if isinstance(callback, six.string_types): mod, _, name = callback.rpartition(".") try: self.callback = getattr(importlib.import_module(mod), name) except AttributeError: self.logger.error( "Unable to load callback function %s. Rate limiting disabled", callback ) self.enabled = False else: self.callback = callback
def __check_request_limit(self): endpoint = request.endpoint or "" view_func = current_app.view_functions.get(endpoint, None) name = ("%s.%s" % (view_func.__module__, view_func.__name__) if view_func else "") if (not request.endpoint or not self.enabled or view_func == current_app.send_static_file or name in self.exempt_routes or request.blueprint in self.blueprint_exempt or any(fn() for fn in self.request_filters)): return limits = (name in self.route_limits and self.route_limits[name] or []) dynamic_limits = [] if name in self.dynamic_route_limits: for lim in self.dynamic_route_limits[name]: try: dynamic_limits.extend( ExtLimit(limit, lim.key_func, lim.scope, lim. per_method, lim.methods, lim.error_message) for limit in parse_many(lim.limit)) except ValueError as e: self.logger.error( "failed to load ratelimit for view function %s (%s)", name, e) if request.blueprint: if (request.blueprint in self.blueprint_dynamic_limits and not dynamic_limits): for lim in self.blueprint_dynamic_limits[request.blueprint]: try: dynamic_limits.extend( ExtLimit(limit, lim.key_func, lim.scope, lim.per_method, lim.methods, lim.error_message) for limit in parse_many(lim.limit)) except ValueError as e: self.logger.error( "failed to load ratelimit for blueprint %s (%s)", request.blueprint, e) if (request.blueprint in self.blueprint_limits and not limits): limits.extend(self.blueprint_limits[request.blueprint]) failed_limit = None limit_for_header = None try: for lim in (limits + dynamic_limits or self.global_limits): limit_scope = lim.scope or endpoint if lim.methods is not None and request.method.lower( ) not in lim.methods: return if lim.per_method: limit_scope += ":%s" % request.method if not limit_for_header or lim.limit < limit_for_header[0]: limit_for_header = (lim.limit, lim.key_func(), limit_scope) if not self.limiter.hit(lim.limit, lim.key_func(), limit_scope): self.logger.warning( "ratelimit %s (%s) exceeded at endpoint: %s", lim.limit, lim.key_func(), limit_scope) failed_limit = lim limit_for_header = (lim.limit, lim.key_func(), limit_scope) break g.view_rate_limit = limit_for_header if failed_limit: if failed_limit.error_message: exc_description = failed_limit.error_message if not callable( failed_limit.error_message ) else failed_limit.error_message() else: exc_description = six.text_type(failed_limit.limit) raise RateLimitExceeded(exc_description) except Exception: # no qa if self.swallow_errors: self.logger.exception("Failed to rate limit. Swallowing error") else: six.reraise(*sys.exc_info())
def get_limits(self, request): return list(parse_many(self._limits(request))) if callable( self._limits) else self._limits
def __check_request_limit(self): endpoint = request.endpoint or "" view_func = current_app.view_functions.get(endpoint, None) name = ("%s.%s" % ( view_func.__module__, view_func.__name__ ) if view_func else "" ) if (not request.endpoint or not self.enabled or view_func == current_app.send_static_file or name in self._exempt_routes or request.blueprint in self._blueprint_exempt or any(fn() for fn in self._request_filters) ): return limits = ( name in self._route_limits and self._route_limits[name] or [] ) dynamic_limits = [] if name in self._dynamic_route_limits: for lim in self._dynamic_route_limits[name]: try: dynamic_limits.extend( ExtLimit( limit, lim.key_func, lim.scope, lim.per_method, lim.methods, lim.error_message, lim.exempt_when ) for limit in parse_many(lim.limit) ) except ValueError as e: self.logger.error( "failed to load ratelimit for view function %s (%s)" , name, e ) if request.blueprint: if (request.blueprint in self._blueprint_dynamic_limits and not dynamic_limits ): for lim in self._blueprint_dynamic_limits[request.blueprint]: try: dynamic_limits.extend( ExtLimit( limit, lim.key_func, lim.scope, lim.per_method, lim.methods, lim.error_message, lim.exempt_when ) for limit in parse_many(lim.limit) ) except ValueError as e: self.logger.error( "failed to load ratelimit for blueprint %s (%s)" , request.blueprint, e ) if (request.blueprint in self._blueprint_limits and not limits ): limits.extend(self._blueprint_limits[request.blueprint]) failed_limit = None limit_for_header = None try: all_limits = [] if self._storage_dead and self._fallback_limiter: if self.__should_check_backend() and self._storage.check(): self.logger.info( "Rate limit storage recovered" ) self._storage_dead = False self.__check_backend_count = 0 else: all_limits = self._in_memory_fallback if not all_limits: all_limits = (limits + dynamic_limits or self._global_limits) for lim in all_limits: limit_scope = lim.scope or endpoint if lim.is_exempt: return if lim.methods is not None and request.method.lower() not in lim.methods: return if lim.per_method: limit_scope += ":%s" % request.method if not limit_for_header or lim.limit < limit_for_header[0]: limit_for_header = (lim.limit, lim.key_func(), limit_scope) if not self.limiter.hit(lim.limit, lim.key_func(), limit_scope): self.logger.warning( "ratelimit %s (%s) exceeded at endpoint: %s" , lim.limit, lim.key_func(), limit_scope ) failed_limit = lim limit_for_header = (lim.limit, lim.key_func(), limit_scope) break g.view_rate_limit = limit_for_header if failed_limit: if failed_limit.error_message: exc_description = failed_limit.error_message if not callable( failed_limit.error_message ) else failed_limit.error_message() else: exc_description = six.text_type(failed_limit.limit) raise RateLimitExceeded(exc_description) except Exception as e: # no qa if isinstance(e, RateLimitExceeded): six.reraise(*sys.exc_info()) if self._in_memory_fallback and not self._storage_dead: self.logger.warn( "Rate limit storage unreachable - falling back to" " in-memory storage" ) self._storage_dead = True self.__check_request_limit() else: if self._swallow_errors: self.logger.exception( "Failed to rate limit. Swallowing error" ) else: six.reraise(*sys.exc_info())
def __init__(self, app=None , key_func=None , global_limits=[] , headers_enabled=False , strategy=None , storage_uri=None , storage_options={} , auto_check=True , swallow_errors=False , in_memory_fallback=[] , retry_after=None ): self.app = app self.logger = logging.getLogger("flask-limiter") self.enabled = True self._global_limits = [] self._in_memory_fallback = [] self._exempt_routes = set() self._request_filters = [] self._headers_enabled = headers_enabled self._header_mapping = {} self._retry_after = retry_after self._strategy = strategy self._storage_uri = storage_uri self._storage_options = storage_options self._auto_check = auto_check self._swallow_errors = swallow_errors if not key_func: warnings.warn( "Use of the default `get_ipaddr` function is discouraged." " Please refer to https://flask-limiter.readthedocs.org/#rate-limit-domain" " for the recommended configuration", UserWarning ) self._key_func = key_func or get_ipaddr for limit in global_limits: self._global_limits.extend( [ ExtLimit( limit, self._key_func, None, False, None, None, None ) for limit in parse_many(limit) ] ) for limit in in_memory_fallback: self._in_memory_fallback.extend( [ ExtLimit( limit, self._key_func, None, False, None, None, None ) for limit in parse_many(limit) ] ) self._route_limits = {} self._dynamic_route_limits = {} self._blueprint_limits = {} self._blueprint_dynamic_limits = {} self._blueprint_exempt = set() self._storage = self._limiter = None self._storage_dead = False self._fallback_limiter = None self.__check_backend_count = 0 self.__last_check_backend = time.time() class BlackHoleHandler(logging.StreamHandler): def emit(*_): return self.logger.addHandler(BlackHoleHandler()) if app: self.init_app(app)
def test_parse_many_csv(self): parsed = parse_many("1 per 3 hour, 1 per second") self.assertEqual(len(parsed), 2) self.assertEqual(parsed[0].get_expiry(), 3 * 60 * 60) self.assertEqual(parsed[1].get_expiry(), 1)
def __check_request_limit(self): endpoint = request.endpoint or "" view_func = current_app.view_functions.get(endpoint, None) name = ("%s.%s" % ( view_func.__module__, view_func.__name__ ) if view_func else "" ) if (not request.endpoint or not self.enabled or view_func == current_app.send_static_file or name in self.exempt_routes or request.blueprint in self.blueprint_exempt or any(fn() for fn in self.request_filters) ): return limits = ( name in self.route_limits and self.route_limits[name] or [] ) dynamic_limits = [] if name in self.dynamic_route_limits: for lim in self.dynamic_route_limits[name]: try: dynamic_limits.extend( ExtLimit( limit, lim.key_func, lim.scope, lim.per_method, lim.methods, lim.error_message ) for limit in parse_many(lim.limit) ) except ValueError as e: self.logger.error( "failed to load ratelimit for view function %s (%s)" , name, e ) if request.blueprint: if (request.blueprint in self.blueprint_dynamic_limits and not dynamic_limits ): for lim in self.blueprint_dynamic_limits[request.blueprint]: try: dynamic_limits.extend( ExtLimit( limit, lim.key_func, lim.scope, lim.per_method, lim.methods, lim.error_message ) for limit in parse_many(lim.limit) ) except ValueError as e: self.logger.error( "failed to load ratelimit for blueprint %s (%s)" , request.blueprint, e ) if (request.blueprint in self.blueprint_limits and not limits ): limits.extend(self.blueprint_limits[request.blueprint]) failed_limit = None limit_for_header = None try: for lim in (limits + dynamic_limits or self.global_limits): limit_scope = lim.scope or endpoint if lim.methods is not None and request.method.lower() not in lim.methods: return if lim.per_method: limit_scope += ":%s" % request.method if not limit_for_header or lim.limit < limit_for_header[0]: limit_for_header = (lim.limit, lim.key_func(), limit_scope) if not self.limiter.hit(lim.limit, lim.key_func(), limit_scope): self.logger.warning( "ratelimit %s (%s) exceeded at endpoint: %s" , lim.limit, lim.key_func(), limit_scope ) failed_limit = lim limit_for_header = (lim.limit, lim.key_func(), limit_scope) break g.view_rate_limit = limit_for_header if failed_limit: if failed_limit.error_message: exc_description = failed_limit.error_message if not callable( failed_limit.error_message ) else failed_limit.error_message() else: exc_description = six.text_type(failed_limit.limit) raise RateLimitExceeded(exc_description) except Exception: # no qa if self.swallow_errors: self.logger.exception( "Failed to rate limit. Swallowing error" ) else: six.reraise(*sys.exc_info())
def test_parse_many(self): parsed = parse_many("1 per 3 hour; 1 per second") self.assertEqual(len(parsed), 2) self.assertEqual(parsed[0].get_expiry(), 3 * 60 * 60) self.assertEqual(parsed[1].get_expiry(), 1)
def __check_request_limit(self): endpoint = request.endpoint or "" view_func = current_app.view_functions.get(endpoint, None) name = ("%s.%s" % (view_func.__module__, view_func.__name__) if view_func else "") if (not request.endpoint or not self.enabled or view_func == current_app.send_static_file or name in self._exempt_routes or request.blueprint in self._blueprint_exempt or any(fn() for fn in self._request_filters)): return limits = (name in self._route_limits and self._route_limits[name] or []) dynamic_limits = [] if name in self._dynamic_route_limits: for lim in self._dynamic_route_limits[name]: try: dynamic_limits.extend( ExtLimit(limit, lim.key_func, lim.scope, lim.per_method, lim.methods, lim.error_message, lim.exempt_when) for limit in parse_many(lim.limit)) except ValueError as e: self.logger.error( "failed to load ratelimit for view function %s (%s)", name, e) if request.blueprint: if (request.blueprint in self._blueprint_dynamic_limits and not dynamic_limits): for lim in self._blueprint_dynamic_limits[request.blueprint]: try: dynamic_limits.extend( ExtLimit(limit, lim.key_func, lim.scope, lim.per_method, lim.methods, lim.error_message, lim.exempt_when) for limit in parse_many(lim.limit)) except ValueError as e: self.logger.error( "failed to load ratelimit for blueprint %s (%s)", request.blueprint, e) if (request.blueprint in self._blueprint_limits and not limits): limits.extend(self._blueprint_limits[request.blueprint]) failed_limit = None limit_for_header = None try: all_limits = [] if self._storage_dead and self._fallback_limiter: if self.__should_check_backend() and self._storage.check(): self.logger.info("Rate limit storage recovered") self._storage_dead = False self.__check_backend_count = 0 else: all_limits = self._in_memory_fallback if not all_limits: all_limits = self._application_limits + ( limits + dynamic_limits or self._default_limits) for lim in all_limits: limit_scope = lim.scope or endpoint if lim.is_exempt: return if lim.methods is not None and request.method.lower( ) not in lim.methods: return if lim.per_method: limit_scope += ":%s" % request.method if not limit_for_header or lim.limit < limit_for_header[0]: limit_for_header = (lim.limit, lim.key_func(), limit_scope) if not self.limiter.hit(lim.limit, lim.key_func(), limit_scope): self.logger.warning( "ratelimit %s (%s) exceeded at endpoint: %s", lim.limit, lim.key_func(), limit_scope) failed_limit = lim limit_for_header = (lim.limit, lim.key_func(), limit_scope) break g.view_rate_limit = limit_for_header if failed_limit: if failed_limit.error_message: exc_description = failed_limit.error_message if not callable( failed_limit.error_message ) else failed_limit.error_message() else: exc_description = six.text_type(failed_limit.limit) raise RateLimitExceeded(exc_description) except Exception as e: # no qa if isinstance(e, RateLimitExceeded): six.reraise(*sys.exc_info()) if self._in_memory_fallback and not self._storage_dead: self.logger.warn( "Rate limit storage unreachable - falling back to" " in-memory storage") self._storage_dead = True self.__check_request_limit() else: if self._swallow_errors: self.logger.exception( "Failed to rate limit. Swallowing error") else: six.reraise(*sys.exc_info())
def get_limits(self, request): return list(parse_many(self._limits(request))) if callable(self._limits) else self._limits
def __init__(self, app=None, key_func=None, global_limits=[], default_limits=[], application_limits=[], headers_enabled=False, strategy=None, storage_uri=None, storage_options={}, auto_check=True, swallow_errors=False, in_memory_fallback=[], retry_after=None): self.app = app self.logger = logging.getLogger("flask-limiter") self.enabled = True self._default_limits = [] self._application_limits = [] self._in_memory_fallback = [] self._exempt_routes = set() self._request_filters = [] self._headers_enabled = headers_enabled self._header_mapping = {} self._retry_after = retry_after self._strategy = strategy self._storage_uri = storage_uri self._storage_options = storage_options self._auto_check = auto_check self._swallow_errors = swallow_errors if not key_func: warnings.warn( "Use of the default `get_ipaddr` function is discouraged." " Please refer to https://flask-limiter.readthedocs.org/#rate-limit-domain" " for the recommended configuration", UserWarning) if global_limits: self.raise_global_limits_warning() self._key_func = key_func or get_ipaddr for limit in set(global_limits + default_limits): self._default_limits.extend([ ExtLimit(limit, self._key_func, None, False, None, None, None) for limit in parse_many(limit) ]) for limit in application_limits: self._application_limits.extend([ ExtLimit(limit, self._key_func, "global", False, None, None, None) for limit in parse_many(limit) ]) for limit in in_memory_fallback: self._in_memory_fallback.extend([ ExtLimit(limit, self._key_func, None, False, None, None, None) for limit in parse_many(limit) ]) self._route_limits = {} self._dynamic_route_limits = {} self._blueprint_limits = {} self._blueprint_dynamic_limits = {} self._blueprint_exempt = set() self._storage = self._limiter = None self._storage_dead = False self._fallback_limiter = None self.__check_backend_count = 0 self.__last_check_backend = time.time() class BlackHoleHandler(logging.StreamHandler): def emit(*_): return self.logger.addHandler(BlackHoleHandler()) if app: self.init_app(app)
def __check_request_limit(self, request): endpoint = request.path or "" view_handler = self.app.router.routes_static.get(endpoint, None) if view_handler is None: return view_func = view_handler.handler view_bpname = view_func.__dict__.get('__blueprintname__', None) name = ("{}.{}".format(view_func.__module__, view_func.__name__) if view_func else "") if (not endpoint or not self.enabled or name in self._exempt_routes or any(fn() for fn in self._request_filters)): return limits = self._route_limits.get(name, []) dynamic_limits = [] if name in self._dynamic_route_limits: for lim in self._dynamic_route_limits[name]: try: dynamic_limits.extend( ExtLimit(limit, lim.key_func, lim.scope, lim.per_method, lim.methods, lim.error_message, lim.exempt_when) for limit in parse_many(lim.limit)) except ValueError as e: self.logger.error( "failed to load ratelimit for view function %s (%s)", name, e) if view_bpname: if view_bpname in self._blueprint_dynamic_limits and not dynamic_limits: for lim in self._blueprint_dynamic_limits[view_bpname]: try: dynamic_limits.extend( ExtLimit(limit, lim.key_func, lim.scope, lim.per_method, lim.methods, lim.error_message, lim.exempt_when) for limit in parse_many(lim.limit)) except ValueError as e: self.logger.error( "failed to load ratelimit for view blueprint %s (%s)", view_bpname, e) if view_bpname in self._blueprint_limits and not limits: limits.extend(self._blueprint_limits[view_bpname]) failed_limit = None try: for lim in (limits + dynamic_limits or self._global_limits): limit_scope = lim.scope or endpoint if lim.is_exempt: return if lim.methods is not None and request.method.lower( ) not in lim.methods: return if lim.per_method: limit_scope += ":%s" % request.method for param in inspect.signature( lim.key_func).parameters.values(): if param.default is inspect.Parameter.empty: key = lim.key_func(request) else: key = lim.key_func() break else: key = lim.key_func() if key is None: # Ignore empty result of the key function. continue if not self.limiter.hit(lim.limit, key, limit_scope): self.logger.warning( "ratelimit %s (%s) exceeded at endpoint: %s", lim.limit, key, limit_scope) failed_limit = lim break if failed_limit: if failed_limit.error_message: exc_description = failed_limit.error_message if not callable( failed_limit.error_message ) else failed_limit.error_message() else: exc_description = six.text_type(failed_limit.limit) raise RateLimitExceeded(exc_description) except Exception as e: # no qa if isinstance(e, RateLimitExceeded): six.reraise(*sys.exc_info()) if self._swallow_errors: self.logger.exception("Failed to rate limit. Swallowing error") else: six.reraise(*sys.exc_info())