def test_set_backend_local_threadsafe(): pytest.importorskip('torch') global_default = tl.get_backend() with ThreadPoolExecutor(max_workers=1) as executor: with tl.backend_context('numpy', local_threadsafe=True): assert tl.get_backend() == 'numpy' # Changes only happen locally in this thread assert executor.submit(tl.get_backend).result() == global_default # Set the global default backend try: tl.set_backend('pytorch', local_threadsafe=False) # Changed toplevel default in all threads assert executor.submit(tl.get_backend).result() == 'pytorch' with tl.backend_context('numpy', local_threadsafe=True): assert tl.get_backend() == 'numpy' def check(): assert tl.get_backend() == 'pytorch' with tl.backend_context('numpy', local_threadsafe=True): assert tl.get_backend() == 'numpy' assert tl.get_backend() == 'pytorch' executor.submit(check).result() finally: tl.set_backend(global_default, local_threadsafe=False) executor.submit(tl.set_backend, global_default).result() assert tl.get_backend() == global_default assert executor.submit(tl.get_backend).result() == global_default
def test_set_backend(): torch = pytest.importorskip('torch') toplevel_backend = tl.get_backend() # Set in context manager with tl.backend_context('numpy'): assert tl.get_backend() == 'numpy' assert isinstance(tl.tensor([1, 2, 3]), np.ndarray) assert isinstance(T.tensor([1, 2, 3]), np.ndarray) assert tl.float32 is T.float32 is np.float32 with tl.backend_context('pytorch'): assert tl.get_backend() == 'pytorch' assert torch.is_tensor(tl.tensor([1, 2, 3])) assert torch.is_tensor(T.tensor([1, 2, 3])) assert tl.float32 is T.float32 is torch.float32 # Sets back to numpy assert tl.get_backend() == 'numpy' assert isinstance(tl.tensor([1, 2, 3]), np.ndarray) assert isinstance(T.tensor([1, 2, 3]), np.ndarray) assert tl.float32 is T.float32 is np.float32 # Reset back to initial backend assert tl.get_backend() == toplevel_backend # Set not in context manager tl.set_backend('pytorch') assert tl.get_backend() == 'pytorch' tl.set_backend(toplevel_backend) assert tl.get_backend() == toplevel_backend # Improper name doesn't reset backend with assert_raises(ValueError): tl.set_backend('not-a-real-backend') assert tl.get_backend() == toplevel_backend
def check(): assert tl.get_backend() == 'pytorch' with tl.backend_context('numpy', local_threadsafe=True): assert tl.get_backend() == 'numpy' assert tl.get_backend() == 'pytorch'