def test_global_default_context(): context = chainerx.Context() chainerx.set_global_default_context(None) with pytest.raises(chainerx.ContextError): chainerx.get_global_default_context() chainerx.set_global_default_context(context) assert chainerx.get_global_default_context() is context
def test_global_default_context(): context = chainerx.Context() chainerx.set_global_default_context(None) with pytest.raises(chainerx.ContextError): chainerx.get_global_default_context() chainerx.set_global_default_context(context) assert chainerx.get_global_default_context() is context
def cache_restore_context(request): device = chainerx.get_default_device() context = chainerx.get_default_context() global_context = chainerx.get_global_default_context() def restore_context(): chainerx.set_global_default_context(global_context) chainerx.set_default_context(context) chainerx.set_default_device(device) request.addfinalizer(restore_context)
def cache_restore_context(request): device = chainerx.get_default_device() context = chainerx.get_default_context() global_context = chainerx.get_global_default_context() def restore_context(): chainerx.set_global_default_context(global_context) chainerx.set_default_context(context) chainerx.set_default_device(device) request.addfinalizer(restore_context)
def test_creation(): ctx = chainerx.get_global_default_context() backend = ctx.get_backend('native') device = backend.get_device(0) assert device.name == 'native:0' assert device.backend is backend assert device.context is ctx assert device.index == 0 device = backend.get_device(1) assert device.name == 'native:1' assert device.backend is backend assert device.context is ctx assert device.index == 1
def get_cuda_limit(): global _cuda_limit if _cuda_limit is not None: return _cuda_limit if os.getenv('CHAINERX_TEST_CUDA_DEVICE_LIMIT') is None: try: backend = chainerx.get_global_default_context().get_backend('cuda') _cuda_limit = backend.get_device_count() except chainerx.BackendError: _cuda_limit = 0 else: _cuda_limit = int(os.getenv('CHAINERX_TEST_CUDA_DEVICE_LIMIT')) if _cuda_limit < 0: raise chainerx.ChainerxError( 'CHAINERX_TEST_DUDA_DEVICE_LIMIT must be non-negative ' 'integer: {}'.format(_cuda_limit)) return _cuda_limit
def get_cuda_limit(): global _cuda_limit if _cuda_limit is not None: return _cuda_limit if os.getenv('CHAINERX_TEST_CUDA_DEVICE_LIMIT') is None: try: backend = chainerx.get_global_default_context().get_backend('cuda') _cuda_limit = backend.get_device_count() except chainerx.BackendError: _cuda_limit = 0 else: _cuda_limit = int(os.getenv('CHAINERX_TEST_CUDA_DEVICE_LIMIT')) if _cuda_limit < 0: raise chainerx.ChainerxError( 'CHAINERX_TEST_DUDA_DEVICE_LIMIT must be non-negative ' 'integer: {}'.format(_cuda_limit)) return _cuda_limit
def test_synchronize(): ctx = chainerx.get_global_default_context() device = ctx.get_device('native', 0) device.synchronize()
def device_instance2(request, device_data2): return chainerx.get_global_default_context().get_device( 'native', device_data2['index'])
def test_name_native(): backend = chainerx.get_global_default_context().get_backend('native') assert 'native' == backend.name
def test_get_device_count_cuda(): backend = chainerx.get_global_default_context().get_backend('cuda') assert backend.get_device_count() > 0
def test_get_device_cuda(): backend = chainerx.get_global_default_context().get_backend('cuda') device = backend.get_device(0) assert 0 == device.index assert 'cuda:0' == device.name assert device is backend.get_device(0)
def test_name_cuda(): backend = chainerx.get_global_default_context().get_backend('cuda') assert 'cuda' == backend.name
def test_get_device_count_native(): backend = chainerx.get_global_default_context().get_backend('native') assert backend.get_device_count() > 0
def test_get_device_native(): backend = chainerx.get_global_default_context().get_backend('native') device = backend.get_device(0) assert 0 == device.index assert 'native:0' == device.name assert device is backend.get_device(0)