def main(args: argparse.Namespace): interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0) # create the out = in[i+1] statement body_ast = sir_utils.make_ast( [ sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("out"), sir_utils.make_reduction_over_neighbor_expr( "+", sir_utils.make_literal_access_expr("1.0", SIR.BuiltinType.Float), sir_utils.make_field_access_expr("in"), lhs_location=SIR.LocationType.Value("Edge"), rhs_location=SIR.LocationType.Value("Cell"), ), "=", ) ] ) vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt( body_ast, interval, SIR.VerticalRegion.Forward ) sir = sir_utils.make_sir( OUTPUT_FILE, SIR.GridType.Value("Unstructured"), [ sir_utils.make_stencil( OUTPUT_NAME, sir_utils.make_ast([vertical_region_stmt]), [ sir_utils.make_field( "in", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Cell")], 1 ), ), sir_utils.make_field( "out", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1 ), ), ], ), ], ) # print the SIR if args.verbose: sir_utils.pprint(sir) # compile code = dawn4py.compile(sir, backend="c++-naive-ico") # write to file print(f"Writing generated code to '{OUTPUT_PATH}'") with open(OUTPUT_PATH, "w") as f: f.write(code)
def two_copies_mixed(): outputfile = "StageMergerTestTwoCopiesMixed" interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0) body_ast = sir_utils.make_ast([ sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("out_cell"), sir_utils.make_field_access_expr("in_cell"), "=", ), sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("out_edge"), sir_utils.make_field_access_expr("in_edge"), "=", ) ]) vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt( body_ast, interval, SIR.VerticalRegion.Forward) sir = sir_utils.make_sir( outputfile, SIR.GridType.Value("Unstructured"), [ sir_utils.make_stencil( "generated", sir_utils.make_ast([vertical_region_stmt]), [ sir_utils.make_field( "in_cell", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Cell")], 1), ), sir_utils.make_field( "out_cell", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Cell")], 1), ), sir_utils.make_field( "in_edge", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1), ), sir_utils.make_field( "out_edge", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1), ), ], ), ], ) dawn4py.compile(sir, backend=backend, serialize_iir=True, output_file=outputfile) os.rename(outputfile + ".0.iir", "../input/" + outputfile + ".iir")
def function_call(): outputfile = "../input/test_set_stage_location_type_function_call.sir" interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0) fun_ast = sir_utils.make_ast([ sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("out"), sir_utils.make_literal_access_expr( value="2.0", type=sir_utils.BuiltinType.Float), "=", ), ]) arg_field = sir_utils.make_field( "out", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Cell")], 1)) fun = sir_utils.make_stencil_function( name='f', asts=[fun_ast], intervals=[interval], arguments=[sir_utils.make_stencil_function_arg(arg_field)]) body_ast = sir_utils.make_ast([ sir_utils.make_expr_stmt(expr=sir_utils.make_stencil_fun_call_expr( callee="f", arguments=[sir_utils.make_field_access_expr("out_cell")])), ]) vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt( body_ast, interval, SIR.VerticalRegion.Forward) sir = sir_utils.make_sir( outputfile, SIR.GridType.Value("Unstructured"), [ sir_utils.make_stencil( "generated", sir_utils.make_ast([vertical_region_stmt]), [ sir_utils.make_field( "out_cell", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Cell")], 1), ), ], ), ], functions=[fun]) f = open(outputfile, "w") f.write(MessageToJson(sir)) f.close()
def main(args: argparse.Namespace): interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0) body_ast = sir_utils.make_ast([ sir_utils.make_loop_stmt([ sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("out"), sir_utils.make_field_access_expr("in"), "=", ) ], [SIR.LocationType.Value("Cell"), SIR.LocationType.Value("Edge")]) ]) vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt( body_ast, interval, SIR.VerticalRegion.Forward) sir = sir_utils.make_sir( OUTPUT_FILE, SIR.GridType.Value("Unstructured"), [ sir_utils.make_stencil( OUTPUT_NAME, sir_utils.make_ast([vertical_region_stmt]), [ sir_utils.make_field( "out", sir_utils.make_field_dimensions_unstructured([ SIR.LocationType.Value("Cell"), SIR.LocationType.Value("Edge") ], 1), ), sir_utils.make_field( "in", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1), ), ], ) ], ) # print the SIR if args.verbose: sir_utils.pprint(sir) # compile code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaiveIco) # write to file print(f"Writing generated code to '{OUTPUT_PATH}'") with open(OUTPUT_PATH, "w") as f: f.write(code)
def main(args: argparse.Namespace): interval = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End, 0, 0) body_ast = serial_utils.make_ast([ serial_utils.make_assignment_stmt( serial_utils.make_field_access_expr("x"), serial_utils.make_literal_access_expr("1", AST.BuiltinType.Double), "=", ) ]) vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt( body_ast, interval, AST.VerticalRegion.Forward) sir = serial_utils.make_sir( OUTPUT_FILE, AST.GridType.Value("Unstructured"), [ serial_utils.make_stencil( OUTPUT_NAME, serial_utils.make_ast([vertical_region_stmt]), [ serial_utils.make_field( "in", serial_utils.make_field_dimensions_unstructured( [AST.LocationType.Value("Edge")], 1), ), serial_utils.make_field( "out", serial_utils.make_field_dimensions_unstructured( [AST.LocationType.Value("Edge")], 1), ), serial_utils.make_field( "x", serial_utils.make_field_dimensions_unstructured( [AST.LocationType.Value("Edge")], 1), True), ], ), ], ) # print the SIR if args.verbose: print(MessageToJson(sir)) # compile code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaiveIco) # write to file print(f"Writing generated code to '{OUTPUT_PATH}'") with open(OUTPUT_PATH, "w") as f: f.write(code)
def main(args: argparse.Namespace): interval = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End, 0, 0) body_ast = serial_utils.make_ast([ serial_utils.make_assignment_stmt( serial_utils.make_field_access_expr("out"), serial_utils.make_binary_operator( serial_utils.make_var_access_expr("dt", is_external=True), "*", serial_utils.make_field_access_expr("in")), "="), ]) vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt( body_ast, interval, AST.VerticalRegion.Forward) globals = AST.GlobalVariableMap() globals.map["dt"].double_value = 0.5 sir = serial_utils.make_sir( OUTPUT_FILE, AST.GridType.Value("Unstructured"), [ serial_utils.make_stencil( OUTPUT_NAME, serial_utils.make_ast([vertical_region_stmt]), [ serial_utils.make_field( "in", serial_utils.make_field_dimensions_unstructured( [AST.LocationType.Value("Edge")], 1), ), serial_utils.make_field( "out", serial_utils.make_field_dimensions_unstructured( [AST.LocationType.Value("Edge")], 1), ), ], ) ], global_variables=globals) # print the SIR if args.verbose: print(MessageToJson(sir)) # compile code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CUDAIco) # write to file print(f"Writing generated code to '{OUTPUT_PATH}'") with open(OUTPUT_PATH, "w") as f: f.write(code)
def make_unstructured_stencil_sir(name=None): OUTPUT_NAME = name if name is not None else "unstructured_stencil" OUTPUT_FILE = f"{OUTPUT_NAME}.cpp" interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0) # create the out = in[i+1] statement body_ast = sir_utils.make_ast([ sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("out"), sir_utils.make_reduction_over_neighbor_expr( "+", sir_utils.make_literal_access_expr("1.0", SIR.BuiltinType.Float), sir_utils.make_field_access_expr("in"), chain=[ SIR.LocationType.Value('Edge'), SIR.LocationType.Value('Cell') ]), "=", ) ]) vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt( body_ast, interval, SIR.VerticalRegion.Forward) sir = sir_utils.make_sir( OUTPUT_FILE, sir_utils.GridType.Value("Unstructured"), [ sir_utils.make_stencil( OUTPUT_NAME, sir_utils.make_ast([vertical_region_stmt]), [ sir_utils.make_field( "in", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value('Cell')], 1)), sir_utils.make_field( "out", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value('Edge')], 1)) ], ) ], ) return sir
def vertical_loop(self, order, body, upper=None, lower=None, var: str = None): if lower is None: lower_level, lower_offset = sir.Interval.Start, 0 else: lower_level, lower_offset = self.vertical_interval_bound(lower) if upper is None: upper_level, upper_offset = sir.Interval.End, 0 else: upper_level, upper_offset = self.vertical_interval_bound(upper) order_mapper = { "levels_upward": sir.VerticalRegion.Forward, "levels_downward": sir.VerticalRegion.Backward, } with self.ctx.vertical_region(var): return make_vertical_region_decl_stmt( make_ast(self.statements(body)), make_interval(lower_level, upper_level, lower_offset, upper_offset), order_mapper[order], )
def visit_StencilDefinition( self, node: gt_ir.StencilDefinition, **kwargs: Any) -> Tuple[SIR.SIR, Dict[str, Any]]: stencils = [] functions: List = [] global_variables = self._make_global_variables(node.parameters, node.externals) field_info = FieldInfoCollector.apply(node) self._update_field_extents(field_info) fields = [ field_info[field_name]["field_decl"] for field_name in field_info ] stencil_ast = sir_utils.make_ast( [self.visit(computation) for computation in node.computations]) name = node.name.split(".")[-1] stencils.append( sir_utils.make_stencil(name=name, ast=stencil_ast, fields=fields)) sir = sir_utils.make_sir( filename="<gt4py>", grid_type=SIR.GridType.Value("Cartesian"), stencils=stencils, functions=functions, global_variables=global_variables, ) return sir, field_info
def create_vertical_region_stmt1(): """ create a vertical region statement for the stencil """ interval = serial_utils.make_interval(serial_utils.Interval.Start, serial_utils.Interval.Start, 0, 0) body_ast = serial_utils.make_ast([ serial_utils.make_assignment_stmt( serial_utils.make_unstructured_field_access_expr("c"), serial_utils.make_binary_operator( serial_utils.make_unstructured_field_access_expr("c"), "/", serial_utils.make_unstructured_field_access_expr("b"), ), "=", ), serial_utils.make_assignment_stmt( serial_utils.make_unstructured_field_access_expr("d"), serial_utils.make_binary_operator( serial_utils.make_unstructured_field_access_expr("d"), "/", serial_utils.make_unstructured_field_access_expr("b"), ), "=", ), ]) vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt( body_ast, interval, AST.VerticalRegion.Forward) return vertical_region_stmt
def if_stmt(): outputfile = "../input/test_set_stage_location_type_if_stmt.sir" interval = serial_utils.make_interval( SIR.Interval.Start, SIR.Interval.End, 0, 0) body_ast = serial_utils.make_ast( [ serial_utils.make_var_decl_stmt( serial_utils.make_type(serial_utils.BuiltinType.Float), "out_var_cell"), serial_utils.make_if_stmt(serial_utils.make_expr_stmt(serial_utils.make_var_access_expr("out_var_cell")), serial_utils.make_block_stmt(serial_utils.make_assignment_stmt( serial_utils.make_var_access_expr("out_var_cell"), serial_utils.make_field_access_expr("in_cell"), "=", ))), ] ) vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt( body_ast, interval, SIR.VerticalRegion.Forward ) sir = serial_utils.make_sir( outputfile, SIR.GridType.Value("Unstructured"), [ serial_utils.make_stencil( "generated", serial_utils.make_ast([vertical_region_stmt]), [ serial_utils.make_field( "in_cell", serial_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Cell")], 1 ), ), ], ), ], ) f = open(outputfile, "w") f.write(MessageToJson(sir)) f.close()
def main(args: argparse.Namespace): interval = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End, 0, 0) # create the out = in[i+1] statement body_ast = serial_utils.make_ast([ serial_utils.make_assignment_stmt( serial_utils.make_field_access_expr("out", [0, 0, 0]), serial_utils.make_field_access_expr("in", [1, 0, 0]), "=", ) ]) vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt( body_ast, interval, AST.VerticalRegion.Forward) sir = serial_utils.make_sir( OUTPUT_FILE, AST.GridType.Value("Cartesian"), [ serial_utils.make_stencil( OUTPUT_NAME, serial_utils.make_ast([vertical_region_stmt]), [ serial_utils.make_field( "in", serial_utils.make_field_dimensions_cartesian()), serial_utils.make_field( "out", serial_utils.make_field_dimensions_cartesian()), ], ) ], ) # print the SIR if args.verbose: print(MessageToJson(sir)) # compile code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CUDA) # write to file print(f"Writing generated code to '{OUTPUT_PATH}'") with open(OUTPUT_PATH, "w") as f: f.write(code)
def stencil(self, name: str, body: t.List, fields: t.List): with self.ctx.scope.new_scope(): for field in fields: self.field_declaration(field) body = make_ast(self.statements(body, in_stencil_root_scope=True)) fields = [ symbol.sir for symbol in self.ctx.scope.current_scope if isinstance(symbol, (DuskField, DuskIndexField)) ] return make_stencil(name, body, fields)
def visit_ComputationBlock(self, node: gt_ir.ComputationBlock, **kwargs: Any) -> SIR.VerticalRegionDeclStmt: interval = self.visit(node.interval) body_ast = sir_utils.make_ast(self.visit(node.body, make_block=False)) loop_order = (SIR.VerticalRegion.Backward if node.iteration_order == gt_ir.IterationOrder.BACKWARD else SIR.VerticalRegion.Forward) vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt( body_ast, interval, loop_order) return vertical_region_stmt
def make_copy_stencil_sir(name=None): OUTPUT_NAME = name if name is not None else "copy_stencil" OUTPUT_FILE = f"{OUTPUT_NAME}.cpp" interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0) # create the out = in[i+1] statement body_ast = sir_utils.make_ast([ sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("out", [0, 0, 0]), sir_utils.make_field_access_expr("in", [1, 0, 0]), "=", ) ]) vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt( body_ast, interval, SIR.VerticalRegion.Forward) sir = sir_utils.make_sir( OUTPUT_FILE, sir_utils.GridType.Value("Cartesian"), [ sir_utils.make_stencil( OUTPUT_NAME, sir_utils.make_ast([vertical_region_stmt]), [ sir_utils.make_field( "in", sir_utils.make_field_dimensions_cartesian()), sir_utils.make_field( "out", sir_utils.make_field_dimensions_cartesian()) ], ) ], ) return sir
def main(args: argparse.Namespace): sir = serial_utils.make_sir( OUTPUT_FILE, AST.GridType.Value("Unstructured"), [ serial_utils.make_stencil( OUTPUT_NAME, serial_utils.make_ast([ create_vertical_region_stmt1(), create_vertical_region_stmt2(), create_vertical_region_stmt3(), ]), [ serial_utils.make_field( "a", serial_utils.make_field_dimensions_unstructured( [AST.LocationType.Value("Cell")], 1), ), serial_utils.make_field( "b", serial_utils.make_field_dimensions_unstructured( [AST.LocationType.Value("Cell")], 1), ), serial_utils.make_field( "c", serial_utils.make_field_dimensions_unstructured( [AST.LocationType.Value("Cell")], 1), ), serial_utils.make_field( "d", serial_utils.make_field_dimensions_unstructured( [AST.LocationType.Value("Cell")], 1), ), ], ) ], ) # print the SIR if args.verbose: print(MessageToJson(sir)) # compile code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaiveIco) # write to file print(f"Writing generated code to '{OUTPUT_PATH}'") with open(OUTPUT_PATH, "w") as f: f.write(code)
def main(args: argparse.Namespace): sir = sir_utils.make_sir( OUTPUT_FILE, SIR.GridType.Value("Unstructured"), [ sir_utils.make_stencil( OUTPUT_NAME, sir_utils.make_ast([ create_vertical_region_stmt1(), create_vertical_region_stmt2(), create_vertical_region_stmt3(), ]), [ sir_utils.make_field( "a", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Cell")], 1), ), sir_utils.make_field( "b", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Cell")], 1), ), sir_utils.make_field( "c", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Cell")], 1), ), sir_utils.make_field( "d", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Cell")], 1), ), ], ) ], ) # print the SIR if args.verbose: sir_utils.pprint(sir) # compile code = dawn4py.compile(sir, backend="c++-naive-ico") # write to file print(f"Writing generated code to '{OUTPUT_PATH}'") with open(OUTPUT_PATH, "w") as f: f.write(code)
def create_boundary_correction_region( value="0", i_interval=None, j_interval=None ) -> SIR.VerticalRegionDeclStmt: interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0) boundary_body = sir_utils.make_ast( [ sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("out", [0, 0, 0]), sir_utils.make_literal_access_expr(value, SIR.BuiltinType.Float), "=", ) ] ) vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt( boundary_body, interval, SIR.VerticalRegion.Forward, IRange=i_interval, JRange=j_interval ) return vertical_region_stmt
def create_vertical_region_stmt2(): """ create a vertical region statement for the stencil """ interval = sir_utils.make_interval(sir_utils.Interval.Start, sir_utils.Interval.End, 1, 0) body_ast = sir_utils.make_ast([ sir_utils.make_var_decl_stmt( sir_utils.make_type(sir_utils.BuiltinType.Float), "m", 0, "=", sir_utils.make_expr( sir_utils.make_binary_operator( sir_utils.make_literal_access_expr( "1.0", sir_utils.BuiltinType.Float), "/", sir_utils.make_binary_operator( sir_utils.make_unstructured_field_access_expr("b"), "-", sir_utils.make_binary_operator( sir_utils.make_unstructured_field_access_expr("a"), "*", sir_utils.make_unstructured_field_access_expr( "c", sir_utils.make_unstructured_offset(False), -1)))))), sir_utils.make_assignment_stmt( sir_utils.make_unstructured_field_access_expr("c"), sir_utils.make_binary_operator( sir_utils.make_unstructured_field_access_expr("c"), "*", sir_utils.make_var_access_expr("m")), "="), sir_utils.make_assignment_stmt( sir_utils.make_unstructured_field_access_expr("d"), sir_utils.make_binary_operator( sir_utils.make_binary_operator( sir_utils.make_unstructured_field_access_expr("d"), "-", sir_utils.make_binary_operator( sir_utils.make_unstructured_field_access_expr("a"), "*", sir_utils.make_unstructured_field_access_expr( "d", sir_utils.make_unstructured_offset(False), -1))), "*", sir_utils.make_var_access_expr("m")), "=") ]) vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt( body_ast, interval, SIR.VerticalRegion.Forward) return vertical_region_stmt
def create_vertical_region_stmt() -> SIR.VerticalRegionDeclStmt: """ create a vertical region statement for the stencil """ interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0) # create the out = in[i+1] statement body_ast = sir_utils.make_ast([ sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("out", [0, 0, 0]), sir_utils.make_field_access_expr("in", [0, 0, 0]), "=", ) ]) vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt( body_ast, interval, SIR.VerticalRegion.Forward) return vertical_region_stmt
def create_vertical_region_stmt3(): """ create a vertical region statement for the stencil """ interval = sir_utils.make_interval(sir_utils.Interval.Start, sir_utils.Interval.End, 0, -1) body_ast = sir_utils.make_ast([ sir_utils.make_assignment_stmt( sir_utils.make_unstructured_field_access_expr("d"), sir_utils.make_binary_operator( sir_utils.make_unstructured_field_access_expr("c"), "*", sir_utils.make_unstructured_field_access_expr( "d", sir_utils.make_unstructured_offset(False), 1)), "-=") ]) vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt( body_ast, interval, SIR.VerticalRegion.Backward) return vertical_region_stmt
def make_stencil(outputfile, body_ast): interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0) vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt( body_ast, interval, SIR.VerticalRegion.Forward) sir = sir_utils.make_sir( outputfile, SIR.GridType.Value("Unstructured"), [ sir_utils.make_stencil( "generated", sir_utils.make_ast([vertical_region_stmt]), [], ), ], ) f = open(outputfile, "w") f.write(MessageToJson(sir)) f.close()
def main(args: argparse.Namespace): # ---- First vertical region statement ---- interval_1 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0) body_ast_1 = sir_utils.make_ast([ sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("c"), sir_utils.make_binary_operator( sir_utils.make_field_access_expr("c"), "/", sir_utils.make_field_access_expr("b"), ), "=", ) ]) vertical_region_stmt_1 = sir_utils.make_vertical_region_decl_stmt( body_ast_1, interval_1, SIR.VerticalRegion.Forward) # ---- Second vertical region statement ---- interval_2 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 1, 0) body_ast_2 = sir_utils.make_ast([ sir_utils.make_var_decl_stmt( sir_utils.make_type(SIR.BuiltinType.Integer), "m", 0, "=", sir_utils.make_expr( sir_utils.make_binary_operator( sir_utils.make_literal_access_expr("1.0", SIR.BuiltinType.Float), "/", sir_utils.make_binary_operator( sir_utils.make_field_access_expr("b"), "-", sir_utils.make_binary_operator( sir_utils.make_field_access_expr("a"), "*", sir_utils.make_field_access_expr("c", [0, 0, -1]), ), ), )), ), sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("c"), sir_utils.make_binary_operator( sir_utils.make_field_access_expr("c"), "*", sir_utils.make_var_access_expr("m")), "=", ), sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("d"), sir_utils.make_binary_operator( sir_utils.make_binary_operator( sir_utils.make_field_access_expr("d"), "-", sir_utils.make_binary_operator( sir_utils.make_field_access_expr("a"), "*", sir_utils.make_field_access_expr("d", [0, 0, -1]), ), ), "*", sir_utils.make_var_access_expr("m"), ), "=", ), ]) vertical_region_stmt_2 = sir_utils.make_vertical_region_decl_stmt( body_ast_2, interval_2, SIR.VerticalRegion.Forward) # ---- Third vertical region statement ---- interval_3 = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, -1) body_ast_3 = sir_utils.make_ast([ sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("d"), sir_utils.make_binary_operator( sir_utils.make_field_access_expr("c"), "*", sir_utils.make_field_access_expr("d", [0, 0, 1]), ), "-=", ) ]) vertical_region_stmt_3 = sir_utils.make_vertical_region_decl_stmt( body_ast_3, interval_3, SIR.VerticalRegion.Backward) sir = sir_utils.make_sir( OUTPUT_FILE, SIR.GridType.Value("Cartesian"), [ sir_utils.make_stencil( OUTPUT_NAME, sir_utils.make_ast([ vertical_region_stmt_1, vertical_region_stmt_2, vertical_region_stmt_3 ]), [ sir_utils.make_field( "a", sir_utils.make_field_dimensions_cartesian()), sir_utils.make_field( "b", sir_utils.make_field_dimensions_cartesian()), sir_utils.make_field( "c", sir_utils.make_field_dimensions_cartesian()), sir_utils.make_field( "d", sir_utils.make_field_dimensions_cartesian()), ], ) ], ) # print the SIR if args.verbose: sir_utils.pprint(sir) # compile code = dawn4py.compile(sir, backend="cuda") # write to file print(f"Writing generated code to '{OUTPUT_PATH}'") with open(OUTPUT_PATH, "w") as f: f.write(code)
def make_hori_diff_stencil_sir(name=None): OUTPUT_NAME = name if name is not None else "hori_diff_stencil" OUTPUT_FILE = f"{OUTPUT_NAME}.cpp" interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0) # create the stencil body AST body_ast = sir_utils.make_ast([ sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("lap"), sir_utils.make_binary_operator( sir_utils.make_binary_operator( sir_utils.make_literal_access_expr("-4.0", SIR.BuiltinType.Float), "*", sir_utils.make_field_access_expr("in"), ), "+", sir_utils.make_binary_operator( sir_utils.make_field_access_expr("coeff"), "*", sir_utils.make_binary_operator( sir_utils.make_field_access_expr("in", [1, 0, 0]), "+", sir_utils.make_binary_operator( sir_utils.make_field_access_expr("in", [-1, 0, 0]), "+", sir_utils.make_binary_operator( sir_utils.make_field_access_expr( "in", [0, 1, 0]), "+", sir_utils.make_field_access_expr( "in", [0, -1, 0]), ), ), ), ), ), "=", ), sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("out"), sir_utils.make_binary_operator( sir_utils.make_binary_operator( sir_utils.make_literal_access_expr("-4.0", SIR.BuiltinType.Float), "*", sir_utils.make_field_access_expr("lap"), ), "+", sir_utils.make_binary_operator( sir_utils.make_field_access_expr("coeff"), "*", sir_utils.make_binary_operator( sir_utils.make_field_access_expr("lap", [1, 0, 0]), "+", sir_utils.make_binary_operator( sir_utils.make_field_access_expr( "lap", [-1, 0, 0]), "+", sir_utils.make_binary_operator( sir_utils.make_field_access_expr( "lap", [0, 1, 0]), "+", sir_utils.make_field_access_expr( "lap", [0, -1, 0]), ), ), ), ), ), "=", ), ]) vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt( body_ast, interval, SIR.VerticalRegion.Forward) sir = sir_utils.make_sir( OUTPUT_FILE, sir_utils.GridType.Value("Cartesian"), [ sir_utils.make_stencil( OUTPUT_NAME, sir_utils.make_ast([vertical_region_stmt]), [ sir_utils.make_field( "in", sir_utils.make_field_dimensions_cartesian()), sir_utils.make_field( "out", sir_utils.make_field_dimensions_cartesian()), sir_utils.make_field( "coeff", sir_utils.make_field_dimensions_cartesian()), sir_utils.make_field( "lap", sir_utils.make_field_dimensions_cartesian(), is_temporary=True), ], ) ], ) return sir
def main(): stencil_name = "ICON_laplacian_diamond_stencil" gen_outputfile = f"{stencil_name}.cpp" sir_outputfile = f"{stencil_name}.sir" interval = sir_utils.make_interval( SIR.Interval.Start, SIR.Interval.End, 0, 0) body_ast = sir_utils.make_ast( [ # fill sparse dimension vn vert using the loop concept sir_utils.make_loop_stmt( [sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("vn_vert"), sir_utils.make_binary_operator( sir_utils.make_binary_operator(sir_utils.make_field_access_expr( "u_vert", [True, 0]), "*", sir_utils.make_field_access_expr("primal_normal_x", [True, 0])), "+", sir_utils.make_binary_operator(sir_utils.make_field_access_expr( "v_vert", [True, 0]), "*", sir_utils.make_field_access_expr("primal_normal_y", [True, 0])), ), "=")], [SIR.LocationType.Value( "Edge"), SIR.LocationType.Value("Cell"), SIR.LocationType.Value("Vertex")] ), # dvt_tang for smagorinsky sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("dvt_tang"), sir_utils.make_reduction_over_neighbor_expr( op="+", init=sir_utils.make_literal_access_expr( "0.0", SIR.BuiltinType.Double), rhs=sir_utils.make_binary_operator( sir_utils.make_binary_operator(sir_utils.make_field_access_expr( "u_vert", [True, 0]), "*", sir_utils.make_field_access_expr("dual_normal_x", [True, 0])), "+", sir_utils.make_binary_operator(sir_utils.make_field_access_expr( "v_vert", [True, 0]), "*", sir_utils.make_field_access_expr("dual_normal_y", [True, 0])), ), chain=[SIR.LocationType.Value("Edge"), SIR.LocationType.Value( "Cell"), SIR.LocationType.Value("Vertex")], weights=[sir_utils.make_literal_access_expr( "-1.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr( "1.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr( "0.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr( "0.0", SIR.BuiltinType.Double)] ), "=", ), sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("dvt_tang"), sir_utils.make_binary_operator( sir_utils.make_field_access_expr("dvt_tang"), "*", sir_utils.make_field_access_expr("tangent_orientation")), "="), # dvt_norm for smagorinsky sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("dvt_norm"), sir_utils.make_reduction_over_neighbor_expr( op="+", init=sir_utils.make_literal_access_expr( "0.0", SIR.BuiltinType.Double), rhs=sir_utils.make_binary_operator( sir_utils.make_binary_operator(sir_utils.make_field_access_expr( "u_vert", [True, 0]), "*", sir_utils.make_field_access_expr("dual_normal_x", [True, 0])), "+", sir_utils.make_binary_operator(sir_utils.make_field_access_expr( "v_vert", [True, 0]), "*", sir_utils.make_field_access_expr("dual_normal_y", [True, 0])), ), chain=[SIR.LocationType.Value("Edge"), SIR.LocationType.Value( "Cell"), SIR.LocationType.Value("Vertex")], weights=[sir_utils.make_literal_access_expr( "0.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr( "0.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr( "-1.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr( "1.0", SIR.BuiltinType.Double)] ), "=", ), # compute smagorinsky sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("kh_smag_1"), sir_utils.make_reduction_over_neighbor_expr( op="+", init=sir_utils.make_literal_access_expr( "0.0", SIR.BuiltinType.Double), rhs=sir_utils.make_field_access_expr("vn_vert"), chain=[SIR.LocationType.Value("Edge"), SIR.LocationType.Value( "Cell"), SIR.LocationType.Value("Vertex")], weights=[sir_utils.make_literal_access_expr( "-1.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr( "1.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr( "0.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr( "0.0", SIR.BuiltinType.Double)] ), "=", ), sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("kh_smag_1"), sir_utils.make_binary_operator( sir_utils.make_binary_operator( sir_utils.make_binary_operator( sir_utils.make_field_access_expr("kh_smag_1"), "*", sir_utils.make_field_access_expr("tangent_orientation")), "*", sir_utils.make_field_access_expr("inv_primal_edge_length")), "+", sir_utils.make_binary_operator( sir_utils.make_field_access_expr("dvt_norm"), "*", sir_utils.make_field_access_expr("inv_vert_vert_length"))), "="), sir_utils.make_assignment_stmt(sir_utils.make_field_access_expr("kh_smag_1"), sir_utils.make_binary_operator(sir_utils.make_field_access_expr( "kh_smag_1"), "*", sir_utils.make_field_access_expr("kh_smag_1"))), sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("kh_smag_2"), sir_utils.make_reduction_over_neighbor_expr( op="+", init=sir_utils.make_literal_access_expr( "0.0", SIR.BuiltinType.Double), rhs=sir_utils.make_field_access_expr("vn_vert"), chain=[SIR.LocationType.Value("Edge"), SIR.LocationType.Value( "Cell"), SIR.LocationType.Value("Vertex")], weights=[sir_utils.make_literal_access_expr( "0.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr( "0.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr( "-1.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr( " 1.0", SIR.BuiltinType.Double)] ), "=", ), sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("kh_smag_2"), sir_utils.make_binary_operator( sir_utils.make_binary_operator( sir_utils.make_field_access_expr("kh_smag_2"), "*", sir_utils.make_field_access_expr("inv_vert_vert_length")), "+", sir_utils.make_binary_operator( sir_utils.make_field_access_expr("dvt_tang"), "*", sir_utils.make_field_access_expr("inv_primal_edge_length"))), "="), sir_utils.make_assignment_stmt(sir_utils.make_field_access_expr("kh_smag_2"), sir_utils.make_binary_operator(sir_utils.make_field_access_expr( "kh_smag_2"), "*", sir_utils.make_field_access_expr("kh_smag_2"))), # currently not able to forward a sqrt, so this is technically kh_smag**2 sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("kh_smag"), sir_utils.make_binary_operator(sir_utils.make_field_access_expr("diff_multfac_smag"), "*", sir_utils.make_fun_call_expr("math::sqrt", [sir_utils.make_binary_operator(sir_utils.make_field_access_expr( "kh_smag_1"), "+", sir_utils.make_field_access_expr("kh_smag_2"))])), "="), # compute nabla2 using the diamond reduction sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("nabla2"), sir_utils.make_reduction_over_neighbor_expr( op="+", init=sir_utils.make_literal_access_expr( "0.0", SIR.BuiltinType.Double), rhs=sir_utils.make_binary_operator(sir_utils.make_literal_access_expr( "4.0", SIR.BuiltinType.Double), "*", sir_utils.make_field_access_expr("vn_vert")), chain=[SIR.LocationType.Value("Edge"), SIR.LocationType.Value( "Cell"), SIR.LocationType.Value("Vertex")], weights=[ sir_utils.make_binary_operator( sir_utils.make_field_access_expr( "inv_primal_edge_length"), '*', sir_utils.make_field_access_expr( "inv_primal_edge_length")), sir_utils.make_binary_operator( sir_utils.make_field_access_expr( "inv_primal_edge_length"), '*', sir_utils.make_field_access_expr( "inv_primal_edge_length")), sir_utils.make_binary_operator( sir_utils.make_field_access_expr( "inv_vert_vert_length"), '*', sir_utils.make_field_access_expr( "inv_vert_vert_length")), sir_utils.make_binary_operator( sir_utils.make_field_access_expr( "inv_vert_vert_length"), '*', sir_utils.make_field_access_expr( "inv_vert_vert_length")), ] ), "=", ), sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("nabla2"), sir_utils.make_binary_operator( sir_utils.make_field_access_expr("nabla2"), "-", sir_utils.make_binary_operator( sir_utils.make_binary_operator(sir_utils.make_binary_operator(sir_utils.make_literal_access_expr( "8.0", SIR.BuiltinType.Double), "*", sir_utils.make_field_access_expr("vn")), "*", sir_utils.make_binary_operator( sir_utils.make_field_access_expr( "inv_primal_edge_length"), "*", sir_utils.make_field_access_expr( "inv_primal_edge_length"))), "+", sir_utils.make_binary_operator(sir_utils.make_binary_operator(sir_utils.make_literal_access_expr( "8.0", SIR.BuiltinType.Double), "*", sir_utils.make_field_access_expr("vn")), "*", sir_utils.make_binary_operator( sir_utils.make_field_access_expr( "inv_vert_vert_length"), "*", sir_utils.make_field_access_expr( "inv_vert_vert_length"))))), "=") ] ) vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt( body_ast, interval, SIR.VerticalRegion.Forward ) sir = sir_utils.make_sir( gen_outputfile, SIR.GridType.Value("Unstructured"), [ sir_utils.make_stencil( stencil_name, sir_utils.make_ast([vertical_region_stmt]), [ sir_utils.make_field( "diff_multfac_smag", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value( "Edge")], 1 ), ), sir_utils.make_field( "tangent_orientation", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1 ), ), sir_utils.make_field( "inv_primal_edge_length", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1 ), ), sir_utils.make_field( "inv_vert_vert_length", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1 ), ), sir_utils.make_field( "u_vert", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Vertex")], 1 ), ), sir_utils.make_field( "v_vert", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Vertex")], 1 ), ), sir_utils.make_field( "primal_normal_x", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge"), SIR.LocationType.Value( "Cell"), SIR.LocationType.Value("Vertex")], 1 ), ), sir_utils.make_field( "primal_normal_y", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge"), SIR.LocationType.Value( "Cell"), SIR.LocationType.Value("Vertex")], 1 ), ), sir_utils.make_field( "dual_normal_x", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge"), SIR.LocationType.Value( "Cell"), SIR.LocationType.Value("Vertex")], 1 ), ), sir_utils.make_field( "dual_normal_y", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge"), SIR.LocationType.Value( "Cell"), SIR.LocationType.Value("Vertex")], 1 ), ), sir_utils.make_field( "vn_vert", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge"), SIR.LocationType.Value( "Cell"), SIR.LocationType.Value("Vertex")], 1 ), ), sir_utils.make_field( "vn", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1 ), ), sir_utils.make_field( "dvt_tang", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1 ), ), sir_utils.make_field( "dvt_norm", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1 ), ), sir_utils.make_field( "kh_smag_1", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1 ), ), sir_utils.make_field( "kh_smag_2", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1 ), ), sir_utils.make_field( "kh_smag", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1 ), ), sir_utils.make_field( "nabla2", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1 ), ), ], ), ], ) # write SIR to file (for debugging purposes) f = open(sir_outputfile, "w") f.write(MessageToJson(sir)) f.close() # compile code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaiveIco) # write to file print(f"Writing generated code to '{gen_outputfile}'") with open(gen_outputfile, "w") as f: f.write(code)
def main(args: argparse.Namespace): sir = sir_utils.make_sir( OUTPUT_FILE, SIR.GridType.Value("Cartesian"), [ sir_utils.make_stencil( "global_indexing", sir_utils.make_ast([ create_vertical_region_stmt(), create_boundary_correction_region( value="4", i_interval=sir_utils.make_interval( SIR.Interval.End, SIR.Interval.End, -1, 0), ), create_boundary_correction_region( value="8", i_interval=sir_utils.make_interval( SIR.Interval.Start, SIR.Interval.Start, 0, 1), ), create_boundary_correction_region( value="6", j_interval=sir_utils.make_interval( SIR.Interval.End, SIR.Interval.End, -1, 0), ), create_boundary_correction_region( value="2", j_interval=sir_utils.make_interval( SIR.Interval.Start, SIR.Interval.Start, 0, 1), ), create_boundary_correction_region( value="1", j_interval=sir_utils.make_interval( SIR.Interval.Start, SIR.Interval.Start, 0, 1), i_interval=sir_utils.make_interval( SIR.Interval.Start, SIR.Interval.Start, 0, 1), ), create_boundary_correction_region( value="3", j_interval=sir_utils.make_interval( SIR.Interval.Start, SIR.Interval.Start, 0, 1), i_interval=sir_utils.make_interval( SIR.Interval.End, SIR.Interval.End, -1, 0), ), create_boundary_correction_region( value="7", j_interval=sir_utils.make_interval( SIR.Interval.End, SIR.Interval.End, -1, 0), i_interval=sir_utils.make_interval( SIR.Interval.Start, SIR.Interval.Start, 0, 1), ), create_boundary_correction_region( value="5", j_interval=sir_utils.make_interval( SIR.Interval.End, SIR.Interval.End, -1, 0), i_interval=sir_utils.make_interval( SIR.Interval.End, SIR.Interval.End, -1, 0), ), ]), [ sir_utils.make_field( "in", sir_utils.make_field_dimensions_cartesian()), sir_utils.make_field( "out", sir_utils.make_field_dimensions_cartesian()) ], ) ], ) # print the SIR if args.verbose: sir_utils.pprint(sir) # compile code = dawn4py.compile(sir, backend="c++-naive") # write to file print(f"Writing generated code to '{OUTPUT_PATH}'") with open(OUTPUT_PATH, "w") as f: f.write(code)
def main(args: argparse.Namespace): interval = serial_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0) # create the laplace statement body_ast = serial_utils.make_ast([ serial_utils.make_assignment_stmt( serial_utils.make_field_access_expr("out", [0, 0, 0]), serial_utils.make_binary_operator( serial_utils.make_binary_operator( serial_utils.make_binary_operator( serial_utils.make_field_access_expr("in", [0, 0, 0]), "*", serial_utils.make_literal_access_expr( "-4.0", serial_utils.BuiltinType.Float), ), "+", serial_utils.make_binary_operator( serial_utils.make_field_access_expr("in", [1, 0, 0]), "+", serial_utils.make_binary_operator( serial_utils.make_field_access_expr( "in", [-1, 0, 0]), "+", serial_utils.make_binary_operator( serial_utils.make_field_access_expr( "in", [0, 1, 0]), "+", serial_utils.make_field_access_expr( "in", [0, -1, 0]), ), ), ), ), "/", serial_utils.make_binary_operator( serial_utils.make_var_access_expr("dx", is_external=True), "*", serial_utils.make_var_access_expr("dx", is_external=True), ), ), "=", ), ]) vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt( body_ast, interval, SIR.VerticalRegion.Forward) stencils_globals = serial_utils.GlobalVariableMap() stencils_globals.map["dx"].double_value = 0.0 sir = serial_utils.make_sir( OUTPUT_FILE, SIR.GridType.Value("Cartesian"), [ serial_utils.make_stencil( OUTPUT_NAME, serial_utils.make_ast([vertical_region_stmt]), [ serial_utils.make_field( "out", serial_utils.make_field_dimensions_cartesian()), serial_utils.make_field( "in", serial_utils.make_field_dimensions_cartesian()), ], ) ], global_variables=stencils_globals, ) # print the SIR if args.verbose: serial_utils.pprint(sir) # serialize the SIR to file sir_file = open("./laplacian_stencil_from_python.sir", "wb") sir_file.write(serial_utils.to_json(sir)) sir_file.close() # compile code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaive) # write to file print(f"Writing generated code to '{OUTPUT_PATH}'") with open(OUTPUT_PATH, "w") as f: f.write(code)
def main(args: argparse.Namespace): interval = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End, 0, 0) # out = in_1 on inner cells body_ast_1 = serial_utils.make_ast([ serial_utils.make_assignment_stmt( serial_utils.make_field_access_expr("out"), serial_utils.make_field_access_expr("in_1"), "=", ) ]) vertical_region_stmt_1 = serial_utils.make_vertical_region_decl_stmt( body_ast_1, interval, AST.VerticalRegion.Forward, serial_utils.make_magic_num_interval(0, 1, 0, 0)) # out = out + in_2 on inner cells # should be merge-able to last stage body_ast_2 = serial_utils.make_ast([ serial_utils.make_assignment_stmt( serial_utils.make_field_access_expr("out"), serial_utils.make_binary_operator( serial_utils.make_field_access_expr("out"), "+", serial_utils.make_field_access_expr("in_2"), ), "=") ]) vertical_region_stmt_2 = serial_utils.make_vertical_region_decl_stmt( body_ast_2, interval, AST.VerticalRegion.Forward, serial_utils.make_interval(2, 3, 0, 0)) # out = out + in_3 on lateral boundary cells # out = in_1 on inner cells body_ast_3 = serial_utils.make_ast([ serial_utils.make_assignment_stmt( serial_utils.make_field_access_expr("out"), serial_utils.make_field_access_expr("in_3"), "=", ) ]) vertical_region_stmt_3 = serial_utils.make_vertical_region_decl_stmt( body_ast_3, interval, AST.VerticalRegion.Forward, serial_utils.make_interval(3, 4, 0, 0)) sir = serial_utils.make_sir( OUTPUT_FILE, AST.GridType.Value("Unstructured"), [ serial_utils.make_stencil( OUTPUT_NAME, serial_utils.make_ast([ vertical_region_stmt_1, vertical_region_stmt_2, vertical_region_stmt_3 ]), [ serial_utils.make_field( "in_1", serial_utils.make_field_dimensions_unstructured( [AST.LocationType.Value("Cell")], 1), ), serial_utils.make_field( "in_2", serial_utils.make_field_dimensions_unstructured( [AST.LocationType.Value("Cell")], 1), ), serial_utils.make_field( "in_3", serial_utils.make_field_dimensions_unstructured( [AST.LocationType.Value("Cell")], 1), ), serial_utils.make_field( "out", serial_utils.make_field_dimensions_unstructured( [AST.LocationType.Value("Cell")], 1), ), ], ) ], ) # print the SIR if args.verbose: print(MessageToJson(sir)) # compile pass_groups = dawn4py.default_pass_groups() pass_groups.insert(1, dawn4py.PassGroup.MultiStageMerger) pass_groups.insert(1, dawn4py.PassGroup.StageMerger) # code = dawn4py.compile(sir, groups=pass_groups, # backend=dawn4py.CodeGenBackend.CXXNaiveIco, merge_stages=True, merge_do_methods=True) code = dawn4py.compile(sir, groups=pass_groups, backend=dawn4py.CodeGenBackend.CXXNaiveIco) # write to file print(f"Writing generated code to '{OUTPUT_PATH}'") with open(OUTPUT_PATH, "w") as f: f.write(code)
def main(args: argparse.Namespace): # ---- First vertical region statement ---- interval_1 = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End, 0, 0) body_ast_1 = serial_utils.make_ast([ serial_utils.make_assignment_stmt( serial_utils.make_field_access_expr("c"), serial_utils.make_binary_operator( serial_utils.make_field_access_expr("c"), "/", serial_utils.make_field_access_expr("b"), ), "=", ) ]) vertical_region_stmt_1 = serial_utils.make_vertical_region_decl_stmt( body_ast_1, interval_1, AST.VerticalRegion.Forward) # ---- Second vertical region statement ---- interval_2 = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End, 1, 0) body_ast_2 = serial_utils.make_ast([ serial_utils.make_var_decl_stmt( serial_utils.make_type(AST.BuiltinType.Integer), "m", 0, "=", serial_utils.make_expr( serial_utils.make_binary_operator( serial_utils.make_literal_access_expr( "1.0", AST.BuiltinType.Float), "/", serial_utils.make_binary_operator( serial_utils.make_field_access_expr("b"), "-", serial_utils.make_binary_operator( serial_utils.make_field_access_expr("a"), "*", serial_utils.make_field_access_expr( "c", [0, 0, -1]), ), ), )), ), serial_utils.make_assignment_stmt( serial_utils.make_field_access_expr("c"), serial_utils.make_binary_operator( serial_utils.make_field_access_expr("c"), "*", serial_utils.make_var_access_expr("m")), "=", ), serial_utils.make_assignment_stmt( serial_utils.make_field_access_expr("d"), serial_utils.make_binary_operator( serial_utils.make_binary_operator( serial_utils.make_field_access_expr("d"), "-", serial_utils.make_binary_operator( serial_utils.make_field_access_expr("a"), "*", serial_utils.make_field_access_expr("d", [0, 0, -1]), ), ), "*", serial_utils.make_var_access_expr("m"), ), "=", ), ]) vertical_region_stmt_2 = serial_utils.make_vertical_region_decl_stmt( body_ast_2, interval_2, AST.VerticalRegion.Forward) # ---- Third vertical region statement ---- interval_3 = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End, 0, -1) body_ast_3 = serial_utils.make_ast([ serial_utils.make_assignment_stmt( serial_utils.make_field_access_expr("d"), serial_utils.make_binary_operator( serial_utils.make_field_access_expr("c"), "*", serial_utils.make_field_access_expr("d", [0, 0, 1]), ), "-=", ) ]) vertical_region_stmt_3 = serial_utils.make_vertical_region_decl_stmt( body_ast_3, interval_3, AST.VerticalRegion.Backward) sir = serial_utils.make_sir( OUTPUT_FILE, AST.GridType.Value("Cartesian"), [ serial_utils.make_stencil( OUTPUT_NAME, serial_utils.make_ast([ vertical_region_stmt_1, vertical_region_stmt_2, vertical_region_stmt_3 ]), [ serial_utils.make_field( "a", serial_utils.make_field_dimensions_cartesian()), serial_utils.make_field( "b", serial_utils.make_field_dimensions_cartesian()), serial_utils.make_field( "c", serial_utils.make_field_dimensions_cartesian()), serial_utils.make_field( "d", serial_utils.make_field_dimensions_cartesian()), ], ) ], ) # print the SIR if args.verbose: print(MessageToJson(sir)) # compile pass_groups = dawn4py.default_pass_groups() pass_groups.insert(1, dawn4py.PassGroup.MultiStageMerger) code = dawn4py.compile(sir, groups=pass_groups, backend=dawn4py.CodeGenBackend.CUDA) # write to file print(f"Writing generated code to '{OUTPUT_PATH}'") with open(OUTPUT_PATH, "w") as f: f.write(code)
def main(): stencil_name = "ICON_laplacian_stencil" gen_outputfile = f"{stencil_name}.cpp" sir_outputfile = f"{stencil_name}.sir" interval = sir_utils.make_interval( SIR.Interval.Start, SIR.Interval.End, 0, 0) body_ast = sir_utils.make_ast( [ sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("rot_vec"), sir_utils.make_reduction_over_neighbor_expr( op="+", init=sir_utils.make_literal_access_expr( "0.0", SIR.BuiltinType.Double), rhs=sir_utils.make_binary_operator( sir_utils.make_field_access_expr("vec"), "*", sir_utils.make_field_access_expr("geofac_rot")), chain=[SIR.LocationType.Value( "Vertex"), SIR.LocationType.Value("Edge")] ), "=", ), sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("div_vec"), sir_utils.make_reduction_over_neighbor_expr( op="+", init=sir_utils.make_literal_access_expr( "0.0", SIR.BuiltinType.Double), rhs=sir_utils.make_binary_operator( sir_utils.make_field_access_expr("vec"), "*", sir_utils.make_field_access_expr("geofac_div")), chain=[SIR.LocationType.Value( "Cell"), SIR.LocationType.Value("Edge")] ), "=", ), sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("nabla2t1_vec"), sir_utils.make_reduction_over_neighbor_expr( op="+", init=sir_utils.make_literal_access_expr( "0.0", SIR.BuiltinType.Double), rhs=sir_utils.make_field_access_expr("rot_vec"), chain=[SIR.LocationType.Value( "Edge"), SIR.LocationType.Value("Vertex")], weights=[sir_utils.make_literal_access_expr( "-1.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr( "1.0", SIR.BuiltinType.Double)] ), "=", ), sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("nabla2t1_vec"), sir_utils.make_binary_operator( sir_utils.make_binary_operator( sir_utils.make_field_access_expr( "tangent_orientation"), "*", sir_utils.make_field_access_expr("nabla2t1_vec")), "/", sir_utils.make_field_access_expr("primal_edge_length")), "=", ), sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("nabla2t2_vec"), sir_utils.make_reduction_over_neighbor_expr( op="+", init=sir_utils.make_literal_access_expr( "0.0", SIR.BuiltinType.Double), rhs=sir_utils.make_field_access_expr("div_vec"), chain=[SIR.LocationType.Value( "Edge"), SIR.LocationType.Value("Cell")], weights=[sir_utils.make_literal_access_expr( "-1.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr( "1.0", SIR.BuiltinType.Double)] ), "=", ), sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("nabla2t2_vec"), sir_utils.make_binary_operator( sir_utils.make_field_access_expr("nabla2t2_vec"), "/", sir_utils.make_field_access_expr("dual_edge_length")), "=", ), sir_utils.make_assignment_stmt( sir_utils.make_field_access_expr("nabla2_vec"), sir_utils.make_binary_operator( sir_utils.make_field_access_expr("nabla2t2_vec"), "-", sir_utils.make_field_access_expr("nabla2t1_vec")), "=", ), ] ) vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt( body_ast, interval, SIR.VerticalRegion.Forward ) sir = sir_utils.make_sir( gen_outputfile, SIR.GridType.Value("Unstructured"), [ sir_utils.make_stencil( stencil_name, sir_utils.make_ast([vertical_region_stmt]), [ sir_utils.make_field( "vec", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1 ), ), sir_utils.make_field( "div_vec", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Cell")], 1 ), ), sir_utils.make_field( "rot_vec", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Vertex")], 1 ), ), sir_utils.make_field( "nabla2t1_vec", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1 ), ), sir_utils.make_field( "nabla2t2_vec", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1 ), ), sir_utils.make_field( "nabla2_vec", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1 ), ), sir_utils.make_field( "primal_edge_length", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1 ), ), sir_utils.make_field( "dual_edge_length", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1 ), ), sir_utils.make_field( "tangent_orientation", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value("Edge")], 1 ), ), sir_utils.make_field( "geofac_rot", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value( "Vertex"), SIR.LocationType.Value("Edge")], 1 ), ), sir_utils.make_field( "geofac_div", sir_utils.make_field_dimensions_unstructured( [SIR.LocationType.Value( "Cell"), SIR.LocationType.Value("Edge")], 1 ), ), ], ), ], ) # write SIR to file (for debugging purposes) f = open(sir_outputfile, "w") f.write(MessageToJson(sir)) f.close() # compile code = dawn4py.compile(sir, backend="c++-naive-ico") # write to file print(f"Writing generated code to '{gen_outputfile}'") with open(gen_outputfile, "w") as f: f.write(code)