class DaCeBindingsCodegen: def __init__(self, backend): self.backend = backend self._unique_index: int = 0 def unique_index(self) -> int: self._unique_index += 1 return self._unique_index mako_template = bindings_main_template() def generate_entry_params(self, gtir: gtir.Stencil, sdfg: dace.SDFG): res = {} import dace.data for name in sdfg.signature_arglist(with_types=False, for_call=True): if name in sdfg.arrays: data = sdfg.arrays[name] assert isinstance(data, dace.data.Array) res[name] = "py::buffer {name}, std::array<gt::int_t,{ndim}> {name}_origin".format( name=name, ndim=len(data.shape), ) elif name in sdfg.symbols and not name.startswith("__"): assert name in sdfg.symbols res[name] = "{dtype} {name}".format( dtype=sdfg.symbols[name].ctype, name=name) return list(res[node.name] for node in gtir.params if node.name in res) def generate_sid_params(self, sdfg: dace.SDFG): res = [] import dace.data for name, array in sdfg.arrays.items(): if array.transient: continue domain_dim_flags = tuple(True if any( dace.symbolic.pystr_to_symbolic(f"__{dim.upper()}") in s.free_symbols for s in array.shape if hasattr(s, "free_symbols")) else False for dim in "ijk") data_ndim = len(array.shape) - sum(array_dimensions(array)) sid_def = pybuffer_to_sid( name=name, ctype=array.dtype.ctype, domain_dim_flags=domain_dim_flags, data_ndim=data_ndim, stride_kind_index=self.unique_index(), backend=self.backend, ) res.append(sid_def) # pass scalar parameters as variables for name in (n for n in sdfg.symbols.keys() if not n.startswith("__")): res.append(name) return res def generate_sdfg_bindings(self, gtir, sdfg, module_name): return self.mako_template.render_values( name=sdfg.name, module_name=module_name, entry_params=self.generate_entry_params(gtir, sdfg), sid_params=self.generate_sid_params(sdfg), ) @classmethod def apply(cls, gtir: gtir.Stencil, sdfg: dace.SDFG, module_name: str, *, backend) -> str: generated_code = cls(backend).generate_sdfg_bindings( gtir, sdfg, module_name=module_name) formatted_code = codegen.format_source("cpp", generated_code, style="LLVM") return formatted_code
class GTCCudaBindingsCodegen(codegen.TemplatedGenerator): def __init__(self, backend): self.backend = backend self._unique_index: int = 0 def unique_index(self) -> int: self._unique_index += 1 return self._unique_index def visit_DataType(self, dtype: DataType, **kwargs): return cuir_codegen.CUIRCodegen().visit_DataType(dtype) def visit_FieldDecl(self, node: cuir.FieldDecl, **kwargs): if "external_arg" in kwargs: domain_ndim = node.dimensions.count(True) data_ndim = len(node.data_dims) sid_ndim = domain_ndim + data_ndim if kwargs["external_arg"]: return "py::buffer {name}, std::array<gt::int_t,{sid_ndim}> {name}_origin".format( name=node.name, sid_ndim=sid_ndim, ) else: return pybuffer_to_sid( name=node.name, ctype=self.visit(node.dtype), domain_dim_flags=node.dimensions, data_ndim=len(node.data_dims), stride_kind_index=self.unique_index(), backend=self.backend, ) def visit_ScalarDecl(self, node: cuir.ScalarDecl, **kwargs): if "external_arg" in kwargs: if kwargs["external_arg"]: return "{dtype} {name}".format(name=node.name, dtype=self.visit(node.dtype)) else: return "gridtools::stencil::make_global_parameter({name})".format( name=node.name) def visit_Program(self, node: cuir.Program, **kwargs): assert "module_name" in kwargs entry_params = self.visit(node.params, external_arg=True, **kwargs) sid_params = self.visit(node.params, external_arg=False, **kwargs) return self.generic_visit( node, entry_params=entry_params, sid_params=sid_params, **kwargs, ) Program = bindings_main_template() @classmethod def apply(cls, root, *, module_name="stencil", backend, **kwargs) -> str: generated_code = cls(backend).visit(root, module_name=module_name, **kwargs) if kwargs.get("format_source", True): generated_code = codegen.format_source("cpp", generated_code, style="LLVM") return generated_code