예제 #1
0
def _create_simple_stmt_node_and_insert_behind(code, node):
    if node is None or node.type != python_symbols.simple_stmt:
        return
    simple_stmt_node = Node(python_symbols.simple_stmt, [utils.newline_node(node)])
    _node = utils.code_repr(code).children[0].children[0]
    _node.parent = None
    simple_stmt_node.insert_child(0, _node)
    simple_stmt_node.prefix = utils.get_indent(node)
    utils.insert_node_behind(node, simple_stmt_node)
예제 #2
0
    def _remove_with_dygraph_guard(node: LN, capture: Capture,
                                   filename: Filename):
        # index of with_node, with_node will be replaced with simple_stmt node
        with_node = capture['with']
        parent = with_node.parent
        idx = None
        for i, child in enumerate(parent.children):
            if child is with_node:
                idx = i
                break

        # create simple_stmt node for "paddle.disable_static"
        arg_list_nodes = capture['arg_list']
        simple_stmt_disable_static = Node(python_symbols.simple_stmt,
                                          [utils.newline_node(node)])
        _node = utils.code_repr('paddle.disable_static' +
                                str(arg_list_nodes)).children[0].children[0]
        _node.parent = None
        simple_stmt_disable_static.insert_child(0, _node)
        simple_stmt_disable_static.prefix = with_node.prefix

        # create simple_stmt node for "paddle.enable_static"
        simple_stmt_enable_static = Node(python_symbols.simple_stmt,
                                         [utils.newline_node(node)])
        simple_stmt_enable_static
        _node = utils.code_repr(
            'paddle.enable_static()').children[0].children[0]
        _node.parent = None
        simple_stmt_enable_static.insert_child(0, _node)
        simple_stmt_enable_static.prefix = utils.get_indent(with_node)

        suite_node = capture['suite']
        # remove first newline
        for node in suite_node.children:
            if not isinstance(node, Leaf):
                continue
            if node.type == token.NEWLINE:
                node.remove()
                break
        # remove first indent node, and add indent prefix to sibling node.
        indent = None
        for node in suite_node.children:
            if not isinstance(node, Leaf):
                continue
            if node.type == token.INDENT:
                indent = node.value
                if node.next_sibling is not None:
                    node.next_sibling.prefix = node.prefix + indent
                node.remove()
                break

        # transfer post leading dedent node prefix to sibling of with node
        leaves = [leaf for leaf in suite_node.leaves()]
        # visit all leaves in reversed order
        last_dedent_leaf_idx = len(leaves)
        for leaf in leaves[::-1]:
            if leaf.type == token.DEDENT:
                with_node.next_sibling.prefix = leaf.prefix + with_node.next_sibling.prefix
                leaf.prefix = ""
            else:
                break

        # remove dedenet node corresponding to with node
        for node in suite_node.children[::-1]:
            if not isinstance(node, Leaf):
                continue
            if node.type == token.DEDENT:
                node.remove()
                break

        # unindent all code in suite
        for node in suite_node.leaves():
            if node.type == token.INDENT:
                node.value = utils.dec_indent(node.value)
            else:
                node.prefix = utils.dec_indent(node.prefix)

        with_node.remove()
        parent.insert_child(idx, simple_stmt_disable_static)
        idx += 1
        for node in suite_node.children:
            parent.insert_child(idx, node)
            idx += 1
        parent.insert_child(idx, simple_stmt_enable_static)
예제 #3
0
def sort_imports(root, capture, filename):
    statement_nodes = get_top_import_nodes(root)

    module_imports = []

    # * Go through all top-of-file imports.
    # * Index them by module name.
    # * Do inline sorting of imported names (`import b, c, a` --> `import a, b, c`)
    for i, stmt in enumerate(statement_nodes):
        first_name = None
        imp = stmt.children[0]
        if imp.type == syms.import_name:
            module_node = imp.children[1]
            if module_node.type == TOKEN.NAME:
                # 'import os'
                module = module_node.value
            elif module_node.type == syms.dotted_name:
                # 'import x.y'
                module = str(module_node)
            elif module_node.type == syms.dotted_as_name:
                # 'import os as OS'
                module = module_node.children[0].value
            elif module_node.type == syms.dotted_as_names:
                # 'import os, io'
                module = _sort_imported_names(imp.children[1])

            else:
                raise ValueError(f"Unknown import format: {imp}")
        elif imp.type == syms.import_from:
            module_node = imp.children[1]
            if module_node.type == syms.dotted_name:
                # 'from x.y import z'
                module = ''.join(c.value for c in module_node.children)
            else:
                module = module_node.value
            names = [n for n in imp.children[3:] if n.type != TOKEN.LPAR]
            if names[0].type == TOKEN.NAME:
                # 'from x import y'
                first_name = names[0].value
            elif names[0].type == syms.import_as_name:
                # 'from x import y as z'
                first_name = names[0].children[0].value
            elif names[0].type == syms.import_as_names:
                # 'from x import y, z'
                # 'from x import y as a, z as b'
                first_name = _sort_imported_names(names[0])
            else:
                raise ValueError(f"Unknown import format: {imp}")
        else:
            # top-of-module docstring. float to top.
            module = ''

        module = module.strip()
        root_module_name = module.split('.')[0]

        # do 'from ...' imports after 'import ...' imports.
        from_ = 1 if first_name is not None else 0

        if root_module_name == '':
            # module docstring
            group = Groups.DOCSTRING
        elif root_module_name == '__future__':
            # special case; must come first
            group = Groups.FUTURE
        elif root_module_name in STDLIB_MODULES:
            # stdlib modules
            group = Groups.STDLIB
        elif root_module_name not in cfg['first_party_modules']:
            # third party modules
            group = Groups.THIRD_PARTY
        else:
            # first party modules
            group = Groups.FIRST_PARTY

        # note: the `i` here is for a weird edge case where you try to sort
        # two of the same exact import.
        # turns out, Node instances aren't comparable, so we get an error if
        # the sort ever has to compare them.
        # So we insert a unique integer before them, thus preventing us ever having to
        # compare the node instances.
        module_imports.append((group, from_, module.lower(), first_name
                               and first_name.lower(), i, stmt))

    # Now sort the various lines we've encountered.
    module_imports.sort()

    # Now, clear out all the statements from the parse tree
    for n in statement_nodes:
        n.remove()

    # Then repopulate the tree with the sorted nodes, cleaning up whitespace as we go.
    last_group = 0
    last_root_module_name = None
    for i, (group, from_, module_lower, first_name_lower, _,
            stmt_node) in enumerate(module_imports):
        assert len(stmt_node.children) == 2
        root_module_name = module_lower.split('.')[0]
        import_node = stmt_node.children[0]
        newline_node = stmt_node.children[1]
        prefix = strip_prefix(import_node.prefix)
        if i != 0:
            if last_group != group:
                # add a space between groups.
                prefix = f'\n{prefix}'
            elif (last_root_module_name != root_module_name
                  and group == Groups.FIRST_PARTY):
                # also add a space between different first-party projects.
                prefix = f'\n{prefix}'

        new_stmt = Node(
            syms.simple_stmt,
            [import_node.clone(), newline_node.clone()])
        new_stmt.prefix = prefix
        root.insert_child(i, new_stmt)
        last_group = group
        last_root_module_name = root_module_name