def test_split_nested_imports(self): test_cases = ( 'def foo():\n {import_stmt}\n', 'class Foo(object):\n {import_stmt}\n', 'if foo:\n {import_stmt}\nelse:\n pass\n', 'if foo:\n pass\nelse:\n {import_stmt}\n', 'if foo:\n pass\nelif bar:\n {import_stmt}\n', 'try:\n {import_stmt}\nexcept:\n pass\n', 'try:\n pass\nexcept:\n {import_stmt}\n', 'try:\n pass\nfinally:\n {import_stmt}\n', 'for i in foo:\n {import_stmt}\n', 'for i in foo:\n pass\nelse:\n {import_stmt}\n', 'while foo:\n {import_stmt}\n', ) for template in test_cases: try: src = template.format(import_stmt='import aaa, bbb, ccc') t = ast.parse(src) sc = scope.analyze(t) import_node = ast_utils.find_nodes_by_type(t, ast.Import)[0] import_utils.split_import(sc, import_node, import_node.names[1]) split_import_nodes = ast_utils.find_nodes_by_type(t, ast.Import) self.assertEqual(1, len(t.body)) self.assertEqual(2, len(split_import_nodes)) self.assertEqual([alias.name for alias in split_import_nodes[0].names], ['aaa', 'ccc']) self.assertEqual([alias.name for alias in split_import_nodes[1].names], ['bbb']) except: self.fail('Failed while executing case:\n%s\nCaused by:\n%s' % (src, traceback.format_exc()))
def test_split_imports_with_alias(self): src = 'import aaa as a, bbb as b, ccc as c\n' t = ast.parse(src) import_node = t.body[0] sc = scope.analyze(t) import_utils.split_import(sc, import_node, import_node.names[1]) self.assertEqual(2, len(t.body)) self.assertEqual([alias.name for alias in t.body[0].names], ['aaa', 'ccc']) self.assertEqual([alias.name for alias in t.body[1].names], ['bbb']) self.assertEqual(t.body[1].names[0].asname, 'b')
def test_split_normal_import(self): src = 'import aaa, bbb, ccc\n' t = ast.parse(src) import_node = t.body[0] sc = scope.analyze(t) import_utils.split_import(sc, import_node, import_node.names[1]) self.assertEqual(2, len(t.body)) self.assertEqual(ast.Import, type(t.body[1])) self.assertEqual([alias.name for alias in t.body[0].names], ['aaa', 'ccc']) self.assertEqual([alias.name for alias in t.body[1].names], ['bbb'])
def test_split_from_import(self): src = 'from aaa import bbb, ccc, ddd\n' t = ast.parse(src) import_node = t.body[0] sc = scope.analyze(t) import_utils.split_import(sc, import_node, import_node.names[1]) self.assertEqual(2, len(t.body)) self.assertEqual(ast.ImportFrom, type(t.body[1])) self.assertEqual(t.body[0].module, 'aaa') self.assertEqual(t.body[1].module, 'aaa') self.assertEqual([alias.name for alias in t.body[0].names], ['bbb', 'ddd'])
def test_split_imports_multiple(self): src = 'import aaa, bbb, ccc\n' t = ast.parse(src) import_node = t.body[0] alias_bbb = import_node.names[1] alias_ccc = import_node.names[2] sc = scope.analyze(t) import_utils.split_import(sc, import_node, alias_bbb) import_utils.split_import(sc, import_node, alias_ccc) self.assertEqual(3, len(t.body)) self.assertEqual([alias.name for alias in t.body[0].names], ['aaa']) self.assertEqual([alias.name for alias in t.body[1].names], ['ccc']) self.assertEqual([alias.name for alias in t.body[2].names], ['bbb'])
def _rename_name_in_importfrom(sc, node, old_name, new_name): if old_name == new_name: return False module_parts = node.module.split('.') old_parts = old_name.split('.') new_parts = new_name.split('.') # If just the module is changing, rename it if module_parts[:len(old_parts)] == old_parts: node.module = '.'.join(new_parts + module_parts[len(old_parts):]) return True # Find the alias node to be changed for alias_to_change in node.names: if alias_to_change.name == old_parts[-1]: break else: return False alias_to_change.name = new_parts[-1] # Split the import if the package has changed if module_parts != new_parts[:-1]: if len(node.names) > 1: new_import = import_utils.split_import(sc, node, alias_to_change) new_import.module = '.'.join(new_parts[:-1]) else: node.module = '.'.join(new_parts[:-1]) return True