예제 #1
0
    def if_stmt(self, condition: expr, body: t.List, orelse: t.List):

        condition = make_expr_stmt(self.expression(condition))
        body = make_block_stmt(self.statements(body))
        orelse = make_block_stmt(self.statements(orelse))

        return make_if_stmt(condition, body, orelse)
예제 #2
0
def function_call():
    outputfile = "../input/test_set_stage_location_type_function_call.sir"

    interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0,
                                       0)

    fun_ast = sir_utils.make_ast([
        sir_utils.make_assignment_stmt(
            sir_utils.make_field_access_expr("out"),
            sir_utils.make_literal_access_expr(
                value="2.0", type=sir_utils.BuiltinType.Float),
            "=",
        ),
    ])

    arg_field = sir_utils.make_field(
        "out",
        sir_utils.make_field_dimensions_unstructured(
            [SIR.LocationType.Value("Cell")], 1))

    fun = sir_utils.make_stencil_function(
        name='f',
        asts=[fun_ast],
        intervals=[interval],
        arguments=[sir_utils.make_stencil_function_arg(arg_field)])

    body_ast = sir_utils.make_ast([
        sir_utils.make_expr_stmt(expr=sir_utils.make_stencil_fun_call_expr(
            callee="f",
            arguments=[sir_utils.make_field_access_expr("out_cell")])),
    ])

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)

    sir = sir_utils.make_sir(
        outputfile,
        SIR.GridType.Value("Unstructured"), [
            sir_utils.make_stencil(
                "generated",
                sir_utils.make_ast([vertical_region_stmt]),
                [
                    sir_utils.make_field(
                        "out_cell",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Cell")], 1),
                    ),
                ],
            ),
        ],
        functions=[fun])

    f = open(outputfile, "w")
    f.write(MessageToJson(sir))
    f.close()
예제 #3
0
def if_stmt():
    outputfile = "../input/test_set_stage_location_type_if_stmt.sir"

    interval = serial_utils.make_interval(
        SIR.Interval.Start, SIR.Interval.End, 0, 0)

    body_ast = serial_utils.make_ast(
        [
            serial_utils.make_var_decl_stmt(
                serial_utils.make_type(serial_utils.BuiltinType.Float),
                "out_var_cell"),
            serial_utils.make_if_stmt(serial_utils.make_expr_stmt(serial_utils.make_var_access_expr("out_var_cell")), serial_utils.make_block_stmt(serial_utils.make_assignment_stmt(
                serial_utils.make_var_access_expr("out_var_cell"),
                serial_utils.make_field_access_expr("in_cell"),
                "=",
            ))),
        ]
    )

    vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward
    )

    sir = serial_utils.make_sir(
        outputfile,
        SIR.GridType.Value("Unstructured"),
        [
            serial_utils.make_stencil(
                "generated",
                serial_utils.make_ast([vertical_region_stmt]),
                [
                    serial_utils.make_field(
                        "in_cell",
                        serial_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Cell")], 1
                        ),
                    ),
                ],
            ),
        ],
    )

    f = open(outputfile, "w")
    f.write(MessageToJson(sir))
    f.close()
예제 #4
0
 def visit_If(self, node: gt_ir.If, **kwargs: Any) -> SIR.IfStmt:
     cond = sir_utils.make_expr_stmt(self.visit(node.condition))
     then_part = self.visit(node.main_body)
     else_part = self.visit(node.else_body)
     stmt = sir_utils.make_if_stmt(cond, then_part, else_part)
     return stmt
예제 #5
0
def main(args: argparse.Namespace):
    interval = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End,
                                          0, 0)

    line_1 = serial_utils.make_assignment_stmt(
        serial_utils.make_field_access_expr("a"),
        serial_utils.make_binary_operator(
            serial_utils.make_binary_operator(
                serial_utils.make_field_access_expr("b"),
                "/",
                serial_utils.make_field_access_expr("c"),
            ), "+",
            serial_utils.make_literal_access_expr("5", AST.BuiltinType.Float)),
        "=")

    line_2 = serial_utils.make_block_stmt(
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("a"),
            serial_utils.make_field_access_expr("b"), "="))

    line_3 = serial_utils.make_block_stmt(
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("c"),
            serial_utils.make_binary_operator(
                serial_utils.make_field_access_expr("a"), "+",
                serial_utils.make_literal_access_expr("1",
                                                      AST.BuiltinType.Float)),
            "="))

    body_ast = serial_utils.make_ast([
        line_1,
        serial_utils.make_if_stmt(
            serial_utils.make_expr_stmt(
                serial_utils.make_field_access_expr("d")), line_2,
            serial_utils.make_block_stmt(
                serial_utils.make_if_stmt(
                    serial_utils.make_expr_stmt(
                        serial_utils.make_field_access_expr("e")), line_3)))
    ])

    vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt(
        body_ast, interval, AST.VerticalRegion.Forward)

    sir = serial_utils.make_sir(
        OUTPUT_FILE,
        AST.GridType.Value("Unstructured"),
        [
            serial_utils.make_stencil(
                OUTPUT_NAME,
                serial_utils.make_ast([vertical_region_stmt]),
                [
                    serial_utils.make_field(
                        "a",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1),
                    ),
                    serial_utils.make_field(
                        "b",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1),
                    ),
                    serial_utils.make_field(
                        "c",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1),
                    ),
                    serial_utils.make_field(
                        "d",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1),
                    ),
                    serial_utils.make_field(
                        "e",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1),
                    ),
                ],
            ),
        ],
    )

    # print the SIR
    if args.verbose:
        print(MessageToJson(sir))

    # compile
    code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaiveIco)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)