Пример #1
0
    def visit_HorizontalExecution(
        self,
        node: oir.HorizontalExecution,
        *,
        prog_ctx: ProgramContext,
        comp_ctx: GTComputationContext,
        interval: gtcpp.GTInterval,
        **kwargs: Any,
    ) -> gtcpp.GTStage:
        assert "stencil_symtable" in kwargs
        body = self.visit(node.body, **kwargs)
        mask = self.visit(node.mask, **kwargs)
        if mask:
            body = [
                gtcpp.IfStmt(cond=mask, true_branch=gtcpp.BlockStmt(body=body))
            ]
        apply_method = gtcpp.GTApplyMethod(
            interval=self.visit(interval, **kwargs),
            body=body,
            local_variables=self.visit(node.declarations, **kwargs),
        )
        accessors = _extract_accessors(apply_method)
        stage_args = [gtcpp.Arg(name=acc.name) for acc in accessors]

        comp_ctx.add_arguments({
            param_arg
            for param_arg in stage_args if param_arg.name not in
            [tmp.name for tmp in comp_ctx.temporaries]
        })

        prog_ctx.add_functor(
            gtcpp.GTFunctor(
                name=node.id_,
                applies=[apply_method],
                param_list=gtcpp.GTParamList(accessors=accessors),
            )),

        return gtcpp.GTStage(functor=node.id_, args=stage_args)
Пример #2
0
 def visit_MaskStmt(self, node: oir.MaskStmt,
                    **kwargs: Any) -> gtcpp.IfStmt:
     return gtcpp.IfStmt(
         cond=self.visit(node.mask, **kwargs),
         true_branch=gtcpp.BlockStmt(body=self.visit(node.body, **kwargs)),
     )
Пример #3
0
 def visit_HorizontalRestriction(self, node: oir.HorizontalRestriction,
                                 **kwargs: Any) -> gtcpp.IfStmt:
     mask = self._mask_to_expr(node.mask, kwargs["comp_ctx"])
     return gtcpp.IfStmt(
         cond=mask,
         true_branch=gtcpp.BlockStmt(body=self.visit(node.body, **kwargs)))