Beispiel #1
0
    def start(self):
        """
        Call this function to start the component
        """
        message = PalmMessage()

        self.listen_to.bind(self.listen_address)
        self.logger.info('{} successfully started'.format(self.name))

        for i in range(self.messages):
            self.logger.debug('{} blocked waiting for broker'.format(
                self.name))
            message_data = self.broker.recv()
            self.logger.debug('{} Got message from broker'.format(self.name))
            message_data = self._translate_from_broker(message_data)
            message.ParseFromString(message_data)

            for scattered in self.scatter(message):
                topic, scattered = self.handle_stream(scattered)
                self.listen_to.send_multipart(
                    [topic.encode('utf-8'),
                     message.SerializeToString()])
                self.logger.debug('Component {} Sent message. Topic {}'.format(
                    self.name, topic))

                if self.reply:
                    feedback = self.listen_to.recv()
                    feedback = self._translate_to_broker(feedback)
                    self.handle_feedback(feedback)

            self.broker.send(self.reply_feedback())

        return self.name
Beispiel #2
0
            def do_POST(self):
                """
                Note that this http server always replies
                """
                message = PalmMessage()
                self.send_response(200)
                self.end_headers()
                message_data = self.rfile.read(
                    int(self.headers.get('Content-Length')))

                message.ParseFromString(message_data)
                scattered_messages = scatter(message)

                if not scattered_messages:
                    self.wfile.write(b'0')

                else:
                    for scattered in scattered_messages:
                        scattered = _translate_to_broker(scattered)

                        if scattered:
                            broker.send(scattered.SerializeToString())
                            handle_feedback(broker.recv())

                    self.wfile.write(reply_feedback())
Beispiel #3
0
    def start(self):
        """
        Call this function to start the component
        """
        message = PalmMessage()

        if self.bind:
            self.listen_to.bind(self.listen_address)
        else:
            self.listen_to.connect(self.listen_address)

        self.logger.info('{} successfully started'.format(self.name))
            
        for i in range(self.messages):
            self.logger.debug('{} blocked waiting for broker'.format(self.name))
            message_data = self.broker.recv()
            self.logger.debug('{} Got message from broker'.format(self.name))
            message_data = self._translate_from_broker(message_data)
            message.ParseFromString(message_data)

            for scattered in self.scatter(message):
                self.listen_to.send(scattered.SerializeToString())
                self.logger.debug('{} Sent message'.format(self.name))

                if self.reply:
                    feedback = self.listen_to.recv()
                    feedback = self._translate_to_broker(feedback)
                    self.handle_feedback(feedback)

            self.broker.send(self.reply_feedback())

        return self.name
Beispiel #4
0
 def dummy_response():
     dummy_router = zmq_context.socket(zmq.REP)
     dummy_router.bind('inproc://broker')
     msg = dummy_router.recv()
     message = PalmMessage()
     message.ParseFromString(msg)
     dummy_router.send(msg)
     return message.payload
Beispiel #5
0
def fake_terminator():
    socket = zmq_context.socket(zmq.REP)
    socket.bind('inproc://terminator')
    message = PalmMessage()
    message.ParseFromString(socket.recv())

    print("Got the message at the terminator: ")
    socket.send(b'0')
    return message
Beispiel #6
0
def test_gateway_dealer():
    """
    Test function for the complete gateway with a dummy router.
    """
    cache = DictDB()

    def dummy_response():
        dummy_router = zmq_context.socket(zmq.ROUTER)
        dummy_router.bind('inproc://broker')
        [target, empty, message] = dummy_router.recv_multipart()
        dummy_router.send_multipart([target, empty, b'0'])

        broker_message = PalmMessage()
        broker_message.ParseFromString(message)

        dummy_router.send_multipart([b'gateway_dealer', empty, message])
        [target, message] = dummy_router.recv_multipart()

    def dummy_initiator():
        dummy_client = zmq_context.socket(zmq.REQ)
        dummy_client.identity = b'0'
        dummy_client.connect('inproc://gateway_router')
        message = PalmMessage()
        message.client = dummy_client.identity
        message.pipeline = '0'
        message.function = 'f.servername'
        message.stage = 1
        message.payload = b'This is a message'
        dummy_client.send(message.SerializeToString())
        return dummy_client.recv()

    got = []

    dealer = GatewayDealer(cache=cache, logger=logging, messages=1)
    router = GatewayRouter(cache=cache, logger=logging, messages=2)

    with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
        results = [
            executor.submit(dummy_response),
            executor.submit(dummy_initiator),
            executor.submit(dealer.start),
            executor.submit(router.start)
        ]

        for future in concurrent.futures.as_completed(results):
            try:
                result = future.result()
                if result:
                    got.append(result)

            except Exception as exc:
                print(exc)

    message = PalmMessage()
    message.ParseFromString(got[0])
    assert message.payload == b'This is a message'
Beispiel #7
0
    def dummy_response():
        dummy_router = zmq_context.socket(zmq.ROUTER)
        dummy_router.bind('inproc://broker')
        [target, empty, message] = dummy_router.recv_multipart()
        dummy_router.send_multipart([target, empty, b'0'])

        broker_message = PalmMessage()
        broker_message.ParseFromString(message)

        dummy_router.send_multipart([b'gateway_dealer', empty, message])
        [target, message] = dummy_router.recv_multipart()
Beispiel #8
0
    def eval(self,
             function,
             payload: bytes,
             messages: int = 1,
             cache: str = ''):
        """
        Execute single job.

        :param function: Sting or list of strings following the format
            ``server.function``.
        :param payload: Binary message to be sent
        :param messages: Number of messages expected to be sent back to the
            client
        :param cache: Cache data included in the message
        :return: If messages=1, the result data. If messages > 1, a list with the results
        """
        push_socket = zmq_context.socket(zmq.PUSH)
        push_socket.connect(self.push_address)

        sub_socket = zmq_context.socket(zmq.SUB)
        sub_socket.setsockopt_string(zmq.SUBSCRIBE, self.uuid)
        sub_socket.connect(self.sub_address)

        if type(function) == str:
            # Single-stage job
            pass
        elif type(function) == list:
            # Pipelined job.
            function = ' '.join(function)

        message = PalmMessage()
        message.function = function
        message.stage = 0
        message.pipeline = self.pipeline
        message.client = self.uuid
        message.payload = payload
        if cache:
            message.cache = cache

        push_socket.send(message.SerializeToString())

        result = []

        for i in range(messages):
            [client, message_data] = sub_socket.recv_multipart()
            message.ParseFromString(message_data)
            result.append(message.payload)

        if messages == 1:
            return result[0]

        else:
            return result
Beispiel #9
0
    def job(self,
            function,
            generator,
            messages: int = sys.maxsize,
            cache: str = ''):
        """
        Submit a job with multiple messages to a server.

        :param function: Sting or list of strings following the format
            ``server.function``.
        :param payload: A generator that yields a series of binary messages.
        :param messages: Number of messages expected to be sent back to the
            client. Defaults to infinity (sys.maxsize)
        :param cache: Cache data included in the message
        :return: an iterator with the messages that are sent back to the client.
        """
        push_socket = zmq_context.socket(zmq.PUSH)
        push_socket.connect(self.push_address)

        sub_socket = zmq_context.socket(zmq.SUB)
        sub_socket.setsockopt_string(zmq.SUBSCRIBE, self.uuid)
        sub_socket.connect(self.sub_address)

        if type(function) == str:
            # Single-stage job
            pass
        elif type(function) == list:
            # Pipelined job.
            function = ' '.join(function)

        # Remember that sockets are not thread safe
        sender_thread = Thread(target=self._sender,
                               args=(push_socket, function, generator, cache))

        # Sender runs in background.
        sender_thread.start()

        for i in range(messages):
            [client, message_data] = sub_socket.recv_multipart()
            if not client.decode('utf-8') == self.uuid:
                raise ValueError(
                    'The client got a message that does not belong')

            message = PalmMessage()
            message.ParseFromString(message_data)
            yield message.payload
Beispiel #10
0
    def start(self):
        """
        Call this function to start the component
        """
        message = PalmMessage()

        def load_url(url, data):
            request = Request(url, data=data)
            response = urlopen(request)
            return response.read()

        for i in range(self.messages):
            self.logger.debug('{} blocked waiting for broker'.format(
                self.name))
            message_data = self.broker.recv()
            self.logger.debug('{} Got message from broker'.format(self.name))
            message_data = self._translate_from_broker(message_data)
            message.ParseFromString(message_data)

            with concurrent.futures.ThreadPoolExecutor(
                    max_workers=self.max_workers) as executor:
                future_to = [
                    executor.submit(load_url, self.url,
                                    scattered.SerializeToString())
                    for scattered in self.scatter(message)
                ]
                for future in concurrent.futures.as_completed(future_to):
                    try:
                        feedback = future.result()
                    except Exception as exc:
                        self.logger.error('HttpConnection generated an error')
                        lines = traceback.format_exception(*sys.exc_info())
                        self.logger.exception(lines[0])
                        feedback = None

                    if self.reply:
                        feedback = self._translate_to_broker(feedback)
                        self.handle_feedback(feedback)

            if feedback:
                self.broker.send(self.reply_feedback())
            else:
                self.broker.send(message_data)
Beispiel #11
0
    def recv(self):
        message_data = self.listen_to.recv()
        message = PalmMessage()
        message.ParseFromString(message_data)
        instruction = message.function.split('.')[1]

        if instruction == 'set':
            if message.cache:
                key = message.cache
            else:
                key = str(uuid4())

            self.logger.debug('Cache Service: Set key {}'.format(key))
            value = message.payload
            self.cache.set(key, value)
            return_value = key.encode('utf-8')

        elif instruction == 'get':
            key = message.payload.decode('utf-8')
            self.logger.debug('Cache Service: Get key {}'.format(key))
            value = self.cache.get(key)
            if not value:
                self.logger.error('key {} not present'.format(key))
                return_value = b''
            else:
                return_value = value

        elif instruction == 'delete':
            key = message.payload.decode('utf-8')
            self.logger.debug('Cache Service: Delete key {}'.format(key))
            self.cache.delete(key)
            return_value = key.encode('utf-8')

        else:
            self.logger.error('Cache {}:Key not found in the database'.format(
                self.name))
            return_value = b''

        if isinstance(return_value, str):
            self.listen_to.send_string(return_value)
        else:
            self.listen_to.send(return_value)
Beispiel #12
0
    def do_GET(self):
        socket = zmq_context.socket(zmq.REQ)
        # This is the identity of the socket and the client.
        identity = str(uuid4()).encode('utf-8')
        socket.identity = identity
        socket.connect(self.gateway_router_address)
        
        function = self.path_parser(self.path)

        if function:
            message = PalmMessage()
            message.pipeline = str(uuid4())
            message.function = function
            
            content_length = self.headers.get('content-length')
            if content_length:
                message.payload = self.rfile.read(int(content_length))
            else:
                message.payload = b'No Payload'
                
            message.stage = 0
            # Uses the same identity as the socket to tell the gateway
            # router where it has to route to
            message.client = identity

            socket.send(message.SerializeToString())
            message.ParseFromString(socket.recv())
            
            self.send_response(200)
            self.send_header('Content-type', 'text/plain')
            self.end_headers()
        else:
            self.send_response(404)
            self.send_header('Content-type', 'text/plain')
            self.end_headers()
            message = b'Not found'
            
        self.wfile.write(message.payload)

        socket.close()
        return
Beispiel #13
0
def fake_server(messages=1):
    db_socket = zmq_context.socket(zmq.REP)
    db_socket.bind('inproc://db')

    pull_socket = zmq_context.socket(zmq.PULL)
    pull_socket.bind('inproc://pull')

    pub_socket = zmq_context.socket(zmq.PUB)
    pub_socket.bind('inproc://pub')

    # PUB-SUB takes a while
    time.sleep(1.0)

    for i in range(messages):
        message_data = pull_socket.recv()
        print(i)
        message = PalmMessage()
        message.ParseFromString(message_data)

        topic = message.client
        pub_socket.send_multipart([topic.encode('utf-8'), message_data])
Beispiel #14
0
    def start(self):
        """
        Call this function to start the component
        """
        message = PalmMessage()
        self.listen_to.connect(self.listen_address)

        for i in range(self.messages):
            self.logger.debug(
                'Component {} blocked waiting for broker'.format(self.name))
            [me, message_data] = self.broker.recv_multipart()
            message.ParseFromString(message_data)
            self.logger.debug(
                'Component {} Got message from broker'.format(self.name))
            target, message = self._translate_from_broker(message)

            for scattered in self.scatter(message):
                self.listen_to.send_multipart([target.encode('utf-8'), b'',
                                               scattered.SerializeToString()])
                self.logger.debug('Component {} sent message'.format(self.name))

            self.broker.send(b'')
Beispiel #15
0
    def handle(self):
        if self.request.method == 'POST':
            try:
                message = PalmMessage()
                message.ParseFromString(self.request.data)

                # This exports the message information
                self.message = message
                instruction = message.function.split('.')[1]
                result = getattr(self, instruction)(message.payload)
                message.payload = result
                response_body = message.SerializeToString()
                status = '200 OK'

            except Exception as exc:
                status = '500 Internal Server Error'
                response_body = b''
        else:
            status = '405 Method not allowed'
            response_body = b''

        return status, response_body
Beispiel #16
0
    def start(self):
        """
        Call this function to start the component
        """
        message = PalmMessage()
        self.listen_to.bind(self.listen_address)
        self.logger.info('Launch component {}'.format(self.name))

        for i in range(self.messages):
            self.logger.debug('Component {} blocked waiting messages'.format(self.name))
            response = self.listen_to.recv_multipart()

            # If the message is from anything but the dealer, send it to the
            # router.
            if len(response) == 3:
                [target, empty, message_data] = response
                self.logger.debug('{} Got inbound message'.format(self.name))
                
                try:
                    message.ParseFromString(message_data)
                    for scattered in self.scatter(message):
                        scattered = self._translate_to_broker(scattered)
                        self.broker.send(scattered.SerializeToString())
                        self.logger.debug(
                            'Component {} blocked waiting for broker'.format(
                            self.name))
                        self.broker.recv()

                except:
                    self.logger.error('Error in scatter function')
                    lines = traceback.format_exception(*sys.exc_info())
                    self.logger.exception(lines[0])

            # This is what's different. The response to be sent from the router
            # is what it gets from the dealer.
            elif len(response) == 4 and response[0] == b'dealer':
                self.listen_to.send_multipart(response[1:])
Beispiel #17
0
    def start(self):
        """
        Call this function to start the component
        """
        message = PalmMessage()

        if self.bind:
            self.listen_to.bind(self.listen_address)
        else:
            self.listen_to.connect(self.listen_address)

        self.logger.info('{} successfully started'.format(self.name))
        for i in range(self.messages):
            self.logger.debug('{} blocked waiting messages'.format(self.name))
            message_data = self.listen_to.recv()
            self.logger.debug('{} Got inbound message'.format(self.name))

            try:
                message.ParseFromString(message_data)
                for scattered in self.scatter(message):
                    scattered = self._translate_to_broker(scattered)
                    self.broker.send(scattered.SerializeToString())
                    self.logger.debug('{} blocked waiting for broker'.format(
                        self.name))
                    self.handle_feedback(self.broker.recv())

                if self.reply:
                    self.listen_to.send(self.reply_feedback())
            except:
                self.logger.error('Exception in scatter or routing.')
                lines = traceback.format_exception(*sys.exc_info())
                self.logger.exception(lines[0])

                if self.reply:
                    self.listen_to.send(b'0')

        return self.name
Beispiel #18
0
class Worker(object):
    """
    Standalone worker for the standalone master.

    :param name: Name assigned to this worker server
    :param db_address: Address of the db service of the master
    :param push_address: Address the workers push to. If left blank, fetches
        it from the master
    :param pull_address: Address the workers pull from. If left blank,
        fetches it from the master
    :param log_level: Log level for this server.
    :param messages: Number of messages before it is shut down.

    """
    def __init__(self, name='', db_address='', push_address=None,
                 pull_address=None, log_level=logging.INFO,
                 messages=sys.maxsize):

        self.uuid = str(uuid4())

        # Give a random name if not given
        if not name:
            self.name = self.uuid
        else:
            self.name = name

        # Configure the log handler
        self.logger = logging.getLogger(name=name)
        handler = logging.StreamHandler(sys.stdout)
        handler.setFormatter(logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        )
        self.logger.addHandler(handler)
        self.logger.setLevel(log_level)

        # Configure the connections.
        self.push_address = push_address
        self.pull_address = pull_address

        if not db_address:
            raise ValueError('db_address argument is mandatory')

        self.db_address = db_address
        self.db = zmq_context.socket(zmq.REQ)
        self.db.connect(db_address)

        self._get_config_from_master()

        self.pull = zmq_context.socket(zmq.PULL)
        self.pull.connect(self.push_address)

        self.push = zmq_context.socket(zmq.PUSH)
        self.push.connect(self.pull_address)

        self.messages = messages
        self.message = PalmMessage()

    def _get_config_from_master(self):
        if not self.push_address:
            self.push_address = self.get('worker_push_address').decode('utf-8')
            self.logger.info(
                'Got worker push address: {}'.format(self.push_address))

        if not self.pull_address:
            self.pull_address = self.get('worker_pull_address').decode('utf-8')
            self.logger.info(
                'Got worker pull address: {}'.format(self.pull_address))

        return {'push_address': self.push_address,
                'pull_address': self.pull_address}

    def _exec_function(self):
        """
        Waits for a message and return the result
        """
        message_data = self.pull.recv()
        self.logger.debug('{} Got a message'.format(self.name))
        result = b'0'
        try:
            self.message.ParseFromString(message_data)
            try:
                if ' ' in self.message.function:
                    instruction = self.message.function.split()[
                        self.message.stage].split('.')[1]
                else:
                    instruction = self.message.function.split('.')[1]
            except IndexError:
                raise ValueError('Pipeline call not correct. Review the '
                                 'config in your client')

            try:
                user_function = getattr(self, instruction)
                self.logger.debug('Looking for {}'.format(instruction))
                try:
                    result = user_function(self.message.payload)
                    self.logger.debug('{} Ok'.format(instruction))
                except:
                    self.logger.error(
                        '{} User function {} gave an error'.format(
                            self.name, instruction)
                    )
                    lines = traceback.format_exception(*sys.exc_info())
                    self.logger.exception(lines[0])

            except AttributeError:
                self.logger.error(
                    'Function {} was not found'.format(instruction)
                )
        except DecodeError:
            self.logger.error('Message could not be decoded')
        return result

    def start(self):
        """
        Starts the server
        """
        for i in range(self.messages):
            self.message.payload = self._exec_function()
            self.push.send(self.message.SerializeToString())

    def set(self, value, key=None):
        """
        Sets a key value pare in the remote database.

        :param key:
        :param value:
        :return:
        """
        message = PalmMessage()
        message.pipeline = str(uuid4())
        message.client = self.uuid
        message.stage = 0
        message.function = '.'.join(['_', 'set'])
        message.payload = value
        if key:
            message.cache = key

        self.db.send(message.SerializeToString())
        return self.db.recv().decode('utf-8')

    def get(self, key):
        """
        Gets a value from server's internal cache

        :param key: Key for the data to be selected.
        :return:
        """
        message = PalmMessage()
        message.pipeline = str(uuid4())
        message.client = self.uuid
        message.stage = 0
        message.function = '.'.join(['_', 'get'])
        message.payload = key.encode('utf-8')
        self.db.send(message.SerializeToString())
        return self.db.recv()

    def delete(self, key):
        """
        Deletes data in the server's internal cache.

        :param key: Key of the data to be deleted
        :return:
        """
        message = PalmMessage()
        message.pipeline = str(uuid4())
        message.client = self.uuid
        message.stage = 0
        message.function = '.'.join(['_', 'delete'])
        message.payload = key.encode('utf-8')
        self.db.send(message.SerializeToString())
        return self.db.recv().decode('utf-8')
Beispiel #19
0
    def job_list(self,
                 function,
                 list_generator,
                 messages: int=sys.maxsize,
                 workers: int=sys.maxsize,
                 cache: str=''):
        """
        Submit a job with multiple messages to a server.

        :param function: Sting or list of strings following the format
            ``server.function``.
        :param payload: A generator that yields a series of binary messages.
        :param messages: Number of messages expected to be sent back to the
            client. Defaults to infinity (sys.maxsize)
        :param list_generator: List generator.
        :param workers: Number of workers, default 1:
        :param cache: Cache data included in the message
        :return: an iterator with the messages that are sent back to the client.
        """
        def _generator_from_list(a_list):
            """
            Stupid function to create a
            generator from list.

            :param some_list:
            :return: a generator.
            """
            for element in a_list:
                yield element

        push_socket = zmq_context.socket(zmq.PUSH)
        push_socket.connect(self.push_address)

        sub_socket = zmq_context.socket(zmq.SUB)
        sub_socket.setsockopt_string(zmq.SUBSCRIBE, self.uuid)
        sub_socket.connect(self.sub_address)

        if type(function) == str:
            # Single-stage job
            pass
        elif type(function) == list:
            # Pipelined job.
            function = ' '.join(function)

        if messages != sys.maxsize:
            self.logger.debug('There are {} messages'.format(messages))
        if workers != sys.maxsize:
            self.logger.debug('There are {} workers'.format(workers))
        self.logger.debug('Sending jobs to workers')
        for i in range(0, messages, workers):
            self.logger.debug('JSONs from {} to {}'.format(i, i + workers))
            sub_list_generator = list_generator[i:i + workers]
            # Remember that sockets are not thread safe
            sub_generator = _generator_from_list(sub_list_generator)
            sender_thread = Thread(target=self._sender,
                                   args=(push_socket, function, sub_generator, cache))
            # Sender runs in background.
            sender_thread.start()

        for i in range(sys.maxsize):
            [client, message_data] = sub_socket.recv_multipart()
            if not client.decode('utf-8') == self.uuid:
                raise ValueError('The client got a message that does not belong')

            message = PalmMessage()
            message.ParseFromString(message_data)
            if i == messages - 1:
                return message.payload
            else:
                yield message.payload
Beispiel #20
0
class Sink(Server):
    """
    Minimal server that acts as a sink of multiple streams.

    :param str name: Name of the server
    :param str db_address: ZeroMQ address of the cache service.
    :param str sub_addresses: List of addresses of the pub socket of the previous servers
    :param str pub_address: Address of the pub socket
    :param previous: List of names of the previous servers.
    :param to_client: True if the message is sent back to the client. Defaults to True
    :param log_level: Minimum output log level. Defaults to INFO
    :param int messages: Total number of messages that the server processes. Defaults to Infty
        Useful for debugging.
    """
    def __init__(self, name, db_address,
                 sub_addresses, pub_address, previous, to_client=True,
                 log_level=logging.INFO, messages=sys.maxsize):
        self.name = name
        self.cache = DictDB()
        self.db_address = db_address
        self.sub_addresses = sub_addresses
        self.pub_address = pub_address
        self.pipelined = not to_client
        self.message = None

        self.cache.set('name', name.encode('utf-8'))
        for i, address in enumerate(sub_addresses):
            self.cache.set('sub_address_{}'.format(i), address.encode('utf-8'))

        self.cache.set('pub_address', pub_address.encode('utf-8'))

        self.logger = logging.getLogger(name=name)
        handler = logging.StreamHandler(sys.stdout)
        handler.setFormatter(
            logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
            )
        )
        self.logger.addHandler(handler)
        self.logger.setLevel(log_level)

        self.messages = messages

        self.sub_sockets = list()

        # Simple type checks
        assert type(previous) == list
        assert type(sub_addresses) == list
        
        for address, prev in zip(self.sub_addresses, previous):
            self.sub_sockets.append(zmq_context.socket(zmq.SUB))
            self.sub_sockets[-1].setsockopt_string(zmq.SUBSCRIBE, prev)
            self.sub_sockets[-1].connect(address)

        self.pub_socket = zmq_context.socket(zmq.PUB)
        self.pub_socket.bind(self.pub_address)

        self.poller = zmq.Poller()

        for sock in self.sub_sockets:
            self.poller.register(sock, zmq.POLLIN)
        
    def _execution_handler(self):
        for i in range(self.messages):
            self.logger.debug('Server waiting for messages')
            locked_socks = dict(self.poller.poll())

            for sock in self.sub_sockets:
                if sock in locked_socks:
                    message_data = sock.recv_multipart()[1]
                    
                    self.logger.debug('Got message {}'.format(i + 1))
                    result = b'0'
                    self.message = PalmMessage()
                    try:
                        self.message.ParseFromString(message_data)
                    
                        # Handle the fact that the message may be a complete pipeline
                        try:
                            if ' ' in self.message.function:
                                [server, function] = self.message.function.split()[
                                    self.message.stage].split('.')
                            else:
                                [server, function] = self.message.function.split('.')
                        except IndexError:
                            raise ValueError('Pipeline call not correct. Review the '
                                             'config in your client')

                        if not self.name == server:
                            self.logger.error('You called {}, instead of {}'.format(
                                server, self.name))
                        else:
                            try:
                                user_function = getattr(self, function)
                                self.logger.debug('Looking for {}'.format(function))
                                try:
                                    result = user_function(self.message.payload)
                                except:
                                    self.logger.error('User function gave an error')
                                    exc_type, exc_value, exc_traceback = sys.exc_info()
                                    lines = traceback.format_exception(
                                        exc_type, exc_value, exc_traceback)
                                    for l in lines:
                                        self.logger.exception(l)
                    
                            except KeyError:
                                self.logger.error(
                                    'Function {} was not found'.format(function)
                                )
                    except DecodeError:
                        self.logger.error('Message could not be decoded')
                    
                    # Do nothing if the function returns no value
                    if result is None:
                        continue
                    
                    self.message.payload = result
                    
                    topic, self.message = self.handle_stream(self.message)

                    self.pub_socket.send_multipart(
                        [topic.encode('utf-8'), self.message.SerializeToString()]
                    )
Beispiel #21
0
class Server(object):
    """
    Standalone and minimal server that replies single requests.

    :param str name: Name of the server
    :param str db_address: ZeroMQ address of the cache service.
    :param str pull_address: Address of the pull socket
    :param str pub_address: Address of the pub socket
    :param pipelined: True if the server is chained to another server.
    :param log_level: Minimum output log level.
    :param int messages: Total number of messages that the server processes.
        Useful for debugging.
    """
    def __init__(self, name, db_address,
                 pull_address, pub_address, pipelined=False,
                 log_level=logging.INFO, messages=sys.maxsize):
        self.name = name
        self.cache = DictDB()
        self.db_address = db_address
        self.pull_address = pull_address
        self.pub_address = pub_address
        self.pipelined = pipelined
        self.message = None

        self.cache.set('name', name.encode('utf-8'))
        self.cache.set('pull_address', pull_address.encode('utf-8'))
        self.cache.set('pub_address', pub_address.encode('utf-8'))

        self.logger = logging.getLogger(name=name)
        handler = logging.StreamHandler(sys.stdout)
        handler.setFormatter(
            logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
            )
        )
        self.logger.addHandler(handler)
        self.logger.setLevel(log_level)

        self.messages = messages

        self.pull_socket = zmq_context.socket(zmq.PULL)
        self.pull_socket.bind(self.pull_address)

        self.pub_socket = zmq_context.socket(zmq.PUB)
        self.pub_socket.bind(self.pub_address)

    def handle_stream(self, message):
        """
        Handle the stream of messages.

        :param message: The message about to be sent to the next step in the
            cluster
        :return: topic (str) and message (PalmMessage)

        The default behaviour is the following. If you leave this function
        unchanged and pipeline is set to False, the topic is the ID of the
        client, which makes the message return to the client. If the pipeline
        parameter is set to True, the topic is set as the name of the server and
        the step of the message is incremented by one.

        You can alter this default behaviour by overriding this function.
        Take into account that the message is also available in this function,
        and you can change other parameters like the stage or the function.
        """
        if self.pipelined:
            topic = self.name
            message.stage += 1
        else:
            topic = message.client

        return topic, message

    def echo(self, payload):
        """
        Echo utility function that returns the unchanged payload. This
        function is useful when the server is there as just to modify the
        stream of messages.

        :return: payload (bytes)
        """
        return payload

    def _execution_handler(self):
        for i in range(self.messages):
            self.logger.debug('Server waiting for messages')
            message_data = self.pull_socket.recv()
            self.logger.debug('Got message {}'.format(i + 1))
            result = b'0'
            self.message = PalmMessage()
            try:
                self.message.ParseFromString(message_data)

                # Handle the fact that the message may be a complete pipeline
                try:
                    if ' ' in self.message.function:
                        [server, function] = self.message.function.split()[
                            self.message.stage].split('.')
                    else:
                        [server, function] = self.message.function.split('.')
                except IndexError:
                    raise ValueError('Pipeline call not correct. Review the '
                                     'config in your client')
                        
                if not self.name == server:
                    self.logger.error('You called {}, instead of {}'.format(
                        server, self.name))
                else:
                    try:
                        user_function = getattr(self, function)
                        self.logger.debug('Looking for {}'.format(function))
                        try:
                            result = user_function(self.message.payload)
                        except:
                            self.logger.error('User function gave an error')
                            exc_type, exc_value, exc_traceback = sys.exc_info()
                            lines = traceback.format_exception(
                                exc_type, exc_value, exc_traceback)
                            for l in lines:
                                self.logger.exception(l)

                    except KeyError:
                        self.logger.error(
                            'Function {} was not found'.format(function)
                        )
            except DecodeError:
                self.logger.error('Message could not be decoded')

            self.message.payload = result

            topic, self.message = self.handle_stream(self.message)

            self.pub_socket.send_multipart(
                [topic.encode('utf-8'), self.message.SerializeToString()]
            )

    def start(self, cache_messages=sys.maxsize):
        """
        Start the server

        :param cache_messages: Number of messages the cache service handles
            before it shuts down. Useful for debugging

        """
        threads = []

        cache = CacheService('cache', self.db_address, logger=self.logger,
                             cache=self.cache, messages=cache_messages)

        threads.append(cache.start)
        threads.append(self._execution_handler)

        with concurrent.futures.ThreadPoolExecutor(max_workers=len(threads)) as executor:
            results = [executor.submit(thread) for thread in threads]
            for future in concurrent.futures.as_completed(results):
                try:
                    future.result()
                except Exception as exc:
                    self.logger.error(
                        'This is critical, one of the components of the '
                        'server died')
                    lines = traceback.format_exception(*sys.exc_info())
                    for line in lines:
                        self.logger.error(line.strip('\n'))

        return self.name.encode('utf-8')
Beispiel #22
0
def test_multiple_clients():
    cache = DictDB()
    cache.set('name', b'master')
    cache.set('pub_address', pub_address.encode('utf-8'))
    cache.set('pull_address', pull_address.encode('utf-8'))

    router = Router(logger=logging, cache=cache, messages=4)
    router.register_inbound('puller', route='publisher')
    router.register_outbound('publisher')

    cache_service = CacheService('db', db_address,
                                 cache=cache,
                                 logger=logging,
                                 messages=6,
                                 )

    puller = PullService('puller',
                         pull_address,
                         broker_address=router.inbound_address,
                         logger=logging,
                         cache=cache,
                         messages=4)

    publisher = PubService('publisher',
                           pub_address,
                           broker_address=router.outbound_address,
                           logger=logging,
                           cache=cache,
                           messages=4)

    def client1_job():
        client = Client('master',
                        db_address,
                        session=None)
        return [r for r in client.job('master.something', [b'1', b'2'], messages=2)]

    def client2_job():
        client = Client('master',
                        db_address,
                        session=None)
        return [r for r in client.job('master.something', [b'3', b'4'], messages=2)]

    with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
        results = [
            executor.submit(cache_service.start),
            executor.submit(puller.start),
            executor.submit(publisher.start),
            executor.submit(router.start),
            executor.submit(client1_job),
            executor.submit(client2_job),
        ]

        # This works because servers do not return values.
        for i, future in enumerate(concurrent.futures.as_completed(results)):
            try:
                result = future.result()
                if type(result) == list:
                    got = []
                    for r in result:
                        message = PalmMessage()
                        message.ParseFromString(r)
                        got.append(message.payload)

                    assert got == [b'1', b'2'] or got == [b'3', b'4']

            except Exception as exc:
                print(exc)
                lines = traceback.format_exception(*sys.exc_info())
                print(*lines)

    assert i == 5