예제 #1
0
파일: usid_codegen.py 프로젝트: fthaler/gtc
class UsidNaiveCodeGenerator(UsidCodeGenerator):

    cache_allocator_ = "gridtools::sid::make_cached_allocator(&std::make_unique<char[]>);"

    KernelCall = as_mako("""
        {
            ${ ''.join(connectivities) }

            ${ ''.join(sids) }

            ${ name }(${','.join(args)});
        }
        """)

    Kernel = as_mako("""<%
            prim_conn = symbol_tbl_conn[_this_node.primary_connectivity]
            prim_sid = symbol_tbl_sids[_this_node.primary_sid]
        %>
        template<${ ','.join("class {}_t".format(p) for p in parameters)}>
        void ${ name }( ${','.join("{0}_t {0}".format(p) for p in parameters) }) {
            for(std::size_t idx = 0; idx < gridtools::next::connectivity::size(${ prim_conn.name }); idx++) {
                % if len(prim_sid.entries) > 0:
                auto ${ prim_sid.ptr_name } = ${ prim_sid.origin_name }();
                gridtools::sid::shift(${ prim_sid.ptr_name }, gridtools::host_device::at_key<
                    ${ _this_generator.LOCATION_TYPE_TO_STR[prim_sid.location.elements[-1]] }
                    >(${ prim_sid.strides_name }), idx);
                % endif
                ${ "".join(ast) }
            }
        }
        """)
예제 #2
0
파일: usid_codegen.py 프로젝트: fthaler/gtc
class UsidGpuCodeGenerator(UsidCodeGenerator):

    cache_allocator_ = (
        "gridtools::sid::make_cached_allocator(&gridtools::cuda_util::cuda_malloc<char[]>);"
    )

    headers_ = UsidCodeGenerator.headers_ + [
        "<gridtools/next/cuda_util.hpp>",
        "<gridtools/common/cuda_util.hpp>",
    ]

    preface_ = (UsidCodeGenerator.preface_ + """
        #ifndef __CUDACC__
        #error "Tried to compile CUDA code with a regular C++ compiler."
        #endif
    """)

    KernelCall = as_mako("""
        {
            ${ ''.join(connectivities) }

            ${ ''.join(sids) }

            auto [blocks, threads_per_block] = gridtools::next::cuda_util::cuda_setup(gridtools::next::connectivity::size(${ primary_connectivity.name }));
            ${ name }<<<blocks, threads_per_block>>>(${','.join(args)});
            GT_CUDA_CHECK(cudaDeviceSynchronize());
        }
        """)

    Kernel = as_mako("""<%
            prim_conn = symbol_tbl_conn[_this_node.primary_connectivity]
            prim_sid = symbol_tbl_sids[_this_node.primary_sid]
        %>
        template<${ ','.join("class {}_t".format(p) for p in parameters)}>
        __global__ void ${ name }( ${','.join("{0}_t {0}".format(p) for p in parameters) }) {
            auto idx = blockIdx.x * blockDim.x + threadIdx.x;
            if (idx >= gridtools::next::connectivity::size(${ prim_conn.name }))
                return;
            % if len(prim_sid.entries) > 0:
            auto ${ prim_sid.ptr_name } = ${ prim_sid.origin_name }();
            gridtools::sid::shift(${ prim_sid.ptr_name }, gridtools::host_device::at_key<
                ${ _this_generator.LOCATION_TYPE_TO_STR[prim_sid.location.elements[-1]] }
                >(${ prim_sid.strides_name }), idx);
            % endif
            ${ "".join(ast) }
        }
        """)
예제 #3
0
def bindings_main_template():
    return as_mako("""
        #include <chrono>
        #include <pybind11/pybind11.h>
        #include <pybind11/stl.h>
        #include <gridtools/storage/adapter/python_sid_adapter.hpp>
        #include <gridtools/stencil/cartesian.hpp>
        #include <gridtools/stencil/global_parameter.hpp>
        #include <gridtools/sid/sid_shift_origin.hpp>
        #include <gridtools/sid/rename_dimensions.hpp>
        #include "computation.hpp"
        namespace gt = gridtools;
        namespace py = ::pybind11;
        PYBIND11_MODULE(${module_name}, m) {
            m.def("run_computation", [](
            ${','.join(["std::array<gt::uint_t, 3> domain", *entry_params, 'py::object exec_info'])}
            ){
                if (!exec_info.is(py::none()))
                {
                    auto exec_info_dict = exec_info.cast<py::dict>();
                    exec_info_dict["run_cpp_start_time"] = static_cast<double>(
                        std::chrono::duration_cast<std::chrono::nanoseconds>(
                            std::chrono::high_resolution_clock::now().time_since_epoch()).count())/1e9;
                }

                ${name}(domain)(${','.join(sid_params)});

                if (!exec_info.is(py::none()))
                {
                    auto exec_info_dict = exec_info.cast<py::dict>();
                    exec_info_dict["run_cpp_end_time"] = static_cast<double>(
                        std::chrono::duration_cast<std::chrono::nanoseconds>(
                            std::chrono::high_resolution_clock::now().time_since_epoch()).count()/1e9);
                }

            }, "Runs the given computation");}
        """)
예제 #4
0
class DaCeComputationCodegen:

    template = as_mako("""
        auto ${name}(const std::array<gt::uint_t, 3>& domain) {
            return [domain](${",".join(functor_args)}) {
                const int __I = domain[0];
                const int __J = domain[1];
                const int __K = domain[2];
                ${name}_t dace_handle;
                auto allocator = gt::sid::make_cached_allocator(&std::make_unique<char[]>);
                ${"\\n".join(tmp_allocs)}
                __program_${name}(${",".join(["&dace_handle", *dace_args])});
            };
        }
        """)

    def generate_tmp_allocs(self, sdfg):
        fmt = "dace_handle.{name} = allocate(allocator, gt::meta::lazy::id<{dtype}>(), {size})();"
        return [
            fmt.format(name=f"__{sdfg.sdfg_id}_{name}",
                       dtype=array.dtype.ctype,
                       size=array.total_size)
            for name, array in sdfg.arrays.items() if array.transient
            and array.lifetime == dace.AllocationLifetime.Persistent
        ]

    @classmethod
    def apply(cls, gtir, sdfg: dace.SDFG):
        self = cls()
        code_objects = sdfg.generate_code()
        computations = code_objects[[co.title for co in code_objects
                                     ].index("Frame")].clean_code
        lines = computations.split("\n")
        computations = "\n".join(
            lines[0:2] + lines[3:])  # remove import of not generated file
        computations = codegen.format_source("cpp", computations, style="LLVM")
        interface = cls.template.definition.render(
            name=sdfg.name,
            dace_args=self.generate_dace_args(gtir, sdfg),
            functor_args=self.generate_functor_args(sdfg),
            tmp_allocs=self.generate_tmp_allocs(sdfg),
        )
        generated_code = f"""#include <gridtools/sid/sid_shift_origin.hpp>
                             #include <gridtools/sid/allocator.hpp>
                             #include <gridtools/stencil/cartesian.hpp>
                             namespace gt = gridtools;
                             {computations}
                             {interface}
                             """
        formatted_code = codegen.format_source("cpp",
                                               generated_code,
                                               style="LLVM")
        return formatted_code

    def __init__(self):
        self._unique_index = 0

    def generate_dace_args(self, gtir, sdfg):
        oir = gtir_to_oir.GTIRToOIR().visit(gtir)
        field_extents = compute_fields_extents(oir, add_k=True)

        offset_dict: Dict[str, Tuple[int, int, int]] = {
            k: (-v[0][0], -v[1][0], -v[2][0])
            for k, v in field_extents.items()
        }
        k_origins = {
            field_name: boundary[0]
            for field_name, boundary in compute_k_boundary(gtir).items()
        }
        for name, origin in k_origins.items():
            offset_dict[name] = (offset_dict[name][0], offset_dict[name][1],
                                 origin)

        symbols = {f"__{var}": f"__{var}" for var in "IJK"}
        for name, array in sdfg.arrays.items():
            if array.transient:
                symbols[f"__{name}_K_stride"] = "1"
                symbols[f"__{name}_J_stride"] = str(array.shape[2])
                symbols[f"__{name}_I_stride"] = str(array.shape[1] *
                                                    array.shape[2])
            else:
                dims = [
                    dim for dim, select in zip("IJK", array_dimensions(array))
                    if select
                ]
                data_ndim = len(array.shape) - len(dims)

                # api field strides
                fmt = "gt::sid::get_stride<{dim}>(gt::sid::get_strides(__{name}_sid))"

                symbols.update({
                    f"__{name}_{dim}_stride":
                    fmt.format(dim=f"gt::stencil::dim::{dim.lower()}",
                               name=name)
                    for dim in dims
                })
                symbols.update({
                    f"__{name}_d{dim}_stride":
                    fmt.format(dim=f"gt::integral_constant<int, {3 + dim}>",
                               name=name)
                    for dim in range(data_ndim)
                })

                # api field pointers
                fmt = """gt::sid::multi_shifted(
                             gt::sid::get_origin(__{name}_sid)(),
                             gt::sid::get_strides(__{name}_sid),
                             std::array<gt::int_t, {ndim}>{{{origin}}}
                         )"""
                origin = tuple(-offset_dict[name][idx]
                               for idx, var in enumerate("IJK") if any(
                                   dace.symbolic.pystr_to_symbolic(f"__{var}")
                                   in s.free_symbols for s in array.shape
                                   if hasattr(s, "free_symbols")))
                symbols[name] = fmt.format(name=name,
                                           ndim=len(array.shape),
                                           origin=",".join(
                                               str(o) for o in origin))
        # the remaining arguments are variables and can be passed by name
        for sym in sdfg.signature_arglist(with_types=False, for_call=True):
            if sym not in symbols:
                symbols[sym] = sym

        # return strings in order of sdfg signature
        return [
            symbols[s]
            for s in sdfg.signature_arglist(with_types=False, for_call=True)
        ]

    def generate_functor_args(self, sdfg: dace.SDFG):
        res = []
        for name, array in sdfg.arrays.items():
            if array.transient:
                continue
            res.append(f"auto && __{name}_sid")
        for name, dtype in ((n, d) for n, d in sdfg.symbols.items()
                            if not n.startswith("__")):
            res.append(dtype.as_arg(name))
        return res
예제 #5
0
class GTCppBindingsCodegen(codegen.TemplatedGenerator):
    def __init__(self):
        self._unique_index: int = 0

    def unique_index(self) -> int:
        self._unique_index += 1
        return self._unique_index

    def visit_DataType(self, dtype: DataType, **kwargs):
        return gtcpp_codegen.GTCppCodegen().visit_DataType(dtype)

    def visit_FieldDecl(self, node: gtcpp.FieldDecl, **kwargs):
        assert "gt_backend_t" in kwargs
        if "external_arg" in kwargs:
            domain_ndim = node.dimensions.count(True)
            data_ndim = len(node.data_dims)
            sid_ndim = domain_ndim + data_ndim
            if kwargs["external_arg"]:
                return "py::buffer {name}, std::array<gt::uint_t,{sid_ndim}> {name}_origin".format(
                    name=node.name,
                    sid_ndim=sid_ndim,
                )
            else:
                sid_def = """gt::as_{sid_type}<{dtype}, {sid_ndim},
                    gt::integral_constant<int, {unique_index}>>({name})""".format(
                    sid_type="cuda_sid"
                    if kwargs["gt_backend_t"] == "gpu" else "sid",
                    name=node.name,
                    dtype=self.visit(node.dtype),
                    unique_index=self.unique_index(),
                    sid_ndim=sid_ndim,
                )
                if domain_ndim != 3:
                    gt_dims = [
                        f"gt::stencil::dim::{dim}" for dim in
                        gtc_utils.dimension_flags_to_names(node.dimensions)
                    ]
                    if data_ndim:
                        gt_dims += [
                            f"gt::integral_constant<int, {3 + dim}>"
                            for dim in range(data_ndim)
                        ]
                    sid_def = "gt::sid::rename_numbered_dimensions<{gt_dims}>({sid_def})".format(
                        gt_dims=", ".join(gt_dims), sid_def=sid_def)

                return "gt::sid::shift_sid_origin({sid_def}, {name}_origin)".format(
                    sid_def=sid_def,
                    name=node.name,
                )

    def visit_GlobalParamDecl(self, node: gtcpp.GlobalParamDecl, **kwargs):
        if "external_arg" in kwargs:
            if kwargs["external_arg"]:
                return "{dtype} {name}".format(name=node.name,
                                               dtype=self.visit(node.dtype))
            else:
                return "gridtools::stencil::make_global_parameter({name})".format(
                    name=node.name)

    def visit_Program(self, node: gtcpp.Program, **kwargs):
        assert "module_name" in kwargs
        entry_params = self.visit(node.parameters, external_arg=True, **kwargs)
        sid_params = self.visit(node.parameters, external_arg=False, **kwargs)
        return self.generic_visit(
            node,
            entry_params=entry_params,
            sid_params=sid_params,
            **kwargs,
        )

    Program = as_mako("""
        #include <chrono>
        #include <pybind11/pybind11.h>
        #include <pybind11/stl.h>
        #include <gridtools/storage/adapter/python_sid_adapter.hpp>
        #include <gridtools/stencil/global_parameter.hpp>
        #include <gridtools/sid/sid_shift_origin.hpp>
        #include <gridtools/sid/rename_dimensions.hpp>
        #include "computation.hpp"
        namespace gt = gridtools;
        namespace py = ::pybind11;
        PYBIND11_MODULE(${module_name}, m) {
            m.def("run_computation", [](
            ${','.join(["std::array<gt::uint_t, 3> domain", *entry_params, 'py::object exec_info'])}
            ){
                if (!exec_info.is(py::none()))
                {
                    auto exec_info_dict = exec_info.cast<py::dict>();
                    exec_info_dict["run_cpp_start_time"] = static_cast<double>(
                        std::chrono::duration_cast<std::chrono::nanoseconds>(
                            std::chrono::high_resolution_clock::now().time_since_epoch()).count())/1e9;
                }

                ${name}(domain)(${','.join(sid_params)});

                if (!exec_info.is(py::none()))
                {
                    auto exec_info_dict = exec_info.cast<py::dict>();
                    exec_info_dict["run_cpp_end_time"] = static_cast<double>(
                        std::chrono::duration_cast<std::chrono::nanoseconds>(
                            std::chrono::high_resolution_clock::now().time_since_epoch()).count()/1e9);
                }

            }, "Runs the given computation");}
        """)

    @classmethod
    def apply(cls, root, *, module_name="stencil", **kwargs) -> str:
        generated_code = cls().visit(root, module_name=module_name, **kwargs)
        formatted_code = codegen.format_source("cpp",
                                               generated_code,
                                               style="LLVM")
        return formatted_code
예제 #6
0
class DaCeComputationCodegen:

    template = as_mako("""
        auto ${name}(const std::array<gt::uint_t, 3>& domain) {
            return [domain](${",".join(functor_args)}) {
                const int __I = domain[0];
                const int __J = domain[1];
                const int __K = domain[2];
                ${name}_t dace_handle;
                auto allocator = gt::sid::make_cached_allocator(&${allocator}<char[]>);
                ${"\\n".join(tmp_allocs)}
                __program_${name}(${",".join(["&dace_handle", *dace_args])});
            };
        }
        """)

    def generate_tmp_allocs(self, sdfg):
        fmt = "dace_handle.{name} = allocate(allocator, gt::meta::lazy::id<{dtype}>(), {size})();"
        return [
            fmt.format(
                name=f"__{array_sdfg.sdfg_id}_{name}",
                dtype=array.dtype.ctype,
                size=array.total_size,
            ) for array_sdfg, name, array in sdfg.arrays_recursive()
            if array.transient
            and array.lifetime == dace.AllocationLifetime.Persistent
        ]

    @staticmethod
    def _postprocess_dace_code(code_objects, is_gpu):
        lines = code_objects[[co.title for co in code_objects
                              ].index("Frame")].clean_code.split("\n")

        if is_gpu:
            regex = re.compile("struct [a-zA-Z_][a-zA-Z0-9_]*_t {")
            for i, line in enumerate(lines):
                if regex.match(line.strip()):
                    j = i + 1
                    while "};" not in lines[j].strip():
                        j += 1
                    lines = lines[0:i] + lines[j + 1:]
                    break
            for i, line in enumerate(lines):
                if "#include <dace/dace.h>" in line:
                    cuda_code = [
                        co.clean_code for co in code_objects
                        if co.title == "CUDA"
                    ][0]
                    lines = lines[0:i] + cuda_code.split("\n") + lines[i + 1:]
                    break

        def keep_line(line):
            line = line.strip()
            if line == '#include "../../include/hash.h"':
                return False
            if line.startswith("DACE_EXPORTED") and line.endswith(");"):
                return False
            if line == "#include <cuda_runtime.h>":
                return False
            return True

        lines = filter(keep_line, lines)
        return codegen.format_source("cpp", "\n".join(lines), style="LLVM")

    @classmethod
    def apply(cls, stencil_ir: gtir.Stencil, sdfg: dace.SDFG):
        self = cls()
        with dace.config.temporary_config():
            dace.config.Config.set("compiler",
                                   "cuda",
                                   "max_concurrent_streams",
                                   value=-1)
            dace.config.Config.set("compiler",
                                   "cpu",
                                   "openmp_sections",
                                   value=False)
            code_objects = sdfg.generate_code()
        is_gpu = "CUDA" in {co.title for co in code_objects}

        computations = cls._postprocess_dace_code(code_objects, is_gpu)

        interface = cls.template.definition.render(
            name=sdfg.name,
            dace_args=self.generate_dace_args(stencil_ir, sdfg),
            functor_args=self.generate_functor_args(sdfg),
            tmp_allocs=self.generate_tmp_allocs(sdfg),
            allocator="gt::cuda_util::cuda_malloc"
            if is_gpu else "std::make_unique",
        )
        generated_code = f"""#include <gridtools/sid/sid_shift_origin.hpp>
                             #include <gridtools/sid/allocator.hpp>
                             #include <gridtools/stencil/cartesian.hpp>
                             {"#include <gridtools/common/cuda_util.hpp>" if is_gpu else ""}
                             namespace gt = gridtools;
                             {computations}

                             {interface}
                             """
        formatted_code = codegen.format_source("cpp",
                                               generated_code,
                                               style="LLVM")
        return formatted_code

    def __init__(self):
        self._unique_index = 0

    def generate_dace_args(self, ir, sdfg):
        oir = GTIRToOIR().visit(ir)
        field_extents = compute_fields_extents(oir, add_k=True)

        offset_dict: Dict[str, Tuple[int, int, int]] = {
            k: (max(-v[0][0], 0), max(-v[1][0], 0), -v[2][0])
            for k, v in field_extents.items()
        }
        k_origins = {
            field_name: boundary[0]
            for field_name, boundary in compute_k_boundary(ir).items()
        }
        for name, origin in k_origins.items():
            offset_dict[name] = (offset_dict[name][0], offset_dict[name][1],
                                 origin)

        symbols = {f"__{var}": f"__{var}" for var in "IJK"}
        for name, array in sdfg.arrays.items():
            if not array.transient:
                dims = [
                    dim for dim, select in zip("IJK", array_dimensions(array))
                    if select
                ]
                data_ndim = len(array.shape) - len(dims)

                # api field strides
                fmt = "gt::sid::get_stride<{dim}>(gt::sid::get_strides(__{name}_sid))"

                symbols.update({
                    f"__{name}_{dim}_stride":
                    fmt.format(dim=f"gt::stencil::dim::{dim.lower()}",
                               name=name)
                    for dim in dims
                })
                symbols.update({
                    f"__{name}_d{dim}_stride":
                    fmt.format(dim=f"gt::integral_constant<int, {3 + dim}>",
                               name=name)
                    for dim in range(data_ndim)
                })

                # api field pointers
                fmt = """gt::sid::multi_shifted(
                             gt::sid::get_origin(__{name}_sid)(),
                             gt::sid::get_strides(__{name}_sid),
                             std::array<gt::int_t, {ndim}>{{{origin}}}
                         )"""
                origin = tuple(-offset_dict[name][idx]
                               for idx, var in enumerate("IJK") if any(
                                   dace.symbolic.pystr_to_symbolic(f"__{var}")
                                   in s.free_symbols for s in array.shape
                                   if hasattr(s, "free_symbols")))
                symbols[name] = fmt.format(name=name,
                                           ndim=len(array.shape),
                                           origin=",".join(
                                               str(o) for o in origin))
        # the remaining arguments are variables and can be passed by name
        for sym in sdfg.signature_arglist(with_types=False, for_call=True):
            if sym not in symbols:
                symbols[sym] = sym

        # return strings in order of sdfg signature
        return [
            symbols[s]
            for s in sdfg.signature_arglist(with_types=False, for_call=True)
        ]

    def generate_functor_args(self, sdfg: dace.SDFG):
        res = []
        for name, array in sdfg.arrays.items():
            if array.transient:
                continue
            res.append(f"auto && __{name}_sid")
        for name, dtype in ((n, d) for n, d in sdfg.symbols.items()
                            if not n.startswith("__")):
            res.append(dtype.as_arg(name))
        return res
예제 #7
0
class CUIRCodegen(codegen.TemplatedGenerator):

    contexts = (traits.SymbolTableTrait.symtable_merger, )

    LocalScalar = as_fmt("{dtype} {name};")

    FieldDecl = as_fmt("{name}")

    ScalarDecl = as_fmt("{name}")

    Temporary = as_fmt("{name}")

    AssignStmt = as_fmt("{left} = {right};")

    MaskStmt = as_mako("""
        if (${mask}) {
            ${'\\n'.join(body)}
        }
        """)

    While = as_mako("""
        while (${cond}) {
            ${'\\n'.join(body)}
        }
        """)

    def visit_FieldAccess(self, node: cuir.FieldAccess, **kwargs: Any):
        def maybe_const(s):
            try:
                return f"{int(s)}_c"
            except ValueError:
                return s

        kwargs["this_data_index"] = "".join(
            ", " + maybe_const(self.visit(index, **kwargs))
            for index in node.data_index)
        return self.generic_visit(node, **kwargs)

    FieldAccess = as_mako("${name}(${offset}${this_data_index})")

    def visit_IJCacheAccess(self, node: cuir.IJCacheAccess,
                            symtable: Dict[str, Any], **kwargs: Any) -> str:
        extent = symtable[node.name].extent
        if extent.i == extent.j == (0, 0):
            # cache is scalar
            assert node.offset.i == node.offset.j == 0
            return node.name
        if node.offset.i == node.offset.j == 0:
            return "*" + node.name
        offsets = (f"{o} * {d}_stride_{node.name}"
                   for o, d in zip([node.offset.i, node.offset.j], "ij")
                   if o != 0)
        return node.name + "[" + " + ".join(offsets) + "]"

    KCacheAccess = as_mako(
        "${_this_generator.k_cache_var(name, _this_node.offset.k)}")

    ScalarAccess = as_fmt("{name}")

    CartesianOffset = as_fmt("{i}_c, {j}_c, {k}_c")

    VariableKOffset = as_fmt("0_c, 0_c, {k}")

    BinaryOp = as_fmt("({left} {op} {right})")

    UNARY_OPERATOR_TO_CODE = {
        UnaryOperator.NOT: "!",
        UnaryOperator.NEG: "-",
        UnaryOperator.POS: "+",
    }

    UnaryOp = as_fmt(
        "({_this_generator.UNARY_OPERATOR_TO_CODE[_this_node.op]}{expr})")

    TernaryOp = as_fmt("({cond} ? {true_expr} : {false_expr})")

    Cast = as_fmt("static_cast<{dtype}>({expr})")

    BUILTIN_LITERAL_TO_CODE = {
        BuiltInLiteral.TRUE: "true",
        BuiltInLiteral.FALSE: "false",
    }

    def visit_BuiltInLiteral(self, builtin: BuiltInLiteral,
                             **kwargs: Any) -> str:
        try:
            return self.BUILTIN_LITERAL_TO_CODE[builtin]
        except KeyError as error:
            raise NotImplementedError(
                "Not implemented BuiltInLiteral encountered.") from error

    Literal = as_mako("static_cast<${dtype}>(${value})")

    NATIVE_FUNCTION_TO_CODE = {
        NativeFunction.ABS: "std::abs",
        NativeFunction.MIN: "std::min",
        NativeFunction.MAX: "std::max",
        NativeFunction.MOD: "std::fmod",
        NativeFunction.SIN: "std::sin",
        NativeFunction.COS: "std::cos",
        NativeFunction.TAN: "std::tan",
        NativeFunction.ARCSIN: "std::asin",
        NativeFunction.ARCCOS: "std::acos",
        NativeFunction.ARCTAN: "std::atan",
        NativeFunction.SINH: "std::sinh",
        NativeFunction.COSH: "std::cosh",
        NativeFunction.TANH: "std::tanh",
        NativeFunction.ARCSINH: "std::asinh",
        NativeFunction.ARCCOSH: "std::acosh",
        NativeFunction.ARCTANH: "std::atanh",
        NativeFunction.SQRT: "std::sqrt",
        NativeFunction.POW: "std::pow",
        NativeFunction.EXP: "std::exp",
        NativeFunction.LOG: "std::log",
        NativeFunction.GAMMA: "std::tgamma",
        NativeFunction.CBRT: "std::cbrt",
        NativeFunction.ISFINITE: "std::isfinite",
        NativeFunction.ISINF: "std::isinf",
        NativeFunction.ISNAN: "std::isnan",
        NativeFunction.FLOOR: "std::floor",
        NativeFunction.CEIL: "std::ceil",
        NativeFunction.TRUNC: "std::trunc",
    }

    def visit_NativeFunction(self, func: NativeFunction, **kwargs: Any) -> str:
        try:
            return self.NATIVE_FUNCTION_TO_CODE[func]
        except KeyError as error:
            raise NotImplementedError(
                f"Not implemented NativeFunction '{func}' encountered."
            ) from error

    NativeFuncCall = as_mako("${func}(${','.join(args)})")

    DATA_TYPE_TO_CODE = {
        DataType.BOOL: "bool",
        DataType.INT8: "std::int8_t",
        DataType.INT16: "std::int16_t",
        DataType.INT32: "std::int32_t",
        DataType.INT64: "std::int64_t",
        DataType.FLOAT32: "float",
        DataType.FLOAT64: "double",
    }

    def visit_DataType(self, dtype: DataType, **kwargs: Any) -> str:
        try:
            return self.DATA_TYPE_TO_CODE[dtype]
        except KeyError as error:
            raise NotImplementedError(
                f"Not implemented DataType '{dtype.name}' encountered."
            ) from error

    IJExtent = as_fmt("extent<{i[0]}, {i[1]}, {j[0]}, {j[1]}>")

    HorizontalExecution = as_mako("""
        // HorizontalExecution ${id(_this_node)}
        if (validator(${extent}())) {
            ${'\\n'.join(declarations)}
            ${'\\n'.join(body)}
        }
        """)

    def visit_AxisBound(self, node: cuir.AxisBound, **kwargs: Any) -> str:
        if node.level == LevelMarker.START:
            return f"{node.offset}"
        if node.level == LevelMarker.END:
            return f"k_size + {node.offset}"
        raise ValueError("Cannot handle dynamic levels")

    IJCacheDecl = as_mako("""
        % if _this_node.extent.i == _this_node.extent.j == (0, 0):
        // scalar ij-cache
        ${dtype} ${name};
        % else:
        // ij-cache in shared memory
        constexpr int ${name}_cache_data_size = (i_block_size_t() + ${-_this_node.extent.i[0] + _this_node.extent.i[1]}) * (j_block_size_t() + ${-_this_node.extent.j[0] + _this_node.extent.j[1]});
        __shared__ ${dtype} ${name}_cache_data[${name}_cache_data_size];
        constexpr int i_stride_${name} = 1;
        constexpr int j_stride_${name} = i_block_size_t() + ${-_this_node.extent.i[0] + _this_node.extent.i[1]};
        ${dtype} *${name} = ${name}_cache_data + (${-_this_node.extent.i[0]} + _i_block) * i_stride_${name} + (${-_this_node.extent.j[0]} + _j_block) * j_stride_${name};
        % endif
        """)

    KCacheDecl = as_mako("""
        % for var in _this_generator.k_cache_vars(_this_node):
            ${dtype} ${var};
        % endfor
        """)

    VerticalLoopSection = as_mako("""
        <%def name="sid_shift(step)">
            sid::shift(_ptr, sid::get_stride<dim::k>(m_strides), ${step}_c);
        </%def>
        <%def name="cache_shift(cache_vars)">
            % for dst, src in zip(cache_vars[:-1], cache_vars[1:]):
                ${dst} = ${src};
            % endfor
        </%def>
        // VerticalLoopSection ${id(_this_node)}
        % if order == cuir.LoopOrder.FORWARD:
            for (int _k_block = ${start}; _k_block < ${end}; ++_k_block) {
                ${'\\n__syncthreads();\\n'.join(horizontal_executions)}

                ${sid_shift(1)}
                % for k_cache in k_cache_decls:
                    ${cache_shift(_this_generator.k_cache_vars(k_cache))}
                % endfor
            }
        % elif order == cuir.LoopOrder.BACKWARD:
            for (int _k_block = ${end} - 1; _k_block >= ${start}; --_k_block) {
                ${'\\n__syncthreads();\\n'.join(horizontal_executions)}

                ${sid_shift(-1)}
                % for k_cache in k_cache_decls:
                    ${cache_shift(_this_generator.k_cache_vars(k_cache)[::-1])}
                % endfor
            }
        % else:
            if (_k_block >= ${start} && _k_block < ${end}) {
                ${'\\n__syncthreads();\\n'.join(horizontal_executions)}
            }
        % endif
        """)

    @staticmethod
    def k_cache_var(name: str, offset: int) -> str:
        return name + (f"p{offset}" if offset >= 0 else f"m{-offset}")

    @classmethod
    def k_cache_vars(cls, k_cache: cuir.KCacheDecl) -> List[str]:
        assert k_cache.extent
        return [
            cls.k_cache_var(k_cache.name, offset)
            for offset in range(k_cache.extent.k[0], k_cache.extent.k[1] + 1)
        ]

    def visit_VerticalLoop(self, node: cuir.VerticalLoop, *,
                           symtable: Dict[str, Any],
                           **kwargs: Any) -> Union[str, Collection[str]]:

        fields = {
            name: data_dims
            for name, data_dims in
            node.iter_tree().if_isinstance(cuir.FieldAccess).getattr(
                "name", "data_index").map(lambda x: (x[0], len(x[1])))
        }

        return self.generic_visit(
            node,
            fields=fields,
            k_cache_decls=node.k_caches,
            order=node.loop_order,
            symtable=symtable,
            **kwargs,
        )

    VerticalLoop = as_mako("""
        template <class Sid>
        struct loop_${id(_this_node)}_f {
            sid::ptr_holder_type<Sid> m_ptr_holder;
            sid::strides_type<Sid> m_strides;
            int k_size;

            template <class Validator>
            GT_FUNCTION_DEVICE void operator()(const int _i_block,
                                               const int _j_block,
                                               Validator validator) const {
                auto _ptr = m_ptr_holder();
                sid::shift(_ptr,
                           sid::get_stride<sid::blocked_dim<dim::i>>(m_strides),
                           blockIdx.x);
                sid::shift(_ptr,
                           sid::get_stride<sid::blocked_dim<dim::j>>(m_strides),
                           blockIdx.y);
                sid::shift(_ptr,
                           sid::get_stride<dim::i>(m_strides),
                           _i_block);
                sid::shift(_ptr,
                           sid::get_stride<dim::j>(m_strides),
                           _j_block);
                % if order == cuir.LoopOrder.PARALLEL:
                const int _k_block = blockIdx.z;
                sid::shift(_ptr,
                           sid::get_stride<dim::k>(m_strides),
                           _k_block);
                % endif

                % for field, data_dims in fields.items():
                    const auto ${field} = [&](auto i, auto j, auto k
                        % for i in range(data_dims):
                            , auto dim_${i + 3}
                        % endfor
                        ) -> auto&& {
                        return *sid::multi_shifted<tag::${field}>(
                            device::at_key<tag::${field}>(_ptr),
                            m_strides,
                            tuple_util::device::make<hymap::keys<dim::i, dim::j, dim::k
                            % for i in range(data_dims):
                                , integral_constant<int, ${i + 3}>
                            % endfor
                            >::template values>(i, j, k
                            % for i in range(data_dims):
                                , dim_${i + 3}
                            % endfor
                            ));
                    };
                % endfor

                % for ij_cache in ij_caches:
                    ${ij_cache}
                % endfor

                % for k_cache in k_caches:
                    ${k_cache}
                % endfor

                % for section in sections:
                    ${section}
                % endfor
            }
        };
        """)

    Kernel = as_mako("""
        % for vertical_loop in vertical_loops:
            ${vertical_loop}
        % endfor

        template <${', '.join(f'class Loop{id(vl)}' for vl in _this_node.vertical_loops)}>
        struct kernel_${id(_this_node)}_f {
            % for vertical_loop in _this_node.vertical_loops:
                Loop${id(vertical_loop)} m_${id(vertical_loop)};
            % endfor

            template <class Validator>
            GT_FUNCTION_DEVICE void operator()(const int _i_block,
                                               const int _j_block,
                                               Validator validator) const {
                % for vertical_loop in _this_node.vertical_loops:
                    m_${id(vertical_loop)}(_i_block, _j_block, validator);
                % endfor
            }
        };

        """)

    def visit_Program(self, node: cuir.Program,
                      **kwargs: Any) -> Union[str, Collection[str]]:
        def loop_start(vertical_loop: cuir.VerticalLoop) -> str:
            if vertical_loop.loop_order == cuir.LoopOrder.FORWARD:
                return self.visit(vertical_loop.sections[0].start, **kwargs)
            if vertical_loop.loop_order == cuir.LoopOrder.BACKWARD:
                return self.visit(vertical_loop.sections[0].end, **
                                  kwargs) + " - 1"
            return "0"

        def loop_fields(vertical_loop: cuir.VerticalLoop) -> Set[str]:
            return (vertical_loop.iter_tree().if_isinstance(
                cuir.FieldAccess).getattr("name").to_set())

        def ctype(symbol: str) -> str:
            return self.visit(kwargs["symtable"][symbol].dtype, **kwargs)

        return self.generic_visit(
            node,
            max_extent=self.visit(
                cuir.IJExtent.zero().union(
                    *node.iter_tree().if_isinstance(cuir.IJExtent)), **kwargs),
            loop_start=loop_start,
            loop_fields=loop_fields,
            ctype=ctype,
            cuir=cuir,
            **kwargs,
        )

    Program = as_mako("""#include <algorithm>
        #include <array>
        #include <cstdint>
        #include <gridtools/common/cuda_util.hpp>
        #include <gridtools/common/host_device.hpp>
        #include <gridtools/common/hymap.hpp>
        #include <gridtools/common/integral_constant.hpp>
        #include <gridtools/sid/allocator.hpp>
        #include <gridtools/sid/block.hpp>
        #include <gridtools/sid/composite.hpp>
        #include <gridtools/sid/multi_shift.hpp>
        #include <gridtools/stencil/common/dim.hpp>
        #include <gridtools/stencil/common/extent.hpp>
        #include <gridtools/stencil/gpu/launch_kernel.hpp>
        #include <gridtools/stencil/gpu/tmp_storage_sid.hpp>

        namespace ${name}_impl_{
            using namespace gridtools;
            using namespace literals;
            using namespace stencil;

            using domain_t = std::array<unsigned, 3>;
            using i_block_size_t = integral_constant<int, 64>;
            using j_block_size_t = integral_constant<int, 8>;

            template <class Storage>
            auto block(Storage storage) {
                return sid::block(std::move(storage),
                    tuple_util::make<hymap::keys<dim::i, dim::j>::values>(
                        i_block_size_t(), j_block_size_t()));
            }

            namespace tag {
                % for p in set().union(*(loop_fields(v) for k in _this_node.kernels for v in k.vertical_loops)):
                    struct ${p} {};
                % endfor
            }

            % for kernel in kernels:
                ${kernel}
            % endfor

            auto ${name}(domain_t domain){
                return [domain](${','.join(f'auto&& {p}' for p in params)}){
                    auto tmp_alloc = sid::device::make_cached_allocator(&cuda_util::cuda_malloc<char[]>);
                    const int i_size = domain[0];
                    const int j_size = domain[1];
                    const int k_size = domain[2];
                    const int i_blocks = (i_size + i_block_size_t() - 1) / i_block_size_t();
                    const int j_blocks = (j_size + j_block_size_t() - 1) / j_block_size_t();

                    % for tmp in temporaries:
                        auto ${tmp} = gpu_backend::make_tmp_storage<${ctype(tmp)}>(
                            1_c,
                            i_block_size_t(),
                            j_block_size_t(),
                            ${max_extent}(),
                            i_blocks,
                            j_blocks,
                            k_size,
                            tmp_alloc);
                    % endfor

                    % for kernel in _this_node.kernels:

                        // kernel ${id(kernel)}

                        % for vertical_loop in kernel.vertical_loops:
                            // vertical loop ${id(vertical_loop)}

                            assert((${loop_start(vertical_loop)}) >= 0 &&
                                   (${loop_start(vertical_loop)}) < k_size);
                            auto offset_${id(vertical_loop)} = tuple_util::make<hymap::keys<dim::k>::values>(
                                ${loop_start(vertical_loop)}
                            );

                            auto composite_${id(vertical_loop)} = sid::composite::make<
                                    ${', '.join(f'tag::{field}' for field in loop_fields(vertical_loop))}
                                >(

                            % for field in loop_fields(vertical_loop):
                                % if field in params:
                                    block(sid::shift_sid_origin(
                                        ${field},
                                        offset_${id(vertical_loop)}
                                    ))
                                % else:
                                    sid::shift_sid_origin(
                                        ${field},
                                        offset_${id(vertical_loop)}
                                    )
                                % endif
                                ${'' if loop.last else ','}
                            % endfor
                            );
                            using composite_${id(vertical_loop)}_t = decltype(composite_${id(vertical_loop)});
                            loop_${id(vertical_loop)}_f<composite_${id(vertical_loop)}_t> loop_${id(vertical_loop)}{
                                sid::get_origin(composite_${id(vertical_loop)}),
                                sid::get_strides(composite_${id(vertical_loop)}),
                                k_size
                            };

                        % endfor

                        kernel_${id(kernel)}_f<${', '.join(f'decltype(loop_{id(vl)})' for vl in kernel.vertical_loops)}> kernel_${id(kernel)}{
                            ${', '.join(f'loop_{id(vl)}' for vl in kernel.vertical_loops)}
                        };
                        gpu_backend::launch_kernel<${max_extent},
                            i_block_size_t::value, j_block_size_t::value>(
                            i_size,
                            j_size,
                            % if kernel.vertical_loops[0].loop_order == cuir.LoopOrder.PARALLEL:
                                k_size,
                            % else:
                                1,
                            %endif
                            kernel_${id(kernel)},
                            0);
                    % endfor
                };
            }
        }

        using ${name}_impl_::${name};
        """)

    @classmethod
    def apply(cls, root: LeafNode, **kwargs: Any) -> str:
        if not isinstance(root, cuir.Program):
            raise ValueError("apply() requires gtcpp.Progam root node")
        generated_code = super().apply(root, **kwargs)
        if kwargs.get("format_source", True):
            generated_code = codegen.format_source("cpp",
                                                   generated_code,
                                                   style="LLVM")

        return generated_code
예제 #8
0
파일: usid_codegen.py 프로젝트: fthaler/gtc
class UsidCodeGenerator(codegen.TemplatedGenerator):
    DATA_TYPE_TO_STR: ClassVar[Mapping[common.DataType,
                                       str]] = MappingProxyType({
                                           common.DataType.BOOLEAN:
                                           "bool",
                                           common.DataType.INT32:
                                           "int",
                                           common.DataType.UINT32:
                                           "unsigned_int",
                                           common.DataType.FLOAT32:
                                           "float",
                                           common.DataType.FLOAT64:
                                           "double",
                                       })

    LOCATION_TYPE_TO_STR: ClassVar[Mapping[common.LocationType,
                                           str]] = MappingProxyType({
                                               common.LocationType.Vertex:
                                               "vertex",
                                               common.LocationType.Edge:
                                               "edge",
                                               common.LocationType.Cell:
                                               "cell",
                                           })

    BUILTIN_LITERAL_TO_STR: ClassVar[Mapping[
        common.BuiltInLiteral, str]] = MappingProxyType({
            common.BuiltInLiteral.MAX_VALUE:
            "std::numeric_limits<TODO>::max()",
            common.BuiltInLiteral.MIN_VALUE:
            "std::numeric_limits<TODO>::min()",
            common.BuiltInLiteral.ZERO:
            "0",
            common.BuiltInLiteral.ONE:
            "1",
        })

    @classmethod
    def apply(cls, root, **kwargs) -> str:
        symbol_tbl_resolved = SymbolTblHelper().visit(root)
        generated_code = super().apply(symbol_tbl_resolved, **kwargs)
        formatted_code = codegen.format_source("cpp",
                                               generated_code,
                                               style="LLVM")
        return formatted_code

    def location_type_from_dimensions(self, dimensions):
        location_type = [
            dim for dim in dimensions if isinstance(dim, common.LocationType)
        ]
        if len(location_type) != 1:
            raise ValueError("Doesn't contain a LocationType!")
        return location_type[0]

    headers_ = [
        "<gridtools/next/mesh.hpp>",
        "<gridtools/next/tmp_storage.hpp>",
        "<gridtools/next/unstructured.hpp>",
        "<gridtools/sid/allocator.hpp>",
        "<gridtools/sid/composite.hpp>",
    ]

    preface_ = ""

    Connectivity = as_fmt(
        "auto {name} = gridtools::next::mesh::connectivity<{chain}>(mesh);")

    NeighborChain = as_mako("""<%
            loc_strs = [_this_generator.LOCATION_TYPE_TO_STR[e] for e in _this_node.elements]
        %>
        std::tuple<${ ','.join(loc_strs) }>
        """)

    SidCompositeNeighborTableEntry = as_fmt(
        "gridtools::next::connectivity::neighbor_table({_this_node.connectivity_deref_.name})"
    )

    SidCompositeEntry = as_fmt("{name}")

    SidComposite = as_mako("""
        auto ${ _this_node.field_name } = tu::make<gridtools::sid::composite::keys<${ ','.join([t.tag_name for t in _this_node.entries]) }>::values>(
        ${ ','.join(entries)});
        """)

    def visit_KernelCall(self, node: KernelCall, **kwargs):
        kernel: Kernel = kwargs["symbol_tbl_kernel"][node.name]
        connectivities = [
            self.generic_visit(conn, **kwargs)
            for conn in kernel.connectivities
        ]
        primary_connectivity: Connectivity = kernel.symbol_tbl[
            kernel.primary_connectivity]
        sids = [
            self.generic_visit(s, **kwargs) for s in kernel.sids
            if len(s.entries) > 0
        ]

        # TODO I don't like that I render here and that I somehow have the same pattern for the parameters
        args = [c.name for c in kernel.connectivities]
        args += [
            "gridtools::sid::get_origin({0}), gridtools::sid::get_strides({0})"
            .format(s.field_name) for s in kernel.sids if len(s.entries) > 0
        ]
        # connectivity_args = [c.name for c in kernel.connectivities]
        return self.generic_visit(
            node,
            connectivities=connectivities,
            sids=sids,
            primary_connectivity=primary_connectivity,
            args=args,
            **kwargs,
        )

    def visit_Kernel(self, node: Kernel, **kwargs):
        symbol_tbl_conn = {c.name: c for c in node.connectivities}
        symbol_tbl_sids = {s.name: s for s in node.sids}

        parameters = [c.name for c in node.connectivities]
        for s in node.sids:
            if len(s.entries) > 0:
                parameters.append(s.origin_name)
                parameters.append(s.strides_name)

        return self.generic_visit(
            node,
            parameters=parameters,
            symbol_tbl_conn=symbol_tbl_conn,
            symbol_tbl_sids=symbol_tbl_sids,
            **kwargs,
        )

    FieldAccess = as_mako("""<%
            sid_deref = symbol_tbl_sids[_this_node.sid]
            sid_entry_deref = sid_deref.symbol_tbl[_this_node.name]
        %>*gridtools::host_device::at_key<${ sid_entry_deref.tag_name }>(${ sid_deref.ptr_name })"""
                          )

    AssignStmt = as_fmt("{left} = {right};")

    BinaryOp = as_fmt("({left} {op} {right})")

    NeighborLoop = as_mako("""<%
            outer_sid_deref = symbol_tbl_sids[_this_node.outer_sid]
            sid_deref = symbol_tbl_sids[_this_node.sid] if _this_node.sid else None
            conn_deref = symbol_tbl_conn[_this_node.connectivity]
            body_location = _this_generator.LOCATION_TYPE_TO_STR[sid_deref.location.elements[-1]] if sid_deref else None
        %>
        for (int neigh = 0; neigh < gridtools::next::connectivity::max_neighbors(${ conn_deref.name }); ++neigh) {
            auto absolute_neigh_index = *gridtools::host_device::at_key<${ conn_deref.neighbor_tbl_tag }>(${ outer_sid_deref.ptr_name});
            if (absolute_neigh_index != gridtools::next::connectivity::skip_value(${ conn_deref.name })) {
                % if sid_deref:
                    auto ${ sid_deref.ptr_name } = ${ sid_deref.origin_name }();
                    gridtools::sid::shift(
                        ${ sid_deref.ptr_name }, gridtools::host_device::at_key<${ body_location }>(${ sid_deref.strides_name }), absolute_neigh_index);
                % endif

                // bodyparameters
                ${ ''.join(body) }
                // end body
            }
            gridtools::sid::shift(${ outer_sid_deref.ptr_name }, gridtools::host_device::at_key<neighbor>(${ outer_sid_deref.strides_name }), 1);
        }
        gridtools::sid::shift(${ outer_sid_deref.ptr_name }, gridtools::host_device::at_key<neighbor>(${ outer_sid_deref.strides_name }),
            -gridtools::next::connectivity::max_neighbors(${ conn_deref.name }));

        """)

    Literal = as_mako("""<%
            literal= _this_node.value if isinstance(_this_node.value, str) else _this_generator.BUILTIN_LITERAL_TO_STR[_this_node.value]
        %>(${ _this_generator.DATA_TYPE_TO_STR[_this_node.vtype] })${ literal }"""
                      )

    VarAccess = as_fmt("{name}")

    VarDecl = as_mako(
        "${ _this_generator.DATA_TYPE_TO_STR[_this_node.vtype] } ${ name } = ${ init };"
    )

    def visit_Computation(self, node: Computation, **kwargs):
        symbol_tbl_kernel = {k.name: k for k in node.kernels}
        sid_tags = set()
        for k in node.kernels:
            for s in k.sids:
                for e in s.entries:
                    sid_tags.add("struct " + e.tag_name + ";")

        return self.generic_visit(
            node,
            computation_fields=node.parameters + node.temporaries,
            # cache_allocator=cache_allocator_,
            sid_tags=sid_tags,
            symbol_tbl_kernel=symbol_tbl_kernel,
            **kwargs,
        )

    Computation = as_mako("""${_this_generator.preface_}
        ${ '\\n'.join('#include ' + header for header in _this_generator.headers_) }

        namespace ${ name }_impl_ {
            ${ ''.join(sid_tags) }

            ${ ''.join(kernels) }
        }

        template<class mesh_t, ${ ','.join('class ' + p.name + '_t' for p in _this_node.parameters) }>
        void ${ name }(mesh_t&& mesh, ${ ','.join(p.name + '_t&& ' + p.name for p in _this_node.parameters) }){
            namespace tu = gridtools::tuple_util;
            using namespace ${ name }_impl_;

            % if len(temporaries) > 0:
                auto tmp_alloc = ${ _this_generator.cache_allocator_ }
            % endif
            ${ ''.join(temporaries) }

            ${ ''.join(ctrlflow_ast) }
        }
        """)

    def visit_Temporary(self, node: Temporary, **kwargs):
        c_vtype = self.DATA_TYPE_TO_STR[node.vtype]
        loctype = self.LOCATION_TYPE_TO_STR[self.location_type_from_dimensions(
            node.dimensions)]
        return self.generic_visit(node,
                                  loctype=loctype,
                                  c_vtype=c_vtype,
                                  **kwargs)

    Temporary = as_mako("""
        auto ${ name } = gridtools::next::make_simple_tmp_storage<${ loctype }, ${ c_vtype }>(
            (int)gridtools::next::connectivity::size(gridtools::next::mesh::connectivity<std::tuple<${ loctype }>>(mesh)), 1 /* TODO ksize */, tmp_alloc);"""
                        )
예제 #9
0
class GTCppCodegen(codegen.TemplatedGenerator):

    GTExtent = as_fmt("extent<{i[0]},{i[1]},{j[0]},{j[1]},{k[0]},{k[1]}>")

    GTAccessor = as_fmt("using {name} = {intent}_accessor<{id}, {extent}, {ndim}>;")

    GTParamList = as_mako(
        """${ '\\n'.join(accessors) }

        using param_list = make_param_list<${ ','.join(a.name for a in _this_node.accessors)}>;
        """
    )

    GTFunctor = as_mako(
        """struct ${ name } {
        ${param_list}

        ${ '\\n'.join(applies) }
    };
    """
    )

    GTLevel = as_fmt("gridtools::stencil::core::level<{splitter}, {offset}, {offset_limit}>")

    GTInterval = as_fmt("gridtools::stencil::core::interval<{from_level}, {to_level}>")

    LocalVarDecl = as_fmt("{dtype} {name};")

    GTApplyMethod = as_mako(
        """
    template<typename Evaluation>
    GT_FUNCTION static void apply(Evaluation eval, ${interval}) {
        ${ ' '.join(local_variables) }
        ${ '\\n'.join(body) }
    }
    """
    )

    AssignStmt = as_fmt("{left} = {right};")

    AccessorRef = as_fmt("eval({name}({', '.join([offset, *data_index])}))")

    ScalarAccess = as_fmt("{name}")

    CartesianOffset = as_fmt("{i}, {j}, {k}")

    BinaryOp = as_fmt("({left} {op} {right})")

    UnaryOp = as_fmt("({op}{expr})")

    TernaryOp = as_fmt("({cond} ? {true_expr} : {false_expr})")

    Cast = as_fmt("static_cast<{dtype}>({expr})")

    def visit_BuiltInLiteral(self, builtin: BuiltInLiteral, **kwargs: Any) -> str:
        if builtin == BuiltInLiteral.TRUE:
            return "true"
        elif builtin == BuiltInLiteral.FALSE:
            return "false"
        raise NotImplementedError("Not implemented BuiltInLiteral encountered.")

    Literal = as_mako("static_cast<${dtype}>(${value})")

    def visit_NativeFunction(self, func: NativeFunction, **kwargs: Any) -> str:
        try:
            return {
                NativeFunction.ABS: "std::abs",
                NativeFunction.MIN: "std::min",
                NativeFunction.MAX: "std::max",
                NativeFunction.MOD: "std::fmod",
                NativeFunction.SIN: "std::sin",
                NativeFunction.COS: "std::cos",
                NativeFunction.TAN: "std::tan",
                NativeFunction.ARCSIN: "std::asin",
                NativeFunction.ARCCOS: "std::acos",
                NativeFunction.ARCTAN: "std::atan",
                NativeFunction.SQRT: "std::sqrt",
                NativeFunction.POW: "std::pow",
                NativeFunction.EXP: "std::exp",
                NativeFunction.LOG: "std::log",
                NativeFunction.ISFINITE: "std::isfinite",
                NativeFunction.ISINF: "std::isinf",
                NativeFunction.ISNAN: "std::isnan",
                NativeFunction.FLOOR: "std::floor",
                NativeFunction.CEIL: "std::ceil",
                NativeFunction.TRUNC: "std::trunc",
            }[func]
        except KeyError as error:
            raise NotImplementedError(
                f"Not implemented NativeFunction '{func}' encountered."
            ) from error

    NativeFuncCall = as_mako("${func}(${','.join(args)})")

    DATA_TYPE_TO_CODE = {
        DataType.BOOL: "bool",
        DataType.INT8: "std::int8_t",
        DataType.INT16: "std::int16_t",
        DataType.INT32: "std::int32_t",
        DataType.INT64: "std::int64_t",
        DataType.FLOAT32: "float",
        DataType.FLOAT64: "double",
    }

    def visit_DataType(self, dtype: DataType, **kwargs: Any) -> str:
        try:
            return self.DATA_TYPE_TO_CODE[dtype]
        except KeyError as error:
            raise NotImplementedError(
                f"Not implemented DataType '{dtype.name}' encountered."
            ) from error

    UNARY_OPERATOR_TO_CODE = {
        UnaryOperator.NOT: "!",
        UnaryOperator.NEG: "-",
        UnaryOperator.POS: "+",
    }

    UnaryOp = as_fmt("({_this_generator.UNARY_OPERATOR_TO_CODE[_this_node.op]}{expr})")

    Arg = as_fmt("{name}")

    Param = as_fmt("{name}")

    ApiParamDecl = as_fmt("{name}")

    GTStage = as_mako(".stage(${functor}(), ${','.join(args)})")

    GTMultiStage = as_mako("execute_${ loop_order }()${''.join(caches)}${''.join(stages)}")

    IJCache = as_fmt(".ij_cached({name})")
    KCache = as_mako(
        ".k_cached(${'cache_io_policy::fill(), ' if _this_node.fill else ''}${'cache_io_policy::flush(), ' if _this_node.flush else ''}${name})"
    )

    def visit_LoopOrder(self, looporder: LoopOrder, **kwargs: Any) -> str:
        return {
            LoopOrder.PARALLEL: "parallel",
            LoopOrder.FORWARD: "forward",
            LoopOrder.BACKWARD: "backward",
        }[looporder]

    Temporary = as_fmt("GT_DECLARE_TMP({dtype}, {name});")

    IfStmt = as_mako(
        """if(${cond}) ${true_branch}
        %if _this_node.false_branch:
            else ${false_branch}
        %endif
        """
    )

    BlockStmt = as_mako("{${''.join(body)}}")

    def visit_GTComputationCall(
        self, node: gtcpp.GTComputationCall, **kwargs: Any
    ) -> Union[str, Collection[str]]:
        computation_name = type(node).__name__ + str(id(node))
        return self.generic_visit(node, computation_name=computation_name, **kwargs)

    GTComputationCall = as_mako(
        """
        %if len(multi_stages) > 0 and len(arguments) > 0:
        {
            auto grid = make_grid(domain[0], domain[1], axis<1,
                axis_config::offset_limit<${offset_limit}>>{domain[2]});

            auto ${ computation_name } = [](${ ','.join('auto ' + a for a in arguments) }) {

                ${ '\\n'.join(temporaries) }
                return multi_pass(${ ','.join(multi_stages) });
            };

            run(${computation_name}, ${gt_backend_t}<>{}, grid, ${','.join(f"std::forward<decltype({arg})>({arg})" for arg in arguments)});
        }
        %endif
        """
    )

    Program = as_mako(
        """
        #include <gridtools/stencil/${gt_backend_t}.hpp>
        #include <gridtools/stencil/cartesian.hpp>

        namespace ${ name }_impl_{
            using Domain = std::array<gridtools::uint_t, 3>;
            using namespace gridtools::stencil;
            using namespace gridtools::stencil::cartesian;

            ${'\\n'.join(functors)}

            auto ${name}(Domain domain){
                return [domain](${ ','.join( 'auto&& ' + p for p in parameters)}){
                    ${gt_computation}
                };
            }
        }

        auto ${name}(${name}_impl_::Domain domain){
            return ${name}_impl_::${name}(domain);
        }
        """
    )

    @classmethod
    def apply(cls, root: LeafNode, **kwargs: Any) -> str:
        if not isinstance(root, gtcpp.Program):
            raise ValueError("apply() requires gtcpp.Progam root node")
        if "gt_backend_t" not in kwargs:
            raise TypeError("apply() missing 1 required keyword-only argument: 'gt_backend_t'")
        generated_code = super().apply(root, offset_limit=_offset_limit(root), **kwargs)
        formatted_code = codegen.format_source("cpp", generated_code, style="LLVM")
        return formatted_code
예제 #10
0
파일: expansion.py 프로젝트: havogt/gt4py
class TaskletCodegen(codegen.TemplatedGenerator):

    ScalarAccess = as_fmt("{name}")

    def visit_FieldAccess(self, node: oir.FieldAccess, *, is_target, targets):

        if (is_target or node.name in targets) and self.visit(node.offset) == "":
            targets.add(node.name)
            name = "__" + node.name
        else:
            name = node.name + "__" + self.visit(node.offset)
        if node.data_index:
            offset_str = str(node.data_index)
        else:
            offset_str = ""
        return name + offset_str

    def visit_CartesianOffset(self, node: common.CartesianOffset):
        res = []
        if node.i != 0:
            res.append(f'i{"m" if node.i<0 else "p"}{abs(node.i):d}')
        if node.j != 0:
            res.append(f'j{"m" if node.j<0 else "p"}{abs(node.j):d}')
        if node.k != 0:
            res.append(f'k{"m" if node.k<0 else "p"}{abs(node.k):d}')
        return "_".join(res)

    def visit_AssignStmt(self, node: oir.AssignStmt, **kwargs):
        right = self.visit(node.right, is_target=False, **kwargs)
        left = self.visit(node.left, is_target=True, **kwargs)
        return f"{left} = {right}"

    BinaryOp = as_fmt("({left} {op} {right})")

    UnaryOp = as_fmt("({op}{expr})")

    TernaryOp = as_fmt("({true_expr} if {cond} else {false_expr})")

    def visit_BuiltInLiteral(self, builtin: common.BuiltInLiteral, **kwargs: Any) -> str:
        if builtin == common.BuiltInLiteral.TRUE:
            return "True"
        elif builtin == common.BuiltInLiteral.FALSE:
            return "False"
        raise NotImplementedError("Not implemented BuiltInLiteral encountered.")

    Literal = as_fmt("{dtype}({value})")

    Cast = as_fmt("{dtype}({expr})")

    def visit_NativeFunction(self, func: common.NativeFunction, **kwargs: Any) -> str:
        try:
            return {
                common.NativeFunction.ABS: "abs",
                common.NativeFunction.MIN: "min",
                common.NativeFunction.MAX: "max",
                common.NativeFunction.MOD: "fmod",
                common.NativeFunction.SIN: "dace.math.sin",
                common.NativeFunction.COS: "dace.math.cos",
                common.NativeFunction.TAN: "dace.math.tan",
                common.NativeFunction.ARCSIN: "asin",
                common.NativeFunction.ARCCOS: "acos",
                common.NativeFunction.ARCTAN: "atan",
                common.NativeFunction.SQRT: "dace.math.sqrt",
                common.NativeFunction.POW: "dace.math.pow",
                common.NativeFunction.EXP: "dace.math.exp",
                common.NativeFunction.LOG: "dace.math.log",
                common.NativeFunction.ISFINITE: "isfinite",
                common.NativeFunction.ISINF: "isinf",
                common.NativeFunction.ISNAN: "isnan",
                common.NativeFunction.FLOOR: "dace.math.ifloor",
                common.NativeFunction.CEIL: "ceil",
                common.NativeFunction.TRUNC: "trunc",
            }[func]
        except KeyError as error:
            raise NotImplementedError("Not implemented NativeFunction encountered.") from error

    NativeFuncCall = as_mako("${func}(${','.join(args)})")

    def visit_DataType(self, dtype: common.DataType, **kwargs: Any) -> str:
        if dtype == common.DataType.BOOL:
            return "dace.bool_"
        elif dtype == common.DataType.INT8:
            return "dace.int8"
        elif dtype == common.DataType.INT16:
            return "dace.int16"
        elif dtype == common.DataType.INT32:
            return "dace.int32"
        elif dtype == common.DataType.INT64:
            return "dace.int64"
        elif dtype == common.DataType.FLOAT32:
            return "dace.float32"
        elif dtype == common.DataType.FLOAT64:
            return "dace.float64"
        raise NotImplementedError("Not implemented DataType encountered.")

    def visit_UnaryOperator(self, op: common.UnaryOperator, **kwargs: Any) -> str:
        if op == common.UnaryOperator.NOT:
            return " not "
        elif op == common.UnaryOperator.NEG:
            return "-"
        elif op == common.UnaryOperator.POS:
            return "+"
        raise NotImplementedError("Not implemented UnaryOperator encountered.")

    Arg = as_fmt("{name}")

    Param = as_fmt("{name}")

    LocalScalar = as_fmt("{name}: {dtype}")

    def visit_HorizontalExecution(self, node: oir.HorizontalExecution):
        targets: Set[str] = set()
        return "\n".join([*self.visit(node.declarations), *self.visit(node.body, targets=targets)])

    def visit_MaskStmt(self, node: oir.MaskStmt, **kwargs):
        mask_str = ""
        indent = ""
        if node.mask is not None:
            mask_str = f"if {self.visit(node.mask, is_target=False, **kwargs)}:"
            indent = "    "
        body_code = self.visit(node.body, targets=kwargs["targets"])
        body_code = [indent + b for b in body_code]
        return "\n".join([mask_str] + body_code)

    @classmethod
    def apply(cls, node: oir.HorizontalExecution, **kwargs: Any) -> str:
        if not isinstance(node, oir.HorizontalExecution):
            raise ValueError("apply() requires oir.HorizontalExecution node")
        generated_code = super().apply(node)
        formatted_code = codegen.format_source("python", generated_code)
        return formatted_code
예제 #11
0
class IconBindingsCodegen(codegen.TemplatedGenerator):
    @classmethod
    def apply(cls, root, **kwargs) -> str:
        assert "stencil_code" in kwargs
        generated_code = cls().visit(root, stencil_code=kwargs["stencil_code"])
        formatted_code = codegen.format_source("cpp",
                                               generated_code,
                                               style="LLVM")
        return formatted_code

    def visit_UField(self, node: UField, **kwargs):
        if node.name in kwargs["dimensionality"]:
            return self.generic_visit(node, **kwargs)
        else:
            return ""

    UField = as_fmt(
        "gridtools::fortran_array_view<T, {len(dimensionality[name])}, field_kind<{','.join(str(i) for i in dimensionality[name])}>> {name}"
    )

    def visit_SparseField(self, node: SparseField, **kwargs):
        if node.name in kwargs["dimensionality"]:
            return self.generic_visit(node, **kwargs)
        else:
            return ""

    SparseField = as_fmt(
        "gridtools::fortran_array_view<T,{len(dimensionality[name])}, field_kind<{','.join(str(i) for i in dimensionality[name])}>> {name}"
    )

    Connectivity = as_fmt("neigh_tbl_t {name}")

    def visit_Computation(self, node: Computation, **kwargs):
        dimensionality = {}
        for p in node.params:
            dimensionality[p.name] = [0, 1, 2]
            if not p.dimensions.horizontal:
                dimensionality[p.name].remove(0)
            if not p.dimensions.vertical:
                dimensionality[p.name].remove(1)
            if not isinstance(p, SparseField):
                dimensionality[p.name].remove(2)

        param_names = []
        for name, dims in dimensionality.items():
            renames = {}
            for index in range(0, 3):
                if index < len(dims):
                    if index != dims[index]:
                        renames[index] = dims[index]

            if renames:
                param_names.append(
                    "gridtools::sid::rename_dimensions<" + ",".join(
                        f"gridtools::integral_constant<int,{old}>, gridtools::integral_constant<int,{new}>>"
                        for old, new in renames.items()) + f"({name})")
            else:
                param_names.append(name)
        return self.generic_visit(node,
                                  param_names=param_names,
                                  dimensionality=dimensionality,
                                  **kwargs)

    Computation = as_mako("""
    # include <cpp_bindgen/export.hpp>
    # include <gridtools/storage/adapter/fortran_array_view.hpp>
    # include <gridtools/storage/sid.hpp>
    # include <gridtools/usid/icon.hpp>

    ${stencil_code}

    namespace icon_bindings_${name}_impl{

    struct default_tag{};

    template<int...>
    struct field_kind{};

    // template<class Tag>
    using neigh_tbl_t = gridtools::fortran_array_view<int, 2, default_tag, false>;

    auto alloc_${name}_impl(${','.join(['int n_edges', 'int n_k'] + connectivities)}) {
        return ${name}({-1, n_edges, -1, n_k}, ${','.join(f"icon::make_connectivity_producer<{c.max_neighbors}>({c.name})" for c in _this_node.connectivities)});
    }

    BINDGEN_EXPORT_BINDING_WRAPPED(${2+len(connectivities)}, alloc_${name}, alloc_${name}_impl);

    // template<class Tag>
    using ${name}_t = decltype(alloc_${name}_impl(0,0, neigh_tbl_t/*<Tag>*/{{}}));

    template <class T>
    void ${name}_impl(${name}_t ${name}, ${','.join(params)}){
        ${name}(${','.join(param_names)});
    }

    BINDGEN_EXPORT_GENERIC_BINDING_WRAPPED(${1+len(params)}, ${name}, ${name}_impl,
                                            (double));
    }
    """)
예제 #12
0
class UsidCodeGenerator(codegen.TemplatedGenerator):
    DATA_TYPE_TO_STR: ClassVar[Mapping[common.DataType,
                                       str]] = MappingProxyType({
                                           common.DataType.BOOLEAN:
                                           "bool",
                                           common.DataType.INT32:
                                           "int",
                                           common.DataType.UINT32:
                                           "unsigned_int",
                                           common.DataType.FLOAT32:
                                           "float",
                                           common.DataType.FLOAT64:
                                           "double",
                                       })

    BUILTIN_LITERAL_TO_STR: ClassVar[Mapping[
        common.BuiltInLiteral, str]] = MappingProxyType({
            common.BuiltInLiteral.MAX_VALUE:
            "std::numeric_limits<double>::max()",  # TODO: datatype
            common.BuiltInLiteral.MIN_VALUE:
            "std::numeric_limits<double>::min()",
            common.BuiltInLiteral.ZERO:
            "0",
            common.BuiltInLiteral.ONE:
            "1",
        })

    @classmethod
    def apply(cls, root, **kwargs) -> str:
        generated_code = super().apply(root, **kwargs)
        formatted_code = codegen.format_source("cpp",
                                               generated_code,
                                               style="LLVM")
        return formatted_code

    def location_type_from_dimensions(self, dimensions):
        location_type = [
            dim for dim in dimensions if isinstance(dim, common.LocationType)
        ]
        if len(location_type) != 1:
            raise ValueError("Doesn't contain a LocationType!")
        return location_type[0]

    headers_ = [
        "<gridtools/common/gt_math.hpp>",
        "<gridtools/common/array.hpp>",
        "<gridtools/usid/dim.hpp>",
        "<gridtools/usid/helpers.hpp>",
        "<gridtools/common/gt_math.hpp>",
    ]

    namespace_ = ""

    preface_ = ""

    def visit_LocationType(self, node: common.LocationType, **kwargs):
        return {
            common.LocationType.Vertex: "vertex",
            common.LocationType.Edge: "edge",
            common.LocationType.Cell: "cell",
        }[node]

    def visit_bool(self, node: bool, **kwargs):
        if node:
            return "true"
        else:
            return "false"

    def visit_SidCompositeSparseEntry(self, node: SidCompositeSparseEntry,
                                      **kwargs):
        return self.generic_visit(
            node,
            connectivity_tag=kwargs["symtable"][node.connectivity].tag,
            **kwargs)

    SidCompositeSparseEntry = as_fmt(
        "sid::rename_dimensions<dim::s, {connectivity_tag}>({ref})")

    SidCompositeEntry = as_fmt("{ref}")

    SidComposite = as_mako("""
        sid::composite::make<${ ','.join([t.name for t in _this_node.entries]) }>(
        ${ ','.join(entries)})
        """)

    def visit_KernelCall(self, node: KernelCall, **kwargs):
        kernel: Kernel = kwargs["symtable"][node.name]
        domain = f"d.{self.visit(kernel.primary_location)}"

        sids = self.visit([kernel.primary_composite] +
                          kernel.secondary_composites, **kwargs)

        return self.generic_visit(node, domain=domain, sids=sids)

    KernelCall = as_mako("""
        call_kernel<${name}>(${domain}, d.k, ${','.join(sids)});
        """)

    FieldAccess = as_mako("""<%
            composite_deref = symtable[_this_node.sid]
            sid_entry_deref = symtable[_this_node.name]
        %>field<${ sid_entry_deref.name }>(${ composite_deref.ptr_name })""")

    ArrayAccess = as_fmt("{name}[{subscript}]")

    AssignStmt = as_fmt("{left} = {right};")

    NativeFuncCall = as_fmt(
        "gridtools::math::{func}({','.join(args)})")  # TODO: fix func

    BinaryOp = as_fmt("({left} {op} {right})")

    PtrRef = as_fmt("{name}")

    LocalIndex = as_fmt("{name}")

    def visit_NeighborLoop(self, node: NeighborLoop, symtable, **kwargs):
        primary_sid_deref = symtable[node.primary_sid]
        connectivity_deref = symtable[node.connectivity]
        indexed = ""
        index_var = ""
        if node.local_index:
            indexed = "_indexed"
            index_var = f", auto {self.visit(node.local_index)}"
        return self.generic_visit(
            node,
            symtable={
                **node.symtable_,
                **symtable,
            },  # should be partly bounded (should see only global scope (tags) and current scope)
            primary_sid_deref=primary_sid_deref,
            connectivity_deref=connectivity_deref,
            indexed=indexed,
            index_var=index_var,
            **kwargs,
        )

    # TODO consider stricter capture
    NeighborLoop = as_mako("""
        foreach_neighbor${indexed}<${connectivity_deref.tag}>([&](auto &&${primary}, auto &&${secondary}${index_var}){${''.join(body)}}, ${primary_sid_deref.ptr_name}, ${primary_sid_deref.strides_name}, ${secondary_sid});
        """)

    Literal = as_mako("""<%
            literal= _this_node.value if isinstance(_this_node.value, str) else _this_generator.BUILTIN_LITERAL_TO_STR[_this_node.value]
        %>(${ _this_generator.DATA_TYPE_TO_STR[_this_node.vtype] })${ literal }"""
                      )

    VarAccess = as_fmt("{name}")

    VarDecl = as_mako(
        "${ _this_generator.DATA_TYPE_TO_STR[_this_node.vtype] } ${ name } = ${ init };"
    )

    StaticArrayDecl = as_mako(
        "gridtools::array<${_this_generator.DATA_TYPE_TO_STR[_this_node.vtype]}, ${size}> ${name} = {${','.join(init)}};"
    )

    def visit_Connectivity(self, node: Connectivity, **kwargs):
        c_has_skip_values = "true" if node.has_skip_values else "false"
        return self.generic_visit(node, c_has_skip_values=c_has_skip_values)

    Connectivity = as_mako(
        "struct ${_this_node.tag}: connectivity<${max_neighbors},${c_has_skip_values}>{};"
    )

    def visit_Temporary(self, node: Temporary, **kwargs):
        c_vtype = self.DATA_TYPE_TO_STR[node.vtype]
        loctype = self.visit(
            self.location_type_from_dimensions(node.dimensions))
        return self.generic_visit(node,
                                  loctype=loctype,
                                  c_vtype=c_vtype,
                                  **kwargs)

    Temporary = as_mako("""
        auto ${ name } = make_simple_tmp_storage<${ c_vtype }>(
            d.${ loctype }, d.k, alloc);""")

    def visit_TemporarySparseField(self, node: TemporarySparseField, *,
                                   symtable, **kwargs):
        c_vtype = self.DATA_TYPE_TO_STR[node.vtype]
        loctype = self.visit(
            self.location_type_from_dimensions(node.dimensions))
        connectivity_deref = symtable[node.connectivity]
        return self.generic_visit(
            node,
            s_size=connectivity_deref.max_neighbors,
            c_vtype=c_vtype,
            loctype=loctype,
            **kwargs,
        )

    TemporarySparseField = as_mako("""
        auto ${ name } = make_simple_sparse_tmp_storage<${ c_vtype }>(
            d.${ loctype }, d.k, ${s_size}, alloc);""")

    def visit_Kernel(self, node: Kernel, symtable, **kwargs):
        primary_signature = f"auto && {node.primary_composite.ptr_name}, auto&& {node.primary_composite.strides_name}"
        secondary_signature = (
            "" if len(node.secondary_composites) == 0 else ", auto &&" +
            ", auto&&".join(c.name for c in node.secondary_composites))
        return self.generic_visit(
            node,
            symtable={
                **symtable,
                **node.symtable_
            },
            primary_signature=primary_signature,
            secondary_signature=secondary_signature,
            **kwargs,
        )

    Kernel = as_mako("""
        struct ${name} {
            GT_FUNCTION auto operator()() const {
                return [](${primary_signature}${secondary_signature}){
                    ${''.join(body)}
                };
            }
        };
        """)

    def visit_Computation(self, node: Computation, **kwargs):
        # maybe tags should be generated in lowering
        field_tag_names = node.iter_tree().if_isinstance(
            SidCompositeEntry).getattr("name").to_set()
        connectivity_tag_names = (c.tag for c in node.connectivities)
        field_tags = [
            f"struct {field_tag};"
            for field_tag in field_tag_names.difference(connectivity_tag_names)
        ]

        connectivity_params = [f"auto&& {c.name}" for c in node.connectivities]
        field_params = [f"auto&& {f.name}" for f in node.parameters]

        connectivity_fields = [
            f"{c.name} = sid::rename_dimensions<dim::n, {c.tag}>(std::forward<decltype({c.name})>({c.name})(traits_t()))"
            for c in node.connectivities
        ]

        return self.generic_visit(
            node,
            field_tags=field_tags,
            connectivity_params=connectivity_params,
            connectivity_fields=connectivity_fields,
            field_params=field_params,
            symtable=node.symtable_,
            **kwargs,
        )

    Computation = as_mako("""
        ${ '\\n'.join('#include ' + header for header in _this_generator.headers_) }


        namespace ${ name }_impl_ {
            using namespace gridtools;
            using namespace gridtools::usid;
            using namespace gridtools::usid::${_this_generator.namespace_};
            ${ ''.join(connectivities)}
            ${ ''.join(field_tags) }

            ${ ''.join(kernels) }


            auto ${name} = [](domain d
                %if connectivity_params:
                , ${','.join(connectivity_params)}
                %endif
                ) {
                ${ ''.join(f"static_assert(is_sid<decltype({c.name}(traits_t()))>());" for c in _this_node.connectivities)}
                return
                    [d = std::move(d)
                    %if connectivity_fields:
                    , ${','.join(connectivity_fields)}
                    %endif
                            ](
                        ${','.join(field_params)}
                            ){
                            ${ ''.join(f"static_assert(is_sid<decltype({p.name})>());" for p in _this_node.parameters)}
                            %if temporaries:
                            auto alloc = make_allocator();
                            %endif
                            ${''.join(temporaries)}

                            ${''.join(ctrlflow_ast)}

                            };

            };
        }

        using ${ name }_impl_::${name};
        """)
예제 #13
0
파일: backend.py 프로젝트: havogt/gt4py
class DaCeBindingsCodegen:
    def __init__(self):
        self._unique_index: int = 0

    def unique_index(self) -> int:
        self._unique_index += 1
        return self._unique_index

    mako_template = as_mako("""#include <chrono>
           #include <pybind11/pybind11.h>
           #include <pybind11/stl.h>
           #include <gridtools/storage/adapter/python_sid_adapter.hpp>
           #include <gridtools/stencil/cartesian.hpp>
           #include <gridtools/stencil/global_parameter.hpp>
           #include <gridtools/sid/sid_shift_origin.hpp>
           #include <gridtools/sid/rename_dimensions.hpp>
           #include "computation.hpp"
           namespace gt = gridtools;
           namespace py = ::pybind11;
           PYBIND11_MODULE(${module_name}, m) {
               m.def("run_computation", [](
               ${','.join(["std::array<gt::uint_t, 3> domain", *entry_params, 'py::object exec_info'])}
               ){
                   if (!exec_info.is(py::none()))
                   {
                       auto exec_info_dict = exec_info.cast<py::dict>();
                       exec_info_dict["run_cpp_start_time"] = static_cast<double>(
                           std::chrono::duration_cast<std::chrono::nanoseconds>(
                               std::chrono::high_resolution_clock::now().time_since_epoch()).count())/1e9;
                   }

                   ${name}(domain)(${','.join(sid_params)});

                   if (!exec_info.is(py::none()))
                   {
                       auto exec_info_dict = exec_info.cast<py::dict>();
                       exec_info_dict["run_cpp_end_time"] = static_cast<double>(
                           std::chrono::duration_cast<std::chrono::nanoseconds>(
                               std::chrono::high_resolution_clock::now().time_since_epoch()).count()/1e9);
                   }

               }, "Runs the given computation");}
        """)

    def generate_entry_params(self, gtir: gtir.Stencil, sdfg: dace.SDFG):
        res = {}
        import dace.data

        for name in sdfg.signature_arglist(with_types=False, for_call=True):
            if name in sdfg.arrays:
                data = sdfg.arrays[name]
                assert isinstance(data, dace.data.Array)
                res[name] = "py::buffer {name}, std::array<gt::uint_t,{ndim}> {name}_origin".format(
                    name=name,
                    ndim=len(data.shape),
                )
            elif name in sdfg.symbols and not name.startswith("__"):
                assert name in sdfg.symbols
                res[name] = "{dtype} {name}".format(
                    dtype=sdfg.symbols[name].ctype, name=name)
        return list(res[node.name] for node in gtir.params if node.name in res)

    def generate_sid_params(self, sdfg: dace.SDFG):
        res = []
        import dace.data

        for name, array in sdfg.arrays.items():
            if array.transient:
                continue
            dimensions = array_dimensions(array)
            domain_ndim = sum(dimensions)
            data_ndim = len(array.shape) - domain_ndim
            sid_def = """gt::as_{sid_type}<{dtype}, {num_dims},
                gt::integral_constant<int, {unique_index}>>({name})""".format(
                sid_type="cuda_sid" if array.storage in [
                    dace.StorageType.GPU_Global, dace.StorageType.GPU_Shared
                ] else "sid",
                name=name,
                dtype=array.dtype.ctype,
                unique_index=self.unique_index(),
                num_dims=len(array.shape),
            )
            sid_def = "gt::sid::shift_sid_origin({sid_def}, {name}_origin)".format(
                sid_def=sid_def, name=name)

            if domain_ndim != 3:
                gt_dims = [
                    f"gt::stencil::dim::{dim}" for dim in "ijk" if any(
                        dace.symbolic.pystr_to_symbolic(f"__{dim.upper()}") in
                        s.free_symbols for s in array.shape
                        if hasattr(s, "free_symbols"))
                ]
                if data_ndim:
                    gt_dims += [
                        f"gt::integral_constant<int, {3 + dim}>"
                        for dim in range(data_ndim)
                    ]
                sid_def = "gt::sid::rename_numbered_dimensions<{gt_dims}>({sid_def})".format(
                    gt_dims=", ".join(gt_dims), sid_def=sid_def)

            res.append(sid_def)
        # pass scalar parameters as variables
        for name in (n for n in sdfg.symbols.keys() if not n.startswith("__")):
            res.append(name)
        return res

    def generate_sdfg_bindings(self, gtir, sdfg, module_name):

        return self.mako_template.render_values(
            name=sdfg.name,
            module_name=module_name,
            entry_params=self.generate_entry_params(gtir, sdfg),
            sid_params=self.generate_sid_params(sdfg),
        )

    @classmethod
    def apply(cls, gtir: gtir.Stencil, sdfg: dace.SDFG,
              module_name: str) -> str:
        generated_code = cls().generate_sdfg_bindings(gtir,
                                                      sdfg,
                                                      module_name=module_name)
        formatted_code = codegen.format_source("cpp",
                                               generated_code,
                                               style="LLVM")
        return formatted_code
예제 #14
0
class GTCppCodegen(codegen.TemplatedGenerator):

    GTExtent = as_fmt("extent<{i[0]},{i[1]},{j[0]},{j[1]},{k[0]},{k[1]}>")

    GTAccessor = as_fmt("using {name} = {intent}_accessor<{id}, {extent}>;")

    GTParamList = as_mako(
        """${ '\\n'.join(accessors) }

        using param_list = make_param_list<${ ','.join(a.name for a in _this_node.accessors)}>;
        """
    )

    GTFunctor = as_mako(
        """struct ${ name } {
        ${param_list}

        ${ '\\n'.join(applies) }
    };
    """
    )

    GTLevel = as_fmt("gridtools::stencil::core::level<{splitter}, {offset}, {offset_limit}>")

    GTInterval = as_fmt("gridtools::stencil::core::interval<{from_level}, {to_level}>")

    LocalVarDecl = as_fmt("{dtype} {name};")

    GTApplyMethod = as_mako(
        """
    template<typename Evaluation>
    GT_FUNCTION static void apply(Evaluation eval, ${interval}) {
        ${ ' '.join(local_variables) }
        ${ '\\n'.join(body) }
    }
    """
    )

    AssignStmt = as_fmt("{left} = {right};")

    AccessorRef = as_fmt("eval({name}({offset}))")

    ScalarAccess = as_fmt("{name}")

    CartesianOffset = as_fmt("{i}, {j}, {k}")

    BinaryOp = as_fmt("({left} {op} {right})")

    UnaryOp = as_fmt("({op}{expr})")

    TernaryOp = as_fmt("({cond} ? {true_expr} : {false_expr})")

    Cast = as_fmt("static_cast<{dtype}>({expr})")

    def visit_BuiltInLiteral(self, builtin: BuiltInLiteral, **kwargs: Any) -> str:
        if builtin == BuiltInLiteral.TRUE:
            return "true"
        elif builtin == BuiltInLiteral.FALSE:
            return "false"
        raise NotImplementedError("Not implemented BuiltInLiteral encountered.")

    Literal = as_mako("static_cast<${dtype}>(${value})")

    def visit_NativeFunction(self, func: NativeFunction, **kwargs: Any) -> str:
        if func == NativeFunction.SQRT:
            return "gridtools::math::sqrt"
        elif func == NativeFunction.MIN:
            return "gridtools::math::min"
        elif func == NativeFunction.MAX:
            return "gridtools::math::max"
        raise NotImplementedError("Not implemented NativeFunction encountered.")

    NativeFuncCall = as_mako("${func}(${','.join(args)})")

    def visit_DataType(self, dtype: DataType, **kwargs: Any) -> str:
        if dtype == DataType.INT64:
            return "long long"
        elif dtype == DataType.FLOAT64:
            return "double"
        elif dtype == DataType.FLOAT32:
            return "float"
        elif dtype == DataType.BOOL:
            return "bool"
        raise NotImplementedError("Not implemented NativeFunction encountered.")

    def visit_UnaryOperator(self, op: UnaryOperator, **kwargs: Any) -> str:
        if op == UnaryOperator.NOT:
            return "!"
        elif op == UnaryOperator.NEG:
            return "-"
        elif op == UnaryOperator.POS:
            return "+"
        raise NotImplementedError("Not implemented UnaryOperator encountered.")

    Arg = as_fmt("{name}")

    Param = as_fmt("{name}")

    ApiParamDecl = as_fmt("{name}")

    GTStage = as_mako(".stage(${functor}(), ${','.join(args)})")

    GTMultiStage = as_mako("execute_${ loop_order }()${''.join(caches)}${''.join(stages)}")

    IJCache = as_fmt(".ij_cached({name})")
    KCache = as_mako(
        ".k_cached(${'cache_io_policy::fill(), ' if _this_node.fill else ''}${'cache_io_policy::flush(), ' if _this_node.flush else ''}${name})"
    )

    def visit_LoopOrder(self, looporder: LoopOrder, **kwargs: Any) -> str:
        return {
            LoopOrder.PARALLEL: "parallel",
            LoopOrder.FORWARD: "forward",
            LoopOrder.BACKWARD: "backward",
        }[looporder]

    Temporary = as_fmt("GT_DECLARE_TMP({dtype}, {name});")

    IfStmt = as_mako(
        """if(${cond}) ${true_branch}
        %if _this_node.false_branch:
            else ${false_branch}
        %endif
        """
    )

    BlockStmt = as_mako("{${''.join(body)}}")

    def visit_GTComputationCall(
        self, node: gtcpp.GTComputationCall, **kwargs: Any
    ) -> Union[str, Collection[str]]:
        return self.generic_visit(node, computation_name=node.id_, **kwargs)

    GTComputationCall = as_mako(
        """
        %if len(multi_stages) > 0 and len(arguments) > 0:
        {
            auto grid = make_grid(domain[0], domain[1], axis<1,
                axis_config::offset_limit<${offset_limit}>>{domain[2]});

            auto ${ computation_name } = [](${ ','.join('auto ' + a for a in arguments) }) {

                ${ '\\n'.join(temporaries) }
                return multi_pass(${ ','.join(multi_stages) });
            };

            run(${computation_name}, ${gt_backend_t}<>{}, grid, ${','.join(arguments)});
        }
        %endif
        """
    )

    Program = as_mako(
        """#include <gridtools/stencil/${gt_backend_t}.hpp>
        #include <gridtools/stencil/cartesian.hpp>

        namespace ${ name }_impl_{
            using Domain = std::array<gridtools::uint_t, 3>;
            using namespace gridtools::stencil;
            using namespace gridtools::stencil::cartesian;

            ${'\\n'.join(functors)}

            auto ${name}(Domain domain){
                return [domain](${ ','.join( 'auto&& ' + p for p in parameters)}){
                    ${gt_computation}
                };
            }
        }

        auto ${name}(${name}_impl_::Domain domain){
            return ${name}_impl_::${name}(domain);
        }
        """
    )

    @classmethod
    def apply(cls, root: LeafNode, **kwargs: Any) -> str:
        if not isinstance(root, gtcpp.Program):
            raise ValueError("apply() requires gtcpp.Progam root node")
        if "gt_backend_t" not in kwargs:
            raise TypeError("apply() missing 1 required keyword-only argument: 'gt_backend_t'")
        generated_code = super().apply(root, offset_limit=_offset_limit(root), **kwargs)
        formatted_code = codegen.format_source("cpp", generated_code, style="LLVM")
        return formatted_code
예제 #15
0
class GTCppBindingsCodegen(codegen.TemplatedGenerator):
    def __init__(self):
        self._unique_index: int = 0

    def unique_index(self) -> int:
        self._unique_index += 1
        return self._unique_index

    def visit_DataType(self, dtype: DataType, **kwargs):
        if dtype == DataType.INT64:
            return "long long"
        elif dtype == DataType.FLOAT64:
            return "double"
        elif dtype == DataType.FLOAT32:
            return "float"
        elif dtype == DataType.BOOL:
            return "bool"
        else:
            raise AssertionError(f"Invalid DataType value: {dtype}")

    def visit_FieldDecl(self, node: gtcpp.FieldDecl, **kwargs):
        assert "gt_backend_t" in kwargs
        if "external_arg" in kwargs:
            if kwargs["external_arg"]:
                return "py::buffer {name}, std::array<gt::uint_t,3> {name}_origin".format(
                    name=node.name)
            else:
                return """gt::sid::shift_sid_origin(gt::as_{sid_type}<{dtype}, 3,
                    std::integral_constant<int, {unique_index}>>({name}), {name}_origin)""".format(
                    name=node.name,
                    dtype=self.visit(node.dtype),
                    unique_index=self.unique_index(),
                    sid_type="cuda_sid"
                    if kwargs["gt_backend_t"] == "gpu" else "sid",
                )

    def visit_GlobalParamDecl(self, node: gtcpp.GlobalParamDecl, **kwargs):
        if "external_arg" in kwargs:
            if kwargs["external_arg"]:
                return "{dtype} {name}".format(name=node.name,
                                               dtype=self.visit(node.dtype))
            else:
                return "gridtools::stencil::make_global_parameter({name})".format(
                    name=node.name)

    def visit_Program(self, node: gtcpp.Program, **kwargs):
        assert "module_name" in kwargs
        entry_params = self.visit(node.parameters, external_arg=True, **kwargs)
        sid_params = self.visit(node.parameters, external_arg=False, **kwargs)
        return self.generic_visit(
            node,
            entry_params=entry_params,
            sid_params=sid_params,
            **kwargs,
        )

    Program = as_mako("""
        #include <chrono>
        #include <pybind11/pybind11.h>
        #include <pybind11/stl.h>
        #include <gridtools/storage/adapter/python_sid_adapter.hpp>
        #include <gridtools/stencil/global_parameter.hpp>
        #include <gridtools/sid/sid_shift_origin.hpp>
        #include "computation.hpp"
        namespace gt = gridtools;
        namespace py = ::pybind11;
        %if len(entry_params) > 0:
        PYBIND11_MODULE(${module_name}, m) {
            m.def("run_computation", [](std::array<gt::uint_t, 3> domain,
            ${','.join(entry_params)},
            py::object exec_info){
                if (!exec_info.is(py::none()))
                {
                    auto exec_info_dict = exec_info.cast<py::dict>();
                    exec_info_dict["run_cpp_start_time"] = static_cast<double>(
                        std::chrono::duration_cast<std::chrono::nanoseconds>(
                            std::chrono::high_resolution_clock::now().time_since_epoch()).count())/1e9;
                }

                ${name}(domain)(${','.join(sid_params)});

                if (!exec_info.is(py::none()))
                {
                    auto exec_info_dict = exec_info.cast<py::dict>();
                    exec_info_dict["run_cpp_end_time"] = static_cast<double>(
                        std::chrono::duration_cast<std::chrono::nanoseconds>(
                            std::chrono::high_resolution_clock::now().time_since_epoch()).count()/1e9);
                }

            }, "Runs the given computation");}
        %endif
        """)

    @classmethod
    def apply(cls, root, *, module_name="stencil", **kwargs) -> str:
        generated_code = cls().visit(root, module_name=module_name, **kwargs)
        formatted_code = codegen.format_source("cpp",
                                               generated_code,
                                               style="LLVM")
        return formatted_code
예제 #16
0
class NaiveCodeGenerator(codegen.TemplatedGenerator):
    DATA_TYPE_TO_STR: ClassVar[Mapping[common.DataType,
                                       str]] = MappingProxyType({
                                           common.DataType.BOOLEAN:
                                           "bool",
                                           common.DataType.INT32:
                                           "int",
                                           common.DataType.UINT32:
                                           "unsigned_int",
                                           common.DataType.FLOAT32:
                                           "float",
                                           common.DataType.FLOAT64:
                                           "double",
                                       })

    LOCATION_TYPE_TO_STR_MAP: ClassVar[Mapping[LocationType, Mapping[
        str, str]]] = MappingProxyType({
            LocationType.Node:
            MappingProxyType({
                "singular": "vertex",
                "plural": "vertices"
            }),
            LocationType.Edge:
            MappingProxyType({
                "singular": "edge",
                "plural": "edges"
            }),
            LocationType.Face:
            MappingProxyType({
                "singular": "cell",
                "plural": "cells"
            }),
        })

    @classmethod
    def apply(cls, root, **kwargs) -> str:
        generated_code = super().apply(root, **kwargs)
        formatted_code = codegen.format_source("cpp",
                                               generated_code,
                                               style="LLVM")
        return formatted_code

    def visit_DataType(self, node, **kwargs) -> str:
        return self.DATA_TYPE_TO_STR[node]

    def visit_LocationType(self, node, **kwargs) -> Mapping[str, str]:
        return self.LOCATION_TYPE_TO_STR_MAP[node]

    Node = as_mako(
        "${_this_node.__class__.__name__.upper()}")  # only for testing

    UnstructuredField = as_mako("""<%
loc_type = location_type["singular"]
sparseloc = "sparse_" if _this_node.sparse_location_type else ""
%>
dawn::${ sparseloc }${ loc_type }_field_t<LibTag, ${ data_type }>& ${ name };"""
                                )

    FieldAccessExpr = as_mako("""<%
sparse_index = "m_sparse_dimension_idx, " if _this_node.is_sparse else ""
field_acc_itervar = outer_iter_var if _this_node.is_sparse else iter_var
%>${ name }(deref(LibTag{}, ${ field_acc_itervar }), ${ sparse_index } k)""")

    AssignmentExpr = as_fmt("{left} = {right}")

    VarAccessExpr = as_fmt("{name}")

    BinaryOp = as_fmt("{left} {op} {right}")

    ExprStmt = as_fmt("\n{expr};")

    VarDeclStmt = as_fmt("\n{data_type} {name};")

    TemporaryFieldDeclStmt = as_mako("""using dawn::allocateEdgeField;
        auto ${ name } = allocate${ location_type['singular'].capitalize() }Field<${ data_type }>(mesh);"""
                                     )

    ForK = as_mako("""<%
if _this_node.loop_order == _this_module.common.LoopOrder.FORWARD:
    k_init = '0'
    k_cond = 'k < k_size'
    k_step = '++k'
else:
    k_init = 'k_size -1'
    k_cond = 'k >= 0'
    k_step = '--k'
%>for (int k = ${k_init}; ${k_cond}; ${k_step}) {
int m_sparse_dimension_idx;
${ "".join(horizontal_loops) }\n}""")

    HorizontalLoop = as_mako("""<%
loc_type = location_type['plural'].title()
%>for(auto const & t: get${ loc_type }(LibTag{}, mesh)) ${ ast }""")

    def visit_HorizontalLoop(self, node, **kwargs) -> str:
        return self.generic_visit(node, iter_var="t", **kwargs)

    BlockStmt = as_mako("{${ ''.join(statements) }\n}")

    ReduceOverNeighbourExpr = as_mako("""<%
right_loc_type = right_location_type["singular"].title()
loc_type = location_type["singular"].title()
%>(m_sparse_dimension_idx=0,reduce${ right_loc_type }To${ loc_type }(mesh, ${ outer_iter_var }, ${ init }, [&](auto& lhs, auto const& ${ iter_var }) {
lhs ${ operation }= ${ right };
m_sparse_dimension_idx++;
return lhs;
}))""")

    def visit_ReduceOverNeighbourExpr(self, node, *, iter_var,
                                      **kwargs) -> str:
        outer_iter_var = iter_var
        return self.generic_visit(
            node,
            outer_iter_var=outer_iter_var,
            iter_var="redIdx",
            **kwargs,
        )

    LiteralExpr = as_fmt("({data_type}){value}")

    Stencil = as_mako("""
void ${name}() {
using dawn::deref;

${ "\\n".join(declarations) if _this_node.declarations else ""}

${ "".join(k_loops) }
}
""")

    Computation = as_mako("""<%
stencil_calls = '\\n'.join("{name}();".format(name=s.name) for s in _this_node.stencils)
ctor_field_params = ', '.join(
    'dawn::{sparse_loc}{loc_type}_field_t<LibTag, {data_type}>& {name}'.format(
        loc_type=_this_generator.LOCATION_TYPE_TO_STR_MAP[p.location_type]['singular'],
        name=p.name,
        data_type=_this_generator.DATA_TYPE_TO_STR[p.data_type],
        sparse_loc="sparse_" if p.sparse_location_type else ""
    )
    for p in _this_node.params
)
ctor_field_initializers = ', '.join(
    '{name}({name})'.format(name=p.name) for p in _this_node.params
)
%>#define DAWN_GENERATED 1
#define DAWN_BACKEND_T CXXNAIVEICO
#include <driver-includes/unstructured_interface.hpp>
namespace dawn_generated {
namespace cxxnaiveico {
template <typename LibTag>
class generated {
private:
dawn::mesh_t<LibTag>& mesh;
int const k_size;

${ ''.join(params) }
${ ''.join(stencils) }

public:
generated(dawn::mesh_t<LibTag>& mesh, int k_size, ${ ctor_field_params }): mesh(mesh), k_size(k_size), ${ ctor_field_initializers } {}

void run() {
${ stencil_calls }
}
};
}
}

""")