async def _declare_averager_periodically(self, key_manager: GroupKeyManager): async with self.lock_declare: try: while True: await self.running.wait() new_expiration_time = min( get_dht_time() + self.averaging_expiration, self.search_end_time) self.declared_group_key = group_key = key_manager.current_key self.declared_expiration_time = new_expiration_time self.declared_expiration.set() await key_manager.declare_averager( group_key, self.endpoint, expiration_time=new_expiration_time) await asyncio.sleep(self.declared_expiration_time - get_dht_time()) if self.running.is_set() and len(self.leader_queue) == 0: await key_manager.update_key_on_not_enough_peers() except Exception as e: # note: we catch exceptions here because otherwise they are never printed logger.error(f"{self.endpoint} - caught {type(e)}: {e}") finally: if self.declared_group_key is not None: prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time self.declared_group_key, self.declared_expiration_time = None, float( 'inf') self.leader_queue, self.max_assured_time = TimedStorage[ Endpoint, DHTExpiration](), float('-inf') await key_manager.declare_averager(prev_declared_key, self.endpoint, prev_expiration_time, looking_for_group=False)
async def _update_queue_periodically(self, key_manager: GroupKeyManager): try: DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS while get_dht_time() < self.search_end_time: new_peers = await key_manager.get_averagers( key_manager.current_key, only_active=True) self.max_assured_time = max( self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY) self.leader_queue.clear() for peer, peer_expiration_time in new_peers: if peer == self.endpoint or ( peer, peer_expiration_time) in self.past_attempts: continue self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time) self.max_assured_time = max( self.max_assured_time, peer_expiration_time - DISCREPANCY) self.update_finished.set() await asyncio.wait( {self.running.wait(), self.update_triggered.wait()}, return_when=asyncio.ALL_COMPLETED, timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None) self.update_triggered.clear() except (concurrent.futures.CancelledError, asyncio.CancelledError): return # note: this is a compatibility layer for python3.7 except Exception as e: logger.error(f"{self.endpoint} - caught {type(e)}: {e}") raise
async def _update_queue_periodically(self, group_key: GroupKey): DISCREPANCY = hivemind.utils.timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS while get_dht_time() < self.search_end_time: new_peers = await self.dht.get_averagers(group_key, only_active=True, return_future=True) self.max_assured_time = max( self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY) self.leader_queue.clear() for peer, peer_expiration_time in new_peers: if peer == self.endpoint or ( peer, peer_expiration_time) in self.past_attempts: continue self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time) self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY) self.update_finished.set() await asyncio.wait( {self.running.wait(), self.update_triggered.wait()}, return_when=asyncio.ALL_COMPLETED, timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None) self.update_triggered.clear()
async def _declare_averager_periodically(self, group_key: GroupKey): async with self.lock_declare: try: while True: await self.running.wait() new_expiration_time = min( get_dht_time() + self.averaging_expiration, self.search_end_time) self.declared_group_key, self.declared_expiration_time = group_key, new_expiration_time self.declared_expiration.set() await self.dht.declare_averager(group_key, self.endpoint, new_expiration_time, looking_for_group=True, return_future=True) await asyncio.sleep(self.declared_expiration_time - get_dht_time()) except Exception as e: # note: we catch exceptions here because otherwise they are never printed logger.error(f"{self.endpoint} - caught {type(e)}: {e}") finally: if self.declared_group_key is not None: prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time self.declared_group_key, self.declared_expiration_time = None, float( 'inf') self.leader_queue, self.max_assured_time = TimedStorage[ Endpoint, DHTExpiration](), float('-inf') await self.dht.declare_averager(prev_declared_key, self.endpoint, prev_expiration_time, looking_for_group=False, return_future=True)
def test_sending_validator_instance_between_processes(): alice = hivemind.DHT(start=True) bob = hivemind.DHT(start=True, initial_peers=[f"{LOCALHOST}:{alice.port}"]) alice.add_validators([SchemaValidator(SampleSchema)]) bob.add_validators([SchemaValidator(SampleSchema)]) assert bob.store('experiment_name', b'foo_bar', get_dht_time() + 10) assert not bob.store('experiment_name', 777, get_dht_time() + 10) assert alice.get('experiment_name', latest=True).value == b'foo_bar'
async def test_expecting_public_keys(dht_nodes_with_schema): alice, bob = dht_nodes_with_schema # Subkeys expected to contain a public key # (so hivemind.dht.crypto.RSASignatureValidator would require a signature) assert await bob.store('signed_data', b'foo_bar', get_dht_time() + 10, subkey=b'uid[owner:public-key]') assert not await bob.store('signed_data', b'foo_bar', get_dht_time() + 10, subkey=b'uid-without-public-key') for peer in [alice, bob]: dictionary = (await peer.get('signed_data', latest=True)).value assert (len(dictionary) == 1 and dictionary[b'uid[owner:public-key]'].value == b'foo_bar')
async def test_expecting_regular_value(dht_nodes_with_schema): alice, bob = dht_nodes_with_schema # Regular value (bytes) expected assert await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10) assert not await bob.store('experiment_name', 666, get_dht_time() + 10) assert not await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10, subkey=b'subkey') # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion assert not await bob.store('experiment_name', [], get_dht_time() + 10) assert not await bob.store('experiment_name', [1, 2, 3], get_dht_time() + 10) for peer in [alice, bob]: assert (await peer.get('experiment_name', latest=True)).value == b'foo_bar'
def request_expiration_time(self) -> float: """ this averager's current expiration time - used to send join requests to leaders """ if isfinite(self.declared_expiration_time): return self.declared_expiration_time else: return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpiration) -> Optional[GroupInfo]: """ :param leader: request this peer to be your leader for allreduce :param expiration_time: inform leader that we intend to begin averaging before this expiration_time :returns: if leader leader accepted us and started AllReduce, return that AllReduce. Otherwise, return None :note: this function does not guarantee that your group leader is the same as :leader: parameter The originally specified leader can disband group and redirect us to a different leader """ assert self.is_looking_for_group and self.current_leader is None call: Optional[grpc.aio.UnaryStreamCall] = None try: async with self.lock_request_join_group: leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True) call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest( endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time, client_mode=self.client_mode, gather=self.data_for_gather)) message = await asyncio.wait_for(call.read(), timeout=self.request_timeout) if message.code == averaging_pb2.ACCEPTED: logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers") self.current_leader = leader self.was_accepted_to_group.set() if len(self.current_followers) > 0: await self.leader_disband_group() if message.code != averaging_pb2.ACCEPTED: code = averaging_pb2.MessageCode.Name(message.code) logger.debug(f"{self.endpoint} - requested {leader} to be my leader, but got rejected with {code}") return None async with self.potential_leaders.pause_search(): time_to_expiration = max(expiration_time - get_dht_time(), 0.0) message = await asyncio.wait_for(call.read(), time_to_expiration + self.request_timeout) if message.code == averaging_pb2.BEGIN_ALLREDUCE: async with self.lock_request_join_group: return await self.follower_assemble_group(leader, message) if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED): if message.suggested_leader and message.suggested_leader != self.endpoint: logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}") self.current_leader = None call.cancel() return await self.request_join_group(message.suggested_leader, expiration_time) else: logger.debug(f"{self} - leader disbanded group") return None logger.debug(f"{self} - unexpected message from leader: {averaging_pb2.MessageCode.Name(message.code)}") return None except asyncio.TimeoutError: logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}") if call is not None: call.cancel() return None finally: self.was_accepted_to_group.clear() self.current_leader = None if call is not None: await call.code()
async def begin_search(self, key_manager: GroupKeyManager, timeout: Optional[float], declare: bool = True): async with self.lock_search: self.running.set() self.search_end_time = get_dht_time( ) + timeout if timeout is not None else float('inf') update_queue_task = asyncio.create_task( self._update_queue_periodically(key_manager)) if declare: declare_averager_task = asyncio.create_task( self._declare_averager_periodically(key_manager)) try: yield self finally: if not update_queue_task.done(): update_queue_task.cancel() if declare and not declare_averager_task.done(): declare_averager_task.cancel() for field in (self.past_attempts, self.leader_queue, self.running, self.update_finished, self.update_triggered, self.declared_expiration): field.clear() self.max_assured_time = float('-inf') self.search_end_time = float('inf')
async def test_keys_outside_schema(dht_nodes_with_schema): class Schema(BaseModel): some_field: StrictInt class MergedSchema(BaseModel): another_field: StrictInt for allow_extra_keys in [False, True]: validator = SchemaValidator(Schema, allow_extra_keys=allow_extra_keys) assert validator.merge_with( SchemaValidator(MergedSchema, allow_extra_keys=False)) alice = await DHTNode.create(record_validator=validator) bob = await DHTNode.create(record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"]) store_ok = await bob.store('unknown_key', b'foo_bar', get_dht_time() + 10) assert store_ok == allow_extra_keys for peer in [alice, bob]: result = await peer.get('unknown_key', latest=True) if allow_extra_keys: assert result.value == b'foo_bar' else: assert result is None
def test_rsa_signature_validator(): receiver_validator = RSASignatureValidator() sender_validator = RSASignatureValidator(RSAPrivateKey()) mallory_validator = RSASignatureValidator(RSAPrivateKey()) plain_record = DHTRecord(key=b'key', subkey=b'subkey', value=b'value', expiration_time=get_dht_time() + 10) protected_records = [ dataclasses.replace(plain_record, key=plain_record.key + sender_validator.local_public_key), dataclasses.replace(plain_record, subkey=plain_record.subkey + sender_validator.local_public_key), ] # test 1: Non-protected record (no signature added) assert sender_validator.sign_value(plain_record) == plain_record.value assert receiver_validator.validate(plain_record) # test 2: Correct signatures signed_records = [dataclasses.replace(record, value=sender_validator.sign_value(record)) for record in protected_records] for record in signed_records: assert receiver_validator.validate(record) assert receiver_validator.strip_value(record) == b'value' # test 3: Invalid signatures signed_records = protected_records # Without signature signed_records += [dataclasses.replace(record, value=record.value + b'[signature:INVALID_BYTES]') for record in protected_records] # With invalid signature signed_records += [dataclasses.replace(record, value=mallory_validator.sign_value(record)) for record in protected_records] # With someone else's signature for record in signed_records: assert not receiver_validator.validate(record)
async def test_prefix(): class Schema(BaseModel): field: StrictInt validator = SchemaValidator(Schema, allow_extra_keys=False, prefix='prefix') alice = await DHTNode.create(record_validator=validator) bob = await DHTNode.create( record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"]) assert await bob.store('prefix_field', 777, get_dht_time() + 10) assert not await bob.store('prefix_field', 'string_value', get_dht_time() + 10) assert not await bob.store('field', 777, get_dht_time() + 10) for peer in [alice, bob]: assert (await peer.get('prefix_field', latest=True)).value == 777 assert (await peer.get('field', latest=True)) is None
async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext ) -> AsyncIterator[averaging_pb2.MessageFromLeader]: """ accept or reject a join request from another averager; if accepted, run him through allreduce steps """ try: async with self.lock_request_join_group: reason_to_reject = self._check_reasons_to_reject(request) if reason_to_reject is not None: yield reason_to_reject return self.current_followers[request.endpoint] = request yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED) if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done(): # outcome 1: we have assembled a full group and are ready for allreduce await self.leader_assemble_group() # wait for the group to be assembled or disbanded timeout = max(0.0, self.potential_leaders.declared_expiration_time - get_dht_time()) await asyncio.wait({self.assembled_group, self.was_accepted_to_group.wait()}, return_when=asyncio.FIRST_COMPLETED, timeout=timeout) if not self.assembled_group.done() and not self.was_accepted_to_group.is_set(): async with self.lock_request_join_group: if self.assembled_group.done(): pass # this covers a rare case when the group is assembled while the event loop was busy. elif len(self.current_followers) + 1 >= self.min_group_size and self.is_looking_for_group: # outcome 2: the time is up, run allreduce with what we have or disband await self.leader_assemble_group() else: await self.leader_disband_group() if self.was_accepted_to_group.is_set() or not self.assembled_group.done() \ or self.assembled_group.cancelled() or request.endpoint not in self.assembled_group.result(): if self.current_leader is not None: # outcome 3: found by a leader with higher priority, send our followers to him yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED, suggested_leader=self.current_leader) return else: yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED) return allreduce_group = self.assembled_group.result() yield averaging_pb2.MessageFromLeader( code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id, ordered_group_endpoints=allreduce_group.ordered_group_endpoints, part_sizes=allreduce_group.part_sizes, gathered=allreduce_group.gathered, group_key_seed=allreduce_group.group_key_seed) except (concurrent.futures.CancelledError, asyncio.CancelledError): return # note: this is a compatibility layer for python3.7 except Exception as e: logger.exception(e) yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR) finally: # note: this code is guaranteed to run even if the coroutine is destroyed prematurely self.current_followers.pop(request.endpoint, None) self.follower_was_discarded.set()
def test_signing_in_different_process(): parent_conn, child_conn = mp.Pipe() process = mp.Process(target=get_signed_record, args=[child_conn]) process.start() validator = RSASignatureValidator() parent_conn.send(validator) record = DHTRecord(key=b'key', subkey=b'subkey' + validator.local_public_key, value=b'value', expiration_time=get_dht_time() + 10) parent_conn.send(record) signed_record = parent_conn.recv() assert b'[signature:' in signed_record.value assert validator.validate(signed_record)
def test_validator_instance_is_picklable(): # Needs to be picklable because the validator instance may be sent between processes original_validator = RSASignatureValidator() unpickled_validator = pickle.loads(pickle.dumps(original_validator)) # To check that the private key was pickled and unpickled correctly, we sign a record # with the original public key using the unpickled validator and then validate the signature record = DHTRecord(key=b'key', subkey=b'subkey' + original_validator.local_public_key, value=b'value', expiration_time=get_dht_time() + 10) signed_record = dataclasses.replace(record, value=unpickled_validator.sign_value(record)) assert b'[signature:' in signed_record.value assert original_validator.validate(signed_record) assert unpickled_validator.validate(signed_record)
async def test_merging_schema_validators(dht_nodes_with_schema): alice, bob = dht_nodes_with_schema class TrivialValidator(RecordValidatorBase): def validate(self, record: DHTRecord) -> bool: return True second_validator = TrivialValidator() # Can't merge with the validator of the different type assert not alice.protocol.record_validator.merge_with(second_validator) class SecondSchema(BaseModel): some_field: StrictInt another_field: str class ThirdSchema(BaseModel): another_field: StrictInt # Allow it to be a StrictInt as well for schema in [SecondSchema, ThirdSchema]: new_validator = SchemaValidator(schema, allow_extra_keys=False) for peer in [alice, bob]: assert peer.protocol.record_validator.merge_with(new_validator) assert await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10) assert await bob.store('some_field', 777, get_dht_time() + 10) assert not await bob.store('some_field', 'string_value', get_dht_time() + 10) assert await bob.store('another_field', 42, get_dht_time() + 10) assert await bob.store('another_field', 'string_value', get_dht_time() + 10) # Unknown keys are allowed since the first schema is created with allow_extra_keys=True assert await bob.store('unknown_key', 999, get_dht_time() + 10) for peer in [alice, bob]: assert (await peer.get('experiment_name', latest=True)).value == b'foo_bar' assert (await peer.get('some_field', latest=True)).value == 777 assert (await peer.get('another_field', latest=True)).value == 'string_value' assert (await peer.get('unknown_key', latest=True)).value == 999
async def test_expecting_dictionary(dht_nodes_with_schema): alice, bob = dht_nodes_with_schema # Dictionary (bytes -> non-negative int) expected assert await bob.store('n_batches', 777, get_dht_time() + 10, subkey=b'uid1') assert await bob.store('n_batches', 778, get_dht_time() + 10, subkey=b'uid2') assert not await bob.store('n_batches', -666, get_dht_time() + 10, subkey=b'uid3') assert not await bob.store('n_batches', 666, get_dht_time() + 10) assert not await bob.store('n_batches', b'not_integer', get_dht_time() + 10, subkey=b'uid1') assert not await bob.store('n_batches', 666, get_dht_time() + 10, subkey=666) # Refuse storing a plain dictionary bypassing the DictionaryDHTValue convention assert not await bob.store('n_batches', {b'uid3': 779}, get_dht_time() + 10) # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion assert not await bob.store('n_batches', 779.5, get_dht_time() + 10, subkey=b'uid3') assert not await bob.store('n_batches', 779.0, get_dht_time() + 10, subkey=b'uid3') assert not await bob.store('n_batches', [], get_dht_time() + 10) assert not await bob.store('n_batches', [(b'uid3', 779)], get_dht_time() + 10) # Refuse records despite https://github.com/samuelcolvin/pydantic/issues/1268 assert not await bob.store('n_batches', '', get_dht_time() + 10) for peer in [alice, bob]: dictionary = (await peer.get('n_batches', latest=True)).value assert (len(dictionary) == 2 and dictionary[b'uid1'].value == 777 and dictionary[b'uid2'].value == 778)