예제 #1
0
    async def _communicate_with_peer(self, peer_endpoint: Endpoint,
                                     local_part: torch.Tensor) -> torch.Tensor:
        """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
        serialized_tensor_part = serialize_torch_tensor(local_part,
                                                        self.compression_type,
                                                        allow_inplace=False)
        chunks = split_for_streaming(serialized_tensor_part,
                                     self.chunk_size_bytes)

        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
        await stream.write(
            averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING,
                                        group_id=self.group_id,
                                        endpoint=self.endpoint,
                                        tensor_part=next(chunks)))
        for chunk in chunks:
            await stream.write(averaging_pb2.AveragingData(tensor_part=chunk))
        await stream.done_writing()

        outputs: Sequence[averaging_pb2.AveragingData] = [
            message async for message in stream
        ]
        code = outputs[0].code if outputs else averaging_pb2.INTERNAL_ERROR
        if code != averaging_pb2.AVERAGED_PART:
            raise AllreduceException(
                f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)}"
                f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)},"
                f" allreduce failed")

        averaged_part = deserialize_torch_tensor(
            combine_from_streaming(
                [message.tensor_part for message in outputs]))
        self.register_averaged_part(peer_endpoint, averaged_part)
        return averaged_part
예제 #2
0
 async def accumulate_part_streaming(self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor]
                                     ) -> Iterable[runtime_pb2.Tensor]:
     """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """
     tensor_part: torch.Tensor = deserialize_torch_tensor(combine_from_streaming(stream_messages))
     averaged_part = await self.accumulate_part(source, tensor_part)
     if not self.averaged_part_stream.done():
         serialized_tensor = serialize_torch_tensor(averaged_part, self.compression_type, allow_inplace=False)
         stream_chunks = tuple(split_for_streaming(serialized_tensor, self.chunk_size_bytes))
         self.averaged_part_stream.set_result(stream_chunks)
         return stream_chunks
     else:
         return self.averaged_part_stream.result()
예제 #3
0
    async def accumulate_part_streaming(
        self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor]
    ) -> Iterable[runtime_pb2.Tensor]:
        """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """
        try:
            tensor_part = deserialize_torch_tensor(
                combine_from_streaming(stream_messages))
        except RuntimeError as e:
            raise AllreduceException(
                f"Could not deserialize tensor part from {source} for streaming {e}"
            )

        averaged_part = await self.accumulate_part(
            source, tensor_part, weight=self.peer_weights[source])
        serialized_tensor = serialize_torch_tensor(averaged_part - tensor_part,
                                                   self.compression_type,
                                                   allow_inplace=False)
        stream_chunks = tuple(
            split_for_streaming(serialized_tensor, self.chunk_size_bytes))
        return stream_chunks