def test_contextmanager_nested(): with tn.DefaultBackend("tensorflow"): a = tn.Node(np.ones((10, ))) assert a.backend.name == "tensorflow" with tn.DefaultBackend("numpy"): b = tn.Node(np.ones((10, ))) assert b.backend.name == "numpy" c = tn.Node(np.ones((10, ))) assert c.backend.name == "tensorflow" d = tn.Node(np.ones((10, ))) assert d.backend.name == "numpy"
def test_operator_kron(backend): with tn.DefaultBackend(backend): X = np.array([[0, 1], [1, 0]], dtype=np.float32) Z = np.array([[1, 0], [0, -1]], dtype=np.float32) expected = np.kron(X, Z).reshape(2, 2, 2, 2) result = tn.kron([tn.Node(X), tn.Node(Z)]) np.testing.assert_allclose(result.tensor, expected)
def test_kron_raises(backend): with tn.DefaultBackend(backend): A = tn.Node(np.ones((2, 2, 2))) B = tn.Node(np.ones((2, 2, 2))) with pytest.raises( ValueError, match="All operator tensors must have an even order."): tn.kron([A, B])
def test_get_neighbors(backend): with tn.DefaultBackend(backend): a = tn.Node(np.ones((2, 2))) b = tn.Node(np.ones((2, 2, 2, 2))) c = tn.Node(np.ones((2, 2, 2))) d = tn.Node(np.ones((2, 2))) b[0] ^ a[1] b[3] ^ c[2] a[0] ^ d[1] b[1] ^ b[2] result = tn.get_neighbors(b) assert result == [a, c]
def test_contextmanager_simple(): with tn.DefaultBackend("tensorflow"): a = tn.Node(np.ones((10, ))) b = tn.Node(np.ones((10, ))) assert a.backend.name == b.backend.name
def test_contextmanager_BaseBackend(): tn.set_default_backend("pytorch") a = tn.Node(np.ones((10, ))) with tn.DefaultBackend(a.backend): b = tn.Node(np.ones((10, ))) assert b.backend.name == "pytorch"
def test_contextmanager_wrong_item(): a = tn.Node(np.ones((10, ))) with pytest.raises(ValueError): tn.DefaultBackend(a) # pytype: disable=wrong-arg-types
def test_contextmanager_interruption(): tn.set_default_backend("pytorch") with pytest.raises(AssertionError): with tn.DefaultBackend("numpy"): tn.set_default_backend("tensorflow")
def test_contextmanager_default_backend(): tn.set_default_backend("pytorch") with tn.DefaultBackend("numpy"): assert _default_backend_stack.default_backend == "pytorch"