def test_assumptions(): # ASTs generated seperately from the same source should compare equal test_tree = vy_ast.parse_to_ast("foo = 42") expected_tree = vy_ast.parse_to_ast("foo = 42") assert vy_ast.compare_nodes(test_tree, expected_tree) # ASTs generated seperately with different source should compare not-equal test_tree = vy_ast.parse_to_ast("foo = 42") expected_tree = vy_ast.parse_to_ast("bar = 666") assert not vy_ast.compare_nodes(test_tree, expected_tree)
def _check_iterator_modification( target_node: vy_ast.VyperNode, search_node: vy_ast.VyperNode) -> Optional[vy_ast.VyperNode]: similar_nodes = [ n for n in search_node.get_descendants(type(target_node)) if vy_ast.compare_nodes(target_node, n) ] for node in similar_nodes: # raise if the node is the target of an assignment statement assign_node = node.get_ancestor((vy_ast.Assign, vy_ast.AugAssign)) # note the use of get_descendants() blocks statements like # self.my_array[i] = x if assign_node and node in assign_node.target.get_descendants( include_self=True): return node attr_node = node.get_ancestor(vy_ast.Attribute) # note the use of get_descendants() blocks statements like # self.my_array[i].append(x) if (attr_node is not None and node in attr_node.value.get_descendants(include_self=True) and attr_node.attr in ("append", "pop", "extend")): return node return None
def test_replace_literal_ops(): test_ast = vy_ast.parse_to_ast("[not True, True and False, True or False]") expected_ast = vy_ast.parse_to_ast("[False, False, True]") folding.replace_literal_ops(test_ast) assert vy_ast.compare_nodes(test_ast, expected_ast)
def test_replace_builtins(source, original, result): original_ast = vy_ast.parse_to_ast(source.format(original)) target_ast = vy_ast.parse_to_ast(source.format(result)) folding.replace_builtin_functions(original_ast) assert vy_ast.compare_nodes(original_ast, target_ast)
def test_replace_binop_simple(): test_ast = vy_ast.parse_to_ast("1 + 2") expected_ast = vy_ast.parse_to_ast("3") folding.replace_literal_ops(test_ast) assert vy_ast.compare_nodes(test_ast, expected_ast)
def test_replace_binop_nested(): test_ast = vy_ast.parse_to_ast("((6 + (2**4)) * 4) / 2") expected_ast = vy_ast.parse_to_ast("44") folding.replace_literal_ops(test_ast) assert vy_ast.compare_nodes(test_ast, expected_ast)
def test_replace_subscripts_simple(): test_ast = vy_ast.parse_to_ast("[foo, bar, baz][1]") expected_ast = vy_ast.parse_to_ast("bar") folding.replace_subscripts(test_ast) assert vy_ast.compare_nodes(test_ast, expected_ast)
def test_replace_builtin_constant_no(source): unmodified_ast = vy_ast.parse_to_ast(source) folded_ast = vy_ast.parse_to_ast(source) folding.replace_builtin_constants(folded_ast) assert vy_ast.compare_nodes(unmodified_ast, folded_ast)
def test_replace_subscripts_nested(): test_ast = vy_ast.parse_to_ast("[[0, 1], [2, 3], [4, 5]][2][1]") expected_ast = vy_ast.parse_to_ast("5") folding.replace_subscripts(test_ast) assert vy_ast.compare_nodes(test_ast, expected_ast)
def test_replace_constant_no(source): unmodified_ast = vy_ast.parse_to_ast(source) folded_ast = vy_ast.parse_to_ast(source) folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), True) assert vy_ast.compare_nodes(unmodified_ast, folded_ast)
def test_integration(): test_ast = vy_ast.parse_to_ast("[1+2, 6+7][8-8]") expected_ast = vy_ast.parse_to_ast("3") folding.fold(test_ast) assert vy_ast.compare_nodes(test_ast, expected_ast)
def test_compare_different_node_clases(): vyper_ast = vy_ast.parse_to_ast("foo = 42") left = vyper_ast.body[0].target right = vyper_ast.body[0].value assert left != right assert not vy_ast.compare_nodes(left, right)
def test_compare_complex_nodes_same_value(): vyper_ast = vy_ast.parse_to_ast( "[{'foo':'bar', 43:[1,2,3]}, {'foo':'bar', 43:[1,2,3]}]") left, right = vyper_ast.body[0].value.elts assert left != right assert vy_ast.compare_nodes(left, right)
def test_replace_userdefined_constant_no(source): source = f"FOO: constant(int128) = 42\n{source}" unmodified_ast = vy_ast.parse_to_ast(source) folded_ast = vy_ast.parse_to_ast(source) folding.replace_user_defined_constants(folded_ast) assert vy_ast.compare_nodes(unmodified_ast, folded_ast)
def test_simple_replacement(): test_tree = vy_ast.parse_to_ast("foo = 42") expected_tree = vy_ast.parse_to_ast("bar = 42") old_node = test_tree.body[0].target new_node = vy_ast.parse_to_ast("bar").body[0].value test_tree.replace_in_tree(old_node, new_node) assert vy_ast.compare_nodes(test_tree, expected_tree)
def test_list_replacement_similar_nodes(): test_tree = vy_ast.parse_to_ast("foo = [1, 1, 1, 1, 1]") expected_tree = vy_ast.parse_to_ast("foo = [1, 1, 31337, 1, 1]") old_node = test_tree.body[0].value.elements[2] new_node = vy_ast.parse_to_ast("31337").body[0].value test_tree.replace_in_tree(old_node, new_node) assert vy_ast.compare_nodes(test_tree, expected_tree)
def test_replace_userdefined_attribute(source): preamble = f"ADDR: constant(address) = {dummy_address}" l_source = f"{preamble}\n{source[0]}" r_source = f"{preamble}\n{source[1]}" l_ast = vy_ast.parse_to_ast(l_source) folding.replace_user_defined_constants(l_ast) r_ast = vy_ast.parse_to_ast(r_source) assert vy_ast.compare_nodes(l_ast, r_ast)
def _check_iterator_assign( target_node: vy_ast.VyperNode, search_node: vy_ast.VyperNode ) -> Optional[vy_ast.VyperNode]: similar_nodes = [ n for n in search_node.get_descendants(type(target_node)) if vy_ast.compare_nodes(target_node, n) ] for node in similar_nodes: # raise if the node is the target of an assignment statement assign_node = node.get_ancestor((vy_ast.Assign, vy_ast.AugAssign)) if assign_node and node in assign_node.target.get_descendants(include_self=True): return node return None
def test_replace_userdefined_struct(source): preamble = """ struct Foo: a: uint256 b: uint256 FOO: constant(Foo) = Foo({a: 123, b: 456}) """ l_source = f"{preamble}\n{source[0]}" r_source = f"{preamble}\n{source[1]}" l_ast = vy_ast.parse_to_ast(l_source) folding.replace_user_defined_constants(l_ast) r_ast = vy_ast.parse_to_ast(r_source) assert vy_ast.compare_nodes(l_ast, r_ast)
def test_replace_userdefined_nested_struct(source): preamble = """ struct Bar: b1: uint256 b2: uint256 struct Foo: f1: Bar f2: uint256 FOO: constant(Foo) = Foo({f1: Bar({b1: 123, b2: 456}), f2: 789}) """ l_source = f"{preamble}\n{source[0]}" r_source = f"{preamble}\n{source[1]}" l_ast = vy_ast.parse_to_ast(l_source) folding.replace_user_defined_constants(l_ast) r_ast = vy_ast.parse_to_ast(r_source) assert vy_ast.compare_nodes(l_ast, r_ast)
def test_binary_becomes_bytes(): expected = vy_ast.parse_to_ast("foo: Bytes[1] = b'\x01'") mutated = vy_ast.parse_to_ast("foo: Bytes[1] = 0b00000001") assert vy_ast.compare_nodes(expected, mutated)
def test_compare_nodes(): old_node = vy_ast.parse_to_ast("foo = 42") new_node = vy_ast.Int.from_node(old_node, value=666) assert not vy_ast.compare_nodes(old_node, new_node)
def test_compare_different_nodes_same_value(): vyper_ast = vy_ast.parse_to_ast("[1, 1]") left, right = vyper_ast.body[0].value.elts assert left != right assert vy_ast.compare_nodes(left, right)
def test_compare_same_node(): vyper_ast = vy_ast.parse_to_ast("42") node = vyper_ast.body[0].value assert node == node assert vy_ast.compare_nodes(node, node)
def visit_For(self, node): if isinstance(node.iter, vy_ast.Subscript): raise StructureException("Cannot iterate over a nested list", node.iter) if isinstance(node.iter, vy_ast.Call): # iteration via range() if node.iter.get("func.id") != "range": raise IteratorException( "Cannot iterate over the result of a function call", node.iter ) validate_call_args(node.iter, (1, 2)) args = node.iter.args if len(args) == 1: # range(CONSTANT) if not isinstance(args[0], vy_ast.Num): raise StateAccessViolation("Value must be a literal", node) if args[0].value <= 0: raise StructureException("For loop must have at least 1 iteration", args[0]) validate_expected_type(args[0], Uint256Definition()) type_list = get_possible_types_from_node(args[0]) else: validate_expected_type(args[0], IntegerAbstractType()) type_list = get_common_types(*args) if not isinstance(args[0], vy_ast.Constant): # range(x, x + CONSTANT) if not isinstance(args[1], vy_ast.BinOp) or not isinstance( args[1].op, vy_ast.Add ): raise StructureException( "Second element must be the first element plus a literal value", args[0], ) if not vy_ast.compare_nodes(args[0], args[1].left): raise StructureException( "First and second variable must be the same", args[1].left ) if not isinstance(args[1].right, vy_ast.Int): raise InvalidLiteral("Literal must be an integer", args[1].right) if args[1].right.value < 1: raise StructureException( f"For loop has invalid number of iterations ({args[1].right.value})," " the value must be greater than zero", args[1].right, ) else: # range(CONSTANT, CONSTANT) if not isinstance(args[1], vy_ast.Int): raise InvalidType("Value must be a literal integer", args[1]) validate_expected_type(args[1], IntegerAbstractType()) if args[0].value >= args[1].value: raise StructureException("Second value must be > first value", args[1]) else: # iteration over a variable or literal list type_list = [ i.value_type for i in get_possible_types_from_node(node.iter) if isinstance(i, ArrayDefinition) ] if not type_list: raise InvalidType("Not an iterable type", node.iter) if next((i for i in type_list if isinstance(i, ArrayDefinition)), False): raise StructureException("Cannot iterate over a nested list", node.iter) if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): # check for references to the iterated value within the body of the loop assign = _check_iterator_assign(node.iter, node) if assign: raise ImmutableViolation("Cannot modify array during iteration", assign) if node.iter.get("value.id") == "self": # check if iterated value may be modified by function calls inside the loop iter_name = node.iter.attr for call_node in node.get_descendants(vy_ast.Call, {"func.value.id": "self"}): fn_name = call_node.func.attr fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": fn_name})[0] if _check_iterator_assign(node.iter, fn_node): # check for direct modification raise ImmutableViolation( f"Cannot call '{fn_name}' inside for loop, it potentially " f"modifies iterated storage variable '{iter_name}'", call_node, ) for name in self.namespace["self"].members[fn_name].recursive_calls: # check for indirect modification fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": name})[0] if _check_iterator_assign(node.iter, fn_node): raise ImmutableViolation( f"Cannot call '{fn_name}' inside for loop, it may call to '{name}' " f"which potentially modifies iterated storage variable '{iter_name}'", call_node, ) for_loop_exceptions = [] iter_name = node.target.id for type_ in type_list: # type check the for loop body using each possible type for iterator value type_ = copy.deepcopy(type_) type_.is_immutable = True with self.namespace.enter_scope(): try: self.namespace[iter_name] = type_ except VyperException as exc: raise exc.with_annotation(node) from None try: for n in node.body: self.visit(n) return except TypeMismatch as exc: for_loop_exceptions.append(exc) if len(set(str(i) for i in for_loop_exceptions)) == 1: # if every attempt at type checking raised the same exception raise for_loop_exceptions[0] # return an aggregate TypeMismatch that shows all possible exceptions # depending on which type is used types_str = [str(i) for i in type_list] given_str = f"{', '.join(types_str[:1])} or {types_str[-1]}" raise TypeMismatch( f"Iterator value '{iter_name}' may be cast as {given_str}, " "but type checking fails with all possible types:", node, *( (f"Casting '{iter_name}' as {type_}: {exc.message}", exc.annotations[0]) for type_, exc in zip(type_list, for_loop_exceptions) ), )
def parse_for(self): # Type 0 for, e.g. for i in list(): ... if self._is_list_iter(): return self.parse_for_list() if not isinstance(self.stmt.iter, vy_ast.Call): if isinstance(self.stmt.iter, vy_ast.Subscript): raise StructureException("Cannot iterate over a nested list", self.stmt.iter) raise StructureException( f"Cannot iterate over '{type(self.stmt.iter).__name__}' object", self.stmt.iter) if getattr(self.stmt.iter.func, 'id', None) != "range": raise StructureException( "Non-literals cannot be used as loop range", self.stmt.iter.func) if len(self.stmt.iter.args) not in {1, 2}: raise StructureException( f"Range expects between 1 and 2 arguments, got {len(self.stmt.iter.args)}", self.stmt.iter.func) block_scope_id = id(self.stmt) with self.context.make_blockscope(block_scope_id): # Get arg0 arg0 = self.stmt.iter.args[0] num_of_args = len(self.stmt.iter.args) # Type 1 for, e.g. for i in range(10): ... if num_of_args == 1: arg0_val = self._get_range_const_value(arg0) start = LLLnode.from_list(0, typ='int128', pos=getpos(self.stmt)) rounds = arg0_val # Type 2 for, e.g. for i in range(100, 110): ... elif self._check_valid_range_constant(self.stmt.iter.args[1], raise_exception=False)[0]: arg0_val = self._get_range_const_value(arg0) arg1_val = self._get_range_const_value(self.stmt.iter.args[1]) start = LLLnode.from_list(arg0_val, typ='int128', pos=getpos(self.stmt)) rounds = LLLnode.from_list(arg1_val - arg0_val, typ='int128', pos=getpos(self.stmt)) # Type 3 for, e.g. for i in range(x, x + 10): ... else: arg1 = self.stmt.iter.args[1] if not isinstance(arg1, vy_ast.BinOp) or not isinstance( arg1.op, vy_ast.Add): raise StructureException( ("Two-arg for statements must be of the form `for i " "in range(start, start + rounds): ...`"), arg1, ) if not vy_ast.compare_nodes(arg0, arg1.left): raise StructureException( ("Two-arg for statements of the form `for i in " "range(x, x + y): ...` must have x identical in both " f"places: {vy_ast.ast_to_dict(arg0)} {vy_ast.ast_to_dict(arg1.left)}" ), self.stmt.iter, ) rounds = self._get_range_const_value(arg1.right) start = Expr.parse_value_expr(arg0, self.context) r = rounds if isinstance(rounds, int) else rounds.value if r < 1: raise StructureException( f"For loop has invalid number of iterations ({r})," " the value must be greater than zero", self.stmt.iter) varname = self.stmt.target.id pos = self.context.new_variable(varname, BaseType('int128'), pos=getpos(self.stmt)) self.context.forvars[varname] = True o = LLLnode.from_list( [ 'repeat', pos, start, rounds, parse_body(self.stmt.body, self.context) ], typ=None, pos=getpos(self.stmt), ) del self.context.vars[varname] del self.context.forvars[varname] return o