def test_context_scope(): context1 = chainerx.Context() context2 = chainerx.Context() chainerx.set_default_context(context1) with chainerx.context_scope(context2): assert chainerx.get_default_context() is context2 scope = chainerx.context_scope(context2) assert chainerx.get_default_context() is context1 with scope: assert chainerx.get_default_context() is context2 assert chainerx.get_default_context() is context1
def test_global_default_context(): context = chainerx.Context() chainerx.set_global_default_context(None) with pytest.raises(chainerx.ContextError): chainerx.get_global_default_context() chainerx.set_global_default_context(context) assert chainerx.get_global_default_context() is context
def test_chainerx_get_device(): context = chainerx.Context() with chainerx.context_scope(context): device = chainerx.get_device('native:0') assert device.context is context assert device.name == 'native:0' assert device is chainerx.get_device('native', 0) assert device is chainerx.get_device(device) assert chainerx.get_default_device() is chainerx.get_device()
def test_get_backend(): context = chainerx.Context() backend = context.get_backend('native') assert backend.name == 'native' assert context.get_backend('native') is backend with pytest.raises(chainerx.BackendError): context.get_backend('something_that_does_not_exist')
def test_create_and_release_backprop_id(): context = chainerx.Context() backprop_id = context.make_backprop_id("bp1") assert "bp1" == backprop_id.name assert context == backprop_id.context context._check_valid_backprop_id(backprop_id) context.release_backprop_id(backprop_id) with pytest.raises(chainerx.ChainerxError): context._check_valid_backprop_id(backprop_id)
def test_get_device(): context = chainerx.Context() device = context.get_device('native') assert device.name == 'native:0' assert device.index == 0 assert context.get_device('native:0') is device assert context.get_device('native', 0) is device with pytest.raises(chainerx.BackendError): context.get_device('something_that_does_not_exist:0')
def test_is_backprop_required(): current_context = chainerx.get_default_context() another_context = chainerx.Context() with chainerx.backprop_scope('bp1') as bp1, \ chainerx.backprop_scope('bp2') as bp2: with chainerx.no_backprop_mode(): with chainerx.force_backprop_mode(bp1): assert not chainerx.is_backprop_required() assert chainerx.is_backprop_required(bp1) assert not chainerx.is_backprop_required(bp2) assert not chainerx.is_backprop_required( context=current_context) assert chainerx.is_backprop_required(context=another_context) with pytest.raises(TypeError): chainerx.is_backprop_required(context='foo')
def test_chainerx_get_backend(): context = chainerx.Context() with chainerx.context_scope(context): backend = chainerx.get_backend('native') assert backend.context is context assert backend.name == 'native'
def test_creation(): chainerx.Context()