Esempio n. 1
0
def test_prefer_gpu():
    current_ops = get_current_ops()
    try:
        import cupy  # noqa: F401

        prefer_gpu()
        assert isinstance(get_current_ops(), CupyOps)
    except ImportError:
        assert not prefer_gpu()
    set_current_ops(current_ops)
Esempio n. 2
0
def test_require_gpu():
    current_ops = get_current_ops()
    try:
        import cupy  # noqa: F401

        require_gpu()
        assert isinstance(get_current_ops(), CupyOps)
    except ImportError:
        with pytest.raises(ValueError):
            require_gpu()
    set_current_ops(current_ops)
Esempio n. 3
0
def test_require_cpu():
    current_ops = get_current_ops()
    require_cpu()
    assert isinstance(get_current_ops(), NumpyOps)
    try:
        import cupy  # noqa: F401

        require_gpu()
        assert isinstance(get_current_ops(), CupyOps)
    except ImportError:
        pass
    require_cpu()
    assert isinstance(get_current_ops(), NumpyOps)
    set_current_ops(current_ops)
Esempio n. 4
0
def set_backend(name, gpu_id):
    global CONFIG
    if name == "jax":
        set_current_ops(JaxOps())
        CONFIG = CONFIG.replace("PyTorch", "")
    else:
        if gpu_id == -1:
            set_current_ops(NumpyOps())
        else:
            set_current_ops(CupyOps())
        CONFIG = CONFIG.replace("LSTM.v1", "PyTorchLSTM.v1")
Esempio n. 5
0
def set_backend(name, gpu_id):
    global CONFIG
    if name == "generic":
        set_current_ops(Ops())
    else:
        if gpu_id == -1:
            set_current_ops(NumpyOps(use_blis=True))
        else:
            set_current_ops(CupyOps())
        if name == "pytorch":
            import torch
            torch.set_num_threads(1)
            CONFIG = CONFIG.replace("LSTM.v1", "PyTorchLSTM.v1")
Esempio n. 6
0
def set_backend(name, gpu_id):
    if gpu_id == -1:
        set_current_ops(NumpyOps())
    else:
        set_current_ops(CupyOps())
    CONFIG = CONFIG.replace("LSTM.v1", "PyTorchLSTM.v1")