def visit_HorizontalExecution( self, node: oir.HorizontalExecution, *, prog_ctx: ProgramContext, comp_ctx: GTComputationContext, interval: gtcpp.GTInterval, **kwargs: Any, ) -> gtcpp.GTStage: assert "stencil_symtable" in kwargs body = self.visit(node.body, **kwargs) mask = self.visit(node.mask, **kwargs) if mask: body = [ gtcpp.IfStmt(cond=mask, true_branch=gtcpp.BlockStmt(body=body)) ] apply_method = gtcpp.GTApplyMethod( interval=self.visit(interval, **kwargs), body=body, local_variables=self.visit(node.declarations, **kwargs), ) accessors = _extract_accessors(apply_method) stage_args = [gtcpp.Arg(name=acc.name) for acc in accessors] comp_ctx.add_arguments({ param_arg for param_arg in stage_args if param_arg.name not in [tmp.name for tmp in comp_ctx.temporaries] }) prog_ctx.add_functor( gtcpp.GTFunctor( name=node.id_, applies=[apply_method], param_list=gtcpp.GTParamList(accessors=accessors), )), return gtcpp.GTStage(functor=node.id_, args=stage_args)
def visit_MaskStmt(self, node: oir.MaskStmt, **kwargs: Any) -> gtcpp.IfStmt: return gtcpp.IfStmt( cond=self.visit(node.mask, **kwargs), true_branch=gtcpp.BlockStmt(body=self.visit(node.body, **kwargs)), )
def visit_HorizontalRestriction(self, node: oir.HorizontalRestriction, **kwargs: Any) -> gtcpp.IfStmt: mask = self._mask_to_expr(node.mask, kwargs["comp_ctx"]) return gtcpp.IfStmt( cond=mask, true_branch=gtcpp.BlockStmt(body=self.visit(node.body, **kwargs)))