async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor: """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """ serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False) chunks = split_for_streaming(serialized_tensor_part, self.chunk_size_bytes) stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part() await stream.write( averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING, group_id=self.group_id, endpoint=self.endpoint, tensor_part=next(chunks))) for chunk in chunks: await stream.write(averaging_pb2.AveragingData(tensor_part=chunk)) await stream.done_writing() outputs: Sequence[averaging_pb2.AveragingData] = [ message async for message in stream ] code = outputs[0].code if outputs else averaging_pb2.INTERNAL_ERROR if code != averaging_pb2.AVERAGED_PART: raise AllreduceException( f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)}" f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}," f" allreduce failed") averaged_part = deserialize_torch_tensor( combine_from_streaming( [message.tensor_part for message in outputs])) self.register_averaged_part(peer_endpoint, averaged_part) return averaged_part
async def accumulate_part_streaming(self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor] ) -> Iterable[runtime_pb2.Tensor]: """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """ tensor_part: torch.Tensor = deserialize_torch_tensor(combine_from_streaming(stream_messages)) averaged_part = await self.accumulate_part(source, tensor_part) if not self.averaged_part_stream.done(): serialized_tensor = serialize_torch_tensor(averaged_part, self.compression_type, allow_inplace=False) stream_chunks = tuple(split_for_streaming(serialized_tensor, self.chunk_size_bytes)) self.averaged_part_stream.set_result(stream_chunks) return stream_chunks else: return self.averaged_part_stream.result()
async def accumulate_part_streaming( self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor] ) -> Iterable[runtime_pb2.Tensor]: """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """ try: tensor_part = deserialize_torch_tensor( combine_from_streaming(stream_messages)) except RuntimeError as e: raise AllreduceException( f"Could not deserialize tensor part from {source} for streaming {e}" ) averaged_part = await self.accumulate_part( source, tensor_part, weight=self.peer_weights[source]) serialized_tensor = serialize_torch_tensor(averaged_part - tensor_part, self.compression_type, allow_inplace=False) stream_chunks = tuple( split_for_streaming(serialized_tensor, self.chunk_size_bytes)) return stream_chunks