def test_contextmanager_nested():
    with tn.DefaultBackend("tensorflow"):
        a = tn.Node(np.ones((10, )))
        assert a.backend.name == "tensorflow"
        with tn.DefaultBackend("numpy"):
            b = tn.Node(np.ones((10, )))
            assert b.backend.name == "numpy"
        c = tn.Node(np.ones((10, )))
        assert c.backend.name == "tensorflow"
    d = tn.Node(np.ones((10, )))
    assert d.backend.name == "numpy"
def test_operator_kron(backend):
    with tn.DefaultBackend(backend):
        X = np.array([[0, 1], [1, 0]], dtype=np.float32)
        Z = np.array([[1, 0], [0, -1]], dtype=np.float32)
        expected = np.kron(X, Z).reshape(2, 2, 2, 2)
        result = tn.kron([tn.Node(X), tn.Node(Z)])
        np.testing.assert_allclose(result.tensor, expected)
def test_kron_raises(backend):
    with tn.DefaultBackend(backend):
        A = tn.Node(np.ones((2, 2, 2)))
        B = tn.Node(np.ones((2, 2, 2)))
        with pytest.raises(
                ValueError,
                match="All operator tensors must have an even order."):
            tn.kron([A, B])
예제 #4
0
def test_get_neighbors(backend):
    with tn.DefaultBackend(backend):
        a = tn.Node(np.ones((2, 2)))
        b = tn.Node(np.ones((2, 2, 2, 2)))
        c = tn.Node(np.ones((2, 2, 2)))
        d = tn.Node(np.ones((2, 2)))
        b[0] ^ a[1]
        b[3] ^ c[2]
        a[0] ^ d[1]
        b[1] ^ b[2]
        result = tn.get_neighbors(b)
        assert result == [a, c]
def test_contextmanager_simple():
    with tn.DefaultBackend("tensorflow"):
        a = tn.Node(np.ones((10, )))
        b = tn.Node(np.ones((10, )))
    assert a.backend.name == b.backend.name
def test_contextmanager_BaseBackend():
    tn.set_default_backend("pytorch")
    a = tn.Node(np.ones((10, )))
    with tn.DefaultBackend(a.backend):
        b = tn.Node(np.ones((10, )))
    assert b.backend.name == "pytorch"
def test_contextmanager_wrong_item():
    a = tn.Node(np.ones((10, )))
    with pytest.raises(ValueError):
        tn.DefaultBackend(a)  # pytype: disable=wrong-arg-types
def test_contextmanager_interruption():
    tn.set_default_backend("pytorch")
    with pytest.raises(AssertionError):
        with tn.DefaultBackend("numpy"):
            tn.set_default_backend("tensorflow")
def test_contextmanager_default_backend():
    tn.set_default_backend("pytorch")
    with tn.DefaultBackend("numpy"):
        assert _default_backend_stack.default_backend == "pytorch"