Пример #1
0
def test_vector_compression(size=(128, 128, 64), alpha=5e-08):
    torch.manual_seed(0)
    from hivemind.proto.runtime_pb2 import CompressionType
    from hivemind.utils import serialize_torch_tensor, deserialize_torch_tensor
    X = torch.randn(*size)
    assert torch.allclose(
        deserialize_torch_tensor(
            serialize_torch_tensor(X, CompressionType.NONE)), X)
    error = deserialize_torch_tensor(
        serialize_torch_tensor(X,
                               CompressionType.MEANSTD_LAST_AXIS_FLOAT16)) - X
    assert error.square().mean() < alpha
    error = deserialize_torch_tensor(
        serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
    assert error.square().mean() < alpha
Пример #2
0
    async def _communicate_with_peer(self, peer_endpoint: Endpoint,
                                     local_part: torch.Tensor) -> torch.Tensor:
        """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
        serialized_tensor_part = serialize_torch_tensor(local_part,
                                                        self.compression_type,
                                                        allow_inplace=False)
        chunks = split_for_streaming(serialized_tensor_part,
                                     self.chunk_size_bytes)

        stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
        await stream.write(
            averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING,
                                        group_id=self.group_id,
                                        endpoint=self.endpoint,
                                        tensor_part=next(chunks)))
        for chunk in chunks:
            await stream.write(averaging_pb2.AveragingData(tensor_part=chunk))
        await stream.done_writing()

        outputs: Sequence[averaging_pb2.AveragingData] = [
            message async for message in stream
        ]
        code = outputs[0].code if outputs else averaging_pb2.INTERNAL_ERROR
        if code != averaging_pb2.AVERAGED_PART:
            raise AllreduceException(
                f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)}"
                f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)},"
                f" allreduce failed")

        averaged_part = deserialize_torch_tensor(
            combine_from_streaming(
                [message.tensor_part for message in outputs]))
        self.register_averaged_part(peer_endpoint, averaged_part)
        return averaged_part
Пример #3
0
def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
    torch.manual_seed(0)
    X = torch.randn(*size)
    assert torch.allclose(
        deserialize_torch_tensor(
            serialize_torch_tensor(X, CompressionType.NONE)), X)
    error = deserialize_torch_tensor(
        serialize_torch_tensor(X,
                               CompressionType.MEANSTD_LAST_AXIS_FLOAT16)) - X
    assert error.square().mean() < alpha
    error = deserialize_torch_tensor(
        serialize_torch_tensor(X, CompressionType.FLOAT16)) - X
    assert error.square().mean() < alpha
    error = deserialize_torch_tensor(
        serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X
    assert error.square().mean() < beta
Пример #4
0
    def _collect_responses(
        task_to_indices: Dict[grpc.Future,
                              Tuple[int, int]], num_samples: int, k_min: int,
        timeout_total: Optional[float], timeout_after_k_min: Optional[float]
    ) -> Tuple[List[Tuple[int, int]], List[Tuple[torch.Tensor, ...]]]:
        """ await up to k_min results and any result submitted within timeout_after_k_min, cancel stragglers """
        timeout_total = float(
            'inf') if timeout_total is None else timeout_total
        timeout_after_k_min = float(
            'inf') if timeout_after_k_min is None else timeout_after_k_min
        num_successful_tasks = [0 for _ in range(num_samples)]
        pending_samples = num_samples  # samples for which we have less than k_min results
        finished_indices, finished_outputs = [], []
        t_finish = time.perf_counter() + timeout_total
        pending_tasks = set(task_to_indices.keys())
        finished_tasks = Queue()

        try:
            # the algorithm below is essentially futures.as_completed, but for grpc.Future
            for task in pending_tasks:
                task.add_done_callback(finished_tasks.put)

            for _ in range(len(task_to_indices)):
                timeout = max(
                    0.0, t_finish -
                    time.perf_counter()) if t_finish != float('inf') else None
                task = finished_tasks.get(timeout=timeout)
                pending_tasks.discard(task)

                if task.exception() or task.cancelled():
                    logger.warning(
                        f"Task {task} failed: {type(task.exception())}")
                    continue

                finished_indices.append(task_to_indices[task])
                finished_outputs.append(
                    tuple(
                        deserialize_torch_tensor(tensor)
                        for tensor in task.result().tensors))

                # count how many successes we have for each input sample
                sample_index = task_to_indices[task][0]
                num_successful_tasks[sample_index] += 1
                if num_successful_tasks[sample_index] == k_min:
                    pending_samples -= 1
                    if pending_samples <= 0:  # all tasks finished, await stragglers for at most timeout_after_k_min
                        t_finish = min(
                            t_finish,
                            time.perf_counter() + timeout_after_k_min)

        except Empty:
            pass  # we reached t_finish, this is normal behavior
        finally:
            for task in pending_tasks:
                task.cancel()
        return finished_indices, finished_outputs
Пример #5
0
 async def forward(self, request: runtime_pb2.ExpertRequest,
                   context: grpc.ServicerContext):
     inputs = [
         deserialize_torch_tensor(tensor) for tensor in request.tensors
     ]
     future = self.experts[request.uid].forward_pool.submit_task(*inputs)
     serialized_response = [
         serialize_torch_tensor(tensor) for tensor in await future
     ]
     return runtime_pb2.ExpertResponse(tensors=serialized_response)
Пример #6
0
 async def accumulate_part_streaming(self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor]
                                     ) -> Iterable[runtime_pb2.Tensor]:
     """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """
     tensor_part: torch.Tensor = deserialize_torch_tensor(combine_from_streaming(stream_messages))
     averaged_part = await self.accumulate_part(source, tensor_part)
     if not self.averaged_part_stream.done():
         serialized_tensor = serialize_torch_tensor(averaged_part, self.compression_type, allow_inplace=False)
         stream_chunks = tuple(split_for_streaming(serialized_tensor, self.chunk_size_bytes))
         self.averaged_part_stream.set_result(stream_chunks)
         return stream_chunks
     else:
         return self.averaged_part_stream.result()
Пример #7
0
def _process_dispatched_task(task: grpc.Future, detect_anomalies: bool) -> Optional[Tuple[torch.Tensor]]:
    if task.exception() or task.cancelled():
        logger.warning(f"Task {task} failed: {type(task.exception())}")
        return None

    deserialized_outputs = []
    for tensor in task.result().tensors:
        deserialized_tensor = deserialize_torch_tensor(tensor)
        if detect_anomalies and not deserialized_tensor.isfinite().all():
            logger.error(f"Task {task} failed: output tensor contains nan/inf values")
            return None
        deserialized_outputs.append(deserialized_tensor)

    return tuple(deserialized_outputs)
Пример #8
0
 async def backward(self, request: runtime_pb2.ExpertRequest,
                    context: grpc.ServicerContext):
     inputs_and_grad_outputs = [
         deserialize_torch_tensor(tensor) for tensor in request.tensors
     ]
     future = self.experts[request.uid].backward_pool.submit_task(
         *inputs_and_grad_outputs)
     serialized_response = [
         serialize_torch_tensor(tensor,
                                proto.compression,
                                allow_inplace=True)
         for tensor, proto in zip(
             await future,
             nested_flatten(self.experts[request.uid].grad_inputs_schema))
     ]
     return runtime_pb2.ExpertResponse(tensors=serialized_response)
Пример #9
0
 async def _forward_one_expert(grid_indices: Tuple[int, ...],
                               expert: RemoteExpert,
                               inputs: Tuple[torch.Tensor]):
     stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(
         expert.endpoint, aio=True)
     try:
         outputs = await stub.forward(
             runtime_pb2.ExpertRequest(uid=expert.uid,
                                       tensors=[
                                           serialize_torch_tensor(tensor)
                                           for tensor in inputs
                                       ]))
         return grid_indices, tuple(
             deserialize_torch_tensor(tensor) for tensor in outputs.tensors)
     except grpc.experimental.aio.AioRpcError as error:
         logger.warning(
             f"RemoteExpert {expert} failed forward: {error.code()} (inputs: {inputs})"
         )
Пример #10
0
    async def accumulate_part_streaming(
        self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor]
    ) -> Iterable[runtime_pb2.Tensor]:
        """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """
        try:
            tensor_part = deserialize_torch_tensor(
                combine_from_streaming(stream_messages))
        except RuntimeError as e:
            raise AllreduceException(
                f"Could not deserialize tensor part from {source} for streaming {e}"
            )

        averaged_part = await self.accumulate_part(
            source, tensor_part, weight=self.peer_weights[source])
        if not self.averaged_part_stream.done():
            serialized_tensor = serialize_torch_tensor(averaged_part,
                                                       self.compression_type,
                                                       allow_inplace=False)
            stream_chunks = tuple(
                split_for_streaming(serialized_tensor, self.chunk_size_bytes))
            self.averaged_part_stream.set_result(stream_chunks)
            return stream_chunks
        else:
            return self.averaged_part_stream.result()
def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float:
    t = time.time()
    deserialize_torch_tensor(serialize_torch_tensor(tensor, compression_type))
    return time.time() - t