Example #1
0
class OpenStackRateLimitMiddleware(object):
    """
    OpenStack Rate Limit Middleware enforces configurable rate limits.

    Per combination of:
      service         ( compute, identity, object-store, .. )
      scope           ( initiator|target project uid, initiator host address )
      target_type_uri ( service/compute/servers, service/storage/block/volumes,.. )
      action          ( create, read, update, delete, authenticate, .. )
    """
    def __init__(self, app, **conf):
        self.app = app
        # Configuration via paste.ini.
        self.__conf = conf
        self.logger = log.Logger(conf.get('log_name', __name__))

        # StatsD is used to emit metrics.
        statsd_host = self.__conf.get('statsd_host', '127.0.0.1')
        statsd_port = common.to_int(self.__conf.get('statsd_port', 9125))
        statsd_prefix = self.__conf.get('statsd_prefix',
                                        common.Constants.metric_prefix)

        # Init StatsD client.
        self.metricsClient = DogStatsd(
            host=os.getenv('STATSD_HOST', statsd_host),
            port=int(os.getenv('STATSD_PORT', statsd_port)),
            namespace=os.getenv('STATSD_PREFIX', statsd_prefix))

        # Get backend configuration.
        # Backend is used to store count of requests.
        self.backend_host = self.__conf.get('backend_host', '127.0.0.1')
        self.backend_port = common.to_int(self.__conf.get('backend_port'),
                                          6379)
        self.logger.debug("using backend '{0}' on '{1}:{2}'".format(
            'redis', self.backend_host, self.backend_port))
        backend_timeout_seconds = common.to_int(
            self.__conf.get('backend_timeout_seconds'), 20)
        backend_max_connections = common.to_int(
            self.__conf.get('backend_max_connections'), 100)

        # Load configuration file.
        self.config = {}
        config_file = self.__conf.get('config_file', None)
        if config_file:
            try:
                self.config = common.load_config(config_file)
            except errors.ConfigError as e:
                self.logger.warning("error loading configuration: {0}".format(
                    str(e)))

        self.service_type = self.__conf.get('service_type', None)

        # This is required to trim the prefix from the target_type_uri.
        # Example:
        #   service_type      = identity
        #   cadf_service_name = data/security
        #   target_type_uri   = data/security/auth/tokens -> auth/tokens
        self.cadf_service_name = self.__conf.get('cadf_service_name', None)
        if common.is_none_or_unknown(self.cadf_service_name):
            self.cadf_service_name = common.CADF_SERVICE_TYPE_PREFIX_MAP.get(
                self.service_type, None)

        # Use configured parameters or ensure defaults.
        max_sleep_time_seconds = common.to_int(
            self.__conf.get(common.Constants.max_sleep_time_seconds), 20)
        log_sleep_time_seconds = common.to_int(
            self.__conf.get(common.Constants.log_sleep_time_seconds), 10)

        # Setup ratelimit and blacklist response.
        self._setup_response()

        # White-/blacklist can contain project, domain, user ids or the client ip address.
        # Don't apply rate limits to localhost.
        default_whitelist = ['127.0.0.1', 'localhost']
        config_whitelist = self.config.get('whitelist', [])
        self.whitelist = default_whitelist + config_whitelist
        self.whitelist_users = self.config.get('whitelist_users', [])

        self.blacklist = self.config.get('blacklist', [])
        self.blacklist_users = self.config.get('blacklist_users', [])

        # Mapping of potentially multiple CADF actions to one action.
        self.rate_limit_groups = self.config.get('groups', {})

        # Configurable scope in which a rate limit is applied. Defaults to initiator project id.
        # Rate limits are applied based on the tuple of (rate_limit_by, action, target_type_uri).
        self.rate_limit_by = self.__conf.get(
            'rate_limit_by', common.Constants.initiator_project_id)

        # Accuracy of the request timestamps used. Defaults to nanosecond accuracy.
        clock_accuracy = int(
            1 / units.Units.parse(self.__conf.get('clock_accuracy', '1ns')))

        self.backend = rate_limit_backend.RedisBackend(
            host=self.backend_host,
            port=self.backend_port,
            rate_limit_response=self.ratelimit_response,
            max_sleep_time_seconds=max_sleep_time_seconds,
            log_sleep_time_seconds=log_sleep_time_seconds,
            timeout_seconds=backend_timeout_seconds,
            max_connections=backend_max_connections,
            clock_accuracy=clock_accuracy,
        )

        # Test if the backend is ready.
        is_available, msg = self.backend.is_available()
        if not is_available:
            self.logger.warning(
                "rate limit not possible. the backend is not available: {0}".
                format(msg))

        # Provider for rate limits. Defaults to configuration file.
        # Also supports Limes.
        configuration_ratelimit_provider = provider.ConfigurationRateLimitProvider(
            service_type=self.service_type)

        # Force load of rate limits from configuration file.
        configuration_ratelimit_provider.read_rate_limits_from_config(
            config_file)
        self.ratelimit_provider = configuration_ratelimit_provider

        # If limes is enabled and we want to rate limit by initiator|target project id,
        # Set LimesRateLimitProvider as the provider for rate limits.
        limes_enabled = self.__conf.get('limes_enabled', False)
        if limes_enabled:
            self.__setup_limes_ratelimit_provider()

        self.logger.info("OpenStack Rate Limit Middleware ready for requests.")

    def _setup_response(self):
        """Setup configurable RateLimitExceededResponse and BlacklistResponse."""
        # Default responses.
        ratelimit_response = response.RateLimitExceededResponse()
        blacklist_response = response.BlacklistResponse()

        # Overwrite default responses if custom ones are configured.
        try:
            ratelimit_response_config = self.config.get(
                common.Constants.ratelimit_response)
            if ratelimit_response_config:
                status, status_code, headers, body, json_body = \
                    response.response_parameters_from_config(ratelimit_response_config)

                # Only create custom response if all parameters are given.
                if status and status_code and (body or json_body):
                    ratelimit_response = response.RateLimitExceededResponse(
                        status=status,
                        status_code=status_code,
                        headerlist=headers,
                        body=body,
                        json_body=json_body)

            blacklist_response_config = self.config.get(
                common.Constants.blacklist_response)
            if blacklist_response_config:
                status, status_code, headers, body, json_body = \
                    response.response_parameters_from_config(blacklist_response_config)

                # Only create custom response if all parameters are given.
                if status and status_code and (body or json_body):
                    blacklist_response = response.BlacklistResponse(
                        status=status,
                        status_code=status_code,
                        headerlist=headers,
                        body=body,
                        json_body=json_body)

        except Exception as e:
            self.logger.debug(
                "error configuring custom responses. falling back to defaults: {0}"
                .format(str(e)))

        finally:
            self.ratelimit_response = ratelimit_response
            self.blacklist_response = blacklist_response

    def __setup_limes_ratelimit_provider(self):
        """Setup Limes as provider for rate limits. If not successful fallback to configuration file."""
        try:
            limes_ratelimit_provider = provider.LimesRateLimitProvider(
                service_type=self.service_type,
                redis_host=self.backend_host,
                redis_port=self.backend_port,
                refresh_interval_seconds=self.__conf.get(
                    common.Constants.limes_refresh_interval_seconds, 300),
                limes_api_uri=self.__conf.get(common.Constants.limes_api_uri),
                auth_url=self.__conf.get('identity_auth_url'),
                username=self.__conf.get('username'),
                user_domain_name=self.__conf.get('user_domain_name'),
                password=self.__conf.get('password'),
                domain_name=self.__conf.get('domain_name'))
            self.ratelimit_provider = limes_ratelimit_provider

        except Exception as e:
            self.logger.debug(
                "failed to setup limes rate limit provider: {0}".format(
                    str(e)))

    @classmethod
    def factory(cls, global_config, **local_config):
        conf = global_config.copy()
        conf.update(local_config)

        def limiter(app):
            return cls(app, **conf)

        return limiter

    def _rate_limit(self, scope, action, target_type_uri, **kwargs):
        """
        Check the whitelist, blacklist, global and local ratelimits.

        :param scope: the scope of the request
        :param action: the action of the request
        :param target_type_uri: the target type URI of the response
        :return: None or BlacklistResponse or RateLimitResponse
        """
        # Labels used for all metrics.
        metric_labels = [
            'service:{0}'.format(self.service_type),
            'service_name:{0}'.format(self.cadf_service_name),
            'action:{0}'.format(action),
            '{0}:{1}'.format(self.rate_limit_by, scope),
            'target_type_uri:{0}'.format(target_type_uri)
        ]
        global_metric_labels = metric_labels + ['level:global']
        local_metric_labels = metric_labels + ['level:local']

        # Check whether a set of CADF actions are accounted together.
        new_action = self.get_action_from_rate_limit_groups(action)
        if not common.is_none_or_unknown(new_action):
            action = new_action
            metric_labels.append('action_group:{0}'.format(action))

        # Get CADF service name and trim from target_type_uri.
        trimmed_target_type_uri = target_type_uri
        if not common.is_none_or_unknown(self.cadf_service_name):
            trimmed_target_type_uri = self._trim_cadf_service_prefix_from_target_type_uri(
                self.cadf_service_name, target_type_uri)

        username = kwargs.get('username', None)
        # If we have the username: Check whether the user is white- or blacklisted.
        if username:
            metric_labels.append('initiator_user_name:{}'.format(username))
            if self.is_user_whitelisted(username):
                self.logger.debug(
                    "user {0} is whitelisted. skipping rate limit".format(
                        username))
                self.metricsClient.increment(
                    common.Constants.metric_requests_whitelisted_total,
                    tags=metric_labels)
                return None

            if self.is_user_blacklisted(username):
                self.logger.debug(
                    "user {0} is blacklisted. returning BlacklistResponse".
                    format(username))
                self.metricsClient.increment(
                    common.Constants.metric_requests_blacklisted_total,
                    tags=metric_labels)
                return self.blacklist_response

        # The key of the scope in the format $domainName/projectName.
        scope_name_key = kwargs.get('scope_name_key', None)
        if scope_name_key:
            metric_labels.append(
                'initiator_project_name:{}'.format(scope_name_key))

        # Check whitelist. If scope is whitelisted break here and don't apply any rate limits.
        if self.is_scope_whitelisted(scope) or self.is_scope_whitelisted(
                scope_name_key):
            self.logger.debug(
                "scope {0} (key: {1}) is whitelisted. skipping rate limit".
                format(scope, scope_name_key))
            self.metricsClient.increment(
                common.Constants.metric_requests_whitelisted_total,
                tags=metric_labels)
            return None

        # Check blacklist. If scope is blacklisted return BlacklistResponse.
        if self.is_scope_blacklisted(scope) or self.is_scope_blacklisted(
                scope_name_key):
            self.logger.debug(
                "scope {0} (key: {1}) is blacklisted. returning BlacklistResponse"
                .format(scope, scope_name_key))
            self.metricsClient.increment(
                common.Constants.metric_requests_blacklisted_total,
                tags=metric_labels)
            return self.blacklist_response

        # Get global rate limits from the provider.
        global_rate_limit = self.ratelimit_provider.get_global_rate_limits(
            action, trimmed_target_type_uri)

        # Don't rate limit if limit=-1 or unknown.
        if not common.is_unlimited(global_rate_limit):
            self.logger.debug(
                "global rate limit configured for request with action '{0}', target type URI '{1}': '{2}'"
                .format(action, target_type_uri, global_rate_limit))

            # Check global rate limits.
            # Global rate limits enforce a backend protection by counting all requests independent of their scope.
            rate_limit_response = self.backend.rate_limit(
                scope=None,
                action=action,
                target_type_uri=trimmed_target_type_uri,
                max_rate_string=global_rate_limit)
            if rate_limit_response:
                self.metricsClient.increment(
                    common.Constants.metric_requests_ratelimit_total,
                    tags=global_metric_labels)
                return rate_limit_response

        # Get local (for a certain scope) rate limits from provider.
        local_rate_limit = self.ratelimit_provider.get_local_rate_limits(
            scope, action, trimmed_target_type_uri)

        # Don't rate limit for rate_limit=-1 or if unknown.
        if not common.is_unlimited(local_rate_limit):
            self.logger.debug(
                "local rate limit configured for request with action '{0}', target type URI '{1}', scope '{2}': '{3}'"
                .format(action, target_type_uri, scope, local_rate_limit))

            # Check local (for a specific scope) rate limits.
            rate_limit_response = self.backend.rate_limit(
                scope=scope,
                action=action,
                target_type_uri=trimmed_target_type_uri,
                max_rate_string=local_rate_limit)
            if rate_limit_response:
                self.metricsClient.increment(
                    common.Constants.metric_requests_ratelimit_total,
                    tags=local_metric_labels)
                return rate_limit_response

        return None

    def __call__(self, environ, start_response):
        """
        WSGI entry point. Wraps environ in webob.Request.

        :param environ: the WSGI environment dict
        :param start_response: WSGI callable
        """
        # Save the app's response so it can be returned easily.
        resp = self.app

        try:
            self.metricsClient.open_buffer()

            # If the service type and/or service name is not configured,
            # attempt to extract watcher classification from environ and set it.
            self._set_service_type_and_name(environ)

            # Get openstack-watcher-middleware classification from requests environ.
            scope, action, target_type_uri = self.get_scope_action_target_type_uri_from_environ(
                environ)

            # Don't rate limit if any of scope, action, target type URI cannot be determined.
            if common.is_none_or_unknown(scope) or \
               common.is_none_or_unknown(action) or \
               common.is_none_or_unknown(target_type_uri):
                path = str(environ.get('PATH_INFO', common.Constants.unknown))
                method = str(
                    environ.get('REQUEST_METHOD', common.Constants.unknown))
                self.logger.debug(
                    "unknown request: action: {0}, target_type_uri: {1}, scope: {2}, method: {3}, path: {4}"
                    .format(action, target_type_uri, scope, method, path))

                self.metricsClient.increment(
                    common.Constants.metric_requests_unknown_classification,
                    tags=[
                        'service:{0}'.format(self.service_type),
                        'service_name:{0}'.format(self.cadf_service_name),
                    ])
                return

            # Returns a RateLimitResponse or BlacklistResponse or None, in which case the original response is returned.
            rate_limit_response = self._rate_limit(
                scope=scope,
                action=action,
                target_type_uri=target_type_uri,
                scope_name_key=self._get_scope_name_key_from_environ(environ),
                username=self._get_username_from_environ(environ),
            )
            if rate_limit_response:
                rate_limit_response.set_environ(environ)
                resp = rate_limit_response

        except Exception as e:
            self.metricsClient.increment(common.Constants.metric_errors_total)
            self.logger.debug("checking rate limits failed with: {0}".format(
                str(e)))

        finally:
            self.metricsClient.close_buffer()
            return resp(environ, start_response)

    def is_scope_blacklisted(self, key_to_check):
        """
        Check whether a scope (user_id, project_id or client ip) is blacklisted.

        :param key_to_check: the user, project uid or client ip
        :return: bool whether the key is blacklisted
        """
        for entry in self.blacklist:
            if entry == key_to_check:
                return True
        return False

    def is_user_blacklisted(self, user_to_check):
        """
        Check whether a user is blacklisted.

        :param user_to_check: the name of the user to check
        :return: bool whether user is blacklisted
        """
        for u in self.blacklist_users:
            if str(u).lower() == str(user_to_check).lower():
                return True
        return False

    def is_scope_whitelisted(self, key_to_check):
        """
        Check whether a scope (user_id, project_id or client ip) is whitelisted.

        :param key_to_check: the user, project uid or client ip
        :return: bool whether the key is whitelisted
        """
        for entry in self.whitelist:
            if entry == key_to_check:
                return True
        return False

    def is_user_whitelisted(self, user_to_check):
        """
        Check whether a user is whitelisted.

        :param user_to_check: the name of the user to check
        :return: bool whether user is whitelisted
        """
        for u in self.whitelist_users:
            if str(u).lower() == str(user_to_check).lower():
                return True
        return False

    def get_scope_action_target_type_uri_from_environ(self, environ):
        """
        Get the scope, action, target type URI from the request environ.

        :param environ: the request environ
        :return: tuple of scope, action, target type URI
        """
        action = target_type_uri = scope = None
        try:
            # Get the CADF action.
            env_action = environ.get('WATCHER.ACTION')
            if not common.is_none_or_unknown(env_action):
                action = env_action

            # Get the target type URI.
            env_target_type_uri = environ.get('WATCHER.TARGET_TYPE_URI')
            if not common.is_none_or_unknown(env_target_type_uri):
                target_type_uri = env_target_type_uri

            # Get scope from request environment, which might be an initiator.project_id, target.project_id, etc. .
            env_scope = self._get_scope_from_environ(environ)
            if not common.is_none_or_unknown(env_scope):
                scope = env_scope

        except Exception as e:
            self.logger.debug(
                "error while getting scope, action, target type URI from environ: {0}"
                .format(str(e)))

        finally:
            self.logger.debug(
                'got WATCHER.* attributes from environ: action: {0}, target_type_uri: {1}, scope: {2}'
                .format(action, target_type_uri, scope))
            return scope, action, target_type_uri

    def _get_scope_from_environ(self, environ):
        """
        Get the scope from the requests environ.
        The scope is configurable and may be the target|initiator project uid or the initiator host address.
        Default to initiator project ID.

        :param environ: the requests environ
        :return: the scope
        """
        scope = None
        if self.rate_limit_by == common.Constants.target_project_id:
            env_scope = environ.get('WATCHER.TARGET_PROJECT_ID', None)
        elif self.rate_limit_by == common.Constants.initiator_host_address:
            env_scope = environ.get('WATCHER.INITIATOR_HOST_ADDRESS', None)
        else:
            env_scope = environ.get('WATCHER.INITIATOR_PROJECT_ID', None)

        # Ensure the scope is not 'unknown'.
        if not common.is_none_or_unknown(env_scope):
            scope = env_scope
        return scope

    def _get_scope_name_key_from_environ(self, environ):
        """
        Attempt to build the key '$domainName/$projectName' from the WATCHER attributes found in the request environ.

        :param environ: the request environ
        :return: the key or None
        """
        _domain_name = environ.get('WATCHER.INITIATOR_DOMAIN_NAME', None)
        _project_domain_name = environ.get(
            'WATCHER.INITIATOR_PROJECT_DOMAIN_NAME', None)
        project_name = environ.get('WATCHER.INITIATOR_PROJECT_NAME', None)
        domain_name = _project_domain_name or _domain_name

        if common.is_none_or_unknown(
                project_name) or common.is_none_or_unknown(domain_name):
            return None
        return '{0}/{1}'.format(domain_name, project_name)

    def _get_username_from_environ(self, environ):
        """
        Attempt to get username from WATCHER attributes found in request environ.

        :param environ: the request environ
        :return: the username or None
        """
        username = environ.get('WATCHER.INITIATOR_USER_NAME', None)
        if common.is_none_or_unknown(username):
            return None
        return username

    def _set_service_type_and_name(self, environ):
        """
        Set the service type and name according to the watchers classification passed in the request WSGI environ.
        Used if nothing was configured.

        :param environ: the request WSGI environment
        """
        # Get service type from request environ.
        if common.is_none_or_unknown(self.service_type):
            svc_type = environ.get('WATCHER.SERVICE_TYPE')
            if not common.is_none_or_unknown(svc_type):
                self.service_type = svc_type
                self.ratelimit_provider.service_type = self.service_type

        # set service name from environ
        if common.is_none_or_unknown(self.cadf_service_name):
            svc_name = environ.get('WATCHER.CADF_SERVICE_NAME')
            if not common.is_none_or_unknown(svc_name):
                self.cadf_service_name = svc_name
                self.ratelimit_provider.cadf_service_name = self.cadf_service_name

    def _trim_cadf_service_prefix_from_target_type_uri(self, prefix,
                                                       target_type_uri):
        """
        Get cadf service name and trim from target_type_uri.

        Example:
            target_type_uri:            service/storage/object/account/container/object
            cadf_service_name:          service/storage/object
            => trimmed_target_type_uri: account/container/object

        :param prefix: the cadf service name prefixing the target_type_uri
        :param target_type_uri: the target_type_uri with the prefix
        :return: target_type_uri without prefix
        """
        target_type_uri_without_prefix = target_type_uri
        try:
            without_prefix = target_type_uri.split(prefix)
            if len(without_prefix) != 2:
                raise IndexError
            target_type_uri_without_prefix = without_prefix[-1].lstrip('/')
        except IndexError as e:
            self.logger.warning(
                "rate limiting might not be possible. cannot trim prefix '{0}' from target_type_uri '{1}': {2}"
                .format(prefix, target_type_uri, str(e)))
        finally:
            return target_type_uri_without_prefix

    def get_action_from_rate_limit_groups(self, action):
        """
        Multiple CADF actions can be grouped and accounted as one entity.

        :param action: the original CADF action
        :return: the original action or action as per grouping
        """
        for group in self.rate_limit_groups:
            if action in self.rate_limit_groups[group]:
                return group
        return action
class OpenStackWatcherMiddleware(object):
    """
    OpenStack Watcher Middleware

    Watches OpenStack traffic and classifies according to CADF standard
    """
    def __init__(self, app, config, logger=logging.getLogger(__name__)):
        self.logger = logger
        self.app = app
        self.wsgi_config = config
        self.watcher_config = {}

        self.cadf_service_name = self.wsgi_config.get('cadf_service_name',
                                                      None)
        self.service_type = self.wsgi_config.get('service_type',
                                                 taxonomy.UNKNOWN)
        # get the project uid from the request path or from the token (default)
        self.is_project_id_from_path = common.string_to_bool(
            self.wsgi_config.get('target_project_id_from_path', 'False'))
        # get the project id from the service catalog (see documentation on keystone auth_token middleware)
        self.is_project_id_from_service_catalog = common.string_to_bool(
            self.wsgi_config.get('target_project_id_from_service_catalog',
                                 'False'))

        # whether to include the target project id in the metrics
        self.is_include_target_project_id_in_metric = common.string_to_bool(
            self.wsgi_config.get('include_target_project_id_in_metric',
                                 'True'))
        # whether to include the target domain id in the metrics
        self.is_include_target_domain_id_in_metric = common.string_to_bool(
            self.wsgi_config.get('include_target_domain_id_in_metric', 'True'))
        # whether to include the initiator user id in the metrics
        self.is_include_initiator_user_id_in_metric = common.string_to_bool(
            self.wsgi_config.get('include_initiator_user_id_in_metric',
                                 'False'))

        config_file_path = config.get('config_file', None)
        if config_file_path:
            try:
                self.watcher_config = load_config(config_file_path)
            except errors.ConfigError as e:
                self.logger.debug("custom actions not available: %s", str(e))

        custom_action_config = self.watcher_config.get('custom_actions', {})
        path_keywords = self.watcher_config.get('path_keywords', {})
        keyword_exclusions = self.watcher_config.get('keyword_exclusions', {})
        regex_mapping = self.watcher_config.get('regex_path_mapping', {})

        # init the strategy used to determine the target type uri
        strat = STRATEGIES.get(self.service_type, strategies.BaseCADFStrategy)

        # set custom prefix to target type URI or use defaults
        target_type_uri_prefix = common.SERVICE_TYPE_CADF_PREFIX_MAP.get(
            self.service_type, 'service/{0}'.format(self.service_type))

        if self.cadf_service_name:
            target_type_uri_prefix = self.cadf_service_name

        strategy = strat(target_type_uri_prefix=target_type_uri_prefix,
                         path_keywords=path_keywords,
                         keyword_exclusions=keyword_exclusions,
                         custom_action_config=custom_action_config,
                         regex_mapping=regex_mapping)

        self.strategy = strategy

        self.metric_client = DogStatsd(
            host=self.wsgi_config.get("statsd_host", "127.0.0.1"),
            port=int(self.wsgi_config.get("statsd_port", 9125)),
            namespace=self.wsgi_config.get("statsd_namespace",
                                           "openstack_watcher"))

    @classmethod
    def factory(cls, global_config, **local_config):
        conf = global_config.copy()
        conf.update(local_config)

        def watcher(app):
            return cls(app, conf)

        return watcher

    def __call__(self, environ, start_response):
        """
        WSGI entry point. Wraps environ in webob.Request

        :param environ: the WSGI environment dict
        :param start_response: WSGI callable
        """

        # capture start timestamp
        start = time.time()

        req = Request(environ)

        # determine initiator based on token context
        initiator_project_id = self.get_safe_from_environ(
            environ, 'HTTP_X_PROJECT_ID')
        initiator_project_name = self.get_safe_from_environ(
            environ, 'HTTP_X_PROJECT_NAME')
        initiator_project_domain_id = self.get_safe_from_environ(
            environ, 'HTTP_X_PROJECT_DOMAIN_ID')
        initiator_project_domain_name = self.get_safe_from_environ(
            environ, 'HTTP_X_PROJECT_DOMAIN_NAME')
        initiator_domain_id = self.get_safe_from_environ(
            environ, 'HTTP_X_DOMAIN_ID')
        initiator_domain_name = self.get_safe_from_environ(
            environ, 'HTTP_X_DOMAIN_NAME')
        initiator_user_id = self.get_safe_from_environ(environ,
                                                       'HTTP_X_USER_ID')
        initiator_user_name = self.get_safe_from_environ(
            environ, 'HTTP_X_USER_NAME')
        initiator_user_domain_id = self.get_safe_from_environ(
            environ, 'HTTP_X_USER_DOMAIN_ID')
        initiator_user_domain_name = self.get_safe_from_environ(
            environ, 'HTTP_X_USER_DOMAIN_NAME')
        initiator_host_address = req.client_addr or taxonomy.UNKNOWN

        # determine target based on request path or keystone.token_info
        target_project_id = taxonomy.UNKNOWN
        if self.is_project_id_from_path:
            target_project_id = self.get_target_project_uid_from_path(req.path)
        elif self.is_project_id_from_service_catalog:
            target_project_id = self.get_target_project_id_from_keystone_token_info(
                environ.get('keystone.token_info'))

        # default target_project_id to initiator_project_id if still unknown
        if not target_project_id or target_project_id == taxonomy.UNKNOWN:
            target_project_id = initiator_project_id

        # determine target.type_uri for request
        target_type_uri = self.determine_target_type_uri(req)

        # determine cadf_action for request. consider custom action config.
        cadf_action = self.determine_cadf_action(req, target_type_uri)

        # if authentication request consider project, domain and user in body
        if self.service_type == 'identity' and cadf_action == taxonomy.ACTION_AUTHENTICATE:
            initiator_project_id, initiator_domain_id, initiator_user_id = \
                self.get_project_domain_and_user_id_from_keystone_authentication_request(req)

        # set environ for initiator
        environ['WATCHER.INITIATOR_PROJECT_ID'] = initiator_project_id
        environ['WATCHER.INITIATOR_PROJECT_NAME'] = initiator_project_name
        environ[
            'WATCHER.INITIATOR_PROJECT_DOMAIN_ID'] = initiator_project_domain_id
        environ[
            'WATCHER.INITIATOR_PROJECT_DOMAIN_NAME'] = initiator_project_domain_name
        environ['WATCHER.INITIATOR_DOMAIN_ID'] = initiator_domain_id
        environ['WATCHER.INITIATOR_DOMAIN_NAME'] = initiator_domain_name
        environ['WATCHER.INITIATOR_USER_ID'] = initiator_user_id
        environ['WATCHER.INITIATOR_USER_NAME'] = initiator_user_name
        environ['WATCHER.INITIATOR_USER_DOMAIN_ID'] = initiator_user_domain_id
        environ[
            'WATCHER.INITIATOR_USER_DOMAIN_NAME'] = initiator_user_domain_name
        environ['WATCHER.INITIATOR_HOST_ADDRESS'] = initiator_host_address

        # set environ for target
        environ['WATCHER.TARGET_PROJECT_ID'] = target_project_id
        environ['WATCHER.TARGET_TYPE_URI'] = target_type_uri

        # general cadf attributes
        environ['WATCHER.ACTION'] = cadf_action
        environ['WATCHER.SERVICE_TYPE'] = self.service_type
        environ[
            'WATCHER.CADF_SERVICE_NAME'] = self.strategy.get_cadf_service_name(
            )

        # labels applied to all metrics emitted by this middleware
        labels = [
            "service_name:{0}".format(self.strategy.get_cadf_service_name()),
            "service:{0}".format(self.service_type),
            "action:{0}".format(cadf_action),
            "target_type_uri:{0}".format(target_type_uri),
        ]

        # additional labels not needed in all metrics
        detail_labels = [
            "initiator_project_id:{0}".format(initiator_project_id),
            "initiator_domain_id:{0}".format(initiator_domain_id),
        ]
        detail_labels = labels + detail_labels

        # include the target project id in metric
        if self.is_include_target_project_id_in_metric:
            detail_labels.append(
                "target_project_id:{0}".format(target_project_id))

        # include initiator user id
        if self.is_include_initiator_user_id_in_metric:
            detail_labels.append(
                "initiator_user_id:{0}".format(initiator_user_id))

        # if swift request: determine target.container_id based on request path
        if common.is_swift_request(
                req.path) or self.service_type == 'object-store':
            _, target_container_id = self.get_target_account_container_id_from_request(
                req)
            environ['WATCHER.TARGET_CONTAINER_ID'] = target_container_id

        self.logger.debug(
            'got request with initiator_project_id: {0}, initiator_domain_id: {1}, initiator_user_id: {2}, '
            'target_project_id: {3}, action: {4}, target_type_uri: {5}'.format(
                initiator_project_id, initiator_domain_id, initiator_user_id,
                target_project_id, cadf_action, target_type_uri))

        # capture the response status
        response_wrapper = {}

        try:

            def _start_response_wrapper(status, headers, exc_info=None):
                response_wrapper.update(status=status,
                                        headers=headers,
                                        exc_info=exc_info)
                return start_response(status, headers, exc_info)

            return self.app(environ, _start_response_wrapper)
        finally:
            try:
                self.metric_client.open_buffer()

                status = response_wrapper.get('status')
                if status:
                    status_code = status.split()[0]
                else:
                    status_code = taxonomy.UNKNOWN

                labels.append("status:{0}".format(status_code))
                detail_labels.append("status:{0}".format(status_code))

                self.metric_client.timing(
                    'api_requests_duration_seconds',
                    int(round(1000 * (time.time() - start))),
                    tags=labels)
                self.metric_client.increment('api_requests_total',
                                             tags=detail_labels)
            except Exception as e:
                self.logger.debug("failed to submit metrics for %s: %s" %
                                  (str(labels), str(e)))
            finally:
                self.metric_client.close_buffer()

    def get_safe_from_environ(self, environ, key, default=taxonomy.UNKNOWN):
        """
        get value for a key from the environ dict ensuring it's never None or an empty string

        :param environ: the request environ
        :param key: the key in the environ dictionary
        :param default: return value if key not found
        :return: the value to the key or default
        """
        val = default
        try:
            v = environ.get(key, default)
            if v and v != "":
                val = v

        except Exception as e:
            self.logger.debug("error getting '{0}' from environ: {1}".format(
                key, e))

        finally:
            return val

    def get_target_project_uid_from_path(self, path):
        """
        get the project uid from the path, which should look like
        ../v1.2/<project_uid>/.. or ../v1/AUTH_<project_uid>/..

        :param path: the request path containing a project uid
        :return: the project uid
        """
        project_uid = taxonomy.UNKNOWN
        try:
            if common.is_swift_request(
                    path) and self.strategy.name == 'object-store':
                project_uid = self.strategy.get_swift_project_id_from_path(
                    path)
            else:
                project_uid = common.get_project_id_from_os_path()
        finally:
            if project_uid == taxonomy.UNKNOWN:
                self.logger.debug(
                    "unable to obtain target.project_id from request path '{0}'"
                    .format(path))
            else:
                self.logger.debug(
                    "request path '{0}' contains target.project_id '{1}'".
                    format(path, project_uid))
            return project_uid

    def get_target_project_id_from_keystone_token_info(self, token_info):
        """
        the token info dict contains the service catalog, in which the project specific
        endpoint urls per service can be found.

        :param token_info: token info dictionary
        :return: the project id or unknown
        """
        project_id = taxonomy.UNKNOWN
        try:
            service_catalog = token_info.get('token', {}).get('catalog', [])
            if not service_catalog:
                raise None

            for service in service_catalog:
                svc_type = service.get('type', None)
                if not svc_type or svc_type != self.service_type:
                    continue

                svc_endpoints = service.get('endpoints', None)
                if not svc_endpoints:
                    continue

                project_id = self._get_project_id_from_service_endpoints(
                    svc_endpoints)
                if project_id:
                    break

        except Exception as e:
            self.logger.debug(
                'unable to get target.project_id from service catalog: ',
                str(e))

        finally:
            if project_id == taxonomy.UNKNOWN:
                self.logger.debug(
                    "unable to get target.project_id '{0}' for service type '{1}' from service catalog"
                    .format(project_id, self.service_type))
            else:
                self.logger.debug(
                    "got target.project_id '{0}' for service type '{1}' from service catalog"
                    .format(project_id, self.service_type))
            return project_id

    def _get_project_id_from_service_endpoints(self,
                                               endpoint_list,
                                               endpoint_type=None):
        """
        get the project id from an endpoint url for a given type | type = {public,internal,admin}

        :param endpoint_list: list of endpoints
        :param endpoint_type: optional endpoint type
        :return: the project id or unknown
        """
        project_id = taxonomy.UNKNOWN
        try:
            for endpoint in endpoint_list:
                url = endpoint.get('url', None)
                type = endpoint.get('interface', None)
                if not url or not type:
                    continue

                if self.strategy.name == 'object-store':
                    project_id = self.strategy.get_swift_project_id_from_path(
                        url)
                else:
                    project_id = common.get_project_id_from_os_path(url)

                # break here if endpoint_type is given and types match
                if endpoint_type and endpoint_type.lower() == type.lower():
                    break
                # break here if no endpoint_type given but project id was found
                elif not endpoint_type and project_id != taxonomy.UNKNOWN:
                    break
        finally:
            if project_id == taxonomy.UNKNOWN:
                self.logger.debug(
                    "found no project id in endpoints for service type '{0}'".
                    format(self.service_type))
            else:
                self.logger.debug(
                    "found target project id '{0}' in endpoints for service type '{1}'"
                    .format(project_id, self.service_type))
            return project_id

    def get_project_domain_and_user_id_from_keystone_authentication_request(
            self, req):
        """
        get project, domain, user id from authentication request.
        used in combination with client_addr to determine which client authenticates in which scope

        :param req: the request
        :return: project_id, domain_id, user_id
        """
        project_id = domain_id = user_id = taxonomy.UNKNOWN
        try:
            if not req.json:
                return

            json_body_dict = common.load_json_dict(req.json)
            if not json_body_dict:
                return
            project_id = common.find_project_id_in_auth_dict(json_body_dict)
            domain_id = common.find_domain_id_in_auth_dict(json_body_dict)
            user_id = common.find_user_id_in_auth_dict(json_body_dict)
        except Exception as e:
            self.logger.debug(
                'unable to parse keystone authentication request body: {0}'.
                format(str(e)))
        finally:
            return project_id, domain_id, user_id

    def get_target_account_container_id_from_request(self, req):
        """
        get swift account id, container name from request

        :param req: the request
        :return: account uid, container name or unknown
        """
        # break here if we don't have the object-store strategy
        if self.strategy.name != 'object-store':
            return taxonomy.UNKNOWN, taxonomy.UNKNOWN

        account_id, container_id, _ = self.strategy.get_swift_account_container_object_id_from_path(
            req.path)
        return account_id, container_id

    def determine_target_type_uri(self, req):
        """
        determine the target type uri as per concrete strategy

        :param req: the request
        :return: the target type uri or taxonomy.UNKNOWN
        """
        target_type_uri = self.strategy.determine_target_type_uri(req)
        self.logger.debug(
            "target type URI of requests '{0} {1}' is '{2}'".format(
                req.method, req.path, target_type_uri))
        return target_type_uri

    def determine_cadf_action(self, req, target_type_uri=None):
        """
        attempts to determine the cadf action for a request in the following order:
        (1) return custom action if one is configured
        (2) if /action, /os-instance-action request, return action from request body
        (3) return action based on request method

        :param custom_action_config: configuration of custom actions
        :param target_type_uri: the target type URI
        :param req: the request
        :return: the cadf action or unknown
        """
        cadf_action = self.strategy.determine_cadf_action(req, target_type_uri)
        self.logger.debug("cadf action for '{0} {1}' is '{2}'".format(
            req.method, req.path, cadf_action))
        return cadf_action
class StatsdMiddleware(object):
    def __init__(self,
                 app,
                 statsd_host='localhost',
                 statsd_port='8125',
                 statsd_prefix='openstack',
                 statsd_replace='id'):
        self.app = app
        self.replace_strategy = _ReplaceStrategy(
            os.getenv('STATSD_REPLACE', statsd_replace))
        self.client = DogStatsd(host=os.getenv('STATSD_HOST', statsd_host),
                                port=int(os.getenv('STATSD_PORT',
                                                   statsd_port)),
                                namespace=os.getenv('STATSD_PREFIX',
                                                    statsd_prefix))

    @classmethod
    def factory(cls, global_config, **local_config):
        def _factory(app):
            return cls(app, **local_config)

        return _factory

    def process_response(self,
                         start,
                         environ,
                         response_wrapper,
                         exception=None):
        self.client.increment('responses_total')

        status = response_wrapper.get('status')
        if status:
            status_code = status.split()[0]
        else:
            status_code = 'none'

        method = environ['REQUEST_METHOD']

        # cleanse request path
        path = urlparse.urlparse(environ['SCRIPT_NAME'] +
                                 environ['PATH_INFO']).path
        # strip extensions
        path = splitext(path)[0]

        # replace parts of the path with constants based on strategy
        path = self.replace_strategy.apply(path)

        parts = path.rstrip('\/').split('/')
        if exception:
            parts.append(exception.__class__.__name__)
        api = '/'.join(parts)

        self.client.timing('latency_by_api',
                           time.time() - start,
                           tags=['method:%s' % method,
                                 'api:%s' % api])

        self.client.increment('responses_by_api',
                              tags=[
                                  'method:%s' % method,
                                  'api:%s' % api,
                                  'status:%s' % status_code
                              ])

    def __call__(self, environ, start_response):
        response_interception = {}

        def start_response_wrapper(status, response_headers, exc_info=None):
            response_interception.update(status=status,
                                         response_headers=response_headers,
                                         exc_info=exc_info)
            return start_response(status, response_headers, exc_info)

        start = time.time()
        try:
            self.client.open_buffer()
            self.client.increment('requests_total')

            response = self.app(environ, start_response_wrapper)
            try:
                for event in response:
                    yield event
            finally:
                if hasattr(response, 'close'):
                    response.close()

            self.process_response(start, environ, response_interception)
        except Exception as exception:
            self.process_response(start, environ, response_interception,
                                  exception)
            raise
        finally:
            self.client.close_buffer()