def test_split_nested_imports(self):
    test_cases = (
        'def foo():\n  {import_stmt}\n',
        'class Foo(object):\n  {import_stmt}\n',
        'if foo:\n  {import_stmt}\nelse:\n  pass\n',
        'if foo:\n  pass\nelse:\n  {import_stmt}\n',
        'if foo:\n  pass\nelif bar:\n  {import_stmt}\n',
        'try:\n  {import_stmt}\nexcept:\n  pass\n',
        'try:\n  pass\nexcept:\n  {import_stmt}\n',
        'try:\n  pass\nfinally:\n  {import_stmt}\n',
        'for i in foo:\n  {import_stmt}\n',
        'for i in foo:\n  pass\nelse:\n  {import_stmt}\n',
        'while foo:\n  {import_stmt}\n',
    )

    for template in test_cases:
      try:
        src = template.format(import_stmt='import aaa, bbb, ccc')
        t = ast.parse(src)
        sc = scope.analyze(t)
        import_node = ast_utils.find_nodes_by_type(t, ast.Import)[0]
        import_utils.split_import(sc, import_node, import_node.names[1])

        split_import_nodes = ast_utils.find_nodes_by_type(t, ast.Import)
        self.assertEqual(1, len(t.body))
        self.assertEqual(2, len(split_import_nodes))
        self.assertEqual([alias.name for alias in split_import_nodes[0].names],
                         ['aaa', 'ccc'])
        self.assertEqual([alias.name for alias in split_import_nodes[1].names],
                         ['bbb'])
      except:
        self.fail('Failed while executing case:\n%s\nCaused by:\n%s' % 
                  (src, traceback.format_exc()))
  def test_split_imports_with_alias(self):
    src = 'import aaa as a, bbb as b, ccc as c\n'
    t = ast.parse(src)
    import_node = t.body[0]
    sc = scope.analyze(t)
    import_utils.split_import(sc, import_node, import_node.names[1])

    self.assertEqual(2, len(t.body))
    self.assertEqual([alias.name for alias in t.body[0].names], ['aaa', 'ccc'])
    self.assertEqual([alias.name for alias in t.body[1].names], ['bbb'])
    self.assertEqual(t.body[1].names[0].asname, 'b')
  def test_split_normal_import(self):
    src = 'import aaa, bbb, ccc\n'
    t = ast.parse(src)
    import_node = t.body[0]
    sc = scope.analyze(t)
    import_utils.split_import(sc, import_node, import_node.names[1])

    self.assertEqual(2, len(t.body))
    self.assertEqual(ast.Import, type(t.body[1]))
    self.assertEqual([alias.name for alias in t.body[0].names], ['aaa', 'ccc'])
    self.assertEqual([alias.name for alias in t.body[1].names], ['bbb'])
  def test_split_from_import(self):
    src = 'from aaa import bbb, ccc, ddd\n'
    t = ast.parse(src)
    import_node = t.body[0]
    sc = scope.analyze(t)
    import_utils.split_import(sc, import_node, import_node.names[1])

    self.assertEqual(2, len(t.body))
    self.assertEqual(ast.ImportFrom, type(t.body[1]))
    self.assertEqual(t.body[0].module, 'aaa')
    self.assertEqual(t.body[1].module, 'aaa')
    self.assertEqual([alias.name for alias in t.body[0].names], ['bbb', 'ddd'])
  def test_split_imports_multiple(self):
    src = 'import aaa, bbb, ccc\n'
    t = ast.parse(src)
    import_node = t.body[0]
    alias_bbb = import_node.names[1]
    alias_ccc = import_node.names[2]
    sc = scope.analyze(t)
    import_utils.split_import(sc, import_node, alias_bbb)
    import_utils.split_import(sc, import_node, alias_ccc)

    self.assertEqual(3, len(t.body))
    self.assertEqual([alias.name for alias in t.body[0].names], ['aaa'])
    self.assertEqual([alias.name for alias in t.body[1].names], ['ccc'])
    self.assertEqual([alias.name for alias in t.body[2].names], ['bbb'])
示例#6
0
def _rename_name_in_importfrom(sc, node, old_name, new_name):
    if old_name == new_name:
        return False

    module_parts = node.module.split('.')
    old_parts = old_name.split('.')
    new_parts = new_name.split('.')

    # If just the module is changing, rename it
    if module_parts[:len(old_parts)] == old_parts:
        node.module = '.'.join(new_parts + module_parts[len(old_parts):])
        return True

    # Find the alias node to be changed
    for alias_to_change in node.names:
        if alias_to_change.name == old_parts[-1]:
            break
    else:
        return False

    alias_to_change.name = new_parts[-1]

    # Split the import if the package has changed
    if module_parts != new_parts[:-1]:
        if len(node.names) > 1:
            new_import = import_utils.split_import(sc, node, alias_to_change)
            new_import.module = '.'.join(new_parts[:-1])
        else:
            node.module = '.'.join(new_parts[:-1])

    return True