Esempio n. 1
0
    def visit_Call(self, node):
        assert isinstance(node, gast.Call)
        if is_paddle_api(node):
            # Visit gast.Attribute and gast.Name to replace tensor.shape if necessary.
            self.generic_visit(node)

        return node
    def _used_by_paddle_api(self, node):
        """
        Whether node is used in paddle api as arguments.
        For example:
            1) Return True in `paddle.relu(x)` where node is `x` (gast.Name)
            2) Return True in `paddle.add(self.x)` where node is `self.x` (gast.Attribute)
            3) Return False in `paddle.add(self.x)` where node is `paddle.add` (gast.Attribute),
               because the role of node is not arguments but `gast.Call.func`.
        """
        assert isinstance(node, (gast.Attribute, gast.Name))
        wrapper_node = self.node_to_wrapper_map.get(node)
        if not wrapper_node:
            # Transformed node is not in node_to_wrapper_map
            return False
        while wrapper_node.parent:
            parent_node = wrapper_node.parent.node
            if isinstance(parent_node, gast.Call):
                # Note(Aurelius84): Filter the case when the role of node is `gast.Call.func`.
                if is_paddle_api(parent_node) and parent_node.func != node:
                    return True
                else:
                    return False
            wrapper_node = wrapper_node.parent

        return False
Esempio n. 3
0
def is_to_variable(node):
    assert isinstance(node, gast.Call)
    api_name = utils.ast_to_source_code(node.func).strip()

    if utils.is_dygraph_api(node):
        return api_name.endswith("to_variable")

    if utils.is_paddle_api(node):
        return api_name.endswith("to_tensor")

    return False
Esempio n. 4
0
def is_grad_api_node(node):
    assert isinstance(node, gast.Call)
    api_name = utils.ast_to_source_code(node.func).strip()
    if utils.is_paddle_api(node):
        if 'no_grad' in api_name:
            warnings.warn(
                "paddle.no_grad is only supported for inference model, and not supported for training under @to_static."
            )
            return False
        return api_name.endswith("grad")
    return False
Esempio n. 5
0
    def visit_Call(self, node):
        self.generic_visit(node)
        if is_paddle_api(node):
            return node

        if self._is_builtin_call(node):
            return node

        func_str = ast_to_source_code(node.func).strip()
        new_func_str = "fluid.dygraph.dygraph_to_static.convert_call({})".format(
            func_str)
        new_func_ast = gast.parse(new_func_str).body[0].value
        node.func = new_func_ast

        return node
Esempio n. 6
0
    def _used_by_paddle_api(self, node):
        assert isinstance(node, (gast.Attribute, gast.Name))
        wrapper_node = self.node_to_wrapper_map.get(node)
        if not wrapper_node:
            # Transformed node is not in node_to_wrapper_map
            return False
        while wrapper_node.parent:
            parent_node = wrapper_node.parent.node
            if isinstance(parent_node, gast.Call):
                if is_paddle_api(parent_node):
                    return True
                else:
                    return False
            wrapper_node = wrapper_node.parent

        return False
Esempio n. 7
0
    def _no_need_convert_call(self, node):
        """
        Determines whether a function needs to be transformed by `convert_call`.
        It doesn't need to be transformed when a function satisfies the following conditions:
          1. It's a api of paddle
          2. It's a python builtin function not include `len`
        """
        assert isinstance(node, gast.Call)
        if is_paddle_api(node):
            return True

        func_str = ast_to_source_code(node.func).strip()
        try:
            from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin_len, is_builtin
            is_builtin = eval("is_builtin({})".format(func_str))
            is_builtin_len = eval("is_builtin_len({})".format(func_str))
            return is_builtin and not is_builtin_len
        except Exception:
            return False
 def visit_Call(self, node):
     if is_paddle_api(node):
         # Visit gast.Attribute and gast.Name to replace var.shape if necessary.
         self.generic_visit(node)
     # Don't have to visit other APIs
     return node
Esempio n. 9
0
 def visit_Call(self, node):
     self._visit_Call(node)
     if is_paddle_api(node):
         self.is_control_flow_num += 1
     return node