Exemple #1
0
def test_make_args_data_from_gtir(backend_name, mode):
    backend_cls = backend_registry[backend_name]
    builder = StencilBuilder(stencil_def, backend=backend_cls).with_externals({"MODE": mode})
    args_data = make_args_data_from_gtir(builder.gtir_pipeline)
    iir_args_data = make_args_data_from_iir(builder.implementation_ir)

    assert args_data.field_info == iir_args_data.field_info
    assert args_data.parameter_info == iir_args_data.parameter_info
    assert args_data.unreferenced == iir_args_data.unreferenced
Exemple #2
0
def test_generate_post_run(backend_name, mode):
    backend_cls = backend_registry[backend_name]
    builder = StencilBuilder(stencil_def, backend=backend_cls).with_externals({"MODE": mode})
    iir = builder.implementation_ir
    args_data = make_args_data_from_iir(iir)

    module_generator = backend_cls.MODULE_GENERATOR_CLASS()
    module_generator.args_data = args_data
    source = module_generator.generate_post_run()

    if gt_backend.from_name(backend_name).storage_info["device"] == "cpu":
        assert source == ""
    else:
        assert source == "out._set_device_modified()"
Exemple #3
0
def test_generate_pre_run(backend_name, mode):
    backend_cls = backend_registry[backend_name]
    builder = StencilBuilder(stencil_def, backend=backend_cls).with_externals({"MODE": mode})
    iir = builder.implementation_ir
    args_data = make_args_data_from_iir(iir)

    module_generator = backend_cls.MODULE_GENERATOR_CLASS()
    module_generator.args_data = args_data
    source = module_generator.generate_pre_run()

    if gt_backend.from_name(backend_name).storage_info["device"] == "cpu":
        if "dawn:" not in backend_name:
            assert source == ""
    else:
        for key in field_info_val[mode]:
            assert f"{key}.host_to_device()" in source
        for key in unreferenced_val[mode]:
            assert f"{key}.host_to_device()" not in source
Exemple #4
0
def test_device_sync_option(backend_name, mode, device_sync):
    backend_cls = backend_registry[backend_name]
    builder = StencilBuilder(stencil_def, backend=backend_cls).with_externals({"MODE": mode})
    builder.options.backend_opts["device_sync"] = device_sync
    if backend_cls.USE_LEGACY_TOOLCHAIN:
        args_data = make_args_data_from_iir(builder.implementation_ir)
    else:
        args_data = make_args_data_from_gtir(builder.gtir_pipeline)
    module_generator = backend_cls.MODULE_GENERATOR_CLASS()
    source = module_generator(
        args_data,
        builder,
        pyext_module_name=builder.module_name,
        pyext_file_path=str(builder.module_path),
    )

    if device_sync:
        assert "cupy.cuda.Device(0).synchronize()" in source
    else:
        assert "cupy.cuda.Device(0).synchronize()" not in source