def test_serialize_tensor(): tensor = torch.randn(512, 12288) serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE) for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10**9]: chunks = list( hivemind.split_for_streaming(serialized_tensor, chunk_size)) assert len( chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1 restored = hivemind.combine_from_streaming(chunks) assert torch.allclose(deserialize_torch_tensor(restored), tensor) chunk_size = 30 * 1024 serialized_tensor = serialize_torch_tensor(tensor, CompressionType.FLOAT16) chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size)) assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1 restored = hivemind.combine_from_streaming(chunks) assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=0, atol=1e-2) tensor = torch.randint(0, 100, (512, 1, 1)) serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE) chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size)) assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1 restored = hivemind.combine_from_streaming(chunks) assert torch.allclose(deserialize_torch_tensor(restored), tensor) scalar = torch.tensor(1.) serialized_scalar = serialize_torch_tensor(scalar, CompressionType.NONE) assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar) serialized_scalar = serialize_torch_tensor(scalar, CompressionType.FLOAT16) assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar)
def handle_add_torch(args): args = MSGPackSerializer.loads(args) tensor = runtime_pb2.Tensor() tensor.ParseFromString(args[0]) result = deserialize_torch_tensor(tensor) for i in range(1, len(args)): tensor = runtime_pb2.Tensor() tensor.ParseFromString(args[i]) result = result + deserialize_torch_tensor(tensor) return serialize_torch_tensor(result).SerializeToString()
def test_split_parts(): tensor = torch.randn(910, 512) serialized_tensor_part = serialize_torch_tensor(tensor, allow_inplace=False) chunks1 = list( hivemind.utils.split_for_streaming(serialized_tensor_part, 16384)) assert len(chunks1) == int( np.ceil(tensor.numel() * tensor.element_size() / 16384)) chunks2 = list( hivemind.utils.split_for_streaming(serialized_tensor_part, 10_000)) assert len(chunks2) == int( np.ceil(tensor.numel() * tensor.element_size() / 10_000)) chunks3 = list( hivemind.utils.split_for_streaming(serialized_tensor_part, 10**9)) assert len(chunks3) == 1 compressed_tensor_part = serialize_torch_tensor(tensor, CompressionType.FLOAT16, allow_inplace=False) chunks4 = list( hivemind.utils.split_for_streaming(compressed_tensor_part, 16384)) assert len(chunks4) == int(np.ceil(tensor.numel() * 2 / 16384)) combined1 = hivemind.utils.combine_from_streaming(chunks1) combined2 = hivemind.utils.combine_from_streaming(iter(chunks2)) combined3 = hivemind.utils.combine_from_streaming(chunks3) combined4 = hivemind.utils.combine_from_streaming(chunks4) for combined in combined1, combined2, combined3: assert torch.allclose(tensor, deserialize_torch_tensor(combined), rtol=1e-5, atol=1e-8) assert torch.allclose(tensor, deserialize_torch_tensor(combined4), rtol=1e-3, atol=1e-3) combined_incomplete = hivemind.utils.combine_from_streaming(chunks4[:5]) combined_incomplete2 = hivemind.utils.combine_from_streaming(chunks4[:1]) combined_incomplete3 = hivemind.utils.combine_from_streaming(chunks4[:-1]) for combined in combined_incomplete, combined_incomplete2, combined_incomplete3: with pytest.raises(RuntimeError): deserialize_torch_tensor(combined)
async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor: """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """ assert self.peer_modes[ self. endpoint] != AveragingMode.AUX, "Auxiliary peers are disallowed from sending tensors" if peer_endpoint == self.endpoint: return await self.accumulate_part( self.endpoint, local_part, weight=self.peer_weights[self.endpoint]) serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False) chunks = split_for_streaming(serialized_tensor_part, self.chunk_size_bytes) stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part() await stream.write( averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING, group_id=self.group_id, endpoint=self.endpoint, tensor_part=next(chunks))) for chunk in chunks: await stream.write(averaging_pb2.AveragingData(tensor_part=chunk)) await stream.done_writing() outputs: Sequence[averaging_pb2.AveragingData] = [ message async for message in stream ] code = outputs[0].code if outputs else averaging_pb2.INTERNAL_ERROR if code != averaging_pb2.AVERAGED_PART: raise AllreduceException( f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)}" f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}," f" allreduce failed") try: averaged_part = local_part + deserialize_torch_tensor( combine_from_streaming( [message.tensor_part for message in outputs])) except RuntimeError as e: raise AllreduceException( f"Could not deserialize averaged part from {peer_endpoint}: {e}" ) self.register_averaged_part(peer_endpoint, averaged_part) return averaged_part
async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext): inputs_and_grad_outputs = [ deserialize_torch_tensor(tensor) for tensor in request.tensors ] future = self.experts[request.uid].backward_pool.submit_task( *inputs_and_grad_outputs) serialized_response = [ serialize_torch_tensor(tensor, proto.compression, allow_inplace=True) for tensor, proto in zip( await future, nested_flatten(self.experts[request.uid].grad_inputs_schema)) ] return runtime_pb2.ExpertResponse(tensors=serialized_response)
def _process_dispatched_task( task: grpc.Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]: if task.exception() or task.cancelled(): logger.warning(f"Task {task} failed: {type(task.exception())}") return None deserialized_outputs = [] for tensor in task.result().tensors: deserialized_tensor = deserialize_torch_tensor(tensor) if detect_anomalies and not deserialized_tensor.isfinite().all(): logger.error( f"Task {task} failed: output tensor contains nan/inf values") return None deserialized_outputs.append(deserialized_tensor) return tuple(deserialized_outputs)
async def accumulate_part_streaming( self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor] ) -> Iterable[runtime_pb2.Tensor]: """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """ try: tensor_part = deserialize_torch_tensor( combine_from_streaming(stream_messages)) except RuntimeError as e: raise AllreduceException( f"Could not deserialize tensor part from {source} for streaming {e}" ) averaged_part = await self.accumulate_part( source, tensor_part, weight=self.peer_weights[source]) serialized_tensor = serialize_torch_tensor(averaged_part - tensor_part, self.compression_type, allow_inplace=False) stream_chunks = tuple( split_for_streaming(serialized_tensor, self.chunk_size_bytes)) return stream_chunks
async def test_call_peer_torch_square(test_input, expected, handler_name="handle"): handle = handle_square_torch server = await P2P.create() await server.add_stream_handler(handler_name, handle) nodes = bootstrap_from([server]) client = await P2P.create(bootstrap=True, bootstrap_peers=nodes) await client.wait_for_at_least_n_peers(1) inp = serialize_torch_tensor(test_input).SerializeToString() result_pb = await client.call_peer_handler(server.id, handler_name, inp) result = runtime_pb2.Tensor() result.ParseFromString(result_pb) result = deserialize_torch_tensor(result) assert torch.allclose(result, expected) await server.stop_listening() await server.shutdown() await client.shutdown()
def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008): torch.manual_seed(0) X = torch.randn(*size) assert torch.allclose( deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.NONE)), X) error = deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.MEANSTD_16BIT)) - X assert error.square().mean() < alpha error = deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.FLOAT16)) - X assert error.square().mean() < alpha error = deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X assert error.square().mean() < beta error = deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X assert error.square().mean() < beta zeros = torch.zeros(5, 5) for compression_type in CompressionType.values(): assert deserialize_torch_tensor( serialize_torch_tensor(zeros, compression_type)).isfinite().all()
async def _load_state_from_peers(self, future: MPFuture): try: key_manager = self._matchmaking.group_key_manager peer_priority, _ = self.dht.get( f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None) peer_priority = { peer: float(info.value) for peer, info in peer_priority.items() if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int)) } if not isinstance(peer_priority, dict) or len(peer_priority) == 0: logger.info( f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}." ) future.set_result(None) return metadata = None for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True): if peer != self.endpoint: logger.info(f"Downloading parameters from peer {peer}") stream = None try: stub = ChannelCache.get_stub( peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True) stream = stub.rpc_download_state( averaging_pb2.DownloadRequest()) current_tensor_parts, tensors = [], [] async for message in stream: if message.metadata: metadata = self.serializer.loads( message.metadata) if message.tensor_part.dtype and current_tensor_parts: # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one tensors.append( deserialize_torch_tensor( combine_from_streaming( current_tensor_parts))) current_tensor_parts = [] current_tensor_parts.append(message.tensor_part) if current_tensor_parts: tensors.append( deserialize_torch_tensor( combine_from_streaming( current_tensor_parts))) if not metadata: logger.debug( f"Peer {peer} did not send its state.") continue logger.info(f"Finished downloading state from {peer}") future.set_result((metadata, tensors)) self.last_updated = get_dht_time() return except BaseException as e: logger.exception( f"Failed to download state from {peer} - {repr(e)}" ) finally: if stream is not None: await stream.code() finally: if not future.done(): logger.warning( "Averager could not load state from peers: all requests have failed." ) future.set_result(None)
def handle_square_torch(x): tensor = runtime_pb2.Tensor() tensor.ParseFromString(x) tensor = deserialize_torch_tensor(tensor) result = tensor**2 return serialize_torch_tensor(result).SerializeToString()
def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float: t = time.time() deserialize_torch_tensor(serialize_torch_tensor(tensor, compression_type)) return time.time() - t