Пример #1
0
def test_serialize_tensor():
    tensor = torch.randn(512, 12288)

    serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE)
    for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10**9]:
        chunks = list(
            hivemind.split_for_streaming(serialized_tensor, chunk_size))
        assert len(
            chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
        restored = hivemind.combine_from_streaming(chunks)
        assert torch.allclose(deserialize_torch_tensor(restored), tensor)

    chunk_size = 30 * 1024
    serialized_tensor = serialize_torch_tensor(tensor, CompressionType.FLOAT16)
    chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
    assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
    restored = hivemind.combine_from_streaming(chunks)
    assert torch.allclose(deserialize_torch_tensor(restored),
                          tensor,
                          rtol=0,
                          atol=1e-2)

    tensor = torch.randint(0, 100, (512, 1, 1))
    serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE)
    chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
    assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
    restored = hivemind.combine_from_streaming(chunks)
    assert torch.allclose(deserialize_torch_tensor(restored), tensor)

    scalar = torch.tensor(1.)
    serialized_scalar = serialize_torch_tensor(scalar, CompressionType.NONE)
    assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar)

    serialized_scalar = serialize_torch_tensor(scalar, CompressionType.FLOAT16)
    assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar)
Пример #2
0
def handle_add_torch(args):
    args = MSGPackSerializer.loads(args)
    tensor = runtime_pb2.Tensor()
    tensor.ParseFromString(args[0])
    result = deserialize_torch_tensor(tensor)

    for i in range(1, len(args)):
        tensor = runtime_pb2.Tensor()
        tensor.ParseFromString(args[i])
        result = result + deserialize_torch_tensor(tensor)

    return serialize_torch_tensor(result).SerializeToString()
Пример #3
0
def test_split_parts():
    tensor = torch.randn(910, 512)
    serialized_tensor_part = serialize_torch_tensor(tensor,
                                                    allow_inplace=False)
    chunks1 = list(
        hivemind.utils.split_for_streaming(serialized_tensor_part, 16384))
    assert len(chunks1) == int(
        np.ceil(tensor.numel() * tensor.element_size() / 16384))

    chunks2 = list(
        hivemind.utils.split_for_streaming(serialized_tensor_part, 10_000))
    assert len(chunks2) == int(
        np.ceil(tensor.numel() * tensor.element_size() / 10_000))

    chunks3 = list(
        hivemind.utils.split_for_streaming(serialized_tensor_part, 10**9))
    assert len(chunks3) == 1

    compressed_tensor_part = serialize_torch_tensor(tensor,
                                                    CompressionType.FLOAT16,
                                                    allow_inplace=False)
    chunks4 = list(
        hivemind.utils.split_for_streaming(compressed_tensor_part, 16384))
    assert len(chunks4) == int(np.ceil(tensor.numel() * 2 / 16384))

    combined1 = hivemind.utils.combine_from_streaming(chunks1)
    combined2 = hivemind.utils.combine_from_streaming(iter(chunks2))
    combined3 = hivemind.utils.combine_from_streaming(chunks3)
    combined4 = hivemind.utils.combine_from_streaming(chunks4)
    for combined in combined1, combined2, combined3:
        assert torch.allclose(tensor,
                              deserialize_torch_tensor(combined),
                              rtol=1e-5,
                              atol=1e-8)

    assert torch.allclose(tensor,
                          deserialize_torch_tensor(combined4),
                          rtol=1e-3,
                          atol=1e-3)

    combined_incomplete = hivemind.utils.combine_from_streaming(chunks4[:5])
    combined_incomplete2 = hivemind.utils.combine_from_streaming(chunks4[:1])
    combined_incomplete3 = hivemind.utils.combine_from_streaming(chunks4[:-1])
    for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
        with pytest.raises(RuntimeError):
            deserialize_torch_tensor(combined)
Пример #4
0
    async def _communicate_with_peer(self, peer_endpoint: Endpoint,
                                     local_part: torch.Tensor) -> torch.Tensor:
        """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
        assert self.peer_modes[
            self.
            endpoint] != AveragingMode.AUX, "Auxiliary peers are disallowed from sending tensors"
        if peer_endpoint == self.endpoint:
            return await self.accumulate_part(
                self.endpoint,
                local_part,
                weight=self.peer_weights[self.endpoint])
        serialized_tensor_part = serialize_torch_tensor(local_part,
                                                        self.compression_type,
                                                        allow_inplace=False)
        chunks = split_for_streaming(serialized_tensor_part,
                                     self.chunk_size_bytes)

        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
        await stream.write(
            averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING,
                                        group_id=self.group_id,
                                        endpoint=self.endpoint,
                                        tensor_part=next(chunks)))
        for chunk in chunks:
            await stream.write(averaging_pb2.AveragingData(tensor_part=chunk))
        await stream.done_writing()

        outputs: Sequence[averaging_pb2.AveragingData] = [
            message async for message in stream
        ]
        code = outputs[0].code if outputs else averaging_pb2.INTERNAL_ERROR
        if code != averaging_pb2.AVERAGED_PART:
            raise AllreduceException(
                f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)}"
                f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)},"
                f" allreduce failed")

        try:
            averaged_part = local_part + deserialize_torch_tensor(
                combine_from_streaming(
                    [message.tensor_part for message in outputs]))
        except RuntimeError as e:
            raise AllreduceException(
                f"Could not deserialize averaged part from {peer_endpoint}: {e}"
            )

        self.register_averaged_part(peer_endpoint, averaged_part)
        return averaged_part
Пример #5
0
 async def backward(self, request: runtime_pb2.ExpertRequest,
                    context: grpc.ServicerContext):
     inputs_and_grad_outputs = [
         deserialize_torch_tensor(tensor) for tensor in request.tensors
     ]
     future = self.experts[request.uid].backward_pool.submit_task(
         *inputs_and_grad_outputs)
     serialized_response = [
         serialize_torch_tensor(tensor,
                                proto.compression,
                                allow_inplace=True)
         for tensor, proto in zip(
             await future,
             nested_flatten(self.experts[request.uid].grad_inputs_schema))
     ]
     return runtime_pb2.ExpertResponse(tensors=serialized_response)
Пример #6
0
def _process_dispatched_task(
        task: grpc.Future,
        detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
    if task.exception() or task.cancelled():
        logger.warning(f"Task {task} failed: {type(task.exception())}")
        return None

    deserialized_outputs = []
    for tensor in task.result().tensors:
        deserialized_tensor = deserialize_torch_tensor(tensor)
        if detect_anomalies and not deserialized_tensor.isfinite().all():
            logger.error(
                f"Task {task} failed: output tensor contains nan/inf values")
            return None
        deserialized_outputs.append(deserialized_tensor)

    return tuple(deserialized_outputs)
Пример #7
0
    async def accumulate_part_streaming(
        self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor]
    ) -> Iterable[runtime_pb2.Tensor]:
        """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """
        try:
            tensor_part = deserialize_torch_tensor(
                combine_from_streaming(stream_messages))
        except RuntimeError as e:
            raise AllreduceException(
                f"Could not deserialize tensor part from {source} for streaming {e}"
            )

        averaged_part = await self.accumulate_part(
            source, tensor_part, weight=self.peer_weights[source])
        serialized_tensor = serialize_torch_tensor(averaged_part - tensor_part,
                                                   self.compression_type,
                                                   allow_inplace=False)
        stream_chunks = tuple(
            split_for_streaming(serialized_tensor, self.chunk_size_bytes))
        return stream_chunks
Пример #8
0
async def test_call_peer_torch_square(test_input,
                                      expected,
                                      handler_name="handle"):
    handle = handle_square_torch
    server = await P2P.create()
    await server.add_stream_handler(handler_name, handle)

    nodes = bootstrap_from([server])
    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)

    await client.wait_for_at_least_n_peers(1)

    inp = serialize_torch_tensor(test_input).SerializeToString()
    result_pb = await client.call_peer_handler(server.id, handler_name, inp)
    result = runtime_pb2.Tensor()
    result.ParseFromString(result_pb)
    result = deserialize_torch_tensor(result)
    assert torch.allclose(result, expected)

    await server.stop_listening()
    await server.shutdown()
    await client.shutdown()
Пример #9
0
def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
    torch.manual_seed(0)
    X = torch.randn(*size)
    assert torch.allclose(
        deserialize_torch_tensor(
            serialize_torch_tensor(X, CompressionType.NONE)), X)
    error = deserialize_torch_tensor(
        serialize_torch_tensor(X, CompressionType.MEANSTD_16BIT)) - X
    assert error.square().mean() < alpha
    error = deserialize_torch_tensor(
        serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
    assert error.square().mean() < alpha
    error = deserialize_torch_tensor(
        serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X
    assert error.square().mean() < beta
    error = deserialize_torch_tensor(
        serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
    assert error.square().mean() < beta

    zeros = torch.zeros(5, 5)
    for compression_type in CompressionType.values():
        assert deserialize_torch_tensor(
            serialize_torch_tensor(zeros, compression_type)).isfinite().all()
Пример #10
0
    async def _load_state_from_peers(self, future: MPFuture):
        try:
            key_manager = self._matchmaking.group_key_manager
            peer_priority, _ = self.dht.get(
                f"{key_manager.prefix}.all_averagers",
                latest=True) or ({}, None)
            peer_priority = {
                peer: float(info.value)
                for peer, info in peer_priority.items()
                if isinstance(info, ValueWithExpiration)
                and isinstance(info.value, (float, int))
            }

            if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
                logger.info(
                    f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}."
                )
                future.set_result(None)
                return

            metadata = None
            for peer in sorted(peer_priority.keys(),
                               key=peer_priority.get,
                               reverse=True):
                if peer != self.endpoint:
                    logger.info(f"Downloading parameters from peer {peer}")
                    stream = None
                    try:
                        stub = ChannelCache.get_stub(
                            peer,
                            averaging_pb2_grpc.DecentralizedAveragingStub,
                            aio=True)
                        stream = stub.rpc_download_state(
                            averaging_pb2.DownloadRequest())
                        current_tensor_parts, tensors = [], []
                        async for message in stream:
                            if message.metadata:
                                metadata = self.serializer.loads(
                                    message.metadata)
                            if message.tensor_part.dtype and current_tensor_parts:
                                # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
                                tensors.append(
                                    deserialize_torch_tensor(
                                        combine_from_streaming(
                                            current_tensor_parts)))
                                current_tensor_parts = []
                            current_tensor_parts.append(message.tensor_part)
                        if current_tensor_parts:
                            tensors.append(
                                deserialize_torch_tensor(
                                    combine_from_streaming(
                                        current_tensor_parts)))

                        if not metadata:
                            logger.debug(
                                f"Peer {peer} did not send its state.")
                            continue

                        logger.info(f"Finished downloading state from {peer}")
                        future.set_result((metadata, tensors))
                        self.last_updated = get_dht_time()
                        return
                    except BaseException as e:
                        logger.exception(
                            f"Failed to download state from {peer} - {repr(e)}"
                        )
                    finally:
                        if stream is not None:
                            await stream.code()

        finally:
            if not future.done():
                logger.warning(
                    "Averager could not load state from peers: all requests have failed."
                )
                future.set_result(None)
Пример #11
0
def handle_square_torch(x):
    tensor = runtime_pb2.Tensor()
    tensor.ParseFromString(x)
    tensor = deserialize_torch_tensor(tensor)
    result = tensor**2
    return serialize_torch_tensor(result).SerializeToString()
def benchmark_compression(tensor: torch.Tensor,
                          compression_type: CompressionType) -> float:
    t = time.time()
    deserialize_torch_tensor(serialize_torch_tensor(tensor, compression_type))
    return time.time() - t