示例#1
0
def _detail_pointer_tensor(worker: AbstractWorker,
                           tensor_tuple: tuple) -> PointerTensor:
    """
    This function reconstructs a PointerTensor given it's attributes in form of a dictionary.
    We use the spread operator to pass the dict data as arguments
    to the init method of PointerTensor
    Args:
        worker: the worker doing the deserialization
        tensor_tuple: a tuple holding the attributes of the PointerTensor
    Returns:
        PointerTensor: a PointerTensor
    Examples:
        ptr = _detail_pointer_tensor(data)
    """
    # TODO: fix comment for this and simplifier
    obj_id = tensor_tuple[0]
    id_at_location = tensor_tuple[1]
    worker_id = tensor_tuple[2].decode("utf-8")
    point_to_attr = tensor_tuple[3]
    shape = tensor_tuple[4]

    if shape is not None:
        shape = torch.Size(shape)

    # If the pointer received is pointing at the current worker, we load the tensor instead
    if worker_id == worker.id:

        tensor = worker.get_obj(id_at_location)

        if point_to_attr is not None and tensor is not None:

            point_to_attrs = point_to_attr.decode("utf-8").split(".")
            for attr in point_to_attrs:
                if len(attr) > 0:
                    tensor = getattr(tensor, attr)

            if tensor is not None:

                if not tensor.is_wrapper and not isinstance(
                        tensor, torch.Tensor):

                    # if the tensor is a wrapper then it doesn't need to be wrapped
                    # i the tensor isn't a wrapper, BUT it's just a plain torch tensor,
                    # then it doesn't need to be wrapped.
                    # if the tensor is not a wrapper BUT it's also not a torch tensor,
                    # then it needs to be wrapped or else it won't be able to be used
                    # by other interfaces
                    tensor = tensor.wrap()

        return tensor
    # Else we keep the same Pointer
    else:
        location = syft.torch.hook.local_worker.get_worker(worker_id)
        return PointerTensor(location=location,
                             id_at_location=id_at_location,
                             owner=worker,
                             id=obj_id,
                             shape=shape)
示例#2
0
def test_pointer_found_exception(workers):
    ptr_id = int(10e10 * random.random())
    pointer = PointerTensor(id=ptr_id, location=workers["alice"], owner=workers["me"])

    try:
        raise RemoteTensorFoundError(pointer)
    except RemoteTensorFoundError as err:
        err_pointer = err.pointer
        assert isinstance(err_pointer, PointerTensor)
        assert err_pointer.id == ptr_id
示例#3
0
def test_pointer_tensor_simplify():
    """Test the simplification of PointerTensor"""

    alice = syft.VirtualWorker(syft.torch.hook, id="alice")
    input_tensor = PointerTensor(id=1000, location=alice, owner=alice)

    output = _simplify(input_tensor)

    assert output[1][0] == input_tensor.id
    assert output[1][1] == input_tensor.id_at_location
    assert output[1][2] == input_tensor.owner.id
示例#4
0
def test_PointerTensor(hook, workers):
    t = PointerTensor(
        id=1000, location=workers["alice"], owner=workers["alice"], id_at_location=12345
    )
    t_serialized = serialize(t, compress=False)
    t_serialized_deserialized = deserialize(t_serialized, compressed=False)
    print(f"t.location - {t.location}")
    print(f"t_serialized_deserialized.location - {t_serialized_deserialized.location}")
    assert t.id == t_serialized_deserialized.id
    assert t.location.id == t_serialized_deserialized.location.id
    assert t.id_at_location == t_serialized_deserialized.id_at_location
示例#5
0
def test_pointer_found_exception(workers):
    ptr_id = syft.ID_PROVIDER.pop()
    pointer = PointerTensor(id=ptr_id,
                            location=workers["alice"],
                            owner=workers["me"])

    try:
        raise RemoteTensorFoundError(pointer)
    except RemoteTensorFoundError as err:
        err_pointer = err.pointer
        assert isinstance(err_pointer, PointerTensor)
        assert err_pointer.id == ptr_id
示例#6
0
def test_PointerTensor(hook, workers):
    syft.serde._apply_compress_scheme = apply_no_compression
    t = PointerTensor(
        id=1000, location=workers["alice"], owner=workers["alice"], id_at_location=12345
    )
    t_serialized = serialize(t)
    t_serialized_deserialized = deserialize(t_serialized)
    print(f"t.location - {t.location}")
    print(f"t_serialized_deserialized.location - {t_serialized_deserialized.location}")
    assert t.id == t_serialized_deserialized.id
    assert t.location.id == t_serialized_deserialized.location.id
    assert t.id_at_location == t_serialized_deserialized.id_at_location
示例#7
0
def test_init(workers):
    pointer = PointerTensor(id=1000, location=workers["alice"], owner=workers["me"])
    pointer.__str__()
示例#8
0
文件: native.py 项目: tyrinwu/PySyft
    def create_pointer(
        self,
        location: BaseWorker = None,
        id_at_location: (str or int) = None,
        register: bool = False,
        owner: BaseWorker = None,
        ptr_id: (str or int) = None,
        garbage_collect_data: bool = True,
        shape=None,
    ) -> PointerTensor:
        """Creates a pointer to the "self" torch.Tensor object.

        This method is called on a torch.Tensor object, returning a pointer
        to that object. This method is the CORRECT way to create a pointer,
        and the parameters of this method give all possible attributes that
        a pointer can be created with.

        Args:
            location: The BaseWorker object which points to the worker on which
                this pointer's object can be found. In nearly all cases, this
                is self.owner and so this attribute can usually be left blank.
                Very rarely you may know that you are about to move the Tensor
                to another worker so you can pre-initialize the location
                attribute of the pointer to some other worker, but this is a
                rare exception.
            id_at_location: A string or integer id of the tensor being pointed
                to. Similar to location, this parameter is almost always
                self.id and so you can leave this parameter to None. The only
                exception is if you happen to know that the ID is going to be
                something different than self.id, but again this is very rare
                and most of the time, setting this means that you are probably
                doing something you shouldn't.
            register: A boolean parameter (default False) that determines
                whether to register the new pointer that gets created. This is
                set to false by default because most of the time a pointer is
                initialized in this way so that it can be sent to someone else
                (i.e., "Oh you need to point to my tensor? let me create a
                pointer and send it to you" ). Thus, when a pointer gets
                created, we want to skip being registered on the local worker
                because the pointer is about to be sent elsewhere. However, if
                you are initializing a pointer you intend to keep, then it is
                probably a good idea to register it, especially if there is any
                chance that someone else will initialize a pointer to your
                pointer.
            owner: A BaseWorker parameter to specify the worker on which the
                pointer is located. It is also where the pointer is registered
                if register is set to True.
            ptr_id: A string or integer parameter to specify the id of the pointer
                in case you wish to set it manually for any special reason.
                Otherwise, it will be set randomly.
            garbage_collect_data: If true (default), delete the remote tensor when the
                pointer is deleted.

        Returns:
            A torch.Tensor[PointerTensor] pointer to self. Note that this
            object will likely be wrapped by a torch.Tensor wrapper.
        """
        if owner is None:
            owner = self.owner

        if location is None:
            location = self.owner.id

        owner = self.owner.get_worker(owner)
        location = self.owner.get_worker(location)

        if id_at_location is None:
            id_at_location = self.id

        if ptr_id is None:
            if location.id != self.owner.id:
                ptr_id = self.id
            else:
                ptr_id = int(10e10 * random.random())

        if shape is None:
            shape = self.shape

        # previous_pointer = owner.get_pointer_to(location, id_at_location)
        previous_pointer = None

        if previous_pointer is None:
            ptr = PointerTensor(
                parent=self,
                location=location,
                id_at_location=id_at_location,
                register=register,
                owner=owner,
                id=ptr_id,
                garbage_collect_data=garbage_collect_data,
                shape=shape,
                tags=self.tags,
                description=self.description,
            )

        return ptr