def test_serialize_tensor():
    tensor = torch.randn(512, 12288)

    serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE)
    for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10**9]:
        chunks = list(
            hivemind.split_for_streaming(serialized_tensor, chunk_size))
        assert len(
            chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
        restored = hivemind.combine_from_streaming(chunks)
        assert torch.allclose(deserialize_torch_tensor(restored), tensor)

    chunk_size = 30 * 1024
    serialized_tensor = serialize_torch_tensor(tensor, CompressionType.FLOAT16)
    chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
    assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
    restored = hivemind.combine_from_streaming(chunks)
    assert torch.allclose(deserialize_torch_tensor(restored),
                          tensor,
                          rtol=0,
                          atol=1e-2)

    tensor = torch.randint(0, 100, (512, 1, 1))
    serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE)
    chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size))
    assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
    restored = hivemind.combine_from_streaming(chunks)
    assert torch.allclose(deserialize_torch_tensor(restored), tensor)

    scalar = torch.tensor(1.)
    serialized_scalar = serialize_torch_tensor(scalar, CompressionType.NONE)
    assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar)

    serialized_scalar = serialize_torch_tensor(scalar, CompressionType.FLOAT16)
    assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar)
Exemple #2
0
    async def rpc_download_state(
        self, request: averaging_pb2.DownloadRequest,
        context: grpc.ServicerContext
    ) -> AsyncIterator[averaging_pb2.DownloadData]:
        """
        Get the up-to-date trainer state from a peer.
        The state consists of two parts: (serialized_metadata, tensors)

         - serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
         - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
        """
        if not self.allow_state_sharing:
            return  # deny request and direct peer to the next prospective averager
        chunk_size_bytes = self.matchmaking_kwargs.get(
            'chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
        metadata, tensors = await self._get_current_state_from_host_process()

        for tensor in tensors:
            for part in split_for_streaming(serialize_torch_tensor(tensor),
                                            chunk_size_bytes):
                if metadata is not None:
                    yield averaging_pb2.DownloadData(tensor_part=part,
                                                     metadata=metadata)
                    metadata = None
                else:
                    yield averaging_pb2.DownloadData(tensor_part=part)
def test_split_parts():
    tensor = torch.randn(910, 512)
    serialized_tensor_part = serialize_torch_tensor(tensor,
                                                    allow_inplace=False)
    chunks1 = list(
        hivemind.utils.split_for_streaming(serialized_tensor_part, 16384))
    assert len(chunks1) == int(
        np.ceil(tensor.numel() * tensor.element_size() / 16384))

    chunks2 = list(
        hivemind.utils.split_for_streaming(serialized_tensor_part, 10_000))
    assert len(chunks2) == int(
        np.ceil(tensor.numel() * tensor.element_size() / 10_000))

    chunks3 = list(
        hivemind.utils.split_for_streaming(serialized_tensor_part, 10**9))
    assert len(chunks3) == 1

    compressed_tensor_part = serialize_torch_tensor(tensor,
                                                    CompressionType.FLOAT16,
                                                    allow_inplace=False)
    chunks4 = list(
        hivemind.utils.split_for_streaming(compressed_tensor_part, 16384))
    assert len(chunks4) == int(np.ceil(tensor.numel() * 2 / 16384))

    combined1 = hivemind.utils.combine_from_streaming(chunks1)
    combined2 = hivemind.utils.combine_from_streaming(iter(chunks2))
    combined3 = hivemind.utils.combine_from_streaming(chunks3)
    combined4 = hivemind.utils.combine_from_streaming(chunks4)
    for combined in combined1, combined2, combined3:
        assert torch.allclose(tensor,
                              deserialize_torch_tensor(combined),
                              rtol=1e-5,
                              atol=1e-8)

    assert torch.allclose(tensor,
                          deserialize_torch_tensor(combined4),
                          rtol=1e-3,
                          atol=1e-3)

    combined_incomplete = hivemind.utils.combine_from_streaming(chunks4[:5])
    combined_incomplete2 = hivemind.utils.combine_from_streaming(chunks4[:1])
    combined_incomplete3 = hivemind.utils.combine_from_streaming(chunks4[:-1])
    for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
        with pytest.raises(RuntimeError):
            deserialize_torch_tensor(combined)
def handle_add_torch(args):
    args = MSGPackSerializer.loads(args)
    tensor = runtime_pb2.Tensor()
    tensor.ParseFromString(args[0])
    result = deserialize_torch_tensor(tensor)

    for i in range(1, len(args)):
        tensor = runtime_pb2.Tensor()
        tensor.ParseFromString(args[i])
        result = result + deserialize_torch_tensor(tensor)

    return serialize_torch_tensor(result).SerializeToString()
Exemple #5
0
    async def _communicate_with_peer(self, peer_endpoint: Endpoint,
                                     local_part: torch.Tensor) -> torch.Tensor:
        """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
        assert self.peer_modes[
            self.
            endpoint] != AveragingMode.AUX, "Auxiliary peers are disallowed from sending tensors"
        if peer_endpoint == self.endpoint:
            return await self.accumulate_part(
                self.endpoint,
                local_part,
                weight=self.peer_weights[self.endpoint])
        serialized_tensor_part = serialize_torch_tensor(local_part,
                                                        self.compression_type,
                                                        allow_inplace=False)
        chunks = split_for_streaming(serialized_tensor_part,
                                     self.chunk_size_bytes)

        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
        await stream.write(
            averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING,
                                        group_id=self.group_id,
                                        endpoint=self.endpoint,
                                        tensor_part=next(chunks)))
        for chunk in chunks:
            await stream.write(averaging_pb2.AveragingData(tensor_part=chunk))
        await stream.done_writing()

        outputs: Sequence[averaging_pb2.AveragingData] = [
            message async for message in stream
        ]
        code = outputs[0].code if outputs else averaging_pb2.INTERNAL_ERROR
        if code != averaging_pb2.AVERAGED_PART:
            raise AllreduceException(
                f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)}"
                f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)},"
                f" allreduce failed")

        try:
            averaged_part = local_part + deserialize_torch_tensor(
                combine_from_streaming(
                    [message.tensor_part for message in outputs]))
        except RuntimeError as e:
            raise AllreduceException(
                f"Could not deserialize averaged part from {peer_endpoint}: {e}"
            )

        self.register_averaged_part(peer_endpoint, averaged_part)
        return averaged_part
Exemple #6
0
 async def backward(self, request: runtime_pb2.ExpertRequest,
                    context: grpc.ServicerContext):
     inputs_and_grad_outputs = [
         deserialize_torch_tensor(tensor) for tensor in request.tensors
     ]
     future = self.experts[request.uid].backward_pool.submit_task(
         *inputs_and_grad_outputs)
     serialized_response = [
         serialize_torch_tensor(tensor,
                                proto.compression,
                                allow_inplace=True)
         for tensor, proto in zip(
             await future,
             nested_flatten(self.experts[request.uid].grad_inputs_schema))
     ]
     return runtime_pb2.ExpertResponse(tensors=serialized_response)
Exemple #7
0
    async def accumulate_part_streaming(
        self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor]
    ) -> Iterable[runtime_pb2.Tensor]:
        """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """
        try:
            tensor_part = deserialize_torch_tensor(
                combine_from_streaming(stream_messages))
        except RuntimeError as e:
            raise AllreduceException(
                f"Could not deserialize tensor part from {source} for streaming {e}"
            )

        averaged_part = await self.accumulate_part(
            source, tensor_part, weight=self.peer_weights[source])
        serialized_tensor = serialize_torch_tensor(averaged_part - tensor_part,
                                                   self.compression_type,
                                                   allow_inplace=False)
        stream_chunks = tuple(
            split_for_streaming(serialized_tensor, self.chunk_size_bytes))
        return stream_chunks
async def test_call_peer_torch_square(test_input,
                                      expected,
                                      handler_name="handle"):
    handle = handle_square_torch
    server = await P2P.create()
    await server.add_stream_handler(handler_name, handle)

    nodes = bootstrap_from([server])
    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)

    await client.wait_for_at_least_n_peers(1)

    inp = serialize_torch_tensor(test_input).SerializeToString()
    result_pb = await client.call_peer_handler(server.id, handler_name, inp)
    result = runtime_pb2.Tensor()
    result.ParseFromString(result_pb)
    result = deserialize_torch_tensor(result)
    assert torch.allclose(result, expected)

    await server.stop_listening()
    await server.shutdown()
    await client.shutdown()
async def test_call_peer_error(replicate, handler_name="handle"):
    server_primary = await P2P.create()
    server = await replicate_if_needed(server_primary, replicate)
    await server.add_stream_handler(handler_name, handle_add_torch_with_exc)

    nodes = bootstrap_from([server])
    client_primary = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
    client = await replicate_if_needed(client_primary, replicate)

    await client.wait_for_at_least_n_peers(1)

    inp = [
        serialize_torch_tensor(i).SerializeToString()
        for i in [torch.zeros((2,
                               3)), torch.zeros((3, 2))]
    ]
    inp_msgp = MSGPackSerializer.dumps(inp)
    result = await client.call_peer_handler(server.id, handler_name, inp_msgp)
    assert result == b'something went wrong :('

    await server.stop_listening()
    await server_primary.shutdown()
    await client_primary.shutdown()
def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
    torch.manual_seed(0)
    X = torch.randn(*size)
    assert torch.allclose(
        deserialize_torch_tensor(
            serialize_torch_tensor(X, CompressionType.NONE)), X)
    error = deserialize_torch_tensor(
        serialize_torch_tensor(X, CompressionType.MEANSTD_16BIT)) - X
    assert error.square().mean() < alpha
    error = deserialize_torch_tensor(
        serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
    assert error.square().mean() < alpha
    error = deserialize_torch_tensor(
        serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X
    assert error.square().mean() < beta
    error = deserialize_torch_tensor(
        serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
    assert error.square().mean() < beta

    zeros = torch.zeros(5, 5)
    for compression_type in CompressionType.values():
        assert deserialize_torch_tensor(
            serialize_torch_tensor(zeros, compression_type)).isfinite().all()
def handle_square_torch(x):
    tensor = runtime_pb2.Tensor()
    tensor.ParseFromString(x)
    tensor = deserialize_torch_tensor(tensor)
    result = tensor**2
    return serialize_torch_tensor(result).SerializeToString()
def benchmark_compression(tensor: torch.Tensor,
                          compression_type: CompressionType) -> float:
    t = time.time()
    deserialize_torch_tensor(serialize_torch_tensor(tensor, compression_type))
    return time.time() - t
Exemple #13
0
    def backward(
            cls, ctx,
            *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
        assert not torch.is_grad_enabled()
        (info, backward_k_min, backward_timeout, timeout_after_k_min,
         expert_per_sample, detect_anomalies) = ctx._saved_non_tensors
        alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors

        dummy_grad_mask, *flat_grad_outputs = raw_grads

        flat_grad_outputs_cpu = []
        for tensor in flat_grad_outputs:
            if detect_anomalies and not tensor.isfinite().all():
                raise ValueError("One of gradients has nan/inf values")
            flat_grad_outputs_cpu.append(tensor.cpu())

        num_samples, max_experts = dummy_grad_mask.shape

        inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0)
                                  for tensor in flat_inputs_cpu))
        grad_outputs_per_expert = zip(*(tensor[alive_ii,
                                               alive_jj].split(1, dim=0)
                                        for tensor in flat_grad_outputs_cpu))
        backward_schema = tuple(
            nested_flatten((info["forward_schema"], info["outputs_schema"])))

        # dispatch tasks to all remote experts, collect responses
        pending_tasks = {}
        for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(),
                                                    alive_jj.cpu().numpy(),
                                                    inputs_per_expert,
                                                    grad_outputs_per_expert):
            expert = expert_per_sample[i.item()][j.item()]
            stub = _get_expert_stub(expert.endpoint)
            inputs_and_grad_outputs = tuple(
                nested_flatten((inputs_ij, grad_outputs_ij)))
            tensors_serialized = [
                serialize_torch_tensor(tensor, proto.compression) for tensor,
                proto in zip(inputs_and_grad_outputs, backward_schema)
            ]
            new_task = stub.backward.future(
                runtime_pb2.ExpertRequest(uid=expert.uid,
                                          tensors=tensors_serialized))
            pending_tasks[new_task] = (i, j)

        survivor_inds, survivor_grad_inputs = cls._collect_responses(
            pending_tasks, num_samples, backward_k_min, backward_timeout,
            timeout_after_k_min, detect_anomalies)
        if len(survivor_inds) < backward_k_min:
            raise TimeoutError(
                f"Backward pass: less than {backward_k_min} experts responded within timeout."
            )

        # assemble responses
        batch_inds, expert_inds = map(
            lambda x: torch.as_tensor(x, dtype=torch.long),
            list(zip(*survivor_inds)) or ([], []))

        survivor_grad_inputs_stacked = (torch.cat(grad_inputs)
                                        for grad_inputs in zip(
                                            *survivor_grad_inputs))
        # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]

        grad_inputs = nested_map(
            lambda descr: descr.make_empty(
                num_samples, device=flat_grad_outputs[0].device).zero_(),
            list(nested_flatten(info['forward_schema'])))

        for grad_input, survivor_grad_stacked in zip(
                grad_inputs, survivor_grad_inputs_stacked):
            grad_input_per_expert = torch.zeros(  # gradient tensor with individual contributions from each expert
                (num_samples, max_experts, *grad_input.shape[1:]),
                device=survivor_grad_stacked.device,
                dtype=survivor_grad_stacked.dtype)
            grad_input_per_expert[batch_inds,
                                  expert_inds] = survivor_grad_stacked
            grad_input.copy_(
                grad_input_per_expert.to(
                    flat_grad_outputs[0].device).sum(dim=1))

        return (DUMMY, None, None, None, None, None, None, None, None, None,
                *grad_inputs)
Exemple #14
0
    def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]],
                k_min: int, backward_k_min: int, timeout_after_k_min: float,
                forward_timeout: Optional[float],
                backward_timeout: Optional[float], detect_anomalies: bool,
                allow_zero_outputs: bool, info: Dict[str, Any],
                *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
        assert not torch.is_grad_enabled()
        num_samples, max_experts = len(experts_per_sample), max(
            map(len, experts_per_sample))

        flat_inputs_cpu = []
        for tensor in flat_inputs:
            if detect_anomalies and not tensor.isfinite().all():
                raise ValueError("One of inputs has nan/inf values")
            flat_inputs_cpu.append(tensor.cpu())

        flat_inputs_per_sample = list(
            zip(*(x.split(1, dim=0) for x in flat_inputs_cpu)))
        assert len(experts_per_sample) == len(
            flat_inputs_per_sample) == num_samples

        # dispatch tasks to all remote experts collect responses
        pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {}
        for i in range(num_samples):
            for j, expert in enumerate(experts_per_sample[i]):
                input_tensors = [
                    serialize_torch_tensor(tensor, proto.compression)
                    for tensor, proto in zip(
                        flat_inputs_per_sample[i],
                        nested_flatten(info['forward_schema']))
                ]
                stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(
                    expert.endpoint)
                new_task = stub.forward.future(
                    runtime_pb2.ExpertRequest(uid=expert.uid,
                                              tensors=input_tensors))
                pending_tasks[new_task] = (i, j)

        responded_inds, alive_flat_outputs = cls._collect_responses(
            pending_tasks, num_samples, k_min, forward_timeout,
            timeout_after_k_min, detect_anomalies)
        if len(responded_inds) < k_min:
            raise TimeoutError(
                f"Forward pass: less than {k_min} responded within timeout.")

        if not isinstance(info['outputs_schema'], tuple):
            outputs_schema = (info['outputs_schema'], )
        else:
            outputs_schema = info['outputs_schema']
        outputs = nested_map(
            lambda descriptor: descriptor.make_empty(
                num_samples, max_experts, device=flat_inputs[0].device).zero_(
                ), outputs_schema)

        # assemble responses
        if len(responded_inds) > 0 or allow_zero_outputs:
            batch_inds, expert_inds = map(
                lambda x: torch.as_tensor(
                    x, device=flat_inputs[0].device, dtype=torch.long),
                list(zip(*responded_inds)) or ([], []))

            alive_flat_outputs_stacked = (torch.cat(outputs)
                                          for outputs in zip(
                                              *alive_flat_outputs))
            # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]

            for output, response_stacked in zip(outputs,
                                                alive_flat_outputs_stacked):
                output[batch_inds,
                       expert_inds] = response_stacked.to(output.device)

        else:
            raise RuntimeError(
                'Forward pass: 0 experts responded within timeout and allow_zero_outputs is False'
            )

        mask = torch.zeros([num_samples, max_experts],
                           dtype=torch.bool,
                           device=flat_inputs[0].device)
        mask[batch_inds, expert_inds] = True

        # save individual outputs for backward pass
        ctx.save_for_backward(batch_inds, expert_inds, *flat_inputs_cpu)
        ctx._saved_non_tensors = (info, backward_k_min, backward_timeout,
                                  timeout_after_k_min, experts_per_sample,
                                  detect_anomalies)

        return (mask, ) + outputs