Ejemplo n.º 1
0
    def __init__(self, app=None
                 , key_func=get_ipaddr
                 , global_limits=[]
                 , headers_enabled=False
                 , strategy=None
                 , storage_uri=None
                 , storage_options={}
                 , auto_check=True
                 , swallow_errors=False
                 , in_memory_fallback=[]
    ):
        self.app = app
        self.logger = logging.getLogger("flask-limiter")

        self.enabled = True
        self._global_limits = []
        self._in_memory_fallback = []
        self._exempt_routes = set()
        self._request_filters = []
        self._headers_enabled = headers_enabled
        self._header_mapping = {}
        self._strategy = strategy
        self._storage_uri = storage_uri
        self._storage_options = storage_options
        self._auto_check = auto_check
        self._swallow_errors = swallow_errors
        for limit in global_limits:
            self._global_limits.extend(
                [
                    ExtLimit(
                        limit, key_func, None, False, None, None, None
                    ) for limit in parse_many(limit)
                ]
            )
        for limit in in_memory_fallback:
            self._in_memory_fallback.extend(
                [
                    ExtLimit(
                        limit, key_func, None, False, None, None, None
                    ) for limit in parse_many(limit)
                    ]
            )
        self._route_limits = {}
        self._dynamic_route_limits = {}
        self._blueprint_limits = {}
        self._blueprint_dynamic_limits = {}
        self._blueprint_exempt = set()
        self._storage = self._limiter = None
        self._key_func = key_func
        self._storage_dead = False
        self._fallback_limiter = None
        self.__check_backend_count = 0
        self.__last_check_backend = time.time()

        class BlackHoleHandler(logging.StreamHandler):
            def emit(*_):
                return
        self.logger.addHandler(BlackHoleHandler())
        if app:
            self.init_app(app)
Ejemplo n.º 2
0
    def init_app(self, app):
        """
        :param app: :class:`flask.Flask` instance to rate limit.
        """
        self.enabled = app.config.setdefault(C.ENABLED, True)
        self._swallow_errors = app.config.setdefault(
            C.SWALLOW_ERRORS, self._swallow_errors
        )
        self._headers_enabled = (
            self._headers_enabled
            or app.config.setdefault(C.HEADERS_ENABLED, False)
        )
        self._storage_options.update(
            app.config.get(C.STORAGE_OPTIONS, {})
        )
        self._storage = storage_from_string(
            self._storage_uri
            or app.config.setdefault(C.STORAGE_URL, 'memory://'),
            ** self._storage_options
        )
        strategy = (
            self._strategy
            or app.config.setdefault(C.STRATEGY, 'fixed-window')
        )
        if strategy not in STRATEGIES:
            raise ConfigurationError("Invalid rate limiting strategy %s" % strategy)
        self._limiter = STRATEGIES[strategy](self._storage)
        self._header_mapping.update({
           HEADERS.RESET : self._header_mapping.get(HEADERS.RESET,None) or app.config.setdefault(C.HEADER_RESET, "X-RateLimit-Reset"),
           HEADERS.REMAINING : self._header_mapping.get(HEADERS.REMAINING,None) or app.config.setdefault(C.HEADER_REMAINING, "X-RateLimit-Remaining"),
           HEADERS.LIMIT : self._header_mapping.get(HEADERS.LIMIT,None) or app.config.setdefault(C.HEADER_LIMIT, "X-RateLimit-Limit"),
        })

        conf_limits = app.config.get(C.GLOBAL_LIMITS, None)
        if not self._global_limits and conf_limits:
            self._global_limits = [
                ExtLimit(
                    limit, self._key_func, None, False, None, None, None
                ) for limit in parse_many(conf_limits)
            ]
        fallback_limits = app.config.get(C.IN_MEMORY_FALLBACK, None)
        if not self._in_memory_fallback and fallback_limits:
            self._in_memory_fallback = [
                ExtLimit(
                    limit, self._key_func, None, False, None, None, None
                ) for limit in parse_many(fallback_limits)
                ]
        if self._auto_check:
            app.before_request(self.__check_request_limit)
        app.after_request(self.__inject_headers)

        if self._in_memory_fallback:
            self._fallback_storage = MemoryStorage()
            self._fallback_limiter = STRATEGIES[strategy](self._fallback_storage)

        # purely for backward compatibility as stated in flask documentation
        if not hasattr(app, 'extensions'):
            app.extensions = {} # pragma: no cover
        app.extensions['limiter'] = self
Ejemplo n.º 3
0
 def _inner(obj):
     func = key_func or self._key_func
     is_bp = True if isinstance(obj, Blueprint) else False
     name = "{}.{}".format(obj.__module__,
                           obj.__name__) if not is_bp else obj.name
     dynamic_limit, static_limits = None, []
     if callable(limit_value):
         dynamic_limit = ExtLimit(limit_value, func, _scope, per_method,
                                  methods, error_message, exempt_when)
     else:
         try:
             static_limits = [
                 ExtLimit(limit, func, _scope, per_method, methods,
                          error_message, exempt_when)
                 for limit in parse_many(limit_value)
             ]
         except ValueError as e:
             self.logger.error("failed to configure {} {} ({})".format(
                 "view function", name, e))
     if is_bp:
         if dynamic_limit:
             self._blueprint_dynamic_limits.setdefault(
                 name, []).append(dynamic_limit)
         else:
             self._blueprint_limits.setdefault(name,
                                               []).extend(static_limits)
     else:
         if dynamic_limit:
             self._dynamic_route_limits.setdefault(
                 name, []).append(dynamic_limit)
         else:
             self._route_limits.setdefault(name,
                                           []).extend(static_limits)
         return obj
Ejemplo n.º 4
0
        def _inner(obj):
            func = key_func or self._key_func
            name = "{}.{}".format(obj.__module__, obj.__name__)
            dynamic_limit, static_limits = None, []
            if callable(limit_value):
                dynamic_limit = ExtLimit(limit_value, func, _scope, per_method,
                                         methods, error_message, exempt_when)
            else:
                try:
                    static_limits = [
                        ExtLimit(limit, func, _scope, per_method, methods,
                                 error_message, exempt_when)
                        for limit in parse_many(limit_value)
                    ]
                except ValueError as e:
                    self.logger.error("failed to configure {} {} ({})".format(
                        "view function", name, e))

            @wraps(obj)
            def __inner(*a, **k):
                return obj(*a, **k)

            if dynamic_limit:
                self._dynamic_route_limits.setdefault(name,
                                                      []).append(dynamic_limit)
            else:
                self._route_limits.setdefault(name, []).extend(static_limits)
            return __inner
Ejemplo n.º 5
0
    def init_app(self, app):
        """
        :param app: :class:`sanic.Sanic` instance to rate limit.
        """
        self.app = app
        self.enabled = app.config.setdefault(C.ENABLED, True)
        self._swallow_errors = app.config.setdefault(C.SWALLOW_ERRORS,
                                                     self._swallow_errors)
        self._storage_options.update(app.config.get(C.STORAGE_OPTIONS, {}))
        self._storage = storage_from_string(
            self._storage_uri
            or app.config.setdefault(C.STORAGE_URL, 'memory://'),
            **self._storage_options)
        strategy = (self._strategy
                    or app.config.setdefault(C.STRATEGY, 'fixed-window'))
        if strategy not in STRATEGIES:
            raise ConfigurationError("Invalid rate limiting strategy %s" %
                                     strategy)
        self._limiter = STRATEGIES[strategy](self._storage)

        conf_limits = app.config.get(C.GLOBAL_LIMITS, None)
        if not self._global_limits and conf_limits:
            self._global_limits = [
                ExtLimit(limit, self._key_func, None, False, None, None, None)
                for limit in parse_many(conf_limits)
            ]
        app.request_middleware.append(self.__check_request_limit)
Ejemplo n.º 6
0
    def __init__(self,
                 app=None,
                 key_func=get_ipaddr,
                 global_limits=[],
                 headers_enabled=False,
                 strategy=None,
                 storage_uri=None,
                 storage_options={},
                 auto_check=True,
                 swallow_errors=False,
                 outside_routes=set()
                 ):
        self.app = app
        self.enabled = True
        self.global_limits = []
        self.exempt_routes = outside_routes
        self.request_filters = []
        self.request_blockers = []
        self.headers_enabled = headers_enabled
        self.header_mapping = {}
        self.strategy = strategy
        self.storage_uri = storage_uri
        self.storage_options = storage_options
        self.auto_check = auto_check
        self.swallow_errors = swallow_errors
        for limit in global_limits:
            self.global_limits.extend(
                [
                    ExtLimit(
                        limit, key_func, None, False, None, None
                    ) for limit in parse_many(limit)
                ]
            )
        self.route_limits = {}
        self.dynamic_route_limits = {}
        self.blueprint_limits = {}
        self.blueprint_dynamic_limits = {}
        self.blueprint_exempt = set()
        self.storage = self.limiter = None
        self.key_func = key_func
        self.logger = logging.getLogger("flask-limiter")

        class BlackHoleHandler(logging.StreamHandler):

            def emit(*_):
                return
        self.logger.addHandler(BlackHoleHandler())
        if app:
            self.init_app(app)
Ejemplo n.º 7
0
    def __inner(fn):
        @wraps(fn)
        def _inner(*args, **kwargs):
            return fn(*args, **kwargs)

        if fn in DECORATED:
            DECORATED.setdefault(_inner, DECORATED.pop(fn))
        if callable(limit_value):
            DECORATED.setdefault(_inner, []).append(
                LimitWrapper(limit_value, key_function, scope)
            )
        else:
            DECORATED.setdefault(_inner, []).extend([
                LimitWrapper(
                    list(parse_many(limit_value)), key_function, scope
                )
            ])
        return _inner
Ejemplo n.º 8
0
    def __init__(self,
                 app=None,
                 key_func=get_ipaddr,
                 global_limits=[],
                 headers_enabled=False,
                 strategy=None,
                 storage_uri=None,
                 storage_options={},
                 auto_check=True,
                 swallow_errors=False):
        self.app = app
        self.enabled = True
        self.global_limits = []
        self.exempt_routes = set()
        self.request_filters = []
        self.headers_enabled = headers_enabled
        self.header_mapping = {}
        self.strategy = strategy
        self.storage_uri = storage_uri
        self.storage_options = storage_options
        self.auto_check = auto_check
        self.swallow_errors = swallow_errors
        for limit in global_limits:
            self.global_limits.extend([
                ExtLimit(limit, key_func, None, False, None, None)
                for limit in parse_many(limit)
            ])
        self.route_limits = {}
        self.dynamic_route_limits = {}
        self.blueprint_limits = {}
        self.blueprint_dynamic_limits = {}
        self.blueprint_exempt = set()
        self.storage = self.limiter = None
        self.key_func = key_func
        self.logger = logging.getLogger("flask-limiter")

        class BlackHoleHandler(logging.StreamHandler):
            def emit(*_):
                return

        self.logger.addHandler(BlackHoleHandler())
        if app:
            self.init_app(app)
Ejemplo n.º 9
0
    def __init__(self):
        conf_limits = getattr(settings, C.GLOBAL_LIMITS, "")
        callback = getattr(settings, C.CALLBACK, self.__raise_exceeded)
        self.enabled = getattr(settings, C.ENABLED, True)
        self.headers_enabled = getattr(settings, C.HEADERS_ENABLED, False)
        self.strategy = getattr(settings, C.STRATEGY, 'fixed-window')
        if self.strategy not in STRATEGIES:
            raise ConfigurationError("Invalid rate limiting strategy %s" %
                                     self.strategy)
        self.storage = storage_from_string(
            getattr(settings, C.STORAGE_URL, "memory://"))
        self.limiter = STRATEGIES[self.strategy](self.storage)
        self.key_function = getattr(settings, C.DEFAULT_KEY_FUNCTION,
                                    get_ipaddr)
        self.global_limits = []
        if conf_limits:
            self.global_limits = [
                LimitWrapper(list(parse_many(conf_limits)), self.key_function,
                             None, False)
            ]
        self.header_mapping = {
            HEADERS.RESET:
            getattr(settings, C.HEADER_RESET, "X-RateLimit-Reset"),
            HEADERS.REMAINING:
            getattr(settings, C.HEADER_REMAINING, "X-RateLimit-Remaining"),
            HEADERS.LIMIT:
            getattr(settings, C.HEADER_LIMIT, "X-RateLimit-Limit"),
        }
        self.logger = logging.getLogger("djlimiter")
        self.logger.addHandler(BlackHoleHandler())

        if isinstance(callback, six.string_types):
            mod, _, name = callback.rpartition(".")
            try:
                self.callback = getattr(importlib.import_module(mod), name)
            except AttributeError:
                self.logger.error(
                    "Unable to load callback function %s. Rate limiting disabled",
                    callback)
                self.enabled = False
        else:
            self.callback = callback
Ejemplo n.º 10
0
 def _inner(obj):
     func = key_func or self._key_func
     is_route = not isinstance(obj, Blueprint)
     name = "%s.%s" % (obj.__module__, obj.__name__) if is_route else obj.name
     dynamic_limit, static_limits = None, []
     if callable(limit_value):
         dynamic_limit = ExtLimit(limit_value, func, _scope, per_method,
                                  methods, error_message, exempt_when)
     else:
         try:
             static_limits = [ExtLimit(
                 limit, func, _scope, per_method,
                 methods, error_message, exempt_when
             ) for limit in parse_many(limit_value)]
         except ValueError as e:
             self.logger.error(
                 "failed to configure %s %s (%s)",
                 "view function" if is_route else "blueprint", name, e
             )
     if isinstance(obj, Blueprint):
         if dynamic_limit:
             self._blueprint_dynamic_limits.setdefault(name, []).append(
                 dynamic_limit
             )
         else:
             self._blueprint_limits.setdefault(name, []).extend(
                 static_limits
             )
     else:
         @wraps(obj)
         def __inner(*a, **k):
             return obj(*a, **k)
         if dynamic_limit:
             self._dynamic_route_limits.setdefault(name, []).append(
                 dynamic_limit
             )
         else:
             self._route_limits.setdefault(name, []).extend(
                 static_limits
             )
         return __inner
Ejemplo n.º 11
0
    def __init__(self,
                 app=None,
                 key_func=None,
                 global_limits=[],
                 strategy=None,
                 storage_uri=None,
                 storage_options={},
                 swallow_errors=False):
        self.app = app
        self.logger = logging.getLogger("sanic-limiter")

        self.enabled = True
        self._global_limits = []
        self._exempt_routes = set()
        self._request_filters = []
        self._strategy = strategy
        self._storage_uri = storage_uri
        self._storage_options = storage_options
        self._swallow_errors = swallow_errors
        self._key_func = key_func or get_remote_address
        for limit in global_limits:
            self._global_limits.extend([
                ExtLimit(limit, self._key_func, None, False, None, None, None)
                for limit in parse_many(limit)
            ])
        self._route_limits = {}
        self._dynamic_route_limits = {}
        self._blueprint_dynamic_limits = {}
        self._blueprint_limits = {}
        self._storage = None
        self._limiter = None
        self._storage_dead = False

        class BlackHoleHandler(logging.StreamHandler):
            def emit(*_):
                return

        self.logger.addHandler(BlackHoleHandler())
        if app:
            self.init_app(app)
Ejemplo n.º 12
0
    def __init__(self):
        conf_limits = getattr(settings, C.GLOBAL_LIMITS, "")
        callback = getattr(settings, C.CALLBACK, self.__raise_exceeded )
        self.enabled = getattr(settings, C.ENABLED, True)
        self.headers_enabled = getattr(settings, C.HEADERS_ENABLED, False)
        self.strategy = getattr(settings, C.STRATEGY, 'fixed-window')
        if self.strategy not in STRATEGIES:
            raise ConfigurationError("Invalid rate limiting strategy %s" % self.strategy)
        self.storage = storage_from_string(getattr(settings, C.STORAGE_URL, "memory://"))
        self.limiter = STRATEGIES[self.strategy](self.storage)
        self.key_function = getattr(settings, C.DEFAULT_KEY_FUNCTION, get_ipaddr)
        self.global_limits = []
        if conf_limits:
            self.global_limits = [
                LimitWrapper(
                    list(parse_many(conf_limits)), self.key_function, None, False
                )
            ]
        self.header_mapping = {
            HEADERS.RESET : getattr(settings,C.HEADER_RESET, "X-RateLimit-Reset"),
            HEADERS.REMAINING : getattr(settings,C.HEADER_REMAINING, "X-RateLimit-Remaining"),
            HEADERS.LIMIT : getattr(settings,C.HEADER_LIMIT, "X-RateLimit-Limit"),
        }
        self.logger = logging.getLogger("djlimiter")
        self.logger.addHandler(BlackHoleHandler())

        if isinstance(callback, six.string_types):
            mod, _, name = callback.rpartition(".")
            try:
                self.callback = getattr(importlib.import_module(mod), name)
            except AttributeError:
                self.logger.error(
                    "Unable to load callback function %s. Rate limiting disabled",
                    callback
                )
                self.enabled = False
        else:
            self.callback = callback
Ejemplo n.º 13
0
    def __check_request_limit(self):
        endpoint = request.endpoint or ""
        view_func = current_app.view_functions.get(endpoint, None)
        name = ("%s.%s" % (view_func.__module__, view_func.__name__)
                if view_func else "")
        if (not request.endpoint or not self.enabled
                or view_func == current_app.send_static_file
                or name in self.exempt_routes
                or request.blueprint in self.blueprint_exempt
                or any(fn() for fn in self.request_filters)):
            return
        limits = (name in self.route_limits and self.route_limits[name] or [])
        dynamic_limits = []
        if name in self.dynamic_route_limits:
            for lim in self.dynamic_route_limits[name]:
                try:
                    dynamic_limits.extend(
                        ExtLimit(limit, lim.key_func, lim.scope, lim.
                                 per_method, lim.methods, lim.error_message)
                        for limit in parse_many(lim.limit))
                except ValueError as e:
                    self.logger.error(
                        "failed to load ratelimit for view function %s (%s)",
                        name, e)
        if request.blueprint:
            if (request.blueprint in self.blueprint_dynamic_limits
                    and not dynamic_limits):
                for lim in self.blueprint_dynamic_limits[request.blueprint]:
                    try:
                        dynamic_limits.extend(
                            ExtLimit(limit, lim.key_func, lim.scope,
                                     lim.per_method, lim.methods,
                                     lim.error_message)
                            for limit in parse_many(lim.limit))
                    except ValueError as e:
                        self.logger.error(
                            "failed to load ratelimit for blueprint %s (%s)",
                            request.blueprint, e)
            if (request.blueprint in self.blueprint_limits and not limits):
                limits.extend(self.blueprint_limits[request.blueprint])

        failed_limit = None
        limit_for_header = None
        try:
            for lim in (limits + dynamic_limits or self.global_limits):
                limit_scope = lim.scope or endpoint
                if lim.methods is not None and request.method.lower(
                ) not in lim.methods:
                    return
                if lim.per_method:
                    limit_scope += ":%s" % request.method
                if not limit_for_header or lim.limit < limit_for_header[0]:
                    limit_for_header = (lim.limit, lim.key_func(), limit_scope)
                if not self.limiter.hit(lim.limit, lim.key_func(),
                                        limit_scope):
                    self.logger.warning(
                        "ratelimit %s (%s) exceeded at endpoint: %s",
                        lim.limit, lim.key_func(), limit_scope)
                    failed_limit = lim
                    limit_for_header = (lim.limit, lim.key_func(), limit_scope)
                    break

            g.view_rate_limit = limit_for_header

            if failed_limit:
                if failed_limit.error_message:
                    exc_description = failed_limit.error_message if not callable(
                        failed_limit.error_message
                    ) else failed_limit.error_message()
                else:
                    exc_description = six.text_type(failed_limit.limit)
                raise RateLimitExceeded(exc_description)
        except Exception:  # no qa
            if self.swallow_errors:
                self.logger.exception("Failed to rate limit. Swallowing error")
            else:
                six.reraise(*sys.exc_info())
Ejemplo n.º 14
0
 def get_limits(self, request):
     return list(parse_many(self._limits(request))) if callable(
         self._limits) else self._limits
Ejemplo n.º 15
0
    def __check_request_limit(self):
        endpoint = request.endpoint or ""
        view_func = current_app.view_functions.get(endpoint, None)
        name = ("%s.%s" % (
                view_func.__module__, view_func.__name__
            ) if view_func else ""
        )
        if (not request.endpoint
            or not self.enabled
            or view_func == current_app.send_static_file
            or name in self._exempt_routes
            or request.blueprint in self._blueprint_exempt
            or any(fn() for fn in self._request_filters)
        ):
            return
        limits = (
            name in self._route_limits and self._route_limits[name]
            or []
        )
        dynamic_limits = []
        if name in self._dynamic_route_limits:
            for lim in self._dynamic_route_limits[name]:
                try:
                    dynamic_limits.extend(
                        ExtLimit(
                            limit, lim.key_func, lim.scope, lim.per_method,
                            lim.methods, lim.error_message, lim.exempt_when
                        ) for limit in parse_many(lim.limit)
                    )
                except ValueError as e:
                    self.logger.error(
                        "failed to load ratelimit for view function %s (%s)"
                        , name, e
                    )
        if request.blueprint:
            if (request.blueprint in self._blueprint_dynamic_limits
                and not dynamic_limits
            ):
                for lim in self._blueprint_dynamic_limits[request.blueprint]:
                    try:
                        dynamic_limits.extend(
                            ExtLimit(
                                limit, lim.key_func, lim.scope, lim.per_method,
                                lim.methods, lim.error_message, lim.exempt_when
                            ) for limit in parse_many(lim.limit)
                        )
                    except ValueError as e:
                        self.logger.error(
                            "failed to load ratelimit for blueprint %s (%s)"
                            , request.blueprint, e
                        )
            if (request.blueprint in self._blueprint_limits
                and not limits
            ):
               limits.extend(self._blueprint_limits[request.blueprint])

        failed_limit = None
        limit_for_header = None
        try:
            all_limits = []
            if self._storage_dead and self._fallback_limiter:
                if self.__should_check_backend() and self._storage.check():
                    self.logger.info(
                        "Rate limit storage recovered"
                    )
                    self._storage_dead = False
                    self.__check_backend_count = 0
                else:
                    all_limits = self._in_memory_fallback
            if not all_limits:
                all_limits = (limits + dynamic_limits or self._global_limits)
            for lim in all_limits:
                limit_scope = lim.scope or endpoint
                if lim.is_exempt:
                    return
                if lim.methods is not None and request.method.lower() not in lim.methods:
                    return
                if lim.per_method:
                    limit_scope += ":%s" % request.method
                if not limit_for_header or lim.limit < limit_for_header[0]:
                    limit_for_header = (lim.limit, lim.key_func(), limit_scope)
                if not self.limiter.hit(lim.limit, lim.key_func(), limit_scope):
                    self.logger.warning(
                        "ratelimit %s (%s) exceeded at endpoint: %s"
                        , lim.limit, lim.key_func(), limit_scope
                    )
                    failed_limit = lim
                    limit_for_header = (lim.limit, lim.key_func(), limit_scope)
                    break

            g.view_rate_limit = limit_for_header

            if failed_limit:
                if failed_limit.error_message:
                    exc_description = failed_limit.error_message if not callable(
                        failed_limit.error_message
                    ) else failed_limit.error_message()
                else:
                    exc_description = six.text_type(failed_limit.limit)
                raise RateLimitExceeded(exc_description)
        except Exception as e: # no qa
            if isinstance(e, RateLimitExceeded):
                six.reraise(*sys.exc_info())
            if self._in_memory_fallback and not self._storage_dead:
                self.logger.warn(
                    "Rate limit storage unreachable - falling back to"
                    " in-memory storage"
                )
                self._storage_dead = True
                self.__check_request_limit()
            else:
                if self._swallow_errors:
                    self.logger.exception(
                        "Failed to rate limit. Swallowing error"
                    )
                else:
                    six.reraise(*sys.exc_info())
Ejemplo n.º 16
0
    def __init__(self, app=None
                 , key_func=None
                 , global_limits=[]
                 , headers_enabled=False
                 , strategy=None
                 , storage_uri=None
                 , storage_options={}
                 , auto_check=True
                 , swallow_errors=False
                 , in_memory_fallback=[]
                 , retry_after=None
    ):
        self.app = app
        self.logger = logging.getLogger("flask-limiter")

        self.enabled = True
        self._global_limits = []
        self._in_memory_fallback = []
        self._exempt_routes = set()
        self._request_filters = []
        self._headers_enabled = headers_enabled
        self._header_mapping = {}
        self._retry_after = retry_after
        self._strategy = strategy
        self._storage_uri = storage_uri
        self._storage_options = storage_options
        self._auto_check = auto_check
        self._swallow_errors = swallow_errors
        if not key_func:
            warnings.warn(
                "Use of the default `get_ipaddr` function is discouraged."
                " Please refer to https://flask-limiter.readthedocs.org/#rate-limit-domain"
                " for the recommended configuration",
                UserWarning
            )

        self._key_func = key_func or get_ipaddr
        for limit in global_limits:
            self._global_limits.extend(
                [
                    ExtLimit(
                        limit, self._key_func, None, False, None, None, None
                    ) for limit in parse_many(limit)
                ]
            )
        for limit in in_memory_fallback:
            self._in_memory_fallback.extend(
                [
                    ExtLimit(
                        limit, self._key_func, None, False, None, None, None
                    ) for limit in parse_many(limit)
                    ]
            )
        self._route_limits = {}
        self._dynamic_route_limits = {}
        self._blueprint_limits = {}
        self._blueprint_dynamic_limits = {}
        self._blueprint_exempt = set()
        self._storage = self._limiter = None
        self._storage_dead = False
        self._fallback_limiter = None
        self.__check_backend_count = 0
        self.__last_check_backend = time.time()

        class BlackHoleHandler(logging.StreamHandler):
            def emit(*_):
                return
        self.logger.addHandler(BlackHoleHandler())
        if app:
            self.init_app(app)
Ejemplo n.º 17
0
 def test_parse_many_csv(self):
     parsed = parse_many("1 per 3 hour, 1 per second")
     self.assertEqual(len(parsed), 2)
     self.assertEqual(parsed[0].get_expiry(), 3 * 60 * 60)
     self.assertEqual(parsed[1].get_expiry(), 1)
Ejemplo n.º 18
0
    def __check_request_limit(self):
        endpoint = request.endpoint or ""
        view_func = current_app.view_functions.get(endpoint, None)
        name = ("%s.%s" % (
                view_func.__module__, view_func.__name__
            ) if view_func else ""
        )
        if (not request.endpoint
            or not self.enabled
            or view_func == current_app.send_static_file
            or name in self.exempt_routes
            or request.blueprint in self.blueprint_exempt
            or any(fn() for fn in self.request_filters)
        ):
            return
        limits = (
            name in self.route_limits and self.route_limits[name]
            or []
        )
        dynamic_limits = []
        if name in self.dynamic_route_limits:
            for lim in self.dynamic_route_limits[name]:
                try:
                    dynamic_limits.extend(
                        ExtLimit(
                            limit, lim.key_func, lim.scope, lim.per_method,
                            lim.methods, lim.error_message
                        ) for limit in parse_many(lim.limit)
                    )
                except ValueError as e:
                    self.logger.error(
                        "failed to load ratelimit for view function %s (%s)"
                        , name, e
                    )
        if request.blueprint:
            if (request.blueprint in self.blueprint_dynamic_limits
                and not dynamic_limits
            ):
                for lim in self.blueprint_dynamic_limits[request.blueprint]:
                    try:
                        dynamic_limits.extend(
                            ExtLimit(
                                limit, lim.key_func, lim.scope, lim.per_method,
                                lim.methods, lim.error_message
                            ) for limit in parse_many(lim.limit)
                        )
                    except ValueError as e:
                        self.logger.error(
                            "failed to load ratelimit for blueprint %s (%s)"
                            , request.blueprint, e
                        )
            if (request.blueprint in self.blueprint_limits
                and not limits
            ):
               limits.extend(self.blueprint_limits[request.blueprint])

        failed_limit = None
        limit_for_header = None
        try:
            for lim in (limits + dynamic_limits or self.global_limits):
                limit_scope = lim.scope or endpoint
                if lim.methods is not None and request.method.lower() not in lim.methods:
                    return
                if lim.per_method:
                    limit_scope += ":%s" % request.method
                if not limit_for_header or lim.limit < limit_for_header[0]:
                    limit_for_header = (lim.limit, lim.key_func(), limit_scope)
                if not self.limiter.hit(lim.limit, lim.key_func(), limit_scope):
                    self.logger.warning(
                        "ratelimit %s (%s) exceeded at endpoint: %s"
                        , lim.limit, lim.key_func(), limit_scope
                    )
                    failed_limit = lim
                    limit_for_header = (lim.limit, lim.key_func(), limit_scope)
                    break

            g.view_rate_limit = limit_for_header

            if failed_limit:
                if failed_limit.error_message:
                    exc_description = failed_limit.error_message if not callable(
                        failed_limit.error_message
                    ) else failed_limit.error_message()
                else:
                    exc_description = six.text_type(failed_limit.limit)
                raise RateLimitExceeded(exc_description)
        except Exception: # no qa
            if self.swallow_errors:
                self.logger.exception(
                    "Failed to rate limit. Swallowing error"
                )
            else:
                six.reraise(*sys.exc_info())
Ejemplo n.º 19
0
 def test_parse_many(self):
     parsed = parse_many("1 per 3 hour; 1 per second")
     self.assertEqual(len(parsed), 2)
     self.assertEqual(parsed[0].get_expiry(), 3 * 60 * 60)
     self.assertEqual(parsed[1].get_expiry(), 1)
Ejemplo n.º 20
0
    def __check_request_limit(self):
        endpoint = request.endpoint or ""
        view_func = current_app.view_functions.get(endpoint, None)
        name = ("%s.%s" % (view_func.__module__, view_func.__name__)
                if view_func else "")
        if (not request.endpoint or not self.enabled
                or view_func == current_app.send_static_file
                or name in self._exempt_routes
                or request.blueprint in self._blueprint_exempt
                or any(fn() for fn in self._request_filters)):
            return
        limits = (name in self._route_limits and self._route_limits[name]
                  or [])
        dynamic_limits = []
        if name in self._dynamic_route_limits:
            for lim in self._dynamic_route_limits[name]:
                try:
                    dynamic_limits.extend(
                        ExtLimit(limit, lim.key_func, lim.scope,
                                 lim.per_method, lim.methods,
                                 lim.error_message, lim.exempt_when)
                        for limit in parse_many(lim.limit))
                except ValueError as e:
                    self.logger.error(
                        "failed to load ratelimit for view function %s (%s)",
                        name, e)
        if request.blueprint:
            if (request.blueprint in self._blueprint_dynamic_limits
                    and not dynamic_limits):
                for lim in self._blueprint_dynamic_limits[request.blueprint]:
                    try:
                        dynamic_limits.extend(
                            ExtLimit(limit, lim.key_func, lim.scope,
                                     lim.per_method, lim.methods,
                                     lim.error_message, lim.exempt_when)
                            for limit in parse_many(lim.limit))
                    except ValueError as e:
                        self.logger.error(
                            "failed to load ratelimit for blueprint %s (%s)",
                            request.blueprint, e)
            if (request.blueprint in self._blueprint_limits and not limits):
                limits.extend(self._blueprint_limits[request.blueprint])

        failed_limit = None
        limit_for_header = None
        try:
            all_limits = []
            if self._storage_dead and self._fallback_limiter:
                if self.__should_check_backend() and self._storage.check():
                    self.logger.info("Rate limit storage recovered")
                    self._storage_dead = False
                    self.__check_backend_count = 0
                else:
                    all_limits = self._in_memory_fallback
            if not all_limits:
                all_limits = self._application_limits + (
                    limits + dynamic_limits or self._default_limits)
            for lim in all_limits:
                limit_scope = lim.scope or endpoint
                if lim.is_exempt:
                    return
                if lim.methods is not None and request.method.lower(
                ) not in lim.methods:
                    return
                if lim.per_method:
                    limit_scope += ":%s" % request.method
                if not limit_for_header or lim.limit < limit_for_header[0]:
                    limit_for_header = (lim.limit, lim.key_func(), limit_scope)
                if not self.limiter.hit(lim.limit, lim.key_func(),
                                        limit_scope):
                    self.logger.warning(
                        "ratelimit %s (%s) exceeded at endpoint: %s",
                        lim.limit, lim.key_func(), limit_scope)
                    failed_limit = lim
                    limit_for_header = (lim.limit, lim.key_func(), limit_scope)
                    break

            g.view_rate_limit = limit_for_header

            if failed_limit:
                if failed_limit.error_message:
                    exc_description = failed_limit.error_message if not callable(
                        failed_limit.error_message
                    ) else failed_limit.error_message()
                else:
                    exc_description = six.text_type(failed_limit.limit)
                raise RateLimitExceeded(exc_description)
        except Exception as e:  # no qa
            if isinstance(e, RateLimitExceeded):
                six.reraise(*sys.exc_info())
            if self._in_memory_fallback and not self._storage_dead:
                self.logger.warn(
                    "Rate limit storage unreachable - falling back to"
                    " in-memory storage")
                self._storage_dead = True
                self.__check_request_limit()
            else:
                if self._swallow_errors:
                    self.logger.exception(
                        "Failed to rate limit. Swallowing error")
                else:
                    six.reraise(*sys.exc_info())
Ejemplo n.º 21
0
 def get_limits(self, request):
     return list(parse_many(self._limits(request))) if callable(self._limits) else self._limits
Ejemplo n.º 22
0
    def __init__(self,
                 app=None,
                 key_func=None,
                 global_limits=[],
                 default_limits=[],
                 application_limits=[],
                 headers_enabled=False,
                 strategy=None,
                 storage_uri=None,
                 storage_options={},
                 auto_check=True,
                 swallow_errors=False,
                 in_memory_fallback=[],
                 retry_after=None):
        self.app = app
        self.logger = logging.getLogger("flask-limiter")

        self.enabled = True
        self._default_limits = []
        self._application_limits = []
        self._in_memory_fallback = []
        self._exempt_routes = set()
        self._request_filters = []
        self._headers_enabled = headers_enabled
        self._header_mapping = {}
        self._retry_after = retry_after
        self._strategy = strategy
        self._storage_uri = storage_uri
        self._storage_options = storage_options
        self._auto_check = auto_check
        self._swallow_errors = swallow_errors
        if not key_func:
            warnings.warn(
                "Use of the default `get_ipaddr` function is discouraged."
                " Please refer to https://flask-limiter.readthedocs.org/#rate-limit-domain"
                " for the recommended configuration", UserWarning)
        if global_limits:
            self.raise_global_limits_warning()

        self._key_func = key_func or get_ipaddr
        for limit in set(global_limits + default_limits):
            self._default_limits.extend([
                ExtLimit(limit, self._key_func, None, False, None, None, None)
                for limit in parse_many(limit)
            ])
        for limit in application_limits:
            self._application_limits.extend([
                ExtLimit(limit, self._key_func, "global", False, None, None,
                         None) for limit in parse_many(limit)
            ])
        for limit in in_memory_fallback:
            self._in_memory_fallback.extend([
                ExtLimit(limit, self._key_func, None, False, None, None, None)
                for limit in parse_many(limit)
            ])
        self._route_limits = {}
        self._dynamic_route_limits = {}
        self._blueprint_limits = {}
        self._blueprint_dynamic_limits = {}
        self._blueprint_exempt = set()
        self._storage = self._limiter = None
        self._storage_dead = False
        self._fallback_limiter = None
        self.__check_backend_count = 0
        self.__last_check_backend = time.time()

        class BlackHoleHandler(logging.StreamHandler):
            def emit(*_):
                return

        self.logger.addHandler(BlackHoleHandler())
        if app:
            self.init_app(app)
Ejemplo n.º 23
0
    def __check_request_limit(self, request):
        endpoint = request.path or ""
        view_handler = self.app.router.routes_static.get(endpoint, None)
        if view_handler is None:
            return
        view_func = view_handler.handler
        view_bpname = view_func.__dict__.get('__blueprintname__', None)
        name = ("{}.{}".format(view_func.__module__, view_func.__name__)
                if view_func else "")
        if (not endpoint or not self.enabled or name in self._exempt_routes
                or any(fn() for fn in self._request_filters)):
            return
        limits = self._route_limits.get(name, [])
        dynamic_limits = []
        if name in self._dynamic_route_limits:
            for lim in self._dynamic_route_limits[name]:
                try:
                    dynamic_limits.extend(
                        ExtLimit(limit, lim.key_func, lim.scope,
                                 lim.per_method, lim.methods,
                                 lim.error_message, lim.exempt_when)
                        for limit in parse_many(lim.limit))
                except ValueError as e:
                    self.logger.error(
                        "failed to load ratelimit for view function %s (%s)",
                        name, e)
        if view_bpname:
            if view_bpname in self._blueprint_dynamic_limits and not dynamic_limits:
                for lim in self._blueprint_dynamic_limits[view_bpname]:
                    try:
                        dynamic_limits.extend(
                            ExtLimit(limit, lim.key_func, lim.scope,
                                     lim.per_method, lim.methods,
                                     lim.error_message, lim.exempt_when)
                            for limit in parse_many(lim.limit))
                    except ValueError as e:
                        self.logger.error(
                            "failed to load ratelimit for view blueprint %s (%s)",
                            view_bpname, e)
            if view_bpname in self._blueprint_limits and not limits:
                limits.extend(self._blueprint_limits[view_bpname])
        failed_limit = None
        try:
            for lim in (limits + dynamic_limits or self._global_limits):
                limit_scope = lim.scope or endpoint
                if lim.is_exempt:
                    return
                if lim.methods is not None and request.method.lower(
                ) not in lim.methods:
                    return
                if lim.per_method:
                    limit_scope += ":%s" % request.method
                for param in inspect.signature(
                        lim.key_func).parameters.values():
                    if param.default is inspect.Parameter.empty:
                        key = lim.key_func(request)
                    else:
                        key = lim.key_func()
                    break
                else:
                    key = lim.key_func()
                if key is None:
                    # Ignore empty result of the key function.
                    continue
                if not self.limiter.hit(lim.limit, key, limit_scope):
                    self.logger.warning(
                        "ratelimit %s (%s) exceeded at endpoint: %s",
                        lim.limit, key, limit_scope)
                    failed_limit = lim
                    break

            if failed_limit:
                if failed_limit.error_message:
                    exc_description = failed_limit.error_message if not callable(
                        failed_limit.error_message
                    ) else failed_limit.error_message()
                else:
                    exc_description = six.text_type(failed_limit.limit)
                raise RateLimitExceeded(exc_description)
        except Exception as e:  # no qa
            if isinstance(e, RateLimitExceeded):
                six.reraise(*sys.exc_info())
            if self._swallow_errors:
                self.logger.exception("Failed to rate limit. Swallowing error")
            else:
                six.reraise(*sys.exc_info())