def python_encode(self, obj, private_local): # Case of basic types if isinstance(obj, (int, float, str)) or obj is None: return obj elif isinstance(obj, type(...)): return "..." elif isinstance(obj, np.ndarray): return obj.ser(private=private_local, to_json=False) # Tensors and Variable encoded with their id elif torch_utils.is_tensor(obj) or torch_utils.is_variable(obj): tail_object = torch_utils.find_tail_of_chain(obj) if self.retrieve_pointers and isinstance(tail_object, sy._PointerTensor): self.found_pointers.append(tail_object) if torch_utils.is_variable(obj): return obj.ser(private=private_local, is_head=True) else: return obj.ser(private=private_local) # sy._SyftTensor (Pointer, Local) # [Note: shouldn't be called on regular chain with end=tensorvar] elif torch_utils.is_syft_tensor(obj): tail_object = torch_utils.find_tail_of_chain(obj) if self.retrieve_pointers and isinstance(tail_object, sy._PointerTensor): self.found_pointers.append(tail_object) return obj.ser(private=private_local) # List elif isinstance(obj, list): return [self.python_encode(i, private_local) for i in obj] # Iterables non json-serializable elif isinstance(obj, (tuple, set, bytearray, range)): key = '__' + type(obj).__name__ + '__' return {key: [self.python_encode(i, private_local) for i in obj]} # Slice elif isinstance(obj, slice): key = '__' + type(obj).__name__ + '__' return {key: {'args': [obj.start, obj.stop, obj.step]}} # Dict elif isinstance(obj, dict): return { k: self.python_encode(v, private_local) for k, v in obj.items() } # Generator elif isinstance(obj, types.GeneratorType): logging.warning("Generator args can't be transmitted") return [] # worker elif isinstance(obj, (sy.SocketWorker, sy.VirtualWorker)): return {'__worker__': obj.id} # Else log the error else: raise ValueError('Unhandled type', type(obj))
def register(self, result): """Register an object with SyftTensors.""" if issubclass(type(result), sy._SyftTensor): syft_obj = result self.register_object(syft_obj) elif torch_utils.is_tensor(result): tensor = result self.register_object(tensor.child) elif torch_utils.is_variable(result): variable = result self.register(variable.child) self.register(variable.data.child) if not hasattr(variable, "grad") or variable.grad is None: variable.init_grad_() self.register(variable.grad.child) self.register(variable.grad.data.child) # Case of a iter type non json serializable elif isinstance(result, (list, tuple, set, bytearray, range)): for res in result: self.register(res) elif result is None: """do nothing.""" elif isinstance(result, np.ndarray): self.register_object(result) else: raise TypeError("The type", type(result), "is not supported at the moment") return
def python_encode(self, obj, private_local): # /!\ Sort by frequency # Dict if isinstance(obj, dict): return { k: self.python_encode(v, private_local) for k, v in obj.items() } # sy._SyftTensor (Pointer, Local) elif issubclass(type(obj), sy._SyftTensor): tail_object = torch_utils.find_tail_of_chain(obj) if self.retrieve_pointers and isinstance(tail_object, sy._PointerTensor): self.found_pointers.append(tail_object) return obj.ser(private=private_local) # Case of basic types elif isinstance(obj, (int, float, str)) or obj is None: return obj # List elif isinstance(obj, list): return [self.python_encode(i, private_local) for i in obj] # Iterables non json-serializable elif isinstance(obj, (tuple, set, bytearray, range)): key = get_serialized_key(obj) return {key: [self.python_encode(i, private_local) for i in obj]} # Variable elif torch_utils.is_variable(obj): tail_object = torch_utils.find_tail_of_chain(obj) if self.retrieve_pointers and isinstance(tail_object, sy._PointerTensor): self.found_pointers.append(tail_object) return obj.ser(private=private_local, is_head=True) # Tensors elif torch_utils.is_tensor(obj): tail_object = torch_utils.find_tail_of_chain(obj) if self.retrieve_pointers and isinstance(tail_object, sy._PointerTensor): self.found_pointers.append(tail_object) return obj.ser(private=private_local) # Ellipsis elif isinstance(obj, type(...)): return "..." # np.array elif isinstance(obj, np.ndarray): return obj.ser(private=private_local, to_json=False) # Slice elif isinstance(obj, slice): key = get_serialized_key(obj) return {key: {"args": [obj.start, obj.stop, obj.step]}} # Generator elif isinstance(obj, types.GeneratorType): logging.warning("Generator args can't be transmitted") return [] # worker elif isinstance(obj, (sy.SocketWorker, sy.VirtualWorker)): return {"__worker__": obj.id} # Else log the error else: raise ValueError("Unhandled type", type(obj))
def de_register(self, obj): """Un-register an object and its attribute.""" if issubclass(type(obj), sy._SyftTensor): self.rm_obj(obj.id) # TODO: rm .data, .grad, .grad.data if any elif torch_utils.is_tensor(obj): self.de_register(obj.child) elif torch_utils.is_variable(obj): self.de_register(obj.child) self.de_register(obj.data.child) # TODO: rm .grad, .grad.data if any # Case of a iter type non json serializable elif isinstance(obj, (list, tuple, set, bytearray, range)): for o in obj: self.de_register(o) elif obj is None: """do nothing.""" elif isinstance(obj, np.ndarray): self.rm_obj(obj.id) else: raise TypeError("The type", type(obj), "is not supported at the moment") return
def _execute_call(self, attr, self_, *args, **kwargs): """Transmit the call to the appropriate TensorType for handling.""" # if this is none - then it means that self_ is not a torch wrapper # and we need to execute one level higher TODO: not ok for complex args if self_ is not None and self_.child is None: new_args = [ arg.wrap(True) if not isinstance(arg, int) else arg for arg in args ] return self._execute_call(attr, self_.wrap(True), *new_args, **kwargs) # Distinguish between a command with torch tensors (like when called by the client, # or received from another worker), and a command with syft tensor, which can occur # when a function is overloaded by a SyftTensor (for instance _PlusIsMinusTensor # overloads add and replace it by sub) try: torch_utils.assert_has_only_torch_tensorvars((args, kwargs)) is_torch_command = True except AssertionError: is_torch_command = False has_self = self_ is not None if has_self: command = torch._command_guard(attr, "tensorvar_methods") else: command = torch._command_guard(attr, "torch_modules") raw_command = { "command": command, "has_self": has_self, "args": args, "kwargs": kwargs, } if has_self: raw_command["self"] = self_ if is_torch_command: # Unwrap the torch wrapper syft_command, child_type = torch_utils.prepare_child_command( raw_command, replace_tensorvar_with_child=True ) else: # Get the next syft class # The actual syft class is the one which redirected (see the _PlusIsMinus ex.) syft_command, child_type = torch_utils.prepare_child_command( raw_command, replace_tensorvar_with_child=True ) # torch_utils.assert_has_only_syft_tensors(syft_command) # Note: because we have pb of registration of tensors with the right worker, # and because having Virtual workers creates even more ambiguity, we specify the worker # performing the operation result = child_type.handle_call(syft_command, owner=self) if is_torch_command: # Wrap the result if has_self and utils.is_in_place_method(attr): # TODO: fix this properly: don't wrap the same way if syft or Variable if torch_utils.is_variable(result) or torch_utils.is_tensor(result): wrapper = torch_utils.bind_tensor_nodes( raw_command["self"], result.child ) else: wrapper = torch_utils.bind_tensor_nodes(raw_command["self"], result) else: wrapper = torch_utils.wrap_command(result) torch_utils.enforce_owner(wrapper, self) return wrapper else: # We don't need to wrap torch_utils.enforce_owner(result, self) return result
def python_decode(self, dct): """ Is called on every dict found. We check if some keys correspond to special keywords referring to a type we need to re-cast (e.g. tuple, or torch Variable). """ # PLAN A: See if the dct object is not actually a dictionary and address # each case. if isinstance(dct, (int, str, float)): # a very strange special case if (dct == '...'): return ... return dct if isinstance(dct, (list, )): return [self.python_decode(o) for o in dct] if dct is None: return None if not isinstance(dct, dict): print(type(dct)) raise TypeError('Type not handled', dct) # PLAN B: If the dct object IS a dictionary, check to see if it has a "type" key if ('type' in dct): if dct['type'] == "numpy.array": # at first glance, the following if statement might seem a bit confusing # since the dct object is identical for both. Basically, the pointer object # is created here (on the receiving end of a message) as opposed to on the sending # side. We decide whether to use the dictionary to construct a pointer or the # actual tensor based on wehther self.acquire is true. Note that this changes # how dct['id'] is used. If creating an actual tensor, the tensor id is set to dct['id] # otherwise, id_at_location is set to be dct['id']. Similarly with dct['owner']. # if we intend to receive the tensor itself, construct an array if (self.acquire): return array(dct['data'], id=dct['id'], owner=self.worker) # if we intend to create a pointer, construct a pointer. Note that else: return array_ptr(dct['data'], owner=self.worker, location=self.worker.get_worker( dct['owner']), id_at_location=dct['id']) elif dct['type'] == 'numpy.array_ptr': return self.worker.get_obj(dct['id_at_location']) # Plan C: As a last resort, use a Regex to try to find a type somewhere. # TODO: Plan C should never be called - but is used extensively in PySyft's PyTorch integratio pat = re.compile('__(.+)__') for key, obj in dct.items(): if pat.search(key) is not None: obj_type = pat.search(key).group(1) # Case of a tensor if torch_utils.is_tensor(obj_type): o = torch.guard['syft.' + obj_type].deser({key: obj}, self.worker, self.acquire) return o # Case of a Variable elif torch_utils.is_variable(obj_type): return sy.Variable.deser({key: obj}, self.worker, self.acquire, is_head=True) # Case of a Syft tensor elif torch_utils.is_syft_tensor(obj_type): return sy._SyftTensor.deser_routing({key: obj}, self.worker, self.acquire) # Case of a iter type non json serializable elif obj_type in ('tuple', 'set', 'bytearray', 'range'): return eval(obj_type)([self.python_decode(o) for o in obj]) # Case of a slice elif obj_type == 'slice': return slice(*obj['args']) # Case of a worker elif obj_type == 'worker': return self.worker.get_worker(obj) else: raise TypeError('The special object type', obj_type, 'is not supported') else: dct[key] = self.python_decode(obj) return dct