示例#1
0
def _detail_tf_tensor(worker, tensor_tuple) -> tf.Tensor:
    """
    This function converts a serialized tf tensor into a local TF tensor
    using tf.io.

    Args:
        tensor_tuple (bin): serialized obj of TF tensor. It's a tuple where
            the first value is the ID, the second vlaue is the binary for the
            TensorFlow object, the third value is the tensor_dtype_enum, and
            the fourth value is the chain of tensor abstractions

    Returns:
        tf.Tensor: a deserialized TF tensor
    """

    tensor_id, tensor_bin, tensor_dtype_enum, chain = tensor_tuple

    tensor_dtype = syft.serde._detail(worker, tensor_dtype_enum)
    tensor = tf.io.parse_tensor(tensor_bin, tensor_dtype)

    initialize_tensor(
        hook=syft.tensorflow.hook,
        obj=tensor,
        owner=worker,
        id=tensor_id,
        init_args=[],
        init_kwargs={},
    )

    if chain is not None:
        chain = syft.serde._detail(worker, chain)
        tensor.child = chain
        tensor.is_wrapper = True

    return tensor
示例#2
0
def _detail_torch_tensor(worker: AbstractWorker,
                         tensor_tuple: tuple) -> torch.Tensor:
    """
    This function converts a serialized torch tensor into a torch tensor
    using pickle.

    Args:
        tensor_tuple (bin): serialized obj of torch tensor. It's a tuple where
            the first value is the ID, the second vlaue is the binary for the
            PyTorch object, the third value is the chain of tensor abstractions,
            and the fourth object is the chain of gradients (.grad.grad, etc.)

    Returns:
        torch.Tensor: a torch tensor that was serialized
    """

    tensor_id, tensor_bin, chain, grad_chain, tags, description = tensor_tuple

    tensor = _deserialize_tensor(tensor_bin)

    # note we need to do this explicitly because torch.load does not
    # include .grad informatino
    if grad_chain is not None:
        tensor.grad = _detail_torch_tensor(worker, grad_chain)

    initialize_tensor(
        hook_self=syft.torch.hook,
        cls=tensor,
        is_tensor=True,
        owner=worker,
        id=tensor_id,
        init_args=[],
        kwargs={},
    )

    if tags is not None:

        tags = list(tags)

        for i in range(len(tags)):
            tag = tags[i]
            if isinstance(tag, bytes):
                tag = tag.decode("utf-8")
            tags[i] = tag
        tensor.tags = tags

    if description is not None:
        if isinstance(description, bytes):
            description = description.decode("utf-8")
        tensor.description = description

    if chain is not None:
        chain = syft.serde._detail(worker, chain)
        tensor.child = chain
        tensor.is_wrapper = True

    return tensor
示例#3
0
def _detail_torch_tensor(worker: AbstractWorker,
                         tensor_tuple: tuple) -> torch.Tensor:
    """
    This function converts a serialized torch tensor into a torch tensor
    using pickle.

    Args:
        tensor_tuple (bin): serialized obj of torch tensor. It's a tuple where
            the first value is the ID, the second vlaue is the binary for the
            PyTorch object, the third value is the chain of tensor abstractions,
            and the fourth object is the chain of gradients (.grad.grad, etc.)

    Returns:
        torch.Tensor: a torch tensor that was serialized
    """

    (
        tensor_id,
        tensor_bin,
        chain,
        grad_chain,
        tags,
        description,
        serializer,
        origin,
        id_at_origin,
    ) = tensor_tuple

    tensor = _deserialize_tensor(worker, serde._detail(worker, serializer),
                                 tensor_bin)

    # note we need to do this explicitly because torch.load does not
    # include .grad informatino
    if grad_chain is not None:
        tensor.grad = _detail_torch_tensor(worker, grad_chain)

    initialize_tensor(hook=syft.torch.hook,
                      obj=tensor,
                      owner=worker,
                      id=tensor_id,
                      init_args=[],
                      init_kwargs={})

    if chain is not None:
        chain = serde._detail(worker, chain)
        tensor.child = chain
        tensor.is_wrapper = True

    tensor.tags = serde._detail(worker, tags)
    tensor.description = serde._detail(worker, description)
    tensor.origin = serde._detail(worker, origin)
    tensor.id_at_origin = serde._detail(worker, id_at_origin)

    return tensor
示例#4
0
 def new___init__(cls,
                  *args,
                  owner=None,
                  id=None,
                  register=True,
                  **kwargs):
     initialize_tensor(
         hook_self=hook_self,
         cls=cls,
         id=id,
         is_tensor=is_tensor,
         init_args=args,
         init_kwargs=kwargs,
     )
示例#5
0
    def unbufferize(worker: AbstractWorker,
                    protobuf_tensor: "TorchTensorPB") -> torch.Tensor:
        """
        This method converts a Protobuf torch tensor back into a
        Torch tensor. The tensor contents can be deserialized from
        binary representations produced by Torch or Numpy, or from
        the generic Protobuf message format for cross-platform
        communication.

        Args:
            protobuf_tensor (bin): Protobuf message of torch tensor.

        Returns:
            tensor (torch.Tensor): a torch tensor converted from Protobuf
        """
        tensor_id = get_protobuf_id(protobuf_tensor.id)
        tags = protobuf_tensor.tags
        description = protobuf_tensor.description

        contents_type = protobuf_tensor.WhichOneof("contents")
        serialized_tensor = getattr(protobuf_tensor, contents_type)
        serializer = SERIALIZERS_PROTOBUF_TO_SYFT[protobuf_tensor.serializer]

        tensor = _deserialize_tensor(worker, (serializer), serialized_tensor)

        # note we need to do this explicitly because torch.load does not
        # include .grad information
        if protobuf_tensor.HasField("grad_chain"):
            grad_chain = protobuf_tensor.grad_chain
            tensor.grad = TorchTensorWrapper.unbufferize(worker, grad_chain)

        initialize_tensor(
            hook=syft.torch.hook,
            obj=tensor,
            owner=worker,
            id=tensor_id,
            init_args=[],
            init_kwargs={},
        )

        if protobuf_tensor.HasField("chain"):
            chain = protobuf_tensor.chain
            chain = TorchTensorWrapper.unbufferize(worker, chain)
            tensor.child = chain
            tensor.is_wrapper = True

        tensor.tags = set(tags)
        tensor.description = description

        return tensor