예제 #1
0
def test_global_default_context():
    context = chainerx.Context()

    chainerx.set_global_default_context(None)
    with pytest.raises(chainerx.ContextError):
        chainerx.get_global_default_context()

    chainerx.set_global_default_context(context)
    assert chainerx.get_global_default_context() is context
예제 #2
0
def test_global_default_context():
    context = chainerx.Context()

    chainerx.set_global_default_context(None)
    with pytest.raises(chainerx.ContextError):
        chainerx.get_global_default_context()

    chainerx.set_global_default_context(context)
    assert chainerx.get_global_default_context() is context
예제 #3
0
def cache_restore_context(request):
    device = chainerx.get_default_device()
    context = chainerx.get_default_context()
    global_context = chainerx.get_global_default_context()

    def restore_context():
        chainerx.set_global_default_context(global_context)
        chainerx.set_default_context(context)
        chainerx.set_default_device(device)
    request.addfinalizer(restore_context)
예제 #4
0
def cache_restore_context(request):
    device = chainerx.get_default_device()
    context = chainerx.get_default_context()
    global_context = chainerx.get_global_default_context()

    def restore_context():
        chainerx.set_global_default_context(global_context)
        chainerx.set_default_context(context)
        chainerx.set_default_device(device)

    request.addfinalizer(restore_context)
예제 #5
0
def test_creation():
    ctx = chainerx.get_global_default_context()
    backend = ctx.get_backend('native')
    device = backend.get_device(0)
    assert device.name == 'native:0'
    assert device.backend is backend
    assert device.context is ctx
    assert device.index == 0

    device = backend.get_device(1)
    assert device.name == 'native:1'
    assert device.backend is backend
    assert device.context is ctx
    assert device.index == 1
예제 #6
0
def get_cuda_limit():
    global _cuda_limit
    if _cuda_limit is not None:
        return _cuda_limit
    if os.getenv('CHAINERX_TEST_CUDA_DEVICE_LIMIT') is None:
        try:
            backend = chainerx.get_global_default_context().get_backend('cuda')
            _cuda_limit = backend.get_device_count()
        except chainerx.BackendError:
            _cuda_limit = 0
    else:
        _cuda_limit = int(os.getenv('CHAINERX_TEST_CUDA_DEVICE_LIMIT'))
        if _cuda_limit < 0:
            raise chainerx.ChainerxError(
                'CHAINERX_TEST_DUDA_DEVICE_LIMIT must be non-negative '
                'integer: {}'.format(_cuda_limit))
    return _cuda_limit
예제 #7
0
def get_cuda_limit():
    global _cuda_limit
    if _cuda_limit is not None:
        return _cuda_limit
    if os.getenv('CHAINERX_TEST_CUDA_DEVICE_LIMIT') is None:
        try:
            backend = chainerx.get_global_default_context().get_backend('cuda')
            _cuda_limit = backend.get_device_count()
        except chainerx.BackendError:
            _cuda_limit = 0
    else:
        _cuda_limit = int(os.getenv('CHAINERX_TEST_CUDA_DEVICE_LIMIT'))
        if _cuda_limit < 0:
            raise chainerx.ChainerxError(
                'CHAINERX_TEST_DUDA_DEVICE_LIMIT must be non-negative '
                'integer: {}'.format(_cuda_limit))
    return _cuda_limit
예제 #8
0
def test_synchronize():
    ctx = chainerx.get_global_default_context()
    device = ctx.get_device('native', 0)
    device.synchronize()
예제 #9
0
def device_instance2(request, device_data2):
    return chainerx.get_global_default_context().get_device(
        'native', device_data2['index'])
예제 #10
0
def test_name_native():
    backend = chainerx.get_global_default_context().get_backend('native')
    assert 'native' == backend.name
예제 #11
0
def test_get_device_count_cuda():
    backend = chainerx.get_global_default_context().get_backend('cuda')
    assert backend.get_device_count() > 0
예제 #12
0
def test_get_device_cuda():
    backend = chainerx.get_global_default_context().get_backend('cuda')
    device = backend.get_device(0)
    assert 0 == device.index
    assert 'cuda:0' == device.name
    assert device is backend.get_device(0)
예제 #13
0
def test_name_cuda():
    backend = chainerx.get_global_default_context().get_backend('cuda')
    assert 'cuda' == backend.name
예제 #14
0
def test_get_device_count_native():
    backend = chainerx.get_global_default_context().get_backend('native')
    assert backend.get_device_count() > 0
예제 #15
0
def test_get_device_native():
    backend = chainerx.get_global_default_context().get_backend('native')
    device = backend.get_device(0)
    assert 0 == device.index
    assert 'native:0' == device.name
    assert device is backend.get_device(0)