def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> npir.Computation: field_extents, block_extents = compute_extents(node) arguments = [decl.name for decl in node.params] param_decls = [ self.visit(decl, **kwargs) for decl in node.params if isinstance(decl, oir.ScalarDecl) ] api_field_decls = [ self.visit(decl, field_extents=field_extents) for decl in node.params if isinstance(decl, oir.FieldDecl) ] temp_decls = [ self.visit(decl, field_extents=field_extents, **kwargs) for decl in node.declarations ] vertical_passes = utils.flatten_list( self.visit( node.vertical_loops, block_extents=block_extents, **kwargs, )) return npir.Computation( arguments=arguments, api_field_decls=api_field_decls, param_decls=param_decls, temp_decls=temp_decls, vertical_passes=vertical_passes, )
def visit_VerticalLoop(self, node: gtir.VerticalLoop, *, ctx: Context) -> oir.VerticalLoop: horiz_execs: List[oir.HorizontalExecution] = [] for stmt in node.body: ctx.reset_local_scalars() ret = self.visit(stmt, ctx=ctx) stmts = utils.flatten_list( [ret] if isinstance(ret, oir.Stmt) else ret) horiz_execs.append( oir.HorizontalExecution(body=stmts, declarations=ctx.local_scalars)) ctx.temp_fields += [ oir.Temporary(name=temp.name, dtype=temp.dtype, dimensions=temp.dimensions) for temp in node.temporaries ] return oir.VerticalLoop( loop_order=node.loop_order, sections=[ oir.VerticalLoopSection( interval=self.visit(node.interval), horizontal_executions=horiz_execs, loc=node.loc, ) ], )
def visit_HorizontalRestriction(self, node: oir.HorizontalRestriction, *, extent: Extent, **kwargs: Any) -> Any: horizontal_mask = compute_relative_mask(extent, node.mask) if horizontal_mask is None: return NOTHING return utils.flatten_list( self.visit(node.body, horizontal_mask=horizontal_mask, **kwargs))
def visit_HorizontalRestriction( self, node: gtir.HorizontalRestriction, **kwargs: Any) -> oir.HorizontalRestriction: body_stmts = [] for stmt in node.body: stmt_or_stmts = self.visit(stmt, **kwargs) stmts = utils.flatten_list([stmt_or_stmts] if isinstance( stmt_or_stmts, oir.Stmt) else stmt_or_stmts) body_stmts.extend(stmts) return oir.HorizontalRestriction(mask=node.mask, body=body_stmts)
def _all_local_scalars_are_unique_type(stencil: npir.Computation) -> bool: all_declarations = utils.flatten_list(stencil.iter_tree().if_isinstance( npir.HorizontalBlock).getattr("declarations").to_list()) name_to_dtype: Dict[str, common.DataType] = {} for decl in all_declarations: if decl.name in name_to_dtype: if decl.dtype != name_to_dtype[decl.name]: return False else: name_to_dtype[decl.name] = decl.dtype return True
def visit_While(self, node: oir.While, *, mask: Optional[npir.Expr] = None, **kwargs: Any) -> npir.While: cond = self.visit(node.cond, mask=mask, **kwargs) if mask: mask = npir.VectorLogic(op=common.LogicalOperator.AND, left=mask, right=cond) else: mask = cond return npir.While(cond=cond, body=utils.flatten_list( self.visit(node.body, mask=mask, **kwargs)))
def visit_MaskStmt( self, node: oir.MaskStmt, *, mask: Optional[npir.Expr] = None, **kwargs: Any, ) -> List[npir.Stmt]: mask_expr = self.visit(node.mask, **kwargs) if mask: mask_expr = npir.VectorLogic(op=common.LogicalOperator.AND, left=mask, right=mask_expr) return utils.flatten_list( self.visit(node.body, mask=mask_expr, **kwargs))
def _impl(cls: Type[pydantic.BaseModel], values: RootValidatorValuesType) -> RootValidatorValuesType: dtype_nodes: List[Node] = [] for v in flatten_list(values.values()): if isinstance(v, Node): dtype_nodes.extend(v.iter_tree().if_hasattr("dtype")) nodes_without_dtype = [] for node in dtype_nodes: if not node.dtype: nodes_without_dtype.append(node) if len(nodes_without_dtype) > 0: raise ValueError( "Nodes without dtype detected {}".format(nodes_without_dtype)) return values
def visit_HorizontalExecution( self, node: oir.HorizontalExecution, *, block_extents: Optional[Dict[int, Extent]] = None, **kwargs: Any, ) -> npir.HorizontalBlock: stmts = utils.flatten_list(self.visit(node.body, **kwargs)) if block_extents: extent = block_extents[id(node)] else: extent = ((0, 0), (0, 0)) return npir.HorizontalBlock(declarations=self.visit( node.declarations, **kwargs), body=stmts, extent=extent)
def visit_While(self, node: gtir.While, *, mask: oir.Expr = None, **kwargs: Any): body_stmts = [] for stmt in node.body: stmt_or_stmts = self.visit(stmt, **kwargs) stmts = utils.flatten_list([stmt_or_stmts] if isinstance( stmt_or_stmts, oir.Stmt) else stmt_or_stmts) body_stmts.extend(stmts) cond = self.visit(node.cond) if mask: cond = oir.BinaryOp(op=common.LogicalOperator.AND, left=mask, right=cond) stmt = oir.While(cond=cond, body=body_stmts, loc=node.loc) if mask is not None: stmt = oir.MaskStmt(body=[stmt], mask=mask, loc=node.loc) return stmt