def test_cyclic_assignment(self):
        stencil = make_stencil(
            fields=[],
            statements=[
                make_var_decl(name="var"),
                make_var_decl(name="var2"),
                make_assign_to_local_var("var", make_var_acc("var2")),
                make_assign_to_local_var("var2", make_var_acc("var")),
            ],
        )

        with pytest.raises(AnalysisException):
            InferLocalVariableLocationTypeTransformation.apply(stencil)
    def test_incompatible_location(self):
        stencil = make_stencil(
            fields=[make_field("field_edge", sir.LocationType.Edge), make_field("field_cell")],
            statements=[
                make_var_decl(name="var_edge"),
                make_var_decl(name="var_cell"),
                make_assign_to_local_var("var_edge", make_field_acc("field_edge")),
                make_assign_to_local_var("var_cell", make_field_acc("field_cell")),
                make_assign_to_local_var("var_cell", make_var_acc("var_edge")),
            ],
        )

        with pytest.raises(AnalysisException):
            InferLocalVariableLocationTypeTransformation.apply(stencil)
    def test_simple_assignment(self):
        stencil = make_stencil(
            fields=[make_field("field")],
            statements=[
                make_var_decl(name="var"),
                make_assign_to_local_var("var", make_field_acc("field")),
            ],
        )

        result = InferLocalVariableLocationTypeTransformation.apply(stencil)

        vardecl = FindNodes.by_type(sir.VarDeclStmt, result)[0]
        assert vardecl.location_type == sir.LocationType.Cell
    def test_simple_assignment(self):
        stencil = make_stencil(
            fields=[make_field("field")],
            statements=[
                make_var_decl(name="var"),
                make_assign_to_local_var("var", make_field_acc("field")),
            ],
        )

        result = InferLocalVariableLocationTypeTransformation.apply(stencil)

        vardecl = result.iter_tree().if_isinstance(
            sir.VarDeclStmt).to_list()[0]
        assert vardecl.location_type == sir.LocationType.Cell
    def test_chain_assignment(self):
        stencil = make_stencil(
            fields=[make_field("field")],
            statements=[
                make_var_decl(name="var"),
                make_assign_to_local_var("var", make_field_acc("field")),
                make_var_decl(name="another_var", dtype=float_type, init=make_var_acc("var")),
            ],
        )

        result = InferLocalVariableLocationTypeTransformation.apply(stencil)

        vardecls = FindNodes.by_type(sir.VarDeclStmt, result)
        assert len(vardecls) == 2
        for vardecl in vardecls:
            assert vardecl.location_type == sir.LocationType.Cell
    def test_reduction(self):
        stencil = make_stencil(
            fields=[],
            statements=[
                make_var_decl(name="var"),
                make_assign_to_local_var(
                    "var",
                    sir.ReductionOverNeighborExpr(
                        op="+",
                        rhs=make_literal(),
                        init=make_literal(),
                        chain=[sir.LocationType.Edge, sir.LocationType.Cell],
                    ),
                ),
            ],
        )

        result = InferLocalVariableLocationTypeTransformation.apply(stencil)

        vardecl = FindNodes.by_type(sir.VarDeclStmt, result)[0]
        assert vardecl.location_type == sir.LocationType.Edge
    def test_var_type_not_deducible(self):
        stencil = make_stencil(fields=[],
                               statements=[make_var_decl(name="var")])

        with pytest.raises(AnalysisException):
            InferLocalVariableLocationTypeTransformation.apply(stencil)
示例#8
0
#             ),
#         ]
#     )
assign_pnabla_MYY_vol = sir.ExprStmt(expr=sir.AssignmentExpr(
    left=sir.FieldAccessExpr(name="pnabla_MYY",
                             vertical_offset=0,
                             horizontal_offset=sir.ZeroOffset()),
    op="=",
    right=sir.BinaryOperator(
        left=sir.FieldAccessExpr(name="pnabla_MYY",
                                 vertical_offset=0,
                                 horizontal_offset=sir.ZeroOffset()),
        op="/",
        right=sir.FieldAccessExpr(
            name="vol", vertical_offset=0, horizontal_offset=sir.ZeroOffset()),
    ),
))
statements.append(assign_pnabla_MYY_vol)

block = sir.BlockStmt(statements=statements)
ast = sir.AST(root=sir.BlockStmt(statements=statements))
vert_decl_stmt = sir.VerticalRegionDeclStmt(vertical_region=sir.VerticalRegion(
    ast=ast, interval=sir.Interval(), loop_order=common.LoopOrder.FORWARD))
ctrl_flow_ast = sir.AST(root=sir.BlockStmt(statements=[vert_decl_stmt]))
stencil = sir.Stencil(name="nabla", ast=ctrl_flow_ast, params=fields)

var_loc_type_inferred = InferLocalVariableLocationTypeTransformation.apply(
    stencil)
naive_ir = sir_to_naive.SirToNaive().visit(var_loc_type_inferred)
print(naive_codegen.NaiveCodeGenerator.apply(naive_ir))