def register(self, result): """Register an object with SyftTensors.""" if issubclass(type(result), sy._SyftTensor): syft_obj = result self.register_object(syft_obj) elif torch_utils.is_tensor(result): tensor = result self.register_object(tensor.child) elif torch_utils.is_variable(result): variable = result self.register(variable.child) self.register(variable.data.child) if not hasattr(variable, "grad") or variable.grad is None: variable.init_grad_() self.register(variable.grad.child) self.register(variable.grad.data.child) # Case of a iter type non json serializable elif isinstance(result, (list, tuple, set, bytearray, range)): for res in result: self.register(res) elif result is None: """do nothing.""" elif isinstance(result, np.ndarray): self.register_object(result) else: raise TypeError("The type", type(result), "is not supported at the moment") return
def python_encode(self, obj, private_local): # Case of basic types if isinstance(obj, (int, float, str)) or obj is None: return obj elif isinstance(obj, type(...)): return "..." elif isinstance(obj, np.ndarray): return obj.ser(private=private_local, to_json=False) # Dict elif isinstance(obj, dict): return { k: self.python_encode(v, private_local) for k, v in obj.items() } # Variable elif torch_utils.is_variable(obj): tail_object = torch_utils.find_tail_of_chain(obj) if self.retrieve_pointers and isinstance(tail_object, sy._PointerTensor): self.found_pointers.append(tail_object) return obj.ser(private=private_local, is_head=True) # Tensors elif torch_utils.is_tensor(obj): tail_object = torch_utils.find_tail_of_chain(obj) if self.retrieve_pointers and isinstance(tail_object, sy._PointerTensor): self.found_pointers.append(tail_object) return obj.ser(private=private_local) # sy._SyftTensor (Pointer, Local) # [Note: shouldn't be called on regular chain with end=tensorvar] elif torch_utils.is_syft_tensor(obj): tail_object = torch_utils.find_tail_of_chain(obj) if self.retrieve_pointers and isinstance(tail_object, sy._PointerTensor): self.found_pointers.append(tail_object) return obj.ser(private=private_local) # List elif isinstance(obj, list): return [self.python_encode(i, private_local) for i in obj] # np array elif isinstance(obj, np.ndarray): return obj.ser(private=private_local, to_json=False) # Iterables non json-serializable elif isinstance(obj, (tuple, set, bytearray, range)): key = get_serialized_key(obj) return {key: [self.python_encode(i, private_local) for i in obj]} # Slice elif isinstance(obj, slice): key = get_serialized_key(obj) return {key: {'args': [obj.start, obj.stop, obj.step]}} # Generator elif isinstance(obj, types.GeneratorType): logging.warning("Generator args can't be transmitted") return [] # worker elif isinstance(obj, (sy.SocketWorker, sy.VirtualWorker)): return {'__worker__': obj.id} # Else log the error else: raise ValueError('Unhandled type', type(obj))
def de_register_object(self, obj, _recurse_torch_objs=True): """Unregisters an object and removes attributes which are indicative of registration. Note that the way in which attributes are deleted has been informed by this StackOverflow post: https://goo.gl/CBEKLK """ is_torch_tensor = torch_utils.is_tensor(obj) if not is_torch_tensor: if hasattr(obj, "id"): self.rm_obj(obj.id) del obj.id if hasattr(obj, "owner"): del obj.owner if hasattr(obj, "child"): if obj.child is not None: if is_torch_tensor: if _recurse_torch_objs: self.de_register_object(obj.child, _recurse_torch_objs=False) else: self.de_register_object( obj.child, _recurse_torch_objs=_recurse_torch_objs ) if not is_torch_tensor: delattr(obj, "child")
def python_encode(self, obj, private_local): # /!\ Sort by frequency # Dict if isinstance(obj, dict): return { k: self.python_encode(v, private_local) for k, v in obj.items() } # sy._SyftTensor (Pointer, Local) elif issubclass(type(obj), sy._SyftTensor): tail_object = torch_utils.find_tail_of_chain(obj) if self.retrieve_pointers and isinstance(tail_object, sy._PointerTensor): self.found_pointers.append(tail_object) return obj.ser(private=private_local) # Case of basic types elif isinstance(obj, (int, float, str)) or obj is None: return obj # List elif isinstance(obj, list): return [self.python_encode(i, private_local) for i in obj] # Iterables non json-serializable elif isinstance(obj, (tuple, set, bytearray, range)): key = get_serialized_key(obj) return {key: [self.python_encode(i, private_local) for i in obj]} # Variable elif torch_utils.is_variable(obj): tail_object = torch_utils.find_tail_of_chain(obj) if self.retrieve_pointers and isinstance(tail_object, sy._PointerTensor): self.found_pointers.append(tail_object) return obj.ser(private=private_local, is_head=True) # Tensors elif torch_utils.is_tensor(obj): tail_object = torch_utils.find_tail_of_chain(obj) if self.retrieve_pointers and isinstance(tail_object, sy._PointerTensor): self.found_pointers.append(tail_object) return obj.ser(private=private_local) # Ellipsis elif isinstance(obj, type(...)): return "..." # np.array elif isinstance(obj, np.ndarray): return obj.ser(private=private_local, to_json=False) # Slice elif isinstance(obj, slice): key = get_serialized_key(obj) return {key: {"args": [obj.start, obj.stop, obj.step]}} # Generator elif isinstance(obj, types.GeneratorType): logging.warning("Generator args can't be transmitted") return [] # worker elif isinstance(obj, (sy.SocketWorker, sy.VirtualWorker)): return {"__worker__": obj.id} # Else log the error else: raise ValueError("Unhandled type", type(obj))
def _execute_method_call(self, *args, **kwargs): worker = hook_self.local_worker try: return worker._execute_call(attr, self, *args, **kwargs) except NotImplementedError: result = _execute_method_call(self.child, *args, **kwargs) if not torch_utils.is_tensor(self): result = type(self)(result) if hasattr(result, "second_constructor"): result = result.second_constructor() return result else: return result
def de_register(self, obj): """Un-register an object and its attribute.""" if issubclass(type(obj), sy._SyftTensor): self.rm_obj(obj.id) # TODO: rm .data, .grad, .grad.data if any elif torch_utils.is_tensor(obj): self.de_register(obj.child) elif torch_utils.is_variable(obj): self.de_register(obj.child) self.de_register(obj.data.child) # TODO: rm .grad, .grad.data if any # Case of a iter type non json serializable elif isinstance(obj, (list, tuple, set, bytearray, range)): for o in obj: self.de_register(o) elif obj is None: """do nothing.""" elif isinstance(obj, np.ndarray): self.rm_obj(obj.id) else: raise TypeError("The type", type(obj), "is not supported at the moment") return
def share(self, *workers): for k, tensor in self._the_plot.items(): self._the_plot[k] = tensor.long().share(*workers) self._board.board.share(*workers) self._board.layered_board.share(*workers) for k, tensor in self._board.layers.items(): if (not isinstance(tensor.child, sy._SNNTensor)): self._board.layers[k] = tensor.long().share(*workers) for a in self._update_groups: for b in a: for c in b: if (not isinstance(c, str)): for key, item in c.__dict__.items(): if (utils.is_tensor(item)): if (not isinstance(item.child, sy._SNNTensor)): c.__dict__[key] = item.long().share( *workers) self._backdrop.curtain.share(*workers)
def send(self, location): for k, tensor in self._the_plot.items(): self._the_plot[k] = tensor.send(location) self._board.board.send(location) self._board.layered_board.send(location) for k, tensor in self._board.layers.items(): if (not isinstance(tensor.child, sy._PointerTensor)): tensor.send(location) for a in self._update_groups: for b in a: for c in b: if (not isinstance(c, str)): for item in c.__dict__.values(): if (utils.is_tensor(item)): if (not isinstance(item.child, sy._PointerTensor)): item.send(location) self._backdrop.curtain.send(location)
def _execute_call(self, attr, self_, *args, **kwargs): """Transmit the call to the appropriate TensorType for handling.""" # if this is none - then it means that self_ is not a torch wrapper # and we need to execute one level higher TODO: not ok for complex args if self_ is not None and self_.child is None: new_args = [ arg.wrap(True) if not isinstance(arg, int) else arg for arg in args ] return self._execute_call(attr, self_.wrap(True), *new_args, **kwargs) # Distinguish between a command with torch tensors (like when called by the client, # or received from another worker), and a command with syft tensor, which can occur # when a function is overloaded by a SyftTensor (for instance _PlusIsMinusTensor # overloads add and replace it by sub) try: torch_utils.assert_has_only_torch_tensorvars((args, kwargs)) is_torch_command = True except AssertionError: is_torch_command = False has_self = self_ is not None if has_self: command = torch._command_guard(attr, "tensorvar_methods") else: command = torch._command_guard(attr, "torch_modules") raw_command = { "command": command, "has_self": has_self, "args": args, "kwargs": kwargs, } if has_self: raw_command["self"] = self_ if is_torch_command: # Unwrap the torch wrapper syft_command, child_type = torch_utils.prepare_child_command( raw_command, replace_tensorvar_with_child=True ) else: # Get the next syft class # The actual syft class is the one which redirected (see the _PlusIsMinus ex.) syft_command, child_type = torch_utils.prepare_child_command( raw_command, replace_tensorvar_with_child=True ) # torch_utils.assert_has_only_syft_tensors(syft_command) # Note: because we have pb of registration of tensors with the right worker, # and because having Virtual workers creates even more ambiguity, we specify the worker # performing the operation result = child_type.handle_call(syft_command, owner=self) if is_torch_command: # Wrap the result if has_self and utils.is_in_place_method(attr): # TODO: fix this properly: don't wrap the same way if syft or Variable if torch_utils.is_variable(result) or torch_utils.is_tensor(result): wrapper = torch_utils.bind_tensor_nodes( raw_command["self"], result.child ) else: wrapper = torch_utils.bind_tensor_nodes(raw_command["self"], result) else: wrapper = torch_utils.wrap_command(result) torch_utils.enforce_owner(wrapper, self) return wrapper else: # We don't need to wrap torch_utils.enforce_owner(result, self) return result
def python_decode(self, dct): """ Is called on every dict found. We check if some keys correspond to special keywords referring to a type we need to re-cast (e.g. tuple, or torch Variable). """ # PLAN A: See if the dct object is not actually a dictionary and address # each case. if isinstance(dct, (int, str, float)): # a very strange special case if (dct == '...'): return ... return dct if isinstance(dct, (list, )): return [self.python_decode(o) for o in dct] if dct is None: return None if not isinstance(dct, dict): print(type(dct)) raise TypeError('Type not handled', dct) # PLAN B: If the dct object IS a dictionary, check to see if it has a "type" key if ('type' in dct): if dct['type'] == "numpy.array": # at first glance, the following if statement might seem a bit confusing # since the dct object is identical for both. Basically, the pointer object # is created here (on the receiving end of a message) as opposed to on the sending # side. We decide whether to use the dictionary to construct a pointer or the # actual tensor based on wehther self.acquire is true. Note that this changes # how dct['id'] is used. If creating an actual tensor, the tensor id is set to dct['id] # otherwise, id_at_location is set to be dct['id']. Similarly with dct['owner']. # if we intend to receive the tensor itself, construct an array if (self.acquire): return array(dct['data'], id=dct['id'], owner=self.worker) # if we intend to create a pointer, construct a pointer. Note that else: return array_ptr(dct['data'], owner=self.worker, location=self.worker.get_worker( dct['owner']), id_at_location=dct['id']) elif dct['type'] == 'numpy.array_ptr': return self.worker.get_obj(dct['id_at_location']) # Plan C: As a last resort, use a Regex to try to find a type somewhere. # TODO: Plan C should never be called - but is used extensively in PySyft's PyTorch integratio pat = re.compile('__(.+)__') for key, obj in dct.items(): if pat.search(key) is not None: obj_type = pat.search(key).group(1) # Case of a tensor if torch_utils.is_tensor(obj_type): o = torch.guard['syft.' + obj_type].deser({key: obj}, self.worker, self.acquire) return o # Case of a Variable elif torch_utils.is_variable(obj_type): return sy.Variable.deser({key: obj}, self.worker, self.acquire, is_head=True) # Case of a Syft tensor elif torch_utils.is_syft_tensor(obj_type): return sy._SyftTensor.deser_routing({key: obj}, self.worker, self.acquire) # Case of a iter type non json serializable elif obj_type in ('tuple', 'set', 'bytearray', 'range'): return eval(obj_type)([self.python_decode(o) for o in obj]) # Case of a slice elif obj_type == 'slice': return slice(*obj['args']) # Case of a worker elif obj_type == 'worker': return self.worker.get_worker(obj) else: raise TypeError('The special object type', obj_type, 'is not supported') else: dct[key] = self.python_decode(obj) return dct