예제 #1
0
    def if_stmt(self, condition: expr, body: t.List, orelse: t.List):

        condition = make_expr_stmt(self.expression(condition))
        body = make_block_stmt(self.statements(body))
        orelse = make_block_stmt(self.statements(orelse))

        return make_if_stmt(condition, body, orelse)
예제 #2
0
 def visit_BlockStmt(self,
                     node: gt_ir.BlockStmt,
                     *,
                     make_block: bool = True,
                     **kwargs: Any) -> SIR.BlockStmt:
     stmts = [
         self.visit(stmt) for stmt in node.stmts
         if not isinstance(stmt, gt_ir.FieldDecl)
     ]
     if make_block:
         stmts = sir_utils.make_block_stmt(stmts)
     return stmts
예제 #3
0
def if_stmt():
    outputfile = "../input/test_set_stage_location_type_if_stmt.sir"

    interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0,
                                       0)

    body_ast = sir_utils.make_ast([
        sir_utils.make_var_decl_stmt(
            sir_utils.make_type(sir_utils.BuiltinType.Float), "out_var_cell"),
        sir_utils.make_if_stmt(
            sir_utils.make_expr_stmt(
                sir_utils.make_var_access_expr("out_var_cell")),
            sir_utils.make_block_stmt(
                sir_utils.make_assignment_stmt(
                    sir_utils.make_var_access_expr("out_var_cell"),
                    sir_utils.make_field_access_expr("in_cell"),
                    "=",
                ))),
    ])

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)

    sir = sir_utils.make_sir(
        outputfile,
        SIR.GridType.Value("Unstructured"),
        [
            sir_utils.make_stencil(
                "generated",
                sir_utils.make_ast([vertical_region_stmt]),
                [
                    sir_utils.make_field(
                        "in_cell",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Cell")], 1),
                    ),
                ],
            ),
        ],
    )

    f = open(outputfile, "w")
    f.write(MessageToJson(sir))
    f.close()
예제 #4
0
def main(args: argparse.Namespace):
    interval = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End,
                                          0, 0)

    line_1 = serial_utils.make_assignment_stmt(
        serial_utils.make_field_access_expr("a"),
        serial_utils.make_binary_operator(
            serial_utils.make_binary_operator(
                serial_utils.make_field_access_expr("b"),
                "/",
                serial_utils.make_field_access_expr("c"),
            ), "+",
            serial_utils.make_literal_access_expr("5", AST.BuiltinType.Float)),
        "=")

    line_2 = serial_utils.make_block_stmt(
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("a"),
            serial_utils.make_field_access_expr("b"), "="))

    line_3 = serial_utils.make_block_stmt(
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("c"),
            serial_utils.make_binary_operator(
                serial_utils.make_field_access_expr("a"), "+",
                serial_utils.make_literal_access_expr("1",
                                                      AST.BuiltinType.Float)),
            "="))

    body_ast = serial_utils.make_ast([
        line_1,
        serial_utils.make_if_stmt(
            serial_utils.make_expr_stmt(
                serial_utils.make_field_access_expr("d")), line_2,
            serial_utils.make_block_stmt(
                serial_utils.make_if_stmt(
                    serial_utils.make_expr_stmt(
                        serial_utils.make_field_access_expr("e")), line_3)))
    ])

    vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt(
        body_ast, interval, AST.VerticalRegion.Forward)

    sir = serial_utils.make_sir(
        OUTPUT_FILE,
        AST.GridType.Value("Unstructured"),
        [
            serial_utils.make_stencil(
                OUTPUT_NAME,
                serial_utils.make_ast([vertical_region_stmt]),
                [
                    serial_utils.make_field(
                        "a",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1),
                    ),
                    serial_utils.make_field(
                        "b",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1),
                    ),
                    serial_utils.make_field(
                        "c",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1),
                    ),
                    serial_utils.make_field(
                        "d",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1),
                    ),
                    serial_utils.make_field(
                        "e",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1),
                    ),
                ],
            ),
        ],
    )

    # print the SIR
    if args.verbose:
        print(MessageToJson(sir))

    # compile
    code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaiveIco)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)