def test_pointer_tensor_simplify(): """Test the simplification of PointerTensor""" alice = syft.VirtualWorker(syft.torch.hook, id="alice") input_tensor = pointers.PointerTensor(id=1000, location=alice, owner=alice) output = serde._simplify(input_tensor) assert output[1][0] == input_tensor.id assert output[1][1] == input_tensor.id_at_location assert output[1][2] == input_tensor.owner.id
def test_pointer_tensor(hook, workers): serde._apply_compress_scheme = serde.apply_no_compression t = pointers.PointerTensor(id=1000, location=workers["alice"], owner=workers["alice"], id_at_location=12345) t_serialized = serde.serialize(t) t_serialized_deserialized = serde.deserialize(t_serialized) assert t.id == t_serialized_deserialized.id assert t.location.id == t_serialized_deserialized.location.id assert t.id_at_location == t_serialized_deserialized.id_at_location
def _detail_pointer_tensor(worker: AbstractWorker, tensor_tuple: tuple) -> pointers.PointerTensor: """ This function reconstructs a PointerTensor given it's attributes in form of a dictionary. We use the spread operator to pass the dict data as arguments to the init method of PointerTensor Args: worker: the worker doing the deserialization tensor_tuple: a tuple holding the attributes of the PointerTensor Returns: PointerTensor: a pointers.PointerTensor Examples: ptr = _detail_pointer_tensor(data) """ # TODO: fix comment for this and simplifier obj_id, id_at_location, worker_id, point_to_attr, shape, garbage_collect_data = tensor_tuple if isinstance(worker_id, bytes): worker_id = worker_id.decode() if shape is not None: shape = torch.Size(shape) # If the pointer received is pointing at the current worker, we load the tensor instead if worker_id == worker.id: tensor = worker.get_obj(id_at_location) if point_to_attr is not None and tensor is not None: point_to_attrs = point_to_attr.decode("utf-8").split(".") for attr in point_to_attrs: if len(attr) > 0: tensor = getattr(tensor, attr) if tensor is not None: if not tensor.is_wrapper and not isinstance( tensor, torch.Tensor): # if the tensor is a wrapper then it doesn't need to be wrapped # i the tensor isn't a wrapper, BUT it's just a plain torch tensor, # then it doesn't need to be wrapped. # if the tensor is not a wrapper BUT it's also not a torch tensor, # then it needs to be wrapped or else it won't be able to be used # by other interfaces tensor = tensor.wrap() return tensor # Else we keep the same Pointer else: location = syft.torch.hook.local_worker.get_worker(worker_id) ptr = pointers.PointerTensor( location=location, id_at_location=id_at_location, owner=worker, id=obj_id, shape=shape, garbage_collect_data=garbage_collect_data, ) return ptr
def test_build_rule_syft_tensors_and_pointers(): pointer = pointers.PointerTensor( id=1000, location="location", owner="owner", garbage_collect_data=False ) result = hook_args.build_rule(([torch.tensor([1, 2]), pointer], 42)) assert result == ([1, 1], 0)