Exemplo n.º 1
0
def is_arch_supported(arch):
    """Checks whether an arch is supported on the machine.

    Args:
        arch (taichi_core.Arch): Specified arch.

    Returns:
        bool: Whether `arch` is supported on the machine.
    """
    arch_table = {
        cuda: _ti_core.with_cuda,
        metal: _ti_core.with_metal,
        opengl: _ti_core.with_opengl,
        cc: _ti_core.with_cc,
        vulkan: lambda: _ti_core.with_vulkan(),
        wasm: lambda: True,
        cpu: lambda: True,
    }
    with_arch = arch_table.get(arch, lambda: False)
    try:
        return with_arch()
    except Exception as e:
        arch = _ti_core.arch_name(arch)
        _ti_core.warn(
            f"{e.__class__.__name__}: '{e}' occurred when detecting "
            f"{arch}, consider adding `TI_ENABLE_{arch.upper()}=0` "
            f" to environment variables to suppress this warning message.")
        return False
Exemplo n.º 2
0
def supported_archs():
    """Gets all supported archs on the machine.

    Returns:
        List[taichi_core.Arch]: All supported archs on the machine.
    """
    archs = [cpu, cuda, metal, opengl, cc]

    wanted_archs = os.environ.get('TI_WANTED_ARCHS', '')
    want_exclude = wanted_archs.startswith('^')
    if want_exclude:
        wanted_archs = wanted_archs[1:]
    wanted_archs = wanted_archs.split(',')
    # Note, ''.split(',') gives you [''], which is not an empty array.
    wanted_archs = list(filter(lambda x: x != '', wanted_archs))
    if len(wanted_archs):
        archs, old_archs = [], archs
        for arch in old_archs:
            if want_exclude == (_ti_core.arch_name(arch) not in wanted_archs):
                archs.append(arch)

    archs, old_archs = [], archs
    for arch in old_archs:
        if is_arch_supported(arch):
            archs.append(arch)

    return archs
Exemplo n.º 3
0
def is_arch_supported(arch):
    arch_table = {
        cuda: _ti_core.with_cuda,
        metal: _ti_core.with_metal,
        opengl: _ti_core.with_opengl,
        cc: _ti_core.with_cc,
        cpu: lambda: True
    }
    with_arch = arch_table.get(arch, lambda: False)
    try:
        return with_arch()
    except Exception as e:
        arch = _ti_core.arch_name(arch)
        _ti_core.warn(
            f"{e.__class__.__name__}: '{e}' occurred when detecting "
            f"{arch}, consider add `export TI_WITH_{arch.upper()}=0` "
            f" to environment variables to depress this warning message.")
        return False
Exemplo n.º 4
0
def stat_write(key, value):
    import yaml
    case_name = os.environ.get('TI_CURRENT_BENCHMARK')
    if case_name is None:
        return
    if case_name.startswith('benchmark_'):
        case_name = case_name[10:]
    arch_name = _ti_core.arch_name(ti.cfg.arch)
    async_mode = 'async' if ti.cfg.async_mode else 'sync'
    output_dir = os.environ.get('TI_BENCHMARK_OUTPUT_DIR', '.')
    filename = f'{output_dir}/benchmark.yml'
    try:
        with open(filename, 'r') as f:
            data = yaml.load(f, Loader=yaml.SafeLoader)
    except FileNotFoundError:
        data = {}
    data.setdefault(case_name, {})
    data[case_name].setdefault(key, {})
    data[case_name][key].setdefault(arch_name, {})
    data[case_name][key][arch_name][async_mode] = value
    with open(filename, 'w') as f:
        yaml.dump(data, f, Dumper=yaml.SafeDumper)