async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor: """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """ serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False) chunks = split_for_streaming(serialized_tensor_part, self.chunk_size_bytes) stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part() await stream.write( averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING, group_id=self.group_id, endpoint=self.endpoint, tensor_part=next(chunks))) for chunk in chunks: await stream.write(averaging_pb2.AveragingData(tensor_part=chunk)) await stream.done_writing() outputs: Sequence[averaging_pb2.AveragingData] = [ message async for message in stream ] code = outputs[0].code if outputs else averaging_pb2.INTERNAL_ERROR if code != averaging_pb2.AVERAGED_PART: raise AllreduceException( f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)}" f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}," f" allreduce failed") averaged_part = deserialize_torch_tensor( combine_from_streaming( [message.tensor_part for message in outputs])) self.register_averaged_part(peer_endpoint, averaged_part) return averaged_part
async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode): stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part() await stream.write( averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint, code=code)) await stream.done_writing()
async def rpc_aggregate_part(self, request: averaging_pb2.AveragingData, context: grpc.ServicerContext): """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """ if request.group_id not in self._running_groups and not self._pending_group_assembled.is_set( ): # this handles a special case when leader accepted us to group AND began allreduce right away, # but his response with group_id was delayed and other peers got to us first await self._pending_group_assembled.wait() if request.group_id not in self._running_groups: return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID) else: return await self._running_groups[ request.group_id].rpc_aggregate_part(request, context)
async def rpc_aggregate_part( self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext ) -> AsyncIterator[averaging_pb2.AveragingData]: """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """ request: averaging_pb2.AveragingData = await anext(stream) if request.group_id != self.group_id: yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID) elif request.code == averaging_pb2.PART_FOR_AVERAGING: try: tensor_chunks = (request.tensor_part, *[msg.tensor_part async for msg in stream]) averaged_chunks = iter(await self.accumulate_part_streaming( request.endpoint, tensor_chunks)) yield averaging_pb2.AveragingData( code=averaging_pb2.AVERAGED_PART, tensor_part=next(averaged_chunks)) for averaged_chunk in averaged_chunks: yield averaging_pb2.AveragingData( tensor_part=averaged_chunk) except Exception as e: self.set_exception(e) yield averaging_pb2.AveragingData( code=averaging_pb2.INTERNAL_ERROR) else: error_code = averaging_pb2.MessageCode.Name(request.code) logger.debug( f"{self} - peer {request.endpoint} sent {error_code}, allreduce cannot continue" ) self.set_exception( AllreduceException( f"peer {request.endpoint} sent {error_code}.")) yield averaging_pb2.AveragingData( code=averaging_pb2.INTERNAL_ERROR)
async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext ) -> AsyncIterator[averaging_pb2.AveragingData]: """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """ request = await anext(stream) if request.group_id not in self._running_groups: # this handles a special case when leader accepted us to group AND began allreduce right away, # but his response with group_id was delayed and other peers got to us first await self._pending_group_assembled.wait() group = self._running_groups.get(request.group_id) if group is None: yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID) return async for message in group.rpc_aggregate_part(achain(aiter(request), stream), context): yield message