def visit_HorizontalExecution( self, node: oir.HorizontalExecution, *, prog_ctx: ProgramContext, comp_ctx: GTComputationContext, interval: gtcpp.GTInterval, **kwargs: Any, ) -> gtcpp.GTStage: assert "stencil_symtable" in kwargs apply_method = gtcpp.GTApplyMethod( interval=self.visit(interval, **kwargs), body=self.visit(node.body, **kwargs), local_variables=self.visit(node.declarations, **kwargs), ) accessors = _extract_accessors(apply_method) stage_args = [gtcpp.Arg(name=acc.name) for acc in accessors] comp_ctx.add_arguments({ param_arg for param_arg in stage_args if param_arg.name not in [tmp.name for tmp in comp_ctx.temporaries] }) functor_name = type(node).__name__ + str(id(node)) prog_ctx.add_functor( gtcpp.GTFunctor( name=functor_name, applies=[apply_method], param_list=gtcpp.GTParamList(accessors=accessors), )), return gtcpp.GTStage(functor=functor_name, args=stage_args)
def visit_HorizontalExecution( self, node: oir.HorizontalExecution, *, prog_ctx: ProgramContext, comp_ctx: GTComputationContext, interval: gtcpp.GTInterval, **kwargs: Any, ) -> gtcpp.GTStage: assert "stencil_symtable" in kwargs body = self.visit(node.body, **kwargs) mask = self.visit(node.mask, **kwargs) if mask: body = [ gtcpp.IfStmt(cond=mask, true_branch=gtcpp.BlockStmt(body=body)) ] apply_method = gtcpp.GTApplyMethod( interval=self.visit(interval, **kwargs), body=body, local_variables=self.visit(node.declarations, **kwargs), ) accessors = _extract_accessors(apply_method) stage_args = [gtcpp.Arg(name=acc.name) for acc in accessors] comp_ctx.add_arguments({ param_arg for param_arg in stage_args if param_arg.name not in [tmp.name for tmp in comp_ctx.temporaries] }) prog_ctx.add_functor( gtcpp.GTFunctor( name=node.id_, applies=[apply_method], param_list=gtcpp.GTParamList(accessors=accessors), )), return gtcpp.GTStage(functor=node.id_, args=stage_args)