Exemple #1
0
def _unbufferize_torch_parameter(
        worker: AbstractWorker,
        protobuf_param: ParameterPB) -> torch.nn.Parameter:
    data = syft.serde.protobuf.serde._unbufferize(worker,
                                                  protobuf_param.tensor)
    param = torch.nn.Parameter(data,
                               requires_grad=protobuf_param.requires_grad)
    param.id = get_protobuf_id(protobuf_param.id)
    if protobuf_param.HasField("grad"):
        param.grad = syft.serde.protobuf.serde._unbufferize(
            worker, protobuf_param.grad)
    return param
Exemple #2
0
    def unbufferize(worker: AbstractWorker,
                    protobuf_tensor: "TorchTensorPB") -> torch.Tensor:
        """
        This method converts a Protobuf torch tensor back into a
        Torch tensor. The tensor contents can be deserialized from
        binary representations produced by Torch or Numpy, or from
        the generic Protobuf message format for cross-platform
        communication.

        Args:
            protobuf_tensor (bin): Protobuf message of torch tensor.

        Returns:
            tensor (torch.Tensor): a torch tensor converted from Protobuf
        """
        tensor_id = get_protobuf_id(protobuf_tensor.id)
        tags = protobuf_tensor.tags
        description = protobuf_tensor.description

        contents_type = protobuf_tensor.WhichOneof("contents")
        serialized_tensor = getattr(protobuf_tensor, contents_type)
        serializer = SERIALIZERS_PROTOBUF_TO_SYFT[protobuf_tensor.serializer]

        tensor = _deserialize_tensor(worker, (serializer), serialized_tensor)

        # note we need to do this explicitly because torch.load does not
        # include .grad information
        if protobuf_tensor.HasField("grad_chain"):
            grad_chain = protobuf_tensor.grad_chain
            tensor.grad = TorchTensorWrapper.unbufferize(worker, grad_chain)

        initialize_tensor(
            hook=syft.torch.hook,
            obj=tensor,
            owner=worker,
            id=tensor_id,
            init_args=[],
            init_kwargs={},
        )

        if protobuf_tensor.HasField("chain"):
            chain = protobuf_tensor.chain
            chain = TorchTensorWrapper.unbufferize(worker, chain)
            tensor.child = chain
            tensor.is_wrapper = True

        tensor.tags = set(tags)
        tensor.description = description

        return tensor
Exemple #3
0
    def unbufferize(worker: AbstractWorker, protobuf_param: ParameterPB) -> torch.nn.Parameter:
        """
            This method converts a ParameterPB into a torch.nn.Parameter.

            Args:
                protobuf_param (ParameterPB): input ParameterPB to be deserialized.

            Returns:
                param: (torch.nn.Parameter): deserialized ParameterPB.
        """
        data = syft.serde.protobuf.serde._unbufferize(worker, protobuf_param.tensor)
        param = torch.nn.Parameter(data, requires_grad=protobuf_param.requires_grad)
        param.id = get_protobuf_id(protobuf_param.id)
        if protobuf_param.HasField("grad"):
            param.grad = syft.serde.protobuf.serde._unbufferize(worker, protobuf_param.grad)
        return param