def new_backward(self, *args, **kwargs): worker = self.owner # Retrieve all the variable ids involved in the computation graph variable_ids = torch_utils.get_connected_variables(self) variable_ids = [ var_id for var_id in variable_ids if var_id in worker._objects ] # Save all the gradients (to keep the id) and reset the grads saved_grads = {} for variable_id in variable_ids: syft_tensor = worker.get_obj(variable_id) var = syft_tensor.parent assert var.id == variable_id saved_grads[variable_id] = var.grad var.grad = None # Performs the backward self.native_native_backward(*args, **kwargs) # Put back the original grad envelop and insert the new grad value in it for variable_id in variable_ids: syft_tensor = worker.get_obj(variable_id) # retrieve the var to fix var = syft_tensor.parent # retrieve the old grad, and insert it (to keep the chain) [first the envelope, then the data] saved_grad = saved_grads[variable_id] if saved_grad is not None: # store the computed gradient computed_grad = var.grad var.assign_grad_(saved_grad) # Insert the value of the computed_grad if computed_grad is not None: var.grad.data.native_set_(computed_grad.data) # Make sure everyone has the right owner torch_utils.enforce_owner(var, worker)
def _execute_numpy_call(self, attr, self_, *args, **kwargs): """Transmit the call to the appropriate TensorType for handling.""" # Distinguish between a command with torch tensors (like when called by the client, # or received from another worker), and a command with syft tensor, which can occur # when a function is overloaded by a SyftTensor (for instance _PlusIsMinusTensor # overloads add and replace it by sub) # try: # torch_utils.assert_has_only_torch_tensorvars((args, kwargs)) # is_torch_command = True # except AssertionError: is_torch_command = False has_self = self_ is not None # if has_self: # command = torch._command_guard(attr, 'tensorvar_methods') # else: # command = torch._command_guard(attr, 'torch_modules') command = attr raw_command = { "command": command, "has_self": has_self, "args": args, "kwargs": kwargs, } if has_self: raw_command["self"] = self_ # if is_torch_command: # # Unwrap the torch wrapper # syft_command, child_type = torch_utils.prepare_child_command( # raw_command, replace_tensorvar_with_child=True) # else: # # Get the next syft class # # The actual syft class is the one which redirected (see the _PlusIsMinus ex.) # syft_command, child_type = torch_utils.prepare_child_command( # raw_command, replace_tensorvar_with_child=True) # # torch_utils.assert_has_only_syft_tensors(syft_command) # Note: because we have pb of registration of tensors with the right worker, # and because having Virtual workers creates even more ambiguity, we specify the worker # performing the operation # torch_utils.enforce_owner(raw_command, self) result = sy.array.handle_call(raw_command, owner=self) torch_utils.enforce_owner(result, self) if is_torch_command: # Wrap the result if has_self and utils.is_in_place_method(attr): result = torch_utils.bind_tensor_nodes(raw_command["self"], result) else: result = torch_utils.wrap_command(result) return result
def _execute_call(self, attr, self_, *args, **kwargs): """Transmit the call to the appropriate TensorType for handling.""" # if this is none - then it means that self_ is not a torch wrapper # and we need to execute one level higher TODO: not ok for complex args if self_ is not None and self_.child is None: new_args = [ arg.wrap(True) if not isinstance(arg, int) else arg for arg in args ] return self._execute_call(attr, self_.wrap(True), *new_args, **kwargs) # Distinguish between a command with torch tensors (like when called by the client, # or received from another worker), and a command with syft tensor, which can occur # when a function is overloaded by a SyftTensor (for instance _PlusIsMinusTensor # overloads add and replace it by sub) try: torch_utils.assert_has_only_torch_tensorvars((args, kwargs)) is_torch_command = True except AssertionError: is_torch_command = False has_self = self_ is not None if has_self: command = torch._command_guard(attr, "tensorvar_methods") else: command = torch._command_guard(attr, "torch_modules") raw_command = { "command": command, "has_self": has_self, "args": args, "kwargs": kwargs, } if has_self: raw_command["self"] = self_ if is_torch_command: # Unwrap the torch wrapper syft_command, child_type = torch_utils.prepare_child_command( raw_command, replace_tensorvar_with_child=True ) else: # Get the next syft class # The actual syft class is the one which redirected (see the _PlusIsMinus ex.) syft_command, child_type = torch_utils.prepare_child_command( raw_command, replace_tensorvar_with_child=True ) # torch_utils.assert_has_only_syft_tensors(syft_command) # Note: because we have pb of registration of tensors with the right worker, # and because having Virtual workers creates even more ambiguity, we specify the worker # performing the operation result = child_type.handle_call(syft_command, owner=self) if is_torch_command: # Wrap the result if has_self and utils.is_in_place_method(attr): # TODO: fix this properly: don't wrap the same way if syft or Variable if torch_utils.is_variable(result) or torch_utils.is_tensor(result): wrapper = torch_utils.bind_tensor_nodes( raw_command["self"], result.child ) else: wrapper = torch_utils.bind_tensor_nodes(raw_command["self"], result) else: wrapper = torch_utils.wrap_command(result) torch_utils.enforce_owner(wrapper, self) return wrapper else: # We don't need to wrap torch_utils.enforce_owner(result, self) return result