async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpiration) -> Optional[GroupInfo]: """ :param leader: request this peer to be your leader for allreduce :param expiration_time: inform leader that we intend to begin averaging before this expiration_time :returns: if leader leader accepted us and started AllReduce, return that AllReduce. Otherwise, return None :note: this function does not guarantee that your group leader is the same as :leader: parameter The originally specified leader can disband group and redirect us to a different leader """ assert self.is_looking_for_group and self.current_leader is None call: Optional[grpc.aio.UnaryStreamCall] = None try: async with self.lock_request_join_group: leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True) call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest( endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time, client_mode=self.client_mode, gather=self.data_for_gather)) message = await asyncio.wait_for(call.read(), timeout=self.request_timeout) if message.code == averaging_pb2.ACCEPTED: logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers") self.current_leader = leader self.was_accepted_to_group.set() if len(self.current_followers) > 0: await self.leader_disband_group() if message.code != averaging_pb2.ACCEPTED: code = averaging_pb2.MessageCode.Name(message.code) logger.debug(f"{self.endpoint} - requested {leader} to be my leader, but got rejected with {code}") return None async with self.potential_leaders.pause_search(): time_to_expiration = max(expiration_time - get_dht_time(), 0.0) message = await asyncio.wait_for(call.read(), time_to_expiration + self.request_timeout) if message.code == averaging_pb2.BEGIN_ALLREDUCE: async with self.lock_request_join_group: return await self.follower_assemble_group(leader, message) if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED): if message.suggested_leader and message.suggested_leader != self.endpoint: logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}") self.current_leader = None call.cancel() return await self.request_join_group(message.suggested_leader, expiration_time) else: logger.debug(f"{self} - leader disbanded group") return None logger.debug(f"{self} - unexpected message from leader: {averaging_pb2.MessageCode.Name(message.code)}") return None except asyncio.TimeoutError: logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}") if call is not None: call.cancel() return None finally: self.was_accepted_to_group.clear() self.current_leader = None if call is not None: await call.code()
def _get_expert_stub(endpoint: Endpoint, *extra_options: Tuple[str, Any]): """ Create a gRPC stub to access remote expert or use previously created stub from a process-wide cache """ channel_options = (('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1)) + extra_options return ChannelCache.get_stub(endpoint, runtime_grpc.ConnectionHandlerStub, aio=False, options=channel_options)
async def _load_state_from_peers(self, future: MPFuture): key_manager = self._matchmaking.group_key_manager peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None) peer_priority = { peer: float(info.value) for peer, info in peer_priority.items() if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int)) } if not isinstance(peer_priority, dict) or len(peer_priority) == 0: logger.info( f"Averager could not load state from peers: peer dict is absent or corrupted {peer_priority}." ) future.set_result(None) return metadata = None for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True): if peer != self.endpoint: logger.info(f"Downloading parameters from peer {peer}") stream = None try: leader_stub = ChannelCache.get_stub( peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True) stream = leader_stub.rpc_download_state( averaging_pb2.DownloadRequest()) current_tensor_parts, tensors = [], [] async for message in stream: if message.metadata: metadata = self.serializer.loads(message.metadata) if message.tensor_part.dtype and current_tensor_parts: # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one tensors.append( deserialize_torch_tensor( combine_from_streaming( current_tensor_parts))) current_tensor_parts = [] current_tensor_parts.append(message.tensor_part) if current_tensor_parts: tensors.append( deserialize_torch_tensor( combine_from_streaming(current_tensor_parts))) future.set_result((metadata, tensors)) self.last_updated = get_dht_time() return except grpc.aio.AioRpcError as e: logger.info(f"Failed to download state from {peer} - {e}") finally: if stream is not None: await stream.code() else: logger.warning( "Averager could not load state from peers: found no active peers." ) future.set_result(None)