示例#1
0
 def visit_VarDecl(self, node: VarDecl) -> gtir.ScalarDecl:
     # datatype conversion works via same ID
     return gtir.ScalarDecl(
         name=node.name,
         dtype=common.DataType(int(node.data_type.value)),
         loc=common.location_to_source_location(node.loc),
     )
示例#2
0
 def visit_NativeFuncCall(self,
                          node: NativeFuncCall) -> gtir.NativeFuncCall:
     return gtir.NativeFuncCall(
         func=self.GT4PY_NATIVE_FUNC_TO_GTIR[node.func],
         args=[self.visit(arg) for arg in node.args],
         loc=common.location_to_source_location(node.loc),
     )
示例#3
0
 def visit_FieldRef(self, node: FieldRef) -> gtir.FieldAccess:
     return gtir.FieldAccess(
         name=node.name,
         offset=self.transform_offset(node.offset),
         data_index=[self.visit(index) for index in node.data_index],
         loc=common.location_to_source_location(node.loc),
     )
示例#4
0
 def visit_TernaryOpExpr(self, node: TernaryOpExpr) -> gtir.TernaryOp:
     return gtir.TernaryOp(
         cond=self.visit(node.condition),
         true_expr=self.visit(node.then_expr),
         false_expr=self.visit(node.else_expr),
         loc=common.location_to_source_location(node.loc),
     )
示例#5
0
 def visit_BinOpExpr(
         self,
         node: BinOpExpr) -> Union[gtir.BinaryOp, gtir.NativeFuncCall]:
     if node.op in (BinaryOperator.POW, BinaryOperator.MOD):
         return gtir.NativeFuncCall(
             func=common.NativeFunction[node.op.name],
             args=[self.visit(node.lhs),
                   self.visit(node.rhs)],
             loc=common.location_to_source_location(node.loc),
         )
     return gtir.BinaryOp(
         left=self.visit(node.lhs),
         right=self.visit(node.rhs),
         op=self.GT4PY_OP_TO_GTIR_OP[node.op],
         loc=common.location_to_source_location(node.loc),
     )
示例#6
0
 def visit_Assign(self, node: Assign) -> gtir.ParAssignStmt:
     assert isinstance(node.target, FieldRef) or isinstance(
         node.target, VarRef)
     return gtir.ParAssignStmt(
         left=self.visit(node.target),
         right=self.visit(node.value),
         loc=common.location_to_source_location(node.loc),
     )
示例#7
0
 def visit_If(self, node: If) -> Union[gtir.FieldIfStmt, gtir.ScalarIfStmt]:
     cond = self.visit(node.condition)
     if cond.kind == ExprKind.FIELD:
         return gtir.FieldIfStmt(
             cond=cond,
             true_branch=gtir.BlockStmt(body=self.visit(node.main_body)),
             false_branch=gtir.BlockStmt(body=self.visit(node.else_body))
             if node.else_body else None,
             loc=common.location_to_source_location(node.loc),
         )
     else:
         return gtir.ScalarIfStmt(
             cond=cond,
             true_branch=gtir.BlockStmt(body=self.visit(node.main_body)),
             false_branch=gtir.BlockStmt(body=self.visit(node.else_body))
             if node.else_body else None,
             loc=common.location_to_source_location(node.loc),
         )
示例#8
0
 def visit_FieldDecl(self, node: FieldDecl) -> gtir.FieldDecl:
     dimension_names = ["I", "J", "K"]
     dimensions = [dim in node.axes for dim in dimension_names]
     # datatype conversion works via same ID
     return gtir.FieldDecl(
         name=node.name,
         dtype=common.DataType(int(node.data_type.value)),
         dimensions=dimensions,
         data_dims=node.data_dims,
         loc=common.location_to_source_location(node.loc),
     )
示例#9
0
 def visit_ComputationBlock(self,
                            node: ComputationBlock) -> gtir.VerticalLoop:
     stmts = []
     temporaries = []
     for s in node.body.stmts:
         # FieldDecl or VarDecls in the body are temporaries
         if isinstance(s, FieldDecl) or isinstance(s, VarDecl):
             dtype = common.DataType(int(s.data_type.value))
             if dtype == common.DataType.DEFAULT:
                 # TODO this will be a frontend choice later
                 # in non-GTC parts, this is set in the backend
                 dtype = cast(
                     common.DataType, common.DataType.FLOAT64
                 )  # see https://github.com/GridTools/gtc/issues/100
             temporaries.append(
                 gtir.FieldDecl(
                     name=s.name,
                     dtype=dtype,
                     dimensions=(True, True, True),
                     loc=common.location_to_source_location(s.loc),
                 ))
         else:
             stmts.append(self.visit(s))
     start, end = self.visit(node.interval)
     interval = gtir.Interval(
         start=start,
         end=end,
         loc=common.location_to_source_location(node.interval.loc),
     )
     return gtir.VerticalLoop(
         interval=interval,
         loop_order=self.GT4PY_ITERATIONORDER_TO_GTIR_LOOPORDER[
             node.iteration_order],
         body=stmts,
         temporaries=temporaries,
         loc=common.location_to_source_location(node.loc),
     )
示例#10
0
 def visit_StencilDefinition(self, node: StencilDefinition) -> gtir.Stencil:
     field_params = {f.name: self.visit(f) for f in node.api_fields}
     scalar_params = {p.name: self.visit(p) for p in node.parameters}
     vertical_loops = [
         self.visit(c) for c in node.computations if c.body.stmts
     ]
     return gtir.Stencil(
         name=node.name,
         params=[
             self.visit(f, all_params={
                 **field_params,
                 **scalar_params
             }) for f in node.api_signature
         ],
         vertical_loops=vertical_loops,
         loc=common.location_to_source_location(node.loc),
     )
示例#11
0
 def visit_VarRef(self, node: VarRef, **kwargs):
     return gtir.ScalarAccess(name=node.name,
                              loc=common.location_to_source_location(
                                  node.loc))
示例#12
0
 def visit_While(self, node: While) -> gtir.While:
     return gtir.While(
         cond=self.visit(node.condition),
         body=self.visit(node.body),
         loc=common.location_to_source_location(node.loc),
     )
示例#13
0
 def visit_Cast(self, node: Cast) -> gtir.Cast:
     return gtir.Cast(
         dtype=common.DataType(node.data_type.value),
         expr=self.visit(node.expr),
         loc=common.location_to_source_location(node.loc),
     )
示例#14
0
 def visit_UnaryOpExpr(self, node: UnaryOpExpr) -> gtir.UnaryOp:
     return gtir.UnaryOp(
         op=self.GT4PY_UNARYOP_TO_GTIR[node.op],
         expr=self.visit(node.arg),
         loc=common.location_to_source_location(node.loc),
     )