Beispiel #1
0
    def __init__(self, function, input_spec=None, **kwargs):
        """
        Initializes a `StaticFunction`.

        Args:
            function(callable): A function or method that will be converted into static program.
            input_spec(list[InputSpec]): list of InputSpec to specify the `shape/dtype/name` information for each input argument, default None.
            **kwargs(dict): other arguments like `build_strategy` et.al.
        """
        # save the instance `self` while decorating a method of class.
        if inspect.ismethod(function):
            self._dygraph_function = getattr(function, '__func__')
            self._class_instance = getattr(function, '__self__')
        else:
            self._dygraph_function = function
            self._class_instance = None

        self._input_spec = input_spec
        self._function_spec = FunctionSpec(function, input_spec)
        self._program_cache = ProgramCache()
        self._descriptor_cache = weakref.WeakKeyDictionary()
        # Note: Hold a reference to ProgramTranslator for switching `enable_to_static`.
        self._program_trans = ProgramTranslator()
        self._kwargs = kwargs
        self._training = True
        self._cuda_graph_capture_mode = ""
        self._cuda_graph_pool_id = 0
    def test_verify_input_spec(self):
        a_spec = InputSpec([None, 10], name='a')
        b_spec = InputSpec([10], name='b')

        # type(input_spec) should be list or tuple
        with self.assertRaises(TypeError):
            foo_spec = FunctionSpec(foo_func, input_spec=a_spec)

        foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec])
        self.assertTrue(len(foo_spec.flat_input_spec) == 2)
    def test_unified_args_and_kwargs(self):
        foo_spec = FunctionSpec(foo_func)
        # case 1: foo(10, 20, c=4)
        args, kwargs = foo_spec.unified_args_and_kwargs([10, 20], {'c': 4})
        self.assertTupleEqual(args, (10, 20, 4, 2))
        self.assertTrue(len(kwargs) == 0)

        # case 2: foo(a=10, b=20, d=4)
        args, kwargs = foo_spec.unified_args_and_kwargs([], {
            'a': 10,
            'b': 20,
            'd': 4
        })
        self.assertTupleEqual(args, (10, 20, 1, 4))
        self.assertTrue(len(kwargs) == 0)

        # case 3: foo(10, b=20)
        args, kwargs = foo_spec.unified_args_and_kwargs([10], {'b': 20})
        self.assertTupleEqual(args, (10, 20, 1, 2))
        self.assertTrue(len(kwargs) == 0)

        # assert len(self._arg_names) >= len(args)
        with self.assertRaises(ValueError):
            foo_spec.unified_args_and_kwargs([10, 20, 30, 40, 50], {'c': 4})

        # assert arg_name should be in kwargs
        with self.assertRaises(ValueError):
            foo_spec.unified_args_and_kwargs([10], {'c': 4})
Beispiel #4
0
    def __init__(self, function, input_spec=None):
        """
        Initializes a `StaticLayer`.

        Args:
            function(callable): A function or method that will be converted into static program.
            input_spec(list[InputSpec]): list of InputSpec to specify the `shape/dtype/name` information for each input argument, default None.
        """
        # save the instance `self` while decorating a method of class.
        if inspect.ismethod(function):
            self._dygraph_function = getattr(function, '__func__')
            self._class_instance = getattr(function, '__self__')
        else:
            self._dygraph_function = function
            self._class_instance = None

        self._input_spec = input_spec
        self._function_spec = FunctionSpec(function, input_spec)
        self._program_cache = ProgramCache()
        self._descriptor_cache = weakref.WeakKeyDictionary()
        # Note: Hold a reference to ProgramTranslator for switching `enable_declarative`.
        self._program_trans = ProgramTranslator()
    def test_args_to_input_spec(self):
        a_spec = InputSpec([None, 10], name='a')
        b_spec = InputSpec([10], name='b')

        a_tensor = paddle.static.data(name='a_var', shape=[4, 10])
        b_tensor = paddle.static.data(name='b_var', shape=[4, 10])
        kwargs = {'c': 1, 'd': 2}

        # case 1
        foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec])
        input_with_spec, _ = foo_spec.args_to_input_spec(
            (a_tensor, b_tensor, 1, 2), {})

        self.assertTrue(len(input_with_spec) == 4)
        self.assertTrue(input_with_spec[0] == a_spec)  # a
        self.assertTrue(input_with_spec[1] == b_spec)  # b
        self.assertTrue(input_with_spec[2] == 1)  # c
        self.assertTrue(input_with_spec[3] == 2)  # d

        # case 2
        foo_spec = FunctionSpec(foo_func, input_spec=[a_spec])
        input_with_spec, _ = foo_spec.args_to_input_spec((a_tensor, b_tensor),
                                                         {})
        self.assertTrue(len(input_with_spec) == 2)
        self.assertTrue(input_with_spec[0] == a_spec)  # a
        self.assertTupleEqual(input_with_spec[1].shape, (4, 10))  # b.shape
        self.assertEqual(input_with_spec[1].name, 'b_var')  # b.name

        # case 3
        # assert kwargs is None if set `input_spec`
        foo_spec = FunctionSpec(foo_func, input_spec=[a_spec])
        with self.assertRaises(ValueError):
            input_with_spec = foo_spec.args_to_input_spec((a_tensor, b_tensor),
                                                          {'c': 4})

        # case 4
        # assert len(args) >= len(self._input_spec)
        foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec])
        with self.assertRaises(ValueError):
            input_with_spec = foo_spec.args_to_input_spec((a_tensor, ), {})
 def test_constructor(self):
     foo_spec = FunctionSpec(foo_func)
     args_name = foo_spec.args_name
     self.assertListEqual(args_name, ['a', 'b', 'c', 'd'])
     self.assertTrue(foo_spec.dygraph_function == foo_func)
     self.assertTrue(foo_spec.input_spec is None)
Beispiel #7
0
    def get_program(self, dygraph_func, *args, **kwargs):
        """
        Returns the translated static program and input/output variables from
        dygraph function. The users can use the program to run by executor.

        Args:
            dygraph_func (callable): the dygraph function.
            *args, **kwargs : the input argument of dygraph_func.

        Returns:
            tuple of (main_program, startup_program, inputs, outputs) whose
            types are (Program, Program, list of Variable, list of Variable).
            main_program: the converted main program.
            startup_program: the converted startup program.
            inputs: list of input Variables which need to be fed.
            outputs: list of output Variables which users can fetch.

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                def func(x):
                    x = fluid.dygraph.to_variable(x)
                    if fluid.layers.mean(x) > 0:
                        x_v = x - 1
                    else:
                        x_v = x + 1
                    return x_v

                prog_trans = fluid.dygraph.ProgramTranslator()

                x = np.ones([1, 2])
                main_prog, start_prog, inputs, outputs = prog_trans.get_program(func, x)
                print([i.name for i in inputs])
                # ['feed_0'] the feed input variable name representing x
                print([o.name for o in outputs])
                # ['_generated_var_4'] the fetch output variable name representing x_v        

        """
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_program"
        if not self.enable_declarative:
            warnings.warn(
                "The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable=False."
                "We will just return dygraph output.")
            return dygraph_func(*args, **kwargs)

        function_spec = FunctionSpec(dygraph_func)
        cache_key = CacheKey.from_func_and_args(function_spec, args, kwargs,
                                                getattr(dygraph_func,
                                                        '__self__', None))
        concrete_program, partial_program_layer = self._program_cache[cache_key]

        # Note: concrete_program hold all input/output infos include non-Variable
        input_vars = [
            var for var in concrete_program.inputs
            if isinstance(var, framework.Variable)
        ]
        output_vars = [
            var for var in concrete_program.outputs
            if isinstance(var, framework.Variable)
        ]

        return concrete_program.main_program, \
               concrete_program.startup_program, \
               input_vars, \
               output_vars
Beispiel #8
0
    def get_output(self, dygraph_func, *args, **kwargs):
        """
        Returns the output dygraph VarBase for dygraph function. The dygraph
        function will be translated into static graph function so the under
        beneath numerical result will be calculated by declarative mode.

        Args:
            dygraph_func (callable): the dygraph function.
            *args, **kwargs : the input argument of dygraph_func.

        Returns:
            VarBase or tuple of VarBase: the dygraph VarBase containing digital
                result.

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                def func(x):
                    x = fluid.dygraph.to_variable(x)
                    if fluid.layers.mean(x) > 0:
                        x_v = x - 1
                    else:
                        x_v = x + 1
                    return x_v

                prog_trans = fluid.dygraph.ProgramTranslator()

                with fluid.dygraph.guard():
                    x = np.ones([1, 2])
                    x_v = prog_trans.get_output(func, x)
                    print(x_v.numpy()) # [[0. 0.]]

        """
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_output"
        if not self.enable_declarative:
            warnings.warn(
                "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable = False. "
                "We will just return dygraph output.")
            return dygraph_func(*args, **kwargs)

        function_spec = FunctionSpec(dygraph_func)
        cache_key = CacheKey.from_func_and_args(function_spec, args, kwargs,
                                                getattr(dygraph_func,
                                                        '__self__', None))
        _, partial_program_layer = self._program_cache[cache_key]

        if args and isinstance(args[0], layers.Layer):
            # Synchronize self.training attribute.
            partial_program_layer.training = args[0].training
            args = args[1:]
        try:
            return partial_program_layer(args)

        except BaseException as e:
            # NOTE:
            # 1. If e is raised in compile time, e should have been attached to ERROR_DATA before;
            # 2. If e raised in runtime, e should be attached to ERROR_DATA here.
            if not hasattr(e, ERROR_DATA):
                # runtime error
                attach_error_data(e, in_runtime=True)
            raise
Beispiel #9
0
class StaticLayer(object):
    """
    Wrapper class to Manage program conversion of decorated function.

    """

    def __init__(self, function, input_spec=None):
        """
        Initializes a `StaticLayer`.

        Args:
            function(callable): A function or method that will be converted into static program.
            input_spec(list[InputSpec]): list of InputSpec to specify the `shape/dtype/name` information for each input argument, default None.
        """
        # save the instance `self` while decorating a method of class.
        if inspect.ismethod(function):
            self._dygraph_function = getattr(function, '__func__')
            self._class_instance = getattr(function, '__self__')
        else:
            self._dygraph_function = function
            self._class_instance = None

        self._input_spec = input_spec
        self._function_spec = FunctionSpec(function, input_spec)
        self._program_cache = ProgramCache()
        self._descriptor_cache = weakref.WeakKeyDictionary()
        # Note: Hold a reference to ProgramTranslator for switching `enable_declarative`.
        self._program_trans = ProgramTranslator()

    def __get__(self, instance, owner):
        """
        Overrides this method to parse the class instance and call bound method correctly.

        For example:
            
            '''
            class Net(Layer):
                def __init__(self):
                    pass
                
                @paddle.jit.to_static
                def forward(self, x, y):
                    return x + y

            net = Net()
            out = net(x, y)
            '''
        
        In above case, `net(x, y)` will call `net.forward(x, y)` firstly that is a bound method
        of `Net` instance. After decorated by `@paddle.jit.to_static`, it will firstly to call `__get__`
        to parse the class instance correctly instead of the `StaticLayer` instance.
        """
        if instance not in self._descriptor_cache:
            if instance is None:
                return self
            # Note(Aurelius84): To construct new instance of StaticLayer when we
            # first encouter the bound function of layer and cache it.
            new_static_layer = self._clone()
            new_static_layer._class_instance = instance
            self._descriptor_cache[instance] = new_static_layer

        return self._descriptor_cache[instance]

    def _clone(self):
        return self.__class__(self._dygraph_function, self._input_spec)

    def __call__(self, *args, **kwargs):
        """
        Supports to call the returned instance with input `args` and `kwargs` directly.

        Args:
            *args(tuple): tuple of all input arguments from original decorated function.
            **kwargs(dict): dict of all input keyward arguments from original decorated function. 

        Return:
            Outputs of decorated function.
        """

        # 1. call dygraph function directly if not enable `declarative`
        if not self._program_trans.enable_declarative:
            logging_utils.warn(
                "The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable=False. "
                "We will just return dygraph output.")
            return self._call_dygraph_function(*args, **kwargs)

        if not in_dygraph_mode() and self._program_trans.enable_declarative:
            raise RuntimeError(
                "Failed to run the callable object {} decorated by '@paddle.jit.to_static', "
                "because it does NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the "
                "following API: paddle.disable_static().".format(
                    self.dygraph_function))

        # 2. trace ops from dygraph layers and cache the generated program.
        args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs)
        try:
            concrete_program, partial_program_layer = self.get_concrete_program(
                *args, **kwargs)

            # 3. synchronize self.training attribute.
            if isinstance(self._class_instance, layers.Layer):
                partial_program_layer.training = self._class_instance.training

            # 4. return outputs.
            return partial_program_layer(args)
        except Exception as e:
            if not hasattr(e, ERROR_DATA):
                # runtime error
                attach_error_data(e, in_runtime=True)
            error_data = getattr(e, ERROR_DATA, None)
            if error_data:
                new_exception = error_data.create_exception()
                if six.PY3:
                    # NOTE(liym27):
                    # 1. Why `raise new_exception from None`?
                    #   In Python 3, by default, an new exception is raised with trace information of the caught exception.
                    #   This only raises new_exception and hides unwanted implementation details from tracebacks of the
                    #   caught exception.
                    # 2. Use exec to bypass syntax error checking in Python 2.

                    six.exec_("raise new_exception from None")
                else:
                    raise new_exception
            else:
                raise

    def _call_dygraph_function(self, *args, **kwargs):
        """
        Calls dygraph function directly and returns the outputs.

        Args:
            *args(tuple): tuple of all input arguments from original decorated function.
            **kwargs(dict): dict of all input keyward arguments from original decorated function. 

        Return:
            Outputs of dygraph function.
        """
        if self._class_instance is not None:
            dygraph_function = self._dygraph_function.__get__(
                self._class_instance)
        else:
            dygraph_function = self._dygraph_function

        return dygraph_function(*args, **kwargs)

    def get_concrete_program(self, *args, **kwargs):
        """
        Returns traced concrete program and inner executable partial layer.

        Args:
            *args(tuple): input arguments values or InputSpec
            **kwargs(dict) : input kwargs values.

        Returns:
            Traced ConcreteProgram and executable translated Layer.
        """
        # 1. unify args/kwargs and replace Tensor with InputSpec
        if len(args) != len(self._function_spec.args_name):
            args, kwargs = self._function_spec.unified_args_and_kwargs(args,
                                                                       kwargs)
        input_with_spec = self._function_spec.args_to_input_spec(args, kwargs)

        # 2. generate cache key
        cache_key = CacheKey(self._function_spec, input_with_spec,
                             self._class_instance)

        # 3. check whether hit the cache or build a new program for the input arguments
        concrete_program, partial_program_layer = self._program_cache[cache_key]
        return concrete_program, partial_program_layer

    def get_traced_count(self):
        """
        Returns the number of traced programs for the decorated function.
        """
        return len(self._program_cache)

    @property
    def code(self):
        """
        Returns the source code of transformed static function for debugging.
        """
        static_func = convert_to_static(self._dygraph_function)
        source_code = func_to_source_code(static_func)
        return source_code

    @property
    def dygraph_function(self):
        """
        Returns the original decorated function.
        """
        return self._dygraph_function

    @property
    def concrete_program(self):
        """
        Returns recent ConcreteProgram instance of decorated function.

        Examples:
            .. code-block:: python

                import paddle
                from paddle.jit import to_static
                from paddle.static import InputSpec

                paddle.disable_static()

                def foo(x, y):
                    z = x + y
                    return z
                
                # usage 1:
                decorated_foo = to_static(foo, input_spec=[InputSpec([10], name='x'), InputSpec([10], name='y')])
                print(decorated_foo.concrete_program)

                # usage 2:
                decorated_foo = to_static(foo)
                out_foo = decorated_foo(paddle.rand([10]), paddle.rand([10]))
                print(decorated_foo.concrete_program)
        """
        # if specific the `input_spec`, the length of program_cache will always 1,
        # else, return the last one.
        cached_program_len = len(self._program_cache)
        # If specific `input_spec`, apply convertion from dygraph layers into static Program.
        if cached_program_len == 0:
            input_spec = self._function_spec.input_spec
            has_input_spec = (input_spec is not None and len(input_spec) > 0)
            if has_input_spec:
                concrete_program, _ = self.get_concrete_program(*input_spec)
                return concrete_program
            else:
                raise ValueError(
                    "No valid transformed program for {}.\n\t    Please specific `input_spec` in `@paddle.jit.to_static` or feed input tensor to call the decorated function at once.\n".
                    format(self._function_spec))
        # If more than one programs have been cached, return the recent converted program by default.
        elif cached_program_len > 1:
            logging.warning(
                "Current {} has more than one cached programs: {}, the last traced progam will be return by default.".
                format(self._function_spec, cached_program_len))

        cache_key, (concrete_program,
                    partial_layer) = self._program_cache.last()
        return concrete_program

    @property
    def inputs(self):
        """
        Returns input tensors of recent converted static program.
        """
        concrete_program = self.concrete_program
        inputs = [
            var for var in flatten(concrete_program.inputs)
            if isinstance(var, framework.Variable)
        ]
        return inputs

    @property
    def outputs(self):
        """
        Returns output tensors of recent converted static program.
        """
        concrete_program = self.concrete_program
        outputs = [
            var for var in flatten(concrete_program.outputs)
            if isinstance(var, framework.Variable)
        ]

        return outputs

    @property
    def main_program(self):
        """
        Returns recent converted static main program.
        """
        concrete_program = self.concrete_program
        main_program = concrete_program.main_program
        return main_program

    @property
    def program_cache(self):
        return self._program_cache

    @property
    def function_spec(self):
        return self._function_spec
Beispiel #10
0
    def get_output(self, dygraph_func, *args, **kwargs):
        """
        Returns the output dygraph Tensor for dygraph function. The dygraph
        function will be translated into static graph function so the under
        beneath numerical result will be calculated by static graph mode.

        Args:
            dygraph_func (callable): the dygraph function.
            *args (tuple): the input argument of dygraph_func.
            **kwargs (dict): the input argument of dygraph_func.

        Returns:
            Tensor or tuple of Tensors: the dygraph Tensor containing digital result.

        Examples:
            .. code-block:: python

                import paddle


                def func(x):
                    if paddle.mean(x) > 0:
                        x_v = x - 1
                    else:
                        x_v = x + 1
                    return x_v


                prog_trans = paddle.jit.ProgramTranslator()

                x = paddle.ones([1, 2])
                x_v = prog_trans.get_output(func, x)
                print(x_v)  # [[0. 0.]]

        """
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_output"

        if not self.enable_to_static:
            # Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message)
            # will show up **only once**.
            logging_utils.warn(
                "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. "
                "We will just return dygraph output. "
                "Please call ProgramTranslator.enable(True) if you would like to get static output."
            )
            return dygraph_func(*args, **kwargs)
        try:
            function_spec = FunctionSpec(dygraph_func)
            cache_key = CacheKey.from_func_and_args(
                function_spec, args, kwargs,
                getattr(dygraph_func, '__self__', None))
            _, partial_program_layer = self._program_cache[cache_key]

            if args and isinstance(args[0], layers.Layer):
                # Synchronize self.training attribute.
                partial_program_layer.training = args[0].training
                args = args[1:]
            try:
                return partial_program_layer(args)
            except BaseException as e:
                # NOTE:
                # 1. If e is raised in compile time, e should have been attached to ERROR_DATA before;
                # 2. If e raised in runtime, e should be attached to ERROR_DATA here.
                if not hasattr(e, error.ERROR_DATA):
                    # runtime error
                    error.attach_error_data(e, in_runtime=True)
                raise
        except BaseException as e:
            error_data = getattr(e, error.ERROR_DATA, None)
            if error_data:
                error_data.raise_new_exception()
            else:
                logging_utils.warn(
                    "Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'"
                    " if you can't handle this {} yourself.".format(type(e)))
                raise e
Beispiel #11
0
class StaticFunction(object):
    """
    Wrapper class to Manage program conversion of decorated function.

    """
    def __init__(self, function, input_spec=None, **kwargs):
        """
        Initializes a `StaticFunction`.

        Args:
            function(callable): A function or method that will be converted into static program.
            input_spec(list[InputSpec]): list of InputSpec to specify the `shape/dtype/name` information for each input argument, default None.
            **kwargs(dict): other arguments like `build_strategy` et.al.
        """
        # save the instance `self` while decorating a method of class.
        if inspect.ismethod(function):
            self._dygraph_function = getattr(function, '__func__')
            self._class_instance = getattr(function, '__self__')
        else:
            self._dygraph_function = function
            self._class_instance = None

        self._input_spec = input_spec
        self._function_spec = FunctionSpec(function, input_spec)
        self._program_cache = ProgramCache()
        self._descriptor_cache = weakref.WeakKeyDictionary()
        # Note: Hold a reference to ProgramTranslator for switching `enable_to_static`.
        self._program_trans = ProgramTranslator()
        self._kwargs = kwargs
        self._training = True
        self._cuda_graph_capture_mode = ""
        self._cuda_graph_pool_id = 0

    def train(self):
        if isinstance(self._class_instance,
                      layers.Layer) and self._class_instance.training == False:
            raise RuntimeError(
                "Failed to switch train mode. {} is a Layer's method, "
                "please use Layer.train() to switch train mode.".format(
                    self.dygraph_function))
        self._training = True

    def eval(self):
        if isinstance(self._class_instance,
                      layers.Layer) and self._class_instance.training == True:
            raise RuntimeError(
                "Failed to switch eval mode. {} is a Layer's method, "
                "please use Layer.eval() to switch eval mode.".format(
                    self.dygraph_function))
        self._training = False

    def __get__(self, instance, owner):
        """
        Overrides this method to parse the class instance and call bound method correctly.

        For example:
            
            '''
            class Net(Layer):
                def __init__(self):
                    pass
                
                @paddle.jit.to_static
                def forward(self, x, y):
                    return x + y

            net = Net()
            out = net(x, y)
            '''
        
        In above case, `net(x, y)` will call `net.forward(x, y)` firstly that is a bound method
        of `Net` instance. After decorated by `@paddle.jit.to_static`, it will firstly to call `__get__`
        to parse the class instance correctly instead of the `StaticFunction` instance.
        """
        if instance not in self._descriptor_cache:
            if instance is None:
                return self
            # Note(Aurelius84): To construct new instance of StaticFunction when we
            # first encouter the bound function of layer and cache it.
            new_static_layer = self._clone()
            new_static_layer._class_instance = instance
            self._descriptor_cache[instance] = new_static_layer

        return self._descriptor_cache[instance]

    def _clone(self):
        return self.__class__(self._dygraph_function, self._input_spec)

    def __call__(self, *args, **kwargs):
        """
        Supports to call the returned instance with input `args` and `kwargs` directly.

        Args:
            *args(tuple): tuple of all input arguments from original decorated function.
            **kwargs(dict): dict of all input keyward arguments from original decorated function. 

        Return:
            Outputs of decorated function.
        """

        # 1. call dygraph function directly if not enable `declarative`
        if not self._program_trans.enable_to_static:
            # NOTE(liym27):
            # Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message)
            # will show up **only once**. StaticFunction.__call__ will run many times, it is appropriate to
            # display this warning message only once.
            logging_utils.warn(
                "The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable to False. "
                "We will just return dygraph output. If you would like to get static graph output, please call API "
                "ProgramTranslator.enable(True)")
            return self._call_dygraph_function(*args, **kwargs)

        if not _non_static_mode():
            raise RuntimeError(
                "Failed to run the callable object {} decorated by '@paddle.jit.to_static', "
                "because it is NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the "
                "following API: paddle.disable_static().".format(
                    self.dygraph_function))

        # 2. trace ops from dygraph layers and cache the generated program.
        args, kwargs = self._function_spec.unified_args_and_kwargs(
            args, kwargs)

        try:
            concrete_program, partial_program_layer = self.get_concrete_program(
                *args, **kwargs, is_train=self._is_train_mode())

            # 3. synchronize self.training attribute.
            if isinstance(self._class_instance, layers.Layer):
                partial_program_layer.training = self._class_instance.training
            else:
                partial_program_layer.training = self._training

            partial_program_layer._cuda_graph_capture_mode = self._cuda_graph_capture_mode
            partial_program_layer._cuda_graph_pool_id = self._cuda_graph_pool_id

            # 4. return outputs.
            try:
                return partial_program_layer(args)
            except Exception as e:
                if not hasattr(e, error.ERROR_DATA):
                    # runtime error
                    error.attach_error_data(e, in_runtime=True)
                    raise
        except Exception as e:
            error_data = getattr(e, error.ERROR_DATA, None)
            if error_data:
                error_data.raise_new_exception()
            else:
                logging_utils.warn(
                    "Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'"
                    " if you can't handle this {} yourself.".format(type(e)))
                raise e

    def _is_train_mode(self):
        if self._class_instance is not None:
            return self._class_instance.training
        else:
            return self._training

    def _call_dygraph_function(self, *args, **kwargs):
        """
        Calls dygraph function directly and returns the outputs.

        Args:
            *args(tuple): tuple of all input arguments from original decorated function.
            **kwargs(dict): dict of all input keyward arguments from original decorated function. 

        Return:
            Outputs of dygraph function.
        """
        if self._class_instance is not None:
            dygraph_function = self._dygraph_function.__get__(
                self._class_instance)
        else:
            dygraph_function = self._dygraph_function

        return dygraph_function(*args, **kwargs)

    def get_concrete_program(self, *args, **kwargs):
        """
        Returns traced concrete program and inner executable partial layer.

        Args:
            *args(tuple): input arguments values or InputSpec
            **kwargs(dict) : input kwargs values.

        Returns:
            Traced ConcreteProgram and executable translated Layer.
        """

        with_hook = kwargs.get("with_hook", False)
        is_train = kwargs.get("is_train", True)
        if "is_train" in kwargs: kwargs.pop("is_train")
        if "with_hook" in kwargs: kwargs.pop("with_hook")
        # 1. unify args/kwargs and replace Tensor with InputSpec
        if len(args) != len(self._function_spec.args_name):
            args, kwargs = self._function_spec.unified_args_and_kwargs(
                args, kwargs)
        input_args_with_spec, input_kwargs_with_spec = self._function_spec.args_to_input_spec(
            args, kwargs)

        # 2. generate cache key
        cache_key = CacheKey(self._function_spec,
                             input_args_with_spec,
                             input_kwargs_with_spec,
                             self._class_instance,
                             **self._kwargs,
                             with_hook=with_hook,
                             is_train=is_train)

        # 3. check whether hit the cache or build a new program for the input arguments
        concrete_program, partial_program_layer = self._program_cache[
            cache_key]
        return concrete_program, partial_program_layer

    def get_traced_count(self):
        """
        Returns the number of traced programs for the decorated function.
        """
        return len(self._program_cache)

    @property
    def code(self):
        """
        Returns the source code of transformed static function for debugging.
        """
        static_func = convert_to_static(self._dygraph_function)
        source_code = func_to_source_code(static_func)
        return source_code

    @property
    def dygraph_function(self):
        """
        Returns the original decorated function.
        """
        return self._dygraph_function

    @property
    def concrete_program(self):
        """
        Returns recent ConcreteProgram instance of decorated function.

        Examples:
            .. code-block:: python

                import paddle
                from paddle.jit import to_static
                from paddle.static import InputSpec

                paddle.disable_static()

                def foo(x, y):
                    z = x + y
                    return z
                
                # usage 1:
                decorated_foo = to_static(foo, input_spec=[InputSpec([10], name='x'), InputSpec([10], name='y')])
                print(decorated_foo.concrete_program)

                # usage 2:
                decorated_foo = to_static(foo)
                out_foo = decorated_foo(paddle.rand([10]), paddle.rand([10]))
                print(decorated_foo.concrete_program)
        """
        return self.concrete_program_specify_input_spec(input_spec=None)

    def concrete_program_specify_input_spec(self,
                                            input_spec=None,
                                            with_hook=False):
        """
        Returns recent ConcreteProgram instance of decorated function while
        specifying input_spec. If the self._function_spec already has
        input_spec, it will check the compatibility of input input_spec and
        the self._function_spec.input_spec. If input input_spec=None, then
        this method uses self._function_spec.input_spec

        args:
            input_spec (list[InputSpec], optional): Describes the input of
                the translate function.
        """
        # if specific the `input_spec`, the length of program_cache will always 1,
        # else, return the last one.
        cached_program_len = len(self._program_cache)
        # If specific `input_spec`, apply convertion from dygraph layers into static Program.
        if cached_program_len == 0:
            desired_input_spec = input_spec
            if self._function_spec.input_spec is not None:
                if input_spec is not None and not input_specs_compatible(
                        flatten(input_spec),
                        flatten(self._function_spec.input_spec)):
                    raise ValueError(
                        "The `input_spec`: {} used to construct concrete_program is conflict with the `input_spec`: {} in `@paddle.jit.to_static`"
                        .format(input_spec, self._function_spec.input_spec))
                # NOTE(chenweihang): we should always translated program based on the `input_spec`
                # decorated on forward if it is valid
                desired_input_spec = self._function_spec.input_spec
                if input_spec is not None:
                    logging_utils.warn(
                        "\n\nYou have specified `input_spec` both in function definition (higher priority) and `paddle.jit.save` (will be ignored.)\n\n\t Using: {}\n\n\t Ignore: {}\n"
                        .format(desired_input_spec, input_spec))

            has_input_spec = (desired_input_spec is not None)
            if has_input_spec:
                concrete_program, _ = self.get_concrete_program(
                    *desired_input_spec,
                    with_hook=with_hook,
                    is_train=self._is_train_mode())
                return concrete_program
            else:
                raise ValueError(
                    "No valid transformed program for {}.\n\t    Please specific `input_spec` in `@paddle.jit.to_static` or feed input tensor to call the decorated function at once.\n"
                    .format(self._function_spec))
        elif with_hook:
            cache_key = self._program_cache._recent_cache_key
            cache_key.kwargs["with_hook"] = True
            concrete_program, _ = self._program_cache[cache_key]
            return concrete_program

        # If more than one programs have been cached, return the recent converted program by default.
        elif cached_program_len > 1:
            logging_utils.warn(
                "Current {} has more than one cached programs: {}, the last traced progam will be return by default."
                .format(self._function_spec, cached_program_len))

        cache_key, (concrete_program,
                    partial_layer) = self._program_cache.last()
        return concrete_program

    @property
    def inputs(self):
        """
        Returns input tensors of recent converted static program.
        """
        concrete_program = self.concrete_program
        inputs = [
            var for var in flatten(concrete_program.inputs)
            if isinstance(var, framework.Variable)
        ]
        return inputs

    @property
    def outputs(self):
        """
        Returns output tensors of recent converted static program.
        """
        concrete_program = self.concrete_program
        outputs = [
            var for var in flatten(concrete_program.outputs)
            if isinstance(var, framework.Variable)
        ]

        return outputs

    @property
    def main_program(self):
        """
        Returns recent converted static main program.
        """
        concrete_program = self.concrete_program
        main_program = concrete_program.main_program
        return main_program

    @property
    def program_cache(self):
        return self._program_cache

    @property
    def function_spec(self):
        return self._function_spec