Ejemplo n.º 1
0
    def make_extension_sources(self, gt_backend_t: str) -> Dict[str, Any]:
        stencil_short_name = self.builder.stencil_id.qualified_name.split(
            ".")[-1]
        backend_opts = dict(**self.builder.options.backend_opts)
        dawn_namespace = self.DAWN_BACKEND_NS

        dump_sir_opt = backend_opts.get("dump_sir", False)
        if dump_sir_opt:
            if isinstance(dump_sir_opt, str):
                dump_sir_file = dump_sir_opt
            else:
                assert isinstance(dump_sir_opt, bool)
                dump_sir_file = f"{stencil_short_name}_gt4py.sir"
            with open(dump_sir_file, "w") as f:
                f.write(sir_utils.to_json(self.sir))

        # Get list of pass groups
        if "no_opt" in backend_opts:
            pass_groups = []
        elif "opt_groups" in backend_opts:
            pass_groups = [
                DAWN_PASS_GROUPS[k] for k in backend_opts["opt_groups"]
            ]
            if "default_opt" in backend_opts:
                raise ValueError(
                    "Do not add 'default_opt' when opt 'opt_groups'. " +
                    "Instead, append dawn4py.default_pass_groups()")
        else:
            pass_groups = dawn4py.default_pass_groups()

        # If present, parse backend string
        dawn_backend = DAWN_CODEGEN_BACKENDS[self.DAWN_BACKEND_NAME
                                             or "GridTools"]

        dawn_opts = {
            key: value
            for key, value in backend_opts.items()
            if key in _DAWN_TOOLCHAIN_OPTIONS.keys()
        }
        source = dawn4py.compile(self.sir,
                                 groups=pass_groups,
                                 backend=dawn_backend,
                                 run_with_sync=False,
                                 **dawn_opts)
        stencil_unique_name = self.pyext_class_name
        module_name = self.pyext_module_name
        pyext_sources = {f"_dawn_{stencil_short_name}.hpp": source}

        arg_fields = [{
            "name": field.name,
            "dtype": self._DATA_TYPE_TO_CPP[field.data_type],
            "layout_id": i
        } for i, field in enumerate(self.builder.definition_ir.api_fields)]
        header_file = "computation.hpp"
        parameters = []
        for parameter in self.builder.definition_ir.parameters:
            if parameter.data_type in [gt_ir.DataType.BOOL]:
                dtype = "bool"
            elif parameter.data_type in [
                    gt_ir.DataType.INT8,
                    gt_ir.DataType.INT16,
                    gt_ir.DataType.INT32,
                    gt_ir.DataType.INT64,
            ]:
                dtype = "int"
            elif parameter.data_type in [
                    gt_ir.DataType.FLOAT32, gt_ir.DataType.FLOAT64
            ]:
                dtype = "double"
            else:
                assert False, "Wrong data_type for parameter"
            parameters.append({"name": parameter.name, "dtype": dtype})

        template_args = dict(
            arg_fields=arg_fields,
            dawn_namespace=dawn_namespace,
            gt_backend=gt_backend_t,
            header_file=header_file,
            module_name=module_name,
            parameters=parameters,
            stencil_short_name=stencil_short_name,
            stencil_unique_name=stencil_unique_name,
        )

        for key, file_name in self.TEMPLATE_FILES.items():
            with open(os.path.join(self.TEMPLATE_DIR, file_name), "r") as f:
                template = jinja2.Template(f.read())
                pyext_sources[key] = template.render(**template_args)

        return pyext_sources
Ejemplo n.º 2
0
def main():
    stencil_name = "ICON_laplacian_stencil"
    gen_outputfile = f"{stencil_name}.cpp"
    sir_outputfile = f"{stencil_name}.sir"

    interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0)

    body_ast = sir_utils.make_ast(
        [
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("rot_vec"),
                sir_utils.make_reduction_over_neighbor_expr(
                    op="+",
                    init=sir_utils.make_literal_access_expr("0.0", SIR.BuiltinType.Double),
                    rhs=sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("vec", [True, 0]),
                        "*",
                        sir_utils.make_field_access_expr("geofac_rot"),
                    ),
                    chain=[SIR.LocationType.Value("Vertex"), SIR.LocationType.Value("Edge")],
                ),
                "=",
            ),
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("div_vec"),
                sir_utils.make_reduction_over_neighbor_expr(
                    op="+",
                    init=sir_utils.make_literal_access_expr("0.0", SIR.BuiltinType.Double),
                    rhs=sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("vec", [True, 0]),
                        "*",
                        sir_utils.make_field_access_expr("geofac_div"),
                    ),
                    chain=[SIR.LocationType.Value("Cell"), SIR.LocationType.Value("Edge")],
                ),
                "=",
            ),
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("nabla2t1_vec"),
                sir_utils.make_reduction_over_neighbor_expr(
                    op="+",
                    init=sir_utils.make_literal_access_expr(
                        "0.0", SIR.BuiltinType.Double),
                    rhs=sir_utils.make_field_access_expr("rot_vec", [True, 0]),
                    chain=[SIR.LocationType.Value(
                        "Edge"), SIR.LocationType.Value("Vertex")],
                    weights=[sir_utils.make_literal_access_expr(
                        "-1.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr(
                        "1.0", SIR.BuiltinType.Double)]
                ),
                "=",
            ),
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("nabla2t1_vec"),
                sir_utils.make_binary_operator(
                    sir_utils.make_binary_operator(
                        sir_utils.make_field_access_expr("tangent_orientation"),
                        "*",
                        sir_utils.make_field_access_expr("nabla2t1_vec"),
                    ),
                    "/",
                    sir_utils.make_field_access_expr("primal_edge_length"),
                ),
                "=",
            ),
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("nabla2t2_vec"),
                sir_utils.make_reduction_over_neighbor_expr(
                    op="+",
                    init=sir_utils.make_literal_access_expr(
                        "0.0", SIR.BuiltinType.Double),
                    rhs=sir_utils.make_field_access_expr("div_vec", [True, 0]),
                    chain=[SIR.LocationType.Value(
                        "Edge"), SIR.LocationType.Value("Cell")],
                    weights=[sir_utils.make_literal_access_expr(
                        "-1.0", SIR.BuiltinType.Double), sir_utils.make_literal_access_expr(
                        "1.0", SIR.BuiltinType.Double)]
                ),
                "=",
            ),
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("nabla2t2_vec"),
                sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("nabla2t2_vec"),
                    "/",
                    sir_utils.make_field_access_expr("dual_edge_length"),
                ),
                "=",
            ),
            sir_utils.make_assignment_stmt(
                sir_utils.make_field_access_expr("nabla2_vec"),
                sir_utils.make_binary_operator(
                    sir_utils.make_field_access_expr("nabla2t2_vec"),
                    "-",
                    sir_utils.make_field_access_expr("nabla2t1_vec"),
                ),
                "=",
            ),
        ]
    )

    vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward
    )

    sir = sir_utils.make_sir(
        gen_outputfile,
        SIR.GridType.Value("Unstructured"),
        [
            sir_utils.make_stencil(
                stencil_name,
                sir_utils.make_ast([vertical_region_stmt]),
                [
                    sir_utils.make_field(
                        "vec",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "div_vec",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Cell")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "rot_vec",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Vertex")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "nabla2t1_vec",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "nabla2t2_vec",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "nabla2_vec",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "primal_edge_length",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "dual_edge_length",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "tangent_orientation",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "geofac_rot",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Vertex"), SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                    sir_utils.make_field(
                        "geofac_div",
                        sir_utils.make_field_dimensions_unstructured(
                            [SIR.LocationType.Value("Cell"), SIR.LocationType.Value("Edge")], 1
                        ),
                    ),
                ],
            ),
        ],
    )

    # write SIR to file (for debugging purposes)
    f = open(sir_outputfile, "w")
    f.write(sir_utils.to_json(sir))
    f.close()

    # compile
    code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaiveIco)

    # write to file
    print(f"Writing generated code to '{gen_outputfile}'")
    with open(gen_outputfile, "w") as f:
        f.write(code)
Ejemplo n.º 3
0
def main(args: argparse.Namespace):
    interval = serial_utils.make_interval(SIR.Interval.Start, SIR.Interval.End,
                                          0, 0)

    # create the laplace statement
    body_ast = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("out", [0, 0, 0]),
            serial_utils.make_binary_operator(
                serial_utils.make_binary_operator(
                    serial_utils.make_binary_operator(
                        serial_utils.make_field_access_expr("in", [0, 0, 0]),
                        "*",
                        serial_utils.make_literal_access_expr(
                            "-4.0", serial_utils.BuiltinType.Float),
                    ),
                    "+",
                    serial_utils.make_binary_operator(
                        serial_utils.make_field_access_expr("in", [1, 0, 0]),
                        "+",
                        serial_utils.make_binary_operator(
                            serial_utils.make_field_access_expr(
                                "in", [-1, 0, 0]),
                            "+",
                            serial_utils.make_binary_operator(
                                serial_utils.make_field_access_expr(
                                    "in", [0, 1, 0]),
                                "+",
                                serial_utils.make_field_access_expr(
                                    "in", [0, -1, 0]),
                            ),
                        ),
                    ),
                ),
                "/",
                serial_utils.make_binary_operator(
                    serial_utils.make_var_access_expr("dx", is_external=True),
                    "*",
                    serial_utils.make_var_access_expr("dx", is_external=True),
                ),
            ),
            "=",
        ),
    ])

    vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt(
        body_ast, interval, SIR.VerticalRegion.Forward)

    stencils_globals = serial_utils.GlobalVariableMap()
    stencils_globals.map["dx"].double_value = 0.0

    sir = serial_utils.make_sir(
        OUTPUT_FILE,
        SIR.GridType.Value("Cartesian"),
        [
            serial_utils.make_stencil(
                OUTPUT_NAME,
                serial_utils.make_ast([vertical_region_stmt]),
                [
                    serial_utils.make_field(
                        "out", serial_utils.make_field_dimensions_cartesian()),
                    serial_utils.make_field(
                        "in", serial_utils.make_field_dimensions_cartesian()),
                ],
            )
        ],
        global_variables=stencils_globals,
    )

    # print the SIR
    if args.verbose:
        serial_utils.pprint(sir)

    # serialize the SIR to file
    sir_file = open("./laplacian_stencil_from_python.sir", "wb")
    sir_file.write(serial_utils.to_json(sir))
    sir_file.close()

    # compile
    code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaive)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
Ejemplo n.º 4
0
def main(args: argparse.Namespace):
    interval = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End,
                                          0, 0)

    body_ast = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("out"),
            serial_utils.make_binary_operator(
                serial_utils.make_field_access_expr("full"), "+",
                serial_utils.make_binary_operator(
                    serial_utils.make_field_access_expr("horizontal"), "+",
                    serial_utils.make_field_access_expr("vertical"))), "="),
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("out"),
            serial_utils.make_reduction_over_neighbor_expr(
                op="+",
                init=serial_utils.make_literal_access_expr(
                    "1.0", AST.BuiltinType.Float),
                rhs=serial_utils.make_field_access_expr(
                    "horizontal_sparse", [True, 0]),
                chain=[
                    AST.LocationType.Value("Edge"),
                    AST.LocationType.Value("Cell")
                ],
            ),
            "=",
        )
    ])

    vertical_region_stmt = serial_utils.make_vertical_region_decl_stmt(
        body_ast, interval, AST.VerticalRegion.Forward)

    sir = serial_utils.make_sir(
        OUTPUT_FILE,
        AST.GridType.Value("Unstructured"),
        [
            serial_utils.make_stencil(
                OUTPUT_NAME,
                serial_utils.make_ast([vertical_region_stmt]),
                [
                    serial_utils.make_field(
                        "out",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1),
                    ),
                    serial_utils.make_field(
                        "full",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 1),
                    ),
                    serial_utils.make_field(
                        "horizontal",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Edge")], 0),
                    ),
                    serial_utils.make_field(
                        "horizontal_sparse",
                        serial_utils.make_field_dimensions_unstructured([
                            AST.LocationType.Value("Edge"),
                            AST.LocationType.Value("Cell")
                        ], 0),
                    ),
                    serial_utils.make_vertical_field("vertical"),
                ],
            ),
        ],
    )

    # print the SIR
    if args.verbose:
        print(MessageToJson(sir))

    with open("out.json", "w+") as f:
        f.write(serial_utils.to_json(sir))

    # compile
    code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CUDAIco)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)