예제 #1
0
def test_wrap():
    """
    Test the .on() wrap functionality for AutogradTensor
    """

    x_tensor = torch.Tensor([1, 2, 3])
    x = AutogradTensor().on(x_tensor)

    assert isinstance(x, torch.Tensor)
    assert isinstance(x.child, AutogradTensor)
    assert isinstance(x.child.child, torch.Tensor)
예제 #2
0
    LargePrecisionTensor: lambda i: i._internal_representation_to_large_ints(),
    "my_syft_tensor_type": lambda i: i.child,
}

# Dict to return the proper lambda function for the right torch or syft tensor type
backward_func = {
    TorchTensor: lambda i: i.wrap(),
    torch.Tensor: lambda i: i.wrap(),
    torch.nn.Parameter: lambda i: torch.nn.Parameter(data=i),
    PointerTensor: lambda i: i,
    LoggingTensor: lambda i: LoggingTensor().on(i, wrap=False),
    FixedPrecisionTensor: lambda i, **kwargs: FixedPrecisionTensor(**kwargs).on(i, wrap=False),
    LargePrecisionTensor: lambda i, **kwargs: LargePrecisionTensor(**kwargs).on(
        LargePrecisionTensor.create_tensor_from_numpy(i, **kwargs), wrap=False
    ),
    AutogradTensor: lambda i: AutogradTensor(data=i).on(i, wrap=False),
    AdditiveSharingTensor: lambda i, **kwargs: AdditiveSharingTensor(**kwargs).on(i, wrap=False),
    MultiPointerTensor: lambda i, **kwargs: MultiPointerTensor(**kwargs).on(i, wrap=False),
    "my_syft_tensor_type": lambda i, **kwargs: "my_syft_tensor_type(**kwargs).on(i, wrap=False)",
}

# Methods or functions whose signature changes a lot and that we don't want to "cache", because
# they have an arbitrary number of tensors in args which can trigger unexpected behaviour
variable_methods = {"__getitem__", "_getitem_public", "view", "permute", "add_", "sub_"}
variable_functions = {"torch.unbind", "unbind", "torch.stack", "stack", "torch.mean", "torch.sum"}


def hook_method_args(attr, method_self, args, kwargs):
    """Method arguments are sometimes simple types (such as strings or ints) but
    sometimes they are custom Syft tensors such as wrappers (torch.Tensor) or LoggingTensor
    or some other tensor type. Complex types (which have a .child attribute) need to