示例#1
0
def test_assumptions():
    # ASTs generated seperately from the same source should compare equal
    test_tree = sri_ast.parse_to_ast("foo = 42")
    expected_tree = sri_ast.parse_to_ast("foo = 42")
    assert sri_ast.compare_nodes(test_tree, expected_tree)

    # ASTs generated seperately with different source should compare not-equal
    test_tree = sri_ast.parse_to_ast("foo = 42")
    expected_tree = sri_ast.parse_to_ast("bar = 666")
    assert not sri_ast.compare_nodes(test_tree, expected_tree)
示例#2
0
def test_replace_builtin_constant_no(source):
    unmodified_ast = sri_ast.parse_to_ast(source)
    folded_ast = sri_ast.parse_to_ast(source)

    folding.replace_builtin_constants(folded_ast)

    assert sri_ast.compare_nodes(unmodified_ast, folded_ast)
示例#3
0
def test_replace_constant_no(source):
    unmodified_ast = sri_ast.parse_to_ast(source)
    folded_ast = sri_ast.parse_to_ast(source)

    folding.replace_constant(folded_ast, "FOO", sri_ast.Int(value=31337), True)

    assert sri_ast.compare_nodes(unmodified_ast, folded_ast)
示例#4
0
def test_compare_different_node_clases():
    srilang_ast = sri_ast.parse_to_ast("foo = 42")
    left = srilang_ast.body[0].target
    right = srilang_ast.body[0].value

    assert left != right
    assert not sri_ast.compare_nodes(left, right)
示例#5
0
def test_replace_builtins(source, original, result):
    original_ast = sri_ast.parse_to_ast(source.format(original))
    target_ast = sri_ast.parse_to_ast(source.format(result))

    folding.replace_builtin_functions(original_ast)

    assert sri_ast.compare_nodes(original_ast, target_ast)
示例#6
0
def test_replace_subscripts_simple():
    test_ast = sri_ast.parse_to_ast("[foo, bar, baz][1]")
    expected_ast = sri_ast.parse_to_ast("bar")

    folding.replace_subscripts(test_ast)

    assert sri_ast.compare_nodes(test_ast, expected_ast)
示例#7
0
def test_compare_complex_nodes_same_value():
    srilang_ast = sri_ast.parse_to_ast(
        "[{'foo':'bar', 43:[1,2,3]}, {'foo':'bar', 43:[1,2,3]}]")
    left, right = srilang_ast.body[0].value.elts

    assert left != right
    assert sri_ast.compare_nodes(left, right)
示例#8
0
def test_replace_subscripts_nested():
    test_ast = sri_ast.parse_to_ast("[[0, 1], [2, 3], [4, 5]][2][1]")
    expected_ast = sri_ast.parse_to_ast("5")

    folding.replace_subscripts(test_ast)

    assert sri_ast.compare_nodes(test_ast, expected_ast)
示例#9
0
def test_integration():
    test_ast = sri_ast.parse_to_ast("[1+2, 6+7][8-8]")
    expected_ast = sri_ast.parse_to_ast("3")

    folding.fold(test_ast)

    assert sri_ast.compare_nodes(test_ast, expected_ast)
示例#10
0
def test_replace_binop_nested():
    test_ast = sri_ast.parse_to_ast("((6 + (2**4)) * 4) / 2")
    expected_ast = sri_ast.parse_to_ast("44")

    folding.replace_literal_ops(test_ast)

    assert sri_ast.compare_nodes(test_ast, expected_ast)
示例#11
0
def test_replace_binop_simple():
    test_ast = sri_ast.parse_to_ast("1 + 2")
    expected_ast = sri_ast.parse_to_ast("3")

    folding.replace_literal_ops(test_ast)

    assert sri_ast.compare_nodes(test_ast, expected_ast)
示例#12
0
def test_replace_literal_ops():
    test_ast = sri_ast.parse_to_ast(
        "[not True, True and False, True or False]")
    expected_ast = sri_ast.parse_to_ast("[False, False, True]")

    folding.replace_literal_ops(test_ast)

    assert sri_ast.compare_nodes(test_ast, expected_ast)
示例#13
0
def test_replace_userdefined_constant_no(source):
    source = f"FOO: constant(int128) = 42\n{source}"

    unmodified_ast = sri_ast.parse_to_ast(source)
    folded_ast = sri_ast.parse_to_ast(source)

    folding.replace_user_defined_constants(folded_ast)

    assert sri_ast.compare_nodes(unmodified_ast, folded_ast)
示例#14
0
def test_list_replacement_similar_nodes():
    test_tree = sri_ast.parse_to_ast("foo = [1, 1, 1, 1, 1]")
    expected_tree = sri_ast.parse_to_ast("foo = [1, 1, 31337, 1, 1]")

    old_node = test_tree.body[0].value.elts[2]
    new_node = sri_ast.parse_to_ast("31337").body[0].value

    test_tree.replace_in_tree(old_node, new_node)

    assert sri_ast.compare_nodes(test_tree, expected_tree)
示例#15
0
def test_simple_replacement():
    test_tree = sri_ast.parse_to_ast("foo = 42")
    expected_tree = sri_ast.parse_to_ast("bar = 42")

    old_node = test_tree.body[0].target
    new_node = sri_ast.parse_to_ast("bar").body[0].value

    test_tree.replace_in_tree(old_node, new_node)

    assert sri_ast.compare_nodes(test_tree, expected_tree)
示例#16
0
    def parse_for(self):
        # Type 0 for, e.g. for i in list(): ...
        if self._is_list_iter():
            return self.parse_for_list()

        if not isinstance(self.stmt.iter, sri_ast.Call):
            if isinstance(self.stmt.iter, sri_ast.Subscript):
                raise StructureException("Cannot iterate over a nested list",
                                         self.stmt.iter)
            raise StructureException(
                f"Cannot iterate over '{type(self.stmt.iter).__name__}' object",
                self.stmt.iter)
        if getattr(self.stmt.iter.func, 'id', None) != "range":
            raise StructureException(
                "Non-literals cannot be used as loop range",
                self.stmt.iter.func)
        if len(self.stmt.iter.args) not in {1, 2}:
            raise StructureException(
                f"Range expects between 1 and 2 arguments, got {len(self.stmt.iter.args)}",
                self.stmt.iter.func)

        block_scope_id = id(self.stmt)
        with self.context.make_blockscope(block_scope_id):
            # Get arg0
            arg0 = self.stmt.iter.args[0]
            num_of_args = len(self.stmt.iter.args)

            # Type 1 for, e.g. for i in range(10): ...
            if num_of_args == 1:
                arg0_val = self._get_range_const_value(arg0)
                start = LLLnode.from_list(0,
                                          typ='int128',
                                          pos=getpos(self.stmt))
                rounds = arg0_val

            # Type 2 for, e.g. for i in range(100, 110): ...
            elif self._check_valid_range_constant(self.stmt.iter.args[1],
                                                  raise_exception=False)[0]:
                arg0_val = self._get_range_const_value(arg0)
                arg1_val = self._get_range_const_value(self.stmt.iter.args[1])
                start = LLLnode.from_list(arg0_val,
                                          typ='int128',
                                          pos=getpos(self.stmt))
                rounds = LLLnode.from_list(arg1_val - arg0_val,
                                           typ='int128',
                                           pos=getpos(self.stmt))

            # Type 3 for, e.g. for i in range(x, x + 10): ...
            else:
                arg1 = self.stmt.iter.args[1]
                if not isinstance(arg1, sri_ast.BinOp) or not isinstance(
                        arg1.op, sri_ast.Add):
                    raise StructureException(
                        ("Two-arg for statements must be of the form `for i "
                         "in range(start, start + rounds): ...`"),
                        arg1,
                    )

                if not sri_ast.compare_nodes(arg0, arg1.left):
                    raise StructureException(
                        ("Two-arg for statements of the form `for i in "
                         "range(x, x + y): ...` must have x identical in both "
                         f"places: {sri_ast.ast_to_dict(arg0)} {sri_ast.ast_to_dict(arg1.left)}"
                         ),
                        self.stmt.iter,
                    )

                rounds = self._get_range_const_value(arg1.right)
                start = Expr.parse_value_expr(arg0, self.context)

            r = rounds if isinstance(rounds, int) else rounds.value
            if r < 1:
                raise StructureException(
                    f"For loop has invalid number of iterations ({r}),"
                    " the value must be greater than zero", self.stmt.iter)

            varname = self.stmt.target.id
            pos = self.context.new_variable(varname,
                                            BaseType('int128'),
                                            pos=getpos(self.stmt))
            self.context.forvars[varname] = True
            o = LLLnode.from_list(
                [
                    'repeat', pos, start, rounds,
                    parse_body(self.stmt.body, self.context)
                ],
                typ=None,
                pos=getpos(self.stmt),
            )
            del self.context.vars[varname]
            del self.context.forvars[varname]

        return o
示例#17
0
def test_compare_nodes():
    old_node = sri_ast.parse_to_ast("foo = 42")
    new_node = sri_ast.Int.from_node(old_node, value=666)

    assert not sri_ast.compare_nodes(old_node, new_node)
示例#18
0
def test_compare_same_node():
    srilang_ast = sri_ast.parse_to_ast("42")
    node = srilang_ast.body[0].value

    assert node == node
    assert sri_ast.compare_nodes(node, node)
示例#19
0
def test_compare_different_nodes_same_value():
    srilang_ast = sri_ast.parse_to_ast("[1, 1]")
    left, right = srilang_ast.body[0].value.elts

    assert left != right
    assert sri_ast.compare_nodes(left, right)
示例#20
0
def test_binary_becomes_bytes():
    expected = sri_ast.parse_to_ast("foo: bytes[1] = b'\x01'")
    mutated = sri_ast.parse_to_ast("foo: bytes[1] = 0b00000001")

    assert sri_ast.compare_nodes(expected, mutated)