def make_stencil(fields, statements): root = sir.BlockStmt(statements=statements) ast = sir.AST(root=root) vert_decl_stmt = sir.VerticalRegionDeclStmt( vertical_region=sir.VerticalRegion( ast=ast, interval=sir.Interval(), loop_order=common.LoopOrder.FORWARD)) ctrl_flow_ast = sir.AST(root=sir.BlockStmt(statements=[vert_decl_stmt])) return sir.Stencil(name="stencil", ast=ctrl_flow_ast, params=fields)
ast = sir.AST(root=root) vert_decl_stmt = sir.VerticalRegionDeclStmt(vertical_region=sir.VerticalRegion( ast=ast, interval=sir.Interval(), loop_order=common.LoopOrder.FORWARD)) ctrl_flow_ast = sir.AST(root=sir.BlockStmt(statements=[vert_decl_stmt])) field_a = sir.Field( name="field_a", is_temporary=False, field_dimensions=sir.FieldDimensions( horizontal_dimension=sir.UnstructuredDimension( dense_location_type=sir.LocationType.Cell)), ) field_b = sir.Field( name="field_b", is_temporary=False, field_dimensions=sir.FieldDimensions( horizontal_dimension=sir.UnstructuredDimension( dense_location_type=sir.LocationType.Cell)), ) stencil = sir.Stencil(name="copy", ast=ctrl_flow_ast, params=[field_a, field_b]) debug(stencil) naive = sir_to_naive.SirToNaive().visit(stencil) cpp = naive_codegen.NaiveCodeGenerator.apply(naive) print(cpp)
# ), # ] # ) assign_pnabla_MYY_vol = sir.ExprStmt(expr=sir.AssignmentExpr( left=sir.FieldAccessExpr(name="pnabla_MYY", vertical_offset=0, horizontal_offset=sir.ZeroOffset()), op="=", right=sir.BinaryOperator( left=sir.FieldAccessExpr(name="pnabla_MYY", vertical_offset=0, horizontal_offset=sir.ZeroOffset()), op="/", right=sir.FieldAccessExpr( name="vol", vertical_offset=0, horizontal_offset=sir.ZeroOffset()), ), )) statements.append(assign_pnabla_MYY_vol) block = sir.BlockStmt(statements=statements) ast = sir.AST(root=sir.BlockStmt(statements=statements)) vert_decl_stmt = sir.VerticalRegionDeclStmt(vertical_region=sir.VerticalRegion( ast=ast, interval=sir.Interval(), loop_order=common.LoopOrder.FORWARD)) ctrl_flow_ast = sir.AST(root=sir.BlockStmt(statements=[vert_decl_stmt])) stencil = sir.Stencil(name="nabla", ast=ctrl_flow_ast, params=fields) var_loc_type_inferred = InferLocalVariableLocationTypeTransformation.apply( stencil) naive_ir = sir_to_naive.SirToNaive().visit(var_loc_type_inferred) print(naive_codegen.NaiveCodeGenerator.apply(naive_ir))