Example #1
0
    def wrapper(self, *args, **kwargs):
        if self.mode == 'server':
            # In server mode, call the function
            return func(self, *args, **kwargs)

        # Make sure we're connected
        if not self.conn:
            self.connect()

        # Call the remote function
        self.conn.send('CALL', func.__name__, args, kwargs)

        # Receive the response
        cmd, payload = self.conn.recv()
        if cmd == 'ERR':
            self.close()
            raise Exception("Catastrophic error from server: %s" % payload[0])
        elif cmd == 'EXC':
            exc_type = utils.find_entrypoint(None, payload[0])
            raise exc_type(payload[1])
        elif cmd != 'RES':
            self.close()
            raise Exception("Invalid command response from server: %s" % cmd)

        return payload[0]
Example #2
0
    def test_no_group(self, mock_iter_entry_points, mock_parse):
        result = utils.find_entrypoint(None, 'spam:ni', compat=False)

        self.assertEqual(result, 'class')
        self.assertFalse(mock_iter_entry_points.called)
        mock_parse.assert_called_once_with('x=spam:ni')
        mock_parse.return_value.load.assert_called_once_with(False)
Example #3
0
    def wrapper(self, *args, **kwargs):
        if self.mode == 'server':
            # In server mode, call the function
            return func(self, *args, **kwargs)

        # Make sure we're connected
        if not self.conn:
            self.connect()

        # Call the remote function
        self.conn.send('CALL', func.__name__, args, kwargs)

        # Receive the response
        cmd, payload = self.conn.recv()
        if cmd == 'ERR':
            self.close()
            raise Exception("Catastrophic error from server: %s" %
                            payload[0])
        elif cmd == 'EXC':
            exc_type = utils.find_entrypoint(None, payload[0])
            raise exc_type(payload[1])
        elif cmd != 'RES':
            self.close()
            raise Exception("Invalid command response from server: %s" % cmd)

        return payload[0]
Example #4
0
    def test_straight_load(self, mock_iter_entry_points):
        result = utils.find_entrypoint('test.group', 'endpoint')

        self.assertEqual(result, 'ep1')
        mock_iter_entry_points.assert_called_once_with(
            'test.group', 'endpoint')
        mock_iter_entry_points.return_value[0].load.assert_called_once_with()
Example #5
0
    def test_with_compat_importerror(self, mock_iter_entry_points, mock_parse):
        result = utils.find_entrypoint('test.group', 'spam:ni')

        self.assertEqual(result, None)
        self.assertFalse(mock_iter_entry_points.called)
        mock_parse.assert_called_once_with('x=spam:ni')
        mock_parse.return_value.load.assert_called_once_with(False)
Example #6
0
    def test_no_compat(self, mock_iter_entry_points, mock_parse):
        result = utils.find_entrypoint('test.group', 'spam:ni', compat=False)

        self.assertEqual(result, None)
        mock_iter_entry_points.assert_called_once_with(
            'test.group', 'spam:ni')
        self.assertFalse(mock_parse.called)
        self.assertFalse(mock_parse.return_value.load.called)
Example #7
0
    def hydrate(cls, db, limit):
        """
        Given a limit dict, as generated by dehydrate(), generate an
        appropriate instance of Limit (or a subclass).  If the
        required limit class cannot be found, returns None.
        """

        # Extract the limit name from the keyword arguments
        cls_name = limit.pop('limit_class')

        # Is it in the registry yet?
        if cls_name not in cls._registry:
            utils.find_entrypoint(None, cls_name)

        # Look it up in the registry
        cls = cls._registry.get(cls_name)

        # Instantiate the thing
        return cls(db, **limit) if cls else None
Example #8
0
    def test_skip_errors(self, mock_iter_entry_points):
        result = utils.find_entrypoint('test.group', 'endpoint')

        self.assertEqual(result, 'ep3')
        mock_iter_entry_points.assert_called_once_with(
            'test.group', 'endpoint')
        mock_iter_entry_points.return_value[0].load.assert_called_once_with()
        mock_iter_entry_points.return_value[1].load.assert_called_once_with()
        mock_iter_entry_points.return_value[2].load.assert_called_once_with()
        self.assertFalse(mock_iter_entry_points.return_value[3].load.called)
Example #9
0
    def hydrate(cls, db, limit):
        """
        Given a limit dict, as generated by dehydrate(), generate an
        appropriate instance of Limit (or a subclass).  If the
        required limit class cannot be found, returns None.
        """

        # Extract the limit name from the keyword arguments
        cls_name = limit.pop('limit_class')

        # Is it in the registry yet?
        if cls_name not in cls._registry:
            utils.find_entrypoint(None, cls_name)

        # Look it up in the registry
        cls = cls._registry.get(cls_name)

        # Instantiate the thing
        return cls(db, **limit) if cls else None
Example #10
0
def turnstile_filter(global_conf, **local_conf):
    """
    Factory function for turnstile.

    Returns a function which, when passed the application, returns an
    instance of the TurnstileMiddleware.
    """

    # Select the appropriate middleware class to return
    klass = TurnstileMiddleware
    if "turnstile" in local_conf:
        klass = utils.find_entrypoint("turnstile.middleware", local_conf["turnstile"], required=True)

    def wrapper(app):
        return klass(app, local_conf)

    return wrapper
Example #11
0
def turnstile_filter(global_conf, **local_conf):
    """
    Factory function for turnstile.

    Returns a function which, when passed the application, returns an
    instance of the TurnstileMiddleware.
    """

    # Select the appropriate middleware class to return
    klass = TurnstileMiddleware
    if 'turnstile' in local_conf:
        klass = utils.find_entrypoint('turnstile.middleware',
                                      local_conf['turnstile'],
                                      required=True)

    def wrapper(app):
        return klass(app, local_conf)

    return wrapper
Example #12
0
def initialize(config):
    """
    Initialize a connection to the Redis database.
    """

    # Determine the client class to use
    if 'redis_client' in config:
        client = utils.find_entrypoint('turnstile.redis_client',
                                       config['redis_client'], required=True)
    else:
        client = redis.StrictRedis

    # Extract relevant connection information from the configuration
    kwargs = {}
    for cfg_var, type_ in [('host', str), ('port', int), ('db', int),
                           ('password', str), ('socket_timeout', int),
                           ('unix_socket_path', str)]:
        if cfg_var in config:
            kwargs[cfg_var] = type_(config[cfg_var])

    # Make sure we have at a minimum the hostname
    if 'host' not in kwargs and 'unix_socket_path' not in kwargs:
        raise redis.ConnectionError("No host specified for redis database")

    # Look up the connection pool configuration
    cpool_class = None
    cpool = {}
    for key, value in config.items():
        if key.startswith('connection_pool.'):
            _dummy, _sep, varname = key.partition('.')
            if varname == 'connection_class':
                cpool[varname] = utils.find_entrypoint(
                    'turnstile.connection_class', value, required=True)
            elif varname == 'max_connections':
                cpool[varname] = int(value)
            elif varname == 'parser_class':
                cpool[varname] = utils.find_entrypoint(
                    'turnstile.parser_class', value, required=True)
            else:
                cpool[varname] = value
    if cpool:
        cpool_class = redis.ConnectionPool

    # Use custom connection pool class if requested...
    if 'connection_pool' in config:
        cpool_class = utils.find_entrypoint('turnstile.connection_pool',
                                            config['connection_pool'],
                                            required=True)

    # If we're using a connection pool, we'll need to pass the keyword
    # arguments to that instead of to redis
    if cpool_class:
        cpool.update(kwargs)

        # Use a custom connection class?
        if 'connection_class' not in cpool:
            if 'unix_socket_path' in cpool:
                if 'host' in cpool:
                    del cpool['host']
                if 'port' in cpool:
                    del cpool['port']

                cpool['path'] = cpool['unix_socket_path']
                del cpool['unix_socket_path']

                cpool['connection_class'] = redis.UnixDomainSocketConnection
            else:
                cpool['connection_class'] = redis.Connection

        # Build the connection pool to use and set up to pass it into
        # the redis constructor...
        kwargs = dict(connection_pool=cpool_class(**cpool))

    # Build and return the database
    return client(**kwargs)
Example #13
0
    def __init__(self, app, local_conf):
        """
        Initialize the turnstile middleware.  Saves the configuration
        and sets up the list of preprocessors, connects to the
        database, and initiates the control daemon thread.
        """

        # Save the application
        self.app = app
        self.limits = []
        self.limit_sum = None
        self.mapper = None
        self.mapper_lock = eventlet.semaphore.Semaphore()

        # Save the configuration
        self.conf = config.Config(conf_dict=local_conf)

        # We will lazy-load the database
        self._db = None

        # Set up request pre- and post-processors
        self.preprocessors = []
        self.postprocessors = []
        enable = self.conf.get('enable')
        if enable is not None:
            # Use the enabler syntax
            for proc in enable.split():
                # Try the preprocessor
                preproc = utils.find_entrypoint('turnstile.preprocessor',
                                                proc,
                                                compat=False)
                if preproc:
                    self.preprocessors.append(preproc)

                # Now the postprocessor
                postproc = utils.find_entrypoint('turnstile.postprocessor',
                                                 proc,
                                                 compat=False)
                if postproc:
                    # Note the reversed order
                    self.postprocessors.insert(0, postproc)
        else:
            # Using the classic syntax; grab preprocessors...
            for preproc in self.conf.get('preprocess', '').split():
                klass = utils.find_entrypoint('turnstile.preprocessor',
                                              preproc,
                                              required=True)
                self.preprocessors.append(klass)

            # And now the postprocessors...
            for postproc in self.conf.get('postprocess', '').split():
                klass = utils.find_entrypoint('turnstile.postprocessor',
                                              postproc,
                                              required=True)
                self.postprocessors.append(klass)

        # Set up the alternative formatter
        formatter = self.conf.get('formatter')
        if formatter:
            formatter = utils.find_entrypoint('turnstile.formatter',
                                              formatter,
                                              required=True)
            self.formatter = lambda a, b, c, d, e: formatter(
                self.conf.status, a, b, c, d, e)
        else:
            self.formatter = self.format_delay

        # Initialize the control daemon
        if self.conf.to_bool(self.conf['control'].get('remote', 'no'), False):
            self.control_daemon = remote.RemoteControlDaemon(self, self.conf)
        else:
            self.control_daemon = control.ControlDaemon(self, self.conf)

        # Now start the control daemon
        self.control_daemon.start()

        # Emit a log message to indicate that we're running
        LOG.info("Turnstile middleware initialized")
Example #14
0
def parse_limit_node(db, idx, limit):
    """
    Given an XML node describing a limit, return a Limit object.

    :param db: Handle for the Redis database.
    :param idx: The index of the limit in the XML file; used for error
                reporting.
    :param limit: The XML node describing the limit.
    """

    # First, try to import the class; this will raise ImportError if
    # we can't import it
    klass = utils.find_entrypoint('turnstile.limit',
                                  limit.get('class'),
                                  required=True)

    # Build the list of required attributes
    required = set(k for k, v in klass.attrs.items() if 'default' not in v)

    # Now, use introspection on the class to interpret the attributes
    attrs = {}
    for child in limit:
        # Basic validation of child elements
        if child.tag != 'attr':
            warnings.warn("Unrecognized element %r while parsing limit at "
                          "index %d; ignoring..." % (child.tag, idx))
            continue

        # Get the attribute name
        attr = child.get('name')

        # Be liberal in what we accept--ignore unrecognized attributes
        # (with a warning)
        if attr not in klass.attrs:
            warnings.warn("Limit at index %d does not accept an attribute "
                          "%r; ignoring..." % (idx, attr))
            continue

        # OK, get the attribute descriptor
        desc = klass.attrs[attr]

        # Grab the attribute type
        attr_type = desc.get('type', str)

        if attr_type == list:
            # Lists are expressed as child elements; we ignore the
            # child element names
            subtype = desc.get('subtype', str)
            value = []
            try:
                for j, grandchild in enumerate(child):
                    if grandchild.tag != 'value':
                        warnings.warn("Unrecognized element %r while parsing "
                                      "%r attribute of limit at index %d; "
                                      "ignoring element..." %
                                      (grandchild.tag, attr, idx))
                        continue

                    value.append(subtype(grandchild.text))
            except ValueError:
                warnings.warn("Invalid value %r while parsing element %d "
                              "of %r attribute of limit at index %d; "
                              "ignoring attribute..." %
                              (grandchild.text, j, attr, idx))
                continue
        elif attr_type == dict:
            # Dicts are expressed as child elements, with the tags
            # identifying the attribute name
            subtype = desc.get('subtype', str)
            value = {}
            for grandchild in child:
                if grandchild.tag != 'value':
                    warnings.warn("Unrecognized element %r while parsing "
                                  "%r attribute of limit at index %d; "
                                  "ignoring element..." %
                                  (grandchild.tag, attr, idx))
                    continue
                elif 'key' not in grandchild.attrib:
                    warnings.warn("Missing 'key' attribute of 'value' "
                                  "element while parsing %r attribute of "
                                  "limit at index %d; ignoring element..." %
                                  (attr, idx))
                    continue

                try:
                    value[grandchild.get('key')] = subtype(grandchild.text)
                except ValueError:
                    warnings.warn(
                        "Invalid value %r while parsing %r element "
                        "of %r attribute of limit at index %d; "
                        "ignoring element..." %
                        (grandchild.text, grandchild.get('key'), attr, idx))
                    continue
        elif attr_type == bool:
            try:
                value = config.Config.to_bool(child.text)
            except ValueError:
                warnings.warn("Unrecognized boolean value %r while parsing "
                              "%r attribute of limit at index %d; "
                              "ignoring..." % (child.text, attr, idx))
                continue
        else:
            # Simple type conversion
            try:
                value = attr_type(child.text)
            except ValueError:
                warnings.warn("Invalid value %r while parsing %r attribute "
                              "of limit at index %d; ignoring..." %
                              (child.text, attr, idx))
                continue

        # Save the attribute
        attrs[attr] = value

        # Remove from the required set
        required.discard(attr)

    # Did we get all required attributes?
    if required:
        raise TypeError("Missing required attributes %s" %
                        (', '.join(repr(a) for a in sorted(required))))

    # OK, instantiate and return the class
    return klass(db, **attrs)
Example #15
0
def parse_limit_node(db, idx, limit):
    """
    Given an XML node describing a limit, return a Limit object.

    :param db: Handle for the Redis database.
    :param idx: The index of the limit in the XML file; used for error
                reporting.
    :param limit: The XML node describing the limit.
    """

    # First, try to import the class; this will raise ImportError if
    # we can't import it
    klass = utils.find_entrypoint("turnstile.limit", limit.get("class"), required=True)

    # Build the list of required attributes
    required = set(k for k, v in klass.attrs.items() if "default" not in v)

    # Now, use introspection on the class to interpret the attributes
    attrs = {}
    for child in limit:
        # Basic validation of child elements
        if child.tag != "attr":
            warnings.warn("Unrecognized element %r while parsing limit at " "index %d; ignoring..." % (child.tag, idx))
            continue

        # Get the attribute name
        attr = child.get("name")

        # Be liberal in what we accept--ignore unrecognized attributes
        # (with a warning)
        if attr not in klass.attrs:
            warnings.warn("Limit at index %d does not accept an attribute " "%r; ignoring..." % (idx, attr))
            continue

        # OK, get the attribute descriptor
        desc = klass.attrs[attr]

        # Grab the attribute type
        attr_type = desc.get("type", str)

        if attr_type == list:
            # Lists are expressed as child elements; we ignore the
            # child element names
            subtype = desc.get("subtype", str)
            value = []
            try:
                for j, grandchild in enumerate(child):
                    if grandchild.tag != "value":
                        warnings.warn(
                            "Unrecognized element %r while parsing "
                            "%r attribute of limit at index %d; "
                            "ignoring element..." % (grandchild.tag, attr, idx)
                        )
                        continue

                    value.append(subtype(grandchild.text))
            except ValueError:
                warnings.warn(
                    "Invalid value %r while parsing element %d "
                    "of %r attribute of limit at index %d; "
                    "ignoring attribute..." % (grandchild.text, j, attr, idx)
                )
                continue
        elif attr_type == dict:
            # Dicts are expressed as child elements, with the tags
            # identifying the attribute name
            subtype = desc.get("subtype", str)
            value = {}
            for grandchild in child:
                if grandchild.tag != "value":
                    warnings.warn(
                        "Unrecognized element %r while parsing "
                        "%r attribute of limit at index %d; "
                        "ignoring element..." % (grandchild.tag, attr, idx)
                    )
                    continue
                elif "key" not in grandchild.attrib:
                    warnings.warn(
                        "Missing 'key' attribute of 'value' "
                        "element while parsing %r attribute of "
                        "limit at index %d; ignoring element..." % (attr, idx)
                    )
                    continue

                try:
                    value[grandchild.get("key")] = subtype(grandchild.text)
                except ValueError:
                    warnings.warn(
                        "Invalid value %r while parsing %r element "
                        "of %r attribute of limit at index %d; "
                        "ignoring element..." % (grandchild.text, grandchild.get("key"), attr, idx)
                    )
                    continue
        elif attr_type == bool:
            try:
                value = config.Config.to_bool(child.text)
            except ValueError:
                warnings.warn(
                    "Unrecognized boolean value %r while parsing "
                    "%r attribute of limit at index %d; "
                    "ignoring..." % (child.text, attr, idx)
                )
                continue
        else:
            # Simple type conversion
            try:
                value = attr_type(child.text)
            except ValueError:
                warnings.warn(
                    "Invalid value %r while parsing %r attribute "
                    "of limit at index %d; ignoring..." % (child.text, attr, idx)
                )
                continue

        # Save the attribute
        attrs[attr] = value

        # Remove from the required set
        required.discard(attr)

    # Did we get all required attributes?
    if required:
        raise TypeError("Missing required attributes %s" % (", ".join(repr(a) for a in sorted(required))))

    # OK, instantiate and return the class
    return klass(db, **attrs)
Example #16
0
    def test_no_endpoints(self, mock_iter_entry_points):
        result = utils.find_entrypoint('test.group', 'endpoint')

        self.assertEqual(result, None)
        mock_iter_entry_points.assert_called_once_with(
            'test.group', 'endpoint')
Example #17
0
    def __init__(self, app, local_conf):
        """
        Initialize the turnstile middleware.  Saves the configuration
        and sets up the list of preprocessors, connects to the
        database, and initiates the control daemon thread.
        """

        # Save the application
        self.app = app
        self.limits = []
        self.limit_sum = None
        self.mapper = None
        self.mapper_lock = eventlet.semaphore.Semaphore()

        # Save the configuration
        self.conf = config.Config(conf_dict=local_conf)

        # We will lazy-load the database
        self._db = None

        # Set up request pre- and post-processors
        self.preprocessors = []
        self.postprocessors = []
        enable = self.conf.get("enable")
        if enable is not None:
            # Use the enabler syntax
            for proc in enable.split():
                # Try the preprocessor
                preproc = utils.find_entrypoint("turnstile.preprocessor", proc, compat=False)
                if preproc:
                    self.preprocessors.append(preproc)

                # Now the postprocessor
                postproc = utils.find_entrypoint("turnstile.postprocessor", proc, compat=False)
                if postproc:
                    # Note the reversed order
                    self.postprocessors.insert(0, postproc)
        else:
            # Using the classic syntax; grab preprocessors...
            for preproc in self.conf.get("preprocess", "").split():
                klass = utils.find_entrypoint("turnstile.preprocessor", preproc, required=True)
                self.preprocessors.append(klass)

            # And now the postprocessors...
            for postproc in self.conf.get("postprocess", "").split():
                klass = utils.find_entrypoint("turnstile.postprocessor", postproc, required=True)
                self.postprocessors.append(klass)

        # Set up the alternative formatter
        formatter = self.conf.get("formatter")
        if formatter:
            formatter = utils.find_entrypoint("turnstile.formatter", formatter, required=True)
            self.formatter = lambda a, b, c, d, e: formatter(self.conf.status, a, b, c, d, e)
        else:
            self.formatter = self.format_delay

        # Initialize the control daemon
        if self.conf.to_bool(self.conf["control"].get("remote", "no"), False):
            self.control_daemon = remote.RemoteControlDaemon(self, self.conf)
        else:
            self.control_daemon = control.ControlDaemon(self, self.conf)

        # Now start the control daemon
        self.control_daemon.start()

        # Emit a log message to indicate that we're running
        LOG.info("Turnstile middleware initialized")
Example #18
0
def initialize(config):
    """
    Initialize a connection to the Redis database.
    """

    # Determine the client class to use
    if 'redis_client' in config:
        client = utils.find_entrypoint('turnstile.redis_client',
                                       config['redis_client'],
                                       required=True)
    else:
        client = redis.StrictRedis

    # Extract relevant connection information from the configuration
    kwargs = {}
    for cfg_var, type_ in REDIS_CONFIGS.items():
        if cfg_var in config:
            kwargs[cfg_var] = type_(config[cfg_var])

    # Make sure we have at a minimum the hostname
    if 'host' not in kwargs and 'unix_socket_path' not in kwargs:
        raise redis.ConnectionError("No host specified for redis database")

    # Look up the connection pool configuration
    cpool_class = None
    cpool = {}
    extra_kwargs = {}
    for key, value in config.items():
        if key.startswith('connection_pool.'):
            _dummy, _sep, varname = key.partition('.')
            if varname == 'connection_class':
                cpool[varname] = utils.find_entrypoint(
                    'turnstile.connection_class', value, required=True)
            elif varname == 'max_connections':
                cpool[varname] = int(value)
            elif varname == 'parser_class':
                cpool[varname] = utils.find_entrypoint(
                    'turnstile.parser_class', value, required=True)
            else:
                cpool[varname] = value
        elif key not in REDIS_CONFIGS and key not in REDIS_EXCLUDES:
            extra_kwargs[key] = value
    if cpool:
        cpool_class = redis.ConnectionPool

    # Use custom connection pool class if requested...
    if 'connection_pool' in config:
        cpool_class = utils.find_entrypoint('turnstile.connection_pool',
                                            config['connection_pool'],
                                            required=True)

    # If we're using a connection pool, we'll need to pass the keyword
    # arguments to that instead of to redis
    if cpool_class:
        cpool.update(kwargs)

        # Use a custom connection class?
        if 'connection_class' not in cpool:
            if 'unix_socket_path' in cpool:
                if 'host' in cpool:
                    del cpool['host']
                if 'port' in cpool:
                    del cpool['port']

                cpool['path'] = cpool['unix_socket_path']
                del cpool['unix_socket_path']

                cpool['connection_class'] = redis.UnixDomainSocketConnection
            else:
                cpool['connection_class'] = redis.Connection

        # Build the connection pool to use and set up to pass it into
        # the redis constructor...
        kwargs = dict(connection_pool=cpool_class(**cpool))

    # Build and return the database
    kwargs.update(extra_kwargs)
    return client(**kwargs)
Example #19
0
    def listen(self):
        """
        Listen for incoming control messages.

        If the 'redis.shard_hint' configuration is set, its value will
        be passed to the pubsub() method when setting up the
        subscription.  The control channel to subscribe to is
        specified by the 'redis.control_channel' configuration
        ('control' by default).
        """

        # Use a specific database handle, with override.  This allows
        # the long-lived listen thread to be configured to use a
        # different database or different database options.
        db = self.config.get_database('control')

        # Need a pub-sub object
        kwargs = {}
        if 'shard_hint' in self.config['control']:
            kwargs['shard_hint'] = self.config['control']['shard_hint']
        pubsub = db.pubsub(**kwargs)

        # Subscribe to the right channel(s)...
        channel = self.config['control'].get('channel', 'control')
        pubsub.subscribe(channel)

        # Now we listen...
        for msg in pubsub.listen():
            # Only interested in messages to our reload channel
            if (msg['type'] in ('pmessage', 'message') and
                    msg['channel'] == channel):
                # Figure out what kind of message this is
                command, _sep, args = msg['data'].partition(':')

                # We must have some command...
                if not command:
                    continue

                # Don't do anything with internal commands
                if command[0] == '_':
                    LOG.error("Cannot call internal command %r" % command)
                    continue

                # Look up the command
                if command in self._commands:
                    func = self._commands[command]
                else:
                    # Try an entrypoint
                    func = utils.find_entrypoint('turnstile.command', command,
                                                 compat=False)
                    self._commands[command] = func

                # Don't do anything with missing commands
                if not func:
                    LOG.error("No such command %r" % command)
                    continue

                # Execute the desired command
                arglist = args.split(':') if args else []
                try:
                    func(self, *arglist)
                except Exception:
                    LOG.exception("Failed to execute command %r arguments %r" %
                                  (command, arglist))
                    continue