Beispiel #1
0
    def vertical_loop(self,
                      order,
                      body,
                      upper=None,
                      lower=None,
                      var: str = None):

        if lower is None:
            lower_level, lower_offset = sir.Interval.Start, 0
        else:
            lower_level, lower_offset = self.vertical_interval_bound(lower)

        if upper is None:
            upper_level, upper_offset = sir.Interval.End, 0
        else:
            upper_level, upper_offset = self.vertical_interval_bound(upper)

        order_mapper = {
            "levels_upward": sir.VerticalRegion.Forward,
            "levels_downward": sir.VerticalRegion.Backward,
        }
        with self.ctx.vertical_region(var):
            return make_vertical_region_decl_stmt(
                make_ast(self.statements(body)),
                make_interval(lower_level, upper_level, lower_offset,
                              upper_offset),
                order_mapper[order],
            )
Beispiel #2
0
def main(args: argparse.Namespace):
    interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0)

    # create the out = in[i+1] statement
    body_ast = sir_utils.make_ast(
        [
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("out"),
                sir_utils.make_reduction_over_neighbor_expr(
                    "+",
                    sir_utils.make_literal_access_expr("1.0", SIR.BuiltinType.Float),
                    sir_utils.make_field_access_expr("in"),
                    lhs_location=SIR.LocationType.Value("Edge"),
                    rhs_location=SIR.LocationType.Value("Cell"),
                ),
                "=",
            )
        ]
    )

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward
    )

    sir = sir_utils.make_sir(
        OUTPUT_FILE,
        SIR.GridType.Value("Unstructured"),
        [
            sir_utils.make_stencil(
                OUTPUT_NAME,
                sir_utils.make_ast([vertical_region_stmt]),
                [
                    sir_utils.make_field(
                        "in",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Cell")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "out",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                ],
            ),
        ],
    )

    # print the SIR
    if args.verbose:
        sir_utils.pprint(sir)

    # compile
    code = dawn4py.compile(sir, backend="c++-naive-ico")

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
def create_vertical_region_stmt1():
    """ create a vertical region statement for the stencil
    """

    interval = serial_utils.make_interval(serial_utils.Interval.Start,
                                          serial_utils.Interval.Start, 0, 0)

    body_ast = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_unstructured_field_access_expr("c"),
            serial_utils.make_binary_operator(
                serial_utils.make_unstructured_field_access_expr("c"),
                "/",
                serial_utils.make_unstructured_field_access_expr("b"),
            ),
            "=",
        ),
        serial_utils.make_assignment_stmt(
            serial_utils.make_unstructured_field_access_expr("d"),
            serial_utils.make_binary_operator(
                serial_utils.make_unstructured_field_access_expr("d"),
                "/",
                serial_utils.make_unstructured_field_access_expr("b"),
            ),
            "=",
        ),
    ])

    vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt(
        body_ast, interval, AST.VerticalRegion.Forward)
    return vertical_region_stmt
def two_copies_mixed():
    outputfile = "StageMergerTestTwoCopiesMixed"
    interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0,
                                       0)

    body_ast = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("out_cell"),
            sir_utils.make_field_access_expr("in_cell"),
            "=",
        ),
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("out_edge"),
            sir_utils.make_field_access_expr("in_edge"),
            "=",
        )
    ])

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)

    sir = sir_utils.make_sir(
        outputfile,
        SIR.GridType.Value("Unstructured"),
        [
            sir_utils.make_stencil(
                "generated",
                sir_utils.make_ast([vertical_region_stmt]),
                [
                    sir_utils.make_field(
                        "in_cell",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Cell")], 1),
                    ),
                    sir_utils.make_field(
                        "out_cell",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Cell")], 1),
                    ),
                    sir_utils.make_field(
                        "in_edge",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1),
                    ),
                    sir_utils.make_field(
                        "out_edge",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1),
                    ),
                ],
            ),
        ],
    )
    dawn4py.compile(sir,
                    backend=backend,
                    serialize_iir=True,
                    output_file=outputfile)
    os.rename(outputfile + ".0.iir", "../input/" + outputfile + ".iir")
Beispiel #5
0
def function_call():
    outputfile = "../input/test_set_stage_location_type_function_call.sir"

    interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0,
                                       0)

    fun_ast = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("out"),
            sir_utils.make_literal_access_expr(
                value="2.0", type=sir_utils.BuiltinType.Float),
            "=",
        ),
    ])

    arg_field = sir_utils.make_field(
        "out",
        sir_utils.make_field_dimensions_unstructured(
            [SIR.LocationType.Value("Cell")], 1))

    fun = sir_utils.make_stencil_function(
        name='f',
        asts=[fun_ast],
        intervals=[interval],
        arguments=[sir_utils.make_stencil_function_arg(arg_field)])

    body_ast = sir_utils.make_ast([
        sir_utils.make_expr_stmt(expr=sir_utils.make_stencil_fun_call_expr(
            callee="f",
            arguments=[sir_utils.make_field_access_expr("out_cell")])),
    ])

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)

    sir = sir_utils.make_sir(
        outputfile,
        SIR.GridType.Value("Unstructured"), [
            sir_utils.make_stencil(
                "generated",
                sir_utils.make_ast([vertical_region_stmt]),
                [
                    sir_utils.make_field(
                        "out_cell",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Cell")], 1),
                    ),
                ],
            ),
        ],
        functions=[fun])

    f = open(outputfile, "w")
    f.write(MessageToJson(sir))
    f.close()
Beispiel #6
0
def main(args: argparse.Namespace):
    interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0,
                                       0)

    body_ast = sir_utils.make_ast([
        sir_utils.make_loop_stmt([
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("out"),
                sir_utils.make_field_access_expr("in"),
                "=",
            )
        ], [SIR.LocationType.Value("Cell"),
            SIR.LocationType.Value("Edge")])
    ])

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)

    sir = sir_utils.make_sir(
        OUTPUT_FILE,
        SIR.GridType.Value("Unstructured"),
        [
            sir_utils.make_stencil(
                OUTPUT_NAME,
                sir_utils.make_ast([vertical_region_stmt]),
                [
                    sir_utils.make_field(
                        "out",
                        sir_utils.make_field_dimensions_unstructured([
                            SIR.LocationType.Value("Cell"),
                            SIR.LocationType.Value("Edge")
                        ], 1),
                    ),
                    sir_utils.make_field(
                        "in",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1),
                    ),
                ],
            )
        ],
    )

    # print the SIR
    if args.verbose:
        sir_utils.pprint(sir)

    # compile
    code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaiveIco)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
def main(args: argparse.Namespace):
    interval = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End,
                                          0, 0)

    body_ast = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("x"),
            serial_utils.make_literal_access_expr("1", AST.BuiltinType.Double),
            "=",
        )
    ])

    vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt(
        body_ast, interval, AST.VerticalRegion.Forward)

    sir = serial_utils.make_sir(
        OUTPUT_FILE,
        AST.GridType.Value("Unstructured"),
        [
            serial_utils.make_stencil(
                OUTPUT_NAME,
                serial_utils.make_ast([vertical_region_stmt]),
                [
                    serial_utils.make_field(
                        "in",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1),
                    ),
                    serial_utils.make_field(
                        "out",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1),
                    ),
                    serial_utils.make_field(
                        "x",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1), True),
                ],
            ),
        ],
    )

    # print the SIR
    if args.verbose:
        print(MessageToJson(sir))

    # compile
    code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaiveIco)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
def main(args: argparse.Namespace):
    interval = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End,
                                          0, 0)

    body_ast = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("out"),
            serial_utils.make_binary_operator(
                serial_utils.make_var_access_expr("dt", is_external=True), "*",
                serial_utils.make_field_access_expr("in")), "="),
    ])

    vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt(
        body_ast, interval, AST.VerticalRegion.Forward)

    globals = AST.GlobalVariableMap()
    globals.map["dt"].double_value = 0.5

    sir = serial_utils.make_sir(
        OUTPUT_FILE,
        AST.GridType.Value("Unstructured"), [
            serial_utils.make_stencil(
                OUTPUT_NAME,
                serial_utils.make_ast([vertical_region_stmt]),
                [
                    serial_utils.make_field(
                        "in",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1),
                    ),
                    serial_utils.make_field(
                        "out",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1),
                    ),
                ],
            )
        ],
        global_variables=globals)

    # print the SIR
    if args.verbose:
        print(MessageToJson(sir))

    # compile
    code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CUDAIco)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
    def visit_ComputationBlock(self, node: gt_ir.ComputationBlock,
                               **kwargs: Any) -> SIR.VerticalRegionDeclStmt:
        interval = self.visit(node.interval)

        body_ast = sir_utils.make_ast(self.visit(node.body, make_block=False))

        loop_order = (SIR.VerticalRegion.Backward
                      if node.iteration_order == gt_ir.IterationOrder.BACKWARD
                      else SIR.VerticalRegion.Forward)

        vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
            body_ast, interval, loop_order)

        return vertical_region_stmt
Beispiel #10
0
def make_unstructured_stencil_sir(name=None):
    OUTPUT_NAME = name if name is not None else "unstructured_stencil"
    OUTPUT_FILE = f"{OUTPUT_NAME}.cpp"
    interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0,
                                       0)

    # create the out = in[i+1] statement
    body_ast = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("out"),
            sir_utils.make_reduction_over_neighbor_expr(
                "+",
                sir_utils.make_literal_access_expr("1.0",
                                                   SIR.BuiltinType.Float),
                sir_utils.make_field_access_expr("in"),
                chain=[
                    SIR.LocationType.Value('Edge'),
                    SIR.LocationType.Value('Cell')
                ]),
            "=",
        )
    ])

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)

    sir = sir_utils.make_sir(
        OUTPUT_FILE,
        sir_utils.GridType.Value("Unstructured"),
        [
            sir_utils.make_stencil(
                OUTPUT_NAME,
                sir_utils.make_ast([vertical_region_stmt]),
                [
                    sir_utils.make_field(
                        "in",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value('Cell')], 1)),
                    sir_utils.make_field(
                        "out",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value('Edge')], 1))
                ],
            )
        ],
    )

    return sir
Beispiel #11
0
def create_boundary_correction_region(
    value="0", i_interval=None, j_interval=None
) -> SIR.VerticalRegionDeclStmt:
    interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0)
    boundary_body = sir_utils.make_ast(
        [
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("out", [0, 0, 0]),
                sir_utils.make_literal_access_expr(value, SIR.BuiltinType.Float),
                "=",
            )
        ]
    )
    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        boundary_body, interval, SIR.VerticalRegion.Forward, IRange=i_interval, JRange=j_interval
    )
    return vertical_region_stmt
def create_vertical_region_stmt2():
    """ create a vertical region statement for the stencil
    """

    interval = sir_utils.make_interval(sir_utils.Interval.Start,
                                       sir_utils.Interval.End, 1, 0)

    body_ast = sir_utils.make_ast([
        sir_utils.make_var_decl_stmt(
            sir_utils.make_type(sir_utils.BuiltinType.Float), "m", 0, "=",
            sir_utils.make_expr(
                sir_utils.make_binary_operator(
                    sir_utils.make_literal_access_expr(
                        "1.0", sir_utils.BuiltinType.Float), "/",
                    sir_utils.make_binary_operator(
                        sir_utils.make_unstructured_field_access_expr("b"),
                        "-",
                        sir_utils.make_binary_operator(
                            sir_utils.make_unstructured_field_access_expr("a"),
                            "*",
                            sir_utils.make_unstructured_field_access_expr(
                                "c", sir_utils.make_unstructured_offset(False),
                                -1)))))),
        sir_utils.make_assignment_stmt(
            sir_utils.make_unstructured_field_access_expr("c"),
            sir_utils.make_binary_operator(
                sir_utils.make_unstructured_field_access_expr("c"), "*",
                sir_utils.make_var_access_expr("m")), "="),
        sir_utils.make_assignment_stmt(
            sir_utils.make_unstructured_field_access_expr("d"),
            sir_utils.make_binary_operator(
                sir_utils.make_binary_operator(
                    sir_utils.make_unstructured_field_access_expr("d"), "-",
                    sir_utils.make_binary_operator(
                        sir_utils.make_unstructured_field_access_expr("a"),
                        "*",
                        sir_utils.make_unstructured_field_access_expr(
                            "d", sir_utils.make_unstructured_offset(False),
                            -1))), "*", sir_utils.make_var_access_expr("m")),
            "=")
    ])

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)
    return vertical_region_stmt
Beispiel #13
0
def if_stmt():
    outputfile = "../input/test_set_stage_location_type_if_stmt.sir"

    interval = serial_utils.make_interval(
        SIR.Interval.Start, SIR.Interval.End, 0, 0)

    body_ast = serial_utils.make_ast(
        [
            serial_utils.make_var_decl_stmt(
                serial_utils.make_type(serial_utils.BuiltinType.Float),
                "out_var_cell"),
            serial_utils.make_if_stmt(serial_utils.make_expr_stmt(serial_utils.make_var_access_expr("out_var_cell")), serial_utils.make_block_stmt(serial_utils.make_assignment_stmt(
                serial_utils.make_var_access_expr("out_var_cell"),
                serial_utils.make_field_access_expr("in_cell"),
                "=",
            ))),
        ]
    )

    vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward
    )

    sir = serial_utils.make_sir(
        outputfile,
        SIR.GridType.Value("Unstructured"),
        [
            serial_utils.make_stencil(
                "generated",
                serial_utils.make_ast([vertical_region_stmt]),
                [
                    serial_utils.make_field(
                        "in_cell",
                        serial_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Cell")], 1
                        ),
                    ),
                ],
            ),
        ],
    )

    f = open(outputfile, "w")
    f.write(MessageToJson(sir))
    f.close()
Beispiel #14
0
def main(args: argparse.Namespace):
    interval = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End,
                                          0, 0)

    # create the out = in[i+1] statement
    body_ast = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("out", [0, 0, 0]),
            serial_utils.make_field_access_expr("in", [1, 0, 0]),
            "=",
        )
    ])

    vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt(
        body_ast, interval, AST.VerticalRegion.Forward)

    sir = serial_utils.make_sir(
        OUTPUT_FILE,
        AST.GridType.Value("Cartesian"),
        [
            serial_utils.make_stencil(
                OUTPUT_NAME,
                serial_utils.make_ast([vertical_region_stmt]),
                [
                    serial_utils.make_field(
                        "in", serial_utils.make_field_dimensions_cartesian()),
                    serial_utils.make_field(
                        "out", serial_utils.make_field_dimensions_cartesian()),
                ],
            )
        ],
    )

    # print the SIR
    if args.verbose:
        print(MessageToJson(sir))

    # compile
    code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CUDA)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
Beispiel #15
0
def create_vertical_region_stmt() -> SIR.VerticalRegionDeclStmt:
    """ create a vertical region statement for the stencil
    """

    interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0,
                                       0)

    # create the out = in[i+1] statement
    body_ast = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("out", [0, 0, 0]),
            sir_utils.make_field_access_expr("in", [0, 0, 0]),
            "=",
        )
    ])

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)
    return vertical_region_stmt
def create_vertical_region_stmt3():
    """ create a vertical region statement for the stencil
    """

    interval = sir_utils.make_interval(sir_utils.Interval.Start,
                                       sir_utils.Interval.End, 0, -1)

    body_ast = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_unstructured_field_access_expr("d"),
            sir_utils.make_binary_operator(
                sir_utils.make_unstructured_field_access_expr("c"), "*",
                sir_utils.make_unstructured_field_access_expr(
                    "d", sir_utils.make_unstructured_offset(False), 1)), "-=")
    ])

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Backward)
    return vertical_region_stmt
Beispiel #17
0
def make_stencil(outputfile, body_ast):
    interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0,
                                       0)

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)

    sir = sir_utils.make_sir(
        outputfile,
        SIR.GridType.Value("Unstructured"),
        [
            sir_utils.make_stencil(
                "generated",
                sir_utils.make_ast([vertical_region_stmt]),
                [],
            ),
        ],
    )
    f = open(outputfile, "w")
    f.write(MessageToJson(sir))
    f.close()
Beispiel #18
0
def make_copy_stencil_sir(name=None):
    OUTPUT_NAME = name if name is not None else "copy_stencil"
    OUTPUT_FILE = f"{OUTPUT_NAME}.cpp"

    interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0,
                                       0)

    # create the out = in[i+1] statement
    body_ast = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("out", [0, 0, 0]),
            sir_utils.make_field_access_expr("in", [1, 0, 0]),
            "=",
        )
    ])

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)

    sir = sir_utils.make_sir(
        OUTPUT_FILE,
        sir_utils.GridType.Value("Cartesian"),
        [
            sir_utils.make_stencil(
                OUTPUT_NAME,
                sir_utils.make_ast([vertical_region_stmt]),
                [
                    sir_utils.make_field(
                        "in", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "out", sir_utils.make_field_dimensions_cartesian())
                ],
            )
        ],
    )

    return sir
def main(args: argparse.Namespace):

    # ---- First vertical region statement ----
    interval_1 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                         0, 0)
    body_ast_1 = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("c"),
            sir_utils.make_binary_operator(
                sir_utils.make_field_access_expr("c"),
                "/",
                sir_utils.make_field_access_expr("b"),
            ),
            "=",
        )
    ])

    vertical_region_stmt_1 = sir_utils.make_vertical_region_decl_stmt(
        body_ast_1, interval_1, SIR.VerticalRegion.Forward)

    # ---- Second vertical region statement ----
    interval_2 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                         1, 0)

    body_ast_2 = sir_utils.make_ast([
        sir_utils.make_var_decl_stmt(
            sir_utils.make_type(SIR.BuiltinType.Integer),
            "m",
            0,
            "=",
            sir_utils.make_expr(
                sir_utils.make_binary_operator(
                    sir_utils.make_literal_access_expr("1.0",
                                                       SIR.BuiltinType.Float),
                    "/",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("b"),
                        "-",
                        sir_utils.make_binary_operator(
                            sir_utils.make_field_access_expr("a"),
                            "*",
                            sir_utils.make_field_access_expr("c", [0, 0, -1]),
                        ),
                    ),
                )),
        ),
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("c"),
            sir_utils.make_binary_operator(
                sir_utils.make_field_access_expr("c"), "*",
                sir_utils.make_var_access_expr("m")),
            "=",
        ),
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("d"),
            sir_utils.make_binary_operator(
                sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("d"),
                    "-",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("a"),
                        "*",
                        sir_utils.make_field_access_expr("d", [0, 0, -1]),
                    ),
                ),
                "*",
                sir_utils.make_var_access_expr("m"),
            ),
            "=",
        ),
    ])
    vertical_region_stmt_2 = sir_utils.make_vertical_region_decl_stmt(
        body_ast_2, interval_2, SIR.VerticalRegion.Forward)

    # ---- Third vertical region statement ----
    interval_3 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                         0, -1)
    body_ast_3 = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("d"),
            sir_utils.make_binary_operator(
                sir_utils.make_field_access_expr("c"),
                "*",
                sir_utils.make_field_access_expr("d", [0, 0, 1]),
            ),
            "-=",
        )
    ])

    vertical_region_stmt_3 = sir_utils.make_vertical_region_decl_stmt(
        body_ast_3, interval_3, SIR.VerticalRegion.Backward)

    sir = sir_utils.make_sir(
        OUTPUT_FILE,
        SIR.GridType.Value("Cartesian"),
        [
            sir_utils.make_stencil(
                OUTPUT_NAME,
                sir_utils.make_ast([
                    vertical_region_stmt_1, vertical_region_stmt_2,
                    vertical_region_stmt_3
                ]),
                [
                    sir_utils.make_field(
                        "a", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "b", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "c", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "d", sir_utils.make_field_dimensions_cartesian()),
                ],
            )
        ],
    )

    # print the SIR
    if args.verbose:
        sir_utils.pprint(sir)

    # compile
    code = dawn4py.compile(sir, backend="cuda")

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
Beispiel #20
0
def main():
    stencil_name = "general_weights"
    gen_outputfile = f"{stencil_name}.cpp"
    sir_outputfile = f"{stencil_name}.sir"

    interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0,
                                       0)

    body_ast = sir_utils.make_ast([
        # compute nabla2 using the diamond reduction
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("nabla2"),
            sir_utils.make_reduction_over_neighbor_expr(
                op="+",
                init=sir_utils.make_literal_access_expr(
                    "0.0", SIR.BuiltinType.Double),
                rhs=sir_utils.make_field_access_expr("vn_vert"),
                chain=[
                    SIR.LocationType.Value("Edge"),
                    SIR.LocationType.Value("Cell"),
                    SIR.LocationType.Value("Vertex")
                ],
                weights=[
                    sir_utils.make_field_access_expr("inv_primal_edge_length"),
                    sir_utils.make_field_access_expr("inv_primal_edge_length"),
                    sir_utils.make_field_access_expr("inv_primal_edge_length"),
                    sir_utils.make_field_access_expr("inv_primal_edge_length")
                ]),
            "=",
        ),
    ])

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)

    sir = sir_utils.make_sir(
        gen_outputfile,
        SIR.GridType.Value("Unstructured"),
        [
            sir_utils.make_stencil(
                stencil_name,
                sir_utils.make_ast([vertical_region_stmt]),
                [
                    sir_utils.make_field(
                        "inv_primal_edge_length",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1),
                    ),
                    sir_utils.make_field(
                        "vn_vert",
                        sir_utils.make_field_dimensions_unstructured([
                            SIR.LocationType.Value("Edge"),
                            SIR.LocationType.Value("Cell"),
                            SIR.LocationType.Value("Vertex")
                        ], 1),
                    ),
                    sir_utils.make_field(
                        "nabla2",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1),
                    ),
                ],
            ),
        ],
    )

    # write SIR to file (for debugging purposes)
    f = open(sir_outputfile, "w")
    f.write(MessageToJson(sir))
    f.close()

    # compile
    code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaiveIco)

    # write to file
    print(f"Writing generated code to '{gen_outputfile}'")
    with open(gen_outputfile, "w") as f:
        f.write(code)
Beispiel #21
0
def main(args: argparse.Namespace):
    interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0,
                                       0)

    # create the stencil body AST
    body_ast = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("lap"),
            sir_utils.make_binary_operator(
                sir_utils.make_binary_operator(
                    sir_utils.make_literal_access_expr("-4.0",
                                                       SIR.BuiltinType.Float),
                    "*",
                    sir_utils.make_field_access_expr("in"),
                ),
                "+",
                sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("coeff"),
                    "*",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("in", [1, 0, 0]),
                        "+",
                        sir_utils.make_binary_operator(
                            sir_utils.make_field_access_expr("in", [-1, 0, 0]),
                            "+",
                            sir_utils.make_binary_operator(
                                sir_utils.make_field_access_expr(
                                    "in", [0, 1, 0]),
                                "+",
                                sir_utils.make_field_access_expr(
                                    "in", [0, -1, 0]),
                            ),
                        ),
                    ),
                ),
            ),
            "=",
        ),
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("out"),
            sir_utils.make_binary_operator(
                sir_utils.make_binary_operator(
                    sir_utils.make_literal_access_expr("-4.0",
                                                       SIR.BuiltinType.Float),
                    "*",
                    sir_utils.make_field_access_expr("lap"),
                ),
                "+",
                sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("coeff"),
                    "*",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("lap", [1, 0, 0]),
                        "+",
                        sir_utils.make_binary_operator(
                            sir_utils.make_field_access_expr(
                                "lap", [-1, 0, 0]),
                            "+",
                            sir_utils.make_binary_operator(
                                sir_utils.make_field_access_expr(
                                    "lap", [0, 1, 0]),
                                "+",
                                sir_utils.make_field_access_expr(
                                    "lap", [0, -1, 0]),
                            ),
                        ),
                    ),
                ),
            ),
            "=",
        ),
    ])

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)

    sir = sir_utils.make_sir(
        OUTPUT_FILE,
        SIR.GridType.Value("Cartesian"),
        [
            sir_utils.make_stencil(
                OUTPUT_NAME,
                sir_utils.make_ast([vertical_region_stmt]),
                [
                    sir_utils.make_field(
                        "in", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "out", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "coeff", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "lap",
                        sir_utils.make_field_dimensions_cartesian(),
                        is_temporary=True),
                ],
            )
        ],
    )

    # print the SIR
    if args.verbose:
        sir_utils.pprint(sir)

    # compile
    code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CUDA)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
def main(args: argparse.Namespace):
    interval = serial_utils.make_interval(
        AST.Interval.Start, AST.Interval.End, 0, 0)

    # create the out = in[i+1] statement
    body_ast = serial_utils.make_ast(
        [
            serial_utils.make_assignment_stmt(
                serial_utils.make_unstructured_field_access_expr("out"),
                serial_utils.make_reduction_over_neighbor_expr(
                    "+",
                    serial_utils.make_unstructured_field_access_expr(
                        "in", horizontal_offset=serial_utils.make_unstructured_offset(False)),
                    serial_utils.make_literal_access_expr(
                        "1.0", AST.BuiltinType.Float),
                    chain=[AST.LocationType.Value(
                        "Cell"), AST.LocationType.Value("Edge"), AST.LocationType.Value("Cell")],
                ),
                "=",
            )
        ]
    )

    vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt(
        body_ast, interval, AST.VerticalRegion.Forward
    )

    sir = serial_utils.make_sir(
        OUTPUT_FILE,
        AST.GridType.Value("Unstructured"),
        [
            serial_utils.make_stencil(
                OUTPUT_NAME,
                serial_utils.make_ast([vertical_region_stmt]),
                [
                    serial_utils.make_field(
                        "in",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Cell")], 1
                        ),
                    ),
                    serial_utils.make_field(
                        "out",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Cell")], 1
                        ),
                    ),
                ],
            ),
        ],
    )

    # print the SIR
    f = open("unstructured_stencil.sir", "w")
    f.write(MessageToJson(sir))
    f.close()
   
    # compile
    code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaiveIco)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
Beispiel #23
0
def make_hori_diff_stencil_sir(name=None):
    OUTPUT_NAME = name if name is not None else "hori_diff_stencil"
    OUTPUT_FILE = f"{OUTPUT_NAME}.cpp"

    interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0,
                                       0)

    # create the stencil body AST
    body_ast = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("lap"),
            sir_utils.make_binary_operator(
                sir_utils.make_binary_operator(
                    sir_utils.make_literal_access_expr("-4.0",
                                                       SIR.BuiltinType.Float),
                    "*",
                    sir_utils.make_field_access_expr("in"),
                ),
                "+",
                sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("coeff"),
                    "*",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("in", [1, 0, 0]),
                        "+",
                        sir_utils.make_binary_operator(
                            sir_utils.make_field_access_expr("in", [-1, 0, 0]),
                            "+",
                            sir_utils.make_binary_operator(
                                sir_utils.make_field_access_expr(
                                    "in", [0, 1, 0]),
                                "+",
                                sir_utils.make_field_access_expr(
                                    "in", [0, -1, 0]),
                            ),
                        ),
                    ),
                ),
            ),
            "=",
        ),
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("out"),
            sir_utils.make_binary_operator(
                sir_utils.make_binary_operator(
                    sir_utils.make_literal_access_expr("-4.0",
                                                       SIR.BuiltinType.Float),
                    "*",
                    sir_utils.make_field_access_expr("lap"),
                ),
                "+",
                sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("coeff"),
                    "*",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("lap", [1, 0, 0]),
                        "+",
                        sir_utils.make_binary_operator(
                            sir_utils.make_field_access_expr(
                                "lap", [-1, 0, 0]),
                            "+",
                            sir_utils.make_binary_operator(
                                sir_utils.make_field_access_expr(
                                    "lap", [0, 1, 0]),
                                "+",
                                sir_utils.make_field_access_expr(
                                    "lap", [0, -1, 0]),
                            ),
                        ),
                    ),
                ),
            ),
            "=",
        ),
    ])

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)

    sir = sir_utils.make_sir(
        OUTPUT_FILE,
        sir_utils.GridType.Value("Cartesian"),
        [
            sir_utils.make_stencil(
                OUTPUT_NAME,
                sir_utils.make_ast([vertical_region_stmt]),
                [
                    sir_utils.make_field(
                        "in", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "out", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "coeff", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "lap",
                        sir_utils.make_field_dimensions_cartesian(),
                        is_temporary=True),
                ],
            )
        ],
    )

    return sir
def main():
    stencil_name = "ICON_laplacian_diamond_stencil"
    gen_outputfile = f"{stencil_name}.cpp"
    sir_outputfile = f"{stencil_name}.sir"

    interval = sir_utils.make_interval(
        SIR.Interval.Start, SIR.Interval.End, 0, 0)

    body_ast = sir_utils.make_ast(
        [
            # fill sparse dimension vn vert using the loop concept
            sir_utils.make_loop_stmt(
                [sir_utils.make_assignment_stmt(
                    sir_utils.make_field_access_expr("vn_vert"),
                    sir_utils.make_binary_operator(
                        sir_utils.make_binary_operator(sir_utils.make_field_access_expr(
                            "u_vert", [True, 0]), "*", sir_utils.make_field_access_expr("primal_normal_x", [True, 0])),
                        "+", sir_utils.make_binary_operator(sir_utils.make_field_access_expr(
                            "v_vert", [True, 0]), "*", sir_utils.make_field_access_expr("primal_normal_y", [True, 0])),
                    ),
                    "=")],
                [SIR.LocationType.Value(
                    "Edge"), SIR.LocationType.Value("Cell"), SIR.LocationType.Value("Vertex")]
            ),
            # dvt_tang for smagorinsky
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("dvt_tang"),
                sir_utils.make_reduction_over_neighbor_expr(
                    op="+",
                    init=sir_utils.make_literal_access_expr(
                        "0.0", SIR.BuiltinType.Double),
                    rhs=sir_utils.make_binary_operator(
                        sir_utils.make_binary_operator(sir_utils.make_field_access_expr(
                            "u_vert", [True, 0]), "*", sir_utils.make_field_access_expr("dual_normal_x", [True, 0])),
                        "+", sir_utils.make_binary_operator(sir_utils.make_field_access_expr(
                            "v_vert", [True, 0]), "*", sir_utils.make_field_access_expr("dual_normal_y", [True, 0])),
                    ),
                    chain=[SIR.LocationType.Value("Edge"), SIR.LocationType.Value(
                        "Cell"), SIR.LocationType.Value("Vertex")],
                    weights=[sir_utils.make_literal_access_expr(
                        "-1.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr(
                        "1.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr(
                        "0.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr(
                        "0.0", SIR.BuiltinType.Double)]

                ),
                "=",
            ),
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("dvt_tang"), sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("dvt_tang"), "*", sir_utils.make_field_access_expr("tangent_orientation")), "="),
            # dvt_norm for smagorinsky
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("dvt_norm"),
                sir_utils.make_reduction_over_neighbor_expr(
                    op="+",
                    init=sir_utils.make_literal_access_expr(
                        "0.0", SIR.BuiltinType.Double),
                    rhs=sir_utils.make_binary_operator(
                        sir_utils.make_binary_operator(sir_utils.make_field_access_expr(
                            "u_vert", [True, 0]), "*", sir_utils.make_field_access_expr("dual_normal_x", [True, 0])),
                        "+", sir_utils.make_binary_operator(sir_utils.make_field_access_expr(
                            "v_vert", [True, 0]), "*", sir_utils.make_field_access_expr("dual_normal_y", [True, 0])),
                    ),
                    chain=[SIR.LocationType.Value("Edge"), SIR.LocationType.Value(
                        "Cell"), SIR.LocationType.Value("Vertex")],
                    weights=[sir_utils.make_literal_access_expr(
                        "0.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr(
                        "0.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr(
                        "-1.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr(
                        "1.0", SIR.BuiltinType.Double)]

                ),
                "=",
            ),
            # compute smagorinsky
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("kh_smag_1"),
                sir_utils.make_reduction_over_neighbor_expr(
                    op="+",
                    init=sir_utils.make_literal_access_expr(
                        "0.0", SIR.BuiltinType.Double),
                    rhs=sir_utils.make_field_access_expr("vn_vert"),
                    chain=[SIR.LocationType.Value("Edge"), SIR.LocationType.Value(
                        "Cell"), SIR.LocationType.Value("Vertex")],
                    weights=[sir_utils.make_literal_access_expr(
                        "-1.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr(
                        "1.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr(
                        "0.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr(
                        "0.0", SIR.BuiltinType.Double)]

                ),
                "=",
            ),
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("kh_smag_1"),
                sir_utils.make_binary_operator(
                    sir_utils.make_binary_operator(
                        sir_utils.make_binary_operator(
                            sir_utils.make_field_access_expr("kh_smag_1"),
                            "*",
                            sir_utils.make_field_access_expr("tangent_orientation")),
                        "*",
                        sir_utils.make_field_access_expr("inv_primal_edge_length")), "+",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("dvt_norm"),
                        "*",
                        sir_utils.make_field_access_expr("inv_vert_vert_length"))), "="),
            sir_utils.make_assignment_stmt(sir_utils.make_field_access_expr("kh_smag_1"),
                                           sir_utils.make_binary_operator(sir_utils.make_field_access_expr(
                                               "kh_smag_1"), "*", sir_utils.make_field_access_expr("kh_smag_1"))),
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("kh_smag_2"),
                sir_utils.make_reduction_over_neighbor_expr(
                    op="+",
                    init=sir_utils.make_literal_access_expr(
                        "0.0", SIR.BuiltinType.Double),
                    rhs=sir_utils.make_field_access_expr("vn_vert"),
                    chain=[SIR.LocationType.Value("Edge"), SIR.LocationType.Value(
                        "Cell"), SIR.LocationType.Value("Vertex")],
                    weights=[sir_utils.make_literal_access_expr(
                        "0.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr(
                        "0.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr(
                        "-1.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr(
                        " 1.0", SIR.BuiltinType.Double)]

                ),
                "=",
            ),
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("kh_smag_2"),
                sir_utils.make_binary_operator(
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("kh_smag_2"),
                        "*",
                        sir_utils.make_field_access_expr("inv_vert_vert_length")),
                    "+",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("dvt_tang"),
                        "*",
                        sir_utils.make_field_access_expr("inv_primal_edge_length"))), "="),
            sir_utils.make_assignment_stmt(sir_utils.make_field_access_expr("kh_smag_2"),
                                           sir_utils.make_binary_operator(sir_utils.make_field_access_expr(
                                               "kh_smag_2"), "*", sir_utils.make_field_access_expr("kh_smag_2"))),
            # currently not able to forward a sqrt, so this is technically kh_smag**2
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("kh_smag"),
                sir_utils.make_binary_operator(sir_utils.make_field_access_expr("diff_multfac_smag"), "*",
                                               sir_utils.make_fun_call_expr("math::sqrt",
                                                                            [sir_utils.make_binary_operator(sir_utils.make_field_access_expr(
                                                                                "kh_smag_1"), "+", sir_utils.make_field_access_expr("kh_smag_2"))])),
                "="),
            # compute nabla2 using the diamond reduction
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("nabla2"),
                sir_utils.make_reduction_over_neighbor_expr(
                    op="+",
                    init=sir_utils.make_literal_access_expr(
                        "0.0", SIR.BuiltinType.Double),
                    rhs=sir_utils.make_binary_operator(sir_utils.make_literal_access_expr(
                        "4.0", SIR.BuiltinType.Double), "*", sir_utils.make_field_access_expr("vn_vert")),
                    chain=[SIR.LocationType.Value("Edge"), SIR.LocationType.Value(
                        "Cell"), SIR.LocationType.Value("Vertex")],
                    weights=[
                        sir_utils.make_binary_operator(
                            sir_utils.make_field_access_expr(
                                "inv_primal_edge_length"),
                            '*',
                            sir_utils.make_field_access_expr(
                                "inv_primal_edge_length")),
                        sir_utils.make_binary_operator(
                            sir_utils.make_field_access_expr(
                                "inv_primal_edge_length"),
                            '*',
                            sir_utils.make_field_access_expr(
                                "inv_primal_edge_length")),
                        sir_utils.make_binary_operator(
                            sir_utils.make_field_access_expr(
                                "inv_vert_vert_length"),
                            '*',
                            sir_utils.make_field_access_expr(
                                "inv_vert_vert_length")),
                        sir_utils.make_binary_operator(
                            sir_utils.make_field_access_expr(
                                "inv_vert_vert_length"),
                            '*',
                            sir_utils.make_field_access_expr(
                                "inv_vert_vert_length")),
                    ]
                ),
                "=",
            ),
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("nabla2"),
                sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("nabla2"),
                    "-",
                    sir_utils.make_binary_operator(
                        sir_utils.make_binary_operator(sir_utils.make_binary_operator(sir_utils.make_literal_access_expr(
                            "8.0", SIR.BuiltinType.Double), "*", sir_utils.make_field_access_expr("vn")), "*",
                            sir_utils.make_binary_operator(
                                sir_utils.make_field_access_expr(
                                    "inv_primal_edge_length"),
                                "*",
                                sir_utils.make_field_access_expr(
                                    "inv_primal_edge_length"))),
                        "+",
                        sir_utils.make_binary_operator(sir_utils.make_binary_operator(sir_utils.make_literal_access_expr(
                            "8.0", SIR.BuiltinType.Double), "*", sir_utils.make_field_access_expr("vn")), "*",
                            sir_utils.make_binary_operator(
                                sir_utils.make_field_access_expr(
                                    "inv_vert_vert_length"),
                                "*",
                                sir_utils.make_field_access_expr(
                                    "inv_vert_vert_length"))))),
                "=")
        ]
    )

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward
    )

    sir = sir_utils.make_sir(
        gen_outputfile,
        SIR.GridType.Value("Unstructured"),
        [
            sir_utils.make_stencil(
                stencil_name,
                sir_utils.make_ast([vertical_region_stmt]),
                [
                    sir_utils.make_field(
                        "diff_multfac_smag",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value(
                                "Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "tangent_orientation",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "inv_primal_edge_length",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "inv_vert_vert_length",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "u_vert",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Vertex")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "v_vert",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Vertex")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "primal_normal_x",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge"), SIR.LocationType.Value(
                                "Cell"), SIR.LocationType.Value("Vertex")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "primal_normal_y",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge"), SIR.LocationType.Value(
                                "Cell"), SIR.LocationType.Value("Vertex")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "dual_normal_x",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge"), SIR.LocationType.Value(
                                "Cell"), SIR.LocationType.Value("Vertex")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "dual_normal_y",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge"), SIR.LocationType.Value(
                                "Cell"), SIR.LocationType.Value("Vertex")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "vn_vert",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge"), SIR.LocationType.Value(
                                "Cell"), SIR.LocationType.Value("Vertex")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "vn",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "dvt_tang",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "dvt_norm",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "kh_smag_1",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "kh_smag_2",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "kh_smag",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "nabla2",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                ],
            ),
        ],
    )

    # write SIR to file (for debugging purposes)
    f = open(sir_outputfile, "w")
    f.write(MessageToJson(sir))
    f.close()

    # compile
    code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaiveIco)

    # write to file
    print(f"Writing generated code to '{gen_outputfile}'")
    with open(gen_outputfile, "w") as f:
        f.write(code)
Beispiel #25
0
def main(args: argparse.Namespace):

    # ---- First vertical region statement ----
    interval_1 = serial_utils.make_interval(AST.Interval.Start,
                                            AST.Interval.End, 0, 0)
    body_ast_1 = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("c"),
            serial_utils.make_binary_operator(
                serial_utils.make_field_access_expr("c"),
                "/",
                serial_utils.make_field_access_expr("b"),
            ),
            "=",
        )
    ])

    vertical_region_stmt_1 = serial_utils.make_vertical_region_decl_stmt(
        body_ast_1, interval_1, AST.VerticalRegion.Forward)

    # ---- Second vertical region statement ----
    interval_2 = serial_utils.make_interval(AST.Interval.Start,
                                            AST.Interval.End, 1, 0)

    body_ast_2 = serial_utils.make_ast([
        serial_utils.make_var_decl_stmt(
            serial_utils.make_type(AST.BuiltinType.Integer),
            "m",
            0,
            "=",
            serial_utils.make_expr(
                serial_utils.make_binary_operator(
                    serial_utils.make_literal_access_expr(
                        "1.0", AST.BuiltinType.Float),
                    "/",
                    serial_utils.make_binary_operator(
                        serial_utils.make_field_access_expr("b"),
                        "-",
                        serial_utils.make_binary_operator(
                            serial_utils.make_field_access_expr("a"),
                            "*",
                            serial_utils.make_field_access_expr(
                                "c", [0, 0, -1]),
                        ),
                    ),
                )),
        ),
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("c"),
            serial_utils.make_binary_operator(
                serial_utils.make_field_access_expr("c"), "*",
                serial_utils.make_var_access_expr("m")),
            "=",
        ),
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("d"),
            serial_utils.make_binary_operator(
                serial_utils.make_binary_operator(
                    serial_utils.make_field_access_expr("d"),
                    "-",
                    serial_utils.make_binary_operator(
                        serial_utils.make_field_access_expr("a"),
                        "*",
                        serial_utils.make_field_access_expr("d", [0, 0, -1]),
                    ),
                ),
                "*",
                serial_utils.make_var_access_expr("m"),
            ),
            "=",
        ),
    ])
    vertical_region_stmt_2 = serial_utils.make_vertical_region_decl_stmt(
        body_ast_2, interval_2, AST.VerticalRegion.Forward)

    # ---- Third vertical region statement ----
    interval_3 = serial_utils.make_interval(AST.Interval.Start,
                                            AST.Interval.End, 0, -1)
    body_ast_3 = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("d"),
            serial_utils.make_binary_operator(
                serial_utils.make_field_access_expr("c"),
                "*",
                serial_utils.make_field_access_expr("d", [0, 0, 1]),
            ),
            "-=",
        )
    ])

    vertical_region_stmt_3 = serial_utils.make_vertical_region_decl_stmt(
        body_ast_3, interval_3, AST.VerticalRegion.Backward)

    sir = serial_utils.make_sir(
        OUTPUT_FILE,
        AST.GridType.Value("Cartesian"),
        [
            serial_utils.make_stencil(
                OUTPUT_NAME,
                serial_utils.make_ast([
                    vertical_region_stmt_1, vertical_region_stmt_2,
                    vertical_region_stmt_3
                ]),
                [
                    serial_utils.make_field(
                        "a", serial_utils.make_field_dimensions_cartesian()),
                    serial_utils.make_field(
                        "b", serial_utils.make_field_dimensions_cartesian()),
                    serial_utils.make_field(
                        "c", serial_utils.make_field_dimensions_cartesian()),
                    serial_utils.make_field(
                        "d", serial_utils.make_field_dimensions_cartesian()),
                ],
            )
        ],
    )

    # print the SIR
    if args.verbose:
        print(MessageToJson(sir))

    # compile
    pass_groups = dawn4py.default_pass_groups()
    pass_groups.insert(1, dawn4py.PassGroup.MultiStageMerger)
    code = dawn4py.compile(sir,
                           groups=pass_groups,
                           backend=dawn4py.CodeGenBackend.CUDA)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
def main():
    stencil_name = "ICON_laplacian_stencil"
    gen_outputfile = f"{stencil_name}.cpp"
    sir_outputfile = f"{stencil_name}.sir"

    interval = sir_utils.make_interval(
        SIR.Interval.Start, SIR.Interval.End, 0, 0)

    body_ast = sir_utils.make_ast(
        [
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("rot_vec"),
                sir_utils.make_reduction_over_neighbor_expr(
                    op="+",
                    init=sir_utils.make_literal_access_expr(
                        "0.0", SIR.BuiltinType.Double),
                    rhs=sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("vec"),
                        "*",
                        sir_utils.make_field_access_expr("geofac_rot")),
                    chain=[SIR.LocationType.Value(
                        "Vertex"), SIR.LocationType.Value("Edge")]
                ),
                "=",
            ),
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("div_vec"),
                sir_utils.make_reduction_over_neighbor_expr(
                    op="+",
                    init=sir_utils.make_literal_access_expr(
                        "0.0", SIR.BuiltinType.Double),
                    rhs=sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("vec"),
                        "*",
                        sir_utils.make_field_access_expr("geofac_div")),
                    chain=[SIR.LocationType.Value(
                        "Cell"), SIR.LocationType.Value("Edge")]
                ),
                "=",
            ),
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("nabla2t1_vec"),
                sir_utils.make_reduction_over_neighbor_expr(
                    op="+",
                    init=sir_utils.make_literal_access_expr(
                        "0.0", SIR.BuiltinType.Double),
                    rhs=sir_utils.make_field_access_expr("rot_vec"),
                    chain=[SIR.LocationType.Value(
                        "Edge"), SIR.LocationType.Value("Vertex")],
                    weights=[sir_utils.make_literal_access_expr(
                        "-1.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr(
                        "1.0", SIR.BuiltinType.Double)]
                ),
                "=",
            ),
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("nabla2t1_vec"),
                sir_utils.make_binary_operator(
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr(
                            "tangent_orientation"),
                        "*",
                        sir_utils.make_field_access_expr("nabla2t1_vec")),
                    "/",
                    sir_utils.make_field_access_expr("primal_edge_length")),
                "=",
            ),
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("nabla2t2_vec"),
                sir_utils.make_reduction_over_neighbor_expr(
                    op="+",
                    init=sir_utils.make_literal_access_expr(
                        "0.0", SIR.BuiltinType.Double),
                    rhs=sir_utils.make_field_access_expr("div_vec"),
                    chain=[SIR.LocationType.Value(
                        "Edge"), SIR.LocationType.Value("Cell")],
                    weights=[sir_utils.make_literal_access_expr(
                        "-1.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr(
                        "1.0", SIR.BuiltinType.Double)]
                ),
                "=",
            ),
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("nabla2t2_vec"),
                sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("nabla2t2_vec"),
                    "/",
                    sir_utils.make_field_access_expr("dual_edge_length")),
                "=",
            ),
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("nabla2_vec"),
                sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("nabla2t2_vec"),
                    "-",
                    sir_utils.make_field_access_expr("nabla2t1_vec")),
                "=",
            ),
        ]
    )

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward
    )

    sir = sir_utils.make_sir(
        gen_outputfile,
        SIR.GridType.Value("Unstructured"),
        [
            sir_utils.make_stencil(
                stencil_name,
                sir_utils.make_ast([vertical_region_stmt]),
                [
                    sir_utils.make_field(
                        "vec",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "div_vec",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Cell")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "rot_vec",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Vertex")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "nabla2t1_vec",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "nabla2t2_vec",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "nabla2_vec",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "primal_edge_length",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "dual_edge_length",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "tangent_orientation",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "geofac_rot",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value(
                                "Vertex"), SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "geofac_div",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value(
                                "Cell"), SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                ],
            ),
        ],
    )

    # write SIR to file (for debugging purposes)
    f = open(sir_outputfile, "w")
    f.write(MessageToJson(sir))
    f.close()

    # compile
    code = dawn4py.compile(sir, backend="c++-naive-ico")

    # write to file
    print(f"Writing generated code to '{gen_outputfile}'")
    with open(gen_outputfile, "w") as f:
        f.write(code)
Beispiel #27
0
def make_tridiagonal_solve_stencil_sir(name=None):
    OUTPUT_NAME = name if name is not None else "tridiagonal_solve_stencil"
    OUTPUT_FILE = f"{OUTPUT_NAME}.cpp"

    # ---- First vertical region statement ----
    interval_1 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                         0, 0)
    body_ast_1 = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("c"),
            sir_utils.make_binary_operator(
                sir_utils.make_field_access_expr("c"),
                "/",
                sir_utils.make_field_access_expr("b"),
            ),
            "=",
        )
    ])

    vertical_region_stmt_1 = sir_utils.make_vertical_region_decl_stmt(
        body_ast_1, interval_1, SIR.VerticalRegion.Forward)

    # ---- Second vertical region statement ----
    interval_2 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                         1, 0)

    body_ast_2 = sir_utils.make_ast([
        sir_utils.make_var_decl_stmt(
            sir_utils.make_type(SIR.BuiltinType.Integer),
            "m",
            0,
            "=",
            sir_utils.make_expr(
                sir_utils.make_binary_operator(
                    sir_utils.make_literal_access_expr("1.0",
                                                       SIR.BuiltinType.Float),
                    "/",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("b"),
                        "-",
                        sir_utils.make_binary_operator(
                            sir_utils.make_field_access_expr("a"),
                            "*",
                            sir_utils.make_field_access_expr("c", [0, 0, -1]),
                        ),
                    ),
                )),
        ),
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("c"),
            sir_utils.make_binary_operator(
                sir_utils.make_field_access_expr("c"), "*",
                sir_utils.make_var_access_expr("m")),
            "=",
        ),
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("d"),
            sir_utils.make_binary_operator(
                sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("d"),
                    "-",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("a"),
                        "*",
                        sir_utils.make_field_access_expr("d", [0, 0, -1]),
                    ),
                ),
                "*",
                sir_utils.make_var_access_expr("m"),
            ),
            "=",
        ),
    ])
    vertical_region_stmt_2 = sir_utils.make_vertical_region_decl_stmt(
        body_ast_2, interval_2, SIR.VerticalRegion.Forward)

    # ---- Third vertical region statement ----
    interval_3 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                         0, -1)
    body_ast_3 = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("d"),
            sir_utils.make_binary_operator(
                sir_utils.make_field_access_expr("c"),
                "*",
                sir_utils.make_field_access_expr("d", [0, 0, 1]),
            ),
            "-=",
        )
    ])

    vertical_region_stmt_3 = sir_utils.make_vertical_region_decl_stmt(
        body_ast_3, interval_3, SIR.VerticalRegion.Backward)

    sir = sir_utils.make_sir(
        OUTPUT_FILE,
        sir_utils.GridType.Value("Cartesian"),
        [
            sir_utils.make_stencil(
                OUTPUT_NAME,
                sir_utils.make_ast([
                    vertical_region_stmt_1, vertical_region_stmt_2,
                    vertical_region_stmt_3
                ]),
                [
                    sir_utils.make_field(
                        "a", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "b", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "c", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "d", sir_utils.make_field_dimensions_cartesian()),
                ],
            )
        ],
    )

    return sir
def sparse_temporary():
    outputfile = "AlsoDemoteWeight"
    interval = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End,
                                          0, 0)

    body_ast = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_unstructured_field_access_expr("tempF"),
            serial_utils.make_unstructured_field_access_expr("test"),
        ),
        serial_utils.make_assignment_stmt(
            serial_utils.make_unstructured_field_access_expr("outF"),
            serial_utils.make_reduction_over_neighbor_expr(
                "+",
                serial_utils.make_unstructured_field_access_expr("inF"),
                serial_utils.make_literal_access_expr("0.",
                                                      AST.BuiltinType.Double),
                [
                    AST.LocationType.Value("Edge"),
                    AST.LocationType.Value("Cell")
                ],
                weights=[
                    serial_utils.make_unstructured_field_access_expr("tempF"),
                    serial_utils.make_unstructured_field_access_expr("tempF"),
                    serial_utils.make_unstructured_field_access_expr("tempF"),
                    serial_utils.make_unstructured_field_access_expr("tempF")
                ]), "="),
    ])

    vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt(
        body_ast, interval, AST.VerticalRegion.Forward)

    sir = serial_utils.make_sir(
        outputfile,
        AST.GridType.Value("Unstructured"),
        [
            serial_utils.make_stencil(
                "generated",
                serial_utils.make_ast([vertical_region_stmt]),
                [
                    serial_utils.make_field(
                        "inF",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Cell")], 1),
                    ),
                    serial_utils.make_field(
                        "outF",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1),
                    ),
                    serial_utils.make_field(
                        "test",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1),
                    ),
                    serial_utils.make_field(
                        "tempF",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1),
                        is_temporary=True),
                ],
            ),
        ],
    )
    sim = dawn4py.lower_and_optimize(sir, groups=[])
    with open(outputfile, mode="w") as f:
        f.write(MessageToJson(sim["generated"]))
    os.rename(outputfile, "../input/" + outputfile + ".iir")
Beispiel #29
0
def main(args: argparse.Namespace):
    interval = serial_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                          0, 0)

    # create the laplace statement
    body_ast = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("out", [0, 0, 0]),
            serial_utils.make_binary_operator(
                serial_utils.make_binary_operator(
                    serial_utils.make_binary_operator(
                        serial_utils.make_field_access_expr("in", [0, 0, 0]),
                        "*",
                        serial_utils.make_literal_access_expr(
                            "-4.0", serial_utils.BuiltinType.Float),
                    ),
                    "+",
                    serial_utils.make_binary_operator(
                        serial_utils.make_field_access_expr("in", [1, 0, 0]),
                        "+",
                        serial_utils.make_binary_operator(
                            serial_utils.make_field_access_expr(
                                "in", [-1, 0, 0]),
                            "+",
                            serial_utils.make_binary_operator(
                                serial_utils.make_field_access_expr(
                                    "in", [0, 1, 0]),
                                "+",
                                serial_utils.make_field_access_expr(
                                    "in", [0, -1, 0]),
                            ),
                        ),
                    ),
                ),
                "/",
                serial_utils.make_binary_operator(
                    serial_utils.make_var_access_expr("dx", is_external=True),
                    "*",
                    serial_utils.make_var_access_expr("dx", is_external=True),
                ),
            ),
            "=",
        ),
    ])

    vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)

    stencils_globals = serial_utils.GlobalVariableMap()
    stencils_globals.map["dx"].double_value = 0.0

    sir = serial_utils.make_sir(
        OUTPUT_FILE,
        SIR.GridType.Value("Cartesian"),
        [
            serial_utils.make_stencil(
                OUTPUT_NAME,
                serial_utils.make_ast([vertical_region_stmt]),
                [
                    serial_utils.make_field(
                        "out", serial_utils.make_field_dimensions_cartesian()),
                    serial_utils.make_field(
                        "in", serial_utils.make_field_dimensions_cartesian()),
                ],
            )
        ],
        global_variables=stencils_globals,
    )

    # print the SIR
    if args.verbose:
        serial_utils.pprint(sir)

    # serialize the SIR to file
    sir_file = open("./laplacian_stencil_from_python.sir", "wb")
    sir_file.write(serial_utils.to_json(sir))
    sir_file.close()

    # compile
    code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaive)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
def main(args: argparse.Namespace):
    interval = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End,
                                          0, 0)

    # out = in_1 on inner cells
    body_ast_1 = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("out"),
            serial_utils.make_field_access_expr("in_1"),
            "=",
        )
    ])
    vertical_region_stmt_1 = serial_utils.make_vertical_region_decl_stmt(
        body_ast_1, interval, AST.VerticalRegion.Forward,
        serial_utils.make_magic_num_interval(0, 1, 0, 0))

    # out = out + in_2 on inner cells
    #   should be merge-able to last stage
    body_ast_2 = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("out"),
            serial_utils.make_binary_operator(
                serial_utils.make_field_access_expr("out"),
                "+",
                serial_utils.make_field_access_expr("in_2"),
            ), "=")
    ])
    vertical_region_stmt_2 = serial_utils.make_vertical_region_decl_stmt(
        body_ast_2, interval, AST.VerticalRegion.Forward,
        serial_utils.make_interval(2, 3, 0, 0))

    # out = out + in_3 on lateral boundary cells
    # out = in_1 on inner cells
    body_ast_3 = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("out"),
            serial_utils.make_field_access_expr("in_3"),
            "=",
        )
    ])
    vertical_region_stmt_3 = serial_utils.make_vertical_region_decl_stmt(
        body_ast_3, interval, AST.VerticalRegion.Forward,
        serial_utils.make_interval(3, 4, 0, 0))

    sir = serial_utils.make_sir(
        OUTPUT_FILE,
        AST.GridType.Value("Unstructured"),
        [
            serial_utils.make_stencil(
                OUTPUT_NAME,
                serial_utils.make_ast([
                    vertical_region_stmt_1, vertical_region_stmt_2,
                    vertical_region_stmt_3
                ]),
                [
                    serial_utils.make_field(
                        "in_1",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Cell")], 1),
                    ),
                    serial_utils.make_field(
                        "in_2",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Cell")], 1),
                    ),
                    serial_utils.make_field(
                        "in_3",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Cell")], 1),
                    ),
                    serial_utils.make_field(
                        "out",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Cell")], 1),
                    ),
                ],
            )
        ],
    )

    # print the SIR
    if args.verbose:
        print(MessageToJson(sir))

    # compile
    pass_groups = dawn4py.default_pass_groups()
    pass_groups.insert(1, dawn4py.PassGroup.MultiStageMerger)
    pass_groups.insert(1, dawn4py.PassGroup.StageMerger)
    # code = dawn4py.compile(sir, groups=pass_groups,
    #                        backend=dawn4py.CodeGenBackend.CXXNaiveIco, merge_stages=True, merge_do_methods=True)
    code = dawn4py.compile(sir,
                           groups=pass_groups,
                           backend=dawn4py.CodeGenBackend.CXXNaiveIco)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)