Beispiel #1
0
    async def _handle_new_peer(self, peer_id: ID) -> None:
        try:
            stream: INetStream = await self.host.new_stream(
                peer_id, self.protocols)
        except SwarmException as error:
            logger.debug("fail to add new peer %s, error %s", peer_id, error)
            return

        self.peers[peer_id] = stream

        # Send hello packet
        hello = self.get_hello_packet()
        try:
            await stream.write(
                encode_varint_prefixed(hello.SerializeToString()))
        except StreamClosed:
            logger.debug("Fail to add new peer %s: stream closed", peer_id)
            del self.peers[peer_id]
            return
        # TODO: Check EOF of this stream.
        # TODO: Check if the peer in black list.
        try:
            self.router.add_peer(peer_id, stream.get_protocol())
        except Exception as error:
            logger.debug("fail to add new peer %s, error %s", peer_id, error)
            del self.peers[peer_id]
            return

        logger.debug("added new peer %s", peer_id)
Beispiel #2
0
    async def handle_iwant(
        self, iwant_msg: rpc_pb2.ControlIWant, sender_peer_id: ID
    ) -> None:
        """Forwards all request messages that are present in mcache to the
        requesting peer."""
        # FIXME: Update type of message ID
        # FIXME: Find a better way to parse the msg ids
        msg_ids: List[Any] = [literal_eval(msg) for msg in iwant_msg.messageIDs]
        msgs_to_forward: List[rpc_pb2.Message] = []
        for msg_id_iwant in msg_ids:
            # Check if the wanted message ID is present in mcache
            msg: rpc_pb2.Message = self.mcache.get(msg_id_iwant)

            # Cache hit
            if msg:
                # Add message to list of messages to forward to requesting peers
                msgs_to_forward.append(msg)

        # Forward messages to requesting peer
        # Should this just be publishing? No
        # because then the message will forwarded to peers in the topics contained in the messages.
        # We should
        # 1) Package these messages into a single packet
        packet: rpc_pb2.RPC = rpc_pb2.RPC()

        packet.publish.extend(msgs_to_forward)

        # 2) Serialize that packet
        rpc_msg: bytes = packet.SerializeToString()

        # 3) Get the stream to this peer
        peer_stream = self.pubsub.peers[sender_peer_id]

        # 4) And write the packet to the stream
        await peer_stream.write(encode_varint_prefixed(rpc_msg))
Beispiel #3
0
    async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
        """Invoked to forward a new message that has been validated."""
        self.mcache.put(pubsub_msg)

        peers_gen = self._get_peers_to_send(
            pubsub_msg.topicIDs,
            msg_forwarder=msg_forwarder,
            origin=ID(pubsub_msg.from_id),
        )
        rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg])

        logger.debug("publishing message %s", pubsub_msg)

        for peer_id in peers_gen:
            if peer_id not in self.pubsub.peers:
                continue
            stream = self.pubsub.peers[peer_id]
            # FIXME: We should add a `WriteMsg` similar to write delimited messages.
            #   Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
            # TODO: Go use `sendRPC`, which possibly piggybacks gossip/control messages.
            try:
                await stream.write(encode_varint_prefixed(rpc_msg.SerializeToString()))
            except StreamClosed:
                logger.debug("Fail to publish message to %s: stream closed", peer_id)
                self.pubsub._handle_dead_peer(peer_id)
Beispiel #4
0
    async def publish(self, msg_forwarder: ID,
                      pubsub_msg: rpc_pb2.Message) -> None:
        """
        Invoked to forward a new message that has been validated.
        This is where the "flooding" part of floodsub happens

        With flooding, routing is almost trivial: for each incoming message,
        forward to all known peers in the topic. There is a bit of logic,
        as the router maintains a timed cache of previous messages,
        so that seen messages are not further forwarded.
        It also never forwards a message back to the source
        or the peer that forwarded the message.
        :param msg_forwarder: peer ID of the peer who forwards the message to us
        :param pubsub_msg: pubsub message in protobuf.
        """

        peers_gen = self._get_peers_to_send(
            pubsub_msg.topicIDs,
            msg_forwarder=msg_forwarder,
            origin=ID(pubsub_msg.from_id),
        )
        rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg])

        logger.debug("publishing message %s", pubsub_msg)

        for peer_id in peers_gen:
            stream = self.pubsub.peers[peer_id]
            # FIXME: We should add a `WriteMsg` similar to write delimited messages.
            #   Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
            await stream.write(
                encode_varint_prefixed(rpc_msg.SerializeToString()))
Beispiel #5
0
    async def message_all_peers(self, raw_msg: bytes) -> None:
        """
        Broadcast a message to peers
        :param raw_msg: raw contents of the message to broadcast
        """

        # Broadcast message
        for stream in self.peers.values():
            # Write message to stream
            await stream.write(encode_varint_prefixed(raw_msg))
Beispiel #6
0
async def test_message_all_peers(pubsubs_fsub, monkeypatch):
    peer_ids = [ID(b"\x12\x20" + i.to_bytes(32, "big")) for i in range(10)]
    mock_peers = {peer_id: FakeNetStream() for peer_id in peer_ids}
    monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers)

    empty_rpc = rpc_pb2.RPC()
    empty_rpc_bytes = empty_rpc.SerializeToString()
    empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes)
    await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes)
    for stream in mock_peers.values():
        assert (await stream.read()) == empty_rpc_bytes_len_prefixed
Beispiel #7
0
    async def _handle_new_peer(self, peer_id: ID) -> None:
        stream: INetStream = await self.host.new_stream(
            peer_id, self.protocols)

        self.peers[peer_id] = stream

        # Send hello packet
        hello = self.get_hello_packet()
        await stream.write(encode_varint_prefixed(hello.SerializeToString()))
        # TODO: Check EOF of this stream.
        # TODO: Check if the peer in black list.
        self.router.add_peer(peer_id, stream.get_protocol())
Beispiel #8
0
    async def emit_control_message(self, control_msg: rpc_pb2.ControlMessage,
                                   to_peer: ID) -> None:
        # Add control message to packet
        packet: rpc_pb2.RPC = rpc_pb2.RPC()
        packet.control.CopyFrom(control_msg)

        rpc_msg: bytes = packet.SerializeToString()

        # Get stream for peer from pubsub
        peer_stream = self.pubsub.peers[to_peer]

        # Write rpc to stream
        await peer_stream.write(encode_varint_prefixed(rpc_msg))
Beispiel #9
0
    async def message_all_peers(self, raw_msg: bytes) -> None:
        """
        Broadcast a message to peers.

        :param raw_msg: raw contents of the message to broadcast
        """

        # Broadcast message
        for stream in self.peers.values():
            # Write message to stream
            try:
                await stream.write(encode_varint_prefixed(raw_msg))
            except StreamClosed:
                peer_id = stream.muxed_conn.peer_id
                logger.debug("Fail to message peer %s: stream closed", peer_id)
                self._handle_dead_peer(peer_id)
Beispiel #10
0
async def test_message_all_peers(monkeypatch, security_protocol):
    async with PubsubFactory.create_batch_with_floodsub(
            1, security_protocol=security_protocol
    ) as pubsubs_fsub, net_stream_pair_factory(
            security_protocol=security_protocol) as stream_pair:
        peer_id = IDFactory()
        mock_peers = {peer_id: stream_pair[0]}
        with monkeypatch.context() as m:
            m.setattr(pubsubs_fsub[0], "peers", mock_peers)

            empty_rpc = rpc_pb2.RPC()
            empty_rpc_bytes = empty_rpc.SerializeToString()
            empty_rpc_bytes_len_prefixed = encode_varint_prefixed(
                empty_rpc_bytes)
            await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes)
            assert (await stream_pair[1].read(MAX_READ_LEN)
                    ) == empty_rpc_bytes_len_prefixed
Beispiel #11
0
    async def send_message(self, flag: HeaderTags, data: Optional[bytes],
                           stream_id: StreamID) -> int:
        """
        sends a message over the connection
        :param header: header to use
        :param data: data to send in the message
        :param stream_id: stream the message is in
        """
        # << by 3, then or with flag
        header = encode_uvarint((stream_id.channel_id << 3) | flag.value)

        if data is None:
            data = b""

        _bytes = header + encode_varint_prefixed(data)

        return await self.write_to_stream(_bytes)
Beispiel #12
0
    async def emit_control_message(
        self, control_msg: rpc_pb2.ControlMessage, to_peer: ID
    ) -> None:
        # Add control message to packet
        packet: rpc_pb2.RPC = rpc_pb2.RPC()
        packet.control.CopyFrom(control_msg)

        rpc_msg: bytes = packet.SerializeToString()

        # Get stream for peer from pubsub
        peer_stream = self.pubsub.peers[to_peer]

        # Write rpc to stream
        try:
            await peer_stream.write(encode_varint_prefixed(rpc_msg))
        except StreamClosed:
            logger.debug("Fail to emit control message to %s: stream closed", to_peer)
            self.pubsub._handle_dead_peer(to_peer)
Beispiel #13
0
async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
    stream = FakeNetStream()

    await pubsubs_fsub[0].subscribe(TESTING_TOPIC)

    event_push_msg = asyncio.Event()
    event_handle_subscription = asyncio.Event()
    event_handle_rpc = asyncio.Event()

    async def mock_push_msg(msg_forwarder, msg):
        event_push_msg.set()

    def mock_handle_subscription(origin_id, sub_message):
        event_handle_subscription.set()

    async def mock_handle_rpc(rpc, sender_peer_id):
        event_handle_rpc.set()

    monkeypatch.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg)
    monkeypatch.setattr(
        pubsubs_fsub[0], "handle_subscription", mock_handle_subscription
    )
    monkeypatch.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc)

    async def wait_for_event_occurring(event):
        try:
            await asyncio.wait_for(event.wait(), timeout=1)
        except asyncio.TimeoutError as error:
            event.clear()
            raise asyncio.TimeoutError(
                f"Event {event} is not set before the timeout. "
                "This indicates the mocked functions are not called properly."
            ) from error
        else:
            event.clear()

    # Kick off the task `continuously_read_stream`
    task = asyncio.ensure_future(pubsubs_fsub[0].continuously_read_stream(stream))

    # Test: `push_msg` is called when publishing to a subscribed topic.
    publish_subscribed_topic = rpc_pb2.RPC(
        publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])]
    )
    await stream.write(
        encode_varint_prefixed(publish_subscribed_topic.SerializeToString())
    )
    await wait_for_event_occurring(event_push_msg)
    # Make sure the other events are not emitted.
    with pytest.raises(asyncio.TimeoutError):
        await wait_for_event_occurring(event_handle_subscription)
    with pytest.raises(asyncio.TimeoutError):
        await wait_for_event_occurring(event_handle_rpc)

    # Test: `push_msg` is not called when publishing to a topic-not-subscribed.
    publish_not_subscribed_topic = rpc_pb2.RPC(
        publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])]
    )
    await stream.write(
        encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString())
    )
    with pytest.raises(asyncio.TimeoutError):
        await wait_for_event_occurring(event_push_msg)

    # Test: `handle_subscription` is called when a subscription message is received.
    subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()])
    await stream.write(encode_varint_prefixed(subscription_msg.SerializeToString()))
    await wait_for_event_occurring(event_handle_subscription)
    # Make sure the other events are not emitted.
    with pytest.raises(asyncio.TimeoutError):
        await wait_for_event_occurring(event_push_msg)
    with pytest.raises(asyncio.TimeoutError):
        await wait_for_event_occurring(event_handle_rpc)

    # Test: `handle_rpc` is called when a control message is received.
    control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage())
    await stream.write(encode_varint_prefixed(control_msg.SerializeToString()))
    await wait_for_event_occurring(event_handle_rpc)
    # Make sure the other events are not emitted.
    with pytest.raises(asyncio.TimeoutError):
        await wait_for_event_occurring(event_push_msg)
    with pytest.raises(asyncio.TimeoutError):
        await wait_for_event_occurring(event_handle_subscription)

    task.cancel()
Beispiel #14
0
async def test_continuously_read_stream(monkeypatch, nursery,
                                        security_protocol):
    async def wait_for_event_occurring(event):
        await trio.hazmat.checkpoint()
        with trio.fail_after(0.1):
            await event.wait()

    class Events(NamedTuple):
        push_msg: trio.Event
        handle_subscription: trio.Event
        handle_rpc: trio.Event

    @contextmanager
    def mock_methods():
        event_push_msg = trio.Event()
        event_handle_subscription = trio.Event()
        event_handle_rpc = trio.Event()

        async def mock_push_msg(msg_forwarder, msg):
            event_push_msg.set()
            await trio.hazmat.checkpoint()

        def mock_handle_subscription(origin_id, sub_message):
            event_handle_subscription.set()

        async def mock_handle_rpc(rpc, sender_peer_id):
            event_handle_rpc.set()
            await trio.hazmat.checkpoint()

        with monkeypatch.context() as m:
            m.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg)
            m.setattr(pubsubs_fsub[0], "handle_subscription",
                      mock_handle_subscription)
            m.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc)
            yield Events(event_push_msg, event_handle_subscription,
                         event_handle_rpc)

    async with PubsubFactory.create_batch_with_floodsub(
            1, security_protocol=security_protocol
    ) as pubsubs_fsub, net_stream_pair_factory(
            security_protocol=security_protocol) as stream_pair:
        await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
        # Kick off the task `continuously_read_stream`
        nursery.start_soon(pubsubs_fsub[0].continuously_read_stream,
                           stream_pair[0])

        # Test: `push_msg` is called when publishing to a subscribed topic.
        publish_subscribed_topic = rpc_pb2.RPC(
            publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])])
        with mock_methods() as events:
            await stream_pair[1].write(
                encode_varint_prefixed(
                    publish_subscribed_topic.SerializeToString()))
            await wait_for_event_occurring(events.push_msg)
            # Make sure the other events are not emitted.
            with pytest.raises(trio.TooSlowError):
                await wait_for_event_occurring(events.handle_subscription)
            with pytest.raises(trio.TooSlowError):
                await wait_for_event_occurring(events.handle_rpc)

        # Test: `push_msg` is not called when publishing to a topic-not-subscribed.
        publish_not_subscribed_topic = rpc_pb2.RPC(
            publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])])
        with mock_methods() as events:
            await stream_pair[1].write(
                encode_varint_prefixed(
                    publish_not_subscribed_topic.SerializeToString()))
            with pytest.raises(trio.TooSlowError):
                await wait_for_event_occurring(events.push_msg)

        # Test: `handle_subscription` is called when a subscription message is received.
        subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()])
        with mock_methods() as events:
            await stream_pair[1].write(
                encode_varint_prefixed(subscription_msg.SerializeToString()))
            await wait_for_event_occurring(events.handle_subscription)
            # Make sure the other events are not emitted.
            with pytest.raises(trio.TooSlowError):
                await wait_for_event_occurring(events.push_msg)
            with pytest.raises(trio.TooSlowError):
                await wait_for_event_occurring(events.handle_rpc)

        # Test: `handle_rpc` is called when a control message is received.
        control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage())
        with mock_methods() as events:
            await stream_pair[1].write(
                encode_varint_prefixed(control_msg.SerializeToString()))
            await wait_for_event_occurring(events.handle_rpc)
            # Make sure the other events are not emitted.
            with pytest.raises(trio.TooSlowError):
                await wait_for_event_occurring(events.push_msg)
            with pytest.raises(trio.TooSlowError):
                await wait_for_event_occurring(events.handle_subscription)
Beispiel #15
0
 def encode_msg(self, msg: bytes) -> bytes:
     msg_len = len(msg)
     if msg_len > self.max_msg_size:
         raise MessageTooLarge(
             f"msg_len={msg_len} > max_msg_size={self.max_msg_size}")
     return encode_varint_prefixed(msg)