def if_stmt(self, condition: expr, body: t.List, orelse: t.List): condition = make_expr_stmt(self.expression(condition)) body = make_block_stmt(self.statements(body)) orelse = make_block_stmt(self.statements(orelse)) return make_if_stmt(condition, body, orelse)
def function_call(): outputfile = "../input/test_set_stage_location_type_function_call.sir" interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0) fun_ast = sir_utils.make_ast([ sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("out"), sir_utils.make_literal_access_expr( value="2.0", type=sir_utils.BuiltinType.Float), "=", ), ]) arg_field = sir_utils.make_field( "out", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Cell")], 1)) fun = sir_utils.make_stencil_function( name='f', asts=[fun_ast], intervals=[interval], arguments=[sir_utils.make_stencil_function_arg(arg_field)]) body_ast = sir_utils.make_ast([ sir_utils.make_expr_stmt(expr=sir_utils.make_stencil_fun_call_expr( callee="f", arguments=[sir_utils.make_field_access_expr("out_cell")])), ]) vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt( body_ast, interval, SIR.VerticalRegion.Forward) sir = sir_utils.make_sir( outputfile, SIR.GridType.Value("Unstructured"), [ sir_utils.make_stencil( "generated", sir_utils.make_ast([vertical_region_stmt]), [ sir_utils.make_field( "out_cell", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Cell")], 1), ), ], ), ], functions=[fun]) f = open(outputfile, "w") f.write(MessageToJson(sir)) f.close()
def if_stmt(): outputfile = "../input/test_set_stage_location_type_if_stmt.sir" interval = serial_utils.make_interval( SIR.Interval.Start, SIR.Interval.End, 0, 0) body_ast = serial_utils.make_ast( [ serial_utils.make_var_decl_stmt( serial_utils.make_type(serial_utils.BuiltinType.Float), "out_var_cell"), serial_utils.make_if_stmt(serial_utils.make_expr_stmt(serial_utils.make_var_access_expr("out_var_cell")), serial_utils.make_block_stmt(serial_utils.make_assignment_stmt( serial_utils.make_var_access_expr("out_var_cell"), serial_utils.make_field_access_expr("in_cell"), "=", ))), ] ) vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt( body_ast, interval, SIR.VerticalRegion.Forward ) sir = serial_utils.make_sir( outputfile, SIR.GridType.Value("Unstructured"), [ serial_utils.make_stencil( "generated", serial_utils.make_ast([vertical_region_stmt]), [ serial_utils.make_field( "in_cell", serial_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Cell")], 1 ), ), ], ), ], ) f = open(outputfile, "w") f.write(MessageToJson(sir)) f.close()
def visit_If(self, node: gt_ir.If, **kwargs: Any) -> SIR.IfStmt: cond = sir_utils.make_expr_stmt(self.visit(node.condition)) then_part = self.visit(node.main_body) else_part = self.visit(node.else_body) stmt = sir_utils.make_if_stmt(cond, then_part, else_part) return stmt
def main(args: argparse.Namespace): interval = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End, 0, 0) line_1 = serial_utils.make_assignment_stmt( serial_utils.make_field_access_expr("a"), serial_utils.make_binary_operator( serial_utils.make_binary_operator( serial_utils.make_field_access_expr("b"), "/", serial_utils.make_field_access_expr("c"), ), "+", serial_utils.make_literal_access_expr("5", AST.BuiltinType.Float)), "=") line_2 = serial_utils.make_block_stmt( serial_utils.make_assignment_stmt( serial_utils.make_field_access_expr("a"), serial_utils.make_field_access_expr("b"), "=")) line_3 = serial_utils.make_block_stmt( serial_utils.make_assignment_stmt( serial_utils.make_field_access_expr("c"), serial_utils.make_binary_operator( serial_utils.make_field_access_expr("a"), "+", serial_utils.make_literal_access_expr("1", AST.BuiltinType.Float)), "=")) body_ast = serial_utils.make_ast([ line_1, serial_utils.make_if_stmt( serial_utils.make_expr_stmt( serial_utils.make_field_access_expr("d")), line_2, serial_utils.make_block_stmt( serial_utils.make_if_stmt( serial_utils.make_expr_stmt( serial_utils.make_field_access_expr("e")), line_3))) ]) vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt( body_ast, interval, AST.VerticalRegion.Forward) sir = serial_utils.make_sir( OUTPUT_FILE, AST.GridType.Value("Unstructured"), [ serial_utils.make_stencil( OUTPUT_NAME, serial_utils.make_ast([vertical_region_stmt]), [ serial_utils.make_field( "a", serial_utils.make_field_dimensions_unstructured( [AST.LocationType.Value("Edge")], 1), ), serial_utils.make_field( "b", serial_utils.make_field_dimensions_unstructured( [AST.LocationType.Value("Edge")], 1), ), serial_utils.make_field( "c", serial_utils.make_field_dimensions_unstructured( [AST.LocationType.Value("Edge")], 1), ), serial_utils.make_field( "d", serial_utils.make_field_dimensions_unstructured( [AST.LocationType.Value("Edge")], 1), ), serial_utils.make_field( "e", serial_utils.make_field_dimensions_unstructured( [AST.LocationType.Value("Edge")], 1), ), ], ), ], ) # print the SIR if args.verbose: print(MessageToJson(sir)) # compile code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaiveIco) # write to file print(f"Writing generated code to '{OUTPUT_PATH}'") with open(OUTPUT_PATH, "w") as f: f.write(code)