Exemplo n.º 1
0
    def execute_command(self, message: tuple) -> PointerTensor:
        """
        Executes commands received from other workers.

        Args:
            message: A tuple specifying the command and the args.

        Returns:
            A pointer to the result.
        """

        (command_name, _self, args, kwargs), return_ids = message

        # TODO add kwargs
        command_name = command_name
        # Handle methods
        if _self is not None:
            if sy.torch.is_inplace_method(command_name):
                getattr(_self, command_name)(*args, **kwargs)
                return
            else:
                response = getattr(_self, command_name)(*args, **kwargs)
        # Handle functions
        else:
            # At this point, the command is ALWAYS a path to a
            # function (i.e., torch.nn.functional.relu). Thus,
            # we need to fetch this function and run it.

            sy.torch.command_guard(command_name, "torch_modules")

            paths = command_name.split(".")
            command = self
            for path in paths:
                command = getattr(command, path)

            response = command(*args, **kwargs)

        # some functions don't return anything (such as .backward())
        # so we need to check for that here.
        if response is not None:
            # Register response et create pointers for tensor elements
            try:
                response = sy.frameworks.torch.hook_args.register_response(
                    command_name, response, list(return_ids), self)
                return response
            except ResponseSignatureError:
                return_ids = IdProvider(return_ids)
                response = sy.frameworks.torch.hook_args.register_response(
                    command_name, response, return_ids, self)
                raise ResponseSignatureError(return_ids.generated)
Exemplo n.º 2
0
    def execute_command(self, message: tuple) -> PointerTensor:
        """
        Executes commands received from other workers.

        Args:
            message: A tuple specifying the command and the args.

        Returns:
            A pointer to the result.
        """

        (command_name, _self, args, kwargs), return_ids = message

        # TODO add kwargs
        command_name = command_name
        # Handle methods
        if _self is not None:
            if type(_self) == int:
                _self = BaseWorker.get_obj(self, _self)
                if _self is None:
                    return
            if type(_self) == str and _self == "self":
                _self = self
            if sy.framework.is_inplace_method(command_name):
                # TODO[jvmancuso]: figure out a good way to generalize the
                # above check (#2530)
                getattr(_self, command_name)(*args, **kwargs)
                return
            else:
                try:
                    response = getattr(_self, command_name)(*args, **kwargs)
                except TypeError:
                    # TODO Andrew thinks this is gross, please fix. Instead need to properly deserialize strings
                    new_args = [
                        arg.decode("utf-8") if isinstance(arg, bytes) else arg
                        for arg in args
                    ]
                    response = getattr(_self, command_name)(*new_args,
                                                            **kwargs)
        # Handle functions
        else:
            # At this point, the command is ALWAYS a path to a
            # function (i.e., torch.nn.functional.relu). Thus,
            # we need to fetch this function and run it.

            sy.framework.command_guard(command_name)

            paths = command_name.split(".")
            command = self
            for path in paths:
                command = getattr(command, path)

            response = command(*args, **kwargs)

        # some functions don't return anything (such as .backward())
        # so we need to check for that here.
        if response is not None:
            # Register response and create pointers for tensor elements
            try:
                response = hook_args.register_response(command_name, response,
                                                       list(return_ids), self)
                return response
            except ResponseSignatureError:
                return_id_provider = sy.ID_PROVIDER
                return_id_provider.set_next_ids(return_ids, check_ids=False)
                return_id_provider.start_recording_ids()
                response = hook_args.register_response(command_name, response,
                                                       return_id_provider,
                                                       self)
                new_ids = return_id_provider.get_recorded_ids()
                raise ResponseSignatureError(new_ids)
Exemplo n.º 3
0
    def execute_computation_action(self, action: ComputationAction) -> PointerTensor:
        """
        Executes commands received from other workers.
        Args:
            message: A tuple specifying the command and the args.
        Returns:
            The result or None if return_value is False.
        """

        op_name = action.name
        _self = action.target
        args_ = action.args
        kwargs_ = action.kwargs
        return_ids = action.return_ids
        return_value = action.return_value

        # Handle methods
        if _self is not None:
            if type(_self) == int:
                _self = self.get_obj(_self)
                if _self is None:
                    return
            elif isinstance(_self, str):
                if _self == "self":
                    _self = self.worker
                else:
                    res: list = self.worker.search(_self)
                    if len(res) != 1:
                        raise ValueError(
                            f"Searching for {_self} on {self.worker.id}. /!\\ {len(res)} found"
                        )
                    _self = res[0]
            if sy.framework.is_inplace_method(op_name):
                # TODO[jvmancuso]: figure out a good way to generalize the
                # above check (#2530)
                getattr(_self, op_name)(*args_, **kwargs_)
                return
            else:
                try:
                    response = getattr(_self, op_name)(*args_, **kwargs_)
                except TypeError:
                    # TODO Andrew thinks this is gross, please fix. Instead need to
                    # properly deserialize strings
                    new_args = [
                        arg.decode("utf-8") if isinstance(arg, bytes) else arg for arg in args_
                    ]
                    response = getattr(_self, op_name)(*new_args, **kwargs_)
        # Handle functions
        else:
            # At this point, the command is ALWAYS a path to a
            # function (i.e., torch.nn.functional.relu). Thus,
            # we need to fetch this function and run it.

            sy.framework.command_guard(op_name)

            paths = op_name.split(".")
            command = self.worker
            for path in paths:
                command = getattr(command, path)

            response = command(*args_, **kwargs_)

        # some functions don't return anything (such as .backward())
        # so we need to check for that here.
        if response is not None:
            # Register response and create pointers for tensor elements
            try:
                response = hook_args.register_response(
                    op_name, response, list(return_ids), self.worker
                )
                # TODO: Does this mean I can set return_value to False and still
                # get a response? That seems surprising.
                if return_value or isinstance(response, (int, float, bool, str)):
                    return response
                else:
                    return None
            except ResponseSignatureError:
                return_id_provider = sy.ID_PROVIDER
                return_id_provider.set_next_ids(return_ids, check_ids=False)
                return_id_provider.start_recording_ids()
                response = hook_args.register_response(
                    op_name, response, return_id_provider, self.worker
                )
                new_ids = return_id_provider.get_recorded_ids()
                raise ResponseSignatureError(new_ids)