def _execute_call(self, attr, self_, *args, **kwargs): """ Transmit the call to the appropriate TensorType for handling """ # Distinguish between a command with torch tensors (like when called by the client, # or received from another worker), and a command with syft tensor, which can occur # when a function is overloaded by a SyftTensor (for instance _PlusIsMinusTensor # overloads add and replace it by sub) try: torch_utils.assert_has_only_torch_tensorvars((args, kwargs)) is_torch_command = True except AssertionError: is_torch_command = False has_self = self_ is not None if has_self: command = torch._command_guard(attr, torch.tensorvar_methods) else: command = torch._command_guard(attr, torch.torch_modules) raw_command = { 'command': command, 'has_self': has_self, 'args': args, 'kwargs': kwargs } if has_self: raw_command['self'] = self_ if is_torch_command: # Unwrap the torch wrapper syft_command, child_type = torch_utils.prepare_child_command( raw_command, replace_tensorvar_with_child=True) else: # Get the next syft class # The actual syft class is the one which redirected (see the _PlusIsMinus ex.) syft_command, child_type = torch_utils.prepare_child_command( raw_command, replace_tensorvar_with_child=True) torch_utils.assert_has_only_syft_tensors(syft_command) # Note: because we have pb of registration of tensors with the right worker, # and because having Virtual workers creates even more ambiguity, we specify the worker # performing the operation result = child_type.handle_call(syft_command, owner=self) torch_utils.enforce_owner((raw_command, result), self) if is_torch_command: # Wrap the result if has_self and utils.is_in_place_method(attr): wrapper = torch_utils.wrap_command_with( result, raw_command['self']) else: wrapper = torch_utils.wrap_command(result) return wrapper else: # We don't need to wrap return result
def _execute_call(self, attr, self_, *args, **kwargs): """Transmit the call to the appropriate TensorType for handling.""" # if this is none - then it means that self_ is not a torch wrapper # and we need to execute one level higher TODO: not ok for complex args if self_ is not None and self_.child is None: new_args = [ arg.wrap(True) if not isinstance(arg, int) else arg for arg in args ] return self._execute_call(attr, self_.wrap(True), *new_args, **kwargs) # Distinguish between a command with torch tensors (like when called by the client, # or received from another worker), and a command with syft tensor, which can occur # when a function is overloaded by a SyftTensor (for instance _PlusIsMinusTensor # overloads add and replace it by sub) try: torch_utils.assert_has_only_torch_tensorvars((args, kwargs)) is_torch_command = True except AssertionError: is_torch_command = False has_self = self_ is not None if has_self: command = torch._command_guard(attr, "tensorvar_methods") else: command = torch._command_guard(attr, "torch_modules") raw_command = { "command": command, "has_self": has_self, "args": args, "kwargs": kwargs, } if has_self: raw_command["self"] = self_ if is_torch_command: # Unwrap the torch wrapper syft_command, child_type = torch_utils.prepare_child_command( raw_command, replace_tensorvar_with_child=True ) else: # Get the next syft class # The actual syft class is the one which redirected (see the _PlusIsMinus ex.) syft_command, child_type = torch_utils.prepare_child_command( raw_command, replace_tensorvar_with_child=True ) # torch_utils.assert_has_only_syft_tensors(syft_command) # Note: because we have pb of registration of tensors with the right worker, # and because having Virtual workers creates even more ambiguity, we specify the worker # performing the operation result = child_type.handle_call(syft_command, owner=self) if is_torch_command: # Wrap the result if has_self and utils.is_in_place_method(attr): # TODO: fix this properly: don't wrap the same way if syft or Variable if torch_utils.is_variable(result) or torch_utils.is_tensor(result): wrapper = torch_utils.bind_tensor_nodes( raw_command["self"], result.child ) else: wrapper = torch_utils.bind_tensor_nodes(raw_command["self"], result) else: wrapper = torch_utils.wrap_command(result) torch_utils.enforce_owner(wrapper, self) return wrapper else: # We don't need to wrap torch_utils.enforce_owner(result, self) return result
def _is_command_valid_guard(command, allowed): try: torch._command_guard(command, allowed) except RuntimeError: return False return True
def handle_call(cls, syft_command, owner): """ Execute a forwarded command on the native tensor with native operations. Receive a syft command and an owner, and converts it into command with native torch args. Excute native operations and converts it back into syft response using _LocalTensors. """ tensor_command, torch_type = torch_utils.prepare_child_command( syft_command, replace_tensorvar_with_child=True) torch_utils.assert_has_only_torch_tensorvars(tensor_command) attr = tensor_command['command'] args = tensor_command['args'] kwargs = tensor_command['kwargs'] has_self = tensor_command['has_self'] if has_self: self = tensor_command['self'] attr = torch._command_guard(attr, torch.tensorvar_methods) command = getattr(self, "native_" + attr) else: attr = torch._command_guard(attr, torch.torch_modules) elems = attr.split('.') elems[-1] = 'native_' + elems[-1] native_func_name = '.'.join(elems) command = eval(native_func_name) response = command(*args, **kwargs) # TODO : control registration process if response is None: return response if owner.id != owner.hook.local_worker.id: if isinstance(response, (int, float, bool)): response = sy.zeros(1) + response elif isinstance(response, (np.ndarray, )): response = sy.FloatTensor(response) else: if isinstance(response, (int, float, bool, np.ndarray)): return response # If the command is an in-place method, wrap self and return if has_self and utils.is_in_place_method(attr): # wrap the main element torch_utils.wrap_command_with(response, syft_command['self']) if torch_utils.is_variable(response): # Also wrap the data if it's a variable (don't use wrap_command_with: the chain is not well formed yet) syft_command['self'].child.data = response.data response.data.parent = syft_command['self'].child.data.parent # And wrap the grad if there is one if response.grad is not None: if response.grad.data.dim() > 0: syft_command['self'].child.grad = response.grad else: syft_command['self'].child.grad.native_set_() response.grad.parent = syft_command[ 'self'].child.grad.parent # Finally, fix the links .data and .grad if response.grad is None: torch_utils.link_var_chain_to_data_chain( syft_command['self'], response.data.child) else: torch_utils.link_var_chain_to_data_and_grad_chains( syft_command['self'], response.data.child, response.grad.child) return_response = syft_command['self'] # Else, the response if not self. Iterate over the response(s) and wrap with a syft tensor else: responses = response if isinstance(response, tuple) else (response, ) syft_responses = [] for resp in responses: if resp is None: # Don't wrap None syft_responses.append(resp) continue if isinstance(resp, (int, float, bool)): # if not final worker, convert into Float Tensor, which comes with a _LocalTensor if owner.id != owner.hook.local_worker.id: resp = sy.zeros(1) + resp else: # Else don't wrap it syft_responses.append(resp) continue syft_response = sy._LocalTensor(child=resp, parent=resp, owner=owner, torch_type='syft.' + type(resp).__name__) if torch_utils.is_variable(resp): if resp.grad is None: torch_utils.link_var_chain_to_data_chain( syft_response, resp.data.child) else: torch_utils.link_var_chain_to_data_and_grad_chains( syft_response, resp.data.child, resp.grad.child) syft_responses.append(syft_response) return_response = tuple(syft_responses) if len( syft_responses) > 1 else syft_responses[0] return return_response