def test_args_to_input_spec(self):
        a_spec = InputSpec([None, 10], name='a')
        b_spec = InputSpec([10], name='b')

        a_tensor = paddle.static.data(name='a_var', shape=[4, 10])
        b_tensor = paddle.static.data(name='b_var', shape=[4, 10])
        kwargs = {'c': 1, 'd': 2}

        # case 1
        foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec])
        input_with_spec, _ = foo_spec.args_to_input_spec(
            (a_tensor, b_tensor, 1, 2), {})

        self.assertTrue(len(input_with_spec) == 4)
        self.assertTrue(input_with_spec[0] == a_spec)  # a
        self.assertTrue(input_with_spec[1] == b_spec)  # b
        self.assertTrue(input_with_spec[2] == 1)  # c
        self.assertTrue(input_with_spec[3] == 2)  # d

        # case 2
        foo_spec = FunctionSpec(foo_func, input_spec=[a_spec])
        input_with_spec, _ = foo_spec.args_to_input_spec((a_tensor, b_tensor),
                                                         {})
        self.assertTrue(len(input_with_spec) == 2)
        self.assertTrue(input_with_spec[0] == a_spec)  # a
        self.assertTupleEqual(input_with_spec[1].shape, (4, 10))  # b.shape
        self.assertEqual(input_with_spec[1].name, 'b_var')  # b.name

        # case 3
        # assert kwargs is None if set `input_spec`
        foo_spec = FunctionSpec(foo_func, input_spec=[a_spec])
        with self.assertRaises(ValueError):
            input_with_spec = foo_spec.args_to_input_spec((a_tensor, b_tensor),
                                                          {'c': 4})

        # case 4
        # assert len(args) >= len(self._input_spec)
        foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec])
        with self.assertRaises(ValueError):
            input_with_spec = foo_spec.args_to_input_spec((a_tensor, ), {})
Esempio n. 2
0
class StaticLayer(object):
    """
    Wrapper class to Manage program conversion of decorated function.

    """

    def __init__(self, function, input_spec=None):
        """
        Initializes a `StaticLayer`.

        Args:
            function(callable): A function or method that will be converted into static program.
            input_spec(list[InputSpec]): list of InputSpec to specify the `shape/dtype/name` information for each input argument, default None.
        """
        # save the instance `self` while decorating a method of class.
        if inspect.ismethod(function):
            self._dygraph_function = getattr(function, '__func__')
            self._class_instance = getattr(function, '__self__')
        else:
            self._dygraph_function = function
            self._class_instance = None

        self._input_spec = input_spec
        self._function_spec = FunctionSpec(function, input_spec)
        self._program_cache = ProgramCache()
        self._descriptor_cache = weakref.WeakKeyDictionary()
        # Note: Hold a reference to ProgramTranslator for switching `enable_declarative`.
        self._program_trans = ProgramTranslator()

    def __get__(self, instance, owner):
        """
        Overrides this method to parse the class instance and call bound method correctly.

        For example:
            
            '''
            class Net(Layer):
                def __init__(self):
                    pass
                
                @paddle.jit.to_static
                def forward(self, x, y):
                    return x + y

            net = Net()
            out = net(x, y)
            '''
        
        In above case, `net(x, y)` will call `net.forward(x, y)` firstly that is a bound method
        of `Net` instance. After decorated by `@paddle.jit.to_static`, it will firstly to call `__get__`
        to parse the class instance correctly instead of the `StaticLayer` instance.
        """
        if instance not in self._descriptor_cache:
            if instance is None:
                return self
            # Note(Aurelius84): To construct new instance of StaticLayer when we
            # first encouter the bound function of layer and cache it.
            new_static_layer = self._clone()
            new_static_layer._class_instance = instance
            self._descriptor_cache[instance] = new_static_layer

        return self._descriptor_cache[instance]

    def _clone(self):
        return self.__class__(self._dygraph_function, self._input_spec)

    def __call__(self, *args, **kwargs):
        """
        Supports to call the returned instance with input `args` and `kwargs` directly.

        Args:
            *args(tuple): tuple of all input arguments from original decorated function.
            **kwargs(dict): dict of all input keyward arguments from original decorated function. 

        Return:
            Outputs of decorated function.
        """

        # 1. call dygraph function directly if not enable `declarative`
        if not self._program_trans.enable_declarative:
            logging_utils.warn(
                "The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable=False. "
                "We will just return dygraph output.")
            return self._call_dygraph_function(*args, **kwargs)

        if not in_dygraph_mode() and self._program_trans.enable_declarative:
            raise RuntimeError(
                "Failed to run the callable object {} decorated by '@paddle.jit.to_static', "
                "because it does NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the "
                "following API: paddle.disable_static().".format(
                    self.dygraph_function))

        # 2. trace ops from dygraph layers and cache the generated program.
        args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs)
        try:
            concrete_program, partial_program_layer = self.get_concrete_program(
                *args, **kwargs)

            # 3. synchronize self.training attribute.
            if isinstance(self._class_instance, layers.Layer):
                partial_program_layer.training = self._class_instance.training

            # 4. return outputs.
            return partial_program_layer(args)
        except Exception as e:
            if not hasattr(e, ERROR_DATA):
                # runtime error
                attach_error_data(e, in_runtime=True)
            error_data = getattr(e, ERROR_DATA, None)
            if error_data:
                new_exception = error_data.create_exception()
                if six.PY3:
                    # NOTE(liym27):
                    # 1. Why `raise new_exception from None`?
                    #   In Python 3, by default, an new exception is raised with trace information of the caught exception.
                    #   This only raises new_exception and hides unwanted implementation details from tracebacks of the
                    #   caught exception.
                    # 2. Use exec to bypass syntax error checking in Python 2.

                    six.exec_("raise new_exception from None")
                else:
                    raise new_exception
            else:
                raise

    def _call_dygraph_function(self, *args, **kwargs):
        """
        Calls dygraph function directly and returns the outputs.

        Args:
            *args(tuple): tuple of all input arguments from original decorated function.
            **kwargs(dict): dict of all input keyward arguments from original decorated function. 

        Return:
            Outputs of dygraph function.
        """
        if self._class_instance is not None:
            dygraph_function = self._dygraph_function.__get__(
                self._class_instance)
        else:
            dygraph_function = self._dygraph_function

        return dygraph_function(*args, **kwargs)

    def get_concrete_program(self, *args, **kwargs):
        """
        Returns traced concrete program and inner executable partial layer.

        Args:
            *args(tuple): input arguments values or InputSpec
            **kwargs(dict) : input kwargs values.

        Returns:
            Traced ConcreteProgram and executable translated Layer.
        """
        # 1. unify args/kwargs and replace Tensor with InputSpec
        if len(args) != len(self._function_spec.args_name):
            args, kwargs = self._function_spec.unified_args_and_kwargs(args,
                                                                       kwargs)
        input_with_spec = self._function_spec.args_to_input_spec(args, kwargs)

        # 2. generate cache key
        cache_key = CacheKey(self._function_spec, input_with_spec,
                             self._class_instance)

        # 3. check whether hit the cache or build a new program for the input arguments
        concrete_program, partial_program_layer = self._program_cache[cache_key]
        return concrete_program, partial_program_layer

    def get_traced_count(self):
        """
        Returns the number of traced programs for the decorated function.
        """
        return len(self._program_cache)

    @property
    def code(self):
        """
        Returns the source code of transformed static function for debugging.
        """
        static_func = convert_to_static(self._dygraph_function)
        source_code = func_to_source_code(static_func)
        return source_code

    @property
    def dygraph_function(self):
        """
        Returns the original decorated function.
        """
        return self._dygraph_function

    @property
    def concrete_program(self):
        """
        Returns recent ConcreteProgram instance of decorated function.

        Examples:
            .. code-block:: python

                import paddle
                from paddle.jit import to_static
                from paddle.static import InputSpec

                paddle.disable_static()

                def foo(x, y):
                    z = x + y
                    return z
                
                # usage 1:
                decorated_foo = to_static(foo, input_spec=[InputSpec([10], name='x'), InputSpec([10], name='y')])
                print(decorated_foo.concrete_program)

                # usage 2:
                decorated_foo = to_static(foo)
                out_foo = decorated_foo(paddle.rand([10]), paddle.rand([10]))
                print(decorated_foo.concrete_program)
        """
        # if specific the `input_spec`, the length of program_cache will always 1,
        # else, return the last one.
        cached_program_len = len(self._program_cache)
        # If specific `input_spec`, apply convertion from dygraph layers into static Program.
        if cached_program_len == 0:
            input_spec = self._function_spec.input_spec
            has_input_spec = (input_spec is not None and len(input_spec) > 0)
            if has_input_spec:
                concrete_program, _ = self.get_concrete_program(*input_spec)
                return concrete_program
            else:
                raise ValueError(
                    "No valid transformed program for {}.\n\t    Please specific `input_spec` in `@paddle.jit.to_static` or feed input tensor to call the decorated function at once.\n".
                    format(self._function_spec))
        # If more than one programs have been cached, return the recent converted program by default.
        elif cached_program_len > 1:
            logging.warning(
                "Current {} has more than one cached programs: {}, the last traced progam will be return by default.".
                format(self._function_spec, cached_program_len))

        cache_key, (concrete_program,
                    partial_layer) = self._program_cache.last()
        return concrete_program

    @property
    def inputs(self):
        """
        Returns input tensors of recent converted static program.
        """
        concrete_program = self.concrete_program
        inputs = [
            var for var in flatten(concrete_program.inputs)
            if isinstance(var, framework.Variable)
        ]
        return inputs

    @property
    def outputs(self):
        """
        Returns output tensors of recent converted static program.
        """
        concrete_program = self.concrete_program
        outputs = [
            var for var in flatten(concrete_program.outputs)
            if isinstance(var, framework.Variable)
        ]

        return outputs

    @property
    def main_program(self):
        """
        Returns recent converted static main program.
        """
        concrete_program = self.concrete_program
        main_program = concrete_program.main_program
        return main_program

    @property
    def program_cache(self):
        return self._program_cache

    @property
    def function_spec(self):
        return self._function_spec
Esempio n. 3
0
class StaticFunction(object):
    """
    Wrapper class to Manage program conversion of decorated function.

    """
    def __init__(self, function, input_spec=None, **kwargs):
        """
        Initializes a `StaticFunction`.

        Args:
            function(callable): A function or method that will be converted into static program.
            input_spec(list[InputSpec]): list of InputSpec to specify the `shape/dtype/name` information for each input argument, default None.
            **kwargs(dict): other arguments like `build_strategy` et.al.
        """
        # save the instance `self` while decorating a method of class.
        if inspect.ismethod(function):
            self._dygraph_function = getattr(function, '__func__')
            self._class_instance = getattr(function, '__self__')
        else:
            self._dygraph_function = function
            self._class_instance = None

        self._input_spec = input_spec
        self._function_spec = FunctionSpec(function, input_spec)
        self._program_cache = ProgramCache()
        self._descriptor_cache = weakref.WeakKeyDictionary()
        # Note: Hold a reference to ProgramTranslator for switching `enable_to_static`.
        self._program_trans = ProgramTranslator()
        self._kwargs = kwargs
        self._training = True
        self._cuda_graph_capture_mode = ""
        self._cuda_graph_pool_id = 0

    def train(self):
        if isinstance(self._class_instance,
                      layers.Layer) and self._class_instance.training == False:
            raise RuntimeError(
                "Failed to switch train mode. {} is a Layer's method, "
                "please use Layer.train() to switch train mode.".format(
                    self.dygraph_function))
        self._training = True

    def eval(self):
        if isinstance(self._class_instance,
                      layers.Layer) and self._class_instance.training == True:
            raise RuntimeError(
                "Failed to switch eval mode. {} is a Layer's method, "
                "please use Layer.eval() to switch eval mode.".format(
                    self.dygraph_function))
        self._training = False

    def __get__(self, instance, owner):
        """
        Overrides this method to parse the class instance and call bound method correctly.

        For example:
            
            '''
            class Net(Layer):
                def __init__(self):
                    pass
                
                @paddle.jit.to_static
                def forward(self, x, y):
                    return x + y

            net = Net()
            out = net(x, y)
            '''
        
        In above case, `net(x, y)` will call `net.forward(x, y)` firstly that is a bound method
        of `Net` instance. After decorated by `@paddle.jit.to_static`, it will firstly to call `__get__`
        to parse the class instance correctly instead of the `StaticFunction` instance.
        """
        if instance not in self._descriptor_cache:
            if instance is None:
                return self
            # Note(Aurelius84): To construct new instance of StaticFunction when we
            # first encouter the bound function of layer and cache it.
            new_static_layer = self._clone()
            new_static_layer._class_instance = instance
            self._descriptor_cache[instance] = new_static_layer

        return self._descriptor_cache[instance]

    def _clone(self):
        return self.__class__(self._dygraph_function, self._input_spec)

    def __call__(self, *args, **kwargs):
        """
        Supports to call the returned instance with input `args` and `kwargs` directly.

        Args:
            *args(tuple): tuple of all input arguments from original decorated function.
            **kwargs(dict): dict of all input keyward arguments from original decorated function. 

        Return:
            Outputs of decorated function.
        """

        # 1. call dygraph function directly if not enable `declarative`
        if not self._program_trans.enable_to_static:
            # NOTE(liym27):
            # Here calls `warnings.warn` but not `logging_utils.warn` because by default warnings.warn(message)
            # will show up **only once**. StaticFunction.__call__ will run many times, it is appropriate to
            # display this warning message only once.
            logging_utils.warn(
                "The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable to False. "
                "We will just return dygraph output. If you would like to get static graph output, please call API "
                "ProgramTranslator.enable(True)")
            return self._call_dygraph_function(*args, **kwargs)

        if not _non_static_mode():
            raise RuntimeError(
                "Failed to run the callable object {} decorated by '@paddle.jit.to_static', "
                "because it is NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the "
                "following API: paddle.disable_static().".format(
                    self.dygraph_function))

        # 2. trace ops from dygraph layers and cache the generated program.
        args, kwargs = self._function_spec.unified_args_and_kwargs(
            args, kwargs)

        try:
            concrete_program, partial_program_layer = self.get_concrete_program(
                *args, **kwargs, is_train=self._is_train_mode())

            # 3. synchronize self.training attribute.
            if isinstance(self._class_instance, layers.Layer):
                partial_program_layer.training = self._class_instance.training
            else:
                partial_program_layer.training = self._training

            partial_program_layer._cuda_graph_capture_mode = self._cuda_graph_capture_mode
            partial_program_layer._cuda_graph_pool_id = self._cuda_graph_pool_id

            # 4. return outputs.
            try:
                return partial_program_layer(args)
            except Exception as e:
                if not hasattr(e, error.ERROR_DATA):
                    # runtime error
                    error.attach_error_data(e, in_runtime=True)
                    raise
        except Exception as e:
            error_data = getattr(e, error.ERROR_DATA, None)
            if error_data:
                error_data.raise_new_exception()
            else:
                logging_utils.warn(
                    "Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'"
                    " if you can't handle this {} yourself.".format(type(e)))
                raise e

    def _is_train_mode(self):
        if self._class_instance is not None:
            return self._class_instance.training
        else:
            return self._training

    def _call_dygraph_function(self, *args, **kwargs):
        """
        Calls dygraph function directly and returns the outputs.

        Args:
            *args(tuple): tuple of all input arguments from original decorated function.
            **kwargs(dict): dict of all input keyward arguments from original decorated function. 

        Return:
            Outputs of dygraph function.
        """
        if self._class_instance is not None:
            dygraph_function = self._dygraph_function.__get__(
                self._class_instance)
        else:
            dygraph_function = self._dygraph_function

        return dygraph_function(*args, **kwargs)

    def get_concrete_program(self, *args, **kwargs):
        """
        Returns traced concrete program and inner executable partial layer.

        Args:
            *args(tuple): input arguments values or InputSpec
            **kwargs(dict) : input kwargs values.

        Returns:
            Traced ConcreteProgram and executable translated Layer.
        """

        with_hook = kwargs.get("with_hook", False)
        is_train = kwargs.get("is_train", True)
        if "is_train" in kwargs: kwargs.pop("is_train")
        if "with_hook" in kwargs: kwargs.pop("with_hook")
        # 1. unify args/kwargs and replace Tensor with InputSpec
        if len(args) != len(self._function_spec.args_name):
            args, kwargs = self._function_spec.unified_args_and_kwargs(
                args, kwargs)
        input_args_with_spec, input_kwargs_with_spec = self._function_spec.args_to_input_spec(
            args, kwargs)

        # 2. generate cache key
        cache_key = CacheKey(self._function_spec,
                             input_args_with_spec,
                             input_kwargs_with_spec,
                             self._class_instance,
                             **self._kwargs,
                             with_hook=with_hook,
                             is_train=is_train)

        # 3. check whether hit the cache or build a new program for the input arguments
        concrete_program, partial_program_layer = self._program_cache[
            cache_key]
        return concrete_program, partial_program_layer

    def get_traced_count(self):
        """
        Returns the number of traced programs for the decorated function.
        """
        return len(self._program_cache)

    @property
    def code(self):
        """
        Returns the source code of transformed static function for debugging.
        """
        static_func = convert_to_static(self._dygraph_function)
        source_code = func_to_source_code(static_func)
        return source_code

    @property
    def dygraph_function(self):
        """
        Returns the original decorated function.
        """
        return self._dygraph_function

    @property
    def concrete_program(self):
        """
        Returns recent ConcreteProgram instance of decorated function.

        Examples:
            .. code-block:: python

                import paddle
                from paddle.jit import to_static
                from paddle.static import InputSpec

                paddle.disable_static()

                def foo(x, y):
                    z = x + y
                    return z
                
                # usage 1:
                decorated_foo = to_static(foo, input_spec=[InputSpec([10], name='x'), InputSpec([10], name='y')])
                print(decorated_foo.concrete_program)

                # usage 2:
                decorated_foo = to_static(foo)
                out_foo = decorated_foo(paddle.rand([10]), paddle.rand([10]))
                print(decorated_foo.concrete_program)
        """
        return self.concrete_program_specify_input_spec(input_spec=None)

    def concrete_program_specify_input_spec(self,
                                            input_spec=None,
                                            with_hook=False):
        """
        Returns recent ConcreteProgram instance of decorated function while
        specifying input_spec. If the self._function_spec already has
        input_spec, it will check the compatibility of input input_spec and
        the self._function_spec.input_spec. If input input_spec=None, then
        this method uses self._function_spec.input_spec

        args:
            input_spec (list[InputSpec], optional): Describes the input of
                the translate function.
        """
        # if specific the `input_spec`, the length of program_cache will always 1,
        # else, return the last one.
        cached_program_len = len(self._program_cache)
        # If specific `input_spec`, apply convertion from dygraph layers into static Program.
        if cached_program_len == 0:
            desired_input_spec = input_spec
            if self._function_spec.input_spec is not None:
                if input_spec is not None and not input_specs_compatible(
                        flatten(input_spec),
                        flatten(self._function_spec.input_spec)):
                    raise ValueError(
                        "The `input_spec`: {} used to construct concrete_program is conflict with the `input_spec`: {} in `@paddle.jit.to_static`"
                        .format(input_spec, self._function_spec.input_spec))
                # NOTE(chenweihang): we should always translated program based on the `input_spec`
                # decorated on forward if it is valid
                desired_input_spec = self._function_spec.input_spec
                if input_spec is not None:
                    logging_utils.warn(
                        "\n\nYou have specified `input_spec` both in function definition (higher priority) and `paddle.jit.save` (will be ignored.)\n\n\t Using: {}\n\n\t Ignore: {}\n"
                        .format(desired_input_spec, input_spec))

            has_input_spec = (desired_input_spec is not None)
            if has_input_spec:
                concrete_program, _ = self.get_concrete_program(
                    *desired_input_spec,
                    with_hook=with_hook,
                    is_train=self._is_train_mode())
                return concrete_program
            else:
                raise ValueError(
                    "No valid transformed program for {}.\n\t    Please specific `input_spec` in `@paddle.jit.to_static` or feed input tensor to call the decorated function at once.\n"
                    .format(self._function_spec))
        elif with_hook:
            cache_key = self._program_cache._recent_cache_key
            cache_key.kwargs["with_hook"] = True
            concrete_program, _ = self._program_cache[cache_key]
            return concrete_program

        # If more than one programs have been cached, return the recent converted program by default.
        elif cached_program_len > 1:
            logging_utils.warn(
                "Current {} has more than one cached programs: {}, the last traced progam will be return by default."
                .format(self._function_spec, cached_program_len))

        cache_key, (concrete_program,
                    partial_layer) = self._program_cache.last()
        return concrete_program

    @property
    def inputs(self):
        """
        Returns input tensors of recent converted static program.
        """
        concrete_program = self.concrete_program
        inputs = [
            var for var in flatten(concrete_program.inputs)
            if isinstance(var, framework.Variable)
        ]
        return inputs

    @property
    def outputs(self):
        """
        Returns output tensors of recent converted static program.
        """
        concrete_program = self.concrete_program
        outputs = [
            var for var in flatten(concrete_program.outputs)
            if isinstance(var, framework.Variable)
        ]

        return outputs

    @property
    def main_program(self):
        """
        Returns recent converted static main program.
        """
        concrete_program = self.concrete_program
        main_program = concrete_program.main_program
        return main_program

    @property
    def program_cache(self):
        return self._program_cache

    @property
    def function_spec(self):
        return self._function_spec