def test_return_simple(self): orig_code = utils.clip_head(""" def func(a, b): return a + b value = func(3, 39) """) target_code = utils.clip_head(""" def func(a, b): returned_value = None returned_1 = False returned_1 = True returned_value = a + b return returned_value value = func(3, 39) """) orig_ast = gast.ast_to_gast(ast.parse(orig_code)) target_ast = gast.ast_to_gast(ast.parse(target_code)) converted_ast = self.canonicalizer.visit(orig_ast) assert_semantically_equals(orig_code, target_code, ['value']) assert compare_ast(converted_ast, target_ast) assert compare_ast(target_ast, converted_ast)
def test_continue_nested(self): orig_code = utils.clip_head(""" x = 0 for i in range(10): if i == 5: continue for j in range(10): if j == 5: continue x += i * j """) target_code = utils.clip_head(""" x = 0 for i in range(10): continued_0 = False if i == 5: continued_0 = True if not continued_0: for j in range(10): continued_1 = False if j == 5: continued_1 = True if not continued_1: x += i * j """) orig_ast = gast.ast_to_gast(ast.parse(orig_code)) target_ast = gast.ast_to_gast(ast.parse(target_code)) converted_ast = self.canonicalizer.visit(orig_ast) assert_semantically_equals(orig_code, target_code, ['x', 'i', 'j']) assert compare_ast(converted_ast, target_ast) assert compare_ast(target_ast, converted_ast)
def test_break(self): orig_code = utils.clip_head(""" x = 0 for i in range(10): if i == 5: break x += i """) target_code = utils.clip_head(""" x = 0 for i in range(10): breaked_0 = False if i == 5: breaked_0 = True if not breaked_0: x += i keepgoing = not breaked_0 if breaked_0: break """) orig_ast = gast.ast_to_gast(ast.parse(orig_code)) target_ast = gast.ast_to_gast(ast.parse(target_code)) converted_ast = self.canonicalizer.visit(orig_ast) assert_semantically_equals(orig_code, target_code, ['x', 'i']) assert compare_ast(converted_ast, target_ast) assert compare_ast(target_ast, converted_ast)
def test_continue_break_nested(self): orig_code = utils.clip_head(""" x = 0 for i in range(10): if i == 5: continue if i == 6: break for j in range(10): if j == 5: break x += i * j """) target_code = utils.clip_head(""" x = 0 for i in range(10): breaked_1 = False continued_0 = False if i == 5: continued_0 = True if not continued_0: if i == 6: breaked_1 = True if not continued_0 and not breaked_1: for j in range(10): breaked_2 = False if j == 5: breaked_2 = True if not breaked_2: x += i * j keepgoing = not breaked_2 if breaked_2: break keepgoing = not breaked_1 if breaked_1: break """) orig_ast = gast.ast_to_gast(ast.parse(orig_code)) target_ast = gast.ast_to_gast(ast.parse(target_code)) converted_ast = self.canonicalizer.visit(orig_ast) assert_semantically_equals(orig_code, target_code, ['x', 'i', 'j']) assert compare_ast(converted_ast, target_ast) assert compare_ast(target_ast, converted_ast)
def test_return(self): orig_code = utils.clip_head(""" def func(a, b): for i in range(a): if i == b: return i return 0 value = 0 for a in range(10): for b in range(10): value += func(a, b) """) target_code = utils.clip_head(""" def func(a, b): returned_value = None returned_1 = False for i in range(a): if i == b: returned_1 = True returned_value = i keepgoing = not returned_1 if returned_1: break if not returned_1: returned_1 = True returned_value = 0 return returned_value value = 0 for a in range(10): for b in range(10): value += func(a, b) """) orig_ast = gast.ast_to_gast(ast.parse(orig_code)) target_ast = gast.ast_to_gast(ast.parse(target_code)) converted_ast = self.canonicalizer.visit(orig_ast) assert_semantically_equals(orig_code, target_code, ['value']) assert compare_ast(converted_ast, target_ast) assert compare_ast(target_ast, converted_ast)
def test_usub(self): orig_ast = gast.ast_to_gast(ast.parse("-3")) target_ast = gast.Module(body=[gast.Expr(value=gast.Num(n=-3))]) assert compare_ast(self.canonicalizer.visit(orig_ast), target_ast)
def test_usub(self): orig_ast = gast.ast_to_gast(ast.parse("-3")) target_ast = gast.Module( body=[gast.Expr(value=gast.Constant(value=-3, kind=None))], type_ignores=[]) assert compare_ast(self.canonicalizer.visit(orig_ast), target_ast)
def test_return_continue(self): orig_code = utils.clip_head(""" def func(a, b, c): x = 0 for i in range(a): if i == b: continue for j in range(a): if j == c: continue if j == b: return x x += i * j return x value = 0 for a in range(10): for b in range(10): for c in range(10): value += func(a, b, c) """) target_code = utils.clip_head(""" def func(a, b, c): returned_value = None returned_1 = False x = 0 for i in range(a): continued_1 = False if i == b: continued_1 = True if not continued_1: for j in range(a): continued_2 = False if j == c: continued_2 = True if not continued_2: if j == b: if not continued_2: returned_1 = True returned_value = x if not continued_2 and not returned_1: x += i * j keepgoing = not returned_1 if returned_1: break keepgoing = not returned_1 if returned_1: break if not returned_1: returned_1 = True returned_value = x return returned_value value = 0 for a in range(10): for b in range(10): for c in range(10): value += func(a, b, c) """) orig_ast = gast.ast_to_gast(ast.parse(orig_code)) target_ast = gast.ast_to_gast(ast.parse(target_code)) converted_ast = self.canonicalizer.visit(orig_ast) assert_semantically_equals(orig_code, target_code, ['value']) assert compare_ast(converted_ast, target_ast) assert compare_ast(target_ast, converted_ast)