コード例 #1
0
async def test_dhtnode_replicas():
    dht_size = 20
    initial_peers = 3
    num_replicas = random.randint(1, 20)

    peers = []
    for i in range(dht_size):
        neighbors_i = [
            f'{LOCALHOST}:{node.port}'
            for node in random.sample(peers, min(initial_peers, len(peers)))
        ]
        peers.append(await DHTNode.create(initial_peers=neighbors_i,
                                          num_replicas=num_replicas))

    you = random.choice(peers)
    assert await you.store('key1', 'foo', get_dht_time() + 999)

    actual_key1_replicas = sum(len(peer.protocol.storage) for peer in peers)
    assert num_replicas == actual_key1_replicas

    assert await you.store('key2', 'bar', get_dht_time() + 999)
    total_size = sum(len(peer.protocol.storage) for peer in peers)
    actual_key2_replicas = total_size - actual_key1_replicas
    assert num_replicas == actual_key2_replicas

    assert await you.store('key2', 'baz', get_dht_time() + 1000)
    assert sum(
        len(peer.protocol.storage)
        for peer in peers) == total_size, "total size should not have changed"
コード例 #2
0
def test_composite_validator(validators_for_app):
    validator = CompositeValidator(validators_for_app['A'])
    assert ([type(item) for item in validator._validators] ==
        [SchemaValidator, RSASignatureValidator])

    validator.extend(validators_for_app['B'])
    assert ([type(item) for item in validator._validators] ==
        [SchemaValidator, RSASignatureValidator])
    assert len(validator._validators[0]._schemas) == 2

    local_public_key = validators_for_app['A'][0].local_public_key
    record = DHTRecord(key=DHTID.generate(source='field_b').to_bytes(),
                       subkey=DHTProtocol.serializer.dumps(local_public_key),
                       value=DHTProtocol.serializer.dumps(777),
                       expiration_time=hivemind.get_dht_time() + 10)

    signed_record = dataclasses.replace(record, value=validator.sign_value(record))
    # Expect only one signature since two RSASignatureValidatos have been merged
    assert signed_record.value.count(b'[signature:') == 1
    # Expect successful validation since the second SchemaValidator has been merged to the first
    assert validator.validate(signed_record)
    assert validator.strip_value(signed_record) == record.value

    record = DHTRecord(key=DHTID.generate(source='unknown_key').to_bytes(),
                       subkey=DHTProtocol.IS_REGULAR_VALUE,
                       value=DHTProtocol.serializer.dumps(777),
                       expiration_time=hivemind.get_dht_time() + 10)

    signed_record = dataclasses.replace(record, value=validator.sign_value(record))
    assert signed_record.value.count(b'[signature:') == 0
    # Expect failed validation since `unknown_key` is not a part of any schema
    assert not validator.validate(signed_record)
コード例 #3
0
async def test_dhtnode_signatures():
    alice = await hivemind.DHTNode.create(record_validator=RSASignatureValidator())
    bob = await hivemind.DHTNode.create(
        record_validator=RSASignatureValidator(RSAPrivateKey()),
        initial_peers=[f"{LOCALHOST}:{alice.port}"])
    mallory = await hivemind.DHTNode.create(
        record_validator=RSASignatureValidator(RSAPrivateKey()),
        initial_peers=[f"{LOCALHOST}:{alice.port}"])

    key = b'key'
    subkey = b'protected_subkey' + bob.protocol.record_validator.local_public_key

    assert await bob.store(key, b'true_value', hivemind.get_dht_time() + 10, subkey=subkey)
    assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value'

    store_ok = await mallory.store(key, b'fake_value', hivemind.get_dht_time() + 10, subkey=subkey)
    assert not store_ok
    assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value'

    assert await bob.store(key, b'updated_true_value', hivemind.get_dht_time() + 10, subkey=subkey)
    assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value'

    await bob.shutdown()  # Bob has shut down, now Mallory is the single peer of Alice

    store_ok = await mallory.store(key, b'updated_fake_value',
                                   hivemind.get_dht_time() + 10, subkey=subkey)
    assert not store_ok
    assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value'
コード例 #4
0
def test_get_store():
    peers = []
    for i in range(10):
        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
        peers.append(hivemind.DHT(initial_peers=neighbors_i, start=True))

    node1, node2 = random.sample(peers, 2)
    assert node1.store('key1', 'value1', expiration_time=hivemind.get_dht_time() + 30)
    assert node1.get('key1').value == 'value1'
    assert node2.get('key1').value == 'value1'
    assert node2.get('key2') is None

    future = node1.get('foo', return_future=True)
    assert future.result() is None

    future = node1.get('foo', return_future=True)
    future.cancel()

    assert node2.store('key1', 123, expiration_time=hivemind.get_dht_time() + 31)
    assert node2.store('key2', 456, expiration_time=hivemind.get_dht_time() + 32)
    assert node1.get('key1', latest=True).value == 123
    assert node1.get('key2').value == 456

    assert node1.store('key2', subkey='subkey1', value=789, expiration_time=hivemind.get_dht_time() + 32)
    assert node2.store('key2', subkey='subkey2', value='pew', expiration_time=hivemind.get_dht_time() + 32)
    found_dict = node1.get('key2', latest=True).value
    assert isinstance(found_dict, dict) and len(found_dict) == 2
    assert found_dict['subkey1'].value == 789 and found_dict['subkey2'].value == 'pew'

    for peer in peers:
        peer.shutdown()
コード例 #5
0
def test_dht_add_validators(validators_for_app):
    # One app may create a DHT with its validators
    dht = hivemind.DHT(start=False, record_validators=validators_for_app['A'])

    # While the DHT process is not started, you can't send a command to append new validators
    with pytest.raises(RuntimeError):
        dht.add_validators(validators_for_app['B'])
    dht.run_in_background(await_ready=True)

    # After starting the process, other apps may add new validators to the existing DHT
    dht.add_validators(validators_for_app['B'])

    assert dht.store('field_a', b'bytes_value', hivemind.get_dht_time() + 10)
    assert dht.get('field_a', latest=True).value == b'bytes_value'

    assert not dht.store('field_a', 666, hivemind.get_dht_time() + 10)
    assert dht.get('field_a', latest=True).value == b'bytes_value'

    local_public_key = validators_for_app['A'][0].local_public_key
    assert dht.store('field_b', 777, hivemind.get_dht_time() + 10, subkey=local_public_key)
    dictionary = dht.get('field_b', latest=True).value
    assert (len(dictionary) == 1 and
            dictionary[local_public_key].value == 777)

    assert not dht.store('unknown_key', 666, hivemind.get_dht_time() + 10)
    assert dht.get('unknown_key', latest=True) is None
コード例 #6
0
    async def _tester():
        peers = []
        for i in range(10):
            neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
            peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=256))

        await asyncio.gather(
            random.choice(peers).store('k1', 123, hivemind.get_dht_time() + 999),
            random.choice(peers).store('k2', 567, hivemind.get_dht_time() + 999)
        )

        you = random.choice(peers)

        futures1 = await you.get_many(['k1', 'k2'], return_futures=True)
        assert len(you.pending_get_requests[DHTID.generate('k1')]) == 1
        assert len(you.pending_get_requests[DHTID.generate('k2')]) == 1

        futures2 = await you.get_many(['k2', 'k3'], return_futures=True)
        assert len(you.pending_get_requests[DHTID.generate('k2')]) == 2

        await asyncio.gather(*futures1.values(), *futures2.values())
        futures3 = await you.get_many(['k3'], return_futures=True)
        assert len(you.pending_get_requests[DHTID.generate('k1')]) == 0
        assert len(you.pending_get_requests[DHTID.generate('k2')]) == 0
        assert len(you.pending_get_requests[DHTID.generate('k3')]) == 1

        assert (await futures1['k1'])[0] == 123
        assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
        assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) is None
        test_success.set()
コード例 #7
0
ファイル: test_dht.py プロジェクト: yuejiesong1900/hivemind
def test_maxsize_cache():
    d = LocalStorage(maxsize=1)
    d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
    d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
    assert d.get(DHTID.generate(
        "key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
    assert d.get(DHTID.generate(
        "key1"))[0] is None, "Value with less exp time, must be deleted"
コード例 #8
0
def test_change_expiration_time():
    d = LocalStorage()
    d.store(DHTID.generate("key"), b"val1", get_dht_time() + 1)
    assert d.get(DHTID.generate("key"))[0] == b"val1", "Wrong value"
    d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200)
    time.sleep(1)
    assert d.get(DHTID.generate("key"))[0] == b"val2", "Value must be changed, but still kept in table"
    print("Test change expiration time passed")
コード例 #9
0
    async def _tester():
        node2 = await hivemind.DHTNode.create(cache_refresh_before_expiry=5 *
                                              T,
                                              reuse_get_requests=False)
        node1 = await hivemind.DHTNode.create(
            initial_peers=[f'localhost:{node2.port}'],
            cache_refresh_before_expiry=5 * T,
            listen=False,
            reuse_get_requests=False)
        await node2.store('k', [123, 'value'],
                          expiration_time=hivemind.get_dht_time() + 7 * T)
        await node2.store('k2', [654, 'value'],
                          expiration_time=hivemind.get_dht_time() + 7 * T)
        await node2.store('k3', [654, 'value'],
                          expiration_time=hivemind.get_dht_time() + 15 * T)
        await node1.get_many(['k', 'k2', 'k3', 'k4'])
        assert len(node1.protocol.cache) == 3
        assert len(node1.cache_refresh_queue) == 0

        await node1.get_many(['k', 'k2', 'k3', 'k4'])
        assert len(node1.cache_refresh_queue) == 3

        await node2.store('k', [123, 'value'],
                          expiration_time=hivemind.get_dht_time() + 12 * T)
        await asyncio.sleep(4 * T)
        await node1.get('k')
        await asyncio.sleep(1 * T)

        assert len(node1.protocol.cache) == 3
        assert len(node1.cache_refresh_queue) == 2
        await asyncio.sleep(3 * T)

        assert len(node1.cache_refresh_queue) == 1

        await asyncio.sleep(5 * T)
        assert len(node1.cache_refresh_queue) == 0
        await asyncio.sleep(5 * T)
        assert len(node1.cache_refresh_queue) == 0

        await node2.store('k', [123, 'value'],
                          expiration_time=hivemind.get_dht_time() + 10 * T)
        await node1.get('k')
        await asyncio.sleep(1 * T)
        assert len(node1.cache_refresh_queue) == 0
        await node1.get('k')
        await asyncio.sleep(1 * T)
        assert len(node1.cache_refresh_queue) == 1

        await asyncio.sleep(5 * T)
        assert len(node1.cache_refresh_queue) == 0

        await asyncio.gather(node1.shutdown(), node2.shutdown())
        test_success.set()
コード例 #10
0
def test_getset_averagers():
    dht = hivemind.DHT(start=True)

    t = hivemind.get_dht_time()
    dht.declare_averager(group_key='bucket.0b10110',
                         endpoint='localhvost',
                         expiration_time=t + 60)
    dht.declare_averager(group_key='bucket.0b10110',
                         endpoint='localhvost2',
                         expiration_time=t + 61)

    q1 = dht.get_averagers('bucket.0b10110', only_active=True)

    dht.declare_averager(group_key='bucket.0b10110',
                         endpoint='localhvost',
                         expiration_time=t + 66)
    q2 = dht.get_averagers('bucket.0b10110', only_active=True)

    dht.declare_averager(group_key='bucket.0b10110',
                         endpoint='localhvost2',
                         looking_for_group=False,
                         expiration_time=t + 61)
    q3 = dht.get_averagers('bucket.0b10110', only_active=True)
    q4 = dht.get_averagers('bucket.0b10110', only_active=False)

    assert len(q1) == 2 and ('localhvost', t + 60) in q1 and ('localhvost2',
                                                              t + 61) in q1
    assert len(q2) == 2 and ('localhvost', t + 66) in q2 and ('localhvost2',
                                                              t + 61) in q2
    assert len(q3) == 1 and ('localhvost', t + 66) in q3
    assert len(q4) == 2 and ('localhvost', t + 66) in q4 and ('localhvost2',
                                                              t + 61) in q2
コード例 #11
0
def test_localstorage_freeze():
    d = LocalStorage(maxsize=2)

    with d.freeze():
        d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 0.01)
        assert DHTID.generate("key1") in d
        time.sleep(0.03)
        assert DHTID.generate("key1") in d
    assert DHTID.generate("key1") not in d

    with d.freeze():
        d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
        d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2)
        d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 3)  # key3 will push key1 out due to maxsize
        assert DHTID.generate("key1") in d
    assert DHTID.generate("key1") not in d
コード例 #12
0
ファイル: test_dht.py プロジェクト: yuejiesong1900/hivemind
def test_get_expired():
    d = LocalStorage()
    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1)
    time.sleep(0.5)
    assert d.get(
        DHTID.generate("key")) == (None, None), "Expired value must be deleted"
    print("Test get expired passed")
コード例 #13
0
async def test_key_manager():
    key_manager = GroupKeyManager(hivemind.DHT(start=True), endpoint='localhvost',
                                  prefix='test_averaging', initial_group_bits='10110',
                                  target_group_size=2)

    t = hivemind.get_dht_time()
    key = key_manager.current_key
    await key_manager.declare_averager(key, 'localhvost', expiration_time=t + 60)
    await key_manager.declare_averager(key, 'localhvost2', expiration_time=t + 61)

    q1 = await key_manager.get_averagers(key, only_active=True)

    await key_manager.declare_averager(key, 'localhvost', expiration_time=t + 66)
    q2 = await key_manager.get_averagers(key, only_active=True)

    await key_manager.declare_averager(key, 'localhvost2', expiration_time=t + 61, looking_for_group=False)
    q3 = await key_manager.get_averagers(key, only_active=True)
    q4 = await key_manager.get_averagers(key, only_active=False)

    q5 = await key_manager.get_averagers('nonexistent_key.0b0101', only_active=False)

    assert len(q1) == 2 and ('localhvost', t + 60) in q1 and ('localhvost2', t + 61) in q1
    assert len(q2) == 2 and ('localhvost', t + 66) in q2 and ('localhvost2', t + 61) in q2
    assert len(q3) == 1 and ('localhvost', t + 66) in q3
    assert len(q4) == 2 and ('localhvost', t + 66) in q4 and ('localhvost2', t + 61) in q2
    assert len(q5) == 0
コード例 #14
0
 def report_training_progress(self):
     """ Declare this trainer's current step and the number of batches accumulated towards the next step """
     current_time = hivemind.get_dht_time()
     local_state_info = [self.local_step, self.local_samples_accumulated,
                         self.performance_ema.samples_per_second, current_time]
     assert self.is_valid_peer_state(local_state_info)
     self.dht.store(self.training_progess_key, subkey=self.trainer_uuid, value=local_state_info,
                    expiration_time=current_time + self.collaboration_args.metadata_expiration, return_future=True)
コード例 #15
0
 def check_collaboration_state_periodically(self):
     """
     Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
     """
     while self.is_alive:
         with self.lock:
             self.collaboration_state = self.fetch_collaboration_state()
         time.sleep(max(0, self.collaboration_state.next_fetch_time - hivemind.get_dht_time()))
コード例 #16
0
async def test_dhtnode_blacklist():
    node1 = await hivemind.DHTNode.create(blacklist_time=999)
    node2 = await hivemind.DHTNode.create(
        blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
    node3 = await hivemind.DHTNode.create(
        blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
    node4 = await hivemind.DHTNode.create(
        blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])

    assert await node2.store('abc',
                             123,
                             expiration_time=hivemind.get_dht_time() + 99)
    assert len(node2.blacklist.ban_counter) == 0

    await node3.shutdown()
    await node4.shutdown()

    assert await node2.store('def',
                             456,
                             expiration_time=hivemind.get_dht_time() + 99)

    assert len(node2.blacklist.ban_counter) == 2

    for banned_peer in node2.blacklist.ban_counter:
        assert any(
            banned_peer.endswith(str(port))
            for port in [node3.port, node4.port])

    node3_endpoint = await node3.protocol.get_outgoing_request_endpoint(
        f"{hivemind.LOCALHOST}:{node1.port}")
    node3_endpoint = replace_port(node3_endpoint, node3.port)
    assert await node1.get(
        'abc', latest=True
    )  # force node1 to crawl dht and discover unresponsive peers
    assert node3_endpoint in node1.blacklist

    node2_endpoint = await node2.protocol.get_outgoing_request_endpoint(
        f"{hivemind.LOCALHOST}:{node1.port}")
    node2_endpoint = replace_port(node2_endpoint, node2.port)
    assert await node1.get(
        'abc', latest=True
    )  # force node1 to crawl dht and discover unresponsive peers
    assert node2_endpoint not in node1.blacklist
コード例 #17
0
 def update(self, num_processed: int) -> float:
     """
     :param num_processed: how many items were processed since last call
     :returns: current estimate of performance (samples per second), but at most
     """
     assert num_processed > 0, f"Can't register processing {num_processed} samples"
     self.timestamp, old_timestamp = hivemind.get_dht_time(), self.timestamp
     seconds_per_sample = max(
         0, self.timestamp - old_timestamp) / num_processed
     self.ema_seconds_per_sample = self.alpha * seconds_per_sample + (
         1 - self.alpha) * self.ema_seconds_per_sample
     self.num_updates += 1
     adjusted_seconds_per_sample = self.ema_seconds_per_sample / (
         1 - (1 - self.alpha)**self.num_updates)
     self.samples_per_second = 1 / max(adjusted_seconds_per_sample,
                                       self.eps)
     return self.samples_per_second
コード例 #18
0
def test_empty_table():
    """ Test RPC methods with empty routing table """
    peer_port, peer_id, peer_started = hivemind.find_open_port(
    ), DHTID.generate(), mp.Event()
    peer_proc = mp.Process(target=run_protocol_listener,
                           args=(peer_port, peer_id, peer_started),
                           daemon=True)
    peer_proc.start(), peer_started.wait()

    loop = asyncio.get_event_loop()
    protocol = loop.run_until_complete(
        DHTProtocol.create(DHTID.generate(),
                           bucket_size=20,
                           depth_modulo=5,
                           wait_timeout=5,
                           num_replicas=3,
                           listen=False))

    key, value, expiration = DHTID.generate(), [
        random.random(), {
            'ololo': 'pyshpysh'
        }
    ], get_dht_time() + 1e3

    empty_item, nodes_found = loop.run_until_complete(
        protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
    assert empty_item is None and len(nodes_found) == 0
    assert all(
        loop.run_until_complete(
            protocol.call_store(f'{LOCALHOST}:{peer_port}', [key],
                                [hivemind.MSGPackSerializer.dumps(value)],
                                expiration))), "peer rejected store"

    (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
        protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
    recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
    assert len(nodes_found) == 0
    assert recv_value == value and recv_expiration == expiration

    assert loop.run_until_complete(
        protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
    assert loop.run_until_complete(
        protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None
    peer_proc.terminate()
コード例 #19
0
    def on_step_end(self, args: TrainingArguments,
                    state: transformers.TrainerState,
                    control: transformers.TrainerControl, **kwargs):
        control.should_log = True
        if not self.params_are_finite():
            self.load_from_state(self.previous_state)
            return control
        self.previous_state = self.get_current_state()

        if state.log_history:
            self.loss += state.log_history[-1]['loss']
            self.steps += 1
            if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
                self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
                self.total_samples_processed += self.samples
                samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
                statistics = metrics_utils.LocalMetrics(
                    step=self.collaborative_optimizer.local_step,
                    samples_per_second=samples_per_second,
                    samples_accumulated=self.samples,
                    loss=self.loss,
                    mini_steps=self.steps)
                logger.info(f"Step {self.collaborative_optimizer.local_step}")
                logger.info(
                    f"Your current contribution: {self.total_samples_processed} samples"
                )
                if self.steps:
                    logger.info(f"Local loss: {self.loss / self.steps}")

                self.loss = 0
                self.steps = 0
                if self.collaborative_optimizer.is_synchronized:
                    self.dht.store(key=self.collaborative_optimizer.prefix +
                                   "_metrics",
                                   subkey=self.local_public_key,
                                   value=statistics.dict(),
                                   expiration_time=hivemind.get_dht_time() +
                                   self.statistics_expiration,
                                   return_future=True)

        self.samples = self.collaborative_optimizer.local_samples_accumulated

        return control
コード例 #20
0
ファイル: test_dht.py プロジェクト: yuejiesong1900/hivemind
    def _tester():
        # note: we run everything in a separate process to re-initialize all global states from scratch
        # this helps us avoid undesirable side-effects when running multiple tests in sequence

        loop = asyncio.get_event_loop()
        protocol = loop.run_until_complete(
            DHTProtocol.create(DHTID.generate(),
                               bucket_size=20,
                               depth_modulo=5,
                               wait_timeout=5,
                               num_replicas=3,
                               listen=False))

        key, value, expiration = DHTID.generate(), [
            random.random(), {
                'ololo': 'pyshpysh'
            }
        ], get_dht_time() + 1e3

        recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
            protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
        assert recv_value_bytes is None and recv_expiration is None and len(
            nodes_found) == 0
        assert all(
            loop.run_until_complete(
                protocol.call_store(f'{LOCALHOST}:{peer_port}', [key],
                                    [hivemind.MSGPackSerializer.dumps(value)],
                                    expiration))), "peer rejected store"

        recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
            protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
        recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
        assert len(nodes_found) == 0
        assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
            f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"

        assert loop.run_until_complete(
            protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
        assert loop.run_until_complete(
            protocol.call_ping(
                f'{LOCALHOST}:{hivemind.find_open_port()}')) is None
        test_success.set()
コード例 #21
0
    def fetch_collaboration_state(self) -> CollaborationState:
        """ Read performance statistics reported by peers, estimate progress towards next batch """
        target_batch_size = self.collaboration_args.target_batch_size
        response, _expiration = self.dht.get(self.training_progess_key, latest=True) or (None, -float('inf'))
        current_time = hivemind.get_dht_time()

        if not isinstance(response, dict) or len(response) == 0:
            logger.warning(f"Found no active peers: {response}")
            local_eta_next_step = max(0, target_batch_size - self.local_steps_accumulated) / self.performance_ema.samples_per_second
            return CollaborationState(self.local_step, self.local_samples_accumulated, target_batch_size, 0,
                                      eta_next_step=current_time + local_eta_next_step,
                                      next_fetch_time=current_time + self.collaboration_args.default_refresh_period)

        valid_peer_states = [peer_state.value for peer_state in response.values()
                             if isinstance(peer_state, ValueWithExpiration)
                             and self.is_valid_peer_state(peer_state.value)]
        global_optimizer_step = max(self.local_step, max(step for step, *_ in valid_peer_states))

        num_peers = len(valid_peer_states)
        total_samples_accumulated = estimated_curent_samples = total_samples_per_second = 0

        for opt_step, samples_accumulated, samples_per_second, timestep in valid_peer_states:
            total_samples_per_second += samples_per_second
            if opt_step == global_optimizer_step:
                total_samples_accumulated += samples_accumulated
                estimated_curent_samples += samples_accumulated + max(0, current_time - timestep) * samples_per_second
            # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
            # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.

        estimated_time_to_next_step = max(0, target_batch_size - estimated_curent_samples) / total_samples_per_second

        expected_max_peers = max(num_peers + self.collaboration_args.expected_collaboration_drift_peers,
                                 num_peers * (1 + self.collaboration_args.expected_collaboration_drift_rate))
        time_to_next_fetch = float(np.clip(a=estimated_time_to_next_step * num_peers / expected_max_peers,
                                           a_min=self.collaboration_args.min_refresh_period,
                                           a_max=self.collaboration_args.max_refresh_period))
        logger.info(f"Collaboration accumulated {total_samples_accumulated} samples from {num_peers} peers; "
                    f"ETA {estimated_time_to_next_step:.2f} seconds (refresh in {time_to_next_fetch:.2f}s.)")
        return CollaborationState(global_optimizer_step, total_samples_accumulated, target_batch_size=target_batch_size,
                                  num_peers=num_peers, eta_next_step=current_time + estimated_time_to_next_step,
                                  next_fetch_time=current_time + time_to_next_fetch)
コード例 #22
0
def test_localstorage_top():
    d = LocalStorage(maxsize=3)
    d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
    d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2)
    d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 4)
    assert d.top()[:2] == (DHTID.generate("key1"), b"val1")

    d.store(DHTID.generate("key1"), b"val1_new", get_dht_time() + 3)
    assert d.top()[:2] == (DHTID.generate("key2"), b"val2")

    del d[DHTID.generate('key2')]
    assert d.top()[:2] == (DHTID.generate("key1"), b"val1_new")
    d.store(DHTID.generate("key2"), b"val2_new", get_dht_time() + 5)
    d.store(DHTID.generate("key4"), b"val4", get_dht_time() + 6)  # key4 will push out key1 due to maxsize

    assert d.top()[:2] == (DHTID.generate("key3"), b"val3")
コード例 #23
0
async def test_dhtnode_edge_cases():
    peers = []
    for i in range(5):
        neighbors_i = [
            f'{LOCALHOST}:{node.port}'
            for node in random.sample(peers, min(3, len(peers)))
        ]
        peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i,
                                                   parallel_rpc=4))

    subkeys = [0, '', False, True, 'abyrvalg', 4555]
    keys = subkeys + [()]
    values = subkeys + [[]]
    for key, subkey, value in product(keys, subkeys, values):
        await random.choice(peers).store(
            key=key,
            subkey=subkey,
            value=value,
            expiration_time=hivemind.get_dht_time() + 999),

        stored = await random.choice(peers).get(key=key, latest=True)
        assert stored is not None
        assert subkey in stored.value
        assert stored.value[subkey].value == value
コード例 #24
0
 def should_perform_step(self):
     return self.samples_accumulated >= self.target_batch_size or hivemind.get_dht_time() >= self.eta_next_step
コード例 #25
0
ファイル: test_dht.py プロジェクト: yuejiesong1900/hivemind
    def _tester():
        # note: we run everything in a separate process to re-initialize all global states from scratch
        # this helps us avoid undesirable side-effects when running multiple tests in sequence

        loop = asyncio.get_event_loop()
        for listen in [
                False, True
        ]:  # note: order matters, this test assumes that first run uses listen=False
            protocol = loop.run_until_complete(
                DHTProtocol.create(DHTID.generate(),
                                   bucket_size=20,
                                   depth_modulo=5,
                                   wait_timeout=5,
                                   num_replicas=3,
                                   listen=listen))
            print(f"Self id={protocol.node_id}", flush=True)

            assert loop.run_until_complete(
                protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id

            key, value, expiration = DHTID.generate(), [
                random.random(), {
                    'ololo': 'pyshpysh'
                }
            ], get_dht_time() + 1e3
            store_ok = loop.run_until_complete(
                protocol.call_store(f'{LOCALHOST}:{peer1_port}', [key],
                                    [hivemind.MSGPackSerializer.dumps(value)],
                                    expiration))
            assert all(store_ok), "DHT rejected a trivial store"

            # peer 1 must know about peer 2
            recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
                protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
            recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
            (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
            assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \
                f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"

            assert recv_value == value and recv_expiration == expiration, \
                f"call_find_value expected {value} (expires by {expiration}) " \
                f"but got {recv_value} (expires by {recv_expiration})"

            # peer 2 must know about peer 1, but not have a *random* nonexistent value
            dummy_key = DHTID.generate()
            recv_dummy_value, recv_dummy_expiration, nodes_found_2 = loop.run_until_complete(
                protocol.call_find(f'{LOCALHOST}:{peer2_port}',
                                   [dummy_key]))[dummy_key]
            assert recv_dummy_value is None and recv_dummy_expiration is None, "Non-existent keys shouldn't have values"
            (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
            assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
                f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"

            # cause a non-response by querying a nonexistent peer
            dummy_port = hivemind.find_open_port()
            assert loop.run_until_complete(
                protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None

            if listen:
                loop.run_until_complete(protocol.shutdown())
            print("DHTProtocol test finished successfully!")
            test_success.set()
コード例 #26
0
def test_store():
    d = LocalStorage()
    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.5)
    assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value"
    print("Test store passed")
コード例 #27
0
 def on_train_end(self, *args, **kwargs):
     self.is_alive = False
     logger.info("Sending goodbye to peers")
     self.dht.store(self.training_progess_key, subkey=self.trainer_uuid, value=None,
                    expiration_time=hivemind.get_dht_time() + self.collaboration_args.metadata_expiration)
コード例 #28
0
def test_dht_node():
    # create dht with 50 nodes + your 51-st node
    dht: Dict[Endpoint, DHTID] = {}
    processes: List[mp.Process] = []

    for i in range(50):
        node_id = DHTID.generate()
        peers = random.sample(dht.keys(), min(len(dht), 5))
        pipe_recv, pipe_send = mp.Pipe(duplex=False)
        proc = mp.Process(target=run_node,
                          args=(node_id, peers, pipe_send),
                          daemon=True)
        proc.start()
        port = pipe_recv.recv()
        processes.append(proc)
        dht[f"{LOCALHOST}:{port}"] = node_id

    loop = asyncio.get_event_loop()
    me = loop.run_until_complete(
        DHTNode.create(initial_peers=random.sample(dht.keys(), 5),
                       parallel_rpc=10,
                       cache_refresh_before_expiry=False))

    # test 1: find self
    nearest = loop.run_until_complete(
        me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
    assert len(nearest) == 1 and ':'.join(
        nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}"

    # test 2: find others
    for i in range(10):
        ref_endpoint, query_id = random.choice(list(dht.items()))
        nearest = loop.run_until_complete(
            me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
        assert len(nearest) == 1
        found_node_id, found_endpoint = next(iter(nearest.items()))
        assert found_node_id == query_id and ':'.join(
            found_endpoint.split(':')[-2:]) == ref_endpoint

    # test 3: find neighbors to random nodes
    accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
    jaccard_numerator = jaccard_denominator = 0  # jaccard similarity aka intersection over union
    all_node_ids = list(dht.values())

    for i in range(10):
        query_id = DHTID.generate()
        k_nearest = random.randint(1, 10)
        exclude_self = random.random() > 0.5
        nearest = loop.run_until_complete(
            me.find_nearest_nodes([query_id],
                                  k_nearest=k_nearest,
                                  exclude_self=exclude_self))[query_id]
        nearest_nodes = list(nearest)  # keys from ordered dict

        assert len(
            nearest_nodes
        ) == k_nearest, "beam search must return exactly k_nearest results"
        assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
        assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0
                      ), "results must be sorted by distance"

        ref_nearest = heapq.nsmallest(k_nearest + 1,
                                      all_node_ids,
                                      key=query_id.xor_distance)
        if exclude_self and me.node_id in ref_nearest:
            ref_nearest.remove(me.node_id)
        if len(ref_nearest) > k_nearest:
            ref_nearest.pop()

        accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
        accuracy_denominator += 1

        jaccard_numerator += len(
            set.intersection(set(nearest_nodes), set(ref_nearest)))
        jaccard_denominator += k_nearest

    accuracy = accuracy_numerator / accuracy_denominator
    print("Top-1 accuracy:", accuracy)  # should be 98-100%
    jaccard_index = jaccard_numerator / jaccard_denominator
    print("Jaccard index (intersection over union):",
          jaccard_index)  # should be 95-100%
    assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
    assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"

    # test 4: find all nodes
    dummy = DHTID.generate()
    nearest = loop.run_until_complete(
        me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
    assert len(nearest) == len(dht) + 1
    assert len(
        set.difference(set(nearest.keys()),
                       set(all_node_ids) | {me.node_id})) == 0

    # test 5: node without peers
    detached_node = loop.run_until_complete(DHTNode.create())
    nearest = loop.run_until_complete(detached_node.find_nearest_nodes(
        [dummy]))[dummy]
    assert len(nearest) == 1 and nearest[
        detached_node.node_id] == f"{LOCALHOST}:{detached_node.port}"
    nearest = loop.run_until_complete(
        detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
    assert len(nearest) == 0

    # test 6 store and get value
    true_time = get_dht_time() + 1200
    assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
    that_guy = loop.run_until_complete(
        DHTNode.create(initial_peers=random.sample(dht.keys(), 3),
                       parallel_rpc=10,
                       cache_refresh_before_expiry=False,
                       cache_locally=False))

    for node in [me, that_guy]:
        val, expiration_time = loop.run_until_complete(node.get("mykey"))
        assert val == ["Value", 10], "Wrong value"
        assert expiration_time == true_time, f"Wrong time"

    assert loop.run_until_complete(detached_node.get("mykey")) is None

    # test 7: bulk store and bulk get
    keys = 'foo', 'bar', 'baz', 'zzz'
    values = 3, 2, 'batman', [1, 2, 3]
    store_ok = loop.run_until_complete(
        me.store_many(keys, values, expiration_time=get_dht_time() + 999))
    assert all(store_ok.values()), "failed to store one or more keys"
    response = loop.run_until_complete(me.get_many(keys[::-1]))
    for key, value in zip(keys, values):
        assert key in response and response[key][0] == value

    # test 8: store dictionaries as values (with sub-keys)
    upper_key, subkey1, subkey2, subkey3 = 'ololo', 'k1', 'k2', 'k3'
    now = get_dht_time()
    assert loop.run_until_complete(
        me.store(upper_key,
                 subkey=subkey1,
                 value=123,
                 expiration_time=now + 10))
    assert loop.run_until_complete(
        me.store(upper_key,
                 subkey=subkey2,
                 value=456,
                 expiration_time=now + 20))
    for node in [that_guy, me]:
        value, time = loop.run_until_complete(node.get(upper_key))
        assert isinstance(value, dict) and time == now + 20
        assert value[subkey1] == (123, now + 10)
        assert value[subkey2] == (456, now + 20)
        assert len(value) == 2

    assert not loop.run_until_complete(
        me.store(
            upper_key, subkey=subkey2, value=345, expiration_time=now + 10))
    assert loop.run_until_complete(
        me.store(upper_key,
                 subkey=subkey2,
                 value=567,
                 expiration_time=now + 30))
    assert loop.run_until_complete(
        me.store(upper_key,
                 subkey=subkey3,
                 value=890,
                 expiration_time=now + 50))
    loop.run_until_complete(asyncio.sleep(0.1))  # wait for cache to refresh

    for node in [that_guy, me]:
        value, time = loop.run_until_complete(node.get(upper_key))
        assert isinstance(value, dict) and time == now + 50, (value, time)
        assert value[subkey1] == (123, now + 10)
        assert value[subkey2] == (567, now + 30)
        assert value[subkey3] == (890, now + 50)
        assert len(value) == 3

    for proc in processes:
        proc.terminate()
コード例 #29
0
ファイル: test_dht.py プロジェクト: yuejiesong1900/hivemind
    def _tester():
        # note: we run everything in a separate process to re-initialize all global states from scratch
        # this helps us avoid undesirable side-effects when running multiple tests in sequence
        loop = asyncio.get_event_loop()
        me = loop.run_until_complete(
            DHTNode.create(initial_peers=random.sample(dht.keys(), 5),
                           parallel_rpc=10))

        # test 1: find self
        nearest = loop.run_until_complete(
            me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
        assert len(nearest) == 1 and ':'.join(
            nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}"

        # test 2: find others
        for i in range(10):
            ref_endpoint, query_id = random.choice(list(dht.items()))
            nearest = loop.run_until_complete(
                me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
            assert len(nearest) == 1
            found_node_id, found_endpoint = next(iter(nearest.items()))
            assert found_node_id == query_id and ':'.join(
                found_endpoint.split(':')[-2:]) == ref_endpoint

        # test 3: find neighbors to random nodes
        accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
        jaccard_numerator = jaccard_denominator = 0  # jaccard similarity aka intersection over union
        all_node_ids = list(dht.values())

        for i in range(100):
            query_id = DHTID.generate()
            k_nearest = random.randint(1, 20)
            exclude_self = random.random() > 0.5
            nearest = loop.run_until_complete(
                me.find_nearest_nodes([query_id],
                                      k_nearest=k_nearest,
                                      exclude_self=exclude_self))[query_id]
            nearest_nodes = list(nearest)  # keys from ordered dict

            assert len(
                nearest_nodes
            ) == k_nearest, "beam search must return exactly k_nearest results"
            assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
            assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0
                          ), "results must be sorted by distance"

            ref_nearest = heapq.nsmallest(k_nearest + 1,
                                          all_node_ids,
                                          key=query_id.xor_distance)
            if exclude_self and me.node_id in ref_nearest:
                ref_nearest.remove(me.node_id)
            if len(ref_nearest) > k_nearest:
                ref_nearest.pop()

            accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
            accuracy_denominator += 1

            jaccard_numerator += len(
                set.intersection(set(nearest_nodes), set(ref_nearest)))
            jaccard_denominator += k_nearest

        accuracy = accuracy_numerator / accuracy_denominator
        print("Top-1 accuracy:", accuracy)  # should be 98-100%
        jaccard_index = jaccard_numerator / jaccard_denominator
        print("Jaccard index (intersection over union):",
              jaccard_index)  # should be 95-100%
        assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
        assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"

        # test 4: find all nodes
        dummy = DHTID.generate()
        nearest = loop.run_until_complete(
            me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
        assert len(nearest) == len(dht) + 1
        assert len(
            set.difference(set(nearest.keys()),
                           set(all_node_ids) | {me.node_id})) == 0

        # test 5: node without peers
        other_node = loop.run_until_complete(DHTNode.create())
        nearest = loop.run_until_complete(
            other_node.find_nearest_nodes([dummy]))[dummy]
        assert len(nearest) == 1 and nearest[
            other_node.node_id] == f"{LOCALHOST}:{other_node.port}"
        nearest = loop.run_until_complete(
            other_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
        assert len(nearest) == 0

        # test 6 store and get value
        true_time = get_dht_time() + 1200
        assert loop.run_until_complete(
            me.store("mykey", ["Value", 10], true_time))
        for node in [me, other_node]:
            val, expiration_time = loop.run_until_complete(me.get("mykey"))
            assert expiration_time == true_time, "Wrong time"
            assert val == ["Value", 10], "Wrong value"

        # test 7: bulk store and bulk get
        keys = 'foo', 'bar', 'baz', 'zzz'
        values = 3, 2, 'batman', [1, 2, 3]
        store_ok = loop.run_until_complete(
            me.store_many(keys, values, expiration_time=get_dht_time() + 999))
        assert all(store_ok.values()), "failed to store one or more keys"
        response = loop.run_until_complete(me.get_many(keys[::-1]))
        for key, value in zip(keys, values):
            assert key in response and response[key][0] == value

        test_success.set()
コード例 #30
0
def test_dht_protocol():
    # create the first peer
    peer1_port, peer1_id, peer1_started = hivemind.find_open_port(
    ), DHTID.generate(), mp.Event()
    peer1_proc = mp.Process(target=run_protocol_listener,
                            args=(peer1_port, peer1_id, peer1_started),
                            daemon=True)
    peer1_proc.start(), peer1_started.wait()

    # create another peer that connects to the first peer
    peer2_port, peer2_id, peer2_started = hivemind.find_open_port(
    ), DHTID.generate(), mp.Event()
    peer2_proc = mp.Process(target=run_protocol_listener,
                            args=(peer2_port, peer2_id, peer2_started),
                            kwargs={'ping': f'{LOCALHOST}:{peer1_port}'},
                            daemon=True)
    peer2_proc.start(), peer2_started.wait()

    loop = asyncio.get_event_loop()
    for listen in [
            False, True
    ]:  # note: order matters, this test assumes that first run uses listen=False
        protocol = loop.run_until_complete(
            DHTProtocol.create(DHTID.generate(),
                               bucket_size=20,
                               depth_modulo=5,
                               wait_timeout=5,
                               num_replicas=3,
                               listen=listen))
        print(f"Self id={protocol.node_id}", flush=True)

        assert loop.run_until_complete(
            protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id

        key, value, expiration = DHTID.generate(), [
            random.random(), {
                'ololo': 'pyshpysh'
            }
        ], get_dht_time() + 1e3
        store_ok = loop.run_until_complete(
            protocol.call_store(f'{LOCALHOST}:{peer1_port}', [key],
                                [hivemind.MSGPackSerializer.dumps(value)],
                                expiration))
        assert all(store_ok), "DHT rejected a trivial store"

        # peer 1 must know about peer 2
        (recv_value_bytes,
         recv_expiration), nodes_found = loop.run_until_complete(
             protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
        recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
        (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
        assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \
            f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"

        assert recv_value == value and recv_expiration == expiration, \
            f"call_find_value expected {value} (expires by {expiration}) " \
            f"but got {recv_value} (expires by {recv_expiration})"

        # peer 2 must know about peer 1, but not have a *random* nonexistent value
        dummy_key = DHTID.generate()
        empty_item, nodes_found_2 = loop.run_until_complete(
            protocol.call_find(f'{LOCALHOST}:{peer2_port}',
                               [dummy_key]))[dummy_key]
        assert empty_item is None, "Non-existent keys shouldn't have values"
        (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
        assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
            f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"

        # cause a non-response by querying a nonexistent peer
        dummy_port = hivemind.find_open_port()
        assert loop.run_until_complete(
            protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None

        # store/get a dictionary with sub-keys
        nested_key, subkey1, subkey2 = DHTID.generate(), 'foo', 'bar'
        value1, value2 = [random.random(), {'ololo': 'pyshpysh'}], 'abacaba'
        assert loop.run_until_complete(
            protocol.call_store(
                f'{LOCALHOST}:{peer1_port}',
                keys=[nested_key],
                values=[hivemind.MSGPackSerializer.dumps(value1)],
                expiration_time=[expiration],
                subkeys=[subkey1]))
        assert loop.run_until_complete(
            protocol.call_store(
                f'{LOCALHOST}:{peer1_port}',
                keys=[nested_key],
                values=[hivemind.MSGPackSerializer.dumps(value2)],
                expiration_time=[expiration + 5],
                subkeys=[subkey2]))
        (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
            protocol.call_find(f'{LOCALHOST}:{peer1_port}',
                               [nested_key]))[nested_key]
        assert isinstance(recv_dict, DictionaryDHTValue)
        assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
        assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1),
                                           expiration)
        assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2),
                                           expiration + 5)

        assert LOCALHOST in loop.run_until_complete(
            protocol.get_outgoing_request_endpoint(
                f'{LOCALHOST}:{peer1_port}'))

        if listen:
            loop.run_until_complete(protocol.shutdown())

    peer1_proc.terminate()
    peer2_proc.terminate()