def test_vector_compression(size=(128, 128, 64), alpha=5e-08): torch.manual_seed(0) from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils import serialize_torch_tensor, deserialize_torch_tensor X = torch.randn(*size) assert torch.allclose( deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.NONE)), X) error = deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.MEANSTD_LAST_AXIS_FLOAT16)) - X assert error.square().mean() < alpha error = deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.FLOAT16)) - X assert error.square().mean() < alpha
async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor: """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """ serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False) chunks = split_for_streaming(serialized_tensor_part, self.chunk_size_bytes) stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part() await stream.write( averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING, group_id=self.group_id, endpoint=self.endpoint, tensor_part=next(chunks))) for chunk in chunks: await stream.write(averaging_pb2.AveragingData(tensor_part=chunk)) await stream.done_writing() outputs: Sequence[averaging_pb2.AveragingData] = [ message async for message in stream ] code = outputs[0].code if outputs else averaging_pb2.INTERNAL_ERROR if code != averaging_pb2.AVERAGED_PART: raise AllreduceException( f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)}" f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}," f" allreduce failed") averaged_part = deserialize_torch_tensor( combine_from_streaming( [message.tensor_part for message in outputs])) self.register_averaged_part(peer_endpoint, averaged_part) return averaged_part
def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008): torch.manual_seed(0) X = torch.randn(*size) assert torch.allclose( deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.NONE)), X) error = deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.MEANSTD_LAST_AXIS_FLOAT16)) - X assert error.square().mean() < alpha error = deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.FLOAT16)) - X assert error.square().mean() < alpha error = deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X assert error.square().mean() < beta
def _collect_responses( task_to_indices: Dict[grpc.Future, Tuple[int, int]], num_samples: int, k_min: int, timeout_total: Optional[float], timeout_after_k_min: Optional[float] ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]: """ await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers """ timeout_total = float( 'inf') if timeout_total is None else timeout_total timeout_after_k_min = float( 'inf') if timeout_after_k_min is None else timeout_after_k_min num_successful_tasks = [0 for _ in range(num_samples)] pending_samples = num_samples # samples for which we have less than k_min results finished_indices, finished_outputs = [], [] t_finish = time.perf_counter() + timeout_total pending_tasks = set(task_to_indices.keys()) finished_tasks = Queue() try: # the algorithm below is essentially futures.as_completed, but for grpc.Future for task in pending_tasks: task.add_done_callback(finished_tasks.put) for _ in range(len(task_to_indices)): timeout = max( 0.0, t_finish - time.perf_counter()) if t_finish != float('inf') else None task = finished_tasks.get(timeout=timeout) pending_tasks.discard(task) if task.exception() or task.cancelled(): logger.warning( f"Task {task} failed: {type(task.exception())}") continue finished_indices.append(task_to_indices[task]) finished_outputs.append( tuple( deserialize_torch_tensor(tensor) for tensor in task.result().tensors)) # count how many successes we have for each input sample sample_index = task_to_indices[task][0] num_successful_tasks[sample_index] += 1 if num_successful_tasks[sample_index] == k_min: pending_samples -= 1 if pending_samples <= 0: # all tasks finished, await stragglers for at most timeout_after_k_min t_finish = min( t_finish, time.perf_counter() + timeout_after_k_min) except Empty: pass # we reached t_finish, this is normal behavior finally: for task in pending_tasks: task.cancel() return finished_indices, finished_outputs
async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext): inputs = [ deserialize_torch_tensor(tensor) for tensor in request.tensors ] future = self.experts[request.uid].forward_pool.submit_task(*inputs) serialized_response = [ serialize_torch_tensor(tensor) for tensor in await future ] return runtime_pb2.ExpertResponse(tensors=serialized_response)
async def accumulate_part_streaming(self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor] ) -> Iterable[runtime_pb2.Tensor]: """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """ tensor_part: torch.Tensor = deserialize_torch_tensor(combine_from_streaming(stream_messages)) averaged_part = await self.accumulate_part(source, tensor_part) if not self.averaged_part_stream.done(): serialized_tensor = serialize_torch_tensor(averaged_part, self.compression_type, allow_inplace=False) stream_chunks = tuple(split_for_streaming(serialized_tensor, self.chunk_size_bytes)) self.averaged_part_stream.set_result(stream_chunks) return stream_chunks else: return self.averaged_part_stream.result()
def _process_dispatched_task(task: grpc.Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]: if task.exception() or task.cancelled(): logger.warning(f"Task {task} failed: {type(task.exception())}") return None deserialized_outputs = [] for tensor in task.result().tensors: deserialized_tensor = deserialize_torch_tensor(tensor) if detect_anomalies and not deserialized_tensor.isfinite().all(): logger.error(f"Task {task} failed: output tensor contains nan/inf values") return None deserialized_outputs.append(deserialized_tensor) return tuple(deserialized_outputs)
async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext): inputs_and_grad_outputs = [ deserialize_torch_tensor(tensor) for tensor in request.tensors ] future = self.experts[request.uid].backward_pool.submit_task( *inputs_and_grad_outputs) serialized_response = [ serialize_torch_tensor(tensor, proto.compression, allow_inplace=True) for tensor, proto in zip( await future, nested_flatten(self.experts[request.uid].grad_inputs_schema)) ] return runtime_pb2.ExpertResponse(tensors=serialized_response)
async def _forward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert, inputs: Tuple[torch.Tensor]): stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub( expert.endpoint, aio=True) try: outputs = await stub.forward( runtime_pb2.ExpertRequest(uid=expert.uid, tensors=[ serialize_torch_tensor(tensor) for tensor in inputs ])) return grid_indices, tuple( deserialize_torch_tensor(tensor) for tensor in outputs.tensors) except grpc.experimental.aio.AioRpcError as error: logger.warning( f"RemoteExpert {expert} failed forward: {error.code()} (inputs: {inputs})" )
async def accumulate_part_streaming( self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor] ) -> Iterable[runtime_pb2.Tensor]: """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """ try: tensor_part = deserialize_torch_tensor( combine_from_streaming(stream_messages)) except RuntimeError as e: raise AllreduceException( f"Could not deserialize tensor part from {source} for streaming {e}" ) averaged_part = await self.accumulate_part( source, tensor_part, weight=self.peer_weights[source]) if not self.averaged_part_stream.done(): serialized_tensor = serialize_torch_tensor(averaged_part, self.compression_type, allow_inplace=False) stream_chunks = tuple( split_for_streaming(serialized_tensor, self.chunk_size_bytes)) self.averaged_part_stream.set_result(stream_chunks) return stream_chunks else: return self.averaged_part_stream.result()
def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float: t = time.time() deserialize_torch_tensor(serialize_torch_tensor(tensor, compression_type)) return time.time() - t