class CUIRCodegen(codegen.TemplatedGenerator): contexts = (traits.SymbolTableTrait.symtable_merger, ) LocalScalar = as_fmt("{dtype} {name};") FieldDecl = as_fmt("{name}") ScalarDecl = as_fmt("{name}") Temporary = as_fmt("{name}") AssignStmt = as_fmt("{left} = {right};") MaskStmt = as_mako(""" if (${mask}) { ${'\\n'.join(body)} } """) While = as_mako(""" while (${cond}) { ${'\\n'.join(body)} } """) def visit_FieldAccess(self, node: cuir.FieldAccess, **kwargs: Any): def maybe_const(s): try: return f"{int(s)}_c" except ValueError: return s kwargs["this_data_index"] = "".join( ", " + maybe_const(self.visit(index, **kwargs)) for index in node.data_index) return self.generic_visit(node, **kwargs) FieldAccess = as_mako("${name}(${offset}${this_data_index})") def visit_IJCacheAccess(self, node: cuir.IJCacheAccess, symtable: Dict[str, Any], **kwargs: Any) -> str: extent = symtable[node.name].extent if extent.i == extent.j == (0, 0): # cache is scalar assert node.offset.i == node.offset.j == 0 return node.name if node.offset.i == node.offset.j == 0: return "*" + node.name offsets = (f"{o} * {d}_stride_{node.name}" for o, d in zip([node.offset.i, node.offset.j], "ij") if o != 0) return node.name + "[" + " + ".join(offsets) + "]" KCacheAccess = as_mako( "${_this_generator.k_cache_var(name, _this_node.offset.k)}") ScalarAccess = as_fmt("{name}") CartesianOffset = as_fmt("{i}_c, {j}_c, {k}_c") VariableKOffset = as_fmt("0_c, 0_c, {k}") BinaryOp = as_fmt("({left} {op} {right})") UNARY_OPERATOR_TO_CODE = { UnaryOperator.NOT: "!", UnaryOperator.NEG: "-", UnaryOperator.POS: "+", } UnaryOp = as_fmt( "({_this_generator.UNARY_OPERATOR_TO_CODE[_this_node.op]}{expr})") TernaryOp = as_fmt("({cond} ? {true_expr} : {false_expr})") Cast = as_fmt("static_cast<{dtype}>({expr})") BUILTIN_LITERAL_TO_CODE = { BuiltInLiteral.TRUE: "true", BuiltInLiteral.FALSE: "false", } def visit_BuiltInLiteral(self, builtin: BuiltInLiteral, **kwargs: Any) -> str: try: return self.BUILTIN_LITERAL_TO_CODE[builtin] except KeyError as error: raise NotImplementedError( "Not implemented BuiltInLiteral encountered.") from error Literal = as_mako("static_cast<${dtype}>(${value})") NATIVE_FUNCTION_TO_CODE = { NativeFunction.ABS: "std::abs", NativeFunction.MIN: "std::min", NativeFunction.MAX: "std::max", NativeFunction.MOD: "std::fmod", NativeFunction.SIN: "std::sin", NativeFunction.COS: "std::cos", NativeFunction.TAN: "std::tan", NativeFunction.ARCSIN: "std::asin", NativeFunction.ARCCOS: "std::acos", NativeFunction.ARCTAN: "std::atan", NativeFunction.SINH: "std::sinh", NativeFunction.COSH: "std::cosh", NativeFunction.TANH: "std::tanh", NativeFunction.ARCSINH: "std::asinh", NativeFunction.ARCCOSH: "std::acosh", NativeFunction.ARCTANH: "std::atanh", NativeFunction.SQRT: "std::sqrt", NativeFunction.POW: "std::pow", NativeFunction.EXP: "std::exp", NativeFunction.LOG: "std::log", NativeFunction.GAMMA: "std::tgamma", NativeFunction.CBRT: "std::cbrt", NativeFunction.ISFINITE: "std::isfinite", NativeFunction.ISINF: "std::isinf", NativeFunction.ISNAN: "std::isnan", NativeFunction.FLOOR: "std::floor", NativeFunction.CEIL: "std::ceil", NativeFunction.TRUNC: "std::trunc", } def visit_NativeFunction(self, func: NativeFunction, **kwargs: Any) -> str: try: return self.NATIVE_FUNCTION_TO_CODE[func] except KeyError as error: raise NotImplementedError( f"Not implemented NativeFunction '{func}' encountered." ) from error NativeFuncCall = as_mako("${func}(${','.join(args)})") DATA_TYPE_TO_CODE = { DataType.BOOL: "bool", DataType.INT8: "std::int8_t", DataType.INT16: "std::int16_t", DataType.INT32: "std::int32_t", DataType.INT64: "std::int64_t", DataType.FLOAT32: "float", DataType.FLOAT64: "double", } def visit_DataType(self, dtype: DataType, **kwargs: Any) -> str: try: return self.DATA_TYPE_TO_CODE[dtype] except KeyError as error: raise NotImplementedError( f"Not implemented DataType '{dtype.name}' encountered." ) from error IJExtent = as_fmt("extent<{i[0]}, {i[1]}, {j[0]}, {j[1]}>") HorizontalExecution = as_mako(""" // HorizontalExecution ${id(_this_node)} if (validator(${extent}())) { ${'\\n'.join(declarations)} ${'\\n'.join(body)} } """) def visit_AxisBound(self, node: cuir.AxisBound, **kwargs: Any) -> str: if node.level == LevelMarker.START: return f"{node.offset}" if node.level == LevelMarker.END: return f"k_size + {node.offset}" raise ValueError("Cannot handle dynamic levels") IJCacheDecl = as_mako(""" % if _this_node.extent.i == _this_node.extent.j == (0, 0): // scalar ij-cache ${dtype} ${name}; % else: // ij-cache in shared memory constexpr int ${name}_cache_data_size = (i_block_size_t() + ${-_this_node.extent.i[0] + _this_node.extent.i[1]}) * (j_block_size_t() + ${-_this_node.extent.j[0] + _this_node.extent.j[1]}); __shared__ ${dtype} ${name}_cache_data[${name}_cache_data_size]; constexpr int i_stride_${name} = 1; constexpr int j_stride_${name} = i_block_size_t() + ${-_this_node.extent.i[0] + _this_node.extent.i[1]}; ${dtype} *${name} = ${name}_cache_data + (${-_this_node.extent.i[0]} + _i_block) * i_stride_${name} + (${-_this_node.extent.j[0]} + _j_block) * j_stride_${name}; % endif """) KCacheDecl = as_mako(""" % for var in _this_generator.k_cache_vars(_this_node): ${dtype} ${var}; % endfor """) VerticalLoopSection = as_mako(""" <%def name="sid_shift(step)"> sid::shift(_ptr, sid::get_stride<dim::k>(m_strides), ${step}_c); </%def> <%def name="cache_shift(cache_vars)"> % for dst, src in zip(cache_vars[:-1], cache_vars[1:]): ${dst} = ${src}; % endfor </%def> // VerticalLoopSection ${id(_this_node)} % if order == cuir.LoopOrder.FORWARD: for (int _k_block = ${start}; _k_block < ${end}; ++_k_block) { ${'\\n__syncthreads();\\n'.join(horizontal_executions)} ${sid_shift(1)} % for k_cache in k_cache_decls: ${cache_shift(_this_generator.k_cache_vars(k_cache))} % endfor } % elif order == cuir.LoopOrder.BACKWARD: for (int _k_block = ${end} - 1; _k_block >= ${start}; --_k_block) { ${'\\n__syncthreads();\\n'.join(horizontal_executions)} ${sid_shift(-1)} % for k_cache in k_cache_decls: ${cache_shift(_this_generator.k_cache_vars(k_cache)[::-1])} % endfor } % else: if (_k_block >= ${start} && _k_block < ${end}) { ${'\\n__syncthreads();\\n'.join(horizontal_executions)} } % endif """) @staticmethod def k_cache_var(name: str, offset: int) -> str: return name + (f"p{offset}" if offset >= 0 else f"m{-offset}") @classmethod def k_cache_vars(cls, k_cache: cuir.KCacheDecl) -> List[str]: assert k_cache.extent return [ cls.k_cache_var(k_cache.name, offset) for offset in range(k_cache.extent.k[0], k_cache.extent.k[1] + 1) ] def visit_VerticalLoop(self, node: cuir.VerticalLoop, *, symtable: Dict[str, Any], **kwargs: Any) -> Union[str, Collection[str]]: fields = { name: data_dims for name, data_dims in node.iter_tree().if_isinstance(cuir.FieldAccess).getattr( "name", "data_index").map(lambda x: (x[0], len(x[1]))) } return self.generic_visit( node, fields=fields, k_cache_decls=node.k_caches, order=node.loop_order, symtable=symtable, **kwargs, ) VerticalLoop = as_mako(""" template <class Sid> struct loop_${id(_this_node)}_f { sid::ptr_holder_type<Sid> m_ptr_holder; sid::strides_type<Sid> m_strides; int k_size; template <class Validator> GT_FUNCTION_DEVICE void operator()(const int _i_block, const int _j_block, Validator validator) const { auto _ptr = m_ptr_holder(); sid::shift(_ptr, sid::get_stride<sid::blocked_dim<dim::i>>(m_strides), blockIdx.x); sid::shift(_ptr, sid::get_stride<sid::blocked_dim<dim::j>>(m_strides), blockIdx.y); sid::shift(_ptr, sid::get_stride<dim::i>(m_strides), _i_block); sid::shift(_ptr, sid::get_stride<dim::j>(m_strides), _j_block); % if order == cuir.LoopOrder.PARALLEL: const int _k_block = blockIdx.z; sid::shift(_ptr, sid::get_stride<dim::k>(m_strides), _k_block); % endif % for field, data_dims in fields.items(): const auto ${field} = [&](auto i, auto j, auto k % for i in range(data_dims): , auto dim_${i + 3} % endfor ) -> auto&& { return *sid::multi_shifted<tag::${field}>( device::at_key<tag::${field}>(_ptr), m_strides, tuple_util::device::make<hymap::keys<dim::i, dim::j, dim::k % for i in range(data_dims): , integral_constant<int, ${i + 3}> % endfor >::template values>(i, j, k % for i in range(data_dims): , dim_${i + 3} % endfor )); }; % endfor % for ij_cache in ij_caches: ${ij_cache} % endfor % for k_cache in k_caches: ${k_cache} % endfor % for section in sections: ${section} % endfor } }; """) Kernel = as_mako(""" % for vertical_loop in vertical_loops: ${vertical_loop} % endfor template <${', '.join(f'class Loop{id(vl)}' for vl in _this_node.vertical_loops)}> struct kernel_${id(_this_node)}_f { % for vertical_loop in _this_node.vertical_loops: Loop${id(vertical_loop)} m_${id(vertical_loop)}; % endfor template <class Validator> GT_FUNCTION_DEVICE void operator()(const int _i_block, const int _j_block, Validator validator) const { % for vertical_loop in _this_node.vertical_loops: m_${id(vertical_loop)}(_i_block, _j_block, validator); % endfor } }; """) def visit_Program(self, node: cuir.Program, **kwargs: Any) -> Union[str, Collection[str]]: def loop_start(vertical_loop: cuir.VerticalLoop) -> str: if vertical_loop.loop_order == cuir.LoopOrder.FORWARD: return self.visit(vertical_loop.sections[0].start, **kwargs) if vertical_loop.loop_order == cuir.LoopOrder.BACKWARD: return self.visit(vertical_loop.sections[0].end, ** kwargs) + " - 1" return "0" def loop_fields(vertical_loop: cuir.VerticalLoop) -> Set[str]: return (vertical_loop.iter_tree().if_isinstance( cuir.FieldAccess).getattr("name").to_set()) def ctype(symbol: str) -> str: return self.visit(kwargs["symtable"][symbol].dtype, **kwargs) return self.generic_visit( node, max_extent=self.visit( cuir.IJExtent.zero().union( *node.iter_tree().if_isinstance(cuir.IJExtent)), **kwargs), loop_start=loop_start, loop_fields=loop_fields, ctype=ctype, cuir=cuir, **kwargs, ) Program = as_mako("""#include <algorithm> #include <array> #include <cstdint> #include <gridtools/common/cuda_util.hpp> #include <gridtools/common/host_device.hpp> #include <gridtools/common/hymap.hpp> #include <gridtools/common/integral_constant.hpp> #include <gridtools/sid/allocator.hpp> #include <gridtools/sid/block.hpp> #include <gridtools/sid/composite.hpp> #include <gridtools/sid/multi_shift.hpp> #include <gridtools/stencil/common/dim.hpp> #include <gridtools/stencil/common/extent.hpp> #include <gridtools/stencil/gpu/launch_kernel.hpp> #include <gridtools/stencil/gpu/tmp_storage_sid.hpp> namespace ${name}_impl_{ using namespace gridtools; using namespace literals; using namespace stencil; using domain_t = std::array<unsigned, 3>; using i_block_size_t = integral_constant<int, 64>; using j_block_size_t = integral_constant<int, 8>; template <class Storage> auto block(Storage storage) { return sid::block(std::move(storage), tuple_util::make<hymap::keys<dim::i, dim::j>::values>( i_block_size_t(), j_block_size_t())); } namespace tag { % for p in set().union(*(loop_fields(v) for k in _this_node.kernels for v in k.vertical_loops)): struct ${p} {}; % endfor } % for kernel in kernels: ${kernel} % endfor auto ${name}(domain_t domain){ return [domain](${','.join(f'auto&& {p}' for p in params)}){ auto tmp_alloc = sid::device::make_cached_allocator(&cuda_util::cuda_malloc<char[]>); const int i_size = domain[0]; const int j_size = domain[1]; const int k_size = domain[2]; const int i_blocks = (i_size + i_block_size_t() - 1) / i_block_size_t(); const int j_blocks = (j_size + j_block_size_t() - 1) / j_block_size_t(); % for tmp in temporaries: auto ${tmp} = gpu_backend::make_tmp_storage<${ctype(tmp)}>( 1_c, i_block_size_t(), j_block_size_t(), ${max_extent}(), i_blocks, j_blocks, k_size, tmp_alloc); % endfor % for kernel in _this_node.kernels: // kernel ${id(kernel)} % for vertical_loop in kernel.vertical_loops: // vertical loop ${id(vertical_loop)} assert((${loop_start(vertical_loop)}) >= 0 && (${loop_start(vertical_loop)}) < k_size); auto offset_${id(vertical_loop)} = tuple_util::make<hymap::keys<dim::k>::values>( ${loop_start(vertical_loop)} ); auto composite_${id(vertical_loop)} = sid::composite::make< ${', '.join(f'tag::{field}' for field in loop_fields(vertical_loop))} >( % for field in loop_fields(vertical_loop): % if field in params: block(sid::shift_sid_origin( ${field}, offset_${id(vertical_loop)} )) % else: sid::shift_sid_origin( ${field}, offset_${id(vertical_loop)} ) % endif ${'' if loop.last else ','} % endfor ); using composite_${id(vertical_loop)}_t = decltype(composite_${id(vertical_loop)}); loop_${id(vertical_loop)}_f<composite_${id(vertical_loop)}_t> loop_${id(vertical_loop)}{ sid::get_origin(composite_${id(vertical_loop)}), sid::get_strides(composite_${id(vertical_loop)}), k_size }; % endfor kernel_${id(kernel)}_f<${', '.join(f'decltype(loop_{id(vl)})' for vl in kernel.vertical_loops)}> kernel_${id(kernel)}{ ${', '.join(f'loop_{id(vl)}' for vl in kernel.vertical_loops)} }; gpu_backend::launch_kernel<${max_extent}, i_block_size_t::value, j_block_size_t::value>( i_size, j_size, % if kernel.vertical_loops[0].loop_order == cuir.LoopOrder.PARALLEL: k_size, % else: 1, %endif kernel_${id(kernel)}, 0); % endfor }; } } using ${name}_impl_::${name}; """) @classmethod def apply(cls, root: LeafNode, **kwargs: Any) -> str: if not isinstance(root, cuir.Program): raise ValueError("apply() requires gtcpp.Progam root node") generated_code = super().apply(root, **kwargs) if kwargs.get("format_source", True): generated_code = codegen.format_source("cpp", generated_code, style="LLVM") return generated_code
class GTCppCodegen(codegen.TemplatedGenerator): GTExtent = as_fmt("extent<{i[0]},{i[1]},{j[0]},{j[1]},{k[0]},{k[1]}>") GTAccessor = as_fmt("using {name} = {intent}_accessor<{id}, {extent}, {ndim}>;") GTParamList = as_mako( """${ '\\n'.join(accessors) } using param_list = make_param_list<${ ','.join(a.name for a in _this_node.accessors)}>; """ ) GTFunctor = as_mako( """struct ${ name } { ${param_list} ${ '\\n'.join(applies) } }; """ ) GTLevel = as_fmt("gridtools::stencil::core::level<{splitter}, {offset}, {offset_limit}>") GTInterval = as_fmt("gridtools::stencil::core::interval<{from_level}, {to_level}>") LocalVarDecl = as_fmt("{dtype} {name};") GTApplyMethod = as_mako( """ template<typename Evaluation> GT_FUNCTION static void apply(Evaluation eval, ${interval}) { ${ ' '.join(local_variables) } ${ '\\n'.join(body) } } """ ) AssignStmt = as_fmt("{left} = {right};") AccessorRef = as_fmt("eval({name}({', '.join([offset, *data_index])}))") ScalarAccess = as_fmt("{name}") CartesianOffset = as_fmt("{i}, {j}, {k}") BinaryOp = as_fmt("({left} {op} {right})") UnaryOp = as_fmt("({op}{expr})") TernaryOp = as_fmt("({cond} ? {true_expr} : {false_expr})") Cast = as_fmt("static_cast<{dtype}>({expr})") def visit_BuiltInLiteral(self, builtin: BuiltInLiteral, **kwargs: Any) -> str: if builtin == BuiltInLiteral.TRUE: return "true" elif builtin == BuiltInLiteral.FALSE: return "false" raise NotImplementedError("Not implemented BuiltInLiteral encountered.") Literal = as_mako("static_cast<${dtype}>(${value})") def visit_NativeFunction(self, func: NativeFunction, **kwargs: Any) -> str: try: return { NativeFunction.ABS: "std::abs", NativeFunction.MIN: "std::min", NativeFunction.MAX: "std::max", NativeFunction.MOD: "std::fmod", NativeFunction.SIN: "std::sin", NativeFunction.COS: "std::cos", NativeFunction.TAN: "std::tan", NativeFunction.ARCSIN: "std::asin", NativeFunction.ARCCOS: "std::acos", NativeFunction.ARCTAN: "std::atan", NativeFunction.SQRT: "std::sqrt", NativeFunction.POW: "std::pow", NativeFunction.EXP: "std::exp", NativeFunction.LOG: "std::log", NativeFunction.ISFINITE: "std::isfinite", NativeFunction.ISINF: "std::isinf", NativeFunction.ISNAN: "std::isnan", NativeFunction.FLOOR: "std::floor", NativeFunction.CEIL: "std::ceil", NativeFunction.TRUNC: "std::trunc", }[func] except KeyError as error: raise NotImplementedError( f"Not implemented NativeFunction '{func}' encountered." ) from error NativeFuncCall = as_mako("${func}(${','.join(args)})") DATA_TYPE_TO_CODE = { DataType.BOOL: "bool", DataType.INT8: "std::int8_t", DataType.INT16: "std::int16_t", DataType.INT32: "std::int32_t", DataType.INT64: "std::int64_t", DataType.FLOAT32: "float", DataType.FLOAT64: "double", } def visit_DataType(self, dtype: DataType, **kwargs: Any) -> str: try: return self.DATA_TYPE_TO_CODE[dtype] except KeyError as error: raise NotImplementedError( f"Not implemented DataType '{dtype.name}' encountered." ) from error UNARY_OPERATOR_TO_CODE = { UnaryOperator.NOT: "!", UnaryOperator.NEG: "-", UnaryOperator.POS: "+", } UnaryOp = as_fmt("({_this_generator.UNARY_OPERATOR_TO_CODE[_this_node.op]}{expr})") Arg = as_fmt("{name}") Param = as_fmt("{name}") ApiParamDecl = as_fmt("{name}") GTStage = as_mako(".stage(${functor}(), ${','.join(args)})") GTMultiStage = as_mako("execute_${ loop_order }()${''.join(caches)}${''.join(stages)}") IJCache = as_fmt(".ij_cached({name})") KCache = as_mako( ".k_cached(${'cache_io_policy::fill(), ' if _this_node.fill else ''}${'cache_io_policy::flush(), ' if _this_node.flush else ''}${name})" ) def visit_LoopOrder(self, looporder: LoopOrder, **kwargs: Any) -> str: return { LoopOrder.PARALLEL: "parallel", LoopOrder.FORWARD: "forward", LoopOrder.BACKWARD: "backward", }[looporder] Temporary = as_fmt("GT_DECLARE_TMP({dtype}, {name});") IfStmt = as_mako( """if(${cond}) ${true_branch} %if _this_node.false_branch: else ${false_branch} %endif """ ) BlockStmt = as_mako("{${''.join(body)}}") def visit_GTComputationCall( self, node: gtcpp.GTComputationCall, **kwargs: Any ) -> Union[str, Collection[str]]: computation_name = type(node).__name__ + str(id(node)) return self.generic_visit(node, computation_name=computation_name, **kwargs) GTComputationCall = as_mako( """ %if len(multi_stages) > 0 and len(arguments) > 0: { auto grid = make_grid(domain[0], domain[1], axis<1, axis_config::offset_limit<${offset_limit}>>{domain[2]}); auto ${ computation_name } = [](${ ','.join('auto ' + a for a in arguments) }) { ${ '\\n'.join(temporaries) } return multi_pass(${ ','.join(multi_stages) }); }; run(${computation_name}, ${gt_backend_t}<>{}, grid, ${','.join(f"std::forward<decltype({arg})>({arg})" for arg in arguments)}); } %endif """ ) Program = as_mako( """ #include <gridtools/stencil/${gt_backend_t}.hpp> #include <gridtools/stencil/cartesian.hpp> namespace ${ name }_impl_{ using Domain = std::array<gridtools::uint_t, 3>; using namespace gridtools::stencil; using namespace gridtools::stencil::cartesian; ${'\\n'.join(functors)} auto ${name}(Domain domain){ return [domain](${ ','.join( 'auto&& ' + p for p in parameters)}){ ${gt_computation} }; } } auto ${name}(${name}_impl_::Domain domain){ return ${name}_impl_::${name}(domain); } """ ) @classmethod def apply(cls, root: LeafNode, **kwargs: Any) -> str: if not isinstance(root, gtcpp.Program): raise ValueError("apply() requires gtcpp.Progam root node") if "gt_backend_t" not in kwargs: raise TypeError("apply() missing 1 required keyword-only argument: 'gt_backend_t'") generated_code = super().apply(root, offset_limit=_offset_limit(root), **kwargs) formatted_code = codegen.format_source("cpp", generated_code, style="LLVM") return formatted_code
class UsidCodeGenerator(codegen.TemplatedGenerator): DATA_TYPE_TO_STR: ClassVar[Mapping[common.DataType, str]] = MappingProxyType({ common.DataType.BOOLEAN: "bool", common.DataType.INT32: "int", common.DataType.UINT32: "unsigned_int", common.DataType.FLOAT32: "float", common.DataType.FLOAT64: "double", }) LOCATION_TYPE_TO_STR: ClassVar[Mapping[common.LocationType, str]] = MappingProxyType({ common.LocationType.Vertex: "vertex", common.LocationType.Edge: "edge", common.LocationType.Cell: "cell", }) BUILTIN_LITERAL_TO_STR: ClassVar[Mapping[ common.BuiltInLiteral, str]] = MappingProxyType({ common.BuiltInLiteral.MAX_VALUE: "std::numeric_limits<TODO>::max()", common.BuiltInLiteral.MIN_VALUE: "std::numeric_limits<TODO>::min()", common.BuiltInLiteral.ZERO: "0", common.BuiltInLiteral.ONE: "1", }) @classmethod def apply(cls, root, **kwargs) -> str: symbol_tbl_resolved = SymbolTblHelper().visit(root) generated_code = super().apply(symbol_tbl_resolved, **kwargs) formatted_code = codegen.format_source("cpp", generated_code, style="LLVM") return formatted_code def location_type_from_dimensions(self, dimensions): location_type = [ dim for dim in dimensions if isinstance(dim, common.LocationType) ] if len(location_type) != 1: raise ValueError("Doesn't contain a LocationType!") return location_type[0] headers_ = [ "<gridtools/next/mesh.hpp>", "<gridtools/next/tmp_storage.hpp>", "<gridtools/next/unstructured.hpp>", "<gridtools/sid/allocator.hpp>", "<gridtools/sid/composite.hpp>", ] preface_ = "" Connectivity = as_fmt( "auto {name} = gridtools::next::mesh::connectivity<{chain}>(mesh);") NeighborChain = as_mako("""<% loc_strs = [_this_generator.LOCATION_TYPE_TO_STR[e] for e in _this_node.elements] %> std::tuple<${ ','.join(loc_strs) }> """) SidCompositeNeighborTableEntry = as_fmt( "gridtools::next::connectivity::neighbor_table({_this_node.connectivity_deref_.name})" ) SidCompositeEntry = as_fmt("{name}") SidComposite = as_mako(""" auto ${ _this_node.field_name } = tu::make<gridtools::sid::composite::keys<${ ','.join([t.tag_name for t in _this_node.entries]) }>::values>( ${ ','.join(entries)}); """) def visit_KernelCall(self, node: KernelCall, **kwargs): kernel: Kernel = kwargs["symbol_tbl_kernel"][node.name] connectivities = [ self.generic_visit(conn, **kwargs) for conn in kernel.connectivities ] primary_connectivity: Connectivity = kernel.symbol_tbl[ kernel.primary_connectivity] sids = [ self.generic_visit(s, **kwargs) for s in kernel.sids if len(s.entries) > 0 ] # TODO I don't like that I render here and that I somehow have the same pattern for the parameters args = [c.name for c in kernel.connectivities] args += [ "gridtools::sid::get_origin({0}), gridtools::sid::get_strides({0})" .format(s.field_name) for s in kernel.sids if len(s.entries) > 0 ] # connectivity_args = [c.name for c in kernel.connectivities] return self.generic_visit( node, connectivities=connectivities, sids=sids, primary_connectivity=primary_connectivity, args=args, **kwargs, ) def visit_Kernel(self, node: Kernel, **kwargs): symbol_tbl_conn = {c.name: c for c in node.connectivities} symbol_tbl_sids = {s.name: s for s in node.sids} parameters = [c.name for c in node.connectivities] for s in node.sids: if len(s.entries) > 0: parameters.append(s.origin_name) parameters.append(s.strides_name) return self.generic_visit( node, parameters=parameters, symbol_tbl_conn=symbol_tbl_conn, symbol_tbl_sids=symbol_tbl_sids, **kwargs, ) FieldAccess = as_mako("""<% sid_deref = symbol_tbl_sids[_this_node.sid] sid_entry_deref = sid_deref.symbol_tbl[_this_node.name] %>*gridtools::host_device::at_key<${ sid_entry_deref.tag_name }>(${ sid_deref.ptr_name })""" ) AssignStmt = as_fmt("{left} = {right};") BinaryOp = as_fmt("({left} {op} {right})") NeighborLoop = as_mako("""<% outer_sid_deref = symbol_tbl_sids[_this_node.outer_sid] sid_deref = symbol_tbl_sids[_this_node.sid] if _this_node.sid else None conn_deref = symbol_tbl_conn[_this_node.connectivity] body_location = _this_generator.LOCATION_TYPE_TO_STR[sid_deref.location.elements[-1]] if sid_deref else None %> for (int neigh = 0; neigh < gridtools::next::connectivity::max_neighbors(${ conn_deref.name }); ++neigh) { auto absolute_neigh_index = *gridtools::host_device::at_key<${ conn_deref.neighbor_tbl_tag }>(${ outer_sid_deref.ptr_name}); if (absolute_neigh_index != gridtools::next::connectivity::skip_value(${ conn_deref.name })) { % if sid_deref: auto ${ sid_deref.ptr_name } = ${ sid_deref.origin_name }(); gridtools::sid::shift( ${ sid_deref.ptr_name }, gridtools::host_device::at_key<${ body_location }>(${ sid_deref.strides_name }), absolute_neigh_index); % endif // bodyparameters ${ ''.join(body) } // end body } gridtools::sid::shift(${ outer_sid_deref.ptr_name }, gridtools::host_device::at_key<neighbor>(${ outer_sid_deref.strides_name }), 1); } gridtools::sid::shift(${ outer_sid_deref.ptr_name }, gridtools::host_device::at_key<neighbor>(${ outer_sid_deref.strides_name }), -gridtools::next::connectivity::max_neighbors(${ conn_deref.name })); """) Literal = as_mako("""<% literal= _this_node.value if isinstance(_this_node.value, str) else _this_generator.BUILTIN_LITERAL_TO_STR[_this_node.value] %>(${ _this_generator.DATA_TYPE_TO_STR[_this_node.vtype] })${ literal }""" ) VarAccess = as_fmt("{name}") VarDecl = as_mako( "${ _this_generator.DATA_TYPE_TO_STR[_this_node.vtype] } ${ name } = ${ init };" ) def visit_Computation(self, node: Computation, **kwargs): symbol_tbl_kernel = {k.name: k for k in node.kernels} sid_tags = set() for k in node.kernels: for s in k.sids: for e in s.entries: sid_tags.add("struct " + e.tag_name + ";") return self.generic_visit( node, computation_fields=node.parameters + node.temporaries, # cache_allocator=cache_allocator_, sid_tags=sid_tags, symbol_tbl_kernel=symbol_tbl_kernel, **kwargs, ) Computation = as_mako("""${_this_generator.preface_} ${ '\\n'.join('#include ' + header for header in _this_generator.headers_) } namespace ${ name }_impl_ { ${ ''.join(sid_tags) } ${ ''.join(kernels) } } template<class mesh_t, ${ ','.join('class ' + p.name + '_t' for p in _this_node.parameters) }> void ${ name }(mesh_t&& mesh, ${ ','.join(p.name + '_t&& ' + p.name for p in _this_node.parameters) }){ namespace tu = gridtools::tuple_util; using namespace ${ name }_impl_; % if len(temporaries) > 0: auto tmp_alloc = ${ _this_generator.cache_allocator_ } % endif ${ ''.join(temporaries) } ${ ''.join(ctrlflow_ast) } } """) def visit_Temporary(self, node: Temporary, **kwargs): c_vtype = self.DATA_TYPE_TO_STR[node.vtype] loctype = self.LOCATION_TYPE_TO_STR[self.location_type_from_dimensions( node.dimensions)] return self.generic_visit(node, loctype=loctype, c_vtype=c_vtype, **kwargs) Temporary = as_mako(""" auto ${ name } = gridtools::next::make_simple_tmp_storage<${ loctype }, ${ c_vtype }>( (int)gridtools::next::connectivity::size(gridtools::next::mesh::connectivity<std::tuple<${ loctype }>>(mesh)), 1 /* TODO ksize */, tmp_alloc);""" )
class IconBindingsCodegen(codegen.TemplatedGenerator): @classmethod def apply(cls, root, **kwargs) -> str: assert "stencil_code" in kwargs generated_code = cls().visit(root, stencil_code=kwargs["stencil_code"]) formatted_code = codegen.format_source("cpp", generated_code, style="LLVM") return formatted_code def visit_UField(self, node: UField, **kwargs): if node.name in kwargs["dimensionality"]: return self.generic_visit(node, **kwargs) else: return "" UField = as_fmt( "gridtools::fortran_array_view<T, {len(dimensionality[name])}, field_kind<{','.join(str(i) for i in dimensionality[name])}>> {name}" ) def visit_SparseField(self, node: SparseField, **kwargs): if node.name in kwargs["dimensionality"]: return self.generic_visit(node, **kwargs) else: return "" SparseField = as_fmt( "gridtools::fortran_array_view<T,{len(dimensionality[name])}, field_kind<{','.join(str(i) for i in dimensionality[name])}>> {name}" ) Connectivity = as_fmt("neigh_tbl_t {name}") def visit_Computation(self, node: Computation, **kwargs): dimensionality = {} for p in node.params: dimensionality[p.name] = [0, 1, 2] if not p.dimensions.horizontal: dimensionality[p.name].remove(0) if not p.dimensions.vertical: dimensionality[p.name].remove(1) if not isinstance(p, SparseField): dimensionality[p.name].remove(2) param_names = [] for name, dims in dimensionality.items(): renames = {} for index in range(0, 3): if index < len(dims): if index != dims[index]: renames[index] = dims[index] if renames: param_names.append( "gridtools::sid::rename_dimensions<" + ",".join( f"gridtools::integral_constant<int,{old}>, gridtools::integral_constant<int,{new}>>" for old, new in renames.items()) + f"({name})") else: param_names.append(name) return self.generic_visit(node, param_names=param_names, dimensionality=dimensionality, **kwargs) Computation = as_mako(""" # include <cpp_bindgen/export.hpp> # include <gridtools/storage/adapter/fortran_array_view.hpp> # include <gridtools/storage/sid.hpp> # include <gridtools/usid/icon.hpp> ${stencil_code} namespace icon_bindings_${name}_impl{ struct default_tag{}; template<int...> struct field_kind{}; // template<class Tag> using neigh_tbl_t = gridtools::fortran_array_view<int, 2, default_tag, false>; auto alloc_${name}_impl(${','.join(['int n_edges', 'int n_k'] + connectivities)}) { return ${name}({-1, n_edges, -1, n_k}, ${','.join(f"icon::make_connectivity_producer<{c.max_neighbors}>({c.name})" for c in _this_node.connectivities)}); } BINDGEN_EXPORT_BINDING_WRAPPED(${2+len(connectivities)}, alloc_${name}, alloc_${name}_impl); // template<class Tag> using ${name}_t = decltype(alloc_${name}_impl(0,0, neigh_tbl_t/*<Tag>*/{{}})); template <class T> void ${name}_impl(${name}_t ${name}, ${','.join(params)}){ ${name}(${','.join(param_names)}); } BINDGEN_EXPORT_GENERIC_BINDING_WRAPPED(${1+len(params)}, ${name}, ${name}_impl, (double)); } """)
class TaskletCodegen(codegen.TemplatedGenerator): ScalarAccess = as_fmt("{name}") def visit_FieldAccess(self, node: oir.FieldAccess, *, is_target, targets): if (is_target or node.name in targets) and self.visit(node.offset) == "": targets.add(node.name) name = "__" + node.name else: name = node.name + "__" + self.visit(node.offset) if node.data_index: offset_str = str(node.data_index) else: offset_str = "" return name + offset_str def visit_CartesianOffset(self, node: common.CartesianOffset): res = [] if node.i != 0: res.append(f'i{"m" if node.i<0 else "p"}{abs(node.i):d}') if node.j != 0: res.append(f'j{"m" if node.j<0 else "p"}{abs(node.j):d}') if node.k != 0: res.append(f'k{"m" if node.k<0 else "p"}{abs(node.k):d}') return "_".join(res) def visit_AssignStmt(self, node: oir.AssignStmt, **kwargs): right = self.visit(node.right, is_target=False, **kwargs) left = self.visit(node.left, is_target=True, **kwargs) return f"{left} = {right}" BinaryOp = as_fmt("({left} {op} {right})") UnaryOp = as_fmt("({op}{expr})") TernaryOp = as_fmt("({true_expr} if {cond} else {false_expr})") def visit_BuiltInLiteral(self, builtin: common.BuiltInLiteral, **kwargs: Any) -> str: if builtin == common.BuiltInLiteral.TRUE: return "True" elif builtin == common.BuiltInLiteral.FALSE: return "False" raise NotImplementedError("Not implemented BuiltInLiteral encountered.") Literal = as_fmt("{dtype}({value})") Cast = as_fmt("{dtype}({expr})") def visit_NativeFunction(self, func: common.NativeFunction, **kwargs: Any) -> str: try: return { common.NativeFunction.ABS: "abs", common.NativeFunction.MIN: "min", common.NativeFunction.MAX: "max", common.NativeFunction.MOD: "fmod", common.NativeFunction.SIN: "dace.math.sin", common.NativeFunction.COS: "dace.math.cos", common.NativeFunction.TAN: "dace.math.tan", common.NativeFunction.ARCSIN: "asin", common.NativeFunction.ARCCOS: "acos", common.NativeFunction.ARCTAN: "atan", common.NativeFunction.SQRT: "dace.math.sqrt", common.NativeFunction.POW: "dace.math.pow", common.NativeFunction.EXP: "dace.math.exp", common.NativeFunction.LOG: "dace.math.log", common.NativeFunction.ISFINITE: "isfinite", common.NativeFunction.ISINF: "isinf", common.NativeFunction.ISNAN: "isnan", common.NativeFunction.FLOOR: "dace.math.ifloor", common.NativeFunction.CEIL: "ceil", common.NativeFunction.TRUNC: "trunc", }[func] except KeyError as error: raise NotImplementedError("Not implemented NativeFunction encountered.") from error NativeFuncCall = as_mako("${func}(${','.join(args)})") def visit_DataType(self, dtype: common.DataType, **kwargs: Any) -> str: if dtype == common.DataType.BOOL: return "dace.bool_" elif dtype == common.DataType.INT8: return "dace.int8" elif dtype == common.DataType.INT16: return "dace.int16" elif dtype == common.DataType.INT32: return "dace.int32" elif dtype == common.DataType.INT64: return "dace.int64" elif dtype == common.DataType.FLOAT32: return "dace.float32" elif dtype == common.DataType.FLOAT64: return "dace.float64" raise NotImplementedError("Not implemented DataType encountered.") def visit_UnaryOperator(self, op: common.UnaryOperator, **kwargs: Any) -> str: if op == common.UnaryOperator.NOT: return " not " elif op == common.UnaryOperator.NEG: return "-" elif op == common.UnaryOperator.POS: return "+" raise NotImplementedError("Not implemented UnaryOperator encountered.") Arg = as_fmt("{name}") Param = as_fmt("{name}") LocalScalar = as_fmt("{name}: {dtype}") def visit_HorizontalExecution(self, node: oir.HorizontalExecution): targets: Set[str] = set() return "\n".join([*self.visit(node.declarations), *self.visit(node.body, targets=targets)]) def visit_MaskStmt(self, node: oir.MaskStmt, **kwargs): mask_str = "" indent = "" if node.mask is not None: mask_str = f"if {self.visit(node.mask, is_target=False, **kwargs)}:" indent = " " body_code = self.visit(node.body, targets=kwargs["targets"]) body_code = [indent + b for b in body_code] return "\n".join([mask_str] + body_code) @classmethod def apply(cls, node: oir.HorizontalExecution, **kwargs: Any) -> str: if not isinstance(node, oir.HorizontalExecution): raise ValueError("apply() requires oir.HorizontalExecution node") generated_code = super().apply(node) formatted_code = codegen.format_source("python", generated_code) return formatted_code
class UsidCodeGenerator(codegen.TemplatedGenerator): DATA_TYPE_TO_STR: ClassVar[Mapping[common.DataType, str]] = MappingProxyType({ common.DataType.BOOLEAN: "bool", common.DataType.INT32: "int", common.DataType.UINT32: "unsigned_int", common.DataType.FLOAT32: "float", common.DataType.FLOAT64: "double", }) BUILTIN_LITERAL_TO_STR: ClassVar[Mapping[ common.BuiltInLiteral, str]] = MappingProxyType({ common.BuiltInLiteral.MAX_VALUE: "std::numeric_limits<double>::max()", # TODO: datatype common.BuiltInLiteral.MIN_VALUE: "std::numeric_limits<double>::min()", common.BuiltInLiteral.ZERO: "0", common.BuiltInLiteral.ONE: "1", }) @classmethod def apply(cls, root, **kwargs) -> str: generated_code = super().apply(root, **kwargs) formatted_code = codegen.format_source("cpp", generated_code, style="LLVM") return formatted_code def location_type_from_dimensions(self, dimensions): location_type = [ dim for dim in dimensions if isinstance(dim, common.LocationType) ] if len(location_type) != 1: raise ValueError("Doesn't contain a LocationType!") return location_type[0] headers_ = [ "<gridtools/common/gt_math.hpp>", "<gridtools/common/array.hpp>", "<gridtools/usid/dim.hpp>", "<gridtools/usid/helpers.hpp>", "<gridtools/common/gt_math.hpp>", ] namespace_ = "" preface_ = "" def visit_LocationType(self, node: common.LocationType, **kwargs): return { common.LocationType.Vertex: "vertex", common.LocationType.Edge: "edge", common.LocationType.Cell: "cell", }[node] def visit_bool(self, node: bool, **kwargs): if node: return "true" else: return "false" def visit_SidCompositeSparseEntry(self, node: SidCompositeSparseEntry, **kwargs): return self.generic_visit( node, connectivity_tag=kwargs["symtable"][node.connectivity].tag, **kwargs) SidCompositeSparseEntry = as_fmt( "sid::rename_dimensions<dim::s, {connectivity_tag}>({ref})") SidCompositeEntry = as_fmt("{ref}") SidComposite = as_mako(""" sid::composite::make<${ ','.join([t.name for t in _this_node.entries]) }>( ${ ','.join(entries)}) """) def visit_KernelCall(self, node: KernelCall, **kwargs): kernel: Kernel = kwargs["symtable"][node.name] domain = f"d.{self.visit(kernel.primary_location)}" sids = self.visit([kernel.primary_composite] + kernel.secondary_composites, **kwargs) return self.generic_visit(node, domain=domain, sids=sids) KernelCall = as_mako(""" call_kernel<${name}>(${domain}, d.k, ${','.join(sids)}); """) FieldAccess = as_mako("""<% composite_deref = symtable[_this_node.sid] sid_entry_deref = symtable[_this_node.name] %>field<${ sid_entry_deref.name }>(${ composite_deref.ptr_name })""") ArrayAccess = as_fmt("{name}[{subscript}]") AssignStmt = as_fmt("{left} = {right};") NativeFuncCall = as_fmt( "gridtools::math::{func}({','.join(args)})") # TODO: fix func BinaryOp = as_fmt("({left} {op} {right})") PtrRef = as_fmt("{name}") LocalIndex = as_fmt("{name}") def visit_NeighborLoop(self, node: NeighborLoop, symtable, **kwargs): primary_sid_deref = symtable[node.primary_sid] connectivity_deref = symtable[node.connectivity] indexed = "" index_var = "" if node.local_index: indexed = "_indexed" index_var = f", auto {self.visit(node.local_index)}" return self.generic_visit( node, symtable={ **node.symtable_, **symtable, }, # should be partly bounded (should see only global scope (tags) and current scope) primary_sid_deref=primary_sid_deref, connectivity_deref=connectivity_deref, indexed=indexed, index_var=index_var, **kwargs, ) # TODO consider stricter capture NeighborLoop = as_mako(""" foreach_neighbor${indexed}<${connectivity_deref.tag}>([&](auto &&${primary}, auto &&${secondary}${index_var}){${''.join(body)}}, ${primary_sid_deref.ptr_name}, ${primary_sid_deref.strides_name}, ${secondary_sid}); """) Literal = as_mako("""<% literal= _this_node.value if isinstance(_this_node.value, str) else _this_generator.BUILTIN_LITERAL_TO_STR[_this_node.value] %>(${ _this_generator.DATA_TYPE_TO_STR[_this_node.vtype] })${ literal }""" ) VarAccess = as_fmt("{name}") VarDecl = as_mako( "${ _this_generator.DATA_TYPE_TO_STR[_this_node.vtype] } ${ name } = ${ init };" ) StaticArrayDecl = as_mako( "gridtools::array<${_this_generator.DATA_TYPE_TO_STR[_this_node.vtype]}, ${size}> ${name} = {${','.join(init)}};" ) def visit_Connectivity(self, node: Connectivity, **kwargs): c_has_skip_values = "true" if node.has_skip_values else "false" return self.generic_visit(node, c_has_skip_values=c_has_skip_values) Connectivity = as_mako( "struct ${_this_node.tag}: connectivity<${max_neighbors},${c_has_skip_values}>{};" ) def visit_Temporary(self, node: Temporary, **kwargs): c_vtype = self.DATA_TYPE_TO_STR[node.vtype] loctype = self.visit( self.location_type_from_dimensions(node.dimensions)) return self.generic_visit(node, loctype=loctype, c_vtype=c_vtype, **kwargs) Temporary = as_mako(""" auto ${ name } = make_simple_tmp_storage<${ c_vtype }>( d.${ loctype }, d.k, alloc);""") def visit_TemporarySparseField(self, node: TemporarySparseField, *, symtable, **kwargs): c_vtype = self.DATA_TYPE_TO_STR[node.vtype] loctype = self.visit( self.location_type_from_dimensions(node.dimensions)) connectivity_deref = symtable[node.connectivity] return self.generic_visit( node, s_size=connectivity_deref.max_neighbors, c_vtype=c_vtype, loctype=loctype, **kwargs, ) TemporarySparseField = as_mako(""" auto ${ name } = make_simple_sparse_tmp_storage<${ c_vtype }>( d.${ loctype }, d.k, ${s_size}, alloc);""") def visit_Kernel(self, node: Kernel, symtable, **kwargs): primary_signature = f"auto && {node.primary_composite.ptr_name}, auto&& {node.primary_composite.strides_name}" secondary_signature = ( "" if len(node.secondary_composites) == 0 else ", auto &&" + ", auto&&".join(c.name for c in node.secondary_composites)) return self.generic_visit( node, symtable={ **symtable, **node.symtable_ }, primary_signature=primary_signature, secondary_signature=secondary_signature, **kwargs, ) Kernel = as_mako(""" struct ${name} { GT_FUNCTION auto operator()() const { return [](${primary_signature}${secondary_signature}){ ${''.join(body)} }; } }; """) def visit_Computation(self, node: Computation, **kwargs): # maybe tags should be generated in lowering field_tag_names = node.iter_tree().if_isinstance( SidCompositeEntry).getattr("name").to_set() connectivity_tag_names = (c.tag for c in node.connectivities) field_tags = [ f"struct {field_tag};" for field_tag in field_tag_names.difference(connectivity_tag_names) ] connectivity_params = [f"auto&& {c.name}" for c in node.connectivities] field_params = [f"auto&& {f.name}" for f in node.parameters] connectivity_fields = [ f"{c.name} = sid::rename_dimensions<dim::n, {c.tag}>(std::forward<decltype({c.name})>({c.name})(traits_t()))" for c in node.connectivities ] return self.generic_visit( node, field_tags=field_tags, connectivity_params=connectivity_params, connectivity_fields=connectivity_fields, field_params=field_params, symtable=node.symtable_, **kwargs, ) Computation = as_mako(""" ${ '\\n'.join('#include ' + header for header in _this_generator.headers_) } namespace ${ name }_impl_ { using namespace gridtools; using namespace gridtools::usid; using namespace gridtools::usid::${_this_generator.namespace_}; ${ ''.join(connectivities)} ${ ''.join(field_tags) } ${ ''.join(kernels) } auto ${name} = [](domain d %if connectivity_params: , ${','.join(connectivity_params)} %endif ) { ${ ''.join(f"static_assert(is_sid<decltype({c.name}(traits_t()))>());" for c in _this_node.connectivities)} return [d = std::move(d) %if connectivity_fields: , ${','.join(connectivity_fields)} %endif ]( ${','.join(field_params)} ){ ${ ''.join(f"static_assert(is_sid<decltype({p.name})>());" for p in _this_node.parameters)} %if temporaries: auto alloc = make_allocator(); %endif ${''.join(temporaries)} ${''.join(ctrlflow_ast)} }; }; } using ${ name }_impl_::${name}; """)
class GTCppCodegen(codegen.TemplatedGenerator): GTExtent = as_fmt("extent<{i[0]},{i[1]},{j[0]},{j[1]},{k[0]},{k[1]}>") GTAccessor = as_fmt("using {name} = {intent}_accessor<{id}, {extent}>;") GTParamList = as_mako( """${ '\\n'.join(accessors) } using param_list = make_param_list<${ ','.join(a.name for a in _this_node.accessors)}>; """ ) GTFunctor = as_mako( """struct ${ name } { ${param_list} ${ '\\n'.join(applies) } }; """ ) GTLevel = as_fmt("gridtools::stencil::core::level<{splitter}, {offset}, {offset_limit}>") GTInterval = as_fmt("gridtools::stencil::core::interval<{from_level}, {to_level}>") LocalVarDecl = as_fmt("{dtype} {name};") GTApplyMethod = as_mako( """ template<typename Evaluation> GT_FUNCTION static void apply(Evaluation eval, ${interval}) { ${ ' '.join(local_variables) } ${ '\\n'.join(body) } } """ ) AssignStmt = as_fmt("{left} = {right};") AccessorRef = as_fmt("eval({name}({offset}))") ScalarAccess = as_fmt("{name}") CartesianOffset = as_fmt("{i}, {j}, {k}") BinaryOp = as_fmt("({left} {op} {right})") UnaryOp = as_fmt("({op}{expr})") TernaryOp = as_fmt("({cond} ? {true_expr} : {false_expr})") Cast = as_fmt("static_cast<{dtype}>({expr})") def visit_BuiltInLiteral(self, builtin: BuiltInLiteral, **kwargs: Any) -> str: if builtin == BuiltInLiteral.TRUE: return "true" elif builtin == BuiltInLiteral.FALSE: return "false" raise NotImplementedError("Not implemented BuiltInLiteral encountered.") Literal = as_mako("static_cast<${dtype}>(${value})") def visit_NativeFunction(self, func: NativeFunction, **kwargs: Any) -> str: if func == NativeFunction.SQRT: return "gridtools::math::sqrt" elif func == NativeFunction.MIN: return "gridtools::math::min" elif func == NativeFunction.MAX: return "gridtools::math::max" raise NotImplementedError("Not implemented NativeFunction encountered.") NativeFuncCall = as_mako("${func}(${','.join(args)})") def visit_DataType(self, dtype: DataType, **kwargs: Any) -> str: if dtype == DataType.INT64: return "long long" elif dtype == DataType.FLOAT64: return "double" elif dtype == DataType.FLOAT32: return "float" elif dtype == DataType.BOOL: return "bool" raise NotImplementedError("Not implemented NativeFunction encountered.") def visit_UnaryOperator(self, op: UnaryOperator, **kwargs: Any) -> str: if op == UnaryOperator.NOT: return "!" elif op == UnaryOperator.NEG: return "-" elif op == UnaryOperator.POS: return "+" raise NotImplementedError("Not implemented UnaryOperator encountered.") Arg = as_fmt("{name}") Param = as_fmt("{name}") ApiParamDecl = as_fmt("{name}") GTStage = as_mako(".stage(${functor}(), ${','.join(args)})") GTMultiStage = as_mako("execute_${ loop_order }()${''.join(caches)}${''.join(stages)}") IJCache = as_fmt(".ij_cached({name})") KCache = as_mako( ".k_cached(${'cache_io_policy::fill(), ' if _this_node.fill else ''}${'cache_io_policy::flush(), ' if _this_node.flush else ''}${name})" ) def visit_LoopOrder(self, looporder: LoopOrder, **kwargs: Any) -> str: return { LoopOrder.PARALLEL: "parallel", LoopOrder.FORWARD: "forward", LoopOrder.BACKWARD: "backward", }[looporder] Temporary = as_fmt("GT_DECLARE_TMP({dtype}, {name});") IfStmt = as_mako( """if(${cond}) ${true_branch} %if _this_node.false_branch: else ${false_branch} %endif """ ) BlockStmt = as_mako("{${''.join(body)}}") def visit_GTComputationCall( self, node: gtcpp.GTComputationCall, **kwargs: Any ) -> Union[str, Collection[str]]: return self.generic_visit(node, computation_name=node.id_, **kwargs) GTComputationCall = as_mako( """ %if len(multi_stages) > 0 and len(arguments) > 0: { auto grid = make_grid(domain[0], domain[1], axis<1, axis_config::offset_limit<${offset_limit}>>{domain[2]}); auto ${ computation_name } = [](${ ','.join('auto ' + a for a in arguments) }) { ${ '\\n'.join(temporaries) } return multi_pass(${ ','.join(multi_stages) }); }; run(${computation_name}, ${gt_backend_t}<>{}, grid, ${','.join(arguments)}); } %endif """ ) Program = as_mako( """#include <gridtools/stencil/${gt_backend_t}.hpp> #include <gridtools/stencil/cartesian.hpp> namespace ${ name }_impl_{ using Domain = std::array<gridtools::uint_t, 3>; using namespace gridtools::stencil; using namespace gridtools::stencil::cartesian; ${'\\n'.join(functors)} auto ${name}(Domain domain){ return [domain](${ ','.join( 'auto&& ' + p for p in parameters)}){ ${gt_computation} }; } } auto ${name}(${name}_impl_::Domain domain){ return ${name}_impl_::${name}(domain); } """ ) @classmethod def apply(cls, root: LeafNode, **kwargs: Any) -> str: if not isinstance(root, gtcpp.Program): raise ValueError("apply() requires gtcpp.Progam root node") if "gt_backend_t" not in kwargs: raise TypeError("apply() missing 1 required keyword-only argument: 'gt_backend_t'") generated_code = super().apply(root, offset_limit=_offset_limit(root), **kwargs) formatted_code = codegen.format_source("cpp", generated_code, style="LLVM") return formatted_code
class NaiveCodeGenerator(codegen.TemplatedGenerator): DATA_TYPE_TO_STR: ClassVar[Mapping[common.DataType, str]] = MappingProxyType({ common.DataType.BOOLEAN: "bool", common.DataType.INT32: "int", common.DataType.UINT32: "unsigned_int", common.DataType.FLOAT32: "float", common.DataType.FLOAT64: "double", }) LOCATION_TYPE_TO_STR_MAP: ClassVar[Mapping[LocationType, Mapping[ str, str]]] = MappingProxyType({ LocationType.Node: MappingProxyType({ "singular": "vertex", "plural": "vertices" }), LocationType.Edge: MappingProxyType({ "singular": "edge", "plural": "edges" }), LocationType.Face: MappingProxyType({ "singular": "cell", "plural": "cells" }), }) @classmethod def apply(cls, root, **kwargs) -> str: generated_code = super().apply(root, **kwargs) formatted_code = codegen.format_source("cpp", generated_code, style="LLVM") return formatted_code def visit_DataType(self, node, **kwargs) -> str: return self.DATA_TYPE_TO_STR[node] def visit_LocationType(self, node, **kwargs) -> Mapping[str, str]: return self.LOCATION_TYPE_TO_STR_MAP[node] Node = as_mako( "${_this_node.__class__.__name__.upper()}") # only for testing UnstructuredField = as_mako("""<% loc_type = location_type["singular"] sparseloc = "sparse_" if _this_node.sparse_location_type else "" %> dawn::${ sparseloc }${ loc_type }_field_t<LibTag, ${ data_type }>& ${ name };""" ) FieldAccessExpr = as_mako("""<% sparse_index = "m_sparse_dimension_idx, " if _this_node.is_sparse else "" field_acc_itervar = outer_iter_var if _this_node.is_sparse else iter_var %>${ name }(deref(LibTag{}, ${ field_acc_itervar }), ${ sparse_index } k)""") AssignmentExpr = as_fmt("{left} = {right}") VarAccessExpr = as_fmt("{name}") BinaryOp = as_fmt("{left} {op} {right}") ExprStmt = as_fmt("\n{expr};") VarDeclStmt = as_fmt("\n{data_type} {name};") TemporaryFieldDeclStmt = as_mako("""using dawn::allocateEdgeField; auto ${ name } = allocate${ location_type['singular'].capitalize() }Field<${ data_type }>(mesh);""" ) ForK = as_mako("""<% if _this_node.loop_order == _this_module.common.LoopOrder.FORWARD: k_init = '0' k_cond = 'k < k_size' k_step = '++k' else: k_init = 'k_size -1' k_cond = 'k >= 0' k_step = '--k' %>for (int k = ${k_init}; ${k_cond}; ${k_step}) { int m_sparse_dimension_idx; ${ "".join(horizontal_loops) }\n}""") HorizontalLoop = as_mako("""<% loc_type = location_type['plural'].title() %>for(auto const & t: get${ loc_type }(LibTag{}, mesh)) ${ ast }""") def visit_HorizontalLoop(self, node, **kwargs) -> str: return self.generic_visit(node, iter_var="t", **kwargs) BlockStmt = as_mako("{${ ''.join(statements) }\n}") ReduceOverNeighbourExpr = as_mako("""<% right_loc_type = right_location_type["singular"].title() loc_type = location_type["singular"].title() %>(m_sparse_dimension_idx=0,reduce${ right_loc_type }To${ loc_type }(mesh, ${ outer_iter_var }, ${ init }, [&](auto& lhs, auto const& ${ iter_var }) { lhs ${ operation }= ${ right }; m_sparse_dimension_idx++; return lhs; }))""") def visit_ReduceOverNeighbourExpr(self, node, *, iter_var, **kwargs) -> str: outer_iter_var = iter_var return self.generic_visit( node, outer_iter_var=outer_iter_var, iter_var="redIdx", **kwargs, ) LiteralExpr = as_fmt("({data_type}){value}") Stencil = as_mako(""" void ${name}() { using dawn::deref; ${ "\\n".join(declarations) if _this_node.declarations else ""} ${ "".join(k_loops) } } """) Computation = as_mako("""<% stencil_calls = '\\n'.join("{name}();".format(name=s.name) for s in _this_node.stencils) ctor_field_params = ', '.join( 'dawn::{sparse_loc}{loc_type}_field_t<LibTag, {data_type}>& {name}'.format( loc_type=_this_generator.LOCATION_TYPE_TO_STR_MAP[p.location_type]['singular'], name=p.name, data_type=_this_generator.DATA_TYPE_TO_STR[p.data_type], sparse_loc="sparse_" if p.sparse_location_type else "" ) for p in _this_node.params ) ctor_field_initializers = ', '.join( '{name}({name})'.format(name=p.name) for p in _this_node.params ) %>#define DAWN_GENERATED 1 #define DAWN_BACKEND_T CXXNAIVEICO #include <driver-includes/unstructured_interface.hpp> namespace dawn_generated { namespace cxxnaiveico { template <typename LibTag> class generated { private: dawn::mesh_t<LibTag>& mesh; int const k_size; ${ ''.join(params) } ${ ''.join(stencils) } public: generated(dawn::mesh_t<LibTag>& mesh, int k_size, ${ ctor_field_params }): mesh(mesh), k_size(k_size), ${ ctor_field_initializers } {} void run() { ${ stencil_calls } } }; } } """)