Exemplo n.º 1
0
def test_closest_nodes_only_additional(empty_routing_table):
    target = NodeIDFactory()
    nodes = [NodeIDFactory() for _ in range(10)]
    closest_nodes = list(iter_closest_nodes(target, empty_routing_table,
                                            nodes))
    assert closest_nodes == sorted(
        nodes, key=lambda node: compute_distance(target, node))
Exemplo n.º 2
0
def test_closest_nodes_only_routing(empty_routing_table):
    target = NodeIDFactory()
    nodes = [NodeIDFactory() for _ in range(10)]
    for node in nodes:
        empty_routing_table.update(node)

    closest_nodes = list(iter_closest_nodes(target, empty_routing_table, []))
    assert closest_nodes == sorted(
        nodes, key=lambda node: compute_distance(target, node))
Exemplo n.º 3
0
def test_iter_around(routing_table, center_node_id):
    reference_node_id = NodeIDFactory.at_log_distance(center_node_id, 100)
    node_ids = tuple(
        NodeIDFactory.at_log_distance(reference_node_id, distance)
        for distance in (1, 2, 100, 200))
    for node_id in node_ids:
        routing_table.update(node_id)

    assert tuple(
        routing_table.iter_nodes_around(reference_node_id)) == node_ids
    assert tuple(routing_table.iter_nodes_around(node_ids[0])) == node_ids
    assert tuple(routing_table.iter_nodes_around(node_ids[-1])) != node_ids
Exemplo n.º 4
0
def test_get_nodes_at_log_distance(routing_table, center_node_id, bucket_size):
    nodes = tuple(
        NodeIDFactory.at_log_distance(center_node_id, 200)
        for _ in range(bucket_size))
    farther_nodes = tuple(
        NodeIDFactory.at_log_distance(center_node_id, 201) for _ in range(5))
    closer_nodes = tuple(
        NodeIDFactory.at_log_distance(center_node_id, 199) for _ in range(5))
    for node_id in nodes + farther_nodes + closer_nodes:
        routing_table.update(node_id)

    assert set(routing_table.get_nodes_at_log_distance(200)) == set(nodes)
Exemplo n.º 5
0
def test_lookup_generator_mixed(empty_routing_table):
    target = NodeIDFactory()
    nodes = sorted(
        [NodeIDFactory() for _ in range(10)],
        key=lambda node: compute_distance(node, target),
    )
    nodes_in_routing_table = nodes[:3] + nodes[6:8]
    nodes_in_additional = nodes[3:6] + nodes[8:]
    for node in nodes_in_routing_table:
        empty_routing_table.update(node)
    closest_nodes = list(
        iter_closest_nodes(target, empty_routing_table, nodes_in_additional))
    assert closest_nodes == nodes
Exemplo n.º 6
0
def test_update(routing_table, center_node_id):
    node_id_1 = NodeIDFactory.at_log_distance(center_node_id, 200)
    node_id_2 = NodeIDFactory.at_log_distance(center_node_id, 200)
    routing_table.update(node_id_1)
    routing_table.update(node_id_2)
    assert routing_table.get_nodes_at_log_distance(200) == (node_id_2,
                                                            node_id_1)
    routing_table.update(node_id_2)
    assert routing_table.get_nodes_at_log_distance(200) == (node_id_2,
                                                            node_id_1)
    routing_table.update(node_id_1)
    assert routing_table.get_nodes_at_log_distance(200) == (node_id_1,
                                                            node_id_2)
Exemplo n.º 7
0
def test_iter_all_random(routing_table, center_node_id):
    nodes_in_insertion_order = []
    # Use a relatively high number of nodes here otherwise we could have two consecutive calls
    # yielding nodes in the same order.
    for _ in range(100):
        node_id = NodeIDFactory()
        routing_table.update(node_id)
        nodes_in_insertion_order.append(node_id)

    nodes_in_iteration_order = [
        node for node in routing_table.iter_all_random()
    ]

    # We iterate over all nodes
    table_length = sum(
        len(bucket_or_cache) for bucket_or_cache in itertools.chain(
            routing_table.buckets, routing_table.replacement_caches))
    assert (len(nodes_in_iteration_order) == table_length ==
            len(nodes_in_insertion_order))
    # No repeated nodes are returned
    assert len(set(nodes_in_iteration_order)) == len(nodes_in_iteration_order)
    # The order in which we iterate is not the same as the one in which nodes were inserted.
    assert nodes_in_iteration_order != nodes_in_insertion_order

    second_iteration_order = [node for node in routing_table.iter_all_random()]

    # Multiple calls should yield the same nodes, but in a different order.
    assert set(nodes_in_iteration_order) == set(second_iteration_order)
    assert nodes_in_iteration_order != second_iteration_order
Exemplo n.º 8
0
async def test_ping_handler_updates_routing_table(
    ping_handler_service,
    inbound_message_channels,
    outbound_message_channels,
    local_enr,
    remote_enr,
    routing_table,
):
    distance = compute_log_distance(remote_enr.node_id, local_enr.node_id)
    other_node_id = NodeIDFactory.at_log_distance(local_enr.node_id, distance)
    routing_table.update(other_node_id)
    assert routing_table.get_nodes_at_log_distance(distance) == (
        other_node_id,
        remote_enr.node_id,
    )

    ping = PingMessageFactory()
    inbound_message = InboundMessageFactory(
        message=ping,
        sender_node_id=remote_enr.node_id,
    )
    await inbound_message_channels[0].send(inbound_message)
    await wait_all_tasks_blocked()

    assert routing_table.get_nodes_at_log_distance(distance) == (
        remote_enr.node_id,
        other_node_id,
    )
Exemplo n.º 9
0
def invalid_node_id(request, bob, bob_network):
    if request.param == "unknown-endpoint":
        return NodeIDFactory().hex()
    elif request.param == "too-short":
        return (b"\x01" * 31).hex()
    elif request.param == "too-long":
        return (b"\x01" * 33).hex()
    elif request.param == "enode-missing-scheme":
        return f"{bob.node_id.hex()}@{bob.endpoint}"
    elif request.param == "enode-missing-endpoint":
        return f"enode://{bob.node_id.hex()}@"
    elif request.param == "enode-bad-nodeid":
        too_short_nodeid = b"\x01" * 31
        return f"enode://{too_short_nodeid.hex()}@{bob.endpoint}"
    elif request.param == "enr-without-prefix":
        return repr(bob.enr)[4:]
    elif request.param == "enr-without-endpoint":
        bob_network.enr_manager.update(
            (b"ip", None),
            (b"udp", None),
            (b"tcp", None),
        )
        return repr(bob_network.enr_manager.enr)
    else:
        raise Exception(f"Unhandled param: {request.param}")
Exemplo n.º 10
0
def test_at_log_distance():
    for i in range(10000):
        node = NodeIDFactory()
        distance = random.randint(1, 256)
        other = at_log_distance(node, distance)
        actual = compute_log_distance(node, other)
        assert actual == distance
Exemplo n.º 11
0
def test_is_empty(routing_table):
    assert routing_table.is_empty
    node_id = NodeIDFactory()
    routing_table.update(node_id)
    assert not routing_table.is_empty
    routing_table.remove(node_id)
    assert routing_table.is_empty
Exemplo n.º 12
0
def test_content_storage_closest_and_furthest_iteration(base_storage):
    content_keys = tuple(b"key-" + bytes([i]) for i in range(32))
    for idx, content_key in enumerate(content_keys):
        base_storage.set_content(content_key, b"dummy-" + bytes([idx]))

    target = NodeIDFactory()

    expected_closest = tuple(
        sorted(
            content_keys,
            key=lambda content_key: compute_content_distance(
                target, content_key_to_content_id(content_key)
            ),
        )
    )
    expected_furthest = tuple(
        sorted(
            content_keys,
            key=lambda content_key: compute_content_distance(
                target, content_key_to_content_id(content_key)
            ),
            reverse=True,
        )
    )

    actual_closest = tuple(base_storage.iter_closest(target))
    actual_furthest = tuple(base_storage.iter_furthest(target))

    assert actual_closest == expected_closest
    assert actual_furthest == expected_furthest
Exemplo n.º 13
0
def test_request_tracker_reserve_request_id_generated():
    tracker = RequestTracker()

    node_id = NodeIDFactory()

    with tracker.reserve_request_id(node_id) as request_id:
        assert tracker.is_request_id_active(node_id, request_id)
    assert not tracker.is_request_id_active(node_id, request_id)
Exemplo n.º 14
0
def test_least_recently_updated_distance(routing_table, center_node_id):
    with pytest.raises(ValueError):
        routing_table.get_least_recently_updated_log_distance()

    node_id_1 = NodeIDFactory.at_log_distance(center_node_id, 200)
    routing_table.update(node_id_1)
    assert routing_table.get_least_recently_updated_log_distance() == 200

    node_id_2 = NodeIDFactory.at_log_distance(center_node_id, 100)
    routing_table.update(node_id_2)
    assert routing_table.get_least_recently_updated_log_distance() == 200
    routing_table.update(node_id_1)
    assert routing_table.get_least_recently_updated_log_distance() == 100

    routing_table.remove(node_id_1)
    assert routing_table.get_least_recently_updated_log_distance() == 100

    routing_table.remove(node_id_2)
    with pytest.raises(ValueError):
        routing_table.get_least_recently_updated_log_distance()
Exemplo n.º 15
0
def test_add(routing_table, center_node_id):
    assert routing_table.get_nodes_at_log_distance(255) == ()

    node_id_1 = NodeIDFactory.at_log_distance(center_node_id, 255)
    routing_table.update(node_id_1)
    assert routing_table.get_nodes_at_log_distance(255) == (node_id_1, )

    node_id_2 = NodeIDFactory.at_log_distance(center_node_id, 255)
    routing_table.update(node_id_2)
    assert routing_table.get_nodes_at_log_distance(255) == (node_id_2,
                                                            node_id_1)

    node_id_3 = NodeIDFactory.at_log_distance(center_node_id, 255)
    routing_table.update(node_id_3)
    assert routing_table.get_nodes_at_log_distance(255) == (node_id_2,
                                                            node_id_1)

    node_id_4 = NodeIDFactory.at_log_distance(center_node_id, 1)
    routing_table.update(node_id_4)
    assert routing_table.get_nodes_at_log_distance(1) == (node_id_4, )
Exemplo n.º 16
0
async def test_pool_get_session_by_endpoint(tester, initiator, pool, events):
    endpoint = EndpointFactory()

    # A: initiated locally, handshake incomplete
    remote_a = tester.node(endpoint=endpoint)
    session_a = pool.initiate_session(endpoint, remote_a.node_id)

    # B: initiated locally, handshake complete
    remote_b = tester.node(endpoint=endpoint)
    driver_b = tester.session_pair(
        initiator,
        remote_b,
    )
    with trio.fail_after(1):
        await driver_b.handshake()
    session_b = driver_b.initiator.session

    # C: initiated remotely, handshake incomplete
    session_c = pool.receive_session(endpoint)

    # D: initiated remotely, handshake complete
    remote_d = tester.node(endpoint=endpoint)
    driver_d = tester.session_pair(
        remote_d,
        initiator,
    )
    await driver_d.handshake()
    session_d = driver_d.recipient.session

    # Some other sessions with non-matching endpoints before handshake
    session_e = pool.receive_session(EndpointFactory())
    session_f = pool.initiate_session(EndpointFactory(), NodeIDFactory())

    # Some other sessions with non-matching endpoints after handshake
    driver_g = tester.session_pair(initiator, )
    await driver_g.handshake()
    session_g = driver_g.initiator.session

    driver_h = tester.session_pair(recipient=initiator, )
    await driver_h.handshake()
    session_h = driver_h.recipient.session

    endpoint_matches = pool.get_sessions_for_endpoint(endpoint)
    assert len(endpoint_matches) == 4
    assert session_a in endpoint_matches
    assert session_b in endpoint_matches
    assert session_c in endpoint_matches
    assert session_d in endpoint_matches

    assert session_e not in endpoint_matches
    assert session_f not in endpoint_matches
    assert session_g not in endpoint_matches
    assert session_h not in endpoint_matches
Exemplo n.º 17
0
def test_remove(routing_table, center_node_id):
    node_id_1 = NodeIDFactory.at_log_distance(center_node_id, 200)
    node_id_2 = NodeIDFactory.at_log_distance(center_node_id, 200)
    node_id_3 = NodeIDFactory.at_log_distance(center_node_id, 200)
    node_id_4 = NodeIDFactory.at_log_distance(center_node_id, 200)
    routing_table.update(node_id_1)
    routing_table.update(node_id_2)
    routing_table.update(node_id_3)
    routing_table.update(node_id_4)
    assert routing_table.get_nodes_at_log_distance(200) == (node_id_2,
                                                            node_id_1)

    routing_table.remove(
        node_id_4)  # remove from replacement cache, shouldn't appear again
    routing_table.remove(node_id_2)
    assert routing_table.get_nodes_at_log_distance(200) == (node_id_1,
                                                            node_id_3)
    routing_table.remove(node_id_3)
    assert routing_table.get_nodes_at_log_distance(200) == (node_id_1, )
    routing_table.remove(node_id_1)
    assert routing_table.get_nodes_at_log_distance(200) == ()
    routing_table.remove(node_id_1)  # shouldn't raise
Exemplo n.º 18
0
def test_request_tracker_reserve_request_id_provided():
    tracker = RequestTracker()

    node_id = NodeIDFactory()

    request_id = b"\x01\x02\x03\04"

    assert not tracker.is_request_id_active(node_id, request_id)

    with tracker.reserve_request_id(node_id, request_id) as actual_request_id:
        assert actual_request_id == request_id
        assert tracker.is_request_id_active(node_id, request_id)
    assert not tracker.is_request_id_active(node_id, request_id)
Exemplo n.º 19
0
async def test_subscribe_delivers_messages():
    node_id = NodeIDFactory()
    endpoint = EndpointFactory()
    manager = SubscriptionManager()

    async with manager.subscribe(MessageForTesting) as subscription:
        with pytest.raises(trio.WouldBlock):
            subscription.receive_nowait()

        manager.feed_subscriptions(
            InboundMessage(
                message=OtherMessageForTesting(1234),
                sender_node_id=node_id,
                sender_endpoint=endpoint,
            )
        )
        manager.feed_subscriptions(
            InboundMessage(
                message=MessageForTesting(1234),
                sender_node_id=node_id,
                sender_endpoint=endpoint,
            )
        )

        with trio.fail_after(1):
            message = await subscription.receive()

        assert isinstance(message.message, MessageForTesting)
        assert message.message.id == 1234

        with pytest.raises(trio.WouldBlock):
            subscription.receive_nowait()

    manager.feed_subscriptions(
        InboundMessage(
            message=MessageForTesting(1234),
            sender_node_id=node_id,
            sender_endpoint=endpoint,
        )
    )
    with pytest.raises(trio.ClosedResourceError):
        subscription.receive_nowait()
Exemplo n.º 20
0
async def test_rpc_tableInfo_web3(w3, routing_table, rpc_server):
    local_node_id = routing_table.center_node_id
    # 16/16 at furthest distance
    for _ in range(routing_table.bucket_size * 2):
        routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 256))
    # 16/8 at next bucket
    for _ in range(int(routing_table.bucket_size * 1.5)):
        routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 255))
    # 16/4 at next bucket
    for _ in range(int(routing_table.bucket_size * 1.25)):
        routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 254))
    # 16 in this one
    for _ in range(int(routing_table.bucket_size)):
        routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 253))
    # 8 in this one
    for _ in range(int(routing_table.bucket_size // 2)):
        routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 252))
    # 4 in this one
    for _ in range(int(routing_table.bucket_size // 4)):
        routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 251))

    table_info = await trio.to_thread.run_sync(w3.discv5.get_routing_table_info)
    assert table_info.center_node_id == routing_table.center_node_id
    assert table_info.bucket_size == routing_table.bucket_size
    assert table_info.num_buckets == routing_table.num_buckets
    assert len(table_info.buckets) == 6
    bucket_256 = table_info.buckets[256]
    bucket_255 = table_info.buckets[255]
    bucket_254 = table_info.buckets[254]
    bucket_253 = table_info.buckets[253]
    bucket_252 = table_info.buckets[252]
    bucket_251 = table_info.buckets[251]

    assert bucket_256.idx == 256
    assert bucket_256.is_full is True
    assert len(bucket_256.nodes) == routing_table.bucket_size
    assert len(bucket_256.replacement_cache) == routing_table.bucket_size

    assert bucket_255.idx == 255
    assert bucket_255.is_full is True
    assert len(bucket_255.nodes) == routing_table.bucket_size
    assert len(bucket_255.replacement_cache) == routing_table.bucket_size // 2

    assert bucket_254.idx == 254
    assert bucket_254.is_full is True
    assert len(bucket_254.nodes) == routing_table.bucket_size
    assert len(bucket_254.replacement_cache) == routing_table.bucket_size // 4

    assert bucket_253.idx == 253
    assert bucket_253.is_full is True
    assert len(bucket_253.nodes) == routing_table.bucket_size
    assert not bucket_253.replacement_cache

    assert bucket_252.idx == 252
    assert bucket_252.is_full is False
    assert len(bucket_252.nodes) == routing_table.bucket_size // 2
    assert not bucket_253.replacement_cache

    assert bucket_251.idx == 251
    assert bucket_251.is_full is False
    assert len(bucket_251.nodes) == routing_table.bucket_size // 4
    assert not bucket_253.replacement_cache
Exemplo n.º 21
0
def center_node_id():
    return NodeIDFactory()
Exemplo n.º 22
0
async def test_subscribe_filters_by_node_id_and_endpoint():
    node_id = NodeIDFactory()
    endpoint = EndpointFactory()
    manager = SubscriptionManager()

    async with AsyncExitStack() as stack:

        subscription_a = await stack.enter_async_context(
            manager.subscribe(MessageForTesting)
        )
        subscription_b = await stack.enter_async_context(
            manager.subscribe(MessageForTesting, node_id=node_id)
        )
        subscription_c = await stack.enter_async_context(
            manager.subscribe(MessageForTesting, node_id=node_id, endpoint=endpoint)
        )

        with pytest.raises(trio.WouldBlock):
            subscription_a.receive_nowait()
        with pytest.raises(trio.WouldBlock):
            subscription_b.receive_nowait()
        with pytest.raises(trio.WouldBlock):
            subscription_c.receive_nowait()

        # One Message that doesn't match the message type
        manager.feed_subscriptions(
            InboundMessage(
                message=OtherMessageForTesting(1234),
                sender_node_id=node_id,
                sender_endpoint=endpoint,
            )
        )

        # One Message that only matches the message type
        manager.feed_subscriptions(
            InboundMessage(
                message=MessageForTesting(1),
                sender_node_id=NodeIDFactory(),
                sender_endpoint=EndpointFactory(),
            )
        )

        # One Message that matches the message type AND node_id
        manager.feed_subscriptions(
            InboundMessage(
                message=MessageForTesting(2),
                sender_node_id=node_id,
                sender_endpoint=EndpointFactory(),
            )
        )

        # One Message that matches the message type AND node_id AND endpoint
        manager.feed_subscriptions(
            InboundMessage(
                message=MessageForTesting(3),
                sender_node_id=node_id,
                sender_endpoint=endpoint,
            )
        )

        # now grab all the messages
        with trio.fail_after(1):
            message_a_0 = await subscription_a.receive()
            message_a_1 = await subscription_a.receive()
            message_a_2 = await subscription_a.receive()

            message_b_0 = await subscription_b.receive()
            message_b_1 = await subscription_b.receive()

            message_c_0 = await subscription_c.receive()

        # all of the subscriptions should now be empty
        with pytest.raises(trio.WouldBlock):
            subscription_a.receive_nowait()
        with pytest.raises(trio.WouldBlock):
            subscription_b.receive_nowait()
        with pytest.raises(trio.WouldBlock):
            subscription_c.receive_nowait()

        assert message_a_0.message.id == 1
        assert message_a_1.message.id == 2
        assert message_a_2.message.id == 3

        assert message_b_0.message.id == 2
        assert message_b_1.message.id == 3

        assert message_c_0.message.id == 3
Exemplo n.º 23
0
def test_fill_bucket(routing_table, center_node_id, bucket_size):
    assert not routing_table.get_nodes_at_log_distance(200)
    for _ in range(2 * bucket_size):
        routing_table.update(NodeIDFactory.at_log_distance(
            center_node_id, 200))
    assert len(routing_table.get_nodes_at_log_distance(200)) == bucket_size
Exemplo n.º 24
0
def test_closest_nodes_empty(empty_routing_table):
    target = NodeIDFactory()
    assert list(iter_closest_nodes(target, empty_routing_table, [])) == []
Exemplo n.º 25
0
async def test_rpc_tableInfo(make_request, routing_table):
    local_node_id = routing_table.center_node_id
    # 16/16 at furthest distance
    for _ in range(routing_table.bucket_size * 2):
        routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 256))
    # 16/8 at next bucket
    for _ in range(int(routing_table.bucket_size * 1.5)):
        routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 255))
    # 16/4 at next bucket
    for _ in range(int(routing_table.bucket_size * 1.25)):
        routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 254))
    # 16 in this one
    for _ in range(int(routing_table.bucket_size)):
        routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 253))
    # 8 in this one
    for _ in range(int(routing_table.bucket_size // 2)):
        routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 252))
    # 4 in this one
    for _ in range(int(routing_table.bucket_size // 4)):
        routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 251))

    table_info = await make_request("discv5_routingTableInfo")
    assert decode_hex(table_info["center_node_id"]) == routing_table.center_node_id
    assert table_info["bucket_size"] == routing_table.bucket_size
    assert table_info["num_buckets"] == routing_table.num_buckets
    assert len(table_info["buckets"]) == 6

    bucket_256 = table_info["buckets"]["256"]
    bucket_255 = table_info["buckets"]["255"]
    bucket_254 = table_info["buckets"]["254"]
    bucket_253 = table_info["buckets"]["253"]
    bucket_252 = table_info["buckets"]["252"]
    bucket_251 = table_info["buckets"]["251"]

    assert bucket_256["idx"] == 256
    assert bucket_256["is_full"] is True
    assert len(bucket_256["nodes"]) == routing_table.bucket_size
    assert len(bucket_256["replacement_cache"]) == routing_table.bucket_size

    assert bucket_255["idx"] == 255
    assert bucket_255["is_full"] is True
    assert len(bucket_255["nodes"]) == routing_table.bucket_size
    assert len(bucket_255["replacement_cache"]) == routing_table.bucket_size // 2

    assert bucket_254["idx"] == 254
    assert bucket_254["is_full"] is True
    assert len(bucket_254["nodes"]) == routing_table.bucket_size
    assert len(bucket_254["replacement_cache"]) == routing_table.bucket_size // 4

    assert bucket_253["idx"] == 253
    assert bucket_253["is_full"] is True
    assert len(bucket_253["nodes"]) == routing_table.bucket_size
    assert not bucket_253["replacement_cache"]

    assert bucket_252["idx"] == 252
    assert bucket_252["is_full"] is False
    assert len(bucket_252["nodes"]) == routing_table.bucket_size // 2
    assert not bucket_252["replacement_cache"]

    assert bucket_251["idx"] == 251
    assert bucket_251["is_full"] is False
    assert len(bucket_251["nodes"]) == routing_table.bucket_size // 4
    assert not bucket_251["replacement_cache"]