Example #1
0
 def _to_array_write_node(self, node):
     assert isinstance(node, gast.Call)
     array = astor.to_source(gast.gast_to_ast(node.func.value))
     x = astor.to_source(gast.gast_to_ast(node.args[0]))
     i = "paddle.tensor.array_length({})".format(array)
     func_code = "paddle.tensor.array_write(x={}, i={}, array={})".format(
         x, i, array)
     return gast.parse(func_code).body[0].value
Example #2
0
def code_gast_ast(source):
    """
    Transform source_code into gast.Node and modify it,
    then back to ast.Node.
    """
    source = textwrap.dedent(source)
    root = gast.parse(source)
    new_root = GastNodeTransformer(root).apply()
    ast_root = gast.gast_to_ast(new_root)
    return ast.dump(ast_root)
Example #3
0
    def _visit_Call(self, node):
        assert isinstance(node, gast.Call)
        func_name = astor.to_source(gast.gast_to_ast(node.func))

        if self._is_dygraph_forward(func_name):
            class_node = self._get_class_node(func_name)
            static_node = utils.to_static_ast(node, class_node)
            return static_node
        else:
            return node
Example #4
0
    def _update_class_node_dict(self, node):
        assert isinstance(node, gast.Assign)
        node_value = node.value
        if isinstance(node_value, gast.Call):
            if is_to_variable(node_value):
                return False

            if utils.is_dygraph_api(node_value):
                dygraph_api = node_value.func.attr
                if not utils.dygraph_class_to_static_api.get(dygraph_api):
                    return False

                utils.update_args_of_func(node_value, node_value, "__init__")
                target_str = astor.to_source(gast.gast_to_ast(node.targets[0]))
                self.class_node_dict[target_str] = node_value
                return True
            # TODO: node.value is not dygraph class
        return False
Example #5
0
    def _is_list_append_tensor(self, node):
        """
        a.append(b): a is list, b is Tensor
        self.x.append(b): self.x is list, b is Tensor
        """
        assert isinstance(node, gast.Call)
        # 1. The func is `append`.
        if not isinstance(node.func, gast.Attribute):
            return False
        if node.func.attr != 'append':
            return False

        # 2. It's a `python list` to call append().
        value_name = astor.to_source(gast.gast_to_ast(node.func.value)).strip()
        if value_name not in self.list_name_to_updated:
            return False

        # 3. The number of arg of append() is one
        # Only one argument is supported in Python list.append()
        if len(node.args) != 1:
            return False

        # TODO(liym27): The arg of append() should be Tensor. But because the type of arg is often wrong with static analysis,
        #   the arg is not required to be Tensor here.
        # 4. The arg of append() is Tensor
        # arg = node.args[0]
        # if isinstance(arg, gast.Name):
        #     # TODO: `arg.id` may be not in scope_var_type_dict if `arg.id` is the arg of decorated function
        #     # Need a better way to confirm whether `arg.id` is a Tensor.
        #     try:
        #         var_type_set = self.scope_var_type_dict[arg.id]
        #     except KeyError:
        #         return False
        #     if NodeVarType.NUMPY_NDARRAY in var_type_set:
        #         return False
        #     if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
        #         return False
        # # TODO: Consider that `arg` may be a gast.Call about Paddle Api. eg: list_a.append(paddle.reshape(x))
        # # else:
        # # return True
        self.list_name_to_updated[value_name.strip()] = True
        return True
Example #6
0
def get_source_code(func):
    raw_code = inspect.getsource(func)
    code = textwrap.dedent(raw_code)
    root = gast.parse(code)
    source_code = astor.to_source(gast.gast_to_ast(root))
    return source_code