Пример #1
1
def refactor_import(q: Query, change_spec):
    """
    1. add "import paddle" if needed.
    2. remove "import paddle.mod" if needed.
    3. remove "import paddle.module as mod", and convert "mod.api" to "paddle.module.api"
    4. remove "from paddle.module import api", and convert "api" to "paddle.module.api"
    """

    # select import_name and import_from
    pattern = """
        (
            file_input< any* >
         |
            name_import=import_name< 'import' '{name}' >
         |
            as_import=import_name< 'import'
                (
                    module_name='{name}'
                |
                    module_name=dotted_name< {dotted_name} any* >
                |
                    dotted_as_name<
                        (
                            module_name='{name}'
                        |
                            module_name=dotted_name< {dotted_name} any* >
                        )
                        'as' module_nickname=any
                    >
                )
            >
        |
            from_import=import_from< 'from'
                (
                    module_name='{name}'
                |
                    module_name=dotted_name< {dotted_name} any* >
                )
                'import' ['(']
                (
                    import_as_name<
                        module_import=any
                        'as'
                        module_nickname=any
                    >*
                |
                    import_as_names<
                        module_imports=any*
                    >
                |
                    module_import=any
                )
             [')'] >
        |
             leaf_node=NAME
        )
    """
    _kwargs = {}
    _kwargs['name'] = 'paddle'
    _kwargs["dotted_name"] = " ".join(quoted_parts(_kwargs["name"]))
    _kwargs["power_name"] = " ".join(power_parts(_kwargs["name"]))
    pattern = pattern.format(**_kwargs)

    imports_map = {}
    paddle_imported = set()
    paddle_found = set()

    def _find_imports(node: LN, capture: Capture, filename: Filename):
        if not is_import(node):
            return True
        if capture and 'name_import' in capture:
            paddle_imported.add(filename)
            paddle_found.add(filename)
        if capture and ('module_import' in capture or 'module_imports'
                        in capture or 'module_nickname' in capture):
            paddle_found.add(filename)
            if filename not in imports_map:
                imports_map[filename] = {}
            if 'module_import' in capture:
                leaf = capture['module_import']
                if leaf.type == token.NAME:
                    old_name = leaf.value.strip()
                    new_name = str(
                        capture['module_name']).strip() + '.' + old_name
                    imports_map[filename][old_name] = new_name
            if 'module_imports' in capture:
                for leaf in capture['module_imports']:
                    if leaf.type == token.NAME:
                        old_name = leaf.value.strip()
                        new_name = str(
                            capture['module_name']).strip() + '.' + old_name
                        imports_map[filename][old_name] = new_name
            if 'module_nickname' in capture:
                old_name = str(capture['module_nickname']).strip()
                new_name = str(capture['module_name']).strip()
                imports_map[filename][old_name] = new_name
        return True

    q.select(pattern).filter(_find_imports)

    # convert to full module path
    def _full_module_path(node: LN, capture: Capture, filename: Filename):
        if not (isinstance(node, Leaf) and node.type == token.NAME):
            return
        if filename not in imports_map:
            return
        logger.debug("{} [{}]: {}".format(filename, list(capture), node))

        # skip import statement
        if utils.is_import_node(node):
            return
        # skip left operand in argument list
        if utils.is_argument_node(node) and utils.is_left_operand(node):
            return
        # skip if it's already a full module path
        if node.prev_sibling is not None and node.prev_sibling.type == token.DOT:
            return

        rename_dict = imports_map[filename]
        if node.value in rename_dict:
            # find old_name and new_name
            old_name = node.value
            new_name = rename_dict[old_name]
            if node.parent is not None:
                _node = utils.code_repr(new_name).children[0].children[0]
                _node.parent = None
                new_node = _node
                new_node.children[0].prefix = node.prefix
                if node.parent.type == python_symbols.power:
                    node.replace(new_node.children)
                else:
                    node.replace(new_node)
                log_info(
                    filename, node.get_lineno(),
                    "{} -> {}".format(utils.node2code(node),
                                      utils.node2code(new_node)))

    q.modify(_full_module_path)

    # remove as_import and from_import
    def _remove_import(node: LN, capture: Capture, filename: Filename):
        if not is_import(node):
            return
        _node = capture.get('as_import', None) or capture.get(
            'from_import', None)
        if _node is not None:
            prefix = _node.prefix
            p = _node.parent
            _node.remove()
            log_warning(filename, p.get_lineno(),
                        'remove "{}"'.format(utils.node2code(_node)))
            # delete NEWLINE node after delete as_import or from_import
            if p and p.children and len(
                    p.children) == 1 and p.children[0].type == token.NEWLINE:
                p.children[0].remove()
                # restore comment
                p.next_sibling.prefix = prefix + p.next_sibling.prefix

    q.modify(_remove_import)

    # add "import paddle" if needed
    def _add_import(node: LN, capture: Capture, filename: Filename):
        if node.type != python_symbols.file_input:
            return
        if filename in paddle_imported:
            return
        if filename in paddle_found:
            touch_import(None, 'paddle', node, force=True)
            log_info(filename, node.get_lineno(), 'add "import paddle"')
            paddle_imported.add(filename)

    q.modify(_add_import)

    return q
Пример #2
0
def transform_imports(query: 'Query', old_name: str, new_name: str) -> 'Query':
    params = dict(
        name=old_name,
        dotted_name=' '.join(bowler_helpers.quoted_parts(old_name)),
        power_name=' '.join(bowler_helpers.power_parts(old_name)),
    )
    for modifier_class in modifiers:
        modifier = modifier_class(old_name=old_name, new_name=new_name)
        selector = modifier.selector.format(**params)
        query = query.select(selector).modify(modifier)

    return query