def _detail_tf_tensor(worker, tensor_tuple) -> tf.Tensor: """ This function converts a serialized tf tensor into a local TF tensor using tf.io. Args: tensor_tuple (bin): serialized obj of TF tensor. It's a tuple where the first value is the ID, the second vlaue is the binary for the TensorFlow object, the third value is the tensor_dtype_enum, and the fourth value is the chain of tensor abstractions Returns: tf.Tensor: a deserialized TF tensor """ tensor_id, tensor_bin, tensor_dtype_enum, chain = tensor_tuple tensor_dtype = syft.serde._detail(worker, tensor_dtype_enum) tensor = tf.io.parse_tensor(tensor_bin, tensor_dtype) initialize_tensor( hook=syft.tensorflow.hook, obj=tensor, owner=worker, id=tensor_id, init_args=[], init_kwargs={}, ) if chain is not None: chain = syft.serde._detail(worker, chain) tensor.child = chain tensor.is_wrapper = True return tensor
def _detail_torch_tensor(worker: AbstractWorker, tensor_tuple: tuple) -> torch.Tensor: """ This function converts a serialized torch tensor into a torch tensor using pickle. Args: tensor_tuple (bin): serialized obj of torch tensor. It's a tuple where the first value is the ID, the second vlaue is the binary for the PyTorch object, the third value is the chain of tensor abstractions, and the fourth object is the chain of gradients (.grad.grad, etc.) Returns: torch.Tensor: a torch tensor that was serialized """ tensor_id, tensor_bin, chain, grad_chain, tags, description = tensor_tuple tensor = _deserialize_tensor(tensor_bin) # note we need to do this explicitly because torch.load does not # include .grad informatino if grad_chain is not None: tensor.grad = _detail_torch_tensor(worker, grad_chain) initialize_tensor( hook_self=syft.torch.hook, cls=tensor, is_tensor=True, owner=worker, id=tensor_id, init_args=[], kwargs={}, ) if tags is not None: tags = list(tags) for i in range(len(tags)): tag = tags[i] if isinstance(tag, bytes): tag = tag.decode("utf-8") tags[i] = tag tensor.tags = tags if description is not None: if isinstance(description, bytes): description = description.decode("utf-8") tensor.description = description if chain is not None: chain = syft.serde._detail(worker, chain) tensor.child = chain tensor.is_wrapper = True return tensor
def _detail_torch_tensor(worker: AbstractWorker, tensor_tuple: tuple) -> torch.Tensor: """ This function converts a serialized torch tensor into a torch tensor using pickle. Args: tensor_tuple (bin): serialized obj of torch tensor. It's a tuple where the first value is the ID, the second vlaue is the binary for the PyTorch object, the third value is the chain of tensor abstractions, and the fourth object is the chain of gradients (.grad.grad, etc.) Returns: torch.Tensor: a torch tensor that was serialized """ ( tensor_id, tensor_bin, chain, grad_chain, tags, description, serializer, origin, id_at_origin, ) = tensor_tuple tensor = _deserialize_tensor(worker, serde._detail(worker, serializer), tensor_bin) # note we need to do this explicitly because torch.load does not # include .grad informatino if grad_chain is not None: tensor.grad = _detail_torch_tensor(worker, grad_chain) initialize_tensor(hook=syft.torch.hook, obj=tensor, owner=worker, id=tensor_id, init_args=[], init_kwargs={}) if chain is not None: chain = serde._detail(worker, chain) tensor.child = chain tensor.is_wrapper = True tensor.tags = serde._detail(worker, tags) tensor.description = serde._detail(worker, description) tensor.origin = serde._detail(worker, origin) tensor.id_at_origin = serde._detail(worker, id_at_origin) return tensor
def new___init__(cls, *args, owner=None, id=None, register=True, **kwargs): initialize_tensor( hook_self=hook_self, cls=cls, id=id, is_tensor=is_tensor, init_args=args, init_kwargs=kwargs, )
def unbufferize(worker: AbstractWorker, protobuf_tensor: "TorchTensorPB") -> torch.Tensor: """ This method converts a Protobuf torch tensor back into a Torch tensor. The tensor contents can be deserialized from binary representations produced by Torch or Numpy, or from the generic Protobuf message format for cross-platform communication. Args: protobuf_tensor (bin): Protobuf message of torch tensor. Returns: tensor (torch.Tensor): a torch tensor converted from Protobuf """ tensor_id = get_protobuf_id(protobuf_tensor.id) tags = protobuf_tensor.tags description = protobuf_tensor.description contents_type = protobuf_tensor.WhichOneof("contents") serialized_tensor = getattr(protobuf_tensor, contents_type) serializer = SERIALIZERS_PROTOBUF_TO_SYFT[protobuf_tensor.serializer] tensor = _deserialize_tensor(worker, (serializer), serialized_tensor) # note we need to do this explicitly because torch.load does not # include .grad information if protobuf_tensor.HasField("grad_chain"): grad_chain = protobuf_tensor.grad_chain tensor.grad = TorchTensorWrapper.unbufferize(worker, grad_chain) initialize_tensor( hook=syft.torch.hook, obj=tensor, owner=worker, id=tensor_id, init_args=[], init_kwargs={}, ) if protobuf_tensor.HasField("chain"): chain = protobuf_tensor.chain chain = TorchTensorWrapper.unbufferize(worker, chain) tensor.child = chain tensor.is_wrapper = True tensor.tags = set(tags) tensor.description = description return tensor