예제 #1
0
def test_vector_compression(size=(128, 128, 64), alpha=5e-08):
    torch.manual_seed(0)
    from hivemind.proto.runtime_pb2 import CompressionType
    from hivemind.utils import serialize_torch_tensor, deserialize_torch_tensor
    X = torch.randn(*size)
    assert torch.allclose(
        deserialize_torch_tensor(
            serialize_torch_tensor(X, CompressionType.NONE)), X)
    error = deserialize_torch_tensor(
        serialize_torch_tensor(X,
                               CompressionType.MEANSTD_LAST_AXIS_FLOAT16)) - X
    assert error.square().mean() < alpha
    error = deserialize_torch_tensor(
        serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
    assert error.square().mean() < alpha
예제 #2
0
    async def _communicate_with_peer(self, peer_endpoint: Endpoint,
                                     local_part: torch.Tensor) -> torch.Tensor:
        """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
        serialized_tensor_part = serialize_torch_tensor(local_part,
                                                        self.compression_type,
                                                        allow_inplace=False)
        chunks = split_for_streaming(serialized_tensor_part,
                                     self.chunk_size_bytes)

        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
        await stream.write(
            averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING,
                                        group_id=self.group_id,
                                        endpoint=self.endpoint,
                                        tensor_part=next(chunks)))
        for chunk in chunks:
            await stream.write(averaging_pb2.AveragingData(tensor_part=chunk))
        await stream.done_writing()

        outputs: Sequence[averaging_pb2.AveragingData] = [
            message async for message in stream
        ]
        code = outputs[0].code if outputs else averaging_pb2.INTERNAL_ERROR
        if code != averaging_pb2.AVERAGED_PART:
            raise AllreduceException(
                f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)}"
                f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)},"
                f" allreduce failed")

        averaged_part = deserialize_torch_tensor(
            combine_from_streaming(
                [message.tensor_part for message in outputs]))
        self.register_averaged_part(peer_endpoint, averaged_part)
        return averaged_part
예제 #3
0
def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
    torch.manual_seed(0)
    X = torch.randn(*size)
    assert torch.allclose(
        deserialize_torch_tensor(
            serialize_torch_tensor(X, CompressionType.NONE)), X)
    error = deserialize_torch_tensor(
        serialize_torch_tensor(X,
                               CompressionType.MEANSTD_LAST_AXIS_FLOAT16)) - X
    assert error.square().mean() < alpha
    error = deserialize_torch_tensor(
        serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
    assert error.square().mean() < alpha
    error = deserialize_torch_tensor(
        serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X
    assert error.square().mean() < beta
예제 #4
0
 async def forward(self, request: runtime_pb2.ExpertRequest,
                   context: grpc.ServicerContext):
     inputs = [
         deserialize_torch_tensor(tensor) for tensor in request.tensors
     ]
     future = self.experts[request.uid].forward_pool.submit_task(*inputs)
     serialized_response = [
         serialize_torch_tensor(tensor) for tensor in await future
     ]
     return runtime_pb2.ExpertResponse(tensors=serialized_response)
예제 #5
0
    def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
        assert not torch.is_grad_enabled()
        (info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample,
         detect_anomalies) = ctx._saved_non_tensors
        alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors

        dummy_grad_mask, *flat_grad_outputs = raw_grads

        flat_grad_outputs_cpu = []
        for tensor in flat_grad_outputs:
            if detect_anomalies and not tensor.isfinite().all():
                raise ValueError("One of gradients has nan/inf values")
            flat_grad_outputs_cpu.append(tensor.cpu())

        num_samples, max_experts = dummy_grad_mask.shape

        inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs_cpu))
        grad_outputs_per_expert = zip(*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs_cpu))
        backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"])))

        # dispatch tasks to all remote experts, collect responses
        pending_tasks = {}
        for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(),
                                                    inputs_per_expert, grad_outputs_per_expert):
            expert = expert_per_sample[i.item()][j.item()]
            stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
            inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
            tensors_serialized = [serialize_torch_tensor(tensor, proto.compression)
                                  for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)]
            new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized))
            pending_tasks[new_task] = (i, j)

        backward_survivor_indices, survivor_grad_inputs = cls._collect_responses(
            pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies)
        if len(backward_survivor_indices) == 0:
            raise TimeoutError("Backward pass: no alive experts responded within timeout.")

        # assemble responses
        backward_survivor_ii, backward_survivor_jj = map(torch.as_tensor, zip(*backward_survivor_indices) or ([], []))

        survivor_grad_inputs_stacked = (torch.cat(grad_inputs) for grad_inputs in zip(*survivor_grad_inputs))
        # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]

        grad_inputs = []
        for i, survivor_grad_stacked in enumerate(survivor_grad_inputs_stacked):
            grad_input_per_expert = torch.zeros(  # gradient tensor with individual contributions from each expert
                (num_samples, max_experts, *flat_inputs_cpu[i].shape[1:]),
                device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype)
            grad_input_per_expert[backward_survivor_ii, backward_survivor_jj] = survivor_grad_stacked

            # sum gradients from each expert
            grad_inputs.append(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))

        return (DUMMY, None, None, None, None, None, None, None, None, *grad_inputs)
예제 #6
0
 async def accumulate_part_streaming(self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor]
                                     ) -> Iterable[runtime_pb2.Tensor]:
     """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """
     tensor_part: torch.Tensor = deserialize_torch_tensor(combine_from_streaming(stream_messages))
     averaged_part = await self.accumulate_part(source, tensor_part)
     if not self.averaged_part_stream.done():
         serialized_tensor = serialize_torch_tensor(averaged_part, self.compression_type, allow_inplace=False)
         stream_chunks = tuple(split_for_streaming(serialized_tensor, self.chunk_size_bytes))
         self.averaged_part_stream.set_result(stream_chunks)
         return stream_chunks
     else:
         return self.averaged_part_stream.result()
예제 #7
0
    def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int,
                timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float],
                detect_anomalies: bool, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
        assert not torch.is_grad_enabled()
        num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))

        flat_inputs_cpu = []
        for tensor in flat_inputs:
            if detect_anomalies and not tensor.isfinite().all():
                raise ValueError("One of inputs has nan/inf values")
            flat_inputs_cpu.append(tensor.cpu())

        flat_inputs_per_sample = list(zip(*(x.split(1, dim=0) for x in flat_inputs_cpu)))
        assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples

        # dispatch tasks to all remote experts collect responses
        pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {}
        for i in range(num_samples):
            for j, expert in enumerate(experts_per_sample[i]):
                input_tensors = [serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip(
                    flat_inputs_per_sample[i], nested_flatten(info['forward_schema']))]
                stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
                new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
                pending_tasks[new_task] = (i, j)

        alive_grid_indices, alive_flat_outputs = cls._collect_responses(
            pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies)
        if len(alive_grid_indices) == 0:
            raise TimeoutError("Forward pass: no alive experts responded within timeout.")

        # assemble responses
        alive_ii, alive_jj = map(torch.as_tensor, zip(*alive_grid_indices))
        mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
        mask[alive_ii, alive_jj] = True

        alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs))
        # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]

        outputs = []
        for response_stacked in alive_flat_outputs_stacked:
            output = torch.zeros(
                [num_samples, max_experts, *response_stacked.shape[1:]], device=response_stacked.device,
                dtype=response_stacked.dtype, requires_grad=response_stacked.requires_grad)
            output[alive_ii, alive_jj] = response_stacked
            outputs.append(output.to(flat_inputs[0].device))

        # save individual outputs for backward pass
        ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs_cpu)
        ctx._saved_non_tensors = (info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample,
                                  detect_anomalies)
        return (mask,) + tuple(outputs)
예제 #8
0
 async def backward(self, request: runtime_pb2.ExpertRequest,
                    context: grpc.ServicerContext):
     inputs_and_grad_outputs = [
         deserialize_torch_tensor(tensor) for tensor in request.tensors
     ]
     future = self.experts[request.uid].backward_pool.submit_task(
         *inputs_and_grad_outputs)
     serialized_response = [
         serialize_torch_tensor(tensor,
                                proto.compression,
                                allow_inplace=True)
         for tensor, proto in zip(
             await future,
             nested_flatten(self.experts[request.uid].grad_inputs_schema))
     ]
     return runtime_pb2.ExpertResponse(tensors=serialized_response)
예제 #9
0
 async def _forward_one_expert(grid_indices: Tuple[int, ...],
                               expert: RemoteExpert,
                               inputs: Tuple[torch.Tensor]):
     stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(
         expert.endpoint, aio=True)
     try:
         outputs = await stub.forward(
             runtime_pb2.ExpertRequest(uid=expert.uid,
                                       tensors=[
                                           serialize_torch_tensor(tensor)
                                           for tensor in inputs
                                       ]))
         return grid_indices, tuple(
             deserialize_torch_tensor(tensor) for tensor in outputs.tensors)
     except grpc.experimental.aio.AioRpcError as error:
         logger.warning(
             f"RemoteExpert {expert} failed forward: {error.code()} (inputs: {inputs})"
         )
예제 #10
0
    async def accumulate_part_streaming(
        self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor]
    ) -> Iterable[runtime_pb2.Tensor]:
        """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """
        try:
            tensor_part = deserialize_torch_tensor(
                combine_from_streaming(stream_messages))
        except RuntimeError as e:
            raise AllreduceException(
                f"Could not deserialize tensor part from {source} for streaming {e}"
            )

        averaged_part = await self.accumulate_part(
            source, tensor_part, weight=self.peer_weights[source])
        if not self.averaged_part_stream.done():
            serialized_tensor = serialize_torch_tensor(averaged_part,
                                                       self.compression_type,
                                                       allow_inplace=False)
            stream_chunks = tuple(
                split_for_streaming(serialized_tensor, self.chunk_size_bytes))
            self.averaged_part_stream.set_result(stream_chunks)
            return stream_chunks
        else:
            return self.averaged_part_stream.result()
def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float:
    t = time.time()
    deserialize_torch_tensor(serialize_torch_tensor(tensor, compression_type))
    return time.time() - t