コード例 #1
0
    def __call__(self, *args, **kwargs):
        """
        Supports to call the returned instance with input `args` and `kwargs` directly.

        Args:
            *args(tuple): tuple of all input arguments from original decorated function.
            **kwargs(dict): dict of all input keyward arguments from original decorated function. 

        Return:
            Outputs of decorated function.
        """

        # 1. call dygraph function directly if not enable `declarative`
        if not self._program_trans.enable_to_static:
            # NOTE(liym27):
            # Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message)
            # will show up **only once**. StaticFunction.__call__ will run many times, it is appropriate to
            # display this warning message only once.
            warnings.warn(
                "The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable to False. "
                "We will just return dygraph output. If you would like to get static graph output, please call API "
                "ProgramTranslator.enable(True)")
            return self._call_dygraph_function(*args, **kwargs)

        if not in_dygraph_mode():
            raise RuntimeError(
                "Failed to run the callable object {} decorated by '@paddle.jit.to_static', "
                "because it is NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the "
                "following API: paddle.disable_static().".format(
                    self.dygraph_function))

        # 2. trace ops from dygraph layers and cache the generated program.
        args, kwargs = self._function_spec.unified_args_and_kwargs(
            args, kwargs)

        try:
            concrete_program, partial_program_layer = self.get_concrete_program(
                *args, **kwargs)

            # 3. synchronize self.training attribute.
            if isinstance(self._class_instance, layers.Layer):
                partial_program_layer.training = self._class_instance.training

            # 4. return outputs.
            try:
                return partial_program_layer(args)
            except Exception as e:
                if not hasattr(e, error.ERROR_DATA):
                    # runtime error
                    error.attach_error_data(e, in_runtime=True)
                    raise
        except Exception as e:
            error_data = getattr(e, error.ERROR_DATA, None)
            if error_data:
                error_data.raise_new_exception()
            else:
                logging_utils.warn(
                    "Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'"
                    " if you can't handle this {} yourself.".format(type(e)))
                raise e
コード例 #2
0
    def from_func_spec(func_spec, input_spec, class_instance):
        """
        Builds the main_program with specialized inputs and returns outputs
        of program as fetch_list.

        Args:
            func_spec(FunctionSpec): A FunctionSpec instance for decorated function.
            input_spec(list[InputSpec]): 
        """
        # Transforms dygraph function into static function and caches it.
        dygraph_function = func_spec.dygraph_function
        static_func = convert_to_static(dygraph_function)

        main_program, startup_program = framework.Program(), framework.Program()
        # Note: The random seed should be synchronized into cached program
        # if set in `fluid.dygraph_guard` because some ops rely on it, such as
        # `fluid.layers.dropout`.
        main_program.random_seed = framework.default_main_program().random_seed
        startup_program.random_seed = framework.default_startup_program(
        ).random_seed

        with framework.program_guard(main_program, startup_program):
            with _switch_declarative_mode_guard_(is_declarative=True):
                # 1. Adds `fluid.data` layers for input if needed
                inputs = func_spec.to_static_inputs_with_spec(input_spec,
                                                              main_program)
                if class_instance:
                    inputs = tuple([class_instance] + list(inputs))

                # 2. Gets all ParamBases and buffered VarBases in the function
                all_parameters_and_buffers = list(
                    get_parameters(class_instance).values()) + list(
                        get_buffers(class_instance).values())

                # 3. Builds program only once and returns the output Variables.
                with param_guard(get_parameters(
                        class_instance, False)), param_guard(
                            get_buffers(class_instance, False)):
                    try:
                        outputs = static_func(*inputs)
                    except BaseException as e:
                        # NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here.
                        attach_error_data(e)
                        raise

                if not isinstance(outputs,
                                  (tuple, list)) and outputs is not None:
                    outputs = [outputs]

        main_program = update_op_callstack_with_origin_info(main_program)

        return ConcreteProgram(
            inputs=inputs,
            outputs=outputs,
            parameters=all_parameters_and_buffers,
            function=dygraph_function,
            main_program=main_program,
            startup_program=startup_program)
コード例 #3
0
    def __call__(self, *args, **kwargs):
        """
        Supports to call the returned instance with input `args` and `kwargs` directly.

        Args:
            *args(tuple): tuple of all input arguments from original decorated function.
            **kwargs(dict): dict of all input keyward arguments from original decorated function. 

        Return:
            Outputs of decorated function.
        """

        # 1. call dygraph function directly if not enable `declarative`
        if not self._program_trans.enable_declarative:
            logging_utils.warn(
                "The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable=False. "
                "We will just return dygraph output.")
            return self._call_dygraph_function(*args, **kwargs)

        if not in_dygraph_mode() and self._program_trans.enable_declarative:
            raise RuntimeError(
                "Failed to run the callable object {} decorated by '@paddle.jit.to_static', "
                "because it does NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the "
                "following API: paddle.disable_static().".format(
                    self.dygraph_function))

        # 2. trace ops from dygraph layers and cache the generated program.
        args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs)
        try:
            concrete_program, partial_program_layer = self.get_concrete_program(
                *args, **kwargs)

            # 3. synchronize self.training attribute.
            if isinstance(self._class_instance, layers.Layer):
                partial_program_layer.training = self._class_instance.training

            # 4. return outputs.
            return partial_program_layer(args)
        except Exception as e:
            if not hasattr(e, ERROR_DATA):
                # runtime error
                attach_error_data(e, in_runtime=True)
            error_data = getattr(e, ERROR_DATA, None)
            if error_data:
                new_exception = error_data.create_exception()
                if six.PY3:
                    # NOTE(liym27):
                    # 1. Why `raise new_exception from None`?
                    #   In Python 3, by default, an new exception is raised with trace information of the caught exception.
                    #   This only raises new_exception and hides unwanted implementation details from tracebacks of the
                    #   caught exception.
                    # 2. Use exec to bypass syntax error checking in Python 2.

                    six.exec_("raise new_exception from None")
                else:
                    raise new_exception
            else:
                raise
コード例 #4
0
    def get_output(self, dygraph_func, *args, **kwargs):
        """
        Returns the output dygraph VarBase for dygraph function. The dygraph
        function will be translated into static graph function so the under
        beneath numerical result will be calculated by declarative mode.

        Args:
            dygraph_func (callable): the dygraph function.
            *args, **kwargs : the input argument of dygraph_func.

        Returns:
            VarBase or tuple of VarBase: the dygraph VarBase containing digital
                result.

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                def func(x):
                    x = fluid.dygraph.to_variable(x)
                    if fluid.layers.mean(x) > 0:
                        x_v = x - 1
                    else:
                        x_v = x + 1
                    return x_v

                prog_trans = fluid.dygraph.ProgramTranslator()

                with fluid.dygraph.guard():
                    x = np.ones([1, 2])
                    x_v = prog_trans.get_output(func, x)
                    print(x_v.numpy()) # [[0. 0.]]

        """
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_output"
        if not self.enable_declarative:
            warnings.warn(
                "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable = False. "
                "We will just return dygraph output.")
            return dygraph_func(*args, **kwargs)

        function_spec = FunctionSpec(dygraph_func)
        cache_key = CacheKey.from_func_and_args(function_spec, args, kwargs,
                                                getattr(dygraph_func,
                                                        '__self__', None))
        _, partial_program_layer = self._program_cache[cache_key]

        if args and isinstance(args[0], layers.Layer):
            # Synchronize self.training attribute.
            partial_program_layer.training = args[0].training
            args = args[1:]
        try:
            return partial_program_layer(args)

        except BaseException as e:
            # NOTE:
            # 1. If e is raised in compile time, e should have been attached to ERROR_DATA before;
            # 2. If e raised in runtime, e should be attached to ERROR_DATA here.
            if not hasattr(e, ERROR_DATA):
                # runtime error
                attach_error_data(e, in_runtime=True)
            raise
コード例 #5
0
    def get_output(self, dygraph_func, *args, **kwargs):
        """
        Returns the output dygraph Tensor for dygraph function. The dygraph
        function will be translated into static graph function so the under
        beneath numerical result will be calculated by static graph mode.

        Args:
            dygraph_func (callable): the dygraph function.
            *args (tuple): the input argument of dygraph_func.
            **kwargs (dict): the input argument of dygraph_func.

        Returns:
            Tensor or tuple of Tensors: the dygraph Tensor containing digital result.

        Examples:
            .. code-block:: python

                import paddle


                def func(x):
                    if paddle.mean(x) > 0:
                        x_v = x - 1
                    else:
                        x_v = x + 1
                    return x_v


                prog_trans = paddle.jit.ProgramTranslator()

                x = paddle.ones([1, 2])
                x_v = prog_trans.get_output(func, x)
                print(x_v)  # [[0. 0.]]

        """
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_output"

        if not self.enable_to_static:
            # Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message)
            # will show up **only once**.
            logging_utils.warn(
                "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. "
                "We will just return dygraph output. "
                "Please call ProgramTranslator.enable(True) if you would like to get static output."
            )
            return dygraph_func(*args, **kwargs)
        try:
            function_spec = FunctionSpec(dygraph_func)
            cache_key = CacheKey.from_func_and_args(
                function_spec, args, kwargs,
                getattr(dygraph_func, '__self__', None))
            _, partial_program_layer = self._program_cache[cache_key]

            if args and isinstance(args[0], layers.Layer):
                # Synchronize self.training attribute.
                partial_program_layer.training = args[0].training
                args = args[1:]
            try:
                return partial_program_layer(args)
            except BaseException as e:
                # NOTE:
                # 1. If e is raised in compile time, e should have been attached to ERROR_DATA before;
                # 2. If e raised in runtime, e should be attached to ERROR_DATA here.
                if not hasattr(e, error.ERROR_DATA):
                    # runtime error
                    error.attach_error_data(e, in_runtime=True)
                raise
        except BaseException as e:
            error_data = getattr(e, error.ERROR_DATA, None)
            if error_data:
                error_data.raise_new_exception()
            else:
                logging_utils.warn(
                    "Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'"
                    " if you can't handle this {} yourself.".format(type(e)))
                raise e
コード例 #6
0
    def from_func_spec(func_spec, input_spec, input_kwargs_spec,
                       class_instance, **kwargs):
        """
        Builds the main_program with specialized inputs and returns outputs
        of program as fetch_list.

        Args:
            func_spec(FunctionSpec): A FunctionSpec instance for decorated function.
            input_spec(list[InputSpec]): 
        """
        # verify the instance is initialized in imperative mode.
        _verify_init_in_dynamic_mode(class_instance)

        # Transforms dygraph function into static function and caches it.
        dygraph_function = func_spec.dygraph_function
        static_func = convert_to_static(dygraph_function)
        # apply pre\post hook for outermost layer
        hook_helper = HookHelper(dygraph_function, class_instance,
                                 kwargs.get("with_hook", False))

        main_program, startup_program = framework.Program(), framework.Program(
        )
        # Note: The random seed should be synchronized into cached program
        # if set in `fluid.dygraph_guard` because some ops rely on it, such as
        # `fluid.layers.dropout`.
        main_program.random_seed = framework.default_main_program().random_seed
        startup_program.random_seed = framework.default_startup_program(
        ).random_seed

        from paddle.fluid.dygraph.base import _switch_declarative_mode_guard_
        with framework.program_guard(main_program, startup_program):
            with _switch_declarative_mode_guard_(is_declarative=True):
                # 1. Adds `fluid.data` layers for input if needed
                static_inputs = func_spec.to_static_inputs_with_spec(
                    input_spec, main_program)
                _kwargs = func_spec.to_static_inputs_with_spec(
                    input_kwargs_spec, main_program)
                if class_instance:
                    static_inputs = tuple([class_instance] +
                                          list(static_inputs))

                # 2. Gets all ParamBases and buffered VarBases in the function
                all_parameters_and_buffers = _extract_indeed_params_buffers(
                    class_instance)

                # 3. Builds program only once and returns the output Variables.
                with param_guard(get_parameters(
                        class_instance, False)), param_guard(
                            get_buffers(class_instance, False)):
                    try:
                        # only for jit.save, do nothing while train and eval process
                        inputs = hook_helper.apply_pre_hooks(static_inputs)
                        if _kwargs:
                            outputs = static_func(*inputs, **_kwargs)
                        else:
                            outputs = static_func(*inputs)
                        outputs = hook_helper.apply_post_hooks(inputs, outputs)
                    except BaseException as e:
                        # NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here.
                        error.attach_error_data(e)
                        error_data = getattr(e, error.ERROR_DATA, None)
                        if error_data:
                            error_data.raise_new_exception()
                        raise

                if outputs is not None:
                    need_wrap_into_list = not isinstance(
                        outputs, (tuple, list)) or len(outputs) == 1
                    if need_wrap_into_list:
                        outputs = [outputs]

        main_program = update_op_callstack_with_origin_info(main_program)

        return ConcreteProgram(inputs=static_inputs,
                               outputs=outputs,
                               parameters=all_parameters_and_buffers,
                               function=dygraph_function,
                               main_program=main_program,
                               startup_program=startup_program,
                               **kwargs)