def visit_VarDecl(self, node: VarDecl) -> gtir.ScalarDecl: # datatype conversion works via same ID return gtir.ScalarDecl( name=node.name, dtype=common.DataType(int(node.data_type.value)), loc=location_to_source_location(node.loc), )
def visit_NativeFuncCall(self, node: NativeFuncCall) -> gtir.NativeFuncCall: return gtir.NativeFuncCall( func=self.GT4PY_NATIVE_FUNC_TO_GTIR[node.func], args=[self.visit(arg) for arg in node.args], loc=location_to_source_location(node.loc), )
def visit_FieldRef(self, node: FieldRef) -> gtir.FieldAccess: return gtir.FieldAccess( name=node.name, offset=self.transform_offset(node.offset), data_index=[self.visit(index) for index in node.data_index], loc=location_to_source_location(node.loc), )
def visit_TernaryOpExpr(self, node: TernaryOpExpr) -> gtir.TernaryOp: return gtir.TernaryOp( cond=self.visit(node.condition), true_expr=self.visit(node.then_expr), false_expr=self.visit(node.else_expr), loc=location_to_source_location(node.loc), )
def visit_BinOpExpr( self, node: BinOpExpr) -> Union[gtir.BinaryOp, gtir.NativeFuncCall]: if node.op in (BinaryOperator.POW, BinaryOperator.MOD): return gtir.NativeFuncCall( func=common.NativeFunction[node.op.name], args=[self.visit(node.lhs), self.visit(node.rhs)], loc=location_to_source_location(node.loc), ) return gtir.BinaryOp( left=self.visit(node.lhs), right=self.visit(node.rhs), op=self.GT4PY_OP_TO_GTIR_OP[node.op], loc=location_to_source_location(node.loc), )
def visit_Assign(self, node: Assign) -> gtir.ParAssignStmt: assert isinstance(node.target, FieldRef) or isinstance( node.target, VarRef) return gtir.ParAssignStmt( left=self.visit(node.target), right=self.visit(node.value), loc=location_to_source_location(node.loc), )
def visit_If(self, node: If) -> Union[gtir.FieldIfStmt, gtir.ScalarIfStmt]: cond = self.visit(node.condition) if cond.kind == ExprKind.FIELD: return gtir.FieldIfStmt( cond=cond, true_branch=gtir.BlockStmt(body=self.visit(node.main_body)), false_branch=gtir.BlockStmt(body=self.visit(node.else_body)) if node.else_body else None, loc=location_to_source_location(node.loc), ) else: return gtir.ScalarIfStmt( cond=cond, true_branch=gtir.BlockStmt(body=self.visit(node.main_body)), false_branch=gtir.BlockStmt(body=self.visit(node.else_body)) if node.else_body else None, loc=location_to_source_location(node.loc), )
def visit_FieldDecl(self, node: FieldDecl) -> gtir.FieldDecl: dimension_names = ["I", "J", "K"] dimensions = [dim in node.axes for dim in dimension_names] # datatype conversion works via same ID return gtir.FieldDecl( name=node.name, dtype=common.DataType(int(node.data_type.value)), dimensions=dimensions, data_dims=node.data_dims, loc=location_to_source_location(node.loc), )
def visit_ComputationBlock(self, node: ComputationBlock) -> gtir.VerticalLoop: stmts = [] temporaries = [] for s in node.body.stmts: # FieldDecl or VarDecls in the body are temporaries if isinstance(s, FieldDecl) or isinstance(s, VarDecl): dtype = common.DataType(int(s.data_type.value)) if dtype == common.DataType.DEFAULT: # TODO this will be a frontend choice later # in non-GTC parts, this is set in the backend dtype = cast( common.DataType, common.DataType.FLOAT64 ) # see https://github.com/GridTools/gtc/issues/100 temporaries.append( gtir.FieldDecl( name=s.name, dtype=dtype, dimensions=(True, True, True), loc=location_to_source_location(s.loc), )) else: stmts.append(self.visit(s)) start, end = self.visit(node.interval) interval = gtir.Interval( start=start, end=end, loc=location_to_source_location(node.interval.loc), ) return gtir.VerticalLoop( interval=interval, loop_order=self.GT4PY_ITERATIONORDER_TO_GTIR_LOOPORDER[ node.iteration_order], body=stmts, temporaries=temporaries, loc=location_to_source_location(node.loc), )
def visit_StencilDefinition(self, node: StencilDefinition) -> gtir.Stencil: field_params = {f.name: self.visit(f) for f in node.api_fields} scalar_params = {p.name: self.visit(p) for p in node.parameters} vertical_loops = [ self.visit(c) for c in node.computations if c.body.stmts ] if node.externals is not None: externals = { name: _make_literal(value) for name, value in node.externals.items() if isinstance(value, numbers.Number) } else: externals = {} return gtir.Stencil( name=node.name, api_signature=[ gtir.Argument( name=f.name, is_keyword=f.is_keyword, default=str(f.default) if not isinstance(f.default, type(Empty)) else "", ) for f in node.api_signature ], params=[ self.visit(f, all_params={ **field_params, **scalar_params }) for f in node.api_signature ], vertical_loops=vertical_loops, externals=externals, sources=node.sources or "", docstring=node.docstring, loc=location_to_source_location(node.loc), )
def visit_VarRef(self, node: VarRef, **kwargs): return gtir.ScalarAccess(name=node.name, loc=location_to_source_location(node.loc))
def visit_While(self, node: While) -> gtir.While: return gtir.While( cond=self.visit(node.condition), body=self.visit(node.body), loc=location_to_source_location(node.loc), )
def visit_Cast(self, node: Cast) -> gtir.Cast: return gtir.Cast( dtype=common.DataType(node.data_type.value), expr=self.visit(node.expr), loc=location_to_source_location(node.loc), )
def visit_UnaryOpExpr(self, node: UnaryOpExpr) -> gtir.UnaryOp: return gtir.UnaryOp( op=self.GT4PY_UNARYOP_TO_GTIR[node.op], expr=self.visit(node.arg), loc=location_to_source_location(node.loc), )