def handle_add_torch(args): args = MSGPackSerializer.loads(args) tensor = runtime_pb2.Tensor() tensor.ParseFromString(args[0]) result = deserialize_torch_tensor(tensor) for i in range(1, len(args)): tensor = runtime_pb2.Tensor() tensor.ParseFromString(args[i]) result = result + deserialize_torch_tensor(tensor) return serialize_torch_tensor(result).SerializeToString()
def split_for_streaming(serialized_tensor: runtime_pb2.Tensor, chunk_size_bytes: int) -> Iterator[runtime_pb2.Tensor]: """ Split serialized_tensor into multiple chunks for gRPC streaming """ buffer = memoryview(serialized_tensor.buffer) num_chunks = len(range(0, len(buffer), chunk_size_bytes)) yield runtime_pb2.Tensor(compression=serialized_tensor.compression, buffer=buffer[:chunk_size_bytes].tobytes(), chunks=num_chunks, size=serialized_tensor.size, dtype=serialized_tensor.dtype, requires_grad=serialized_tensor.requires_grad) for chunk_start in range(chunk_size_bytes, len(buffer), chunk_size_bytes): yield runtime_pb2.Tensor(buffer=buffer[chunk_start:chunk_start + chunk_size_bytes].tobytes())
def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionType.NONE, allow_inplace=False) -> runtime_pb2.Tensor: assert tensor.device == torch.device('cpu') if compression_type == CompressionType.MEANSTD_LAST_AXIS_FLOAT16: assert tensor.dtype == torch.float32 tensor = tensor if allow_inplace else tensor.clone() means = torch.mean(tensor, dim=-1, keepdim=True) tensor.sub_(means) stds = torch.square(tensor).sum(dim=-1, keepdim=True).div_( tensor.shape[-1]).sqrt_() tensor.div_(stds) tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16) data = b''.join((tensor.numpy().tobytes(), means.numpy().tobytes(), stds.numpy().tobytes())) proto = runtime_pb2.Tensor(compression=compression_type, buffer=data, size=tensor.shape, dtype='compressed_float32', requires_grad=tensor.requires_grad) elif compression_type == CompressionType.FLOAT16: assert tensor.dtype == torch.float32 tensor = tensor if allow_inplace else tensor.clone() tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16) data = tensor.numpy().tobytes() proto = runtime_pb2.Tensor(compression=compression_type, buffer=data, size=tensor.shape, dtype='clamped_float32', requires_grad=tensor.requires_grad) elif compression_type == CompressionType.NONE: array = tensor.numpy() proto = runtime_pb2.Tensor(compression=compression_type, buffer=array.tobytes(), size=array.shape, dtype=array.dtype.name, requires_grad=tensor.requires_grad) else: raise ValueError(f"Unknown compression type: {compression_type}") return proto
def serialize_torch_tensor(tensor: torch.Tensor) -> runtime_pb2.Tensor: array = tensor.numpy() proto = runtime_pb2.Tensor(buffer=array.tobytes(), size=array.shape, dtype=array.dtype.name, requires_grad=tensor.requires_grad) return proto
def combine_from_streaming( stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.Tensor: """ Restore a result of split_into_chunks into a single serialized tensor """ stream = iter(stream) first_chunk = next(stream) serialized_tensor = runtime_pb2.Tensor() serialized_tensor.CopyFrom(first_chunk) buffer_chunks = [first_chunk.buffer] for tensor_part in stream: buffer_chunks.append(tensor_part.buffer) serialized_tensor.buffer = b''.join(buffer_chunks) return serialized_tensor
async def test_call_peer_torch_square(test_input, expected, handler_name="handle"): handle = handle_square_torch server = await P2P.create() await server.add_stream_handler(handler_name, handle) nodes = bootstrap_from([server]) client = await P2P.create(bootstrap=True, bootstrap_peers=nodes) await client.wait_for_at_least_n_peers(1) inp = serialize_torch_tensor(test_input).SerializeToString() result_pb = await client.call_peer_handler(server.id, handler_name, inp) result = runtime_pb2.Tensor() result.ParseFromString(result_pb) result = deserialize_torch_tensor(result) assert torch.allclose(result, expected) await server.stop_listening() await server.shutdown() await client.shutdown()
def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionType.NONE, allow_inplace=False) -> runtime_pb2.Tensor: assert tensor.device == torch.device('cpu') if compression_type == CompressionType.MEANSTD_16BIT: assert tensor.dtype == torch.float32 tensor = tensor if allow_inplace else tensor.clone() means = torch.mean(tensor, dim=-1, keepdim=True) tensor.sub_(means) stds = torch.square(tensor).sum(dim=-1, keepdim=True).div_( tensor.shape[-1]).sqrt_() stds.clamp_min_(FP32_EPS) tensor.div_(stds) tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16) data = b''.join((tensor.numpy().tobytes(), means.numpy().tobytes(), stds.numpy().tobytes())) proto = runtime_pb2.Tensor(compression=compression_type, buffer=data, size=tensor.shape, dtype='compressed_float32', requires_grad=tensor.requires_grad) elif compression_type == CompressionType.FLOAT16: assert tensor.dtype == torch.float32 tensor = tensor if allow_inplace else tensor.clone() tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16) data = tensor.numpy().tobytes() proto = runtime_pb2.Tensor(compression=compression_type, buffer=data, size=tensor.shape, dtype='clamped_float32', requires_grad=tensor.requires_grad) elif compression_type == CompressionType.NONE: array = tensor.numpy() proto = runtime_pb2.Tensor(compression=compression_type, buffer=array.tobytes(), size=array.shape, dtype=array.dtype.name, requires_grad=tensor.requires_grad) elif compression_type in (CompressionType.QUANTILE_8BIT, CompressionType.UNIFORM_8BIT): assert tensor.dtype == torch.float32 if compression_type == CompressionType.QUANTILE_8BIT: quantized, lookup = _quantile_encode_approx( tensor.detach(), NUM_BITS_QUANTILE_COMPRESSION) elif compression_type == CompressionType.UNIFORM_8BIT: quantized, lookup = _uint8_uniform_buckets_encode( tensor.detach(), UNIFORM_BUCKETS_STD_RANGE) data = b''.join((lookup.numpy().tobytes(), quantized.numpy().astype(np.uint8).tobytes())) proto = runtime_pb2.Tensor(compression=compression_type, buffer=data, size=tensor.shape, dtype='compressed_float32', requires_grad=tensor.requires_grad) else: raise ValueError(f"Unknown compression type: {compression_type}") return proto
def handle_square_torch(x): tensor = runtime_pb2.Tensor() tensor.ParseFromString(x) tensor = deserialize_torch_tensor(tensor) result = tensor**2 return serialize_torch_tensor(result).SerializeToString()