def test_assumptions(): # ASTs generated seperately from the same source should compare equal test_tree = sri_ast.parse_to_ast("foo = 42") expected_tree = sri_ast.parse_to_ast("foo = 42") assert sri_ast.compare_nodes(test_tree, expected_tree) # ASTs generated seperately with different source should compare not-equal test_tree = sri_ast.parse_to_ast("foo = 42") expected_tree = sri_ast.parse_to_ast("bar = 666") assert not sri_ast.compare_nodes(test_tree, expected_tree)
def test_replace_builtin_constant_no(source): unmodified_ast = sri_ast.parse_to_ast(source) folded_ast = sri_ast.parse_to_ast(source) folding.replace_builtin_constants(folded_ast) assert sri_ast.compare_nodes(unmodified_ast, folded_ast)
def test_replace_constant_no(source): unmodified_ast = sri_ast.parse_to_ast(source) folded_ast = sri_ast.parse_to_ast(source) folding.replace_constant(folded_ast, "FOO", sri_ast.Int(value=31337), True) assert sri_ast.compare_nodes(unmodified_ast, folded_ast)
def test_compare_different_node_clases(): srilang_ast = sri_ast.parse_to_ast("foo = 42") left = srilang_ast.body[0].target right = srilang_ast.body[0].value assert left != right assert not sri_ast.compare_nodes(left, right)
def test_replace_builtins(source, original, result): original_ast = sri_ast.parse_to_ast(source.format(original)) target_ast = sri_ast.parse_to_ast(source.format(result)) folding.replace_builtin_functions(original_ast) assert sri_ast.compare_nodes(original_ast, target_ast)
def test_replace_subscripts_simple(): test_ast = sri_ast.parse_to_ast("[foo, bar, baz][1]") expected_ast = sri_ast.parse_to_ast("bar") folding.replace_subscripts(test_ast) assert sri_ast.compare_nodes(test_ast, expected_ast)
def test_compare_complex_nodes_same_value(): srilang_ast = sri_ast.parse_to_ast( "[{'foo':'bar', 43:[1,2,3]}, {'foo':'bar', 43:[1,2,3]}]") left, right = srilang_ast.body[0].value.elts assert left != right assert sri_ast.compare_nodes(left, right)
def test_replace_subscripts_nested(): test_ast = sri_ast.parse_to_ast("[[0, 1], [2, 3], [4, 5]][2][1]") expected_ast = sri_ast.parse_to_ast("5") folding.replace_subscripts(test_ast) assert sri_ast.compare_nodes(test_ast, expected_ast)
def test_integration(): test_ast = sri_ast.parse_to_ast("[1+2, 6+7][8-8]") expected_ast = sri_ast.parse_to_ast("3") folding.fold(test_ast) assert sri_ast.compare_nodes(test_ast, expected_ast)
def test_replace_binop_nested(): test_ast = sri_ast.parse_to_ast("((6 + (2**4)) * 4) / 2") expected_ast = sri_ast.parse_to_ast("44") folding.replace_literal_ops(test_ast) assert sri_ast.compare_nodes(test_ast, expected_ast)
def test_replace_binop_simple(): test_ast = sri_ast.parse_to_ast("1 + 2") expected_ast = sri_ast.parse_to_ast("3") folding.replace_literal_ops(test_ast) assert sri_ast.compare_nodes(test_ast, expected_ast)
def test_replace_literal_ops(): test_ast = sri_ast.parse_to_ast( "[not True, True and False, True or False]") expected_ast = sri_ast.parse_to_ast("[False, False, True]") folding.replace_literal_ops(test_ast) assert sri_ast.compare_nodes(test_ast, expected_ast)
def test_replace_userdefined_constant_no(source): source = f"FOO: constant(int128) = 42\n{source}" unmodified_ast = sri_ast.parse_to_ast(source) folded_ast = sri_ast.parse_to_ast(source) folding.replace_user_defined_constants(folded_ast) assert sri_ast.compare_nodes(unmodified_ast, folded_ast)
def test_list_replacement_similar_nodes(): test_tree = sri_ast.parse_to_ast("foo = [1, 1, 1, 1, 1]") expected_tree = sri_ast.parse_to_ast("foo = [1, 1, 31337, 1, 1]") old_node = test_tree.body[0].value.elts[2] new_node = sri_ast.parse_to_ast("31337").body[0].value test_tree.replace_in_tree(old_node, new_node) assert sri_ast.compare_nodes(test_tree, expected_tree)
def test_simple_replacement(): test_tree = sri_ast.parse_to_ast("foo = 42") expected_tree = sri_ast.parse_to_ast("bar = 42") old_node = test_tree.body[0].target new_node = sri_ast.parse_to_ast("bar").body[0].value test_tree.replace_in_tree(old_node, new_node) assert sri_ast.compare_nodes(test_tree, expected_tree)
def parse_for(self): # Type 0 for, e.g. for i in list(): ... if self._is_list_iter(): return self.parse_for_list() if not isinstance(self.stmt.iter, sri_ast.Call): if isinstance(self.stmt.iter, sri_ast.Subscript): raise StructureException("Cannot iterate over a nested list", self.stmt.iter) raise StructureException( f"Cannot iterate over '{type(self.stmt.iter).__name__}' object", self.stmt.iter) if getattr(self.stmt.iter.func, 'id', None) != "range": raise StructureException( "Non-literals cannot be used as loop range", self.stmt.iter.func) if len(self.stmt.iter.args) not in {1, 2}: raise StructureException( f"Range expects between 1 and 2 arguments, got {len(self.stmt.iter.args)}", self.stmt.iter.func) block_scope_id = id(self.stmt) with self.context.make_blockscope(block_scope_id): # Get arg0 arg0 = self.stmt.iter.args[0] num_of_args = len(self.stmt.iter.args) # Type 1 for, e.g. for i in range(10): ... if num_of_args == 1: arg0_val = self._get_range_const_value(arg0) start = LLLnode.from_list(0, typ='int128', pos=getpos(self.stmt)) rounds = arg0_val # Type 2 for, e.g. for i in range(100, 110): ... elif self._check_valid_range_constant(self.stmt.iter.args[1], raise_exception=False)[0]: arg0_val = self._get_range_const_value(arg0) arg1_val = self._get_range_const_value(self.stmt.iter.args[1]) start = LLLnode.from_list(arg0_val, typ='int128', pos=getpos(self.stmt)) rounds = LLLnode.from_list(arg1_val - arg0_val, typ='int128', pos=getpos(self.stmt)) # Type 3 for, e.g. for i in range(x, x + 10): ... else: arg1 = self.stmt.iter.args[1] if not isinstance(arg1, sri_ast.BinOp) or not isinstance( arg1.op, sri_ast.Add): raise StructureException( ("Two-arg for statements must be of the form `for i " "in range(start, start + rounds): ...`"), arg1, ) if not sri_ast.compare_nodes(arg0, arg1.left): raise StructureException( ("Two-arg for statements of the form `for i in " "range(x, x + y): ...` must have x identical in both " f"places: {sri_ast.ast_to_dict(arg0)} {sri_ast.ast_to_dict(arg1.left)}" ), self.stmt.iter, ) rounds = self._get_range_const_value(arg1.right) start = Expr.parse_value_expr(arg0, self.context) r = rounds if isinstance(rounds, int) else rounds.value if r < 1: raise StructureException( f"For loop has invalid number of iterations ({r})," " the value must be greater than zero", self.stmt.iter) varname = self.stmt.target.id pos = self.context.new_variable(varname, BaseType('int128'), pos=getpos(self.stmt)) self.context.forvars[varname] = True o = LLLnode.from_list( [ 'repeat', pos, start, rounds, parse_body(self.stmt.body, self.context) ], typ=None, pos=getpos(self.stmt), ) del self.context.vars[varname] del self.context.forvars[varname] return o
def test_compare_nodes(): old_node = sri_ast.parse_to_ast("foo = 42") new_node = sri_ast.Int.from_node(old_node, value=666) assert not sri_ast.compare_nodes(old_node, new_node)
def test_compare_same_node(): srilang_ast = sri_ast.parse_to_ast("42") node = srilang_ast.body[0].value assert node == node assert sri_ast.compare_nodes(node, node)
def test_compare_different_nodes_same_value(): srilang_ast = sri_ast.parse_to_ast("[1, 1]") left, right = srilang_ast.body[0].value.elts assert left != right assert sri_ast.compare_nodes(left, right)
def test_binary_becomes_bytes(): expected = sri_ast.parse_to_ast("foo: bytes[1] = b'\x01'") mutated = sri_ast.parse_to_ast("foo: bytes[1] = 0b00000001") assert sri_ast.compare_nodes(expected, mutated)