def _detail_torch_tensor(worker: AbstractWorker, tensor_tuple: tuple) -> torch.Tensor: """ This function converts a serialized torch tensor into a torch tensor using pickle. Args: tensor_tuple (bin): serialized obj of torch tensor. It's a tuple where the first value is the ID, the second vlaue is the binary for the PyTorch object, the third value is the chain of tensor abstractions, and the fourth object is the chain of gradients (.grad.grad, etc.) Returns: torch.Tensor: a torch tensor that was serialized """ tensor_id, tensor_bin, chain, grad_chain, tags, description = tensor_tuple tensor = _deserialize_tensor(tensor_bin) # note we need to do this explicitly because torch.load does not # include .grad informatino if grad_chain is not None: tensor.grad = _detail_torch_tensor(worker, grad_chain) initialize_tensor( hook_self=syft.torch.hook, cls=tensor, torch_tensor=True, owner=worker, id=tensor_id, init_args=[], kwargs={}, ) if tags is not None: tags = list(tags) for i in range(len(tags)): tag = tags[i] if isinstance(tag, bytes): tag = tag.decode("utf-8") tags[i] = tag tensor.tags = tags if description is not None: if isinstance(description, bytes): description = description.decode("utf-8") tensor.description = description if chain is not None: chain = syft.serde._detail(worker, chain) tensor.child = chain tensor.is_wrapper = True return tensor
def new___init__(cls, *args, owner=None, id=None, register=True, **kwargs): initialize_tensor( hook_self=hook_self, cls=cls, id=id, torch_tensor=torch_tensor, init_args=args, init_kwargs=kwargs, )