Пример #1
0
def test_pointer_tensor_simplify():
    """Test the simplification of PointerTensor"""

    alice = syft.VirtualWorker(syft.torch.hook, id="alice")
    input_tensor = pointers.PointerTensor(id=1000, location=alice, owner=alice)

    output = serde._simplify(input_tensor)

    assert output[1][0] == input_tensor.id
    assert output[1][1] == input_tensor.id_at_location
    assert output[1][2] == input_tensor.owner.id
Пример #2
0
def test_pointer_tensor(hook, workers):
    serde._apply_compress_scheme = serde.apply_no_compression
    t = pointers.PointerTensor(id=1000,
                               location=workers["alice"],
                               owner=workers["alice"],
                               id_at_location=12345)
    t_serialized = serde.serialize(t)
    t_serialized_deserialized = serde.deserialize(t_serialized)
    assert t.id == t_serialized_deserialized.id
    assert t.location.id == t_serialized_deserialized.location.id
    assert t.id_at_location == t_serialized_deserialized.id_at_location
Пример #3
0
def _detail_pointer_tensor(worker: AbstractWorker,
                           tensor_tuple: tuple) -> pointers.PointerTensor:
    """
    This function reconstructs a PointerTensor given it's attributes in form of a dictionary.
    We use the spread operator to pass the dict data as arguments
    to the init method of PointerTensor
    Args:
        worker: the worker doing the deserialization
        tensor_tuple: a tuple holding the attributes of the PointerTensor
    Returns:
        PointerTensor: a pointers.PointerTensor
    Examples:
        ptr = _detail_pointer_tensor(data)
    """
    # TODO: fix comment for this and simplifier
    obj_id, id_at_location, worker_id, point_to_attr, shape, garbage_collect_data = tensor_tuple

    if isinstance(worker_id, bytes):
        worker_id = worker_id.decode()

    if shape is not None:
        shape = torch.Size(shape)

    # If the pointer received is pointing at the current worker, we load the tensor instead
    if worker_id == worker.id:
        tensor = worker.get_obj(id_at_location)

        if point_to_attr is not None and tensor is not None:

            point_to_attrs = point_to_attr.decode("utf-8").split(".")
            for attr in point_to_attrs:
                if len(attr) > 0:
                    tensor = getattr(tensor, attr)

            if tensor is not None:

                if not tensor.is_wrapper and not isinstance(
                        tensor, torch.Tensor):
                    # if the tensor is a wrapper then it doesn't need to be wrapped
                    # i the tensor isn't a wrapper, BUT it's just a plain torch tensor,
                    # then it doesn't need to be wrapped.
                    # if the tensor is not a wrapper BUT it's also not a torch tensor,
                    # then it needs to be wrapped or else it won't be able to be used
                    # by other interfaces
                    tensor = tensor.wrap()

        return tensor
    # Else we keep the same Pointer
    else:

        location = syft.torch.hook.local_worker.get_worker(worker_id)

        ptr = pointers.PointerTensor(
            location=location,
            id_at_location=id_at_location,
            owner=worker,
            id=obj_id,
            shape=shape,
            garbage_collect_data=garbage_collect_data,
        )

        return ptr
Пример #4
0
def test_build_rule_syft_tensors_and_pointers():
    pointer = pointers.PointerTensor(
        id=1000, location="location", owner="owner", garbage_collect_data=False
    )
    result = hook_args.build_rule(([torch.tensor([1, 2]), pointer], 42))
    assert result == ([1, 1], 0)