def make_stencil(fields, statements): root = sir.BlockStmt(statements=statements) ast = sir.AST(root=root) vert_decl_stmt = sir.VerticalRegionDeclStmt( vertical_region=sir.VerticalRegion( ast=ast, interval=sir.Interval(), loop_order=common.LoopOrder.FORWARD)) ctrl_flow_ast = sir.AST(root=sir.BlockStmt(statements=[vert_decl_stmt])) return sir.Stencil(name="stencil", ast=ctrl_flow_ast, params=fields)
from devtools import debug # noqa: F401 import eve # noqa: F401 from gtc import common, sir from gtc.unstructured import naive_codegen, sir_to_naive field_acc_a = sir.FieldAccessExpr(name="field_a", vertical_offset=0, horizontal_offset=sir.ZeroOffset()) field_acc_b = sir.FieldAccessExpr(name="field_b", vertical_offset=0, horizontal_offset=sir.ZeroOffset()) assign_expr = sir.AssignmentExpr(left=field_acc_a, op="=", right=field_acc_b) assign_expr_stmt = sir.ExprStmt(expr=assign_expr) root = sir.BlockStmt(statements=[assign_expr_stmt]) ast = sir.AST(root=root) vert_decl_stmt = sir.VerticalRegionDeclStmt(vertical_region=sir.VerticalRegion( ast=ast, interval=sir.Interval(), loop_order=common.LoopOrder.FORWARD)) ctrl_flow_ast = sir.AST(root=sir.BlockStmt(statements=[vert_decl_stmt])) field_a = sir.Field( name="field_a", is_temporary=False, field_dimensions=sir.FieldDimensions( horizontal_dimension=sir.UnstructuredDimension( dense_location_type=sir.LocationType.Cell)), ) field_b = sir.Field( name="field_b",
# ), # ] # ) assign_pnabla_MYY_vol = sir.ExprStmt(expr=sir.AssignmentExpr( left=sir.FieldAccessExpr(name="pnabla_MYY", vertical_offset=0, horizontal_offset=sir.ZeroOffset()), op="=", right=sir.BinaryOperator( left=sir.FieldAccessExpr(name="pnabla_MYY", vertical_offset=0, horizontal_offset=sir.ZeroOffset()), op="/", right=sir.FieldAccessExpr( name="vol", vertical_offset=0, horizontal_offset=sir.ZeroOffset()), ), )) statements.append(assign_pnabla_MYY_vol) block = sir.BlockStmt(statements=statements) ast = sir.AST(root=sir.BlockStmt(statements=statements)) vert_decl_stmt = sir.VerticalRegionDeclStmt(vertical_region=sir.VerticalRegion( ast=ast, interval=sir.Interval(), loop_order=common.LoopOrder.FORWARD)) ctrl_flow_ast = sir.AST(root=sir.BlockStmt(statements=[vert_decl_stmt])) stencil = sir.Stencil(name="nabla", ast=ctrl_flow_ast, params=fields) var_loc_type_inferred = InferLocalVariableLocationTypeTransformation.apply( stencil) naive_ir = sir_to_naive.SirToNaive().visit(var_loc_type_inferred) print(naive_codegen.NaiveCodeGenerator.apply(naive_ir))