def create_vertical_region_stmt2(): """ create a vertical region statement for the stencil """ interval = sir_utils.make_interval(sir_utils.Interval.Start, sir_utils.Interval.End, 1, 0) body_ast = sir_utils.make_ast([ sir_utils.make_var_decl_stmt( sir_utils.make_type(sir_utils.BuiltinType.Float), "m", 0, "=", sir_utils.make_expr( sir_utils.make_binary_operator( sir_utils.make_literal_access_expr( "1.0", sir_utils.BuiltinType.Float), "/", sir_utils.make_binary_operator( sir_utils.make_unstructured_field_access_expr("b"), "-", sir_utils.make_binary_operator( sir_utils.make_unstructured_field_access_expr("a"), "*", sir_utils.make_unstructured_field_access_expr( "c", sir_utils.make_unstructured_offset(False), -1)))))), sir_utils.make_assignment_stmt( sir_utils.make_unstructured_field_access_expr("c"), sir_utils.make_binary_operator( sir_utils.make_unstructured_field_access_expr("c"), "*", sir_utils.make_var_access_expr("m")), "="), sir_utils.make_assignment_stmt( sir_utils.make_unstructured_field_access_expr("d"), sir_utils.make_binary_operator( sir_utils.make_binary_operator( sir_utils.make_unstructured_field_access_expr("d"), "-", sir_utils.make_binary_operator( sir_utils.make_unstructured_field_access_expr("a"), "*", sir_utils.make_unstructured_field_access_expr( "d", sir_utils.make_unstructured_offset(False), -1))), "*", sir_utils.make_var_access_expr("m")), "=") ]) vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt( body_ast, interval, SIR.VerticalRegion.Forward) return vertical_region_stmt
def if_stmt(): outputfile = "../input/test_set_stage_location_type_if_stmt.sir" interval = serial_utils.make_interval( SIR.Interval.Start, SIR.Interval.End, 0, 0) body_ast = serial_utils.make_ast( [ serial_utils.make_var_decl_stmt( serial_utils.make_type(serial_utils.BuiltinType.Float), "out_var_cell"), serial_utils.make_if_stmt(serial_utils.make_expr_stmt(serial_utils.make_var_access_expr("out_var_cell")), serial_utils.make_block_stmt(serial_utils.make_assignment_stmt( serial_utils.make_var_access_expr("out_var_cell"), serial_utils.make_field_access_expr("in_cell"), "=", ))), ] ) vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt( body_ast, interval, SIR.VerticalRegion.Forward ) sir = serial_utils.make_sir( outputfile, SIR.GridType.Value("Unstructured"), [ serial_utils.make_stencil( "generated", serial_utils.make_ast([vertical_region_stmt]), [ serial_utils.make_field( "in_cell", serial_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Cell")], 1 ), ), ], ), ], ) f = open(outputfile, "w") f.write(MessageToJson(sir)) f.close()
def main(args: argparse.Namespace): # ---- First vertical region statement ---- interval_1 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0) body_ast_1 = sir_utils.make_ast([ sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("c"), sir_utils.make_binary_operator( sir_utils.make_field_access_expr("c"), "/", sir_utils.make_field_access_expr("b"), ), "=", ) ]) vertical_region_stmt_1 = sir_utils.make_vertical_region_decl_stmt( body_ast_1, interval_1, SIR.VerticalRegion.Forward) # ---- Second vertical region statement ---- interval_2 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 1, 0) body_ast_2 = sir_utils.make_ast([ sir_utils.make_var_decl_stmt( sir_utils.make_type(SIR.BuiltinType.Integer), "m", 0, "=", sir_utils.make_expr( sir_utils.make_binary_operator( sir_utils.make_literal_access_expr("1.0", SIR.BuiltinType.Float), "/", sir_utils.make_binary_operator( sir_utils.make_field_access_expr("b"), "-", sir_utils.make_binary_operator( sir_utils.make_field_access_expr("a"), "*", sir_utils.make_field_access_expr("c", [0, 0, -1]), ), ), )), ), sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("c"), sir_utils.make_binary_operator( sir_utils.make_field_access_expr("c"), "*", sir_utils.make_var_access_expr("m")), "=", ), sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("d"), sir_utils.make_binary_operator( sir_utils.make_binary_operator( sir_utils.make_field_access_expr("d"), "-", sir_utils.make_binary_operator( sir_utils.make_field_access_expr("a"), "*", sir_utils.make_field_access_expr("d", [0, 0, -1]), ), ), "*", sir_utils.make_var_access_expr("m"), ), "=", ), ]) vertical_region_stmt_2 = sir_utils.make_vertical_region_decl_stmt( body_ast_2, interval_2, SIR.VerticalRegion.Forward) # ---- Third vertical region statement ---- interval_3 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, -1) body_ast_3 = sir_utils.make_ast([ sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("d"), sir_utils.make_binary_operator( sir_utils.make_field_access_expr("c"), "*", sir_utils.make_field_access_expr("d", [0, 0, 1]), ), "-=", ) ]) vertical_region_stmt_3 = sir_utils.make_vertical_region_decl_stmt( body_ast_3, interval_3, SIR.VerticalRegion.Backward) sir = sir_utils.make_sir( OUTPUT_FILE, SIR.GridType.Value("Cartesian"), [ sir_utils.make_stencil( OUTPUT_NAME, sir_utils.make_ast([ vertical_region_stmt_1, vertical_region_stmt_2, vertical_region_stmt_3 ]), [ sir_utils.make_field( "a", sir_utils.make_field_dimensions_cartesian()), sir_utils.make_field( "b", sir_utils.make_field_dimensions_cartesian()), sir_utils.make_field( "c", sir_utils.make_field_dimensions_cartesian()), sir_utils.make_field( "d", sir_utils.make_field_dimensions_cartesian()), ], ) ], ) # print the SIR if args.verbose: sir_utils.pprint(sir) # compile code = dawn4py.compile(sir, backend="cuda") # write to file print(f"Writing generated code to '{OUTPUT_PATH}'") with open(OUTPUT_PATH, "w") as f: f.write(code)
def main(args: argparse.Namespace): # ---- First vertical region statement ---- interval_1 = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End, 0, 0) body_ast_1 = serial_utils.make_ast([ serial_utils.make_assignment_stmt( serial_utils.make_field_access_expr("c"), serial_utils.make_binary_operator( serial_utils.make_field_access_expr("c"), "/", serial_utils.make_field_access_expr("b"), ), "=", ) ]) vertical_region_stmt_1 = serial_utils.make_vertical_region_decl_stmt( body_ast_1, interval_1, AST.VerticalRegion.Forward) # ---- Second vertical region statement ---- interval_2 = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End, 1, 0) body_ast_2 = serial_utils.make_ast([ serial_utils.make_var_decl_stmt( serial_utils.make_type(AST.BuiltinType.Integer), "m", 0, "=", serial_utils.make_expr( serial_utils.make_binary_operator( serial_utils.make_literal_access_expr( "1.0", AST.BuiltinType.Float), "/", serial_utils.make_binary_operator( serial_utils.make_field_access_expr("b"), "-", serial_utils.make_binary_operator( serial_utils.make_field_access_expr("a"), "*", serial_utils.make_field_access_expr( "c", [0, 0, -1]), ), ), )), ), serial_utils.make_assignment_stmt( serial_utils.make_field_access_expr("c"), serial_utils.make_binary_operator( serial_utils.make_field_access_expr("c"), "*", serial_utils.make_var_access_expr("m")), "=", ), serial_utils.make_assignment_stmt( serial_utils.make_field_access_expr("d"), serial_utils.make_binary_operator( serial_utils.make_binary_operator( serial_utils.make_field_access_expr("d"), "-", serial_utils.make_binary_operator( serial_utils.make_field_access_expr("a"), "*", serial_utils.make_field_access_expr("d", [0, 0, -1]), ), ), "*", serial_utils.make_var_access_expr("m"), ), "=", ), ]) vertical_region_stmt_2 = serial_utils.make_vertical_region_decl_stmt( body_ast_2, interval_2, AST.VerticalRegion.Forward) # ---- Third vertical region statement ---- interval_3 = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End, 0, -1) body_ast_3 = serial_utils.make_ast([ serial_utils.make_assignment_stmt( serial_utils.make_field_access_expr("d"), serial_utils.make_binary_operator( serial_utils.make_field_access_expr("c"), "*", serial_utils.make_field_access_expr("d", [0, 0, 1]), ), "-=", ) ]) vertical_region_stmt_3 = serial_utils.make_vertical_region_decl_stmt( body_ast_3, interval_3, AST.VerticalRegion.Backward) sir = serial_utils.make_sir( OUTPUT_FILE, AST.GridType.Value("Cartesian"), [ serial_utils.make_stencil( OUTPUT_NAME, serial_utils.make_ast([ vertical_region_stmt_1, vertical_region_stmt_2, vertical_region_stmt_3 ]), [ serial_utils.make_field( "a", serial_utils.make_field_dimensions_cartesian()), serial_utils.make_field( "b", serial_utils.make_field_dimensions_cartesian()), serial_utils.make_field( "c", serial_utils.make_field_dimensions_cartesian()), serial_utils.make_field( "d", serial_utils.make_field_dimensions_cartesian()), ], ) ], ) # print the SIR if args.verbose: print(MessageToJson(sir)) # compile pass_groups = dawn4py.default_pass_groups() pass_groups.insert(1, dawn4py.PassGroup.MultiStageMerger) code = dawn4py.compile(sir, groups=pass_groups, backend=dawn4py.CodeGenBackend.CUDA) # write to file print(f"Writing generated code to '{OUTPUT_PATH}'") with open(OUTPUT_PATH, "w") as f: f.write(code)
def make_tridiagonal_solve_stencil_sir(name=None): OUTPUT_NAME = name if name is not None else "tridiagonal_solve_stencil" OUTPUT_FILE = f"{OUTPUT_NAME}.cpp" # ---- First vertical region statement ---- interval_1 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0) body_ast_1 = sir_utils.make_ast([ sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("c"), sir_utils.make_binary_operator( sir_utils.make_field_access_expr("c"), "/", sir_utils.make_field_access_expr("b"), ), "=", ) ]) vertical_region_stmt_1 = sir_utils.make_vertical_region_decl_stmt( body_ast_1, interval_1, SIR.VerticalRegion.Forward) # ---- Second vertical region statement ---- interval_2 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 1, 0) body_ast_2 = sir_utils.make_ast([ sir_utils.make_var_decl_stmt( sir_utils.make_type(SIR.BuiltinType.Integer), "m", 0, "=", sir_utils.make_expr( sir_utils.make_binary_operator( sir_utils.make_literal_access_expr("1.0", SIR.BuiltinType.Float), "/", sir_utils.make_binary_operator( sir_utils.make_field_access_expr("b"), "-", sir_utils.make_binary_operator( sir_utils.make_field_access_expr("a"), "*", sir_utils.make_field_access_expr("c", [0, 0, -1]), ), ), )), ), sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("c"), sir_utils.make_binary_operator( sir_utils.make_field_access_expr("c"), "*", sir_utils.make_var_access_expr("m")), "=", ), sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("d"), sir_utils.make_binary_operator( sir_utils.make_binary_operator( sir_utils.make_field_access_expr("d"), "-", sir_utils.make_binary_operator( sir_utils.make_field_access_expr("a"), "*", sir_utils.make_field_access_expr("d", [0, 0, -1]), ), ), "*", sir_utils.make_var_access_expr("m"), ), "=", ), ]) vertical_region_stmt_2 = sir_utils.make_vertical_region_decl_stmt( body_ast_2, interval_2, SIR.VerticalRegion.Forward) # ---- Third vertical region statement ---- interval_3 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, -1) body_ast_3 = sir_utils.make_ast([ sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("d"), sir_utils.make_binary_operator( sir_utils.make_field_access_expr("c"), "*", sir_utils.make_field_access_expr("d", [0, 0, 1]), ), "-=", ) ]) vertical_region_stmt_3 = sir_utils.make_vertical_region_decl_stmt( body_ast_3, interval_3, SIR.VerticalRegion.Backward) sir = sir_utils.make_sir( OUTPUT_FILE, sir_utils.GridType.Value("Cartesian"), [ sir_utils.make_stencil( OUTPUT_NAME, sir_utils.make_ast([ vertical_region_stmt_1, vertical_region_stmt_2, vertical_region_stmt_3 ]), [ sir_utils.make_field( "a", sir_utils.make_field_dimensions_cartesian()), sir_utils.make_field( "b", sir_utils.make_field_dimensions_cartesian()), sir_utils.make_field( "c", sir_utils.make_field_dimensions_cartesian()), sir_utils.make_field( "d", sir_utils.make_field_dimensions_cartesian()), ], ) ], ) return sir
"generated", sir_utils.make_ast([vertical_region_stmt]), [], ), ], ) f = open(outputfile, "w") f.write(MessageToJson(sir)) f.close() if __name__ == "__main__": # for the "no stmt" test, use the "one stmt" output and delete the statement # (from SIR it is not possible to generate a stage without statements) one_stmt = sir_utils.make_ast([ sir_utils.make_var_decl_stmt( sir_utils.make_type(SIR.BuiltinType.Integer), "a") ]) make_stencil("../input/test_stage_split_all_statements_one_stmt.sir", one_stmt) two_stmts = sir_utils.make_ast([ sir_utils.make_var_decl_stmt( sir_utils.make_type(SIR.BuiltinType.Integer), "a"), sir_utils.make_var_decl_stmt( sir_utils.make_type(SIR.BuiltinType.Integer), "b") ]) make_stencil("../input/test_stage_split_all_statements_two_stmt.sir", two_stmts)