def _detail_additive_shared_tensor( worker: AbstractWorker, tensor_tuple: tuple) -> AdditiveSharingTensor: """ This function reconstructs a AdditiveSharingTensor given it's attributes in form of a tuple. Args: worker: the worker doing the deserialization tensor_tuple: a tuple holding the attributes of the AdditiveSharingTensor Returns: AdditiveSharingTensor: a AdditiveSharingTensor Examples: shared_tensor = _detail_additive_shared_tensor(data) """ tensor_id, field, crypto_provider, chain = tensor_tuple tensor = AdditiveSharingTensor( owner=worker, id=tensor_id, field=field, crypto_provider=worker.get_worker(crypto_provider)) if chain is not None: chain = _detail(worker, chain) tensor.child = chain return tensor
def test_wrap(workers): """ Test the .on() wrap functionality for AdditiveSharingTensor """ x_tensor = torch.Tensor([1, 2, 3]) x = AdditiveSharingTensor().on(x_tensor) assert isinstance(x, torch.Tensor) assert isinstance(x.child, AdditiveSharingTensor) assert isinstance(x.child.child, torch.Tensor)
"my_syft_tensor_type": lambda i: i.child, } # Dict to return the proper lambda function for the right torch or syft tensor type backward_func = { TorchTensor: lambda i: i.wrap(), torch.Tensor: lambda i: i.wrap(), torch.nn.Parameter: lambda i: torch.nn.Parameter(data=i), PointerTensor: lambda i: i, LoggingTensor: lambda i: LoggingTensor().on(i, wrap=False), FixedPrecisionTensor: lambda i, **kwargs: FixedPrecisionTensor(**kwargs).on(i, wrap=False), LargePrecisionTensor: lambda i, **kwargs: LargePrecisionTensor(**kwargs).on( LargePrecisionTensor.create_tensor_from_numpy(i, **kwargs), wrap=False ), AutogradTensor: lambda i: AutogradTensor(data=i).on(i, wrap=False), AdditiveSharingTensor: lambda i, **kwargs: AdditiveSharingTensor(**kwargs).on(i, wrap=False), MultiPointerTensor: lambda i, **kwargs: MultiPointerTensor(**kwargs).on(i, wrap=False), "my_syft_tensor_type": lambda i, **kwargs: "my_syft_tensor_type(**kwargs).on(i, wrap=False)", } # Methods or functions whose signature changes a lot and that we don't want to "cache", because # they have an arbitrary number of tensors in args which can trigger unexpected behaviour variable_methods = {"__getitem__", "_getitem_public", "view", "permute", "add_", "sub_"} variable_functions = {"torch.unbind", "unbind", "torch.stack", "stack", "torch.mean", "torch.sum"} def hook_method_args(attr, method_self, args, kwargs): """Method arguments are sometimes simple types (such as strings or ints) but sometimes they are custom Syft tensors such as wrappers (torch.Tensor) or LoggingTensor or some other tensor type. Complex types (which have a .child attribute) need to have arguments converted from the arg to arg.child so that the types match as the