Beispiel #1
0
    def test_return_simple(self):
        orig_code = utils.clip_head("""
        def func(a, b):
            return a + b

        value = func(3, 39)
        """)
        target_code = utils.clip_head("""
        def func(a, b):
            returned_value = None
            returned_1 = False

            returned_1 = True
            returned_value = a + b
            
            return returned_value

        value = func(3, 39)
        """)
        orig_ast = gast.ast_to_gast(ast.parse(orig_code))
        target_ast = gast.ast_to_gast(ast.parse(target_code))
        converted_ast = self.canonicalizer.visit(orig_ast)

        assert_semantically_equals(orig_code, target_code, ['value'])
        assert compare_ast(converted_ast, target_ast)
        assert compare_ast(target_ast, converted_ast)
Beispiel #2
0
    def test_continue_nested(self):
        orig_code = utils.clip_head("""
        x = 0
        for i in range(10):
            if i == 5:
                continue
            for j in range(10):
                if j == 5:
                    continue
                x += i * j
        """)
        target_code = utils.clip_head("""
        x = 0
        for i in range(10):
            continued_0 = False
            if i == 5:
                continued_0 = True
            if not continued_0:
                for j in range(10):
                    continued_1 = False
                    if j == 5:
                        continued_1 = True
                    if not continued_1:
                        x += i * j
        """)
        orig_ast = gast.ast_to_gast(ast.parse(orig_code))
        target_ast = gast.ast_to_gast(ast.parse(target_code))
        converted_ast = self.canonicalizer.visit(orig_ast)

        assert_semantically_equals(orig_code, target_code, ['x', 'i', 'j'])
        assert compare_ast(converted_ast, target_ast)
        assert compare_ast(target_ast, converted_ast)
Beispiel #3
0
    def test_break(self):
        orig_code = utils.clip_head("""
        x = 0
        for i in range(10):
            if i == 5:
                break
            x += i
        """)
        target_code = utils.clip_head("""
        x = 0
        for i in range(10):
            breaked_0 = False
            if i == 5:
                breaked_0 = True
            if not breaked_0:
                x += i
            keepgoing = not breaked_0
            if breaked_0:
                break
        """)
        orig_ast = gast.ast_to_gast(ast.parse(orig_code))
        target_ast = gast.ast_to_gast(ast.parse(target_code))
        converted_ast = self.canonicalizer.visit(orig_ast)

        assert_semantically_equals(orig_code, target_code, ['x', 'i'])
        assert compare_ast(converted_ast, target_ast)
        assert compare_ast(target_ast, converted_ast)
Beispiel #4
0
    def test_continue_break_nested(self):
        orig_code = utils.clip_head("""
        x = 0
        for i in range(10):
            if i == 5:
                continue
            if i == 6:
                break
            for j in range(10):
                if j == 5:
                    break
                x += i * j
        """)
        target_code = utils.clip_head("""
        x = 0
        for i in range(10):
            breaked_1 = False
            continued_0 = False
            if i == 5:
                continued_0 = True
            if not continued_0:
                if i == 6:
                    breaked_1 = True
            if not continued_0 and not breaked_1:
                for j in range(10):
                    breaked_2 = False
                    if j == 5:
                        breaked_2 = True
                    if not breaked_2:
                        x += i * j
                    keepgoing = not breaked_2
                    if breaked_2:
                        break
            keepgoing = not breaked_1
            if breaked_1:
                break
        """)
        orig_ast = gast.ast_to_gast(ast.parse(orig_code))
        target_ast = gast.ast_to_gast(ast.parse(target_code))
        converted_ast = self.canonicalizer.visit(orig_ast)

        assert_semantically_equals(orig_code, target_code, ['x', 'i', 'j'])
        assert compare_ast(converted_ast, target_ast)
        assert compare_ast(target_ast, converted_ast)
Beispiel #5
0
    def test_return(self):
        orig_code = utils.clip_head("""
        def func(a, b):
            for i in range(a):
                if i == b:
                    return i
            return 0

        value = 0
        for a in range(10):
            for b in range(10):
                value += func(a, b)
        """)
        target_code = utils.clip_head("""
        def func(a, b):
            returned_value = None
            returned_1 = False
            for i in range(a):
                if i == b:
                    returned_1 = True
                    returned_value = i
                keepgoing = not returned_1
                if returned_1:
                    break
            if not returned_1:
                returned_1 = True
                returned_value = 0
            return returned_value

        value = 0
        for a in range(10):
            for b in range(10):
                value += func(a, b)
        """)
        orig_ast = gast.ast_to_gast(ast.parse(orig_code))
        target_ast = gast.ast_to_gast(ast.parse(target_code))
        converted_ast = self.canonicalizer.visit(orig_ast)

        assert_semantically_equals(orig_code, target_code, ['value'])
        assert compare_ast(converted_ast, target_ast)
        assert compare_ast(target_ast, converted_ast)
 def test_usub(self):
     orig_ast = gast.ast_to_gast(ast.parse("-3"))
     target_ast = gast.Module(body=[gast.Expr(value=gast.Num(n=-3))])
     assert compare_ast(self.canonicalizer.visit(orig_ast), target_ast)
 def test_usub(self):
     orig_ast = gast.ast_to_gast(ast.parse("-3"))
     target_ast = gast.Module(
         body=[gast.Expr(value=gast.Constant(value=-3, kind=None))],
         type_ignores=[])
     assert compare_ast(self.canonicalizer.visit(orig_ast), target_ast)
Beispiel #8
0
    def test_return_continue(self):
        orig_code = utils.clip_head("""
        def func(a, b, c):
            x = 0
            for i in range(a):
                if i == b:
                    continue
                for j in range(a):
                    if j == c:
                        continue
                    if j == b:
                        return x
                    x += i * j
            return x

        value = 0
        for a in range(10):
            for b in range(10):
                for c in range(10):
                    value += func(a, b, c)
        """)
        target_code = utils.clip_head("""
        def func(a, b, c):
            returned_value = None
            returned_1 = False
            x = 0
            for i in range(a):
                continued_1 = False
                if i == b:
                    continued_1 = True
                if not continued_1:
                    for j in range(a):
                        continued_2 = False
                        if j == c:
                            continued_2 = True
                        if not continued_2:
                            if j == b:
                                if not continued_2:
                                    returned_1 = True
                                    returned_value = x
                        if not continued_2 and not returned_1:
                            x += i * j
                        keepgoing = not returned_1
                        if returned_1:
                            break
                keepgoing = not returned_1
                if returned_1:
                    break
            if not returned_1:
                returned_1 = True
                returned_value = x
            return returned_value

        value = 0
        for a in range(10):
            for b in range(10):
                for c in range(10):
                    value += func(a, b, c)
        """)
        orig_ast = gast.ast_to_gast(ast.parse(orig_code))
        target_ast = gast.ast_to_gast(ast.parse(target_code))
        converted_ast = self.canonicalizer.visit(orig_ast)

        assert_semantically_equals(orig_code, target_code, ['value'])
        assert compare_ast(converted_ast, target_ast)
        assert compare_ast(target_ast, converted_ast)