Ejemplo n.º 1
0
def _detail_plan_pointer(worker: AbstractWorker,
                         plan_pointer_tuple: tuple) -> PointerTensor:
    """
    This function reconstructs a PlanPointer given it's attributes in form of a tuple.

    Args:
        worker: the worker doing the deserialization
        plan_pointer_tuple: a tuple holding the attributes of the PlanPointer
    Returns:
        PointerTensor: a PointerTensor
    Examples:
        ptr = _detail_pointer_tensor(data)
    """
    # TODO: fix comment for this and simplifier
    obj_id = plan_pointer_tuple[0]
    id_at_location = plan_pointer_tuple[1]
    if isinstance(id_at_location, bytes):
        id_at_location = id_at_location.decode("utf-8")
    worker_id = plan_pointer_tuple[2].decode("utf-8")

    # If the pointer received is pointing at the current worker, we load the tensor instead
    if worker_id == worker.id:

        tensor = worker.get_obj(id_at_location)

        return tensor
    # Else we keep the same Pointer
    else:
        location = syft.torch.hook.local_worker.get_worker(worker_id)
        ptr = PlanPointer(location=location,
                          id_at_location=id_at_location,
                          owner=worker,
                          id=obj_id,
                          register=True)
        return ptr
Ejemplo n.º 2
0
def _detail_pointer_tensor(worker: AbstractWorker,
                           tensor_tuple: tuple) -> PointerTensor:
    """
    This function reconstructs a PointerTensor given it's attributes in form of a dictionary.
    We use the spread operator to pass the dict data as arguments
    to the init method of PointerTensor
    Args:
        worker: the worker doing the deserialization
        tensor_tuple: a tuple holding the attributes of the PointerTensor
    Returns:
        PointerTensor: a PointerTensor
    Examples:
        ptr = _detail_pointer_tensor(data)
    """
    # TODO: fix comment for this and simplifier
    obj_id = tensor_tuple[0]
    id_at_location = tensor_tuple[1]
    worker_id = tensor_tuple[2].decode("utf-8")
    point_to_attr = tensor_tuple[3]
    shape = tensor_tuple[4]

    if shape is not None:
        shape = torch.Size(shape)

    # If the pointer received is pointing at the current worker, we load the tensor instead
    if worker_id == worker.id:

        tensor = worker.get_obj(id_at_location)

        if point_to_attr is not None and tensor is not None:

            point_to_attrs = point_to_attr.decode("utf-8").split(".")
            for attr in point_to_attrs:
                if len(attr) > 0:
                    tensor = getattr(tensor, attr)

            if tensor is not None:

                if not tensor.is_wrapper and not isinstance(
                        tensor, torch.Tensor):

                    # if the tensor is a wrapper then it doesn't need to be wrapped
                    # i the tensor isn't a wrapper, BUT it's just a plain torch tensor,
                    # then it doesn't need to be wrapped.
                    # if the tensor is not a wrapper BUT it's also not a torch tensor,
                    # then it needs to be wrapped or else it won't be able to be used
                    # by other interfaces
                    tensor = tensor.wrap()

        return tensor
    # Else we keep the same Pointer
    else:
        location = syft.torch.hook.local_worker.get_worker(worker_id)
        return PointerTensor(location=location,
                             id_at_location=id_at_location,
                             owner=worker,
                             id=obj_id,
                             shape=shape)