예제 #1
0
    def testRemoveAlias(self):
        src = "from a import b, c"
        tree = pasta.parse(src)
        import_node = tree.body[0]
        alias1 = import_node.names[0]
        ast_utils.remove_child(import_node, alias1)

        self.assertEqual(pasta.dump(tree), "from a import c")
예제 #2
0
def remove_import_alias_node(sc, node):
    """Remove an alias and if applicable remove their entire import.

  Arguments:
    sc: (scope.Scope) Scope computed on whole tree of the code being modified.
    node: (ast.Import|ast.ImportFrom|ast.alias) The node to remove.
  """
    import_node = sc.parent(node)
    if len(import_node.names) == 1:
        import_parent = sc.parent(import_node)
        ast_utils.remove_child(import_parent, import_node)
    else:
        ast_utils.remove_child(import_node, node)
예제 #3
0
def remove_import(sc, alias_to_remove):
  """Remove an alias and if applicable remove their entire import.

  Arguments:
    sc: (scope.Scope) Scope computed on whole tree of the code being modified.
    alias_to_remove: (list of ast.alias) The import alias nodes to remove.
  """
  import_node = sc.parent(alias_to_remove)
  if len(import_node.names) == 1:
    import_parent = sc.parent(import_node)
    ast_utils.remove_child(import_parent, import_node)
  else:
    ast_utils.remove_child(import_node, alias_to_remove)
예제 #4
0
        def testRemoveFromBlock(self):
            src = """\
if a:
  print("foo!")
  x = 1"""
            tree = pasta.parse(src, py_ver)
            if_block = tree.body[0]
            print_stmt = if_block.body[0]
            ast_utils.remove_child(if_block, print_stmt, py_ver=py_ver)

            expected = """\
if a:
  x = 1"""
            self.assertEqual(pasta.dump(tree, py_ver), expected)
예제 #5
0
        def testRemoveChildMethod(self):
            src = """\
class C():
  def f(x):
    return x + 2
  def g(x):
    return x + 3"""
            tree = pasta.parse(src, py_ver)
            class_node = tree.body[0]
            meth1_node = class_node.body[0]

            ast_utils.remove_child(class_node, meth1_node, py_ver=py_ver)

            result = pasta.dump(tree, py_ver)
            expected = """\
class C():
  def g(x):
    return x + 3"""
            self.assertEqual(result, expected)
예제 #6
0
def inline_name(t, name, py_ver=sys.version_info[:2]):
    """Inline a constant name into a module."""
    sc = scope.analyze(t)
    name_node = sc.names[name]

    # The name must be a Name node (not a FunctionDef, etc.)
    if not isinstance(name_node.definition, (ast27.Name, ast3.Name)):
        raise InlineError('%r is not a constant; it has type %r' %
                          (name, type(name_node.definition)))

    assign_node = sc.parent(name_node.definition)
    if not isinstance(assign_node, (ast27.Assign, ast3.Assign)):
        raise InlineError('%r is not declared in an assignment' % name)

    value = assign_node.value
    if not isinstance(sc.parent(assign_node), (ast27.Module, ast3.Module)):
        raise InlineError('%r is not a top-level name' % name)

    # If the name is written anywhere else in this module, it is not constant
    for ref in name_node.reads:
        if isinstance(getattr(ref, 'ctx', None), (ast27.Store, ast3.Store)):
            raise InlineError('%r is not a constant' % name)

    # Replace all reads of the name with a copy of its value
    for ref in name_node.reads:
        ast_utils.replace_child(sc.parent(ref), ref, copy.deepcopy(value))

    # Remove the assignment to this name
    if len(assign_node.targets) == 1:
        ast_utils.remove_child(sc.parent(assign_node),
                               assign_node,
                               py_ver=py_ver)
    else:
        tgt_list = [
            tgt for tgt in assign_node.targets
            if not (isinstance(tgt, ast.Name) and tgt.id == name)
        ]
        assign_node.targets = tgt_list