async def test_connection_tracker_server_and_client(event_loop, event_bus):
    tracker = MemoryConnectionTracker()
    remote_a = NodeFactory()
    tracker.record_blacklist(remote_a, 60, "testing")

    blacklisted_ids = await tracker.get_blacklisted()
    assert remote_a.id in blacklisted_ids

    service = ConnectionTrackerServer(event_bus, tracker)

    # start the server
    async with background_asyncio_service(service):
        config = BroadcastConfig(filter_endpoint=NETWORKING_EVENTBUS_ENDPOINT)
        bus_tracker = ConnectionTrackerClient(event_bus, config=config)

        # Give `bus_tracker` a moment to setup subscriptions
        await event_bus.wait_until_any_endpoint_subscribed_to(
            GetBlacklistedPeersRequest)
        # ensure we can read from the tracker over the event bus
        bus_blacklisted_ids = await bus_tracker.get_blacklisted()
        assert remote_a.id in bus_blacklisted_ids

        # ensure we can write to the tracker over the event bus
        remote_b = NodeFactory()
        bus_tracker.record_blacklist(remote_b, 60, "testing")
        # let the underlying broadcast_nowait execute
        await asyncio.sleep(0.01)

        bus_blacklisted_ids = await bus_tracker.get_blacklisted()
        blacklisted_ids = await tracker.get_blacklisted()
        assert remote_b.id in blacklisted_ids
        assert bus_blacklisted_ids == blacklisted_ids

        assert sorted(blacklisted_ids) == sorted([remote_a.id, remote_b.id])
async def test_records_failures():
    connection_tracker = MemoryConnectionTracker()

    node = NodeFactory()
    assert await connection_tracker.should_connect_to(node) is True

    connection_tracker.record_failure(node, HandshakeFailure())

    assert await connection_tracker.should_connect_to(node) is False
    assert connection_tracker._record_exists(node.uri())
async def test_records_failures():
    connection_tracker = MemoryConnectionTracker()

    node = NodeFactory()
    blacklisted_ids = await connection_tracker.get_blacklisted()
    assert node.id not in blacklisted_ids

    connection_tracker.record_failure(node, HandshakeFailure())

    blacklisted_ids = await connection_tracker.get_blacklisted()
    assert node.id in blacklisted_ids
    assert connection_tracker._record_exists(node.id)
async def test_memory_does_not_persist():
    node = NodeFactory()

    connection_tracker_a = MemoryConnectionTracker()
    assert await connection_tracker_a.should_connect_to(node) is True
    connection_tracker_a.record_failure(node, HandshakeFailure())
    assert await connection_tracker_a.should_connect_to(node) is False

    # open a second instance
    connection_tracker_b = MemoryConnectionTracker()

    # the second instance has no memory of the failure
    assert await connection_tracker_b.should_connect_to(node) is True
    assert await connection_tracker_a.should_connect_to(node) is False
async def test_timeout_works():
    node = NodeFactory()

    connection_tracker = MemoryConnectionTracker()
    assert await connection_tracker.should_connect_to(node) is True

    connection_tracker.record_failure(node, HandshakeFailure())
    assert await connection_tracker.should_connect_to(node) is False

    record = connection_tracker._get_record(node.uri())
    record.expires_at -= datetime.timedelta(seconds=120)
    connection_tracker.session.add(record)
    connection_tracker.session.commit()

    assert await connection_tracker.should_connect_to(node) is True
async def test_memory_does_not_persist():
    node = NodeFactory()

    connection_tracker_a = MemoryConnectionTracker()
    blacklisted_ids = await connection_tracker_a.get_blacklisted()
    assert node.id not in blacklisted_ids
    connection_tracker_a.record_failure(node, HandshakeFailure())
    blacklisted_ids = await connection_tracker_a.get_blacklisted()
    assert node.id in blacklisted_ids

    # open a second instance
    connection_tracker_b = MemoryConnectionTracker()

    # the second instance has no memory of the failure
    tracker_b_blacklisted_ids = await connection_tracker_b.get_blacklisted()
    assert node.id not in tracker_b_blacklisted_ids

    tracker_a_blacklisted_ids = await connection_tracker_a.get_blacklisted()
    assert node.id in tracker_a_blacklisted_ids
Example #7
0
    def _get_blacklist_tracker(cls, boot_info: BootInfo) -> BaseConnectionTracker:
        backend = boot_info.args.network_tracking_backend

        if backend is TrackingBackend.SQLITE3:
            session = cls._get_database_session(boot_info)
            return SQLiteConnectionTracker(session)
        elif backend is TrackingBackend.MEMORY:
            return MemoryConnectionTracker()
        elif backend is TrackingBackend.DO_NOT_TRACK:
            return NoopConnectionTracker()
        else:
            raise Exception(f"INVARIANT: {backend}")
async def test_get_blacklisted():
    node1, node2 = NodeFactory(), NodeFactory()
    connection_tracker = MemoryConnectionTracker()

    connection_tracker.record_blacklist(node1, timeout_seconds=10, reason='')
    connection_tracker.record_blacklist(node2, timeout_seconds=0, reason='')

    blacklisted_ids = await connection_tracker.get_blacklisted()

    assert blacklisted_ids == tuple([node1.id])