def test_prefer_gpu(): current_ops = get_current_ops() try: import cupy # noqa: F401 prefer_gpu() assert isinstance(get_current_ops(), CupyOps) except ImportError: assert not prefer_gpu() set_current_ops(current_ops)
def test_require_gpu(): current_ops = get_current_ops() try: import cupy # noqa: F401 require_gpu() assert isinstance(get_current_ops(), CupyOps) except ImportError: with pytest.raises(ValueError): require_gpu() set_current_ops(current_ops)
def test_require_cpu(): current_ops = get_current_ops() require_cpu() assert isinstance(get_current_ops(), NumpyOps) try: import cupy # noqa: F401 require_gpu() assert isinstance(get_current_ops(), CupyOps) except ImportError: pass require_cpu() assert isinstance(get_current_ops(), NumpyOps) set_current_ops(current_ops)
def set_backend(name, gpu_id): global CONFIG if name == "jax": set_current_ops(JaxOps()) CONFIG = CONFIG.replace("PyTorch", "") else: if gpu_id == -1: set_current_ops(NumpyOps()) else: set_current_ops(CupyOps()) CONFIG = CONFIG.replace("LSTM.v1", "PyTorchLSTM.v1")
def set_backend(name, gpu_id): global CONFIG if name == "generic": set_current_ops(Ops()) else: if gpu_id == -1: set_current_ops(NumpyOps(use_blis=True)) else: set_current_ops(CupyOps()) if name == "pytorch": import torch torch.set_num_threads(1) CONFIG = CONFIG.replace("LSTM.v1", "PyTorchLSTM.v1")
def set_backend(name, gpu_id): if gpu_id == -1: set_current_ops(NumpyOps()) else: set_current_ops(CupyOps()) CONFIG = CONFIG.replace("LSTM.v1", "PyTorchLSTM.v1")