def test_get_config(): cache = DictDB() cache.set('name', b'master') cache.set('pub_address', pub_address.encode('utf-8')) cache.set('pull_address', pull_address.encode('utf-8')) cache = CacheService('db', db_address, cache=cache, logger=logging, messages=3, ) def boot_client(): client = Client('master', db_address, session=None) return client.push_address, client.sub_address def broker(): socket = zmq_context.socket(zmq.ROUTER) socket.bind(broker_address) with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: results = [ executor.submit(cache.start), executor.submit(boot_client), executor.submit(broker) ] # This works because servers do not return values. for i, future in enumerate(concurrent.futures.as_completed(results)): try: result = future.result() print(result) except Exception as exc: print(exc) lines = traceback.format_exception(*sys.exc_info()) print(lines) assert i == 2
class ServerTemplate(object): """ Low-level tool to build a server from parts. :param logging_level: A correct logging level from the logging module. Defaults to INFO. It has important attributes that you may want to override, like :cache: The key-value database that the server should use :logging_level: Controls the log output of the server. :router: Here's the router, you may want to change its attributes too. """ def __init__(self, logging_level=logging.INFO, router_messages=sys.maxsize): # Name of the server self.name = '' # Logging level for the server self.logging_level = logging_level # Basic Key-value database for storage self.cache = DictDB() self.inbound_components = {} self.outbound_components = {} self.bypass_components = {} # Basic console logging self.logger = logging.getLogger(name=self.name) handler = logging.StreamHandler(sys.stdout) handler.setFormatter( logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s')) self.logger.addHandler(handler) self.logger.setLevel(self.logging_level) # Finally, the router self.router = Router(logger=self.logger, cache=self.cache, messages=router_messages) def register_inbound(self, part, name='', listen_address='', route='', block=False, log='', **kwargs): """ Register inbound part to this server. :param part: part class :param name: Name of the part :param listen_address: Valid ZeroMQ address listening to the exterior :param route: Outbound part it routes to :param block: True if the part blocks waiting for a response :param log: Log message in DEBUG level for each message processed. :param kwargs: Additional keyword arguments to pass to the part """ # Inject the server cache in case it is not configured for the component if 'cache' not in kwargs: kwargs['cache'] = self.cache instance = part(name, listen_address, broker_address=self.router.inbound_address, logger=self.logger, **kwargs) self.router.register_inbound(name, route=route, block=block, log=log) self.inbound_components[name] = instance def register_outbound(self, part, name='', listen_address='', route='', log='', **kwargs): """ Register outbound part to this server :param part: part class :param name: Name of the part :param listen_address: Valid ZeroMQ address listening to the exterior :param route: Outbound part it routes the response (if there is) to :param log: Log message in DEBUG level for each message processed :param kwargs: Additional keyword arguments to pass to the part """ # Inject the server cache in case it is not configured for the component if 'cache' not in kwargs: kwargs['cache'] = self.cache instance = part(name, listen_address, broker_address=self.router.outbound_address, logger=self.logger, **kwargs) self.router.register_outbound(name, route=route, log=log) self.outbound_components[name] = instance def register_bypass(self, part, name='', listen_address='', **kwargs): """ Register a bypass part to this server :param part: part class :param name: part name :param listen_address: Valid ZeroMQ address listening to the exterior :param kwargs: Additional keyword arguments to pass to the part """ # Inject the server cache in case it is not configured for the component if 'cache' not in kwargs: kwargs['cache'] = self.cache instance = part(name, listen_address, logger=self.logger, **kwargs) self.bypass_components[name] = instance def preset_cache(self, **kwargs): """ Send the following keyword arguments as cache variables. Useful for configuration variables that the workers or the clients fetch straight from the cache. :param kwargs: """ for arg, val in kwargs.items(): if type(val) == str: self.cache.set(arg, val.encode('utf-8')) else: self.cache.set(arg, val) def start(self): """ Start the server with all its parts. """ threads = [] self.logger.info("Starting the router") threads.append(self.router.start) for name, part in self.inbound_components.items(): self.logger.info("Starting inbound part {}".format(name)) threads.append(part.start) for name, part in self.outbound_components.items(): self.logger.info("Starting outbound part {}".format(name)) threads.append(part.start) for name, part in self.bypass_components.items(): self.logger.info("Starting bypass part {}".format(name)) threads.append(part.start) with concurrent.futures.ThreadPoolExecutor( max_workers=len(threads)) as executor: results = [executor.submit(thread) for thread in threads] for future in concurrent.futures.as_completed(results): try: future.result() except Exception as exc: self.logger.error( 'This is critical, one of the parts died') lines = traceback.format_exception(*sys.exc_info()) for line in lines: self.logger.error(line.strip('\n'))
def test_send_job(): cache = DictDB() cache.set('name', b'master') cache.set('pub_address', pub_address.encode('utf-8')) cache.set('pull_address', pull_address.encode('utf-8')) cache_service = CacheService('db', db_address, cache=cache, logger=logging, messages=3, ) puller = PullService('puller', pull_address, logger=logging, cache=cache, messages=1) publisher = PubService('publisher', pub_address, logger=logging, cache=cache, messages=1) def client_job(): client = Client('master', db_address, session=None) return [r for r in client.job('master.something', [b'1'], messages=1)] def broker(): socket = zmq_context.socket(zmq.ROUTER) socket.bind(broker_address) message = socket.recv_multipart() # Unblock. Here you see why the actual router is complicated. socket.send_multipart(message) socket.send_multipart([b'publisher', b'', message[2]]) socket.close() return b'router' with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: results = [ executor.submit(cache_service.start), executor.submit(puller.start), executor.submit(publisher.start), executor.submit(client_job), executor.submit(broker) ] # This works because servers do not return values. for i, future in enumerate(concurrent.futures.as_completed(results)): try: result = future.result() print(result) except Exception as exc: print(exc) lines = traceback.format_exception(*sys.exc_info()) print(*lines) assert i == 4
def test_multiple_clients(): cache = DictDB() cache.set('name', b'master') cache.set('pub_address', pub_address.encode('utf-8')) cache.set('pull_address', pull_address.encode('utf-8')) router = Router(logger=logging, cache=cache, messages=4) router.register_inbound('puller', route='publisher') router.register_outbound('publisher') cache_service = CacheService('db', db_address, cache=cache, logger=logging, messages=6, ) puller = PullService('puller', pull_address, broker_address=router.inbound_address, logger=logging, cache=cache, messages=4) publisher = PubService('publisher', pub_address, broker_address=router.outbound_address, logger=logging, cache=cache, messages=4) def client1_job(): client = Client('master', db_address, session=None) return [r for r in client.job('master.something', [b'1', b'2'], messages=2)] def client2_job(): client = Client('master', db_address, session=None) return [r for r in client.job('master.something', [b'3', b'4'], messages=2)] with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor: results = [ executor.submit(cache_service.start), executor.submit(puller.start), executor.submit(publisher.start), executor.submit(router.start), executor.submit(client1_job), executor.submit(client2_job), ] # This works because servers do not return values. for i, future in enumerate(concurrent.futures.as_completed(results)): try: result = future.result() if type(result) == list: got = [] for r in result: message = PalmMessage() message.ParseFromString(r) got.append(message.payload) assert got == [b'1', b'2'] or got == [b'3', b'4'] except Exception as exc: print(exc) lines = traceback.format_exception(*sys.exc_info()) print(*lines) assert i == 5
class Server(object): """ Standalone and minimal server that replies single requests. :param str name: Name of the server :param str db_address: ZeroMQ address of the cache service. :param str pull_address: Address of the pull socket :param str pub_address: Address of the pub socket :param pipelined: True if the server is chained to another server. :param log_level: Minimum output log level. :param int messages: Total number of messages that the server processes. Useful for debugging. """ def __init__(self, name, db_address, pull_address, pub_address, pipelined=False, log_level=logging.INFO, messages=sys.maxsize): self.name = name self.cache = DictDB() self.db_address = db_address self.pull_address = pull_address self.pub_address = pub_address self.pipelined = pipelined self.message = None self.cache.set('name', name.encode('utf-8')) self.cache.set('pull_address', pull_address.encode('utf-8')) self.cache.set('pub_address', pub_address.encode('utf-8')) self.logger = logging.getLogger(name=name) handler = logging.StreamHandler(sys.stdout) handler.setFormatter( logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) ) self.logger.addHandler(handler) self.logger.setLevel(log_level) self.messages = messages self.pull_socket = zmq_context.socket(zmq.PULL) self.pull_socket.bind(self.pull_address) self.pub_socket = zmq_context.socket(zmq.PUB) self.pub_socket.bind(self.pub_address) def handle_stream(self, message): """ Handle the stream of messages. :param message: The message about to be sent to the next step in the cluster :return: topic (str) and message (PalmMessage) The default behaviour is the following. If you leave this function unchanged and pipeline is set to False, the topic is the ID of the client, which makes the message return to the client. If the pipeline parameter is set to True, the topic is set as the name of the server and the step of the message is incremented by one. You can alter this default behaviour by overriding this function. Take into account that the message is also available in this function, and you can change other parameters like the stage or the function. """ if self.pipelined: topic = self.name message.stage += 1 else: topic = message.client return topic, message def echo(self, payload): """ Echo utility function that returns the unchanged payload. This function is useful when the server is there as just to modify the stream of messages. :return: payload (bytes) """ return payload def _execution_handler(self): for i in range(self.messages): self.logger.debug('Server waiting for messages') message_data = self.pull_socket.recv() self.logger.debug('Got message {}'.format(i + 1)) result = b'0' self.message = PalmMessage() try: self.message.ParseFromString(message_data) # Handle the fact that the message may be a complete pipeline try: if ' ' in self.message.function: [server, function] = self.message.function.split()[ self.message.stage].split('.') else: [server, function] = self.message.function.split('.') except IndexError: raise ValueError('Pipeline call not correct. Review the ' 'config in your client') if not self.name == server: self.logger.error('You called {}, instead of {}'.format( server, self.name)) else: try: user_function = getattr(self, function) self.logger.debug('Looking for {}'.format(function)) try: result = user_function(self.message.payload) except: self.logger.error('User function gave an error') exc_type, exc_value, exc_traceback = sys.exc_info() lines = traceback.format_exception( exc_type, exc_value, exc_traceback) for l in lines: self.logger.exception(l) except KeyError: self.logger.error( 'Function {} was not found'.format(function) ) except DecodeError: self.logger.error('Message could not be decoded') self.message.payload = result topic, self.message = self.handle_stream(self.message) self.pub_socket.send_multipart( [topic.encode('utf-8'), self.message.SerializeToString()] ) def start(self, cache_messages=sys.maxsize): """ Start the server :param cache_messages: Number of messages the cache service handles before it shuts down. Useful for debugging """ threads = [] cache = CacheService('cache', self.db_address, logger=self.logger, cache=self.cache, messages=cache_messages) threads.append(cache.start) threads.append(self._execution_handler) with concurrent.futures.ThreadPoolExecutor(max_workers=len(threads)) as executor: results = [executor.submit(thread) for thread in threads] for future in concurrent.futures.as_completed(results): try: future.result() except Exception as exc: self.logger.error( 'This is critical, one of the components of the ' 'server died') lines = traceback.format_exception(*sys.exc_info()) for line in lines: self.logger.error(line.strip('\n')) return self.name.encode('utf-8')
class Sink(Server): """ Minimal server that acts as a sink of multiple streams. :param str name: Name of the server :param str db_address: ZeroMQ address of the cache service. :param str sub_addresses: List of addresses of the pub socket of the previous servers :param str pub_address: Address of the pub socket :param previous: List of names of the previous servers. :param to_client: True if the message is sent back to the client. Defaults to True :param log_level: Minimum output log level. Defaults to INFO :param int messages: Total number of messages that the server processes. Defaults to Infty Useful for debugging. """ def __init__(self, name, db_address, sub_addresses, pub_address, previous, to_client=True, log_level=logging.INFO, messages=sys.maxsize): self.name = name self.cache = DictDB() self.db_address = db_address self.sub_addresses = sub_addresses self.pub_address = pub_address self.pipelined = not to_client self.message = None self.cache.set('name', name.encode('utf-8')) for i, address in enumerate(sub_addresses): self.cache.set('sub_address_{}'.format(i), address.encode('utf-8')) self.cache.set('pub_address', pub_address.encode('utf-8')) self.logger = logging.getLogger(name=name) handler = logging.StreamHandler(sys.stdout) handler.setFormatter( logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) ) self.logger.addHandler(handler) self.logger.setLevel(log_level) self.messages = messages self.sub_sockets = list() # Simple type checks assert type(previous) == list assert type(sub_addresses) == list for address, prev in zip(self.sub_addresses, previous): self.sub_sockets.append(zmq_context.socket(zmq.SUB)) self.sub_sockets[-1].setsockopt_string(zmq.SUBSCRIBE, prev) self.sub_sockets[-1].connect(address) self.pub_socket = zmq_context.socket(zmq.PUB) self.pub_socket.bind(self.pub_address) self.poller = zmq.Poller() for sock in self.sub_sockets: self.poller.register(sock, zmq.POLLIN) def _execution_handler(self): for i in range(self.messages): self.logger.debug('Server waiting for messages') locked_socks = dict(self.poller.poll()) for sock in self.sub_sockets: if sock in locked_socks: message_data = sock.recv_multipart()[1] self.logger.debug('Got message {}'.format(i + 1)) result = b'0' self.message = PalmMessage() try: self.message.ParseFromString(message_data) # Handle the fact that the message may be a complete pipeline try: if ' ' in self.message.function: [server, function] = self.message.function.split()[ self.message.stage].split('.') else: [server, function] = self.message.function.split('.') except IndexError: raise ValueError('Pipeline call not correct. Review the ' 'config in your client') if not self.name == server: self.logger.error('You called {}, instead of {}'.format( server, self.name)) else: try: user_function = getattr(self, function) self.logger.debug('Looking for {}'.format(function)) try: result = user_function(self.message.payload) except: self.logger.error('User function gave an error') exc_type, exc_value, exc_traceback = sys.exc_info() lines = traceback.format_exception( exc_type, exc_value, exc_traceback) for l in lines: self.logger.exception(l) except KeyError: self.logger.error( 'Function {} was not found'.format(function) ) except DecodeError: self.logger.error('Message could not be decoded') # Do nothing if the function returns no value if result is None: continue self.message.payload = result topic, self.message = self.handle_stream(self.message) self.pub_socket.send_multipart( [topic.encode('utf-8'), self.message.SerializeToString()] )