Exemple #1
0
    def __init__(self,
                 logging_level=logging.INFO,
                 router_messages=sys.maxsize):
        # Name of the server
        self.name = ''

        # Logging level for the server
        self.logging_level = logging_level

        # Basic Key-value database for storage
        self.cache = DictDB()

        self.inbound_components = {}
        self.outbound_components = {}
        self.bypass_components = {}

        # Basic console logging
        self.logger = logging.getLogger(name=self.name)
        handler = logging.StreamHandler(sys.stdout)
        handler.setFormatter(
            logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
        self.logger.addHandler(handler)
        self.logger.setLevel(self.logging_level)

        # Finally, the router
        self.router = Router(logger=self.logger,
                             cache=self.cache,
                             messages=router_messages)
Exemple #2
0
    def __init__(self, name, db_address,
                 pull_address, pub_address, pipelined=False,
                 log_level=logging.INFO, messages=sys.maxsize):
        self.name = name
        self.cache = DictDB()
        self.db_address = db_address
        self.pull_address = pull_address
        self.pub_address = pub_address
        self.pipelined = pipelined
        self.message = None

        self.cache.set('name', name.encode('utf-8'))
        self.cache.set('pull_address', pull_address.encode('utf-8'))
        self.cache.set('pub_address', pub_address.encode('utf-8'))

        self.logger = logging.getLogger(name=name)
        handler = logging.StreamHandler(sys.stdout)
        handler.setFormatter(
            logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
            )
        )
        self.logger.addHandler(handler)
        self.logger.setLevel(log_level)

        self.messages = messages

        self.pull_socket = zmq_context.socket(zmq.PULL)
        self.pull_socket.bind(self.pull_address)

        self.pub_socket = zmq_context.socket(zmq.PUB)
        self.pub_socket.bind(self.pub_address)
Exemple #3
0
    def __init__(self, name, db_address,
                 sub_address, pub_address, previous, to_client=True,
                 log_level=logging.INFO, messages=sys.maxsize):
        self.name = name
        self.cache = DictDB()
        self.db_address = db_address
        self.sub_address = sub_address
        self.pub_address = pub_address
        self.pipelined = not to_client
        self.message = None

        self.cache.set('name', name.encode('utf-8'))
        self.cache.set('sub_address', sub_address.encode('utf-8'))
        self.cache.set('pub_address', pub_address.encode('utf-8'))

        self.logger = logging.getLogger(name=name)
        handler = logging.StreamHandler(sys.stdout)
        handler.setFormatter(
            logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
            )
        )
        self.logger.addHandler(handler)
        self.logger.setLevel(log_level)

        self.messages = messages

        self.sub_socket = zmq_context.socket(zmq.SUB)
        self.sub_socket.setsockopt_string(zmq.SUBSCRIBE, previous)
        self.sub_socket.connect(self.sub_address)

        self.pub_socket = zmq_context.socket(zmq.PUB)
        self.pub_socket.bind(self.pub_address)
Exemple #4
0
    def __init__(self, name: str, sub_address: str, pub_address: str,
                 worker_pull_address: str, worker_push_address: str, db_address: str,
                 previous: str, pipelined: bool=False, cache: object = DictDB(),
                 log_level: int = logging.INFO):

        super(Hub, self).__init__(logging_level=log_level)
        self.name = name
        self.cache = cache
        self.pipelined = pipelined

        self.register_inbound(
            SubConnection, 'Sub', sub_address, route='WorkerPush',
            previous=previous)
        self.register_inbound(
            WorkerPullService, 'WorkerPull', worker_pull_address, route='Pub')
        self.register_outbound(
            WorkerPushService, 'WorkerPush', worker_push_address)
        self.register_outbound(
            PubService, 'Pub', pub_address, log='to_sink', pipelined=pipelined)
        self.register_bypass(
            CacheService, 'Cache', db_address)
        self.preset_cache(name=name,
                          db_address=db_address,
                          sub_address=sub_address,
                          pub_address=pub_address,
                          worker_pull_address=worker_pull_address,
                          worker_push_address=worker_push_address)

        # Monkey patches the scatter and gather functions to the
        # scatter function of Push and Pull parts respectively.
        self.inbound_components['Sub'].scatter = self.scatter
        self.outbound_components['Pub'].scatter = self.gather
        self.outbound_components['Pub'].handle_stream = self.handle_stream
Exemple #5
0
    def __init__(self, name, db_address,
                 sub_addresses, pub_address, previous, to_client=True,
                 log_level=logging.INFO, messages=sys.maxsize):
        self.name = name
        self.cache = DictDB()
        self.db_address = db_address
        self.sub_addresses = sub_addresses
        self.pub_address = pub_address
        self.pipelined = not to_client
        self.message = None

        self.cache.set('name', name.encode('utf-8'))
        for i, address in enumerate(sub_addresses):
            self.cache.set('sub_address_{}'.format(i), address.encode('utf-8'))

        self.cache.set('pub_address', pub_address.encode('utf-8'))

        self.logger = logging.getLogger(name=name)
        handler = logging.StreamHandler(sys.stdout)
        handler.setFormatter(
            logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
            )
        )
        self.logger.addHandler(handler)
        self.logger.setLevel(log_level)

        self.messages = messages

        self.sub_sockets = list()

        # Simple type checks
        assert type(previous) == list
        assert type(sub_addresses) == list
        
        for address, prev in zip(self.sub_addresses, previous):
            self.sub_sockets.append(zmq_context.socket(zmq.SUB))
            self.sub_sockets[-1].setsockopt_string(zmq.SUBSCRIBE, prev)
            self.sub_sockets[-1].connect(address)

        self.pub_socket = zmq_context.socket(zmq.PUB)
        self.pub_socket.bind(self.pub_address)

        self.poller = zmq.Poller()

        for sock in self.sub_sockets:
            self.poller.register(sock, zmq.POLLIN)
Exemple #6
0
def test_gateway_dealer():
    """
    Test function for the complete gateway with a dummy router.
    """
    cache = DictDB()

    def dummy_response():
        dummy_router = zmq_context.socket(zmq.ROUTER)
        dummy_router.bind('inproc://broker')
        [target, empty, message] = dummy_router.recv_multipart()
        dummy_router.send_multipart([target, empty, b'0'])

        broker_message = PalmMessage()
        broker_message.ParseFromString(message)

        dummy_router.send_multipart([b'gateway_dealer', empty, message])
        [target, message] = dummy_router.recv_multipart()

    def dummy_initiator():
        dummy_client = zmq_context.socket(zmq.REQ)
        dummy_client.identity = b'0'
        dummy_client.connect('inproc://gateway_router')
        message = PalmMessage()
        message.client = dummy_client.identity
        message.pipeline = '0'
        message.function = 'f.servername'
        message.stage = 1
        message.payload = b'This is a message'
        dummy_client.send(message.SerializeToString())
        return dummy_client.recv()

    got = []

    dealer = GatewayDealer(cache=cache, logger=logging, messages=1)
    router = GatewayRouter(cache=cache, logger=logging, messages=2)

    with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
        results = [
            executor.submit(dummy_response),
            executor.submit(dummy_initiator),
            executor.submit(dealer.start),
            executor.submit(router.start)
        ]

        for future in concurrent.futures.as_completed(results):
            try:
                result = future.result()
                if result:
                    got.append(result)

            except Exception as exc:
                print(exc)

    message = PalmMessage()
    message.ParseFromString(got[0])
    assert message.payload == b'This is a message'
Exemple #7
0
 def __init__(self,
              name='',
              listen_address='inproc://gateway_router',
              hostname='',
              port=8888,
              cache=DictDB(),
              logger=None):
     self.handler = MyHandler
     self.handler.gateway_router_address = listen_address
     self.handler.logger = logger
     self.server = MyServer((hostname, port), self.handler)
     self.logger = logger
     self.port = port
Exemple #8
0
def test_get_config():
    cache = DictDB()
    cache.set('name', b'master')
    cache.set('pub_address', pub_address.encode('utf-8'))
    cache.set('pull_address', pull_address.encode('utf-8'))

    cache = CacheService('db', db_address,
                         cache=cache,
                         logger=logging,
                         messages=3,
                         )

    def boot_client():
        client = Client('master',
                        db_address,
                        session=None)
        return client.push_address, client.sub_address

    def broker():
        socket = zmq_context.socket(zmq.ROUTER)
        socket.bind(broker_address)

    with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
        results = [
            executor.submit(cache.start),
            executor.submit(boot_client),
            executor.submit(broker)
        ]

        # This works because servers do not return values.
        for i, future in enumerate(concurrent.futures.as_completed(results)):
            try:
                result = future.result()
                print(result)

            except Exception as exc:
                print(exc)
                lines = traceback.format_exception(*sys.exc_info())
                print(lines)

        assert i == 2
Exemple #9
0
def test_gateway_http():
    """
    Test function for the complete gateway with a dummy router.
    """
    cache = DictDB()

    def dummy_response():
        dummy_router = zmq_context.socket(zmq.ROUTER)
        dummy_router.bind('inproc://broker')
        [target, empty, message] = dummy_router.recv_multipart()
        dummy_router.send_multipart([target, empty, b'0'])

        broker_message = PalmMessage()
        broker_message.ParseFromString(message)

        dummy_router.send_multipart([b'gateway_dealer', empty, message])
        [target, message] = dummy_router.recv_multipart()

    def dummy_initiator():
        r = requests.get('http://localhost:8888/function')
        return r.text

    got = []

    dealer = GatewayDealer(cache=cache, logger=logging, messages=1)
    router = GatewayRouter(cache=cache, logger=logging, messages=2)
    http = HttpGateway(cache=cache, logger=logging)

    with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
        results = [
            executor.submit(dummy_response),
            executor.submit(dummy_initiator),
            executor.submit(dealer.start),
            executor.submit(router.start),
            executor.submit(http.debug)
        ]

        for future in concurrent.futures.as_completed(results):
            try:
                result = future.result()
                if result:
                    got.append(result)

            except Exception as exc:
                print(exc)

    assert got[0] == 'No Payload'
Exemple #10
0
 def __init__(self,
              name='gateway_router',
              listen_address='inproc://gateway_router',
              broker_address="inproc://broker",
              cache=DictDB(),
              logger=None,
              messages=sys.maxsize):
     super(GatewayRouter, self).__init__(
         'gateway_router',
         listen_address,
         zmq.ROUTER,
         reply=False,
         broker_address=broker_address,
         bind=True,
         cache=cache,
         logger=logger,
         messages=messages,
         )
     if name:
         self.logger.warning('Gateway router part is called "gateway_router",')
         self.logger.warning('check that you have called this way')
Exemple #11
0
class ServerTemplate(object):
    """
    Low-level tool to build a server from parts.

    :param logging_level: A correct logging level from the logging module.
        Defaults to INFO.

    It has important attributes that you may want to override, like

    :cache: The key-value database that the server should use
    :logging_level: Controls the log output of the server.
    :router: Here's the router, you may want to change its attributes too.

    """
    def __init__(self,
                 logging_level=logging.INFO,
                 router_messages=sys.maxsize):
        # Name of the server
        self.name = ''

        # Logging level for the server
        self.logging_level = logging_level

        # Basic Key-value database for storage
        self.cache = DictDB()

        self.inbound_components = {}
        self.outbound_components = {}
        self.bypass_components = {}

        # Basic console logging
        self.logger = logging.getLogger(name=self.name)
        handler = logging.StreamHandler(sys.stdout)
        handler.setFormatter(
            logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
        self.logger.addHandler(handler)
        self.logger.setLevel(self.logging_level)

        # Finally, the router
        self.router = Router(logger=self.logger,
                             cache=self.cache,
                             messages=router_messages)

    def register_inbound(self,
                         part,
                         name='',
                         listen_address='',
                         route='',
                         block=False,
                         log='',
                         **kwargs):
        """
        Register inbound part to this server.

        :param part: part class
        :param name: Name of the part
        :param listen_address: Valid ZeroMQ address listening to the exterior
        :param route: Outbound part it routes to
        :param block: True if the part blocks waiting for a response
        :param log: Log message in DEBUG level for each message processed.
        :param kwargs: Additional keyword arguments to pass to the part
        """
        # Inject the server cache in case it is not configured for the component
        if 'cache' not in kwargs:
            kwargs['cache'] = self.cache

        instance = part(name,
                        listen_address,
                        broker_address=self.router.inbound_address,
                        logger=self.logger,
                        **kwargs)

        self.router.register_inbound(name, route=route, block=block, log=log)

        self.inbound_components[name] = instance

    def register_outbound(self,
                          part,
                          name='',
                          listen_address='',
                          route='',
                          log='',
                          **kwargs):
        """
        Register outbound part to this server

        :param part: part class
        :param name: Name of the part
        :param listen_address: Valid ZeroMQ address listening to the exterior
        :param route: Outbound part it routes the response (if there is) to
        :param log: Log message in DEBUG level for each message processed
        :param kwargs: Additional keyword arguments to pass to the part
        """
        # Inject the server cache in case it is not configured for the component
        if 'cache' not in kwargs:
            kwargs['cache'] = self.cache

        instance = part(name,
                        listen_address,
                        broker_address=self.router.outbound_address,
                        logger=self.logger,
                        **kwargs)

        self.router.register_outbound(name, route=route, log=log)

        self.outbound_components[name] = instance

    def register_bypass(self, part, name='', listen_address='', **kwargs):
        """
        Register a bypass part to this server

        :param part: part class
        :param name: part name
        :param listen_address: Valid ZeroMQ address listening to the exterior
        :param kwargs: Additional keyword arguments to pass to the part
        """
        # Inject the server cache in case it is not configured for the component
        if 'cache' not in kwargs:
            kwargs['cache'] = self.cache

        instance = part(name, listen_address, logger=self.logger, **kwargs)

        self.bypass_components[name] = instance

    def preset_cache(self, **kwargs):
        """
        Send the following keyword arguments as cache variables. Useful
        for configuration variables that the workers or the clients
        fetch straight from the cache.

        :param kwargs:
        """
        for arg, val in kwargs.items():
            if type(val) == str:
                self.cache.set(arg, val.encode('utf-8'))
            else:
                self.cache.set(arg, val)

    def start(self):
        """
        Start the server with all its parts.
        """
        threads = []

        self.logger.info("Starting the router")
        threads.append(self.router.start)

        for name, part in self.inbound_components.items():
            self.logger.info("Starting inbound part {}".format(name))
            threads.append(part.start)

        for name, part in self.outbound_components.items():
            self.logger.info("Starting outbound part {}".format(name))
            threads.append(part.start)

        for name, part in self.bypass_components.items():
            self.logger.info("Starting bypass part {}".format(name))
            threads.append(part.start)

        with concurrent.futures.ThreadPoolExecutor(
                max_workers=len(threads)) as executor:
            results = [executor.submit(thread) for thread in threads]
            for future in concurrent.futures.as_completed(results):
                try:
                    future.result()
                except Exception as exc:
                    self.logger.error(
                        'This is critical, one of the parts died')
                    lines = traceback.format_exception(*sys.exc_info())
                    for line in lines:
                        self.logger.error(line.strip('\n'))
Exemple #12
0
def test_send_job():
    cache = DictDB()
    cache.set('name', b'master')
    cache.set('pub_address', pub_address.encode('utf-8'))
    cache.set('pull_address', pull_address.encode('utf-8'))

    cache_service = CacheService('db', db_address,
                                 cache=cache,
                                 logger=logging,
                                 messages=3,
                                 )

    puller = PullService('puller',
                         pull_address,
                         logger=logging,
                         cache=cache,
                         messages=1)

    publisher = PubService('publisher',
                           pub_address,
                           logger=logging,
                           cache=cache,
                           messages=1)

    def client_job():
        client = Client('master',
                        db_address,
                        session=None)
        return [r for r in client.job('master.something', [b'1'], messages=1)]

    def broker():
        socket = zmq_context.socket(zmq.ROUTER)
        socket.bind(broker_address)
        message = socket.recv_multipart()
        # Unblock. Here you see why the actual router is complicated.
        socket.send_multipart(message)
        socket.send_multipart([b'publisher', b'', message[2]])
        socket.close()

        return b'router'

    with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
        results = [
            executor.submit(cache_service.start),
            executor.submit(puller.start),
            executor.submit(publisher.start),
            executor.submit(client_job),
            executor.submit(broker)
        ]

        # This works because servers do not return values.
        for i, future in enumerate(concurrent.futures.as_completed(results)):
            try:
                result = future.result()
                print(result)

            except Exception as exc:
                print(exc)
                lines = traceback.format_exception(*sys.exc_info())
                print(*lines)

        assert i == 4
Exemple #13
0
def test_multiple_clients():
    cache = DictDB()
    cache.set('name', b'master')
    cache.set('pub_address', pub_address.encode('utf-8'))
    cache.set('pull_address', pull_address.encode('utf-8'))

    router = Router(logger=logging, cache=cache, messages=4)
    router.register_inbound('puller', route='publisher')
    router.register_outbound('publisher')

    cache_service = CacheService('db', db_address,
                                 cache=cache,
                                 logger=logging,
                                 messages=6,
                                 )

    puller = PullService('puller',
                         pull_address,
                         broker_address=router.inbound_address,
                         logger=logging,
                         cache=cache,
                         messages=4)

    publisher = PubService('publisher',
                           pub_address,
                           broker_address=router.outbound_address,
                           logger=logging,
                           cache=cache,
                           messages=4)

    def client1_job():
        client = Client('master',
                        db_address,
                        session=None)
        return [r for r in client.job('master.something', [b'1', b'2'], messages=2)]

    def client2_job():
        client = Client('master',
                        db_address,
                        session=None)
        return [r for r in client.job('master.something', [b'3', b'4'], messages=2)]

    with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
        results = [
            executor.submit(cache_service.start),
            executor.submit(puller.start),
            executor.submit(publisher.start),
            executor.submit(router.start),
            executor.submit(client1_job),
            executor.submit(client2_job),
        ]

        # This works because servers do not return values.
        for i, future in enumerate(concurrent.futures.as_completed(results)):
            try:
                result = future.result()
                if type(result) == list:
                    got = []
                    for r in result:
                        message = PalmMessage()
                        message.ParseFromString(r)
                        got.append(message.payload)

                    assert got == [b'1', b'2'] or got == [b'3', b'4']

            except Exception as exc:
                print(exc)
                lines = traceback.format_exception(*sys.exc_info())
                print(*lines)

    assert i == 5
Exemple #14
0
class Server(object):
    """
    Standalone and minimal server that replies single requests.

    :param str name: Name of the server
    :param str db_address: ZeroMQ address of the cache service.
    :param str pull_address: Address of the pull socket
    :param str pub_address: Address of the pub socket
    :param pipelined: True if the server is chained to another server.
    :param log_level: Minimum output log level.
    :param int messages: Total number of messages that the server processes.
        Useful for debugging.
    """
    def __init__(self, name, db_address,
                 pull_address, pub_address, pipelined=False,
                 log_level=logging.INFO, messages=sys.maxsize):
        self.name = name
        self.cache = DictDB()
        self.db_address = db_address
        self.pull_address = pull_address
        self.pub_address = pub_address
        self.pipelined = pipelined
        self.message = None

        self.cache.set('name', name.encode('utf-8'))
        self.cache.set('pull_address', pull_address.encode('utf-8'))
        self.cache.set('pub_address', pub_address.encode('utf-8'))

        self.logger = logging.getLogger(name=name)
        handler = logging.StreamHandler(sys.stdout)
        handler.setFormatter(
            logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
            )
        )
        self.logger.addHandler(handler)
        self.logger.setLevel(log_level)

        self.messages = messages

        self.pull_socket = zmq_context.socket(zmq.PULL)
        self.pull_socket.bind(self.pull_address)

        self.pub_socket = zmq_context.socket(zmq.PUB)
        self.pub_socket.bind(self.pub_address)

    def handle_stream(self, message):
        """
        Handle the stream of messages.

        :param message: The message about to be sent to the next step in the
            cluster
        :return: topic (str) and message (PalmMessage)

        The default behaviour is the following. If you leave this function
        unchanged and pipeline is set to False, the topic is the ID of the
        client, which makes the message return to the client. If the pipeline
        parameter is set to True, the topic is set as the name of the server and
        the step of the message is incremented by one.

        You can alter this default behaviour by overriding this function.
        Take into account that the message is also available in this function,
        and you can change other parameters like the stage or the function.
        """
        if self.pipelined:
            topic = self.name
            message.stage += 1
        else:
            topic = message.client

        return topic, message

    def echo(self, payload):
        """
        Echo utility function that returns the unchanged payload. This
        function is useful when the server is there as just to modify the
        stream of messages.

        :return: payload (bytes)
        """
        return payload

    def _execution_handler(self):
        for i in range(self.messages):
            self.logger.debug('Server waiting for messages')
            message_data = self.pull_socket.recv()
            self.logger.debug('Got message {}'.format(i + 1))
            result = b'0'
            self.message = PalmMessage()
            try:
                self.message.ParseFromString(message_data)

                # Handle the fact that the message may be a complete pipeline
                try:
                    if ' ' in self.message.function:
                        [server, function] = self.message.function.split()[
                            self.message.stage].split('.')
                    else:
                        [server, function] = self.message.function.split('.')
                except IndexError:
                    raise ValueError('Pipeline call not correct. Review the '
                                     'config in your client')
                        
                if not self.name == server:
                    self.logger.error('You called {}, instead of {}'.format(
                        server, self.name))
                else:
                    try:
                        user_function = getattr(self, function)
                        self.logger.debug('Looking for {}'.format(function))
                        try:
                            result = user_function(self.message.payload)
                        except:
                            self.logger.error('User function gave an error')
                            exc_type, exc_value, exc_traceback = sys.exc_info()
                            lines = traceback.format_exception(
                                exc_type, exc_value, exc_traceback)
                            for l in lines:
                                self.logger.exception(l)

                    except KeyError:
                        self.logger.error(
                            'Function {} was not found'.format(function)
                        )
            except DecodeError:
                self.logger.error('Message could not be decoded')

            self.message.payload = result

            topic, self.message = self.handle_stream(self.message)

            self.pub_socket.send_multipart(
                [topic.encode('utf-8'), self.message.SerializeToString()]
            )

    def start(self, cache_messages=sys.maxsize):
        """
        Start the server

        :param cache_messages: Number of messages the cache service handles
            before it shuts down. Useful for debugging

        """
        threads = []

        cache = CacheService('cache', self.db_address, logger=self.logger,
                             cache=self.cache, messages=cache_messages)

        threads.append(cache.start)
        threads.append(self._execution_handler)

        with concurrent.futures.ThreadPoolExecutor(max_workers=len(threads)) as executor:
            results = [executor.submit(thread) for thread in threads]
            for future in concurrent.futures.as_completed(results):
                try:
                    future.result()
                except Exception as exc:
                    self.logger.error(
                        'This is critical, one of the components of the '
                        'server died')
                    lines = traceback.format_exception(*sys.exc_info())
                    for line in lines:
                        self.logger.error(line.strip('\n'))

        return self.name.encode('utf-8')
Exemple #15
0
class Sink(Server):
    """
    Minimal server that acts as a sink of multiple streams.

    :param str name: Name of the server
    :param str db_address: ZeroMQ address of the cache service.
    :param str sub_addresses: List of addresses of the pub socket of the previous servers
    :param str pub_address: Address of the pub socket
    :param previous: List of names of the previous servers.
    :param to_client: True if the message is sent back to the client. Defaults to True
    :param log_level: Minimum output log level. Defaults to INFO
    :param int messages: Total number of messages that the server processes. Defaults to Infty
        Useful for debugging.
    """
    def __init__(self, name, db_address,
                 sub_addresses, pub_address, previous, to_client=True,
                 log_level=logging.INFO, messages=sys.maxsize):
        self.name = name
        self.cache = DictDB()
        self.db_address = db_address
        self.sub_addresses = sub_addresses
        self.pub_address = pub_address
        self.pipelined = not to_client
        self.message = None

        self.cache.set('name', name.encode('utf-8'))
        for i, address in enumerate(sub_addresses):
            self.cache.set('sub_address_{}'.format(i), address.encode('utf-8'))

        self.cache.set('pub_address', pub_address.encode('utf-8'))

        self.logger = logging.getLogger(name=name)
        handler = logging.StreamHandler(sys.stdout)
        handler.setFormatter(
            logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
            )
        )
        self.logger.addHandler(handler)
        self.logger.setLevel(log_level)

        self.messages = messages

        self.sub_sockets = list()

        # Simple type checks
        assert type(previous) == list
        assert type(sub_addresses) == list
        
        for address, prev in zip(self.sub_addresses, previous):
            self.sub_sockets.append(zmq_context.socket(zmq.SUB))
            self.sub_sockets[-1].setsockopt_string(zmq.SUBSCRIBE, prev)
            self.sub_sockets[-1].connect(address)

        self.pub_socket = zmq_context.socket(zmq.PUB)
        self.pub_socket.bind(self.pub_address)

        self.poller = zmq.Poller()

        for sock in self.sub_sockets:
            self.poller.register(sock, zmq.POLLIN)
        
    def _execution_handler(self):
        for i in range(self.messages):
            self.logger.debug('Server waiting for messages')
            locked_socks = dict(self.poller.poll())

            for sock in self.sub_sockets:
                if sock in locked_socks:
                    message_data = sock.recv_multipart()[1]
                    
                    self.logger.debug('Got message {}'.format(i + 1))
                    result = b'0'
                    self.message = PalmMessage()
                    try:
                        self.message.ParseFromString(message_data)
                    
                        # Handle the fact that the message may be a complete pipeline
                        try:
                            if ' ' in self.message.function:
                                [server, function] = self.message.function.split()[
                                    self.message.stage].split('.')
                            else:
                                [server, function] = self.message.function.split('.')
                        except IndexError:
                            raise ValueError('Pipeline call not correct. Review the '
                                             'config in your client')

                        if not self.name == server:
                            self.logger.error('You called {}, instead of {}'.format(
                                server, self.name))
                        else:
                            try:
                                user_function = getattr(self, function)
                                self.logger.debug('Looking for {}'.format(function))
                                try:
                                    result = user_function(self.message.payload)
                                except:
                                    self.logger.error('User function gave an error')
                                    exc_type, exc_value, exc_traceback = sys.exc_info()
                                    lines = traceback.format_exception(
                                        exc_type, exc_value, exc_traceback)
                                    for l in lines:
                                        self.logger.exception(l)
                    
                            except KeyError:
                                self.logger.error(
                                    'Function {} was not found'.format(function)
                                )
                    except DecodeError:
                        self.logger.error('Message could not be decoded')
                    
                    # Do nothing if the function returns no value
                    if result is None:
                        continue
                    
                    self.message.payload = result
                    
                    topic, self.message = self.handle_stream(self.message)

                    self.pub_socket.send_multipart(
                        [topic.encode('utf-8'), self.message.SerializeToString()]
                    )
Exemple #16
0
import concurrent.futures
import time
import zmq
import logging
import sys
import traceback

logger = logging.getLogger('test_service_pub')
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
    logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)

cache = DictDB()

listen_address = 'inproc://pub1'
broker_address = 'inproc://broker1'

pub_service = PubService('pull_service',
                         listen_address=listen_address,
                         broker_address=broker_address,
                         logger=logger,
                         cache=cache,
                         messages=1)


def fake_router():
    original_message = PalmMessage()
    original_message.pipeline = 'pipeline'