async def test_dhtnode_replicas(): dht_size = 20 initial_peers = 3 num_replicas = random.randint(1, 20) peers = [] for i in range(dht_size): neighbors_i = [ f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers))) ] peers.append(await DHTNode.create(initial_peers=neighbors_i, num_replicas=num_replicas)) you = random.choice(peers) assert await you.store('key1', 'foo', get_dht_time() + 999) actual_key1_replicas = sum(len(peer.protocol.storage) for peer in peers) assert num_replicas == actual_key1_replicas assert await you.store('key2', 'bar', get_dht_time() + 999) total_size = sum(len(peer.protocol.storage) for peer in peers) actual_key2_replicas = total_size - actual_key1_replicas assert num_replicas == actual_key2_replicas assert await you.store('key2', 'baz', get_dht_time() + 1000) assert sum( len(peer.protocol.storage) for peer in peers) == total_size, "total size should not have changed"
def test_composite_validator(validators_for_app): validator = CompositeValidator(validators_for_app['A']) assert ([type(item) for item in validator._validators] == [SchemaValidator, RSASignatureValidator]) validator.extend(validators_for_app['B']) assert ([type(item) for item in validator._validators] == [SchemaValidator, RSASignatureValidator]) assert len(validator._validators[0]._schemas) == 2 local_public_key = validators_for_app['A'][0].local_public_key record = DHTRecord(key=DHTID.generate(source='field_b').to_bytes(), subkey=DHTProtocol.serializer.dumps(local_public_key), value=DHTProtocol.serializer.dumps(777), expiration_time=hivemind.get_dht_time() + 10) signed_record = dataclasses.replace(record, value=validator.sign_value(record)) # Expect only one signature since two RSASignatureValidatos have been merged assert signed_record.value.count(b'[signature:') == 1 # Expect successful validation since the second SchemaValidator has been merged to the first assert validator.validate(signed_record) assert validator.strip_value(signed_record) == record.value record = DHTRecord(key=DHTID.generate(source='unknown_key').to_bytes(), subkey=DHTProtocol.IS_REGULAR_VALUE, value=DHTProtocol.serializer.dumps(777), expiration_time=hivemind.get_dht_time() + 10) signed_record = dataclasses.replace(record, value=validator.sign_value(record)) assert signed_record.value.count(b'[signature:') == 0 # Expect failed validation since `unknown_key` is not a part of any schema assert not validator.validate(signed_record)
async def test_dhtnode_signatures(): alice = await hivemind.DHTNode.create(record_validator=RSASignatureValidator()) bob = await hivemind.DHTNode.create( record_validator=RSASignatureValidator(RSAPrivateKey()), initial_peers=[f"{LOCALHOST}:{alice.port}"]) mallory = await hivemind.DHTNode.create( record_validator=RSASignatureValidator(RSAPrivateKey()), initial_peers=[f"{LOCALHOST}:{alice.port}"]) key = b'key' subkey = b'protected_subkey' + bob.protocol.record_validator.local_public_key assert await bob.store(key, b'true_value', hivemind.get_dht_time() + 10, subkey=subkey) assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value' store_ok = await mallory.store(key, b'fake_value', hivemind.get_dht_time() + 10, subkey=subkey) assert not store_ok assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value' assert await bob.store(key, b'updated_true_value', hivemind.get_dht_time() + 10, subkey=subkey) assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value' await bob.shutdown() # Bob has shut down, now Mallory is the single peer of Alice store_ok = await mallory.store(key, b'updated_fake_value', hivemind.get_dht_time() + 10, subkey=subkey) assert not store_ok assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value'
def test_get_store(): peers = [] for i in range(10): neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))] peers.append(hivemind.DHT(initial_peers=neighbors_i, start=True)) node1, node2 = random.sample(peers, 2) assert node1.store('key1', 'value1', expiration_time=hivemind.get_dht_time() + 30) assert node1.get('key1').value == 'value1' assert node2.get('key1').value == 'value1' assert node2.get('key2') is None future = node1.get('foo', return_future=True) assert future.result() is None future = node1.get('foo', return_future=True) future.cancel() assert node2.store('key1', 123, expiration_time=hivemind.get_dht_time() + 31) assert node2.store('key2', 456, expiration_time=hivemind.get_dht_time() + 32) assert node1.get('key1', latest=True).value == 123 assert node1.get('key2').value == 456 assert node1.store('key2', subkey='subkey1', value=789, expiration_time=hivemind.get_dht_time() + 32) assert node2.store('key2', subkey='subkey2', value='pew', expiration_time=hivemind.get_dht_time() + 32) found_dict = node1.get('key2', latest=True).value assert isinstance(found_dict, dict) and len(found_dict) == 2 assert found_dict['subkey1'].value == 789 and found_dict['subkey2'].value == 'pew' for peer in peers: peer.shutdown()
def test_dht_add_validators(validators_for_app): # One app may create a DHT with its validators dht = hivemind.DHT(start=False, record_validators=validators_for_app['A']) # While the DHT process is not started, you can't send a command to append new validators with pytest.raises(RuntimeError): dht.add_validators(validators_for_app['B']) dht.run_in_background(await_ready=True) # After starting the process, other apps may add new validators to the existing DHT dht.add_validators(validators_for_app['B']) assert dht.store('field_a', b'bytes_value', hivemind.get_dht_time() + 10) assert dht.get('field_a', latest=True).value == b'bytes_value' assert not dht.store('field_a', 666, hivemind.get_dht_time() + 10) assert dht.get('field_a', latest=True).value == b'bytes_value' local_public_key = validators_for_app['A'][0].local_public_key assert dht.store('field_b', 777, hivemind.get_dht_time() + 10, subkey=local_public_key) dictionary = dht.get('field_b', latest=True).value assert (len(dictionary) == 1 and dictionary[local_public_key].value == 777) assert not dht.store('unknown_key', 666, hivemind.get_dht_time() + 10) assert dht.get('unknown_key', latest=True) is None
async def _tester(): peers = [] for i in range(10): neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))] peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=256)) await asyncio.gather( random.choice(peers).store('k1', 123, hivemind.get_dht_time() + 999), random.choice(peers).store('k2', 567, hivemind.get_dht_time() + 999) ) you = random.choice(peers) futures1 = await you.get_many(['k1', 'k2'], return_futures=True) assert len(you.pending_get_requests[DHTID.generate('k1')]) == 1 assert len(you.pending_get_requests[DHTID.generate('k2')]) == 1 futures2 = await you.get_many(['k2', 'k3'], return_futures=True) assert len(you.pending_get_requests[DHTID.generate('k2')]) == 2 await asyncio.gather(*futures1.values(), *futures2.values()) futures3 = await you.get_many(['k3'], return_futures=True) assert len(you.pending_get_requests[DHTID.generate('k1')]) == 0 assert len(you.pending_get_requests[DHTID.generate('k2')]) == 0 assert len(you.pending_get_requests[DHTID.generate('k3')]) == 1 assert (await futures1['k1'])[0] == 123 assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567 assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) is None test_success.set()
def test_maxsize_cache(): d = LocalStorage(maxsize=1) d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1) d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200) assert d.get(DHTID.generate( "key2"))[0] == b"val2", "Value with bigger exp. time must be kept" assert d.get(DHTID.generate( "key1"))[0] is None, "Value with less exp time, must be deleted"
def test_change_expiration_time(): d = LocalStorage() d.store(DHTID.generate("key"), b"val1", get_dht_time() + 1) assert d.get(DHTID.generate("key"))[0] == b"val1", "Wrong value" d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200) time.sleep(1) assert d.get(DHTID.generate("key"))[0] == b"val2", "Value must be changed, but still kept in table" print("Test change expiration time passed")
async def _tester(): node2 = await hivemind.DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False) node1 = await hivemind.DHTNode.create( initial_peers=[f'localhost:{node2.port}'], cache_refresh_before_expiry=5 * T, listen=False, reuse_get_requests=False) await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T) await node2.store('k2', [654, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T) await node2.store('k3', [654, 'value'], expiration_time=hivemind.get_dht_time() + 15 * T) await node1.get_many(['k', 'k2', 'k3', 'k4']) assert len(node1.protocol.cache) == 3 assert len(node1.cache_refresh_queue) == 0 await node1.get_many(['k', 'k2', 'k3', 'k4']) assert len(node1.cache_refresh_queue) == 3 await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 12 * T) await asyncio.sleep(4 * T) await node1.get('k') await asyncio.sleep(1 * T) assert len(node1.protocol.cache) == 3 assert len(node1.cache_refresh_queue) == 2 await asyncio.sleep(3 * T) assert len(node1.cache_refresh_queue) == 1 await asyncio.sleep(5 * T) assert len(node1.cache_refresh_queue) == 0 await asyncio.sleep(5 * T) assert len(node1.cache_refresh_queue) == 0 await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 10 * T) await node1.get('k') await asyncio.sleep(1 * T) assert len(node1.cache_refresh_queue) == 0 await node1.get('k') await asyncio.sleep(1 * T) assert len(node1.cache_refresh_queue) == 1 await asyncio.sleep(5 * T) assert len(node1.cache_refresh_queue) == 0 await asyncio.gather(node1.shutdown(), node2.shutdown()) test_success.set()
def test_getset_averagers(): dht = hivemind.DHT(start=True) t = hivemind.get_dht_time() dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost', expiration_time=t + 60) dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost2', expiration_time=t + 61) q1 = dht.get_averagers('bucket.0b10110', only_active=True) dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost', expiration_time=t + 66) q2 = dht.get_averagers('bucket.0b10110', only_active=True) dht.declare_averager(group_key='bucket.0b10110', endpoint='localhvost2', looking_for_group=False, expiration_time=t + 61) q3 = dht.get_averagers('bucket.0b10110', only_active=True) q4 = dht.get_averagers('bucket.0b10110', only_active=False) assert len(q1) == 2 and ('localhvost', t + 60) in q1 and ('localhvost2', t + 61) in q1 assert len(q2) == 2 and ('localhvost', t + 66) in q2 and ('localhvost2', t + 61) in q2 assert len(q3) == 1 and ('localhvost', t + 66) in q3 assert len(q4) == 2 and ('localhvost', t + 66) in q4 and ('localhvost2', t + 61) in q2
def test_localstorage_freeze(): d = LocalStorage(maxsize=2) with d.freeze(): d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 0.01) assert DHTID.generate("key1") in d time.sleep(0.03) assert DHTID.generate("key1") in d assert DHTID.generate("key1") not in d with d.freeze(): d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1) d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2) d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 3) # key3 will push key1 out due to maxsize assert DHTID.generate("key1") in d assert DHTID.generate("key1") not in d
def test_get_expired(): d = LocalStorage() d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1) time.sleep(0.5) assert d.get( DHTID.generate("key")) == (None, None), "Expired value must be deleted" print("Test get expired passed")
async def test_key_manager(): key_manager = GroupKeyManager(hivemind.DHT(start=True), endpoint='localhvost', prefix='test_averaging', initial_group_bits='10110', target_group_size=2) t = hivemind.get_dht_time() key = key_manager.current_key await key_manager.declare_averager(key, 'localhvost', expiration_time=t + 60) await key_manager.declare_averager(key, 'localhvost2', expiration_time=t + 61) q1 = await key_manager.get_averagers(key, only_active=True) await key_manager.declare_averager(key, 'localhvost', expiration_time=t + 66) q2 = await key_manager.get_averagers(key, only_active=True) await key_manager.declare_averager(key, 'localhvost2', expiration_time=t + 61, looking_for_group=False) q3 = await key_manager.get_averagers(key, only_active=True) q4 = await key_manager.get_averagers(key, only_active=False) q5 = await key_manager.get_averagers('nonexistent_key.0b0101', only_active=False) assert len(q1) == 2 and ('localhvost', t + 60) in q1 and ('localhvost2', t + 61) in q1 assert len(q2) == 2 and ('localhvost', t + 66) in q2 and ('localhvost2', t + 61) in q2 assert len(q3) == 1 and ('localhvost', t + 66) in q3 assert len(q4) == 2 and ('localhvost', t + 66) in q4 and ('localhvost2', t + 61) in q2 assert len(q5) == 0
def report_training_progress(self): """ Declare this trainer's current step and the number of batches accumulated towards the next step """ current_time = hivemind.get_dht_time() local_state_info = [self.local_step, self.local_samples_accumulated, self.performance_ema.samples_per_second, current_time] assert self.is_valid_peer_state(local_state_info) self.dht.store(self.training_progess_key, subkey=self.trainer_uuid, value=local_state_info, expiration_time=current_time + self.collaboration_args.metadata_expiration, return_future=True)
def check_collaboration_state_periodically(self): """ Periodically check the training progress from all peers. Trigger update after target_batch_size total samples """ while self.is_alive: with self.lock: self.collaboration_state = self.fetch_collaboration_state() time.sleep(max(0, self.collaboration_state.next_fetch_time - hivemind.get_dht_time()))
async def test_dhtnode_blacklist(): node1 = await hivemind.DHTNode.create(blacklist_time=999) node2 = await hivemind.DHTNode.create( blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"]) node3 = await hivemind.DHTNode.create( blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"]) node4 = await hivemind.DHTNode.create( blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"]) assert await node2.store('abc', 123, expiration_time=hivemind.get_dht_time() + 99) assert len(node2.blacklist.ban_counter) == 0 await node3.shutdown() await node4.shutdown() assert await node2.store('def', 456, expiration_time=hivemind.get_dht_time() + 99) assert len(node2.blacklist.ban_counter) == 2 for banned_peer in node2.blacklist.ban_counter: assert any( banned_peer.endswith(str(port)) for port in [node3.port, node4.port]) node3_endpoint = await node3.protocol.get_outgoing_request_endpoint( f"{hivemind.LOCALHOST}:{node1.port}") node3_endpoint = replace_port(node3_endpoint, node3.port) assert await node1.get( 'abc', latest=True ) # force node1 to crawl dht and discover unresponsive peers assert node3_endpoint in node1.blacklist node2_endpoint = await node2.protocol.get_outgoing_request_endpoint( f"{hivemind.LOCALHOST}:{node1.port}") node2_endpoint = replace_port(node2_endpoint, node2.port) assert await node1.get( 'abc', latest=True ) # force node1 to crawl dht and discover unresponsive peers assert node2_endpoint not in node1.blacklist
def update(self, num_processed: int) -> float: """ :param num_processed: how many items were processed since last call :returns: current estimate of performance (samples per second), but at most """ assert num_processed > 0, f"Can't register processing {num_processed} samples" self.timestamp, old_timestamp = hivemind.get_dht_time(), self.timestamp seconds_per_sample = max( 0, self.timestamp - old_timestamp) / num_processed self.ema_seconds_per_sample = self.alpha * seconds_per_sample + ( 1 - self.alpha) * self.ema_seconds_per_sample self.num_updates += 1 adjusted_seconds_per_sample = self.ema_seconds_per_sample / ( 1 - (1 - self.alpha)**self.num_updates) self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps) return self.samples_per_second
def test_empty_table(): """ Test RPC methods with empty routing table """ peer_port, peer_id, peer_started = hivemind.find_open_port( ), DHTID.generate(), mp.Event() peer_proc = mp.Process(target=run_protocol_listener, args=(peer_port, peer_id, peer_started), daemon=True) peer_proc.start(), peer_started.wait() loop = asyncio.get_event_loop() protocol = loop.run_until_complete( DHTProtocol.create(DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False)) key, value, expiration = DHTID.generate(), [ random.random(), { 'ololo': 'pyshpysh' } ], get_dht_time() + 1e3 empty_item, nodes_found = loop.run_until_complete( protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key] assert empty_item is None and len(nodes_found) == 0 assert all( loop.run_until_complete( protocol.call_store(f'{LOCALHOST}:{peer_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration))), "peer rejected store" (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete( protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key] recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes) assert len(nodes_found) == 0 assert recv_value == value and recv_expiration == expiration assert loop.run_until_complete( protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id assert loop.run_until_complete( protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None peer_proc.terminate()
def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): control.should_log = True if not self.params_are_finite(): self.load_from_state(self.previous_state) return control self.previous_state = self.get_current_state() if state.log_history: self.loss += state.log_history[-1]['loss'] self.steps += 1 if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step: self.last_reported_collaboration_step = self.collaborative_optimizer.local_step self.total_samples_processed += self.samples samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second statistics = metrics_utils.LocalMetrics( step=self.collaborative_optimizer.local_step, samples_per_second=samples_per_second, samples_accumulated=self.samples, loss=self.loss, mini_steps=self.steps) logger.info(f"Step {self.collaborative_optimizer.local_step}") logger.info( f"Your current contribution: {self.total_samples_processed} samples" ) if self.steps: logger.info(f"Local loss: {self.loss / self.steps}") self.loss = 0 self.steps = 0 if self.collaborative_optimizer.is_synchronized: self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics", subkey=self.local_public_key, value=statistics.dict(), expiration_time=hivemind.get_dht_time() + self.statistics_expiration, return_future=True) self.samples = self.collaborative_optimizer.local_samples_accumulated return control
def _tester(): # note: we run everything in a separate process to re-initialize all global states from scratch # this helps us avoid undesirable side-effects when running multiple tests in sequence loop = asyncio.get_event_loop() protocol = loop.run_until_complete( DHTProtocol.create(DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False)) key, value, expiration = DHTID.generate(), [ random.random(), { 'ololo': 'pyshpysh' } ], get_dht_time() + 1e3 recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete( protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key] assert recv_value_bytes is None and recv_expiration is None and len( nodes_found) == 0 assert all( loop.run_until_complete( protocol.call_store(f'{LOCALHOST}:{peer_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration))), "peer rejected store" recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete( protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key] recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes) assert len(nodes_found) == 0 assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \ f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})" assert loop.run_until_complete( protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id assert loop.run_until_complete( protocol.call_ping( f'{LOCALHOST}:{hivemind.find_open_port()}')) is None test_success.set()
def fetch_collaboration_state(self) -> CollaborationState: """ Read performance statistics reported by peers, estimate progress towards next batch """ target_batch_size = self.collaboration_args.target_batch_size response, _expiration = self.dht.get(self.training_progess_key, latest=True) or (None, -float('inf')) current_time = hivemind.get_dht_time() if not isinstance(response, dict) or len(response) == 0: logger.warning(f"Found no active peers: {response}") local_eta_next_step = max(0, target_batch_size - self.local_steps_accumulated) / self.performance_ema.samples_per_second return CollaborationState(self.local_step, self.local_samples_accumulated, target_batch_size, 0, eta_next_step=current_time + local_eta_next_step, next_fetch_time=current_time + self.collaboration_args.default_refresh_period) valid_peer_states = [peer_state.value for peer_state in response.values() if isinstance(peer_state, ValueWithExpiration) and self.is_valid_peer_state(peer_state.value)] global_optimizer_step = max(self.local_step, max(step for step, *_ in valid_peer_states)) num_peers = len(valid_peer_states) total_samples_accumulated = estimated_curent_samples = total_samples_per_second = 0 for opt_step, samples_accumulated, samples_per_second, timestep in valid_peer_states: total_samples_per_second += samples_per_second if opt_step == global_optimizer_step: total_samples_accumulated += samples_accumulated estimated_curent_samples += samples_accumulated + max(0, current_time - timestep) * samples_per_second # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance; # the rationale behind this is that outdated peers will synchronize and begin contributing shortly. estimated_time_to_next_step = max(0, target_batch_size - estimated_curent_samples) / total_samples_per_second expected_max_peers = max(num_peers + self.collaboration_args.expected_collaboration_drift_peers, num_peers * (1 + self.collaboration_args.expected_collaboration_drift_rate)) time_to_next_fetch = float(np.clip(a=estimated_time_to_next_step * num_peers / expected_max_peers, a_min=self.collaboration_args.min_refresh_period, a_max=self.collaboration_args.max_refresh_period)) logger.info(f"Collaboration accumulated {total_samples_accumulated} samples from {num_peers} peers; " f"ETA {estimated_time_to_next_step:.2f} seconds (refresh in {time_to_next_fetch:.2f}s.)") return CollaborationState(global_optimizer_step, total_samples_accumulated, target_batch_size=target_batch_size, num_peers=num_peers, eta_next_step=current_time + estimated_time_to_next_step, next_fetch_time=current_time + time_to_next_fetch)
def test_localstorage_top(): d = LocalStorage(maxsize=3) d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1) d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2) d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 4) assert d.top()[:2] == (DHTID.generate("key1"), b"val1") d.store(DHTID.generate("key1"), b"val1_new", get_dht_time() + 3) assert d.top()[:2] == (DHTID.generate("key2"), b"val2") del d[DHTID.generate('key2')] assert d.top()[:2] == (DHTID.generate("key1"), b"val1_new") d.store(DHTID.generate("key2"), b"val2_new", get_dht_time() + 5) d.store(DHTID.generate("key4"), b"val4", get_dht_time() + 6) # key4 will push out key1 due to maxsize assert d.top()[:2] == (DHTID.generate("key3"), b"val3")
async def test_dhtnode_edge_cases(): peers = [] for i in range(5): neighbors_i = [ f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers))) ] peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=4)) subkeys = [0, '', False, True, 'abyrvalg', 4555] keys = subkeys + [()] values = subkeys + [[]] for key, subkey, value in product(keys, subkeys, values): await random.choice(peers).store( key=key, subkey=subkey, value=value, expiration_time=hivemind.get_dht_time() + 999), stored = await random.choice(peers).get(key=key, latest=True) assert stored is not None assert subkey in stored.value assert stored.value[subkey].value == value
def should_perform_step(self): return self.samples_accumulated >= self.target_batch_size or hivemind.get_dht_time() >= self.eta_next_step
def _tester(): # note: we run everything in a separate process to re-initialize all global states from scratch # this helps us avoid undesirable side-effects when running multiple tests in sequence loop = asyncio.get_event_loop() for listen in [ False, True ]: # note: order matters, this test assumes that first run uses listen=False protocol = loop.run_until_complete( DHTProtocol.create(DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen)) print(f"Self id={protocol.node_id}", flush=True) assert loop.run_until_complete( protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id key, value, expiration = DHTID.generate(), [ random.random(), { 'ololo': 'pyshpysh' } ], get_dht_time() + 1e3 store_ok = loop.run_until_complete( protocol.call_store(f'{LOCALHOST}:{peer1_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)) assert all(store_ok), "DHT rejected a trivial store" # peer 1 must know about peer 2 recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete( protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key] recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes) (recv_id, recv_endpoint) = next(iter(nodes_found.items())) assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \ f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}" assert recv_value == value and recv_expiration == expiration, \ f"call_find_value expected {value} (expires by {expiration}) " \ f"but got {recv_value} (expires by {recv_expiration})" # peer 2 must know about peer 1, but not have a *random* nonexistent value dummy_key = DHTID.generate() recv_dummy_value, recv_dummy_expiration, nodes_found_2 = loop.run_until_complete( protocol.call_find(f'{LOCALHOST}:{peer2_port}', [dummy_key]))[dummy_key] assert recv_dummy_value is None and recv_dummy_expiration is None, "Non-existent keys shouldn't have values" (recv_id, recv_endpoint) = next(iter(nodes_found_2.items())) assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \ f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}" # cause a non-response by querying a nonexistent peer dummy_port = hivemind.find_open_port() assert loop.run_until_complete( protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None if listen: loop.run_until_complete(protocol.shutdown()) print("DHTProtocol test finished successfully!") test_success.set()
def test_store(): d = LocalStorage() d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.5) assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value" print("Test store passed")
def on_train_end(self, *args, **kwargs): self.is_alive = False logger.info("Sending goodbye to peers") self.dht.store(self.training_progess_key, subkey=self.trainer_uuid, value=None, expiration_time=hivemind.get_dht_time() + self.collaboration_args.metadata_expiration)
def test_dht_node(): # create dht with 50 nodes + your 51-st node dht: Dict[Endpoint, DHTID] = {} processes: List[mp.Process] = [] for i in range(50): node_id = DHTID.generate() peers = random.sample(dht.keys(), min(len(dht), 5)) pipe_recv, pipe_send = mp.Pipe(duplex=False) proc = mp.Process(target=run_node, args=(node_id, peers, pipe_send), daemon=True) proc.start() port = pipe_recv.recv() processes.append(proc) dht[f"{LOCALHOST}:{port}"] = node_id loop = asyncio.get_event_loop() me = loop.run_until_complete( DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10, cache_refresh_before_expiry=False)) # test 1: find self nearest = loop.run_until_complete( me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id] assert len(nearest) == 1 and ':'.join( nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}" # test 2: find others for i in range(10): ref_endpoint, query_id = random.choice(list(dht.items())) nearest = loop.run_until_complete( me.find_nearest_nodes([query_id], k_nearest=1))[query_id] assert len(nearest) == 1 found_node_id, found_endpoint = next(iter(nearest.items())) assert found_node_id == query_id and ':'.join( found_endpoint.split(':')[-2:]) == ref_endpoint # test 3: find neighbors to random nodes accuracy_numerator = accuracy_denominator = 0 # top-1 nearest neighbor accuracy jaccard_numerator = jaccard_denominator = 0 # jaccard similarity aka intersection over union all_node_ids = list(dht.values()) for i in range(10): query_id = DHTID.generate() k_nearest = random.randint(1, 10) exclude_self = random.random() > 0.5 nearest = loop.run_until_complete( me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self))[query_id] nearest_nodes = list(nearest) # keys from ordered dict assert len( nearest_nodes ) == k_nearest, "beam search must return exactly k_nearest results" assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self" assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0 ), "results must be sorted by distance" ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance) if exclude_self and me.node_id in ref_nearest: ref_nearest.remove(me.node_id) if len(ref_nearest) > k_nearest: ref_nearest.pop() accuracy_numerator += nearest_nodes[0] == ref_nearest[0] accuracy_denominator += 1 jaccard_numerator += len( set.intersection(set(nearest_nodes), set(ref_nearest))) jaccard_denominator += k_nearest accuracy = accuracy_numerator / accuracy_denominator print("Top-1 accuracy:", accuracy) # should be 98-100% jaccard_index = jaccard_numerator / jaccard_denominator print("Jaccard index (intersection over union):", jaccard_index) # should be 95-100% assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})" assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})" # test 4: find all nodes dummy = DHTID.generate() nearest = loop.run_until_complete( me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy] assert len(nearest) == len(dht) + 1 assert len( set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0 # test 5: node without peers detached_node = loop.run_until_complete(DHTNode.create()) nearest = loop.run_until_complete(detached_node.find_nearest_nodes( [dummy]))[dummy] assert len(nearest) == 1 and nearest[ detached_node.node_id] == f"{LOCALHOST}:{detached_node.port}" nearest = loop.run_until_complete( detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy] assert len(nearest) == 0 # test 6 store and get value true_time = get_dht_time() + 1200 assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time)) that_guy = loop.run_until_complete( DHTNode.create(initial_peers=random.sample(dht.keys(), 3), parallel_rpc=10, cache_refresh_before_expiry=False, cache_locally=False)) for node in [me, that_guy]: val, expiration_time = loop.run_until_complete(node.get("mykey")) assert val == ["Value", 10], "Wrong value" assert expiration_time == true_time, f"Wrong time" assert loop.run_until_complete(detached_node.get("mykey")) is None # test 7: bulk store and bulk get keys = 'foo', 'bar', 'baz', 'zzz' values = 3, 2, 'batman', [1, 2, 3] store_ok = loop.run_until_complete( me.store_many(keys, values, expiration_time=get_dht_time() + 999)) assert all(store_ok.values()), "failed to store one or more keys" response = loop.run_until_complete(me.get_many(keys[::-1])) for key, value in zip(keys, values): assert key in response and response[key][0] == value # test 8: store dictionaries as values (with sub-keys) upper_key, subkey1, subkey2, subkey3 = 'ololo', 'k1', 'k2', 'k3' now = get_dht_time() assert loop.run_until_complete( me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10)) assert loop.run_until_complete( me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20)) for node in [that_guy, me]: value, time = loop.run_until_complete(node.get(upper_key)) assert isinstance(value, dict) and time == now + 20 assert value[subkey1] == (123, now + 10) assert value[subkey2] == (456, now + 20) assert len(value) == 2 assert not loop.run_until_complete( me.store( upper_key, subkey=subkey2, value=345, expiration_time=now + 10)) assert loop.run_until_complete( me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30)) assert loop.run_until_complete( me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50)) loop.run_until_complete(asyncio.sleep(0.1)) # wait for cache to refresh for node in [that_guy, me]: value, time = loop.run_until_complete(node.get(upper_key)) assert isinstance(value, dict) and time == now + 50, (value, time) assert value[subkey1] == (123, now + 10) assert value[subkey2] == (567, now + 30) assert value[subkey3] == (890, now + 50) assert len(value) == 3 for proc in processes: proc.terminate()
def _tester(): # note: we run everything in a separate process to re-initialize all global states from scratch # this helps us avoid undesirable side-effects when running multiple tests in sequence loop = asyncio.get_event_loop() me = loop.run_until_complete( DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10)) # test 1: find self nearest = loop.run_until_complete( me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id] assert len(nearest) == 1 and ':'.join( nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}" # test 2: find others for i in range(10): ref_endpoint, query_id = random.choice(list(dht.items())) nearest = loop.run_until_complete( me.find_nearest_nodes([query_id], k_nearest=1))[query_id] assert len(nearest) == 1 found_node_id, found_endpoint = next(iter(nearest.items())) assert found_node_id == query_id and ':'.join( found_endpoint.split(':')[-2:]) == ref_endpoint # test 3: find neighbors to random nodes accuracy_numerator = accuracy_denominator = 0 # top-1 nearest neighbor accuracy jaccard_numerator = jaccard_denominator = 0 # jaccard similarity aka intersection over union all_node_ids = list(dht.values()) for i in range(100): query_id = DHTID.generate() k_nearest = random.randint(1, 20) exclude_self = random.random() > 0.5 nearest = loop.run_until_complete( me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self))[query_id] nearest_nodes = list(nearest) # keys from ordered dict assert len( nearest_nodes ) == k_nearest, "beam search must return exactly k_nearest results" assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self" assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0 ), "results must be sorted by distance" ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance) if exclude_self and me.node_id in ref_nearest: ref_nearest.remove(me.node_id) if len(ref_nearest) > k_nearest: ref_nearest.pop() accuracy_numerator += nearest_nodes[0] == ref_nearest[0] accuracy_denominator += 1 jaccard_numerator += len( set.intersection(set(nearest_nodes), set(ref_nearest))) jaccard_denominator += k_nearest accuracy = accuracy_numerator / accuracy_denominator print("Top-1 accuracy:", accuracy) # should be 98-100% jaccard_index = jaccard_numerator / jaccard_denominator print("Jaccard index (intersection over union):", jaccard_index) # should be 95-100% assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})" assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})" # test 4: find all nodes dummy = DHTID.generate() nearest = loop.run_until_complete( me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy] assert len(nearest) == len(dht) + 1 assert len( set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0 # test 5: node without peers other_node = loop.run_until_complete(DHTNode.create()) nearest = loop.run_until_complete( other_node.find_nearest_nodes([dummy]))[dummy] assert len(nearest) == 1 and nearest[ other_node.node_id] == f"{LOCALHOST}:{other_node.port}" nearest = loop.run_until_complete( other_node.find_nearest_nodes([dummy], exclude_self=True))[dummy] assert len(nearest) == 0 # test 6 store and get value true_time = get_dht_time() + 1200 assert loop.run_until_complete( me.store("mykey", ["Value", 10], true_time)) for node in [me, other_node]: val, expiration_time = loop.run_until_complete(me.get("mykey")) assert expiration_time == true_time, "Wrong time" assert val == ["Value", 10], "Wrong value" # test 7: bulk store and bulk get keys = 'foo', 'bar', 'baz', 'zzz' values = 3, 2, 'batman', [1, 2, 3] store_ok = loop.run_until_complete( me.store_many(keys, values, expiration_time=get_dht_time() + 999)) assert all(store_ok.values()), "failed to store one or more keys" response = loop.run_until_complete(me.get_many(keys[::-1])) for key, value in zip(keys, values): assert key in response and response[key][0] == value test_success.set()
def test_dht_protocol(): # create the first peer peer1_port, peer1_id, peer1_started = hivemind.find_open_port( ), DHTID.generate(), mp.Event() peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, peer1_started), daemon=True) peer1_proc.start(), peer1_started.wait() # create another peer that connects to the first peer peer2_port, peer2_id, peer2_started = hivemind.find_open_port( ), DHTID.generate(), mp.Event() peer2_proc = mp.Process(target=run_protocol_listener, args=(peer2_port, peer2_id, peer2_started), kwargs={'ping': f'{LOCALHOST}:{peer1_port}'}, daemon=True) peer2_proc.start(), peer2_started.wait() loop = asyncio.get_event_loop() for listen in [ False, True ]: # note: order matters, this test assumes that first run uses listen=False protocol = loop.run_until_complete( DHTProtocol.create(DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen)) print(f"Self id={protocol.node_id}", flush=True) assert loop.run_until_complete( protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id key, value, expiration = DHTID.generate(), [ random.random(), { 'ololo': 'pyshpysh' } ], get_dht_time() + 1e3 store_ok = loop.run_until_complete( protocol.call_store(f'{LOCALHOST}:{peer1_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)) assert all(store_ok), "DHT rejected a trivial store" # peer 1 must know about peer 2 (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete( protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key] recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes) (recv_id, recv_endpoint) = next(iter(nodes_found.items())) assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \ f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}" assert recv_value == value and recv_expiration == expiration, \ f"call_find_value expected {value} (expires by {expiration}) " \ f"but got {recv_value} (expires by {recv_expiration})" # peer 2 must know about peer 1, but not have a *random* nonexistent value dummy_key = DHTID.generate() empty_item, nodes_found_2 = loop.run_until_complete( protocol.call_find(f'{LOCALHOST}:{peer2_port}', [dummy_key]))[dummy_key] assert empty_item is None, "Non-existent keys shouldn't have values" (recv_id, recv_endpoint) = next(iter(nodes_found_2.items())) assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \ f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}" # cause a non-response by querying a nonexistent peer dummy_port = hivemind.find_open_port() assert loop.run_until_complete( protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None # store/get a dictionary with sub-keys nested_key, subkey1, subkey2 = DHTID.generate(), 'foo', 'bar' value1, value2 = [random.random(), {'ololo': 'pyshpysh'}], 'abacaba' assert loop.run_until_complete( protocol.call_store( f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)], expiration_time=[expiration], subkeys=[subkey1])) assert loop.run_until_complete( protocol.call_store( f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)], expiration_time=[expiration + 5], subkeys=[subkey2])) (recv_dict, recv_expiration), nodes_found = loop.run_until_complete( protocol.call_find(f'{LOCALHOST}:{peer1_port}', [nested_key]))[nested_key] assert isinstance(recv_dict, DictionaryDHTValue) assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5 assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration) assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5) assert LOCALHOST in loop.run_until_complete( protocol.get_outgoing_request_endpoint( f'{LOCALHOST}:{peer1_port}')) if listen: loop.run_until_complete(protocol.shutdown()) peer1_proc.terminate() peer2_proc.terminate()