Exemplo n.º 1
0
    def handle_snapshot(self, msg):
        """snapshot requests"""
        from rill.runtime.core import RillRuntimeError

        print("handle_snapshot: %r" % msg)

        if (msg.protocol, msg.command) == ('internal', 'startsync'):
            if msg.payload:
                # sync graph state
                graph_id = msg.payload
                print("Graph id: %s" % graph_id)

                try:
                    for msg in self.handler.get_graph_messages(graph_id):
                        msg.sendto(self.collector, msg.identity)

                    # send the network status
                    status = self.handler.get_network_status(graph_id)
                    Message(b'network', b'status',
                            status).sendto(self.collector, msg.identity)

                except RillRuntimeError as err:
                    self.dispatcher.send_error(msg, err)

            else:
                # sync runtime state
                # initial connection
                meta = self.runtime.get_runtime_meta()
                Message(b'runtime', b'runtime',
                        meta).sendto(self.collector, msg.identity)

                # send list of component specs
                # FIXME: move this under 'runtime' protocol?
                for msg in self.handler.get_all_component_specs():
                    msg.sendto(self.collector, msg.identity)

                # send list of graphs
                # FIXME: move this under 'runtime' protocol?
                # FIXME: notify subscribers about new graphs in handle_collect
                for graph_id in self.runtime._graphs.keys():
                    graph = self.runtime.get_graph(graph_id)
                    Message(b'graph', b'graph', {
                        'id': graph_id,
                        'metadata': graph.metadata
                    }).sendto(self.collector, msg.identity)

        else:
            print("E: bad request, aborting")
            dump(msg)
            self.loop.stop()
            return

        # Now send END message with revision number
        logging.info("I: Sending state snapshot=%d" % self.dispatcher.revision)
        Message(b'internal', b'endsync',
                self.dispatcher.revision).sendto(self.collector, msg.identity)
Exemplo n.º 2
0
 def get_all_component_specs(self):
     """
     Returns
     -------
     Iterator[Message]
     """
     for spec in self.runtime.get_all_component_specs():
         yield Message(b'component', b'component', spec)
Exemplo n.º 3
0
    def get_graph_messages(self, graph_id):
        """
        Parameters
        ----------
        graph_id : str

        Returns
        -------
        Iterator[rill.runtime.plumbing.Message]
        """
        for command, payload in get_graph_messages(
                self.runtime.get_graph(graph_id), graph_id):
            yield Message(b'graph', command, payload)
Exemplo n.º 4
0
    def handle_message(self, msg_frames):
        """
        Message recevied callback (from ROUTER socket).

        Parameters
        ----------
        msg_frames : List[bytes]
        """
        # first frame of a router message is the identity of the dealer
        identity = msg_frames.pop(0)
        msg = Message.from_frames(*msg_frames)
        msg.identity = identity

        if msg.protocol == 'internal':
            self.handle_snapshot(msg)
        else:
            self.handle_collect(msg)
Exemplo n.º 5
0
def client_agent_loop(ctx, pipe, on_recv):
    agent = ClientAgent(ctx, pipe)
    conn = None

    while True:
        # poller for both the pipe and the active server
        poller = zmq.Poller()
        poll_timer = None

        # choose a server socket
        server_sockets = []
        if agent.state == STATE_INITIAL:
            # In this state we ask the server for a snapshot,
            if agent.connection:
                conn = agent.connection
                print("I: waiting for server at %s:%d..." %
                      (conn.address, conn.port))
                # FIXME: why 2?  I think this may have to do with MAX_SERVER
                if conn.requests < 2:
                    Message(b'internal', b'startsync',
                            b'').sendto(conn.publisher)
                    conn.requests += 1
                conn.expiry = time.time() + SERVER_TTL
                print("switching to sync state")
                agent.state = STATE_SYNCING
                server_sockets = [conn.publisher]
        elif agent.state == STATE_SYNCING:
            # In this state we read from snapshot and we expect
            # the server to respond.
            server_sockets = [conn.publisher]
        elif agent.state == STATE_ACTIVE:
            if agent.graph:
                print("switching to graph sync state")
                Message(b'internal', b'startsync', agent.graph,
                        agent.message_id).sendto(conn.publisher)
                # wipe the graph subscription request so that we don't get
                # here unless the graph has changed
                agent.graph = None
                agent.message_id = None
                conn.expiry = time.time() + SERVER_TTL
                agent.state = STATE_SYNCING
                server_sockets = [conn.publisher]
            else:
                # In this state we read from subscriber.
                server_sockets = [conn.subscriber, conn.publisher]

        # we don't process messages from the client until we're done syncing.
        if agent.state != STATE_SYNCING:
            poller.register(agent.pipe, zmq.POLLIN)
        if len(server_sockets):
            # we have a second socket to poll:
            for server_socket in server_sockets:
                poller.register(server_socket, zmq.POLLIN)

        if conn is not None:
            poll_timer = 1e3 * max(0, conn.expiry - time.time())

        # ------------------------------------------------------------
        # Poll loop
        try:
            items = dict(poller.poll(poll_timer))
        except:
            raise  # DEBUG
            break  # Context has been shut down

        if len(items.keys()):
            for socket in items.keys():
                if socket is agent.pipe:
                    print("Control message")
                    agent.handle_message()
                else:
                    server_socket = socket
                    print("Server message")
                    msg = Message.from_frames(*server_socket.recv_multipart())
                    # Anything from server resets its expiry time
                    conn.expiry = time.time() + SERVER_TTL
                    if agent.state == STATE_SYNCING:
                        conn.requests = 0
                        if (msg.protocol, msg.command) == ('internal',
                                                           'endsync'):
                            # done syncing
                            assert isinstance(msg.payload, int)
                            agent.revision = msg.payload
                            print("switching to active state")
                            agent.state = STATE_ACTIVE
                            logging.info("I: received from %s:%d snapshot=%d",
                                         conn.address, conn.port,
                                         agent.revision)
                            # FIXME: send componentsready?
                            # self.send('component', 'componentsready')
                        else:
                            logging.info("I: received from %s:%d %s %d",
                                         conn.address, conn.port, msg,
                                         agent.revision)
                            on_recv(msg)

                    elif agent.state == STATE_ACTIVE:
                        # Receive message published from server.
                        # Discard out-of-revision updates, incl. hugz
                        print("msg %r" % msg)
                        if (msg.revision > agent.revision
                                or msg.command == 'error'
                                or msg.command == 'log'
                                or msg.protocol == 'component'):
                            agent.revision = msg.revision

                            on_recv(msg)

                            logging.info("I: received from %s:%d %s",
                                         conn.address, conn.port, msg)
                        else:
                            print("Sequence is too low: %d < %d" %
                                  (msg.revision, agent.revision))
                            # if kvmsg.key != b"HUGZ":
                            #     logging.info("I: received from %s:%d %s=%d %s",
                            #                  server.address, server.port, 'UPDATE',
                            #                  agent.revision, kvmsg.key)
                    else:
                        raise RuntimeError("This should not be possible")
        else:
            gevent.sleep(0)
Exemplo n.º 6
0
def test_runtime_server(Context, ZMQStream, is_socket_type):
    runtime = Runtime()
    runtime.register_module('tests.components')

    server = RuntimeServer(runtime)

    test_graph = Message(protocol='graph',
                         command='addgraph',
                         payload={
                             'id': 'testgraph',
                             'name': 'testgraph'
                         },
                         id=uuid.uuid1(),
                         identity=b'')
    server.handle_collect(test_graph)
    server.publisher.send_multipart.assert_called()

    result = Message.from_frames(
        *server.publisher.send_multipart.call_args_list[0][0][0])

    assert result.protocol == 'graph'
    assert result.command == 'addgraph'
    assert result.payload['id'] == 'testgraph'
    assert runtime._graphs['testgraph']

    component_result = Message.from_frames(
        *server.publisher.send_multipart.call_args_list[1][0][0])

    assert component_result.protocol == 'component'
    assert component_result.command == 'component'

    server.publisher.send_multipart.reset_mock()
    server.collector.send_multipart.reset_mock()

    snapshot_message = Message(protocol='internal',
                               command='startsync',
                               payload='testgraph',
                               id=uuid.uuid1(),
                               identity=b'')

    server.handle_snapshot(snapshot_message)

    clear_message = Message.from_frames(
        *server.collector.send_multipart.call_args_list[0][0][0])

    assert clear_message.protocol == 'graph'
    assert clear_message.command == 'clear'
    assert clear_message.payload == {
        'id': 'testgraph',
        'name': 'testgraph',
    }

    status_message = Message.from_frames(
        *server.collector.send_multipart.call_args_list[1][0][0])

    assert status_message.protocol == 'network'
    assert status_message.command == 'status'
    assert status_message.payload['graph'] == 'testgraph'
    assert status_message.payload['running'] is False
    assert status_message.payload['started'] is False

    server.publisher.send_multipart.reset_mock()
    server.collector.send_multipart.reset_mock()

    genarray = Message(protocol='graph',
                       command='addnode',
                       payload={
                           'graph': 'testgraph',
                           'id': 'node1',
                           'component': 'tests.components/GenerateArray'
                       },
                       id=uuid.uuid1(),
                       identity=b'')

    server.handle_collect(genarray)
    server.publisher.send_multipart.assert_called()
    genarray_result = Message.from_frames(
        *server.publisher.send_multipart.call_args[0][0])

    assert genarray_result.protocol == 'graph'
    assert genarray_result.command == 'addnode'
    assert genarray_result.payload == genarray.payload

    server.publisher.send_multipart.reset_mock()
    server.collector.send_multipart.reset_mock()

    repeat = Message(protocol='graph',
                     command='addnode',
                     payload={
                         'graph': 'testgraph',
                         'id': 'node2',
                         'component': 'tests.components/Repeat'
                     },
                     id=uuid.uuid1(),
                     identity=b'')

    server.handle_collect(repeat)
    server.publisher.send_multipart.assert_called()
    repeat_result = Message.from_frames(
        *server.publisher.send_multipart.call_args[0][0])

    assert repeat_result.protocol == 'graph'
    assert repeat_result.command == 'addnode'
    assert repeat_result.payload == repeat.payload

    server.publisher.send_multipart.reset_mock()
    server.collector.send_multipart.reset_mock()

    edge = Message(protocol='graph',
                   command='addedge',
                   payload={
                       'graph': 'testgraph',
                       'src': {
                           'node': 'node1',
                           'port': 'OUT'
                       },
                       'tgt': {
                           'node': 'node2',
                           'port': 'in'
                       }
                   },
                   id=uuid.uuid1(),
                   identity=b'')

    server.handle_collect(edge)
    server.publisher.send_multipart.assert_called()
    edge_result = Message.from_frames(
        *server.publisher.send_multipart.call_args[0][0])

    assert edge_result.payload['src'] == edge.payload['src']
    assert edge_result.payload['tgt'] == edge.payload['tgt']
Exemplo n.º 7
0
def test_runtime_flow():
    runtime = Runtime()
    runtime.register_module('tests.components')

    server = RuntimeServer(runtime, 4556)

    on_response = MagicMock()
    client = RuntimeClient(on_response)

    gevent.spawn(server.start)
    client.connect('tcp://localhost', 4556)

    gevent.sleep(.01)
    expected = [{
        'protocol': 'runtime',
        'command': 'runtime',
        'payload': runtime.get_runtime_meta()
    }]

    next_message_id = itertools.count(2).next
    for component_spec in runtime.get_all_component_specs():
        expected.append({
            'protocol': 'component',
            'command': 'component',
            'payload': component_spec,
            'message_id': next_message_id()
        })

    for i, call in enumerate(on_response.call_args_list):
        message = call[0][0].to_dict()
        expected_message = expected[i]
        assert message['payload'] == expected_message['payload']

    test_graph = {
        'protocol': 'graph',
        'command': 'addgraph',
        'payload': {
            'id': 'testgraph',
            'name': 'testgraph'
        },
        'id': uuid.uuid1()
    }

    on_response.reset_mock()

    client.send(Message(**test_graph))
    gevent.sleep(.01)

    add_graph_result = on_response.call_args_list[0][0][0].to_dict()
    assert add_graph_result['protocol'] == test_graph['protocol']
    assert add_graph_result['command'] == test_graph['command']
    assert add_graph_result['payload'] == test_graph['payload']

    component_result = on_response.call_args_list[1][0][0].to_dict()
    assert component_result['protocol'] == 'component'
    assert component_result['command'] == 'component'

    on_response.reset_mock()

    watch_graph = {
        'protocol': 'graph',
        'command': 'watch',
        'payload': {
            'id': 'testgraph'
        },
        'id': uuid.uuid1()
    }

    client.send(Message(**watch_graph))
    gevent.sleep(.01)

    expected_clear = {
        'protocol': 'graph',
        'command': 'clear',
        'payload': {
            'id': 'testgraph',
            'name': 'testgraph',
        }
    }
    clear_message = on_response.call_args_list[0][0][0].to_dict()
    assert clear_message['payload'] == expected_clear['payload']

    status_message = on_response.call_args_list[1][0][0].to_dict()
    assert status_message['protocol'] == 'network'
    assert status_message['command'] == 'status'
    assert status_message['payload']['graph'] == 'testgraph'
    assert status_message['payload']['running'] == False
    assert status_message['payload']['started'] == False

    genarray = {
        'protocol': 'graph',
        'command': 'addnode',
        'payload': {
            'graph': 'testgraph',
            'id': 'node1',
            'component': 'tests.components/GenerateArray'
        },
        'id': uuid.uuid1()
    }
    client.send(Message(**genarray))

    gevent.sleep(.01)

    genarray_message = on_response.call_args[0][0].to_dict()
    assert genarray_message['payload'] == genarray['payload']

    repeat = {
        'protocol': 'graph',
        'command': 'addnode',
        'payload': {
            'graph': 'testgraph',
            'id': 'node2',
            'component': 'tests.components/Repeat'
        },
        'id': uuid.uuid1()
    }
    client.send(Message(**repeat))
    edge = {
        'protocol': 'graph',
        'command': 'addedge',
        'payload': {
            'graph': 'testgraph',
            'src': {
                'node': 'node1',
                'port': 'OUT'
            },
            'tgt': {
                'node': 'node2',
                'port': 'in'
            }
        },
        'id': uuid.uuid1()
    }

    client.send(Message(**edge))
    gevent.sleep(.01)

    edge_message = on_response.call_args[0][0].to_dict()
    assert edge_message['payload']['src'] == edge['payload']['src']
    assert edge_message['payload']['tgt'] == edge['payload']['tgt']

    on_response.reset_mock()

    # test adding invalid data errors
    addinitial = {
        'protocol': 'graph',
        'command': 'addinitial',
        'payload': {
            'graph': 'testgraph',
            'src': {
                'data': ['lol']
            },
            'tgt': {
                'node': 'node1',
                'port': 'COUNT'
            }
        },
        'id': uuid.uuid1()
    }

    client.send(Message(**addinitial))
    gevent.sleep(.01)

    error_message = on_response.call_args[0][0].to_dict()
    assert error_message['command'] == 'error'
    on_response.reset_mock()

    add_inport = {
        'protocol': 'graph',
        'command': 'addinport',
        'payload': {
            'graph': 'testgraph',
            'node': 'node1',
            'port': 'COUNT',
            'public': 'IN',
            'metadata': {}
        },
        'id': uuid.uuid1()
    }

    client.send(Message(**add_inport))
    gevent.sleep(.01)

    inport_response = on_response.call_args_list[0][0][0]
    assert inport_response.protocol == 'graph'
    assert inport_response.command == 'addinport'

    component_response = on_response.call_args_list[1][0][0]
    assert component_response.protocol == 'component'
    assert component_response.command == 'component'
    assert len(component_response.payload['inPorts']) == 2

    client.disconnect()
    server.stop()
Exemplo n.º 8
0
 def on_message(self, message, **kwargs):
     self.logger.debug('INCOMING: {}'.format(message))
     if message:
         self.client.send(Message(**json.loads(message)))
Exemplo n.º 9
0
def _iter_client_messages():
    graph = get_graph("My Graph")[0]
    for command, payload in get_graph_messages(graph, GRAPH_ID):
        msg = Message('graph', command, payload, id=uuid.uuid1())
        msg.identity = 'foo'
        yield msg