コード例 #1
0
ファイル: global_var.py プロジェクト: MeteoSwiss-APN/dawn
def main(args: argparse.Namespace):
    interval = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End,
                                          0, 0)

    body_ast = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("out", [0, 0, 0]),
            serial_utils.make_binary_operator(
                serial_utils.make_var_access_expr("dt", is_external=True), "*",
                serial_utils.make_field_access_expr("in", [1, 0, 0])), "="),
    ])

    vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt(
        body_ast, interval, AST.VerticalRegion.Forward)

    globals = AST.GlobalVariableMap()
    globals.map["dt"].double_value = 0.5

    sir = serial_utils.make_sir(
        OUTPUT_FILE,
        AST.GridType.Value("Cartesian"), [
            serial_utils.make_stencil(
                OUTPUT_NAME,
                serial_utils.make_ast([vertical_region_stmt]),
                [
                    serial_utils.make_field(
                        "in", serial_utils.make_field_dimensions_cartesian()),
                    serial_utils.make_field(
                        "out", serial_utils.make_field_dimensions_cartesian()),
                ],
            )
        ],
        global_variables=globals)

    # print the SIR
    if args.verbose:
        print(MessageToJson(sir))

    # compile
    code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaive)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
コード例 #2
0
ファイル: copy_stencil.py プロジェクト: MeteoSwiss-APN/dawn
def main(args: argparse.Namespace):
    interval = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End,
                                          0, 0)

    # create the out = in[i+1] statement
    body_ast = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("out", [0, 0, 0]),
            serial_utils.make_field_access_expr("in", [1, 0, 0]),
            "=",
        )
    ])

    vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt(
        body_ast, interval, AST.VerticalRegion.Forward)

    sir = serial_utils.make_sir(
        OUTPUT_FILE,
        AST.GridType.Value("Cartesian"),
        [
            serial_utils.make_stencil(
                OUTPUT_NAME,
                serial_utils.make_ast([vertical_region_stmt]),
                [
                    serial_utils.make_field(
                        "in", serial_utils.make_field_dimensions_cartesian()),
                    serial_utils.make_field(
                        "out", serial_utils.make_field_dimensions_cartesian()),
                ],
            )
        ],
    )

    # print the SIR
    if args.verbose:
        print(MessageToJson(sir))

    # compile
    code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CUDA)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
コード例 #3
0
 def visit_FieldDecl(self, node: gt_ir.FieldDecl, **kwargs: Any) -> None:
     # NOTE Add unstructured support here
     field_dimensions = sir_utils.make_field_dimensions_cartesian(
         [1 if ax in node.axes else 0 for ax in DOMAIN_AXES])
     self.field_info[node.name] = dict(
         field_decl=sir_utils.make_field(name=node.name,
                                         dimensions=field_dimensions,
                                         is_temporary=not node.is_api),
         access=None,
         extent=gt_definitions.Extent.zeros(),
         inputs=set(),
     )
コード例 #4
0
ファイル: utils.py プロジェクト: circleci/dawn
def make_copy_stencil_sir(name=None):
    OUTPUT_NAME = name if name is not None else "copy_stencil"
    OUTPUT_FILE = f"{OUTPUT_NAME}.cpp"

    interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0,
                                       0)

    # create the out = in[i+1] statement
    body_ast = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("out", [0, 0, 0]),
            sir_utils.make_field_access_expr("in", [1, 0, 0]),
            "=",
        )
    ])

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)

    sir = sir_utils.make_sir(
        OUTPUT_FILE,
        sir_utils.GridType.Value("Cartesian"),
        [
            sir_utils.make_stencil(
                OUTPUT_NAME,
                sir_utils.make_ast([vertical_region_stmt]),
                [
                    sir_utils.make_field(
                        "in", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "out", sir_utils.make_field_dimensions_cartesian())
                ],
            )
        ],
    )

    return sir
コード例 #5
0
def main(args: argparse.Namespace):

    # ---- First vertical region statement ----
    interval_1 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                         0, 0)
    body_ast_1 = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("c"),
            sir_utils.make_binary_operator(
                sir_utils.make_field_access_expr("c"),
                "/",
                sir_utils.make_field_access_expr("b"),
            ),
            "=",
        )
    ])

    vertical_region_stmt_1 = sir_utils.make_vertical_region_decl_stmt(
        body_ast_1, interval_1, SIR.VerticalRegion.Forward)

    # ---- Second vertical region statement ----
    interval_2 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                         1, 0)

    body_ast_2 = sir_utils.make_ast([
        sir_utils.make_var_decl_stmt(
            sir_utils.make_type(SIR.BuiltinType.Integer),
            "m",
            0,
            "=",
            sir_utils.make_expr(
                sir_utils.make_binary_operator(
                    sir_utils.make_literal_access_expr("1.0",
                                                       SIR.BuiltinType.Float),
                    "/",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("b"),
                        "-",
                        sir_utils.make_binary_operator(
                            sir_utils.make_field_access_expr("a"),
                            "*",
                            sir_utils.make_field_access_expr("c", [0, 0, -1]),
                        ),
                    ),
                )),
        ),
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("c"),
            sir_utils.make_binary_operator(
                sir_utils.make_field_access_expr("c"), "*",
                sir_utils.make_var_access_expr("m")),
            "=",
        ),
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("d"),
            sir_utils.make_binary_operator(
                sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("d"),
                    "-",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("a"),
                        "*",
                        sir_utils.make_field_access_expr("d", [0, 0, -1]),
                    ),
                ),
                "*",
                sir_utils.make_var_access_expr("m"),
            ),
            "=",
        ),
    ])
    vertical_region_stmt_2 = sir_utils.make_vertical_region_decl_stmt(
        body_ast_2, interval_2, SIR.VerticalRegion.Forward)

    # ---- Third vertical region statement ----
    interval_3 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                         0, -1)
    body_ast_3 = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("d"),
            sir_utils.make_binary_operator(
                sir_utils.make_field_access_expr("c"),
                "*",
                sir_utils.make_field_access_expr("d", [0, 0, 1]),
            ),
            "-=",
        )
    ])

    vertical_region_stmt_3 = sir_utils.make_vertical_region_decl_stmt(
        body_ast_3, interval_3, SIR.VerticalRegion.Backward)

    sir = sir_utils.make_sir(
        OUTPUT_FILE,
        SIR.GridType.Value("Cartesian"),
        [
            sir_utils.make_stencil(
                OUTPUT_NAME,
                sir_utils.make_ast([
                    vertical_region_stmt_1, vertical_region_stmt_2,
                    vertical_region_stmt_3
                ]),
                [
                    sir_utils.make_field(
                        "a", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "b", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "c", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "d", sir_utils.make_field_dimensions_cartesian()),
                ],
            )
        ],
    )

    # print the SIR
    if args.verbose:
        sir_utils.pprint(sir)

    # compile
    code = dawn4py.compile(sir, backend="cuda")

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
コード例 #6
0
def main(args: argparse.Namespace):
    sir = sir_utils.make_sir(
        OUTPUT_FILE,
        SIR.GridType.Value("Cartesian"),
        [
            sir_utils.make_stencil(
                "global_indexing",
                sir_utils.make_ast([
                    create_vertical_region_stmt(),
                    create_boundary_correction_region(
                        value="4",
                        i_interval=sir_utils.make_interval(
                            SIR.Interval.End, SIR.Interval.End, -1, 0),
                    ),
                    create_boundary_correction_region(
                        value="8",
                        i_interval=sir_utils.make_interval(
                            SIR.Interval.Start, SIR.Interval.Start, 0, 1),
                    ),
                    create_boundary_correction_region(
                        value="6",
                        j_interval=sir_utils.make_interval(
                            SIR.Interval.End, SIR.Interval.End, -1, 0),
                    ),
                    create_boundary_correction_region(
                        value="2",
                        j_interval=sir_utils.make_interval(
                            SIR.Interval.Start, SIR.Interval.Start, 0, 1),
                    ),
                    create_boundary_correction_region(
                        value="1",
                        j_interval=sir_utils.make_interval(
                            SIR.Interval.Start, SIR.Interval.Start, 0, 1),
                        i_interval=sir_utils.make_interval(
                            SIR.Interval.Start, SIR.Interval.Start, 0, 1),
                    ),
                    create_boundary_correction_region(
                        value="3",
                        j_interval=sir_utils.make_interval(
                            SIR.Interval.Start, SIR.Interval.Start, 0, 1),
                        i_interval=sir_utils.make_interval(
                            SIR.Interval.End, SIR.Interval.End, -1, 0),
                    ),
                    create_boundary_correction_region(
                        value="7",
                        j_interval=sir_utils.make_interval(
                            SIR.Interval.End, SIR.Interval.End, -1, 0),
                        i_interval=sir_utils.make_interval(
                            SIR.Interval.Start, SIR.Interval.Start, 0, 1),
                    ),
                    create_boundary_correction_region(
                        value="5",
                        j_interval=sir_utils.make_interval(
                            SIR.Interval.End, SIR.Interval.End, -1, 0),
                        i_interval=sir_utils.make_interval(
                            SIR.Interval.End, SIR.Interval.End, -1, 0),
                    ),
                ]),
                [
                    sir_utils.make_field(
                        "in", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "out", sir_utils.make_field_dimensions_cartesian())
                ],
            )
        ],
    )

    # print the SIR
    if args.verbose:
        sir_utils.pprint(sir)

    # compile
    code = dawn4py.compile(sir, backend="c++-naive")

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
コード例 #7
0
def main(args: argparse.Namespace):

    # ---- First vertical region statement ----
    interval_1 = serial_utils.make_interval(AST.Interval.Start,
                                            AST.Interval.End, 0, 0)
    body_ast_1 = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("c"),
            serial_utils.make_binary_operator(
                serial_utils.make_field_access_expr("c"),
                "/",
                serial_utils.make_field_access_expr("b"),
            ),
            "=",
        )
    ])

    vertical_region_stmt_1 = serial_utils.make_vertical_region_decl_stmt(
        body_ast_1, interval_1, AST.VerticalRegion.Forward)

    # ---- Second vertical region statement ----
    interval_2 = serial_utils.make_interval(AST.Interval.Start,
                                            AST.Interval.End, 1, 0)

    body_ast_2 = serial_utils.make_ast([
        serial_utils.make_var_decl_stmt(
            serial_utils.make_type(AST.BuiltinType.Integer),
            "m",
            0,
            "=",
            serial_utils.make_expr(
                serial_utils.make_binary_operator(
                    serial_utils.make_literal_access_expr(
                        "1.0", AST.BuiltinType.Float),
                    "/",
                    serial_utils.make_binary_operator(
                        serial_utils.make_field_access_expr("b"),
                        "-",
                        serial_utils.make_binary_operator(
                            serial_utils.make_field_access_expr("a"),
                            "*",
                            serial_utils.make_field_access_expr(
                                "c", [0, 0, -1]),
                        ),
                    ),
                )),
        ),
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("c"),
            serial_utils.make_binary_operator(
                serial_utils.make_field_access_expr("c"), "*",
                serial_utils.make_var_access_expr("m")),
            "=",
        ),
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("d"),
            serial_utils.make_binary_operator(
                serial_utils.make_binary_operator(
                    serial_utils.make_field_access_expr("d"),
                    "-",
                    serial_utils.make_binary_operator(
                        serial_utils.make_field_access_expr("a"),
                        "*",
                        serial_utils.make_field_access_expr("d", [0, 0, -1]),
                    ),
                ),
                "*",
                serial_utils.make_var_access_expr("m"),
            ),
            "=",
        ),
    ])
    vertical_region_stmt_2 = serial_utils.make_vertical_region_decl_stmt(
        body_ast_2, interval_2, AST.VerticalRegion.Forward)

    # ---- Third vertical region statement ----
    interval_3 = serial_utils.make_interval(AST.Interval.Start,
                                            AST.Interval.End, 0, -1)
    body_ast_3 = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("d"),
            serial_utils.make_binary_operator(
                serial_utils.make_field_access_expr("c"),
                "*",
                serial_utils.make_field_access_expr("d", [0, 0, 1]),
            ),
            "-=",
        )
    ])

    vertical_region_stmt_3 = serial_utils.make_vertical_region_decl_stmt(
        body_ast_3, interval_3, AST.VerticalRegion.Backward)

    sir = serial_utils.make_sir(
        OUTPUT_FILE,
        AST.GridType.Value("Cartesian"),
        [
            serial_utils.make_stencil(
                OUTPUT_NAME,
                serial_utils.make_ast([
                    vertical_region_stmt_1, vertical_region_stmt_2,
                    vertical_region_stmt_3
                ]),
                [
                    serial_utils.make_field(
                        "a", serial_utils.make_field_dimensions_cartesian()),
                    serial_utils.make_field(
                        "b", serial_utils.make_field_dimensions_cartesian()),
                    serial_utils.make_field(
                        "c", serial_utils.make_field_dimensions_cartesian()),
                    serial_utils.make_field(
                        "d", serial_utils.make_field_dimensions_cartesian()),
                ],
            )
        ],
    )

    # print the SIR
    if args.verbose:
        print(MessageToJson(sir))

    # compile
    pass_groups = dawn4py.default_pass_groups()
    pass_groups.insert(1, dawn4py.PassGroup.MultiStageMerger)
    code = dawn4py.compile(sir,
                           groups=pass_groups,
                           backend=dawn4py.CodeGenBackend.CUDA)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
コード例 #8
0
def main(args: argparse.Namespace):
    interval = serial_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                          0, 0)

    # create the laplace statement
    body_ast = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("out", [0, 0, 0]),
            serial_utils.make_binary_operator(
                serial_utils.make_binary_operator(
                    serial_utils.make_binary_operator(
                        serial_utils.make_field_access_expr("in", [0, 0, 0]),
                        "*",
                        serial_utils.make_literal_access_expr(
                            "-4.0", serial_utils.BuiltinType.Float),
                    ),
                    "+",
                    serial_utils.make_binary_operator(
                        serial_utils.make_field_access_expr("in", [1, 0, 0]),
                        "+",
                        serial_utils.make_binary_operator(
                            serial_utils.make_field_access_expr(
                                "in", [-1, 0, 0]),
                            "+",
                            serial_utils.make_binary_operator(
                                serial_utils.make_field_access_expr(
                                    "in", [0, 1, 0]),
                                "+",
                                serial_utils.make_field_access_expr(
                                    "in", [0, -1, 0]),
                            ),
                        ),
                    ),
                ),
                "/",
                serial_utils.make_binary_operator(
                    serial_utils.make_var_access_expr("dx", is_external=True),
                    "*",
                    serial_utils.make_var_access_expr("dx", is_external=True),
                ),
            ),
            "=",
        ),
    ])

    vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)

    stencils_globals = serial_utils.GlobalVariableMap()
    stencils_globals.map["dx"].double_value = 0.0

    sir = serial_utils.make_sir(
        OUTPUT_FILE,
        SIR.GridType.Value("Cartesian"),
        [
            serial_utils.make_stencil(
                OUTPUT_NAME,
                serial_utils.make_ast([vertical_region_stmt]),
                [
                    serial_utils.make_field(
                        "out", serial_utils.make_field_dimensions_cartesian()),
                    serial_utils.make_field(
                        "in", serial_utils.make_field_dimensions_cartesian()),
                ],
            )
        ],
        global_variables=stencils_globals,
    )

    # print the SIR
    if args.verbose:
        serial_utils.pprint(sir)

    # serialize the SIR to file
    sir_file = open("./laplacian_stencil_from_python.sir", "wb")
    sir_file.write(serial_utils.to_json(sir))
    sir_file.close()

    # compile
    code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaive)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
コード例 #9
0
ファイル: utils.py プロジェクト: circleci/dawn
def make_hori_diff_stencil_sir(name=None):
    OUTPUT_NAME = name if name is not None else "hori_diff_stencil"
    OUTPUT_FILE = f"{OUTPUT_NAME}.cpp"

    interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0,
                                       0)

    # create the stencil body AST
    body_ast = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("lap"),
            sir_utils.make_binary_operator(
                sir_utils.make_binary_operator(
                    sir_utils.make_literal_access_expr("-4.0",
                                                       SIR.BuiltinType.Float),
                    "*",
                    sir_utils.make_field_access_expr("in"),
                ),
                "+",
                sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("coeff"),
                    "*",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("in", [1, 0, 0]),
                        "+",
                        sir_utils.make_binary_operator(
                            sir_utils.make_field_access_expr("in", [-1, 0, 0]),
                            "+",
                            sir_utils.make_binary_operator(
                                sir_utils.make_field_access_expr(
                                    "in", [0, 1, 0]),
                                "+",
                                sir_utils.make_field_access_expr(
                                    "in", [0, -1, 0]),
                            ),
                        ),
                    ),
                ),
            ),
            "=",
        ),
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("out"),
            sir_utils.make_binary_operator(
                sir_utils.make_binary_operator(
                    sir_utils.make_literal_access_expr("-4.0",
                                                       SIR.BuiltinType.Float),
                    "*",
                    sir_utils.make_field_access_expr("lap"),
                ),
                "+",
                sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("coeff"),
                    "*",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("lap", [1, 0, 0]),
                        "+",
                        sir_utils.make_binary_operator(
                            sir_utils.make_field_access_expr(
                                "lap", [-1, 0, 0]),
                            "+",
                            sir_utils.make_binary_operator(
                                sir_utils.make_field_access_expr(
                                    "lap", [0, 1, 0]),
                                "+",
                                sir_utils.make_field_access_expr(
                                    "lap", [0, -1, 0]),
                            ),
                        ),
                    ),
                ),
            ),
            "=",
        ),
    ])

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)

    sir = sir_utils.make_sir(
        OUTPUT_FILE,
        sir_utils.GridType.Value("Cartesian"),
        [
            sir_utils.make_stencil(
                OUTPUT_NAME,
                sir_utils.make_ast([vertical_region_stmt]),
                [
                    sir_utils.make_field(
                        "in", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "out", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "coeff", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "lap",
                        sir_utils.make_field_dimensions_cartesian(),
                        is_temporary=True),
                ],
            )
        ],
    )

    return sir
コード例 #10
0
ファイル: utils.py プロジェクト: circleci/dawn
def make_tridiagonal_solve_stencil_sir(name=None):
    OUTPUT_NAME = name if name is not None else "tridiagonal_solve_stencil"
    OUTPUT_FILE = f"{OUTPUT_NAME}.cpp"

    # ---- First vertical region statement ----
    interval_1 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                         0, 0)
    body_ast_1 = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("c"),
            sir_utils.make_binary_operator(
                sir_utils.make_field_access_expr("c"),
                "/",
                sir_utils.make_field_access_expr("b"),
            ),
            "=",
        )
    ])

    vertical_region_stmt_1 = sir_utils.make_vertical_region_decl_stmt(
        body_ast_1, interval_1, SIR.VerticalRegion.Forward)

    # ---- Second vertical region statement ----
    interval_2 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                         1, 0)

    body_ast_2 = sir_utils.make_ast([
        sir_utils.make_var_decl_stmt(
            sir_utils.make_type(SIR.BuiltinType.Integer),
            "m",
            0,
            "=",
            sir_utils.make_expr(
                sir_utils.make_binary_operator(
                    sir_utils.make_literal_access_expr("1.0",
                                                       SIR.BuiltinType.Float),
                    "/",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("b"),
                        "-",
                        sir_utils.make_binary_operator(
                            sir_utils.make_field_access_expr("a"),
                            "*",
                            sir_utils.make_field_access_expr("c", [0, 0, -1]),
                        ),
                    ),
                )),
        ),
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("c"),
            sir_utils.make_binary_operator(
                sir_utils.make_field_access_expr("c"), "*",
                sir_utils.make_var_access_expr("m")),
            "=",
        ),
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("d"),
            sir_utils.make_binary_operator(
                sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("d"),
                    "-",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("a"),
                        "*",
                        sir_utils.make_field_access_expr("d", [0, 0, -1]),
                    ),
                ),
                "*",
                sir_utils.make_var_access_expr("m"),
            ),
            "=",
        ),
    ])
    vertical_region_stmt_2 = sir_utils.make_vertical_region_decl_stmt(
        body_ast_2, interval_2, SIR.VerticalRegion.Forward)

    # ---- Third vertical region statement ----
    interval_3 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                         0, -1)
    body_ast_3 = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("d"),
            sir_utils.make_binary_operator(
                sir_utils.make_field_access_expr("c"),
                "*",
                sir_utils.make_field_access_expr("d", [0, 0, 1]),
            ),
            "-=",
        )
    ])

    vertical_region_stmt_3 = sir_utils.make_vertical_region_decl_stmt(
        body_ast_3, interval_3, SIR.VerticalRegion.Backward)

    sir = sir_utils.make_sir(
        OUTPUT_FILE,
        sir_utils.GridType.Value("Cartesian"),
        [
            sir_utils.make_stencil(
                OUTPUT_NAME,
                sir_utils.make_ast([
                    vertical_region_stmt_1, vertical_region_stmt_2,
                    vertical_region_stmt_3
                ]),
                [
                    sir_utils.make_field(
                        "a", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "b", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "c", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "d", sir_utils.make_field_dimensions_cartesian()),
                ],
            )
        ],
    )

    return sir
コード例 #11
0
ファイル: hori_diff_stencil.py プロジェクト: havogt/dawn
def main(args: argparse.Namespace):
    interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0,
                                       0)

    # create the stencil body AST
    body_ast = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("lap"),
            sir_utils.make_binary_operator(
                sir_utils.make_binary_operator(
                    sir_utils.make_literal_access_expr("-4.0",
                                                       SIR.BuiltinType.Float),
                    "*",
                    sir_utils.make_field_access_expr("in"),
                ),
                "+",
                sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("coeff"),
                    "*",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("in", [1, 0, 0]),
                        "+",
                        sir_utils.make_binary_operator(
                            sir_utils.make_field_access_expr("in", [-1, 0, 0]),
                            "+",
                            sir_utils.make_binary_operator(
                                sir_utils.make_field_access_expr(
                                    "in", [0, 1, 0]),
                                "+",
                                sir_utils.make_field_access_expr(
                                    "in", [0, -1, 0]),
                            ),
                        ),
                    ),
                ),
            ),
            "=",
        ),
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("out"),
            sir_utils.make_binary_operator(
                sir_utils.make_binary_operator(
                    sir_utils.make_literal_access_expr("-4.0",
                                                       SIR.BuiltinType.Float),
                    "*",
                    sir_utils.make_field_access_expr("lap"),
                ),
                "+",
                sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("coeff"),
                    "*",
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("lap", [1, 0, 0]),
                        "+",
                        sir_utils.make_binary_operator(
                            sir_utils.make_field_access_expr(
                                "lap", [-1, 0, 0]),
                            "+",
                            sir_utils.make_binary_operator(
                                sir_utils.make_field_access_expr(
                                    "lap", [0, 1, 0]),
                                "+",
                                sir_utils.make_field_access_expr(
                                    "lap", [0, -1, 0]),
                            ),
                        ),
                    ),
                ),
            ),
            "=",
        ),
    ])

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)

    sir = sir_utils.make_sir(
        OUTPUT_FILE,
        SIR.GridType.Value("Cartesian"),
        [
            sir_utils.make_stencil(
                OUTPUT_NAME,
                sir_utils.make_ast([vertical_region_stmt]),
                [
                    sir_utils.make_field(
                        "in", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "out", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "coeff", sir_utils.make_field_dimensions_cartesian()),
                    sir_utils.make_field(
                        "lap",
                        sir_utils.make_field_dimensions_cartesian(),
                        is_temporary=True),
                ],
            )
        ],
    )

    # print the SIR
    if args.verbose:
        sir_utils.pprint(sir)

    # compile
    code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CUDA)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)