def wrapper(self, *args, **kwargs): if self.mode == 'server': # In server mode, call the function return func(self, *args, **kwargs) # Make sure we're connected if not self.conn: self.connect() # Call the remote function self.conn.send('CALL', func.__name__, args, kwargs) # Receive the response cmd, payload = self.conn.recv() if cmd == 'ERR': self.close() raise Exception("Catastrophic error from server: %s" % payload[0]) elif cmd == 'EXC': exc_type = utils.find_entrypoint(None, payload[0]) raise exc_type(payload[1]) elif cmd != 'RES': self.close() raise Exception("Invalid command response from server: %s" % cmd) return payload[0]
def test_no_group(self, mock_iter_entry_points, mock_parse): result = utils.find_entrypoint(None, 'spam:ni', compat=False) self.assertEqual(result, 'class') self.assertFalse(mock_iter_entry_points.called) mock_parse.assert_called_once_with('x=spam:ni') mock_parse.return_value.load.assert_called_once_with(False)
def test_straight_load(self, mock_iter_entry_points): result = utils.find_entrypoint('test.group', 'endpoint') self.assertEqual(result, 'ep1') mock_iter_entry_points.assert_called_once_with( 'test.group', 'endpoint') mock_iter_entry_points.return_value[0].load.assert_called_once_with()
def test_with_compat_importerror(self, mock_iter_entry_points, mock_parse): result = utils.find_entrypoint('test.group', 'spam:ni') self.assertEqual(result, None) self.assertFalse(mock_iter_entry_points.called) mock_parse.assert_called_once_with('x=spam:ni') mock_parse.return_value.load.assert_called_once_with(False)
def test_no_compat(self, mock_iter_entry_points, mock_parse): result = utils.find_entrypoint('test.group', 'spam:ni', compat=False) self.assertEqual(result, None) mock_iter_entry_points.assert_called_once_with( 'test.group', 'spam:ni') self.assertFalse(mock_parse.called) self.assertFalse(mock_parse.return_value.load.called)
def hydrate(cls, db, limit): """ Given a limit dict, as generated by dehydrate(), generate an appropriate instance of Limit (or a subclass). If the required limit class cannot be found, returns None. """ # Extract the limit name from the keyword arguments cls_name = limit.pop('limit_class') # Is it in the registry yet? if cls_name not in cls._registry: utils.find_entrypoint(None, cls_name) # Look it up in the registry cls = cls._registry.get(cls_name) # Instantiate the thing return cls(db, **limit) if cls else None
def test_skip_errors(self, mock_iter_entry_points): result = utils.find_entrypoint('test.group', 'endpoint') self.assertEqual(result, 'ep3') mock_iter_entry_points.assert_called_once_with( 'test.group', 'endpoint') mock_iter_entry_points.return_value[0].load.assert_called_once_with() mock_iter_entry_points.return_value[1].load.assert_called_once_with() mock_iter_entry_points.return_value[2].load.assert_called_once_with() self.assertFalse(mock_iter_entry_points.return_value[3].load.called)
def turnstile_filter(global_conf, **local_conf): """ Factory function for turnstile. Returns a function which, when passed the application, returns an instance of the TurnstileMiddleware. """ # Select the appropriate middleware class to return klass = TurnstileMiddleware if "turnstile" in local_conf: klass = utils.find_entrypoint("turnstile.middleware", local_conf["turnstile"], required=True) def wrapper(app): return klass(app, local_conf) return wrapper
def turnstile_filter(global_conf, **local_conf): """ Factory function for turnstile. Returns a function which, when passed the application, returns an instance of the TurnstileMiddleware. """ # Select the appropriate middleware class to return klass = TurnstileMiddleware if 'turnstile' in local_conf: klass = utils.find_entrypoint('turnstile.middleware', local_conf['turnstile'], required=True) def wrapper(app): return klass(app, local_conf) return wrapper
def initialize(config): """ Initialize a connection to the Redis database. """ # Determine the client class to use if 'redis_client' in config: client = utils.find_entrypoint('turnstile.redis_client', config['redis_client'], required=True) else: client = redis.StrictRedis # Extract relevant connection information from the configuration kwargs = {} for cfg_var, type_ in [('host', str), ('port', int), ('db', int), ('password', str), ('socket_timeout', int), ('unix_socket_path', str)]: if cfg_var in config: kwargs[cfg_var] = type_(config[cfg_var]) # Make sure we have at a minimum the hostname if 'host' not in kwargs and 'unix_socket_path' not in kwargs: raise redis.ConnectionError("No host specified for redis database") # Look up the connection pool configuration cpool_class = None cpool = {} for key, value in config.items(): if key.startswith('connection_pool.'): _dummy, _sep, varname = key.partition('.') if varname == 'connection_class': cpool[varname] = utils.find_entrypoint( 'turnstile.connection_class', value, required=True) elif varname == 'max_connections': cpool[varname] = int(value) elif varname == 'parser_class': cpool[varname] = utils.find_entrypoint( 'turnstile.parser_class', value, required=True) else: cpool[varname] = value if cpool: cpool_class = redis.ConnectionPool # Use custom connection pool class if requested... if 'connection_pool' in config: cpool_class = utils.find_entrypoint('turnstile.connection_pool', config['connection_pool'], required=True) # If we're using a connection pool, we'll need to pass the keyword # arguments to that instead of to redis if cpool_class: cpool.update(kwargs) # Use a custom connection class? if 'connection_class' not in cpool: if 'unix_socket_path' in cpool: if 'host' in cpool: del cpool['host'] if 'port' in cpool: del cpool['port'] cpool['path'] = cpool['unix_socket_path'] del cpool['unix_socket_path'] cpool['connection_class'] = redis.UnixDomainSocketConnection else: cpool['connection_class'] = redis.Connection # Build the connection pool to use and set up to pass it into # the redis constructor... kwargs = dict(connection_pool=cpool_class(**cpool)) # Build and return the database return client(**kwargs)
def __init__(self, app, local_conf): """ Initialize the turnstile middleware. Saves the configuration and sets up the list of preprocessors, connects to the database, and initiates the control daemon thread. """ # Save the application self.app = app self.limits = [] self.limit_sum = None self.mapper = None self.mapper_lock = eventlet.semaphore.Semaphore() # Save the configuration self.conf = config.Config(conf_dict=local_conf) # We will lazy-load the database self._db = None # Set up request pre- and post-processors self.preprocessors = [] self.postprocessors = [] enable = self.conf.get('enable') if enable is not None: # Use the enabler syntax for proc in enable.split(): # Try the preprocessor preproc = utils.find_entrypoint('turnstile.preprocessor', proc, compat=False) if preproc: self.preprocessors.append(preproc) # Now the postprocessor postproc = utils.find_entrypoint('turnstile.postprocessor', proc, compat=False) if postproc: # Note the reversed order self.postprocessors.insert(0, postproc) else: # Using the classic syntax; grab preprocessors... for preproc in self.conf.get('preprocess', '').split(): klass = utils.find_entrypoint('turnstile.preprocessor', preproc, required=True) self.preprocessors.append(klass) # And now the postprocessors... for postproc in self.conf.get('postprocess', '').split(): klass = utils.find_entrypoint('turnstile.postprocessor', postproc, required=True) self.postprocessors.append(klass) # Set up the alternative formatter formatter = self.conf.get('formatter') if formatter: formatter = utils.find_entrypoint('turnstile.formatter', formatter, required=True) self.formatter = lambda a, b, c, d, e: formatter( self.conf.status, a, b, c, d, e) else: self.formatter = self.format_delay # Initialize the control daemon if self.conf.to_bool(self.conf['control'].get('remote', 'no'), False): self.control_daemon = remote.RemoteControlDaemon(self, self.conf) else: self.control_daemon = control.ControlDaemon(self, self.conf) # Now start the control daemon self.control_daemon.start() # Emit a log message to indicate that we're running LOG.info("Turnstile middleware initialized")
def parse_limit_node(db, idx, limit): """ Given an XML node describing a limit, return a Limit object. :param db: Handle for the Redis database. :param idx: The index of the limit in the XML file; used for error reporting. :param limit: The XML node describing the limit. """ # First, try to import the class; this will raise ImportError if # we can't import it klass = utils.find_entrypoint('turnstile.limit', limit.get('class'), required=True) # Build the list of required attributes required = set(k for k, v in klass.attrs.items() if 'default' not in v) # Now, use introspection on the class to interpret the attributes attrs = {} for child in limit: # Basic validation of child elements if child.tag != 'attr': warnings.warn("Unrecognized element %r while parsing limit at " "index %d; ignoring..." % (child.tag, idx)) continue # Get the attribute name attr = child.get('name') # Be liberal in what we accept--ignore unrecognized attributes # (with a warning) if attr not in klass.attrs: warnings.warn("Limit at index %d does not accept an attribute " "%r; ignoring..." % (idx, attr)) continue # OK, get the attribute descriptor desc = klass.attrs[attr] # Grab the attribute type attr_type = desc.get('type', str) if attr_type == list: # Lists are expressed as child elements; we ignore the # child element names subtype = desc.get('subtype', str) value = [] try: for j, grandchild in enumerate(child): if grandchild.tag != 'value': warnings.warn("Unrecognized element %r while parsing " "%r attribute of limit at index %d; " "ignoring element..." % (grandchild.tag, attr, idx)) continue value.append(subtype(grandchild.text)) except ValueError: warnings.warn("Invalid value %r while parsing element %d " "of %r attribute of limit at index %d; " "ignoring attribute..." % (grandchild.text, j, attr, idx)) continue elif attr_type == dict: # Dicts are expressed as child elements, with the tags # identifying the attribute name subtype = desc.get('subtype', str) value = {} for grandchild in child: if grandchild.tag != 'value': warnings.warn("Unrecognized element %r while parsing " "%r attribute of limit at index %d; " "ignoring element..." % (grandchild.tag, attr, idx)) continue elif 'key' not in grandchild.attrib: warnings.warn("Missing 'key' attribute of 'value' " "element while parsing %r attribute of " "limit at index %d; ignoring element..." % (attr, idx)) continue try: value[grandchild.get('key')] = subtype(grandchild.text) except ValueError: warnings.warn( "Invalid value %r while parsing %r element " "of %r attribute of limit at index %d; " "ignoring element..." % (grandchild.text, grandchild.get('key'), attr, idx)) continue elif attr_type == bool: try: value = config.Config.to_bool(child.text) except ValueError: warnings.warn("Unrecognized boolean value %r while parsing " "%r attribute of limit at index %d; " "ignoring..." % (child.text, attr, idx)) continue else: # Simple type conversion try: value = attr_type(child.text) except ValueError: warnings.warn("Invalid value %r while parsing %r attribute " "of limit at index %d; ignoring..." % (child.text, attr, idx)) continue # Save the attribute attrs[attr] = value # Remove from the required set required.discard(attr) # Did we get all required attributes? if required: raise TypeError("Missing required attributes %s" % (', '.join(repr(a) for a in sorted(required)))) # OK, instantiate and return the class return klass(db, **attrs)
def parse_limit_node(db, idx, limit): """ Given an XML node describing a limit, return a Limit object. :param db: Handle for the Redis database. :param idx: The index of the limit in the XML file; used for error reporting. :param limit: The XML node describing the limit. """ # First, try to import the class; this will raise ImportError if # we can't import it klass = utils.find_entrypoint("turnstile.limit", limit.get("class"), required=True) # Build the list of required attributes required = set(k for k, v in klass.attrs.items() if "default" not in v) # Now, use introspection on the class to interpret the attributes attrs = {} for child in limit: # Basic validation of child elements if child.tag != "attr": warnings.warn("Unrecognized element %r while parsing limit at " "index %d; ignoring..." % (child.tag, idx)) continue # Get the attribute name attr = child.get("name") # Be liberal in what we accept--ignore unrecognized attributes # (with a warning) if attr not in klass.attrs: warnings.warn("Limit at index %d does not accept an attribute " "%r; ignoring..." % (idx, attr)) continue # OK, get the attribute descriptor desc = klass.attrs[attr] # Grab the attribute type attr_type = desc.get("type", str) if attr_type == list: # Lists are expressed as child elements; we ignore the # child element names subtype = desc.get("subtype", str) value = [] try: for j, grandchild in enumerate(child): if grandchild.tag != "value": warnings.warn( "Unrecognized element %r while parsing " "%r attribute of limit at index %d; " "ignoring element..." % (grandchild.tag, attr, idx) ) continue value.append(subtype(grandchild.text)) except ValueError: warnings.warn( "Invalid value %r while parsing element %d " "of %r attribute of limit at index %d; " "ignoring attribute..." % (grandchild.text, j, attr, idx) ) continue elif attr_type == dict: # Dicts are expressed as child elements, with the tags # identifying the attribute name subtype = desc.get("subtype", str) value = {} for grandchild in child: if grandchild.tag != "value": warnings.warn( "Unrecognized element %r while parsing " "%r attribute of limit at index %d; " "ignoring element..." % (grandchild.tag, attr, idx) ) continue elif "key" not in grandchild.attrib: warnings.warn( "Missing 'key' attribute of 'value' " "element while parsing %r attribute of " "limit at index %d; ignoring element..." % (attr, idx) ) continue try: value[grandchild.get("key")] = subtype(grandchild.text) except ValueError: warnings.warn( "Invalid value %r while parsing %r element " "of %r attribute of limit at index %d; " "ignoring element..." % (grandchild.text, grandchild.get("key"), attr, idx) ) continue elif attr_type == bool: try: value = config.Config.to_bool(child.text) except ValueError: warnings.warn( "Unrecognized boolean value %r while parsing " "%r attribute of limit at index %d; " "ignoring..." % (child.text, attr, idx) ) continue else: # Simple type conversion try: value = attr_type(child.text) except ValueError: warnings.warn( "Invalid value %r while parsing %r attribute " "of limit at index %d; ignoring..." % (child.text, attr, idx) ) continue # Save the attribute attrs[attr] = value # Remove from the required set required.discard(attr) # Did we get all required attributes? if required: raise TypeError("Missing required attributes %s" % (", ".join(repr(a) for a in sorted(required)))) # OK, instantiate and return the class return klass(db, **attrs)
def test_no_endpoints(self, mock_iter_entry_points): result = utils.find_entrypoint('test.group', 'endpoint') self.assertEqual(result, None) mock_iter_entry_points.assert_called_once_with( 'test.group', 'endpoint')
def __init__(self, app, local_conf): """ Initialize the turnstile middleware. Saves the configuration and sets up the list of preprocessors, connects to the database, and initiates the control daemon thread. """ # Save the application self.app = app self.limits = [] self.limit_sum = None self.mapper = None self.mapper_lock = eventlet.semaphore.Semaphore() # Save the configuration self.conf = config.Config(conf_dict=local_conf) # We will lazy-load the database self._db = None # Set up request pre- and post-processors self.preprocessors = [] self.postprocessors = [] enable = self.conf.get("enable") if enable is not None: # Use the enabler syntax for proc in enable.split(): # Try the preprocessor preproc = utils.find_entrypoint("turnstile.preprocessor", proc, compat=False) if preproc: self.preprocessors.append(preproc) # Now the postprocessor postproc = utils.find_entrypoint("turnstile.postprocessor", proc, compat=False) if postproc: # Note the reversed order self.postprocessors.insert(0, postproc) else: # Using the classic syntax; grab preprocessors... for preproc in self.conf.get("preprocess", "").split(): klass = utils.find_entrypoint("turnstile.preprocessor", preproc, required=True) self.preprocessors.append(klass) # And now the postprocessors... for postproc in self.conf.get("postprocess", "").split(): klass = utils.find_entrypoint("turnstile.postprocessor", postproc, required=True) self.postprocessors.append(klass) # Set up the alternative formatter formatter = self.conf.get("formatter") if formatter: formatter = utils.find_entrypoint("turnstile.formatter", formatter, required=True) self.formatter = lambda a, b, c, d, e: formatter(self.conf.status, a, b, c, d, e) else: self.formatter = self.format_delay # Initialize the control daemon if self.conf.to_bool(self.conf["control"].get("remote", "no"), False): self.control_daemon = remote.RemoteControlDaemon(self, self.conf) else: self.control_daemon = control.ControlDaemon(self, self.conf) # Now start the control daemon self.control_daemon.start() # Emit a log message to indicate that we're running LOG.info("Turnstile middleware initialized")
def initialize(config): """ Initialize a connection to the Redis database. """ # Determine the client class to use if 'redis_client' in config: client = utils.find_entrypoint('turnstile.redis_client', config['redis_client'], required=True) else: client = redis.StrictRedis # Extract relevant connection information from the configuration kwargs = {} for cfg_var, type_ in REDIS_CONFIGS.items(): if cfg_var in config: kwargs[cfg_var] = type_(config[cfg_var]) # Make sure we have at a minimum the hostname if 'host' not in kwargs and 'unix_socket_path' not in kwargs: raise redis.ConnectionError("No host specified for redis database") # Look up the connection pool configuration cpool_class = None cpool = {} extra_kwargs = {} for key, value in config.items(): if key.startswith('connection_pool.'): _dummy, _sep, varname = key.partition('.') if varname == 'connection_class': cpool[varname] = utils.find_entrypoint( 'turnstile.connection_class', value, required=True) elif varname == 'max_connections': cpool[varname] = int(value) elif varname == 'parser_class': cpool[varname] = utils.find_entrypoint( 'turnstile.parser_class', value, required=True) else: cpool[varname] = value elif key not in REDIS_CONFIGS and key not in REDIS_EXCLUDES: extra_kwargs[key] = value if cpool: cpool_class = redis.ConnectionPool # Use custom connection pool class if requested... if 'connection_pool' in config: cpool_class = utils.find_entrypoint('turnstile.connection_pool', config['connection_pool'], required=True) # If we're using a connection pool, we'll need to pass the keyword # arguments to that instead of to redis if cpool_class: cpool.update(kwargs) # Use a custom connection class? if 'connection_class' not in cpool: if 'unix_socket_path' in cpool: if 'host' in cpool: del cpool['host'] if 'port' in cpool: del cpool['port'] cpool['path'] = cpool['unix_socket_path'] del cpool['unix_socket_path'] cpool['connection_class'] = redis.UnixDomainSocketConnection else: cpool['connection_class'] = redis.Connection # Build the connection pool to use and set up to pass it into # the redis constructor... kwargs = dict(connection_pool=cpool_class(**cpool)) # Build and return the database kwargs.update(extra_kwargs) return client(**kwargs)
def listen(self): """ Listen for incoming control messages. If the 'redis.shard_hint' configuration is set, its value will be passed to the pubsub() method when setting up the subscription. The control channel to subscribe to is specified by the 'redis.control_channel' configuration ('control' by default). """ # Use a specific database handle, with override. This allows # the long-lived listen thread to be configured to use a # different database or different database options. db = self.config.get_database('control') # Need a pub-sub object kwargs = {} if 'shard_hint' in self.config['control']: kwargs['shard_hint'] = self.config['control']['shard_hint'] pubsub = db.pubsub(**kwargs) # Subscribe to the right channel(s)... channel = self.config['control'].get('channel', 'control') pubsub.subscribe(channel) # Now we listen... for msg in pubsub.listen(): # Only interested in messages to our reload channel if (msg['type'] in ('pmessage', 'message') and msg['channel'] == channel): # Figure out what kind of message this is command, _sep, args = msg['data'].partition(':') # We must have some command... if not command: continue # Don't do anything with internal commands if command[0] == '_': LOG.error("Cannot call internal command %r" % command) continue # Look up the command if command in self._commands: func = self._commands[command] else: # Try an entrypoint func = utils.find_entrypoint('turnstile.command', command, compat=False) self._commands[command] = func # Don't do anything with missing commands if not func: LOG.error("No such command %r" % command) continue # Execute the desired command arglist = args.split(':') if args else [] try: func(self, *arglist) except Exception: LOG.exception("Failed to execute command %r arguments %r" % (command, arglist)) continue