def test_module_data_equivalence(): builder = StencilBuilder(sample_stencil_with_args) legacy_module_data = make_args_data_from_iir(builder.implementation_ir) gtc_module_data = make_args_data_from_gtir(builder.gtir_pipeline) assert legacy_module_data == gtc_module_data
def test_make_args_data_from_gtir(backend_name, mode): backend_cls = backend_registry[backend_name] builder = StencilBuilder(stencil_def, backend=backend_cls).with_externals({"MODE": mode}) args_data = make_args_data_from_gtir(builder.gtir_pipeline) iir_args_data = make_args_data_from_iir(builder.implementation_ir) assert args_data.field_info == iir_args_data.field_info assert args_data.parameter_info == iir_args_data.parameter_info assert args_data.unreferenced == iir_args_data.unreferenced
def test_module_data(): builder = StencilBuilder(sample_stencil_with_args) module_data = make_args_data_from_gtir(builder.gtir_pipeline) assert module_data.field_info["used_io_field"].access == AccessKind.WRITE assert module_data.field_info["used_in_field"].access == AccessKind.READ assert module_data.field_info["unused_field"].access == AccessKind.NONE assert module_data.parameter_info["used_scalar"].access == AccessKind.READ assert module_data.parameter_info[ "unused_scalar"].access == AccessKind.NONE
def test_generate_post_run(backend_name, mode): backend_cls = backend_registry[backend_name] builder = StencilBuilder(stencil_def, backend=backend_cls).with_externals( {"MODE": mode}) args_data = make_args_data_from_gtir(builder.gtir_pipeline) module_generator = backend_cls.MODULE_GENERATOR_CLASS() module_generator.args_data = args_data source = module_generator.generate_post_run() if gt_backend.from_name(backend_name).storage_info["device"] == "cpu": assert source == "" else: assert source == "out._set_device_modified()"
def test_make_args_data_from_gtir(backend_name, mode): backend_cls = backend_registry[backend_name] builder = StencilBuilder(stencil_def, backend=backend_cls).with_externals( {"MODE": mode}) args_data = make_args_data_from_gtir(builder.gtir_pipeline) assert set(args_data.unreferenced) == set(unreferenced_val[mode]) field_info_from_gtir = {( p.name, np.dtype(p.dtype.name.lower()), utils.dimension_flags_to_names(p.dimensions).upper(), p.data_dims, ) for p in builder.gtir.params if isinstance(p, gtir.FieldDecl)} field_info_from_args_data = {(name, d.dtype, "".join(d.axes), d.data_dims) for name, d in args_data.field_info.items() if name not in args_data.unreferenced} assert field_info_from_gtir == field_info_from_args_data param_info_from_gtir = {(p.name, np.dtype(p.dtype.name.lower())) for p in builder.gtir.params if isinstance(p, gtir.ScalarDecl)} param_info_from_args_data = { (name, d.dtype) for name, d in args_data.parameter_info.items() if name not in args_data.unreferenced } assert param_info_from_gtir == param_info_from_args_data for name, field_info in args_data.field_info.items(): if name == "out": access = AccessKind.WRITE elif name in field_info_val[mode]: access = AccessKind.READ else: access = AccessKind.NONE assert field_info.access == access for name, param_info in args_data.parameter_info.items(): if name in parameter_info_val[mode]: access = AccessKind.READ else: access = AccessKind.NONE assert param_info.access == access
def test_generate_pre_run(backend_name, mode): backend_cls = backend_registry[backend_name] builder = StencilBuilder(stencil_def, backend=backend_cls).with_externals( {"MODE": mode}) args_data = make_args_data_from_gtir(builder.gtir_pipeline) module_generator = backend_cls.MODULE_GENERATOR_CLASS() module_generator.args_data = args_data source = module_generator.generate_pre_run() if gt_backend.from_name(backend_name).storage_info["device"] == "cpu": assert source == "" else: for key in field_info_val[mode]: assert f"{key}.host_to_device()" in source for key in unreferenced_val[mode]: assert f"{key}.host_to_device()" not in source
def test_device_sync_option(backend_name, mode, device_sync): backend_cls = backend_registry[backend_name] builder = StencilBuilder(stencil_def, backend=backend_cls).with_externals( {"MODE": mode}) builder.options.backend_opts["device_sync"] = device_sync args_data = make_args_data_from_gtir(builder.gtir_pipeline) module_generator = backend_cls.MODULE_GENERATOR_CLASS() source = module_generator( args_data, builder, pyext_module_name=builder.module_name, pyext_file_path=str(builder.module_path), ) if device_sync: assert "cupy.cuda.Device(0).synchronize()" in source else: assert "cupy.cuda.Device(0).synchronize()" not in source
def _expand_and_finalize_sdfg(stencil_ir: gtir.Stencil, sdfg: dace.SDFG, layout_map) -> dace.SDFG: args_data = make_args_data_from_gtir(GtirPipeline(stencil_ir)) # stencils without effect if all(info is None for info in args_data.field_info.values()): sdfg = dace.SDFG(stencil_ir.name) sdfg.add_state(stencil_ir.name) return sdfg for array in sdfg.arrays.values(): if array.transient: array.lifetime = dace.AllocationLifetime.Persistent _pre_expand_trafos(sdfg) sdfg.expand_library_nodes(recursive=True) _specialize_transient_strides(sdfg, layout_map=layout_map) _post_expand_trafos(sdfg) return sdfg