def test_method_on_log_chain(): """ Test method call on a chain including a log tensor """ # Build a long chain tensor Wrapper>LoggingTensor>TorchTensor x_tensor = torch.Tensor([1, 2, 3]) x = LoggingTensor().on(x_tensor) y = x.mul(x) assert (y.child.child == x_tensor.mul(x_tensor)).all()
def test_print_log_chain(): """ Test sending and getting back a chain including a logtensor """ # Build a long chain tensor Wrapper>LoggingTensor>TorchTensor x_tensor = torch.Tensor([1, 2, 3]) x = LoggingTensor().on(x_tensor) assert isinstance(x.__str__(), str) assert isinstance(x.__repr__(), str)
def test_send_get_log_chain(workers): """ Test sending and getting back a chain including a logtensor """ # Build a long chain tensor Wrapper>LoggingTensor>TorchTensor x_tensor = torch.Tensor([1, 2, 3]) x = LoggingTensor().on(x_tensor) x_ptr = x.send(workers["bob"]) x_back = x_ptr.get() assert (x_back.child.child == x_tensor).all()
def test_remote_method_on_log_chain(workers): """ Test remote method call on a chain including a log tensor """ # Build a long chain tensor Wrapper>LoggingTensor>TorchTensor x_tensor = torch.Tensor([1, 2, 3]) x = LoggingTensor().on(x_tensor) x_ptr = x.send(workers["bob"]) y_ptr = F.relu(x_ptr) y = y_ptr.get() assert (y.child.child == F.relu(x_tensor)).all()
def test_function_on_log_chain(): """ Test torch function call on a chain including a log tensor """ x = LoggingTensor().on(torch.Tensor([1, -1, 3])) y = F.relu(x) assert (y.child.child == torch.Tensor([1, 0, 3])).all()
def _detail_log_tensor(worker: AbstractWorker, tensor_tuple: tuple) -> LoggingTensor: """ This function reconstructs a LogTensor given it's attributes in form of a tuple. Args: worker: the worker doing the deserialization tensor_tuple: a tuple holding the attributes of the LogTensor Returns: LoggingTensor: a LogTensor Examples: logtensor = _detail_log_tensor(data) """ obj_id, chain = tensor_tuple tensor = LoggingTensor(owner=worker, id=obj_id) if chain is not None: chain = _detail(worker, chain) tensor.child = chain return tensor
def test_wrap(): """ Test the .on() wrap functionality for LoggingTensor """ x_tensor = torch.Tensor([1, 2, 3]) x = LoggingTensor().on(x_tensor) assert isinstance(x, torch.Tensor) assert isinstance(x.child, LoggingTensor) assert isinstance(x.child.child, torch.Tensor)
def test_hook_module_functional_on_log_chain(attr): """ Test torch function call on a chain including a log tensor """ attr = getattr(F, attr) x = torch.Tensor([1, -1, 3, 4]) expected = attr(x) x_log = LoggingTensor().on(x) res_log = attr(x_log) res = res_log.child.child assert (res == expected).all()
def test_implicit_garbage_collect_logging_on_pointer(workers): """ Tests whether GCing a LoggingTensor on a PointerTensor garbage collects the remote object too """ x = torch.Tensor([1, 2]) x_id = x.id x = x.send(workers["bob"]) x = LoggingTensor().on(x) assert x_id in workers["bob"]._objects x = "open-source" assert x_id not in workers["bob"]._objects
def test_explicit_garbage_collect_logging_on_pointer(workers): """ Tests whether deleting a LoggingTensor on a PointerTensor garbage collects the remote object too """ bob = workers["bob"] x = torch.Tensor([1, 2]) x_id = x.id x = x.send(bob) x = LoggingTensor().on(x) assert x_id in bob._objects del x assert x_id not in bob._objects
LoggingTensor: lambda i: i.child, FixedPrecisionTensor: lambda i: i.child, AutogradTensor: lambda i: i.child, AdditiveSharingTensor: lambda i: i.child, MultiPointerTensor: lambda i: i.child, LargePrecisionTensor: lambda i: i._internal_representation_to_large_ints(), "my_syft_tensor_type": lambda i: i.child, } # Dict to return the proper lambda function for the right torch or syft tensor type backward_func = { TorchTensor: lambda i: i.wrap(), torch.Tensor: lambda i: i.wrap(), torch.nn.Parameter: lambda i: torch.nn.Parameter(data=i), PointerTensor: lambda i: i, LoggingTensor: lambda i: LoggingTensor().on(i, wrap=False), FixedPrecisionTensor: lambda i, **kwargs: FixedPrecisionTensor(**kwargs).on(i, wrap=False), LargePrecisionTensor: lambda i, **kwargs: LargePrecisionTensor(**kwargs).on( LargePrecisionTensor.create_tensor_from_numpy(i, **kwargs), wrap=False ), AutogradTensor: lambda i: AutogradTensor(data=i).on(i, wrap=False), AdditiveSharingTensor: lambda i, **kwargs: AdditiveSharingTensor(**kwargs).on(i, wrap=False), MultiPointerTensor: lambda i, **kwargs: MultiPointerTensor(**kwargs).on(i, wrap=False), "my_syft_tensor_type": lambda i, **kwargs: "my_syft_tensor_type(**kwargs).on(i, wrap=False)", } # Methods or functions whose signature changes a lot and that we don't want to "cache", because # they have an arbitrary number of tensors in args which can trigger unexpected behaviour variable_methods = {"__getitem__", "_getitem_public", "view", "permute", "add_", "sub_"} variable_functions = {"torch.unbind", "unbind", "torch.stack", "stack", "torch.mean", "torch.sum"}