Ejemplo n.º 1
0
    def visit_HorizontalExecution(
        self,
        node: oir.HorizontalExecution,
        *,
        prog_ctx: ProgramContext,
        comp_ctx: GTComputationContext,
        interval: gtcpp.GTInterval,
        **kwargs: Any,
    ) -> gtcpp.GTStage:
        assert "stencil_symtable" in kwargs
        apply_method = gtcpp.GTApplyMethod(
            interval=self.visit(interval, **kwargs),
            body=self.visit(node.body, **kwargs),
            local_variables=self.visit(node.declarations, **kwargs),
        )
        accessors = _extract_accessors(apply_method)
        stage_args = [gtcpp.Arg(name=acc.name) for acc in accessors]

        comp_ctx.add_arguments({
            param_arg
            for param_arg in stage_args if param_arg.name not in
            [tmp.name for tmp in comp_ctx.temporaries]
        })

        functor_name = type(node).__name__ + str(id(node))
        prog_ctx.add_functor(
            gtcpp.GTFunctor(
                name=functor_name,
                applies=[apply_method],
                param_list=gtcpp.GTParamList(accessors=accessors),
            )),

        return gtcpp.GTStage(functor=functor_name, args=stage_args)
Ejemplo n.º 2
0
    def visit_HorizontalExecution(
        self,
        node: oir.HorizontalExecution,
        *,
        prog_ctx: ProgramContext,
        comp_ctx: GTComputationContext,
        interval: gtcpp.GTInterval,
        **kwargs: Any,
    ) -> gtcpp.GTStage:
        assert "stencil_symtable" in kwargs
        body = self.visit(node.body, **kwargs)
        mask = self.visit(node.mask, **kwargs)
        if mask:
            body = [
                gtcpp.IfStmt(cond=mask, true_branch=gtcpp.BlockStmt(body=body))
            ]
        apply_method = gtcpp.GTApplyMethod(
            interval=self.visit(interval, **kwargs),
            body=body,
            local_variables=self.visit(node.declarations, **kwargs),
        )
        accessors = _extract_accessors(apply_method)
        stage_args = [gtcpp.Arg(name=acc.name) for acc in accessors]

        comp_ctx.add_arguments({
            param_arg
            for param_arg in stage_args if param_arg.name not in
            [tmp.name for tmp in comp_ctx.temporaries]
        })

        prog_ctx.add_functor(
            gtcpp.GTFunctor(
                name=node.id_,
                applies=[apply_method],
                param_list=gtcpp.GTParamList(accessors=accessors),
            )),

        return gtcpp.GTStage(functor=node.id_, args=stage_args)