Exemplo n.º 1
0
    def _execute_call(self, attr, self_, *args, **kwargs):
        """
        Transmit the call to the appropriate TensorType for handling
        """

        # Distinguish between a command with torch tensors (like when called by the client,
        # or received from another worker), and a command with syft tensor, which can occur
        # when a function is overloaded by a SyftTensor (for instance _PlusIsMinusTensor
        # overloads add and replace it by sub)
        try:
            torch_utils.assert_has_only_torch_tensorvars((args, kwargs))
            is_torch_command = True
        except AssertionError:
            is_torch_command = False

        has_self = self_ is not None

        if has_self:
            command = torch._command_guard(attr, torch.tensorvar_methods)
        else:
            command = torch._command_guard(attr, torch.torch_modules)

        raw_command = {
            'command': command,
            'has_self': has_self,
            'args': args,
            'kwargs': kwargs
        }
        if has_self:
            raw_command['self'] = self_
        if is_torch_command:
            # Unwrap the torch wrapper
            syft_command, child_type = torch_utils.prepare_child_command(
                raw_command, replace_tensorvar_with_child=True)
        else:
            # Get the next syft class
            # The actual syft class is the one which redirected (see the  _PlusIsMinus ex.)
            syft_command, child_type = torch_utils.prepare_child_command(
                raw_command, replace_tensorvar_with_child=True)

            torch_utils.assert_has_only_syft_tensors(syft_command)

        # Note: because we have pb of registration of tensors with the right worker,
        # and because having Virtual workers creates even more ambiguity, we specify the worker
        # performing the operation
        result = child_type.handle_call(syft_command, owner=self)

        torch_utils.enforce_owner((raw_command, result), self)

        if is_torch_command:
            # Wrap the result
            if has_self and utils.is_in_place_method(attr):
                wrapper = torch_utils.wrap_command_with(
                    result, raw_command['self'])
            else:
                wrapper = torch_utils.wrap_command(result)
            return wrapper
        else:
            # We don't need to wrap
            return result
Exemplo n.º 2
0
    def _execute_call(self, attr, self_, *args, **kwargs):
        """Transmit the call to the appropriate TensorType for handling."""

        # if this is none - then it means that self_ is not a torch wrapper
        # and we need to execute one level higher TODO: not ok for complex args
        if self_ is not None and self_.child is None:
            new_args = [
                arg.wrap(True) if not isinstance(arg, int) else arg for arg in args
            ]
            return self._execute_call(attr, self_.wrap(True), *new_args, **kwargs)

        # Distinguish between a command with torch tensors (like when called by the client,
        # or received from another worker), and a command with syft tensor, which can occur
        # when a function is overloaded by a SyftTensor (for instance _PlusIsMinusTensor
        # overloads add and replace it by sub)
        try:
            torch_utils.assert_has_only_torch_tensorvars((args, kwargs))
            is_torch_command = True
        except AssertionError:
            is_torch_command = False

        has_self = self_ is not None

        if has_self:
            command = torch._command_guard(attr, "tensorvar_methods")
        else:
            command = torch._command_guard(attr, "torch_modules")

        raw_command = {
            "command": command,
            "has_self": has_self,
            "args": args,
            "kwargs": kwargs,
        }
        if has_self:
            raw_command["self"] = self_
        if is_torch_command:
            # Unwrap the torch wrapper
            syft_command, child_type = torch_utils.prepare_child_command(
                raw_command, replace_tensorvar_with_child=True
            )
        else:
            # Get the next syft class
            # The actual syft class is the one which redirected (see the  _PlusIsMinus ex.)
            syft_command, child_type = torch_utils.prepare_child_command(
                raw_command, replace_tensorvar_with_child=True
            )

            # torch_utils.assert_has_only_syft_tensors(syft_command)

        # Note: because we have pb of registration of tensors with the right worker,
        # and because having Virtual workers creates even more ambiguity, we specify the worker
        # performing the operation

        result = child_type.handle_call(syft_command, owner=self)

        if is_torch_command:
            # Wrap the result
            if has_self and utils.is_in_place_method(attr):
                # TODO: fix this properly: don't wrap the same way if syft or Variable
                if torch_utils.is_variable(result) or torch_utils.is_tensor(result):
                    wrapper = torch_utils.bind_tensor_nodes(
                        raw_command["self"], result.child
                    )
                else:
                    wrapper = torch_utils.bind_tensor_nodes(raw_command["self"], result)
            else:
                wrapper = torch_utils.wrap_command(result)
            torch_utils.enforce_owner(wrapper, self)
            return wrapper
        else:
            # We don't need to wrap
            torch_utils.enforce_owner(result, self)
            return result
Exemplo n.º 3
0
def _is_command_valid_guard(command, allowed):
    try:
        torch._command_guard(command, allowed)
    except RuntimeError:
        return False
    return True
Exemplo n.º 4
0
    def handle_call(cls, syft_command, owner):
        """
        Execute a forwarded command on the native tensor with native operations.
        Receive a syft command and an owner, and converts it into command with
        native torch args. Excute native operations and converts it back into
        syft response using _LocalTensors.
        """
        tensor_command, torch_type = torch_utils.prepare_child_command(
            syft_command, replace_tensorvar_with_child=True)
        torch_utils.assert_has_only_torch_tensorvars(tensor_command)

        attr = tensor_command['command']
        args = tensor_command['args']
        kwargs = tensor_command['kwargs']
        has_self = tensor_command['has_self']

        if has_self:
            self = tensor_command['self']
            attr = torch._command_guard(attr, torch.tensorvar_methods)
            command = getattr(self, "native_" + attr)
        else:
            attr = torch._command_guard(attr, torch.torch_modules)
            elems = attr.split('.')
            elems[-1] = 'native_' + elems[-1]
            native_func_name = '.'.join(elems)
            command = eval(native_func_name)

        response = command(*args, **kwargs)

        # TODO : control registration process
        if response is None:
            return response

        if owner.id != owner.hook.local_worker.id:
            if isinstance(response, (int, float, bool)):
                response = sy.zeros(1) + response
            elif isinstance(response, (np.ndarray, )):
                response = sy.FloatTensor(response)
        else:
            if isinstance(response, (int, float, bool, np.ndarray)):
                return response

        # If the command is an in-place method, wrap self and return
        if has_self and utils.is_in_place_method(attr):
            # wrap the main element
            torch_utils.wrap_command_with(response, syft_command['self'])

            if torch_utils.is_variable(response):
                # Also wrap the data if it's a variable (don't use wrap_command_with: the chain is not well formed yet)
                syft_command['self'].child.data = response.data
                response.data.parent = syft_command['self'].child.data.parent
                # And wrap the grad if there is one
                if response.grad is not None:
                    if response.grad.data.dim() > 0:
                        syft_command['self'].child.grad = response.grad
                    else:
                        syft_command['self'].child.grad.native_set_()
                    response.grad.parent = syft_command[
                        'self'].child.grad.parent
                # Finally, fix the links .data and .grad
                if response.grad is None:
                    torch_utils.link_var_chain_to_data_chain(
                        syft_command['self'], response.data.child)
                else:
                    torch_utils.link_var_chain_to_data_and_grad_chains(
                        syft_command['self'], response.data.child,
                        response.grad.child)

            return_response = syft_command['self']
        # Else, the response if not self. Iterate over the response(s) and wrap with a syft tensor
        else:
            responses = response if isinstance(response,
                                               tuple) else (response, )
            syft_responses = []
            for resp in responses:
                if resp is None:  # Don't wrap None
                    syft_responses.append(resp)
                    continue

                if isinstance(resp, (int, float, bool)):
                    # if not final worker, convert into Float Tensor, which comes with a _LocalTensor
                    if owner.id != owner.hook.local_worker.id:
                        resp = sy.zeros(1) + resp
                    else:  # Else don't wrap it
                        syft_responses.append(resp)
                        continue

                syft_response = sy._LocalTensor(child=resp,
                                                parent=resp,
                                                owner=owner,
                                                torch_type='syft.' +
                                                type(resp).__name__)

                if torch_utils.is_variable(resp):
                    if resp.grad is None:
                        torch_utils.link_var_chain_to_data_chain(
                            syft_response, resp.data.child)
                    else:
                        torch_utils.link_var_chain_to_data_and_grad_chains(
                            syft_response, resp.data.child, resp.grad.child)

                syft_responses.append(syft_response)

            return_response = tuple(syft_responses) if len(
                syft_responses) > 1 else syft_responses[0]

        return return_response