async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float, allow_retries: bool, timeout: Optional[float]): loop = asyncio.get_event_loop() start_time = get_dht_time() group_id = None try: while not future.done(): try: self._pending_group_assembled.clear() data_for_gather = self.serializer.dumps([ weight, self._throughput, self.mode.value, gather_binary ]) group_info = await self._matchmaking.look_for_group( timeout=timeout, data_for_gather=data_for_gather) if group_info is None: raise AllreduceException( "Averaging step failed: could not find a group.") group_id = group_info.group_id allreduce_runner = await self._make_allreduce_runner( group_info, **self.allreduce_kwargs) self._running_groups[group_id] = allreduce_runner self._pending_group_assembled.set() await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout) if self.mode != AveragingMode.AUX: await loop.run_in_executor(None, self.update_tensors, allreduce_runner) # averaging is finished, exit the loop future.set_result(allreduce_runner.gathered) except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError, asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e: time_elapsed = get_dht_time() - start_time if not allow_retries or (timeout is not None and timeout < time_elapsed): logger.exception(f"Averager caught {repr(e)}") future.set_exception(e) else: logger.warning(f"Averager caught {repr(e)}, retrying") finally: _ = self._running_groups.pop(group_id, None) self._pending_group_assembled.set() except BaseException as e: if not future.done(): future.set_exception(e) raise finally: if not future.done(): future.set_exception( RuntimeError( "Internal sanity check failed: averager.step left future pending." " Please report this to hivemind issues."))
async def _get(self, node: DHTNode, key: DHTKey, latest: bool, future: MPFuture, **kwargs): try: result = await node.get(key, latest=latest, **kwargs) if not future.done(): future.set_result(result) except BaseException as e: if not future.done(): future.set_exception(e) raise
async def _store(self, node: DHTNode, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration, subkey: Optional[Subkey], future: MPFuture, **kwargs): try: result = await node.store(key, value, expiration_time, subkey=subkey, **kwargs) if not future.done(): future.set_result(result) except BaseException as e: if not future.done(): future.set_exception(e) raise
async def _set_group_bits(self, group_bits: str, future: MPFuture): try: self._matchmaking.group_key_manager.group_bits = group_bits return future.set_result(None) except Exception as e: if not future.done(): future.set_exception(e)
async def _step(self, *, future: MPFuture, gather: DataForGather, allow_retries: bool, timeout: Optional[float]): loop = asyncio.get_event_loop() start_time = get_dht_time() group_id = None while not future.done(): try: self._pending_group_assembled.clear() gather_binary = self.serializer.dumps(gather) allreduce_group = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=gather_binary) if allreduce_group is None: raise AllreduceException("Averaging step failed: could not find a group.") group_id = allreduce_group.group_id self._running_groups[group_id] = allreduce_group self._pending_group_assembled.set() await asyncio.wait_for(allreduce_group.run(), self._allreduce_timeout) await loop.run_in_executor(None, self.update_tensors, allreduce_group) # averaging is finished, exit the loop gathered_items = map(self.serializer.loads, allreduce_group.gathered) gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items)) future.set_result(gathered_data_by_peer) except (AllreduceException, MatchmakingException): time_elapsed = get_dht_time() - start_time if not allow_retries or (timeout is not None and timeout < time_elapsed): future.set_result(None) except Exception as e: future.set_exception(e) raise finally: _ = self._running_groups.pop(group_id, None) self._pending_group_assembled.set()
async def _declare_averager(self, node: DHTNode, *, group_key: str, endpoint: Endpoint, expiration_time: DHTExpiration, looking_for_group: bool, future: MPFuture): try: expiration_time = expiration_time if looking_for_group else float( nextafter(expiration_time, float('inf'))) # ^-- when declaring averager inactive, we increment expiration time to overwrite the pre-existing entry store_ok = await node.store(key=group_key, subkey=endpoint, value=looking_for_group, expiration_time=expiration_time) future.set_result(store_ok) except Exception as e: if not future.done(): future.set_exception(e)
async def _get_averagers(self, node: DHTNode, *, group_key: str, only_active: bool, future: MPFuture): try: result = await node.get(group_key, latest=True) if result is None: logger.debug( f"Allreduce group not found: {group_key}, creating new group." ) future.set_result([]) return assert isinstance(result.value, dict), f"expected {group_key} to be a Dict[Endpoint, is_active], " \ f"but got {result.value} of type {type(result.value)}." averagers = [(endpoint, entry.expiration_time) for endpoint, entry in result.value.items() if not only_active or entry.value is True] future.set_result(averagers) except Exception as e: if not future.done(): future.set_exception(e)
async def _load_state_from_peers(self, future: MPFuture): try: key_manager = self._matchmaking.group_key_manager peer_priority, _ = self.dht.get( f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None) peer_priority = { peer: float(info.value) for peer, info in peer_priority.items() if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int)) } if not isinstance(peer_priority, dict) or len(peer_priority) == 0: logger.info( f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}." ) future.set_result(None) return metadata = None for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True): if peer != self.endpoint: logger.info(f"Downloading parameters from peer {peer}") stream = None try: stub = ChannelCache.get_stub( peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True) stream = stub.rpc_download_state( averaging_pb2.DownloadRequest()) current_tensor_parts, tensors = [], [] async for message in stream: if message.metadata: metadata = self.serializer.loads( message.metadata) if message.tensor_part.dtype and current_tensor_parts: # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one tensors.append( deserialize_torch_tensor( combine_from_streaming( current_tensor_parts))) current_tensor_parts = [] current_tensor_parts.append(message.tensor_part) if current_tensor_parts: tensors.append( deserialize_torch_tensor( combine_from_streaming( current_tensor_parts))) if not metadata: logger.debug( f"Peer {peer} did not send its state.") continue logger.info(f"Finished downloading state from {peer}") future.set_result((metadata, tensors)) self.last_updated = get_dht_time() return except BaseException as e: logger.exception( f"Failed to download state from {peer} - {repr(e)}" ) finally: if stream is not None: await stream.code() finally: if not future.done(): logger.warning( "Averager could not load state from peers: all requests have failed." ) future.set_result(None)