예제 #1
0
def test_module_data_equivalence():
    builder = StencilBuilder(sample_stencil_with_args)

    legacy_module_data = make_args_data_from_iir(builder.implementation_ir)
    gtc_module_data = make_args_data_from_gtir(builder.gtir_pipeline)

    assert legacy_module_data == gtc_module_data
예제 #2
0
def test_make_args_data_from_gtir(backend_name, mode):
    backend_cls = backend_registry[backend_name]
    builder = StencilBuilder(stencil_def, backend=backend_cls).with_externals({"MODE": mode})
    args_data = make_args_data_from_gtir(builder.gtir_pipeline)
    iir_args_data = make_args_data_from_iir(builder.implementation_ir)

    assert args_data.field_info == iir_args_data.field_info
    assert args_data.parameter_info == iir_args_data.parameter_info
    assert args_data.unreferenced == iir_args_data.unreferenced
예제 #3
0
def test_module_data():
    builder = StencilBuilder(sample_stencil_with_args)
    module_data = make_args_data_from_gtir(builder.gtir_pipeline)

    assert module_data.field_info["used_io_field"].access == AccessKind.WRITE
    assert module_data.field_info["used_in_field"].access == AccessKind.READ
    assert module_data.field_info["unused_field"].access == AccessKind.NONE

    assert module_data.parameter_info["used_scalar"].access == AccessKind.READ
    assert module_data.parameter_info[
        "unused_scalar"].access == AccessKind.NONE
예제 #4
0
def test_generate_post_run(backend_name, mode):
    backend_cls = backend_registry[backend_name]
    builder = StencilBuilder(stencil_def, backend=backend_cls).with_externals(
        {"MODE": mode})
    args_data = make_args_data_from_gtir(builder.gtir_pipeline)

    module_generator = backend_cls.MODULE_GENERATOR_CLASS()
    module_generator.args_data = args_data
    source = module_generator.generate_post_run()

    if gt_backend.from_name(backend_name).storage_info["device"] == "cpu":
        assert source == ""
    else:
        assert source == "out._set_device_modified()"
예제 #5
0
def test_make_args_data_from_gtir(backend_name, mode):
    backend_cls = backend_registry[backend_name]
    builder = StencilBuilder(stencil_def, backend=backend_cls).with_externals(
        {"MODE": mode})
    args_data = make_args_data_from_gtir(builder.gtir_pipeline)

    assert set(args_data.unreferenced) == set(unreferenced_val[mode])

    field_info_from_gtir = {(
        p.name,
        np.dtype(p.dtype.name.lower()),
        utils.dimension_flags_to_names(p.dimensions).upper(),
        p.data_dims,
    )
                            for p in builder.gtir.params
                            if isinstance(p, gtir.FieldDecl)}
    field_info_from_args_data = {(name, d.dtype, "".join(d.axes), d.data_dims)
                                 for name, d in args_data.field_info.items()
                                 if name not in args_data.unreferenced}
    assert field_info_from_gtir == field_info_from_args_data

    param_info_from_gtir = {(p.name, np.dtype(p.dtype.name.lower()))
                            for p in builder.gtir.params
                            if isinstance(p, gtir.ScalarDecl)}
    param_info_from_args_data = {
        (name, d.dtype)
        for name, d in args_data.parameter_info.items()
        if name not in args_data.unreferenced
    }
    assert param_info_from_gtir == param_info_from_args_data

    for name, field_info in args_data.field_info.items():
        if name == "out":
            access = AccessKind.WRITE
        elif name in field_info_val[mode]:
            access = AccessKind.READ
        else:
            access = AccessKind.NONE
        assert field_info.access == access

    for name, param_info in args_data.parameter_info.items():
        if name in parameter_info_val[mode]:
            access = AccessKind.READ
        else:
            access = AccessKind.NONE
        assert param_info.access == access
예제 #6
0
def test_generate_pre_run(backend_name, mode):
    backend_cls = backend_registry[backend_name]
    builder = StencilBuilder(stencil_def, backend=backend_cls).with_externals(
        {"MODE": mode})
    args_data = make_args_data_from_gtir(builder.gtir_pipeline)

    module_generator = backend_cls.MODULE_GENERATOR_CLASS()
    module_generator.args_data = args_data
    source = module_generator.generate_pre_run()

    if gt_backend.from_name(backend_name).storage_info["device"] == "cpu":
        assert source == ""
    else:
        for key in field_info_val[mode]:
            assert f"{key}.host_to_device()" in source
        for key in unreferenced_val[mode]:
            assert f"{key}.host_to_device()" not in source
예제 #7
0
def test_device_sync_option(backend_name, mode, device_sync):
    backend_cls = backend_registry[backend_name]
    builder = StencilBuilder(stencil_def, backend=backend_cls).with_externals(
        {"MODE": mode})
    builder.options.backend_opts["device_sync"] = device_sync
    args_data = make_args_data_from_gtir(builder.gtir_pipeline)
    module_generator = backend_cls.MODULE_GENERATOR_CLASS()
    source = module_generator(
        args_data,
        builder,
        pyext_module_name=builder.module_name,
        pyext_file_path=str(builder.module_path),
    )

    if device_sync:
        assert "cupy.cuda.Device(0).synchronize()" in source
    else:
        assert "cupy.cuda.Device(0).synchronize()" not in source
예제 #8
0
def _expand_and_finalize_sdfg(stencil_ir: gtir.Stencil, sdfg: dace.SDFG,
                              layout_map) -> dace.SDFG:

    args_data = make_args_data_from_gtir(GtirPipeline(stencil_ir))

    # stencils without effect
    if all(info is None for info in args_data.field_info.values()):
        sdfg = dace.SDFG(stencil_ir.name)
        sdfg.add_state(stencil_ir.name)
        return sdfg

    for array in sdfg.arrays.values():
        if array.transient:
            array.lifetime = dace.AllocationLifetime.Persistent

    _pre_expand_trafos(sdfg)
    sdfg.expand_library_nodes(recursive=True)
    _specialize_transient_strides(sdfg, layout_map=layout_map)
    _post_expand_trafos(sdfg)

    return sdfg