def is_arch_supported(arch): """Checks whether an arch is supported on the machine. Args: arch (taichi_core.Arch): Specified arch. Returns: bool: Whether `arch` is supported on the machine. """ arch_table = { cuda: _ti_core.with_cuda, metal: _ti_core.with_metal, opengl: _ti_core.with_opengl, cc: _ti_core.with_cc, vulkan: lambda: _ti_core.with_vulkan(), wasm: lambda: True, cpu: lambda: True, } with_arch = arch_table.get(arch, lambda: False) try: return with_arch() except Exception as e: arch = _ti_core.arch_name(arch) _ti_core.warn( f"{e.__class__.__name__}: '{e}' occurred when detecting " f"{arch}, consider adding `TI_ENABLE_{arch.upper()}=0` " f" to environment variables to suppress this warning message.") return False
def supported_archs(): """Gets all supported archs on the machine. Returns: List[taichi_core.Arch]: All supported archs on the machine. """ archs = [cpu, cuda, metal, opengl, cc] wanted_archs = os.environ.get('TI_WANTED_ARCHS', '') want_exclude = wanted_archs.startswith('^') if want_exclude: wanted_archs = wanted_archs[1:] wanted_archs = wanted_archs.split(',') # Note, ''.split(',') gives you [''], which is not an empty array. wanted_archs = list(filter(lambda x: x != '', wanted_archs)) if len(wanted_archs): archs, old_archs = [], archs for arch in old_archs: if want_exclude == (_ti_core.arch_name(arch) not in wanted_archs): archs.append(arch) archs, old_archs = [], archs for arch in old_archs: if is_arch_supported(arch): archs.append(arch) return archs
def is_arch_supported(arch): arch_table = { cuda: _ti_core.with_cuda, metal: _ti_core.with_metal, opengl: _ti_core.with_opengl, cc: _ti_core.with_cc, cpu: lambda: True } with_arch = arch_table.get(arch, lambda: False) try: return with_arch() except Exception as e: arch = _ti_core.arch_name(arch) _ti_core.warn( f"{e.__class__.__name__}: '{e}' occurred when detecting " f"{arch}, consider add `export TI_WITH_{arch.upper()}=0` " f" to environment variables to depress this warning message.") return False
def stat_write(key, value): import yaml case_name = os.environ.get('TI_CURRENT_BENCHMARK') if case_name is None: return if case_name.startswith('benchmark_'): case_name = case_name[10:] arch_name = _ti_core.arch_name(ti.cfg.arch) async_mode = 'async' if ti.cfg.async_mode else 'sync' output_dir = os.environ.get('TI_BENCHMARK_OUTPUT_DIR', '.') filename = f'{output_dir}/benchmark.yml' try: with open(filename, 'r') as f: data = yaml.load(f, Loader=yaml.SafeLoader) except FileNotFoundError: data = {} data.setdefault(case_name, {}) data[case_name].setdefault(key, {}) data[case_name][key].setdefault(arch_name, {}) data[case_name][key][arch_name][async_mode] = value with open(filename, 'w') as f: yaml.dump(data, f, Dumper=yaml.SafeDumper)