def create_vertical_region_stmt2():
    """ create a vertical region statement for the stencil
    """

    interval = sir_utils.make_interval(sir_utils.Interval.Start,
                                       sir_utils.Interval.End, 1, 0)

    body_ast = sir_utils.make_ast([
        sir_utils.make_var_decl_stmt(
            sir_utils.make_type(sir_utils.BuiltinType.Float), "m", 0, "=",
            sir_utils.make_expr(
                sir_utils.make_binary_operator(
                    sir_utils.make_literal_access_expr(
                        "1.0", sir_utils.BuiltinType.Float), "/",
                    sir_utils.make_binary_operator(
                        sir_utils.make_unstructured_field_access_expr("b"),
                        "-",
                        sir_utils.make_binary_operator(
                            sir_utils.make_unstructured_field_access_expr("a"),
                            "*",
                            sir_utils.make_unstructured_field_access_expr(
                                "c", sir_utils.make_unstructured_offset(False),
                                -1)))))),
        sir_utils.make_assignment_stmt(
            sir_utils.make_unstructured_field_access_expr("c"),
            sir_utils.make_binary_operator(
                sir_utils.make_unstructured_field_access_expr("c"), "*",
                sir_utils.make_var_access_expr("m")), "="),
        sir_utils.make_assignment_stmt(
            sir_utils.make_unstructured_field_access_expr("d"),
            sir_utils.make_binary_operator(
                sir_utils.make_binary_operator(
                    sir_utils.make_unstructured_field_access_expr("d"), "-",
                    sir_utils.make_binary_operator(
                        sir_utils.make_unstructured_field_access_expr("a"),
                        "*",
                        sir_utils.make_unstructured_field_access_expr(
                            "d", sir_utils.make_unstructured_offset(False),
                            -1))), "*", sir_utils.make_var_access_expr("m")),
            "=")
    ])

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)
    return vertical_region_stmt
Example #2
0
def if_stmt():
    outputfile = "../input/test_set_stage_location_type_if_stmt.sir"

    interval = serial_utils.make_interval(
        SIR.Interval.Start, SIR.Interval.End, 0, 0)

    body_ast = serial_utils.make_ast(
        [
            serial_utils.make_var_decl_stmt(
                serial_utils.make_type(serial_utils.BuiltinType.Float),
                "out_var_cell"),
            serial_utils.make_if_stmt(serial_utils.make_expr_stmt(serial_utils.make_var_access_expr("out_var_cell")), serial_utils.make_block_stmt(serial_utils.make_assignment_stmt(
                serial_utils.make_var_access_expr("out_var_cell"),
                serial_utils.make_field_access_expr("in_cell"),
                "=",
            ))),
        ]
    )

    vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward
    )

    sir = serial_utils.make_sir(
        outputfile,
        SIR.GridType.Value("Unstructured"),
        [
            serial_utils.make_stencil(
                "generated",
                serial_utils.make_ast([vertical_region_stmt]),
                [
                    serial_utils.make_field(
                        "in_cell",
                        serial_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Cell")], 1
                        ),
                    ),
                ],
            ),
        ],
    )

    f = open(outputfile, "w")
    f.write(MessageToJson(sir))
    f.close()
def main(args: argparse.Namespace):

    # ---- First vertical region statement ----
    interval_1 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                         0, 0)
    body_ast_1 = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("c"),
            sir_utils.make_binary_operator(
                sir_utils.make_field_access_expr("c"),
                "/",
                sir_utils.make_field_access_expr("b"),
            ),
            "=",
        )
    ])

    vertical_region_stmt_1 = sir_utils.make_vertical_region_decl_stmt(
        body_ast_1, interval_1, SIR.VerticalRegion.Forward)

    # ---- Second vertical region statement ----
    interval_2 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                         1, 0)

    body_ast_2 = sir_utils.make_ast([
        sir_utils.make_var_decl_stmt(
            sir_utils.make_type(SIR.BuiltinType.Integer),
            "m",
            0,
            "=",
            sir_utils.make_expr(
                sir_utils.make_binary_operator(
                    sir_utils.make_literal_access_expr("1.0",
                                                       SIR.BuiltinType.Float),
                    "/",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("b"),
                        "-",
                        sir_utils.make_binary_operator(
                            sir_utils.make_field_access_expr("a"),
                            "*",
                            sir_utils.make_field_access_expr("c", [0, 0, -1]),
                        ),
                    ),
                )),
        ),
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("c"),
            sir_utils.make_binary_operator(
                sir_utils.make_field_access_expr("c"), "*",
                sir_utils.make_var_access_expr("m")),
            "=",
        ),
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("d"),
            sir_utils.make_binary_operator(
                sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("d"),
                    "-",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("a"),
                        "*",
                        sir_utils.make_field_access_expr("d", [0, 0, -1]),
                    ),
                ),
                "*",
                sir_utils.make_var_access_expr("m"),
            ),
            "=",
        ),
    ])
    vertical_region_stmt_2 = sir_utils.make_vertical_region_decl_stmt(
        body_ast_2, interval_2, SIR.VerticalRegion.Forward)

    # ---- Third vertical region statement ----
    interval_3 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                         0, -1)
    body_ast_3 = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("d"),
            sir_utils.make_binary_operator(
                sir_utils.make_field_access_expr("c"),
                "*",
                sir_utils.make_field_access_expr("d", [0, 0, 1]),
            ),
            "-=",
        )
    ])

    vertical_region_stmt_3 = sir_utils.make_vertical_region_decl_stmt(
        body_ast_3, interval_3, SIR.VerticalRegion.Backward)

    sir = sir_utils.make_sir(
        OUTPUT_FILE,
        SIR.GridType.Value("Cartesian"),
        [
            sir_utils.make_stencil(
                OUTPUT_NAME,
                sir_utils.make_ast([
                    vertical_region_stmt_1, vertical_region_stmt_2,
                    vertical_region_stmt_3
                ]),
                [
                    sir_utils.make_field(
                        "a", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "b", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "c", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "d", sir_utils.make_field_dimensions_cartesian()),
                ],
            )
        ],
    )

    # print the SIR
    if args.verbose:
        sir_utils.pprint(sir)

    # compile
    code = dawn4py.compile(sir, backend="cuda")

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
Example #4
0
def main(args: argparse.Namespace):

    # ---- First vertical region statement ----
    interval_1 = serial_utils.make_interval(AST.Interval.Start,
                                            AST.Interval.End, 0, 0)
    body_ast_1 = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("c"),
            serial_utils.make_binary_operator(
                serial_utils.make_field_access_expr("c"),
                "/",
                serial_utils.make_field_access_expr("b"),
            ),
            "=",
        )
    ])

    vertical_region_stmt_1 = serial_utils.make_vertical_region_decl_stmt(
        body_ast_1, interval_1, AST.VerticalRegion.Forward)

    # ---- Second vertical region statement ----
    interval_2 = serial_utils.make_interval(AST.Interval.Start,
                                            AST.Interval.End, 1, 0)

    body_ast_2 = serial_utils.make_ast([
        serial_utils.make_var_decl_stmt(
            serial_utils.make_type(AST.BuiltinType.Integer),
            "m",
            0,
            "=",
            serial_utils.make_expr(
                serial_utils.make_binary_operator(
                    serial_utils.make_literal_access_expr(
                        "1.0", AST.BuiltinType.Float),
                    "/",
                    serial_utils.make_binary_operator(
                        serial_utils.make_field_access_expr("b"),
                        "-",
                        serial_utils.make_binary_operator(
                            serial_utils.make_field_access_expr("a"),
                            "*",
                            serial_utils.make_field_access_expr(
                                "c", [0, 0, -1]),
                        ),
                    ),
                )),
        ),
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("c"),
            serial_utils.make_binary_operator(
                serial_utils.make_field_access_expr("c"), "*",
                serial_utils.make_var_access_expr("m")),
            "=",
        ),
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("d"),
            serial_utils.make_binary_operator(
                serial_utils.make_binary_operator(
                    serial_utils.make_field_access_expr("d"),
                    "-",
                    serial_utils.make_binary_operator(
                        serial_utils.make_field_access_expr("a"),
                        "*",
                        serial_utils.make_field_access_expr("d", [0, 0, -1]),
                    ),
                ),
                "*",
                serial_utils.make_var_access_expr("m"),
            ),
            "=",
        ),
    ])
    vertical_region_stmt_2 = serial_utils.make_vertical_region_decl_stmt(
        body_ast_2, interval_2, AST.VerticalRegion.Forward)

    # ---- Third vertical region statement ----
    interval_3 = serial_utils.make_interval(AST.Interval.Start,
                                            AST.Interval.End, 0, -1)
    body_ast_3 = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("d"),
            serial_utils.make_binary_operator(
                serial_utils.make_field_access_expr("c"),
                "*",
                serial_utils.make_field_access_expr("d", [0, 0, 1]),
            ),
            "-=",
        )
    ])

    vertical_region_stmt_3 = serial_utils.make_vertical_region_decl_stmt(
        body_ast_3, interval_3, AST.VerticalRegion.Backward)

    sir = serial_utils.make_sir(
        OUTPUT_FILE,
        AST.GridType.Value("Cartesian"),
        [
            serial_utils.make_stencil(
                OUTPUT_NAME,
                serial_utils.make_ast([
                    vertical_region_stmt_1, vertical_region_stmt_2,
                    vertical_region_stmt_3
                ]),
                [
                    serial_utils.make_field(
                        "a", serial_utils.make_field_dimensions_cartesian()),
                    serial_utils.make_field(
                        "b", serial_utils.make_field_dimensions_cartesian()),
                    serial_utils.make_field(
                        "c", serial_utils.make_field_dimensions_cartesian()),
                    serial_utils.make_field(
                        "d", serial_utils.make_field_dimensions_cartesian()),
                ],
            )
        ],
    )

    # print the SIR
    if args.verbose:
        print(MessageToJson(sir))

    # compile
    pass_groups = dawn4py.default_pass_groups()
    pass_groups.insert(1, dawn4py.PassGroup.MultiStageMerger)
    code = dawn4py.compile(sir,
                           groups=pass_groups,
                           backend=dawn4py.CodeGenBackend.CUDA)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
Example #5
0
def make_tridiagonal_solve_stencil_sir(name=None):
    OUTPUT_NAME = name if name is not None else "tridiagonal_solve_stencil"
    OUTPUT_FILE = f"{OUTPUT_NAME}.cpp"

    # ---- First vertical region statement ----
    interval_1 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                         0, 0)
    body_ast_1 = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("c"),
            sir_utils.make_binary_operator(
                sir_utils.make_field_access_expr("c"),
                "/",
                sir_utils.make_field_access_expr("b"),
            ),
            "=",
        )
    ])

    vertical_region_stmt_1 = sir_utils.make_vertical_region_decl_stmt(
        body_ast_1, interval_1, SIR.VerticalRegion.Forward)

    # ---- Second vertical region statement ----
    interval_2 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                         1, 0)

    body_ast_2 = sir_utils.make_ast([
        sir_utils.make_var_decl_stmt(
            sir_utils.make_type(SIR.BuiltinType.Integer),
            "m",
            0,
            "=",
            sir_utils.make_expr(
                sir_utils.make_binary_operator(
                    sir_utils.make_literal_access_expr("1.0",
                                                       SIR.BuiltinType.Float),
                    "/",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("b"),
                        "-",
                        sir_utils.make_binary_operator(
                            sir_utils.make_field_access_expr("a"),
                            "*",
                            sir_utils.make_field_access_expr("c", [0, 0, -1]),
                        ),
                    ),
                )),
        ),
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("c"),
            sir_utils.make_binary_operator(
                sir_utils.make_field_access_expr("c"), "*",
                sir_utils.make_var_access_expr("m")),
            "=",
        ),
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("d"),
            sir_utils.make_binary_operator(
                sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("d"),
                    "-",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("a"),
                        "*",
                        sir_utils.make_field_access_expr("d", [0, 0, -1]),
                    ),
                ),
                "*",
                sir_utils.make_var_access_expr("m"),
            ),
            "=",
        ),
    ])
    vertical_region_stmt_2 = sir_utils.make_vertical_region_decl_stmt(
        body_ast_2, interval_2, SIR.VerticalRegion.Forward)

    # ---- Third vertical region statement ----
    interval_3 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                         0, -1)
    body_ast_3 = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("d"),
            sir_utils.make_binary_operator(
                sir_utils.make_field_access_expr("c"),
                "*",
                sir_utils.make_field_access_expr("d", [0, 0, 1]),
            ),
            "-=",
        )
    ])

    vertical_region_stmt_3 = sir_utils.make_vertical_region_decl_stmt(
        body_ast_3, interval_3, SIR.VerticalRegion.Backward)

    sir = sir_utils.make_sir(
        OUTPUT_FILE,
        sir_utils.GridType.Value("Cartesian"),
        [
            sir_utils.make_stencil(
                OUTPUT_NAME,
                sir_utils.make_ast([
                    vertical_region_stmt_1, vertical_region_stmt_2,
                    vertical_region_stmt_3
                ]),
                [
                    sir_utils.make_field(
                        "a", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "b", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "c", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "d", sir_utils.make_field_dimensions_cartesian()),
                ],
            )
        ],
    )

    return sir
Example #6
0
                "generated",
                sir_utils.make_ast([vertical_region_stmt]),
                [],
            ),
        ],
    )
    f = open(outputfile, "w")
    f.write(MessageToJson(sir))
    f.close()


if __name__ == "__main__":
    # for the "no stmt" test, use the "one stmt" output and delete the statement
    # (from SIR it is not possible to generate a stage without statements)

    one_stmt = sir_utils.make_ast([
        sir_utils.make_var_decl_stmt(
            sir_utils.make_type(SIR.BuiltinType.Integer), "a")
    ])
    make_stencil("../input/test_stage_split_all_statements_one_stmt.sir",
                 one_stmt)

    two_stmts = sir_utils.make_ast([
        sir_utils.make_var_decl_stmt(
            sir_utils.make_type(SIR.BuiltinType.Integer), "a"),
        sir_utils.make_var_decl_stmt(
            sir_utils.make_type(SIR.BuiltinType.Integer), "b")
    ])
    make_stencil("../input/test_stage_split_all_statements_two_stmt.sir",
                 two_stmts)