def test_make_args_data_from_gtir(backend_name, mode): backend_cls = backend_registry[backend_name] builder = StencilBuilder(stencil_def, backend=backend_cls).with_externals({"MODE": mode}) args_data = make_args_data_from_gtir(builder.gtir_pipeline) iir_args_data = make_args_data_from_iir(builder.implementation_ir) assert args_data.field_info == iir_args_data.field_info assert args_data.parameter_info == iir_args_data.parameter_info assert args_data.unreferenced == iir_args_data.unreferenced
def test_generate_post_run(backend_name, mode): backend_cls = backend_registry[backend_name] builder = StencilBuilder(stencil_def, backend=backend_cls).with_externals({"MODE": mode}) iir = builder.implementation_ir args_data = make_args_data_from_iir(iir) module_generator = backend_cls.MODULE_GENERATOR_CLASS() module_generator.args_data = args_data source = module_generator.generate_post_run() if gt_backend.from_name(backend_name).storage_info["device"] == "cpu": assert source == "" else: assert source == "out._set_device_modified()"
def test_generate_pre_run(backend_name, mode): backend_cls = backend_registry[backend_name] builder = StencilBuilder(stencil_def, backend=backend_cls).with_externals({"MODE": mode}) iir = builder.implementation_ir args_data = make_args_data_from_iir(iir) module_generator = backend_cls.MODULE_GENERATOR_CLASS() module_generator.args_data = args_data source = module_generator.generate_pre_run() if gt_backend.from_name(backend_name).storage_info["device"] == "cpu": if "dawn:" not in backend_name: assert source == "" else: for key in field_info_val[mode]: assert f"{key}.host_to_device()" in source for key in unreferenced_val[mode]: assert f"{key}.host_to_device()" not in source
def test_device_sync_option(backend_name, mode, device_sync): backend_cls = backend_registry[backend_name] builder = StencilBuilder(stencil_def, backend=backend_cls).with_externals({"MODE": mode}) builder.options.backend_opts["device_sync"] = device_sync if backend_cls.USE_LEGACY_TOOLCHAIN: args_data = make_args_data_from_iir(builder.implementation_ir) else: args_data = make_args_data_from_gtir(builder.gtir_pipeline) module_generator = backend_cls.MODULE_GENERATOR_CLASS() source = module_generator( args_data, builder, pyext_module_name=builder.module_name, pyext_file_path=str(builder.module_path), ) if device_sync: assert "cupy.cuda.Device(0).synchronize()" in source else: assert "cupy.cuda.Device(0).synchronize()" not in source