Exemple #1
0
    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
        payload = tuple(nested_flatten((ctx.saved_tensors, grad_outputs)))

        grad_inputs = ctx.stub.backward(
            runtime_pb2.ExpertRequest(
                uid=ctx.uid,
                tensors=[serialize_torch_tensor(tensor)
                         for tensor in payload]))

        deserialized_grad_inputs = [
            deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors
        ]
        return (DUMMY, None, None, *deserialized_grad_inputs)
Exemple #2
0
    def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
        inputs_and_grad_outputs = tuple(
            nested_flatten((ctx.saved_tensors, grad_outputs)))
        backward_schema = tuple(
            nested_flatten(
                (ctx.info["forward_schema"], ctx.info["outputs_schema"])))
        serialized_tensors = [
            serialize_torch_tensor(tensor, proto.compression)
            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
        ]

        grad_inputs = ctx.stub.backward(
            runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))

        deserialized_grad_inputs = [
            deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors
        ]
        return (DUMMY, None, None, None, *deserialized_grad_inputs)
Exemple #3
0
    def forward(ctx, dummy: torch.Tensor, uid: str,
                stub: runtime_grpc.ConnectionHandlerStub,
                *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
        inputs = tuple(
            map(torch.Tensor.detach,
                inputs))  # detach to avoid pickling the computation graph
        ctx.uid, ctx.stub = uid, stub
        ctx.save_for_backward(*inputs)

        outputs = stub.forward(
            runtime_pb2.ExpertRequest(
                uid=ctx.uid,
                tensors=[serialize_torch_tensor(tensor) for tensor in inputs]))

        deserialized_outputs = [
            deserialize_torch_tensor(tensor) for tensor in outputs.tensors
        ]

        return tuple(deserialized_outputs)
Exemple #4
0
    def forward(ctx, dummy: torch.Tensor, uid: str,
                stub: runtime_grpc.ConnectionHandlerStub, info: Dict[str, Any],
                *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
        # detach to avoid pickling the computation graph
        inputs = tuple(tensor.cpu().detach() for tensor in inputs)
        ctx.uid, ctx.stub, ctx.info = uid, stub, info
        ctx.save_for_backward(*inputs)

        serialized_tensors = [
            serialize_torch_tensor(inp, proto.compression) for inp, proto in
            zip(inputs, nested_flatten(info["forward_schema"]))
        ]

        outputs = stub.forward(
            runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))

        deserialized_outputs = [
            deserialize_torch_tensor(tensor) for tensor in outputs.tensors
        ]

        return tuple(deserialized_outputs)
Exemple #5
0
    async def rpc_download_state(
        self, request: averaging_pb2.DownloadRequest,
        context: grpc.ServicerContext
    ) -> AsyncIterator[averaging_pb2.DownloadData]:
        """
        Get the up-to-date trainer state from a peer.
        The state consists of two parts: (serialized_metadata, tensors)

         - serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
         - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
        """
        chunk_size_bytes = self.matchmaking_kwargs.get(
            'chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
        metadata, tensors = await self._get_current_state_from_host_process()

        for tensor in tensors:
            for part in split_for_streaming(serialize_torch_tensor(tensor),
                                            chunk_size_bytes):
                if metadata is not None:
                    yield averaging_pb2.DownloadData(tensor_part=part,
                                                     metadata=metadata)
                    metadata = None
                else:
                    yield averaging_pb2.DownloadData(tensor_part=part)