Ejemplo n.º 1
0
    def _execute_numpy_call(self, attr, self_, *args, **kwargs):
        """Transmit the call to the appropriate TensorType for handling."""

        # Distinguish between a command with torch tensors (like when called by the client,
        # or received from another worker), and a command with syft tensor, which can occur
        # when a function is overloaded by a SyftTensor (for instance _PlusIsMinusTensor
        # overloads add and replace it by sub)
        # try:
        #     torch_utils.assert_has_only_torch_tensorvars((args, kwargs))
        #     is_torch_command = True
        # except AssertionError:
        is_torch_command = False

        has_self = self_ is not None

        # if has_self:
        #     command = torch._command_guard(attr, 'tensorvar_methods')
        # else:
        #     command = torch._command_guard(attr, 'torch_modules')
        command = attr

        raw_command = {
            "command": command,
            "has_self": has_self,
            "args": args,
            "kwargs": kwargs,
        }
        if has_self:
            raw_command["self"] = self_

        # if is_torch_command:
        #     # Unwrap the torch wrapper
        #     syft_command, child_type = torch_utils.prepare_child_command(
        #         raw_command, replace_tensorvar_with_child=True)
        # else:
        #     # Get the next syft class
        #     # The actual syft class is the one which redirected (see the  _PlusIsMinus ex.)
        #     syft_command, child_type = torch_utils.prepare_child_command(
        #         raw_command, replace_tensorvar_with_child=True)
        #
        #     torch_utils.assert_has_only_syft_tensors(syft_command)

        # Note: because we have pb of registration of tensors with the right worker,
        # and because having Virtual workers creates even more ambiguity, we specify the worker
        # performing the operation
        # torch_utils.enforce_owner(raw_command, self)

        result = sy.array.handle_call(raw_command, owner=self)

        torch_utils.enforce_owner(result, self)

        if is_torch_command:
            # Wrap the result
            if has_self and utils.is_in_place_method(attr):
                result = torch_utils.bind_tensor_nodes(raw_command["self"],
                                                       result)
            else:
                result = torch_utils.wrap_command(result)

        return result
Ejemplo n.º 2
0
    def _execute_call(self, attr, self_, *args, **kwargs):
        """Transmit the call to the appropriate TensorType for handling."""

        # if this is none - then it means that self_ is not a torch wrapper
        # and we need to execute one level higher TODO: not ok for complex args
        if self_ is not None and self_.child is None:
            new_args = [
                arg.wrap(True) if not isinstance(arg, int) else arg for arg in args
            ]
            return self._execute_call(attr, self_.wrap(True), *new_args, **kwargs)

        # Distinguish between a command with torch tensors (like when called by the client,
        # or received from another worker), and a command with syft tensor, which can occur
        # when a function is overloaded by a SyftTensor (for instance _PlusIsMinusTensor
        # overloads add and replace it by sub)
        try:
            torch_utils.assert_has_only_torch_tensorvars((args, kwargs))
            is_torch_command = True
        except AssertionError:
            is_torch_command = False

        has_self = self_ is not None

        if has_self:
            command = torch._command_guard(attr, "tensorvar_methods")
        else:
            command = torch._command_guard(attr, "torch_modules")

        raw_command = {
            "command": command,
            "has_self": has_self,
            "args": args,
            "kwargs": kwargs,
        }
        if has_self:
            raw_command["self"] = self_
        if is_torch_command:
            # Unwrap the torch wrapper
            syft_command, child_type = torch_utils.prepare_child_command(
                raw_command, replace_tensorvar_with_child=True
            )
        else:
            # Get the next syft class
            # The actual syft class is the one which redirected (see the  _PlusIsMinus ex.)
            syft_command, child_type = torch_utils.prepare_child_command(
                raw_command, replace_tensorvar_with_child=True
            )

            # torch_utils.assert_has_only_syft_tensors(syft_command)

        # Note: because we have pb of registration of tensors with the right worker,
        # and because having Virtual workers creates even more ambiguity, we specify the worker
        # performing the operation

        result = child_type.handle_call(syft_command, owner=self)

        if is_torch_command:
            # Wrap the result
            if has_self and utils.is_in_place_method(attr):
                # TODO: fix this properly: don't wrap the same way if syft or Variable
                if torch_utils.is_variable(result) or torch_utils.is_tensor(result):
                    wrapper = torch_utils.bind_tensor_nodes(
                        raw_command["self"], result.child
                    )
                else:
                    wrapper = torch_utils.bind_tensor_nodes(raw_command["self"], result)
            else:
                wrapper = torch_utils.wrap_command(result)
            torch_utils.enforce_owner(wrapper, self)
            return wrapper
        else:
            # We don't need to wrap
            torch_utils.enforce_owner(result, self)
            return result