Exemple #1
0
def test_method_on_log_chain():
    """
    Test method call on a chain including a log tensor
    """

    # Build a long chain tensor Wrapper>LoggingTensor>TorchTensor
    x_tensor = torch.Tensor([1, 2, 3])
    x = LoggingTensor().on(x_tensor)
    y = x.mul(x)

    assert (y.child.child == x_tensor.mul(x_tensor)).all()
Exemple #2
0
def test_print_log_chain():
    """
    Test sending and getting back a chain including a logtensor
    """

    # Build a long chain tensor Wrapper>LoggingTensor>TorchTensor
    x_tensor = torch.Tensor([1, 2, 3])
    x = LoggingTensor().on(x_tensor)

    assert isinstance(x.__str__(), str)
    assert isinstance(x.__repr__(), str)
Exemple #3
0
def test_send_get_log_chain(workers):
    """
    Test sending and getting back a chain including a logtensor
    """

    # Build a long chain tensor Wrapper>LoggingTensor>TorchTensor
    x_tensor = torch.Tensor([1, 2, 3])
    x = LoggingTensor().on(x_tensor)
    x_ptr = x.send(workers["bob"])
    x_back = x_ptr.get()

    assert (x_back.child.child == x_tensor).all()
Exemple #4
0
def test_remote_method_on_log_chain(workers):
    """
    Test remote method call on a chain including a log tensor
    """

    # Build a long chain tensor Wrapper>LoggingTensor>TorchTensor
    x_tensor = torch.Tensor([1, 2, 3])
    x = LoggingTensor().on(x_tensor)
    x_ptr = x.send(workers["bob"])
    y_ptr = F.relu(x_ptr)
    y = y_ptr.get()

    assert (y.child.child == F.relu(x_tensor)).all()
Exemple #5
0
def test_function_on_log_chain():
    """
    Test torch function call on a chain including a log tensor
    """

    x = LoggingTensor().on(torch.Tensor([1, -1, 3]))
    y = F.relu(x)

    assert (y.child.child == torch.Tensor([1, 0, 3])).all()
Exemple #6
0
def _detail_log_tensor(worker: AbstractWorker, tensor_tuple: tuple) -> LoggingTensor:
    """
    This function reconstructs a LogTensor given it's attributes in form of a tuple.
    Args:
        worker: the worker doing the deserialization
        tensor_tuple: a tuple holding the attributes of the LogTensor
    Returns:
        LoggingTensor: a LogTensor
    Examples:
        logtensor = _detail_log_tensor(data)
    """
    obj_id, chain = tensor_tuple

    tensor = LoggingTensor(owner=worker, id=obj_id)

    if chain is not None:
        chain = _detail(worker, chain)
        tensor.child = chain

    return tensor
Exemple #7
0
def test_wrap():
    """
    Test the .on() wrap functionality for LoggingTensor
    """

    x_tensor = torch.Tensor([1, 2, 3])
    x = LoggingTensor().on(x_tensor)

    assert isinstance(x, torch.Tensor)
    assert isinstance(x.child, LoggingTensor)
    assert isinstance(x.child.child, torch.Tensor)
Exemple #8
0
def test_hook_module_functional_on_log_chain(attr):
    """
    Test torch function call on a chain including a log tensor
    """

    attr = getattr(F, attr)
    x = torch.Tensor([1, -1, 3, 4])
    expected = attr(x)

    x_log = LoggingTensor().on(x)
    res_log = attr(x_log)
    res = res_log.child.child

    assert (res == expected).all()
Exemple #9
0
def test_implicit_garbage_collect_logging_on_pointer(workers):
    """
    Tests whether GCing a LoggingTensor on a PointerTensor
    garbage collects the remote object too
    """
    x = torch.Tensor([1, 2])
    x_id = x.id

    x = x.send(workers["bob"])
    x = LoggingTensor().on(x)
    assert x_id in workers["bob"]._objects

    x = "open-source"

    assert x_id not in workers["bob"]._objects
Exemple #10
0
def test_explicit_garbage_collect_logging_on_pointer(workers):
    """
    Tests whether deleting a LoggingTensor on a PointerTensor
    garbage collects the remote object too
    """
    bob = workers["bob"]

    x = torch.Tensor([1, 2])
    x_id = x.id

    x = x.send(bob)
    x = LoggingTensor().on(x)
    assert x_id in bob._objects

    del x

    assert x_id not in bob._objects
Exemple #11
0
    LoggingTensor: lambda i: i.child,
    FixedPrecisionTensor: lambda i: i.child,
    AutogradTensor: lambda i: i.child,
    AdditiveSharingTensor: lambda i: i.child,
    MultiPointerTensor: lambda i: i.child,
    LargePrecisionTensor: lambda i: i._internal_representation_to_large_ints(),
    "my_syft_tensor_type": lambda i: i.child,
}

# Dict to return the proper lambda function for the right torch or syft tensor type
backward_func = {
    TorchTensor: lambda i: i.wrap(),
    torch.Tensor: lambda i: i.wrap(),
    torch.nn.Parameter: lambda i: torch.nn.Parameter(data=i),
    PointerTensor: lambda i: i,
    LoggingTensor: lambda i: LoggingTensor().on(i, wrap=False),
    FixedPrecisionTensor: lambda i, **kwargs: FixedPrecisionTensor(**kwargs).on(i, wrap=False),
    LargePrecisionTensor: lambda i, **kwargs: LargePrecisionTensor(**kwargs).on(
        LargePrecisionTensor.create_tensor_from_numpy(i, **kwargs), wrap=False
    ),
    AutogradTensor: lambda i: AutogradTensor(data=i).on(i, wrap=False),
    AdditiveSharingTensor: lambda i, **kwargs: AdditiveSharingTensor(**kwargs).on(i, wrap=False),
    MultiPointerTensor: lambda i, **kwargs: MultiPointerTensor(**kwargs).on(i, wrap=False),
    "my_syft_tensor_type": lambda i, **kwargs: "my_syft_tensor_type(**kwargs).on(i, wrap=False)",
}

# Methods or functions whose signature changes a lot and that we don't want to "cache", because
# they have an arbitrary number of tensors in args which can trigger unexpected behaviour
variable_methods = {"__getitem__", "_getitem_public", "view", "permute", "add_", "sub_"}
variable_functions = {"torch.unbind", "unbind", "torch.stack", "stack", "torch.mean", "torch.sum"}