Example #1
0
def assert_ast_eq(testcase, orig_ast, expected_ast):
    
    if not cmp_ast(orig_ast, expected_ast):
        str1 = str_ast(orig_ast, indent=' ', newline='\n')
        str2 = str_ast(expected_ast, indent=' ', newline='\n')
        msg = 'AST Trees are not equal\n## left ########### \n%s\n## right ########### \n%s' % (str1, str2)
        testcase.fail(msg)
Example #2
0
    def UNPACK_SEQUENCE(self, instr):
        nargs = instr.oparg

        nodes = []
        ast_tuple = _ast.Tuple(elts=nodes, ctx=_ast.Store(), lineno=instr.lineno, col_offset=0)
        for i in range(nargs):
            nex_instr = self.ilst.pop(0)
            self.ast_stack.append(None)
            self.visit(nex_instr)

            node = self.ast_stack.pop()
            nodes.append(node.targets[0])

        expr = self.ast_stack.pop()
        if isinstance(expr, _ast.Assign):
            assgn = expr 
            assgn.targets.append(ast_tuple)
            
            value_dup = self.ast_stack.pop()
            
            assert cmp_ast(assgn.value, value_dup)
            
        else:
            assgn = _ast.Assign(targets=[ast_tuple], value=expr, lineno=instr.lineno, col_offset=0)
        self.ast_stack.append(assgn)
    def UNPACK_SEQUENCE(self, instr):
        nargs = instr.oparg

        nodes = []
        ast_tuple = _ast.Tuple(elts=nodes,
                               ctx=_ast.Store(),
                               lineno=instr.lineno,
                               col_offset=0)
        for i in range(nargs):
            nex_instr = self.ilst.pop(0)
            self.push_ast_item(None)
            self.visit(nex_instr)

            node = self.pop_ast_item()
            nodes.append(node.targets[0])

        expr = self.pop_ast_item()
        if isinstance(expr, _ast.Assign):
            assgn = expr
            assgn.targets.append(ast_tuple)

            value_dup = self.pop_ast_item()

            assert cmp_ast(assgn.value, value_dup)

        else:
            assgn = _ast.Assign(targets=[ast_tuple],
                                value=expr,
                                lineno=instr.lineno,
                                col_offset=0)
        self.push_ast_item(assgn)
Example #4
0
    def test_to_file(self) -> None:
        """
        Tests whether `file` constructs a file, and fills it with the right content
        """

        with TemporaryDirectory() as tempdir:
            filename = os.path.join(tempdir, "delete_me.py")
            try:
                emit.file(class_ast, filename, skip_black=True)

                with open(filename, "rt") as f:
                    ugly = f.read()

                os.remove(filename)

                emit.file(class_ast, filename, skip_black=False)

                with open(filename, "rt") as f:
                    blacked = f.read()

                self.assertNotEqual(ugly, blacked)
                # if PY3_8:
                self.assertTrue(
                    cmp_ast(ast.parse(ugly), ast.parse(blacked)),
                    "Ugly AST doesn't match blacked AST",
                )

            finally:
                if os.path.isfile(filename):
                    os.remove(filename)
Example #5
0
    def test_parse_to_scalar(self) -> None:
        """ Test various inputs and outputs for `parse_to_scalar` """
        for fst, snd in (
            (5, 5),
            ("5", "5"),
            (set_value(5), 5),
            (ast.Expr(None), NoneStr),
        ):
            self.assertEqual(parse_to_scalar(fst), snd)

        self.assertEqual(
            get_value(parse_to_scalar(ast.parse("[5]").body[0]).elts[0]), 5
        )
        self.assertTrue(
            cmp_ast(
                parse_to_scalar(ast.parse("[5]").body[0]),
                List([set_value(5)], Load()),
            )
        )

        self.assertEqual(parse_to_scalar(ast.parse("[5]")), "[5]")

        parse_to_scalar(ast.parse("[5]").body[0])

        self.assertRaises(NotImplementedError, parse_to_scalar, memoryview(b""))
        self.assertRaises(NotImplementedError, parse_to_scalar, memoryview(b""))
Example #6
0
def assert_ast_eq(testcase, orig_ast, expected_ast):

    if not cmp_ast(orig_ast, expected_ast):
        str1 = str_ast(orig_ast, indent=' ', newline='\n')
        str2 = str_ast(expected_ast, indent=' ', newline='\n')
        msg = 'AST Trees are not equal\n## left ########### \n%s\n## right ########### \n%s' % (
            str1, str2)
        testcase.fail(msg)
Example #7
0
def run_ast_test(test_case_instance, gen_ast, gold, skip_black=False):
    """
    Compares `gen_ast` with `gold` standard

    :param test_case_instance: instance of `TestCase`
    :type test_case_instance: ```unittest.TestCase```

    :param gen_ast: generated AST
    :type gen_ast: ```Union[ast.Module, ast.ClassDef, ast.FunctionDef]```

    :param skip_black: Whether to skip formatting with black. Turned off for performance, turn on for pretty debug.
    :type skip_black: ```bool```

    :param gold: mocked AST
    :type gold: ```Union[ast.Module, ast.ClassDef, ast.FunctionDef]```
    """
    if isinstance(gen_ast, str):
        gen_ast = ast.parse(gen_ast).body[0]

    assert gen_ast is not None, "gen_ast is None"
    assert gold is not None, "gold is None"

    gen_ast = deepcopy(gen_ast)
    gold = deepcopy(gold)

    # if reindent_docstring:
    #           gen_docstring = ast.get_docstring(gen_ast)
    #           if gen_docstring is not None:
    #               gen_ast.body[0] = set_value(
    #                   "\n{}".format(indent(cleandoc(gen_docstring), tab))
    #               )
    #           gold.body[0] = set_value(
    #               "\n{}".format(indent(ast.get_docstring(gold, clean=True), tab))
    #           )

    # from meta.asttools import print_ast
    #
    # print("#gen")
    # print_ast(gen_ast)
    # print("#gold")
    # print_ast(gold)

    test_case_instance.assertEqual(*map(
        identity if skip_black else partial(
            format_str,
            mode=Mode(
                target_versions=set(),
                line_length=60,
                is_pyi=False,
                string_normalization=False,
            ),
        ),
        map(source_transformer.to_code, (gold, gen_ast)),
    ))

    test_case_instance.assertTrue(cmp_ast(gen_ast, gold),
                                  "Generated AST doesn't match reference AST")
Example #8
0
def run_ast_test(test_case_instance, gen_ast, gold, skip_black=False):
    """
    Compares `gen_ast` with `gold` standard

    :param test_case_instance: instance of `TestCase`
    :type test_case_instance: ```unittest.TestCase```

    :param gen_ast: generated AST
    :type gen_ast: ```Union[ast.Module, ast.ClassDef, ast.FunctionDef]```

    :param gold: mocked AST
    :type gold: ```Union[ast.Module, ast.ClassDef, ast.FunctionDef]```

    :param skip_black: Whether to skip black
    :type skip_black: ```bool```
    """
    if isinstance(gen_ast, str):
        gen_ast = ast.parse(gen_ast).body[0]

    assert gen_ast is not None, "gen_ast is None"
    assert gold is not None, "gold is None"

    gen_ast = deepcopy(gen_ast)
    gold = deepcopy(gold)

    if hasattr(gen_ast, "body") and len(gen_ast.body) > 0:
        gen_docstring = ast.get_docstring(gen_ast)
        gold_docstring = ast.get_docstring(gold)
        if gen_docstring is not None and gold_docstring is not None:
            test_case_instance.assertEqual(gold_docstring.strip(),
                                           gen_docstring.strip())
            # Following test issue with docstring indentation, remove them from the AST, as symmetry has been confirmed
            gen_ast.body.pop(0)
            gold.body.pop(0)

    test_case_instance.assertEqual(*map(
        partial(
            (lambda _, **kwargs: _) if skip_black else format_str,
            mode=Mode(
                target_versions=set(),
                line_length=60,
                is_pyi=False,
                string_normalization=False,
            ),
        ),
        map(doctrans.source_transformer.to_code, (gold, gen_ast)),
    ))

    # from meta.asttools import print_ast
    # print_ast(gen_ast)
    # print_ast(gold)
    test_case_instance.assertTrue(cmp_ast(gen_ast, gold),
                                  "Generated AST doesn't match reference AST")
Example #9
0
    def assertAstEqual(self, left, right):

        if not isinstance(left, _ast.AST):
            raise self.failureException("%s is not an _ast.AST instance" % (left))
        if not isinstance(right, _ast.AST):
            raise self.failureException("%s is not an _ast.AST instance" % (right))
        result = cmp_ast(left, right)

        if not result:
            
            lstream = StringIO()
            print_ast(left, indent='', file=lstream, newline='')

            rstream = StringIO()
            print_ast(right, indent='', file=rstream, newline='')

            lstream.seek(0)
            rstream.seek(0)
            msg = 'Ast Not Equal:\nGenerated: %r\nExpected:  %r' % (lstream.read(), rstream.read())
            raise self.failureException(msg)
Example #10
0
    def STORE_SLICE_3(self, instr):
        'obj[lower:upper] = expr'

        upper = self.ast_stack.pop()
        lower = self.ast_stack.pop()
        value = self.ast_stack.pop()
        expr = self.ast_stack.pop()
        
        kw = dict(lineno=instr.lineno, col_offset=0)
        slice = _ast.Slice(lower=lower, step=None, upper=upper, **kw)
        subscr = _ast.Subscript(value=value, slice=slice, ctx=_ast.Store(), **kw)
        
        if isinstance(expr, _ast.AugAssign):
            assign = expr
            result = cmp_ast(expr.target, subscr)
            
            assert result
        else:
            assign = _ast.Assign(targets=[subscr], value=expr, **kw)
            
        self.ast_stack.append(assign)
Example #11
0
 def test_get_at_root(self) -> None:
     """ Tests that `get_at_root` successfully gets the imports """
     with open(path.join(path.dirname(__file__), "mocks", "eval.py")) as f:
         imports = get_at_root(ast.parse(f.read()), (Import, ImportFrom))
     self.assertIsInstance(imports, list)
     self.assertEqual(len(imports), 1)
     self.assertTrue(
         cmp_ast(
             imports[0],
             ast.Import(
                 names=[
                     ast.alias(
                         asname=None,
                         name="doctrans.tests.mocks",
                         identifier=None,
                         identifier_name=None,
                     )
                 ],
                 alias=None,
             ),
         ))
Example #12
0
    def assertAstEqual(self, left, right):

        if not isinstance(left, _ast.AST):
            raise self.failureException("%s is not an _ast.AST instance" %
                                        (left))
        if not isinstance(right, _ast.AST):
            raise self.failureException("%s is not an _ast.AST instance" %
                                        (right))
        result = cmp_ast(left, right)

        if not result:

            lstream = StringIO()
            print_ast(left, indent='', file=lstream, newline='')

            rstream = StringIO()
            print_ast(right, indent='', file=rstream, newline='')

            lstream.seek(0)
            rstream.seek(0)
            msg = 'Ast Not Equal:\nGenerated: %r\nExpected:  %r' % (
                lstream.read(), rstream.read())
            raise self.failureException(msg)
    def STORE_SLICE_3(self, instr):
        'obj[lower:upper] = expr'

        upper = self.pop_ast_item()
        lower = self.pop_ast_item()
        value = self.pop_ast_item()
        expr = self.pop_ast_item()

        kw = dict(lineno=instr.lineno, col_offset=0)
        slice = _ast.Slice(lower=lower, step=None, upper=upper, **kw)
        subscr = _ast.Subscript(value=value,
                                slice=slice,
                                ctx=_ast.Store(),
                                **kw)

        if isinstance(expr, _ast.AugAssign):
            assign = expr
            result = cmp_ast(expr.target, subscr)

            assert result
        else:
            assign = _ast.Assign(targets=[subscr], value=expr, **kw)

        self.push_ast_item(assign)
Example #14
0
def _conform_filename(
    filename,
    search,
    emit_func,
    replacement_node_ir,
    type_wanted,
):
    """
    Conform the given file to the `intermediate_repr`

    :param filename: Location of file
    :type filename: ```str```

    :param search: Search query, e.g., ['node_name', 'function_name', 'arg_name']
    :type search: ```List[str]```

    :param replacement_node_ir: Replace what is found with the contents of this param
    :type replacement_node_ir: ```dict```

    :param type_wanted: AST instance
    :type type_wanted: ```AST```

    :return: filename, whether the file was modified
    :rtype: ```Tuple[str, bool]```
    """
    filename = path.realpath(path.expanduser(filename))

    if not path.isfile(filename):
        emit.file(
            emit_func(
                replacement_node_ir,
                emit_default_doc=False,  # emit_func.__name__ == "class_"
            ),
            filename=filename,
            mode="wt",
            skip_black=False,
        )
        return filename, True

    with open(filename, "rt") as f:
        parsed_ast = ast_parse(f.read(), filename=filename)
    assert isinstance(parsed_ast, Module)

    original_node = find_in_ast(search, parsed_ast)
    replacement_node = emit_func(
        replacement_node_ir,
        **_default_options(node=original_node,
                           search=search,
                           type_wanted=type_wanted)())
    if original_node is None:
        emit.file(replacement_node,
                  filename=filename,
                  mode="a",
                  skip_black=False)
        return filename, True
    assert len(search) > 0

    assert type(
        replacement_node) == type_wanted, "Expected {!r} got {!r}".format(
            type_wanted,
            type(replacement_node).__name__)

    replaced = False
    if not cmp_ast(original_node, replacement_node):
        rewrite_at_query = RewriteAtQuery(
            search=search,
            replacement_node=replacement_node,
        )
        rewrite_at_query.visit(parsed_ast)

        print("modified" if rewrite_at_query.replaced else "unchanged",
              filename,
              sep="\t")
        if rewrite_at_query.replaced:
            emit.file(parsed_ast, filename, mode="wt", skip_black=False)

        replaced = rewrite_at_query.replaced

    return filename, replaced