def test_wrap(): """ Test the .on() wrap functionality for AutogradTensor """ x_tensor = torch.Tensor([1, 2, 3]) x = AutogradTensor().on(x_tensor) assert isinstance(x, torch.Tensor) assert isinstance(x.child, AutogradTensor) assert isinstance(x.child.child, torch.Tensor)
LargePrecisionTensor: lambda i: i._internal_representation_to_large_ints(), "my_syft_tensor_type": lambda i: i.child, } # Dict to return the proper lambda function for the right torch or syft tensor type backward_func = { TorchTensor: lambda i: i.wrap(), torch.Tensor: lambda i: i.wrap(), torch.nn.Parameter: lambda i: torch.nn.Parameter(data=i), PointerTensor: lambda i: i, LoggingTensor: lambda i: LoggingTensor().on(i, wrap=False), FixedPrecisionTensor: lambda i, **kwargs: FixedPrecisionTensor(**kwargs).on(i, wrap=False), LargePrecisionTensor: lambda i, **kwargs: LargePrecisionTensor(**kwargs).on( LargePrecisionTensor.create_tensor_from_numpy(i, **kwargs), wrap=False ), AutogradTensor: lambda i: AutogradTensor(data=i).on(i, wrap=False), AdditiveSharingTensor: lambda i, **kwargs: AdditiveSharingTensor(**kwargs).on(i, wrap=False), MultiPointerTensor: lambda i, **kwargs: MultiPointerTensor(**kwargs).on(i, wrap=False), "my_syft_tensor_type": lambda i, **kwargs: "my_syft_tensor_type(**kwargs).on(i, wrap=False)", } # Methods or functions whose signature changes a lot and that we don't want to "cache", because # they have an arbitrary number of tensors in args which can trigger unexpected behaviour variable_methods = {"__getitem__", "_getitem_public", "view", "permute", "add_", "sub_"} variable_functions = {"torch.unbind", "unbind", "torch.stack", "stack", "torch.mean", "torch.sum"} def hook_method_args(attr, method_self, args, kwargs): """Method arguments are sometimes simple types (such as strings or ints) but sometimes they are custom Syft tensors such as wrappers (torch.Tensor) or LoggingTensor or some other tensor type. Complex types (which have a .child attribute) need to