Exemplo n.º 1
0
    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
        payload = tuple(nested_flatten((ctx.saved_tensors, grad_outputs)))

        grad_inputs = ctx.stub.backward(
            runtime_pb2.ExpertRequest(
                uid=ctx.uid,
                tensors=[serialize_torch_tensor(tensor)
                         for tensor in payload]))

        deserialized_grad_inputs = [
            deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors
        ]
        return (DUMMY, None, None, *deserialized_grad_inputs)
Exemplo n.º 2
0
    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
        inputs_and_grad_outputs = tuple(
            nested_flatten((ctx.saved_tensors, grad_outputs)))
        backward_schema = tuple(
            nested_flatten(
                (ctx.info["forward_schema"], ctx.info["outputs_schema"])))
        serialized_tensors = [
            serialize_torch_tensor(tensor, proto.compression)
            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
        ]

        grad_inputs = ctx.stub.backward(
            runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))

        deserialized_grad_inputs = [
            deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors
        ]
        return (DUMMY, None, None, None, *deserialized_grad_inputs)
Exemplo n.º 3
0
    def forward(ctx, dummy: torch.Tensor, uid: str,
                stub: runtime_grpc.ConnectionHandlerStub,
                *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
        inputs = tuple(
            map(torch.Tensor.detach,
                inputs))  # detach to avoid pickling the computation graph
        ctx.uid, ctx.stub = uid, stub
        ctx.save_for_backward(*inputs)

        outputs = stub.forward(
            runtime_pb2.ExpertRequest(
                uid=ctx.uid,
                tensors=[serialize_torch_tensor(tensor) for tensor in inputs]))

        deserialized_outputs = [
            deserialize_torch_tensor(tensor) for tensor in outputs.tensors
        ]

        return tuple(deserialized_outputs)
Exemplo n.º 4
0
    def forward(ctx, dummy: torch.Tensor, uid: str,
                stub: runtime_grpc.ConnectionHandlerStub, info: Dict[str, Any],
                *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
        # detach to avoid pickling the computation graph
        inputs = tuple(tensor.cpu().detach() for tensor in inputs)
        ctx.uid, ctx.stub, ctx.info = uid, stub, info
        ctx.save_for_backward(*inputs)

        serialized_tensors = [
            serialize_torch_tensor(inp, proto.compression) for inp, proto in
            zip(inputs, nested_flatten(info["forward_schema"]))
        ]

        outputs = stub.forward(
            runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))

        deserialized_outputs = [
            deserialize_torch_tensor(tensor) for tensor in outputs.tensors
        ]

        return tuple(deserialized_outputs)
Exemplo n.º 5
0
    async def _load_state_from_peers(self, future: MPFuture):
        key_manager = self._matchmaking.group_key_manager
        peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers",
                                        latest=True) or ({}, None)
        peer_priority = {
            peer: float(info.value)
            for peer, info in peer_priority.items()
            if isinstance(info, ValueWithExpiration)
            and isinstance(info.value, (float, int))
        }

        if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
            logger.info(
                f"Averager could not load state from peers: peer dict is absent or corrupted {peer_priority}."
            )
            future.set_result(None)
            return

        metadata = None
        for peer in sorted(peer_priority.keys(),
                           key=peer_priority.get,
                           reverse=True):
            if peer != self.endpoint:
                logger.info(f"Downloading parameters from peer {peer}")
                stream = None
                try:
                    leader_stub = ChannelCache.get_stub(
                        peer,
                        averaging_pb2_grpc.DecentralizedAveragingStub,
                        aio=True)
                    stream = leader_stub.rpc_download_state(
                        averaging_pb2.DownloadRequest())
                    current_tensor_parts, tensors = [], []
                    async for message in stream:
                        if message.metadata:
                            metadata = self.serializer.loads(message.metadata)
                        if message.tensor_part.dtype and current_tensor_parts:
                            # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
                            tensors.append(
                                deserialize_torch_tensor(
                                    combine_from_streaming(
                                        current_tensor_parts)))
                            current_tensor_parts = []
                        current_tensor_parts.append(message.tensor_part)
                    if current_tensor_parts:
                        tensors.append(
                            deserialize_torch_tensor(
                                combine_from_streaming(current_tensor_parts)))
                    future.set_result((metadata, tensors))
                    self.last_updated = get_dht_time()
                    return
                except grpc.aio.AioRpcError as e:
                    logger.info(f"Failed to download state from {peer} - {e}")
                finally:
                    if stream is not None:
                        await stream.code()

        else:
            logger.warning(
                "Averager could not load state from peers: found no active peers."
            )
            future.set_result(None)