コード例 #1
0
ファイル: oir_to_gtcpp.py プロジェクト: stubbiali/gt4py
    def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> gtcpp.Program:
        prog_ctx = self.ProgramContext()
        comp_ctx = self.GTComputationContext(
            create_symbol_name=symbol_name_creator(collect_symbol_names(node)))

        assert all(
            [isinstance(decl, oir.Temporary) for decl in node.declarations])
        comp_ctx.add_temporaries(self.visit(node.declarations))

        multi_stages = self.visit(node.vertical_loops,
                                  prog_ctx=prog_ctx,
                                  comp_ctx=comp_ctx,
                                  **kwargs)

        gt_computation = gtcpp.GTComputationCall(
            arguments=comp_ctx.arguments,
            extra_decls=comp_ctx.extra_decls,
            temporaries=comp_ctx.temporaries,
            multi_stages=multi_stages,
        )
        parameters = self.visit(node.params)
        return gtcpp.Program(
            name=node.name,
            parameters=parameters,
            functors=prog_ctx.functors,
            gt_computation=gt_computation,
        )
コード例 #2
0
 def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> cuir.Program:
     accessed_fields: Set[str] = set()
     kernels = self.visit(
         node.vertical_loops,
         symtable=node.symtable_,
         new_symbol_name=symbol_name_creator(set(node.symtable_)),
         accessed_fields=accessed_fields,
     )
     temporaries = [
         self.visit(d) for d in node.declarations
         if d.name in accessed_fields
     ]
     return cuir.Program(
         name=node.name,
         params=self.visit(node.params),
         temporaries=temporaries,
         kernels=kernels,
     )
コード例 #3
0
ファイル: oir_to_cuir.py プロジェクト: fthaler/gt4py
 def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> cuir.Program:
     block_extents = compute_horizontal_block_extents(node)
     accessed_fields: Set[str] = set()
     kernels = self.visit(
         node.vertical_loops,
         new_symbol_name=symbol_name_creator(set(kwargs["symtable"])),
         accessed_fields=accessed_fields,
         block_extents=block_extents,
         **kwargs,
     )
     temporaries = [
         self.visit(d) for d in node.declarations
         if d.name in accessed_fields
     ]
     return cuir.Program(
         name=node.name,
         params=self.visit(node.params),
         temporaries=temporaries,
         kernels=kernels,
     )
コード例 #4
0
 def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> cuir.Program:
     block_extents = compute_horizontal_block_extents(node)
     ctx = self.Context(
         new_symbol_name=symbol_name_creator(collect_symbol_names(node)))
     kernels = self.visit(
         node.vertical_loops,
         ctx=ctx,
         block_extents=block_extents,
         **kwargs,
     )
     temporaries = [
         self.visit(d) for d in node.declarations
         if d.name in ctx.accessed_fields
     ]
     return cuir.Program(
         name=node.name,
         params=self.visit(node.params),
         positionals=list(ctx.positionals.values()),
         temporaries=temporaries,
         kernels=kernels,
     )