def _simplify_dictionary(worker: AbstractWorker, my_dict: Dict, shallow: bool = False) -> Tuple: """ This function is designed to search a dict for any objects which may need to be simplified (i.e., torch tensors). It iterates through each key, value in the dict and calls _simplify on it. Finally, it returns the output tuple of tuples containing key/value pairs. The reverse function to this function is _detail_dictionary, which undoes the functionality of this function. Args: my_dict: A dictionary of python objects. Returns: Tuple: Tuple containing tuples of simplified key/value pairs from the input dictionary. """ pieces = [] # for dictionaries we want to simplify both the key and the value for key, value in my_dict.items(): pieces.append( (serde._simplify(worker, key), serde._simplify(worker, value) if not shallow else value)) return tuple(pieces)
def simplify(worker: BaseWorker, subpipeline: "SubPipeline") -> tuple: """Simplifies a SubPipeline object. This requires simplifying each underlying pipe component. Args: worker (BaseWorker): The worker on which the simplify operation is carried out. subpipeline (SupPipeline): the SubPipeline object to simplify. Returns: (tuple): The simplified SubPipeline object. """ # Simplify the attributes and pipe components id = serde._simplify(worker, subpipeline.id) client_id = serde._simplify(worker, subpipeline.client_id) pipe_names = serde._simplify(worker, subpipeline.pipe_names) # A list to store the simplified pipes simple_pipes = [] # Simplify each pipe for pipe in subpipeline.subpipeline: # Get the msgpack code of the pipe proto_id = pipe.get_msgpack_code()["code"] simple_pipes.append((proto_id, pipe.simplify(worker, pipe))) return (id, client_id, pipe_names, simple_pipes)
def _simplify_torch_tensor(worker: AbstractWorker, tensor: torch.Tensor) -> bin: """ This function converts a torch tensor into a serliaized torch tensor using pickle. We choose to use this because PyTorch has a custom and very fast PyTorch pickler. Args: tensor (torch.Tensor): an input tensor to be serialized Returns: tuple: serialized tuple of torch tensor. The first value is the id of the tensor and the second is the binary for the PyTorch object. The third is the chain of abstractions, and the fourth (optinally) is the chain of graident tensors (nested tuple) """ tensor_bin = _serialize_tensor(worker, tensor) # note we need to do this explicitly because torch.save does not # seem to be including .grad by default if tensor.grad is not None: if hasattr(tensor, "child"): if isinstance(tensor.child, PointerTensor): grad_chain = None else: grad_chain = _simplify_torch_tensor(worker, tensor.grad) else: grad_chain = _simplify_torch_tensor(worker, tensor.grad) else: grad_chain = None chain = None # I think the pointer bug is is between here if hasattr(tensor, "child"): chain = serde._simplify(worker, tensor.child) # and here... leaving a reerence here so i can find it later # TODO fix pointer bug return ( tensor.id, tensor_bin, chain, grad_chain, serde._simplify(worker, tensor.tags), serde._simplify(worker, tensor.description), serde._simplify(worker, worker.serializer), )
def _simplify_torch_parameter(worker: AbstractWorker, param: torch.nn.Parameter) -> bin: """ This function converts a torch Parameter into a serialized torch Parameter Args: param (torch.nn.Parameter): an input Parameter to be serialized Returns: tuple: serialized tuple of torch Parameter. The first value is the id of the Parameter and the second is the binary for the PyTorch tensor data attribute and last is the requires_grad attr. """ tensor = param.data tensor_ser = serde._simplify(worker, tensor) grad = param.grad if grad is not None and not (hasattr(grad, "child") and isinstance(grad.child, PointerTensor)): grad_ser = _simplify_torch_tensor(worker, grad) else: grad_ser = None return (param.id, tensor_ser, param.requires_grad, grad_ser)
def _simplify_collection(worker: AbstractWorker, my_collection: Collection, shallow: bool = False) -> Tuple: """ This function is designed to search a collection for any objects which may need to be simplified (i.e., torch tensors). It iterates through each object in the collection and calls _simplify on it. Finally, it returns the output as the tuple of simplified items of the input collection. This function is used to simplify list, set, and tuple. The reverse function, which undoes the functionality of this function is different for each of these types: _detail_collection_list, _detail_collection_set, _detail_collection_tuple. Args: my_collection (Collection): a collection of python objects Returns: Tuple: a tuple with simplified objects. """ # Don't simplify contents if shallow: return tuple(my_collection) # Step 0: initialize empty list pieces = [] # Step 1: serialize each part of the collection for part in my_collection: pieces.append(serde._simplify(worker, part)) # Step 2: return serialization as tuple of simplified items return tuple(pieces)
def test_plan_torch_function_no_args(workers): bob, alice = workers["bob"], workers["alice"] from syft.serde.msgpack import serde @sy.func2plan(args_shape=[(1, )]) def serde_plan(x): y = th.tensor([-1]) z = x + y return z serde_plan_simplified = serde._simplify(bob, serde_plan) serde_plan_detailed = serde._detail(bob, serde_plan_simplified) t = th.tensor([1.0]) expected = serde_plan(t) actual = serde_plan_detailed(t) assert actual == expected == th.tensor([0.0]) @sy.func2plan(args_shape=[(1, )]) def serde_plan(x): y = th.arange(3) z = y + x return z serde_plan_simplified = serde._simplify(bob, serde_plan) serde_plan_detailed = serde._detail(bob, serde_plan_simplified) t = th.tensor([1.0]) expected = serde_plan(t) actual = serde_plan_detailed(t) assert (actual == expected).all() assert (actual == th.tensor([1, 2, 3])).all() @sy.func2plan(args_shape=[(1, )]) def serde_plan(x): th.manual_seed(14) y = th.randint(2, size=(1, ), dtype=th.uint8) y = y + 10 return y serde_plan_simplified = serde._simplify(bob, serde_plan) serde_plan_detailed = serde._detail(bob, serde_plan_simplified) t = th.tensor([1.0]) expected = serde_plan(t) actual = serde_plan_detailed(t) assert actual == expected and actual >= 10
def _simplify_ndarray(worker: AbstractWorker, my_array: numpy.ndarray) -> Tuple[bin, Tuple, Tuple]: """ This function gets the byte representation of the array and stores the dtype and shape for reconstruction Args: my_array (numpy.ndarray): a numpy array Returns: list: a list holding the byte representation, shape and dtype of the array Examples: arr_representation = _simplify_ndarray(numpy.random.random([1000, 1000]))) """ arr_bytes = my_array.tobytes() arr_shape = serde._simplify(worker, my_array.shape) arr_dtype = serde._simplify(worker, my_array.dtype.name) return (arr_bytes, arr_shape, arr_dtype)
def simplified_tensor_serializer(worker: AbstractWorker, tensor: torch.Tensor) -> tuple: """Strategy to serialize a tensor to native python types. If tensor requires to calculate gradients, it will be detached. """ if tensor.requires_grad: warnings.warn( "Torch to native serializer can only be used with tensors that do not require grad. " "Detaching tensor to continue" ) tensor = tensor.detach() tensor_tuple = (tuple(tensor.size()), TORCH_DTYPE_STR[tensor.dtype], tensor.flatten().tolist()) return serde._simplify(worker, tensor_tuple)
def simplify(worker: BaseWorker, simple_tagger: "SimpleTagger"): """Simplifies a SimpleTagger object. Args: worker (BaseWorker): The worker on which the simplify operation is carried out. simple_tagger (SimpleTagger): the SimpleTagger object to simplify. Returns: (tuple): The simplified SimpleTagger object. """ # Simplify the object properties attribute = serde._simplify(worker, simple_tagger.attribute) lookups = serde._simplify(worker, simple_tagger.lookups) tag = serde._simplify(worker, simple_tagger.tag) default_tag = serde._simplify(worker, simple_tagger.default_tag) case_sensitive = serde._simplify(worker, simple_tagger.case_sensitive) return (attribute, lookups, tag, default_tag, case_sensitive)
def test_plan_execute_locally_ambiguous_output(workers): bob, alice = workers["bob"], workers["alice"] @sy.func2plan(args_shape=[(1, )]) def serde_plan(x): x = x + x y = x * 2 return x serde_plan_simplified = serde._simplify(bob, serde_plan) serde_plan_detailed = serde._detail(bob, serde_plan_simplified) t = th.tensor([2.3]) expected = serde_plan(t) actual = serde_plan_detailed(t) assert actual == expected
def test_plan_several_output_action(workers): bob, alice = workers["bob"], workers["alice"] @sy.func2plan(args_shape=[(4, )]) def serde_plan(x, torch=th): y, z = torch.split(x, 2) return y + z serde_plan_simplified = serde._simplify(bob, serde_plan) serde_plan_detailed = serde._detail(bob, serde_plan_simplified) t = th.tensor([1, 2, 3, 4]) expected = serde_plan_detailed(t) actual = serde_plan_detailed(t) assert (actual == th.tensor([4, 6])).all() assert (actual == expected).all()
def test_plan_fixed_len_loop(workers): bob, alice = workers["bob"], workers["alice"] @sy.func2plan(args_shape=[(1, )]) def serde_plan(x): for i in range(10): x = x + 1 return x serde_plan_simplified = serde._simplify(bob, serde_plan) serde_plan_detailed = serde._detail(bob, serde_plan_simplified) t = th.tensor([1.0]) expected = serde_plan_detailed(t) actual = serde_plan_detailed(t) assert actual == expected
def test_plan_with_comp(workers): bob, alice = workers["bob"], workers["alice"] @sy.func2plan(args_shape=[(2, ), (2, )]) def serde_plan(x, y): z = x > y return z serde_plan_simplified = serde._simplify(bob, serde_plan) serde_plan_detailed = serde._detail(bob, serde_plan_simplified) t1 = th.tensor([2.0, 0.0]) t2 = th.tensor([1.0, 1.0]) expected = serde_plan_detailed(t1, t2) actual = serde_plan_detailed(t1, t2) assert (actual == expected).all()
def test_plan_execute_locally_ambiguous_input(workers): bob, alice = workers["bob"], workers["alice"] @sy.func2plan(args_shape=[(1, ), (1, ), (1, )]) def serde_plan(x, y, z): a = x + x # 2 b = x + z # 4 c = y + z # 5 return c, b, a # 5, 4, 2 serde_plan_simplified = serde._simplify(bob, serde_plan) serde_plan_detailed = serde._detail(bob, serde_plan_simplified) t1, t2, t3 = th.tensor([1]), th.tensor([2]), th.tensor([3]) expected = serde_plan(t1, t2, t3) actual = serde_plan_detailed(t1, t2, t3) assert actual == expected
def _simplify_numpy_number( worker: AbstractWorker, numpy_nb: Union[numpy.int32, numpy.int64, numpy.float32, numpy.float64] ) -> Tuple[bin, Tuple]: """ This function gets the byte representation of the numpy number and stores the dtype for reconstruction Args: numpy_nb (e.g numpy.float64): a numpy number Returns: list: a list holding the byte representation, dtype of the numpy number Examples: np_representation = _simplify_numpy_number(worker, numpy.float64(2.3))) """ nb_bytes = numpy_nb.tobytes() nb_dtype = serde._simplify(worker, numpy_nb.dtype.name) return (nb_bytes, nb_dtype)
def _simplify_torch_device(worker: AbstractWorker, device: torch.device) -> Tuple: device_type = serde._simplify(worker, device.type) return (device_type,)