示例#1
0
def test_set_backend_local_threadsafe():
    pytest.importorskip('torch')

    global_default = tl.get_backend()

    with ThreadPoolExecutor(max_workers=1) as executor:

        with tl.backend_context('numpy', local_threadsafe=True):
            assert tl.get_backend() == 'numpy'
            # Changes only happen locally in this thread
            assert executor.submit(tl.get_backend).result() == global_default

        # Set the global default backend
        try:
            tl.set_backend('pytorch', local_threadsafe=False)

            # Changed toplevel default in all threads
            assert executor.submit(tl.get_backend).result() == 'pytorch'

            with tl.backend_context('numpy', local_threadsafe=True):
                assert tl.get_backend() == 'numpy'

                def check():
                    assert tl.get_backend() == 'pytorch'
                    with tl.backend_context('numpy', local_threadsafe=True):
                        assert tl.get_backend() == 'numpy'
                    assert tl.get_backend() == 'pytorch'

                executor.submit(check).result()
        finally:
            tl.set_backend(global_default, local_threadsafe=False)
            executor.submit(tl.set_backend, global_default).result()

        assert tl.get_backend() == global_default
        assert executor.submit(tl.get_backend).result() == global_default
示例#2
0
def test_set_backend():
    torch = pytest.importorskip('torch')

    toplevel_backend = tl.get_backend()

    # Set in context manager
    with tl.backend_context('numpy'):
        assert tl.get_backend() == 'numpy'
        assert isinstance(tl.tensor([1, 2, 3]), np.ndarray)
        assert isinstance(T.tensor([1, 2, 3]), np.ndarray)
        assert tl.float32 is T.float32 is np.float32

        with tl.backend_context('pytorch'):
            assert tl.get_backend() == 'pytorch'
            assert torch.is_tensor(tl.tensor([1, 2, 3]))
            assert torch.is_tensor(T.tensor([1, 2, 3]))
            assert tl.float32 is T.float32 is torch.float32

        # Sets back to numpy
        assert tl.get_backend() == 'numpy'
        assert isinstance(tl.tensor([1, 2, 3]), np.ndarray)
        assert isinstance(T.tensor([1, 2, 3]), np.ndarray)
        assert tl.float32 is T.float32 is np.float32

    # Reset back to initial backend
    assert tl.get_backend() == toplevel_backend

    # Set not in context manager
    tl.set_backend('pytorch')
    assert tl.get_backend() == 'pytorch'
    tl.set_backend(toplevel_backend)

    assert tl.get_backend() == toplevel_backend

    # Improper name doesn't reset backend
    with assert_raises(ValueError):
        tl.set_backend('not-a-real-backend')
    assert tl.get_backend() == toplevel_backend
示例#3
0
 def check():
     assert tl.get_backend() == 'pytorch'
     with tl.backend_context('numpy', local_threadsafe=True):
         assert tl.get_backend() == 'numpy'
     assert tl.get_backend() == 'pytorch'