예제 #1
0
def test_assumptions():
    # ASTs generated seperately from the same source should compare equal
    test_tree = vy_ast.parse_to_ast("foo = 42")
    expected_tree = vy_ast.parse_to_ast("foo = 42")
    assert vy_ast.compare_nodes(test_tree, expected_tree)

    # ASTs generated seperately with different source should compare not-equal
    test_tree = vy_ast.parse_to_ast("foo = 42")
    expected_tree = vy_ast.parse_to_ast("bar = 666")
    assert not vy_ast.compare_nodes(test_tree, expected_tree)
예제 #2
0
파일: local.py 프로젝트: vyperlang/vyper
def _check_iterator_modification(
        target_node: vy_ast.VyperNode,
        search_node: vy_ast.VyperNode) -> Optional[vy_ast.VyperNode]:
    similar_nodes = [
        n for n in search_node.get_descendants(type(target_node))
        if vy_ast.compare_nodes(target_node, n)
    ]

    for node in similar_nodes:
        # raise if the node is the target of an assignment statement
        assign_node = node.get_ancestor((vy_ast.Assign, vy_ast.AugAssign))
        # note the use of get_descendants() blocks statements like
        # self.my_array[i] = x
        if assign_node and node in assign_node.target.get_descendants(
                include_self=True):
            return node

        attr_node = node.get_ancestor(vy_ast.Attribute)
        # note the use of get_descendants() blocks statements like
        # self.my_array[i].append(x)
        if (attr_node is not None
                and node in attr_node.value.get_descendants(include_self=True)
                and attr_node.attr in ("append", "pop", "extend")):
            return node

    return None
예제 #3
0
def test_replace_literal_ops():
    test_ast = vy_ast.parse_to_ast("[not True, True and False, True or False]")
    expected_ast = vy_ast.parse_to_ast("[False, False, True]")

    folding.replace_literal_ops(test_ast)

    assert vy_ast.compare_nodes(test_ast, expected_ast)
예제 #4
0
def test_replace_builtins(source, original, result):
    original_ast = vy_ast.parse_to_ast(source.format(original))
    target_ast = vy_ast.parse_to_ast(source.format(result))

    folding.replace_builtin_functions(original_ast)

    assert vy_ast.compare_nodes(original_ast, target_ast)
예제 #5
0
def test_replace_binop_simple():
    test_ast = vy_ast.parse_to_ast("1 + 2")
    expected_ast = vy_ast.parse_to_ast("3")

    folding.replace_literal_ops(test_ast)

    assert vy_ast.compare_nodes(test_ast, expected_ast)
예제 #6
0
def test_replace_binop_nested():
    test_ast = vy_ast.parse_to_ast("((6 + (2**4)) * 4) / 2")
    expected_ast = vy_ast.parse_to_ast("44")

    folding.replace_literal_ops(test_ast)

    assert vy_ast.compare_nodes(test_ast, expected_ast)
예제 #7
0
def test_replace_subscripts_simple():
    test_ast = vy_ast.parse_to_ast("[foo, bar, baz][1]")
    expected_ast = vy_ast.parse_to_ast("bar")

    folding.replace_subscripts(test_ast)

    assert vy_ast.compare_nodes(test_ast, expected_ast)
예제 #8
0
def test_replace_builtin_constant_no(source):
    unmodified_ast = vy_ast.parse_to_ast(source)
    folded_ast = vy_ast.parse_to_ast(source)

    folding.replace_builtin_constants(folded_ast)

    assert vy_ast.compare_nodes(unmodified_ast, folded_ast)
예제 #9
0
def test_replace_subscripts_nested():
    test_ast = vy_ast.parse_to_ast("[[0, 1], [2, 3], [4, 5]][2][1]")
    expected_ast = vy_ast.parse_to_ast("5")

    folding.replace_subscripts(test_ast)

    assert vy_ast.compare_nodes(test_ast, expected_ast)
예제 #10
0
def test_replace_constant_no(source):
    unmodified_ast = vy_ast.parse_to_ast(source)
    folded_ast = vy_ast.parse_to_ast(source)

    folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), True)

    assert vy_ast.compare_nodes(unmodified_ast, folded_ast)
예제 #11
0
def test_integration():
    test_ast = vy_ast.parse_to_ast("[1+2, 6+7][8-8]")
    expected_ast = vy_ast.parse_to_ast("3")

    folding.fold(test_ast)

    assert vy_ast.compare_nodes(test_ast, expected_ast)
예제 #12
0
def test_compare_different_node_clases():
    vyper_ast = vy_ast.parse_to_ast("foo = 42")
    left = vyper_ast.body[0].target
    right = vyper_ast.body[0].value

    assert left != right
    assert not vy_ast.compare_nodes(left, right)
예제 #13
0
def test_compare_complex_nodes_same_value():
    vyper_ast = vy_ast.parse_to_ast(
        "[{'foo':'bar', 43:[1,2,3]}, {'foo':'bar', 43:[1,2,3]}]")
    left, right = vyper_ast.body[0].value.elts

    assert left != right
    assert vy_ast.compare_nodes(left, right)
예제 #14
0
def test_replace_userdefined_constant_no(source):
    source = f"FOO: constant(int128) = 42\n{source}"

    unmodified_ast = vy_ast.parse_to_ast(source)
    folded_ast = vy_ast.parse_to_ast(source)

    folding.replace_user_defined_constants(folded_ast)

    assert vy_ast.compare_nodes(unmodified_ast, folded_ast)
예제 #15
0
def test_simple_replacement():
    test_tree = vy_ast.parse_to_ast("foo = 42")
    expected_tree = vy_ast.parse_to_ast("bar = 42")

    old_node = test_tree.body[0].target
    new_node = vy_ast.parse_to_ast("bar").body[0].value

    test_tree.replace_in_tree(old_node, new_node)

    assert vy_ast.compare_nodes(test_tree, expected_tree)
예제 #16
0
def test_list_replacement_similar_nodes():
    test_tree = vy_ast.parse_to_ast("foo = [1, 1, 1, 1, 1]")
    expected_tree = vy_ast.parse_to_ast("foo = [1, 1, 31337, 1, 1]")

    old_node = test_tree.body[0].value.elements[2]
    new_node = vy_ast.parse_to_ast("31337").body[0].value

    test_tree.replace_in_tree(old_node, new_node)

    assert vy_ast.compare_nodes(test_tree, expected_tree)
예제 #17
0
def test_replace_userdefined_attribute(source):
    preamble = f"ADDR: constant(address) = {dummy_address}"
    l_source = f"{preamble}\n{source[0]}"
    r_source = f"{preamble}\n{source[1]}"

    l_ast = vy_ast.parse_to_ast(l_source)
    folding.replace_user_defined_constants(l_ast)

    r_ast = vy_ast.parse_to_ast(r_source)

    assert vy_ast.compare_nodes(l_ast, r_ast)
예제 #18
0
def _check_iterator_assign(
    target_node: vy_ast.VyperNode, search_node: vy_ast.VyperNode
) -> Optional[vy_ast.VyperNode]:
    similar_nodes = [
        n
        for n in search_node.get_descendants(type(target_node))
        if vy_ast.compare_nodes(target_node, n)
    ]

    for node in similar_nodes:
        # raise if the node is the target of an assignment statement
        assign_node = node.get_ancestor((vy_ast.Assign, vy_ast.AugAssign))
        if assign_node and node in assign_node.target.get_descendants(include_self=True):
            return node

    return None
예제 #19
0
def test_replace_userdefined_struct(source):
    preamble = """
struct Foo:
    a: uint256
    b: uint256

FOO: constant(Foo) = Foo({a: 123, b: 456})
    """
    l_source = f"{preamble}\n{source[0]}"
    r_source = f"{preamble}\n{source[1]}"

    l_ast = vy_ast.parse_to_ast(l_source)
    folding.replace_user_defined_constants(l_ast)

    r_ast = vy_ast.parse_to_ast(r_source)

    assert vy_ast.compare_nodes(l_ast, r_ast)
예제 #20
0
def test_replace_userdefined_nested_struct(source):
    preamble = """
struct Bar:
    b1: uint256
    b2: uint256

struct Foo:
    f1: Bar
    f2: uint256

FOO: constant(Foo) = Foo({f1: Bar({b1: 123, b2: 456}), f2: 789})
    """
    l_source = f"{preamble}\n{source[0]}"
    r_source = f"{preamble}\n{source[1]}"

    l_ast = vy_ast.parse_to_ast(l_source)
    folding.replace_user_defined_constants(l_ast)

    r_ast = vy_ast.parse_to_ast(r_source)

    assert vy_ast.compare_nodes(l_ast, r_ast)
예제 #21
0
def test_binary_becomes_bytes():
    expected = vy_ast.parse_to_ast("foo: Bytes[1] = b'\x01'")
    mutated = vy_ast.parse_to_ast("foo: Bytes[1] = 0b00000001")

    assert vy_ast.compare_nodes(expected, mutated)
예제 #22
0
def test_compare_nodes():
    old_node = vy_ast.parse_to_ast("foo = 42")
    new_node = vy_ast.Int.from_node(old_node, value=666)

    assert not vy_ast.compare_nodes(old_node, new_node)
예제 #23
0
def test_compare_different_nodes_same_value():
    vyper_ast = vy_ast.parse_to_ast("[1, 1]")
    left, right = vyper_ast.body[0].value.elts

    assert left != right
    assert vy_ast.compare_nodes(left, right)
예제 #24
0
def test_compare_same_node():
    vyper_ast = vy_ast.parse_to_ast("42")
    node = vyper_ast.body[0].value

    assert node == node
    assert vy_ast.compare_nodes(node, node)
예제 #25
0
    def visit_For(self, node):
        if isinstance(node.iter, vy_ast.Subscript):
            raise StructureException("Cannot iterate over a nested list", node.iter)

        if isinstance(node.iter, vy_ast.Call):
            # iteration via range()
            if node.iter.get("func.id") != "range":
                raise IteratorException(
                    "Cannot iterate over the result of a function call", node.iter
                )
            validate_call_args(node.iter, (1, 2))

            args = node.iter.args
            if len(args) == 1:
                # range(CONSTANT)
                if not isinstance(args[0], vy_ast.Num):
                    raise StateAccessViolation("Value must be a literal", node)
                if args[0].value <= 0:
                    raise StructureException("For loop must have at least 1 iteration", args[0])
                validate_expected_type(args[0], Uint256Definition())
                type_list = get_possible_types_from_node(args[0])
            else:
                validate_expected_type(args[0], IntegerAbstractType())
                type_list = get_common_types(*args)
                if not isinstance(args[0], vy_ast.Constant):
                    # range(x, x + CONSTANT)
                    if not isinstance(args[1], vy_ast.BinOp) or not isinstance(
                        args[1].op, vy_ast.Add
                    ):
                        raise StructureException(
                            "Second element must be the first element plus a literal value",
                            args[0],
                        )
                    if not vy_ast.compare_nodes(args[0], args[1].left):
                        raise StructureException(
                            "First and second variable must be the same", args[1].left
                        )
                    if not isinstance(args[1].right, vy_ast.Int):
                        raise InvalidLiteral("Literal must be an integer", args[1].right)
                    if args[1].right.value < 1:
                        raise StructureException(
                            f"For loop has invalid number of iterations ({args[1].right.value}),"
                            " the value must be greater than zero",
                            args[1].right,
                        )
                else:
                    # range(CONSTANT, CONSTANT)
                    if not isinstance(args[1], vy_ast.Int):
                        raise InvalidType("Value must be a literal integer", args[1])
                    validate_expected_type(args[1], IntegerAbstractType())
                    if args[0].value >= args[1].value:
                        raise StructureException("Second value must be > first value", args[1])

        else:
            # iteration over a variable or literal list
            type_list = [
                i.value_type
                for i in get_possible_types_from_node(node.iter)
                if isinstance(i, ArrayDefinition)
            ]

        if not type_list:
            raise InvalidType("Not an iterable type", node.iter)

        if next((i for i in type_list if isinstance(i, ArrayDefinition)), False):
            raise StructureException("Cannot iterate over a nested list", node.iter)

        if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)):
            # check for references to the iterated value within the body of the loop
            assign = _check_iterator_assign(node.iter, node)
            if assign:
                raise ImmutableViolation("Cannot modify array during iteration", assign)

        if node.iter.get("value.id") == "self":
            # check if iterated value may be modified by function calls inside the loop
            iter_name = node.iter.attr
            for call_node in node.get_descendants(vy_ast.Call, {"func.value.id": "self"}):
                fn_name = call_node.func.attr

                fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": fn_name})[0]
                if _check_iterator_assign(node.iter, fn_node):
                    # check for direct modification
                    raise ImmutableViolation(
                        f"Cannot call '{fn_name}' inside for loop, it potentially "
                        f"modifies iterated storage variable '{iter_name}'",
                        call_node,
                    )

                for name in self.namespace["self"].members[fn_name].recursive_calls:
                    # check for indirect modification
                    fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": name})[0]
                    if _check_iterator_assign(node.iter, fn_node):
                        raise ImmutableViolation(
                            f"Cannot call '{fn_name}' inside for loop, it may call to '{name}' "
                            f"which potentially modifies iterated storage variable '{iter_name}'",
                            call_node,
                        )

        for_loop_exceptions = []
        iter_name = node.target.id
        for type_ in type_list:
            # type check the for loop body using each possible type for iterator value
            type_ = copy.deepcopy(type_)
            type_.is_immutable = True

            with self.namespace.enter_scope():
                try:
                    self.namespace[iter_name] = type_
                except VyperException as exc:
                    raise exc.with_annotation(node) from None

                try:
                    for n in node.body:
                        self.visit(n)
                    return
                except TypeMismatch as exc:
                    for_loop_exceptions.append(exc)

        if len(set(str(i) for i in for_loop_exceptions)) == 1:
            # if every attempt at type checking raised the same exception
            raise for_loop_exceptions[0]

        # return an aggregate TypeMismatch that shows all possible exceptions
        # depending on which type is used
        types_str = [str(i) for i in type_list]
        given_str = f"{', '.join(types_str[:1])} or {types_str[-1]}"
        raise TypeMismatch(
            f"Iterator value '{iter_name}' may be cast as {given_str}, "
            "but type checking fails with all possible types:",
            node,
            *(
                (f"Casting '{iter_name}' as {type_}: {exc.message}", exc.annotations[0])
                for type_, exc in zip(type_list, for_loop_exceptions)
            ),
        )
예제 #26
0
    def parse_for(self):
        # Type 0 for, e.g. for i in list(): ...
        if self._is_list_iter():
            return self.parse_for_list()

        if not isinstance(self.stmt.iter, vy_ast.Call):
            if isinstance(self.stmt.iter, vy_ast.Subscript):
                raise StructureException("Cannot iterate over a nested list",
                                         self.stmt.iter)
            raise StructureException(
                f"Cannot iterate over '{type(self.stmt.iter).__name__}' object",
                self.stmt.iter)
        if getattr(self.stmt.iter.func, 'id', None) != "range":
            raise StructureException(
                "Non-literals cannot be used as loop range",
                self.stmt.iter.func)
        if len(self.stmt.iter.args) not in {1, 2}:
            raise StructureException(
                f"Range expects between 1 and 2 arguments, got {len(self.stmt.iter.args)}",
                self.stmt.iter.func)

        block_scope_id = id(self.stmt)
        with self.context.make_blockscope(block_scope_id):
            # Get arg0
            arg0 = self.stmt.iter.args[0]
            num_of_args = len(self.stmt.iter.args)

            # Type 1 for, e.g. for i in range(10): ...
            if num_of_args == 1:
                arg0_val = self._get_range_const_value(arg0)
                start = LLLnode.from_list(0,
                                          typ='int128',
                                          pos=getpos(self.stmt))
                rounds = arg0_val

            # Type 2 for, e.g. for i in range(100, 110): ...
            elif self._check_valid_range_constant(self.stmt.iter.args[1],
                                                  raise_exception=False)[0]:
                arg0_val = self._get_range_const_value(arg0)
                arg1_val = self._get_range_const_value(self.stmt.iter.args[1])
                start = LLLnode.from_list(arg0_val,
                                          typ='int128',
                                          pos=getpos(self.stmt))
                rounds = LLLnode.from_list(arg1_val - arg0_val,
                                           typ='int128',
                                           pos=getpos(self.stmt))

            # Type 3 for, e.g. for i in range(x, x + 10): ...
            else:
                arg1 = self.stmt.iter.args[1]
                if not isinstance(arg1, vy_ast.BinOp) or not isinstance(
                        arg1.op, vy_ast.Add):
                    raise StructureException(
                        ("Two-arg for statements must be of the form `for i "
                         "in range(start, start + rounds): ...`"),
                        arg1,
                    )

                if not vy_ast.compare_nodes(arg0, arg1.left):
                    raise StructureException(
                        ("Two-arg for statements of the form `for i in "
                         "range(x, x + y): ...` must have x identical in both "
                         f"places: {vy_ast.ast_to_dict(arg0)} {vy_ast.ast_to_dict(arg1.left)}"
                         ),
                        self.stmt.iter,
                    )

                rounds = self._get_range_const_value(arg1.right)
                start = Expr.parse_value_expr(arg0, self.context)

            r = rounds if isinstance(rounds, int) else rounds.value
            if r < 1:
                raise StructureException(
                    f"For loop has invalid number of iterations ({r}),"
                    " the value must be greater than zero", self.stmt.iter)

            varname = self.stmt.target.id
            pos = self.context.new_variable(varname,
                                            BaseType('int128'),
                                            pos=getpos(self.stmt))
            self.context.forvars[varname] = True
            o = LLLnode.from_list(
                [
                    'repeat', pos, start, rounds,
                    parse_body(self.stmt.body, self.context)
                ],
                typ=None,
                pos=getpos(self.stmt),
            )
            del self.context.vars[varname]
            del self.context.forvars[varname]

        return o