Ejemplo n.º 1
0
def test_compilation(unstructured_sir_with_reference_code):
    sir, reference_code = unstructured_sir_with_reference_code
    dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaiveIco)
    dawn4py.codegen(
        dawn4py.optimize(dawn4py.lower_and_optimize(sir, groups=[]),
                         groups=dawn4py.default_pass_groups()),
        backend=dawn4py.CodeGenBackend.CXXNaiveIco,
    )
def main(args: argparse.Namespace):
    interval = serial_utils.make_interval(AST.Interval.Start, AST.Interval.End,
                                          0, 0)

    # out = in_1 on inner cells
    body_ast_1 = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("out"),
            serial_utils.make_field_access_expr("in_1"),
            "=",
        )
    ])
    vertical_region_stmt_1 = serial_utils.make_vertical_region_decl_stmt(
        body_ast_1, interval, AST.VerticalRegion.Forward,
        serial_utils.make_magic_num_interval(0, 1, 0, 0))

    # out = out + in_2 on inner cells
    #   should be merge-able to last stage
    body_ast_2 = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("out"),
            serial_utils.make_binary_operator(
                serial_utils.make_field_access_expr("out"),
                "+",
                serial_utils.make_field_access_expr("in_2"),
            ), "=")
    ])
    vertical_region_stmt_2 = serial_utils.make_vertical_region_decl_stmt(
        body_ast_2, interval, AST.VerticalRegion.Forward,
        serial_utils.make_interval(2, 3, 0, 0))

    # out = out + in_3 on lateral boundary cells
    # out = in_1 on inner cells
    body_ast_3 = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("out"),
            serial_utils.make_field_access_expr("in_3"),
            "=",
        )
    ])
    vertical_region_stmt_3 = serial_utils.make_vertical_region_decl_stmt(
        body_ast_3, interval, AST.VerticalRegion.Forward,
        serial_utils.make_interval(3, 4, 0, 0))

    sir = serial_utils.make_sir(
        OUTPUT_FILE,
        AST.GridType.Value("Unstructured"),
        [
            serial_utils.make_stencil(
                OUTPUT_NAME,
                serial_utils.make_ast([
                    vertical_region_stmt_1, vertical_region_stmt_2,
                    vertical_region_stmt_3
                ]),
                [
                    serial_utils.make_field(
                        "in_1",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Cell")], 1),
                    ),
                    serial_utils.make_field(
                        "in_2",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Cell")], 1),
                    ),
                    serial_utils.make_field(
                        "in_3",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Cell")], 1),
                    ),
                    serial_utils.make_field(
                        "out",
                        serial_utils.make_field_dimensions_unstructured(
                            [AST.LocationType.Value("Cell")], 1),
                    ),
                ],
            )
        ],
    )

    # print the SIR
    if args.verbose:
        print(MessageToJson(sir))

    # compile
    pass_groups = dawn4py.default_pass_groups()
    pass_groups.insert(1, dawn4py.PassGroup.MultiStageMerger)
    pass_groups.insert(1, dawn4py.PassGroup.StageMerger)
    # code = dawn4py.compile(sir, groups=pass_groups,
    #                        backend=dawn4py.CodeGenBackend.CXXNaiveIco, merge_stages=True, merge_do_methods=True)
    code = dawn4py.compile(sir,
                           groups=pass_groups,
                           backend=dawn4py.CodeGenBackend.CXXNaiveIco)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)
Ejemplo n.º 3
0
    def make_extension_sources(self, gt_backend_t: str) -> Dict[str, Any]:
        stencil_short_name = self.builder.stencil_id.qualified_name.split(
            ".")[-1]
        backend_opts = dict(**self.builder.options.backend_opts)
        dawn_namespace = self.DAWN_BACKEND_NS

        dump_sir_opt = backend_opts.get("dump_sir", False)
        if dump_sir_opt:
            if isinstance(dump_sir_opt, str):
                dump_sir_file = dump_sir_opt
            else:
                assert isinstance(dump_sir_opt, bool)
                dump_sir_file = f"{stencil_short_name}_gt4py.sir"
            with open(dump_sir_file, "w") as f:
                f.write(sir_utils.to_json(self.sir))

        # Get list of pass groups
        if "no_opt" in backend_opts:
            pass_groups = []
        elif "opt_groups" in backend_opts:
            pass_groups = [
                DAWN_PASS_GROUPS[k] for k in backend_opts["opt_groups"]
            ]
            if "default_opt" in backend_opts:
                raise ValueError(
                    "Do not add 'default_opt' when opt 'opt_groups'. " +
                    "Instead, append dawn4py.default_pass_groups()")
        else:
            pass_groups = dawn4py.default_pass_groups()

        # If present, parse backend string
        dawn_backend = DAWN_CODEGEN_BACKENDS[self.DAWN_BACKEND_NAME
                                             or "GridTools"]

        dawn_opts = {
            key: value
            for key, value in backend_opts.items()
            if key in _DAWN_TOOLCHAIN_OPTIONS.keys()
        }
        source = dawn4py.compile(self.sir,
                                 groups=pass_groups,
                                 backend=dawn_backend,
                                 run_with_sync=False,
                                 **dawn_opts)
        stencil_unique_name = self.pyext_class_name
        module_name = self.pyext_module_name
        pyext_sources = {f"_dawn_{stencil_short_name}.hpp": source}

        arg_fields = [{
            "name": field.name,
            "dtype": self._DATA_TYPE_TO_CPP[field.data_type],
            "layout_id": i
        } for i, field in enumerate(self.builder.definition_ir.api_fields)]
        header_file = "computation.hpp"
        parameters = []
        for parameter in self.builder.definition_ir.parameters:
            if parameter.data_type in [gt_ir.DataType.BOOL]:
                dtype = "bool"
            elif parameter.data_type in [
                    gt_ir.DataType.INT8,
                    gt_ir.DataType.INT16,
                    gt_ir.DataType.INT32,
                    gt_ir.DataType.INT64,
            ]:
                dtype = "int"
            elif parameter.data_type in [
                    gt_ir.DataType.FLOAT32, gt_ir.DataType.FLOAT64
            ]:
                dtype = "double"
            else:
                assert False, "Wrong data_type for parameter"
            parameters.append({"name": parameter.name, "dtype": dtype})

        template_args = dict(
            arg_fields=arg_fields,
            dawn_namespace=dawn_namespace,
            gt_backend=gt_backend_t,
            header_file=header_file,
            module_name=module_name,
            parameters=parameters,
            stencil_short_name=stencil_short_name,
            stencil_unique_name=stencil_unique_name,
        )

        for key, file_name in self.TEMPLATE_FILES.items():
            with open(os.path.join(self.TEMPLATE_DIR, file_name), "r") as f:
                template = jinja2.Template(f.read())
                pyext_sources[key] = template.render(**template_args)

        return pyext_sources
Ejemplo n.º 4
0
def main(args: argparse.Namespace):

    # ---- First vertical region statement ----
    interval_1 = serial_utils.make_interval(AST.Interval.Start,
                                            AST.Interval.End, 0, 0)
    body_ast_1 = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("c"),
            serial_utils.make_binary_operator(
                serial_utils.make_field_access_expr("c"),
                "/",
                serial_utils.make_field_access_expr("b"),
            ),
            "=",
        )
    ])

    vertical_region_stmt_1 = serial_utils.make_vertical_region_decl_stmt(
        body_ast_1, interval_1, AST.VerticalRegion.Forward)

    # ---- Second vertical region statement ----
    interval_2 = serial_utils.make_interval(AST.Interval.Start,
                                            AST.Interval.End, 1, 0)

    body_ast_2 = serial_utils.make_ast([
        serial_utils.make_var_decl_stmt(
            serial_utils.make_type(AST.BuiltinType.Integer),
            "m",
            0,
            "=",
            serial_utils.make_expr(
                serial_utils.make_binary_operator(
                    serial_utils.make_literal_access_expr(
                        "1.0", AST.BuiltinType.Float),
                    "/",
                    serial_utils.make_binary_operator(
                        serial_utils.make_field_access_expr("b"),
                        "-",
                        serial_utils.make_binary_operator(
                            serial_utils.make_field_access_expr("a"),
                            "*",
                            serial_utils.make_field_access_expr(
                                "c", [0, 0, -1]),
                        ),
                    ),
                )),
        ),
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("c"),
            serial_utils.make_binary_operator(
                serial_utils.make_field_access_expr("c"), "*",
                serial_utils.make_var_access_expr("m")),
            "=",
        ),
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("d"),
            serial_utils.make_binary_operator(
                serial_utils.make_binary_operator(
                    serial_utils.make_field_access_expr("d"),
                    "-",
                    serial_utils.make_binary_operator(
                        serial_utils.make_field_access_expr("a"),
                        "*",
                        serial_utils.make_field_access_expr("d", [0, 0, -1]),
                    ),
                ),
                "*",
                serial_utils.make_var_access_expr("m"),
            ),
            "=",
        ),
    ])
    vertical_region_stmt_2 = serial_utils.make_vertical_region_decl_stmt(
        body_ast_2, interval_2, AST.VerticalRegion.Forward)

    # ---- Third vertical region statement ----
    interval_3 = serial_utils.make_interval(AST.Interval.Start,
                                            AST.Interval.End, 0, -1)
    body_ast_3 = serial_utils.make_ast([
        serial_utils.make_assignment_stmt(
            serial_utils.make_field_access_expr("d"),
            serial_utils.make_binary_operator(
                serial_utils.make_field_access_expr("c"),
                "*",
                serial_utils.make_field_access_expr("d", [0, 0, 1]),
            ),
            "-=",
        )
    ])

    vertical_region_stmt_3 = serial_utils.make_vertical_region_decl_stmt(
        body_ast_3, interval_3, AST.VerticalRegion.Backward)

    sir = serial_utils.make_sir(
        OUTPUT_FILE,
        AST.GridType.Value("Cartesian"),
        [
            serial_utils.make_stencil(
                OUTPUT_NAME,
                serial_utils.make_ast([
                    vertical_region_stmt_1, vertical_region_stmt_2,
                    vertical_region_stmt_3
                ]),
                [
                    serial_utils.make_field(
                        "a", serial_utils.make_field_dimensions_cartesian()),
                    serial_utils.make_field(
                        "b", serial_utils.make_field_dimensions_cartesian()),
                    serial_utils.make_field(
                        "c", serial_utils.make_field_dimensions_cartesian()),
                    serial_utils.make_field(
                        "d", serial_utils.make_field_dimensions_cartesian()),
                ],
            )
        ],
    )

    # print the SIR
    if args.verbose:
        print(MessageToJson(sir))

    # compile
    pass_groups = dawn4py.default_pass_groups()
    pass_groups.insert(1, dawn4py.PassGroup.MultiStageMerger)
    code = dawn4py.compile(sir,
                           groups=pass_groups,
                           backend=dawn4py.CodeGenBackend.CUDA)

    # write to file
    print(f"Writing generated code to '{OUTPUT_PATH}'")
    with open(OUTPUT_PATH, "w") as f:
        f.write(code)