Esempio n. 1
0
    async def _communicate_with_peer(self, peer_endpoint: Endpoint,
                                     local_part: torch.Tensor) -> torch.Tensor:
        """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
        serialized_tensor_part = serialize_torch_tensor(local_part,
                                                        self.compression_type,
                                                        allow_inplace=False)
        chunks = split_for_streaming(serialized_tensor_part,
                                     self.chunk_size_bytes)

        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
        await stream.write(
            averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING,
                                        group_id=self.group_id,
                                        endpoint=self.endpoint,
                                        tensor_part=next(chunks)))
        for chunk in chunks:
            await stream.write(averaging_pb2.AveragingData(tensor_part=chunk))
        await stream.done_writing()

        outputs: Sequence[averaging_pb2.AveragingData] = [
            message async for message in stream
        ]
        code = outputs[0].code if outputs else averaging_pb2.INTERNAL_ERROR
        if code != averaging_pb2.AVERAGED_PART:
            raise AllreduceException(
                f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)}"
                f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)},"
                f" allreduce failed")

        averaged_part = deserialize_torch_tensor(
            combine_from_streaming(
                [message.tensor_part for message in outputs]))
        self.register_averaged_part(peer_endpoint, averaged_part)
        return averaged_part
Esempio n. 2
0
 async def _send_error_to_peer(self, peer_endpoint: Endpoint,
                               code: averaging_pb2.MessageCode):
     stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
     await stream.write(
         averaging_pb2.AveragingData(group_id=self.group_id,
                                     endpoint=self.endpoint,
                                     code=code))
     await stream.done_writing()
Esempio n. 3
0
 async def rpc_aggregate_part(self, request: averaging_pb2.AveragingData,
                              context: grpc.ServicerContext):
     """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """
     if request.group_id not in self._running_groups and not self._pending_group_assembled.is_set(
     ):
         # this handles a special case when leader accepted us to group AND began allreduce right away,
         # but his response with group_id was delayed and other peers got to us first
         await self._pending_group_assembled.wait()
     if request.group_id not in self._running_groups:
         return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
     else:
         return await self._running_groups[
             request.group_id].rpc_aggregate_part(request, context)
Esempio n. 4
0
    async def rpc_aggregate_part(
        self, stream: AsyncIterator[averaging_pb2.AveragingData],
        context: grpc.ServicerContext
    ) -> AsyncIterator[averaging_pb2.AveragingData]:
        """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """
        request: averaging_pb2.AveragingData = await anext(stream)

        if request.group_id != self.group_id:
            yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)

        elif request.code == averaging_pb2.PART_FOR_AVERAGING:
            try:
                tensor_chunks = (request.tensor_part,
                                 *[msg.tensor_part async for msg in stream])
                averaged_chunks = iter(await self.accumulate_part_streaming(
                    request.endpoint, tensor_chunks))
                yield averaging_pb2.AveragingData(
                    code=averaging_pb2.AVERAGED_PART,
                    tensor_part=next(averaged_chunks))
                for averaged_chunk in averaged_chunks:
                    yield averaging_pb2.AveragingData(
                        tensor_part=averaged_chunk)

            except Exception as e:
                self.set_exception(e)
                yield averaging_pb2.AveragingData(
                    code=averaging_pb2.INTERNAL_ERROR)
        else:
            error_code = averaging_pb2.MessageCode.Name(request.code)
            logger.debug(
                f"{self} - peer {request.endpoint} sent {error_code}, allreduce cannot continue"
            )
            self.set_exception(
                AllreduceException(
                    f"peer {request.endpoint} sent {error_code}."))
            yield averaging_pb2.AveragingData(
                code=averaging_pb2.INTERNAL_ERROR)
Esempio n. 5
0
    async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
                                 ) -> AsyncIterator[averaging_pb2.AveragingData]:
        """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """
        request = await anext(stream)
        if request.group_id not in self._running_groups:
            # this handles a special case when leader accepted us to group AND began allreduce right away,
            # but his response with group_id was delayed and other peers got to us first
            await self._pending_group_assembled.wait()

        group = self._running_groups.get(request.group_id)
        if group is None:
            yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
            return

        async for message in group.rpc_aggregate_part(achain(aiter(request), stream), context):
            yield message