Esempio n. 1
0
def test_split_node_qr_unitarity(dtype, num_charges):
    np.random.seed(10)
    a = tn.Node(get_square_matrix(50, num_charges, dtype=dtype),
                backend='symmetric')
    q, r = tn.split_node_qr(a, [a[0]], [a[1]])
    r[0] | q[1]
    qbar = tn.conj(q)
    q[1] ^ qbar[1]
    u1 = q @ qbar
    qbar[0] ^ q[0]
    u2 = qbar @ q
    blocks, _, shapes = _find_diagonal_sparse_blocks(u1.tensor.flat_charges,
                                                     u1.tensor.flat_flows,
                                                     len(u1.tensor._order[0]))
    for n, block in enumerate(blocks):
        np.testing.assert_almost_equal(
            np.reshape(u1.tensor.data[block], shapes[:, n]),
            np.eye(N=shapes[0, n], M=shapes[1, n]))

    blocks, _, shapes = _find_diagonal_sparse_blocks(u2.tensor.flat_charges,
                                                     u2.tensor.flat_flows,
                                                     len(u2.tensor._order[0]))
    for n, block in enumerate(blocks):
        np.testing.assert_almost_equal(
            np.reshape(u2.tensor.data[block], shapes[:, n]),
            np.eye(N=shapes[0, n], M=shapes[1, n]))
Esempio n. 2
0
def test_split_node_qr(backend):
    a = tn.Node(np.random.rand(2, 3, 4, 5, 6), backend=backend)
    left_edges = []
    for i in range(3):
        left_edges.append(a[i])
    right_edges = []
    for i in range(3, 5):
        right_edges.append(a[i])
    left, right = tn.split_node_qr(a, left_edges, right_edges)
    tn.check_correct([left, right])
    np.testing.assert_allclose(a.tensor, tn.contract(left[3]).tensor)
Esempio n. 3
0
def test_split_node_qr_unitarity_float(backend):
    a = tn.Node(np.random.rand(3, 3), backend=backend)
    q, r = tn.split_node_qr(a, [a[0]], [a[1]])
    q[1] | r[0]
    qbar = tn.conj(q)
    q[1] ^ qbar[1]
    u1 = q @ qbar
    qbar[0] ^ q[0]
    u2 = qbar @ q

    np.testing.assert_almost_equal(u1.tensor, np.eye(3))
    np.testing.assert_almost_equal(u2.tensor, np.eye(3))
Esempio n. 4
0
def test_split_node_qr_unitarity_float(backend):
    a = tn.Node(np.random.rand(3, 3), backend=backend)
    q, _ = tn.split_node_qr(a, [a[0]], [a[1]])
    n1 = tn.Node(q.tensor, backend=backend)
    n2 = tn.linalg.node_linalg.conj(q)
    n1[1] ^ n2[1]
    u1 = tn.contract_between(n1, n2)

    n1 = tn.Node(q.tensor, backend=backend)
    n2 = tn.Node(q.tensor, backend=backend)
    n2[0] ^ n1[0]
    u2 = tn.contract_between(n1, n2)

    np.testing.assert_almost_equal(u1.tensor, np.eye(3))
    np.testing.assert_almost_equal(u2.tensor, np.eye(3))
def test_split_node_qr_unitarity_complex(backend):
  if backend == "pytorch":
    pytest.skip("Complex numbers currently not supported in PyTorch")
  if backend == "jax":
    pytest.skip("Complex QR crashes jax")

  a = tn.Node(np.random.rand(3, 3) + 1j * np.random.rand(3, 3), backend=backend)
  q, r = tn.split_node_qr(a, [a[0]], [a[1]])
  q[1] | r[0]
  qbar = tn.conj(q)
  q[1] ^ qbar[1]
  u1 = q @ qbar
  qbar[0] ^ q[0]
  u2 = qbar @ q

  np.testing.assert_almost_equal(u1.tensor, np.eye(3))
  np.testing.assert_almost_equal(u2.tensor, np.eye(3))
Esempio n. 6
0
def test_split_node_qr_names(backend):
    a = tn.Node(np.zeros((2, 3, 4, 5, 6)), backend=backend)
    left_edges = []
    for i in range(3):
        left_edges.append(a[i])
    right_edges = []
    for i in range(3, 5):
        right_edges.append(a[i])
    left, right = tn.split_node_qr(a,
                                   left_edges,
                                   right_edges,
                                   left_name='left',
                                   right_name='right',
                                   edge_name='edge')
    assert left.name == 'left'
    assert right.name == 'right'
    assert left.edges[-1].name == 'edge'
    assert right.edges[0].name == 'edge'
def test_split_node_qr(dtype, num_charges):
  np.random.seed(10)
  a = tn.Node(
      get_random((6, 7, 8, 9, 10), num_charges=num_charges, dtype=dtype),
      backend='symmetric')
  left_edges = []
  for i in range(3):
    left_edges.append(a[i])
  right_edges = []
  for i in range(3, 5):
    right_edges.append(a[i])
  left, right = tn.split_node_qr(a, left_edges, right_edges)
  tn.check_correct([left, right])
  result = tn.contract(left[3])
  np.testing.assert_allclose(result.tensor.data, a.tensor.data)
  assert np.all([
      charge_equal(result.tensor._charges[n], a.tensor._charges[n])
      for n in range(len(a.tensor._charges))
  ])
Esempio n. 8
0
def test_split_node_qr_names(num_charges):
    np.random.seed(10)
    a = tn.Node(get_random((2, 3, 4, 5, 6), num_charges=num_charges),
                backend='symmetric')
    left_edges = []
    for i in range(3):
        left_edges.append(a[i])
    right_edges = []
    for i in range(3, 5):
        right_edges.append(a[i])
    left, right = tn.split_node_qr(a,
                                   left_edges,
                                   right_edges,
                                   left_name='left',
                                   right_name='right',
                                   edge_name='edge')
    assert left.name == 'left'
    assert right.name == 'right'
    assert left.edges[-1].name == 'edge'
    assert right.edges[0].name == 'edge'
Esempio n. 9
0
def test_split_node_qr_unitarity_complex(backend):
    if backend == "pytorch":
        pytest.skip("Complex numbers currently not supported in PyTorch")
    if backend == "jax":
        pytest.skip("Complex QR crashes jax")

    a = tn.Node(np.random.rand(3, 3) + 1j * np.random.rand(3, 3),
                backend=backend)
    q, _ = tn.split_node_qr(a, [a[0]], [a[1]])
    n1 = tn.Node(q.tensor, backend=backend)
    n2 = tn.linalg.node_linalg.conj(q)
    n1[1] ^ n2[1]
    u1 = tn.contract_between(n1, n2)

    n1 = tn.Node(q.tensor, backend=backend)
    n2 = tn.linalg.node_linalg.conj(q)
    n2[0] ^ n1[0]
    u2 = tn.contract_between(n1, n2)

    np.testing.assert_almost_equal(u1.tensor, np.eye(3))
    np.testing.assert_almost_equal(u2.tensor, np.eye(3))
Esempio n. 10
0
def test_split_node_qr_orig_shape(backend):
    n1 = tn.Node(np.random.rand(3, 4, 5), backend=backend)
    tn.split_node_qr(n1, [n1[0], n1[2]], [n1[1]])
    np.testing.assert_allclose(n1.shape, (3, 4, 5))
Esempio n. 11
0
def test_split_node_qr_of_node_without_backend_raises_error():
    node = np.random.rand(3, 3, 3)
    with pytest.raises(AttributeError):
        tn.split_node_qr(node, left_edges=[], right_edges=[])
Esempio n. 12
0
    def func(self, inputs):
        # C * x_nodes * y_nodes
        peps_nodes = []
        input_nodes = []

        for i in range(self.xnodes):
            peps_line = []
            input_line = []
            for j in range(self.ynodes):
                peps_line.append(
                    tn.Node(self.peps_var[i][j], name=f'p_{i}_{j}'))
                input_line.append(tn.Node(inputs[:, i, j], name=f'i_{i}_{j}'))
            peps_nodes.append(peps_line)
            input_nodes.append(input_line)

        # Connect the edges
        cx, cy = self.xnodes // 2, self.ynodes // 2

        # Input Features
        for i in range(self.xnodes):
            for j in range(self.ynodes):
                input_nodes[i][j][0] ^ peps_nodes[i][j][0]

        # Y Bond
        for i in range(self.xnodes):
            for j in range(self.ynodes - 1):
                index1 = self.index_result[i, j, 1]
                index2 = self.index_result[i, j + 1, 3]
                peps_nodes[i][j][index1] ^ peps_nodes[i][j + 1][index2]

        # X Bond
        for j in range(self.ynodes):
            for i in range(self.xnodes - 1):
                index1 = self.index_result[i, j, 0]
                index2 = self.index_result[i + 1, j, 2]
                peps_nodes[i][j][index1] ^ peps_nodes[i + 1][j][index2]

        # Contract
        # Contract the features
        # contracted_nodes = []
        for i in range(self.xnodes):
            for j in range(self.ynodes):
                input_nodes[i][j] = input_nodes[i][j] @ peps_nodes[i][j]
                input_nodes[i][j].name = f'p_{i}_{j}'
                input_nodes[i][j].tensor = input_nodes[i][j].tensor / \
                    input_nodes[i][j].tensor.norm()

                # contracted_nodes.append(input_nodes[i][j])

        # Contract each row
        left_nodes: List[tn.Node] = input_nodes[0]
        right_nodes: List[tn.Node] = input_nodes[self.xnodes - 1]
        middle_nodes: List[tn.Node] = input_nodes[cx]

        for i in range(1, cx):
            for j in range(self.ynodes):
                left_nodes[j] = left_nodes[j] @ input_nodes[i][j]
                left_nodes[j].name = f'l_{j}'

            # Row Normalization
            row_norm = torch.mean(
                torch.stack([t.tensor.norm() for t in left_nodes]))
            for t in left_nodes:
                t.tensor = t.tensor / row_norm

            # RQ Decomposition
            for j in range(self.ynodes - 1):
                left_edges = []
                right_edges = []

                for edge in left_nodes[j].edges:
                    nxt_node_name = edge.node1.name if edge.node1.name != f'l_{j}' and edge.node1.name != '__unnamed_node__' else edge.node2.name

                    if nxt_node_name[0] == 'p':
                        right_edges.append(edge)
                    elif nxt_node_name == f'l_{j-1}':
                        right_edges.append(edge)
                    else:
                        left_edges.append(edge)

                node1, node2 = tn.split_node_rq(left_nodes[j],
                                                left_edges=left_edges,
                                                right_edges=right_edges)
                left_nodes[j] = node2
                left_nodes[j + 1] = left_nodes[j + 1] @ node1
                # left_nodes[j+1].tensor = left_nodes[j+1].tensor / \
                #     left_nodes[j+1].tensor.norm()
                left_nodes[j].name = f'l_{j}'
                left_nodes[j + 1].name = f'l_{j+1}'

            # SVD
            for j in range(self.ynodes - 1, 0, -1):
                tmp_node = left_nodes[j] @ left_nodes[j - 1]
                left_edges = []
                right_edges = []

                for edge in tmp_node.edges:
                    nxt_node_name = edge.node1.name if edge.node1.name != f'l_{j}' and edge.node1.name != '__unnamed_node__' else edge.node2.name

                    if nxt_node_name == f'p_{i+1}_{j}':
                        left_edges.append(edge)
                    elif nxt_node_name == f'p_{i+1}_{j-1}':
                        right_edges.append(edge)
                    elif nxt_node_name == f'l_{j+1}':
                        left_edges.append(edge)
                    else:
                        right_edges.append(edge)

                node1, node2, _ = tn.split_node(
                    tmp_node,
                    left_edges=left_edges,
                    right_edges=right_edges,
                    max_singular_values=self.max_singular_values)

                left_nodes[j] = node1
                left_nodes[j - 1] = node2
                left_nodes[j].name = f'l_{j}'
                left_nodes[j - 1].name = f'l_{j-1}'

                # QR Decomposition
                left_edges = []
                right_edges = []

                for edge in left_nodes[j].edges:
                    if not edge.node2 and not edge.node1:
                        continue
                    nxt_node_name = edge.node1.name if edge.node1.name != f'l_{j}' and edge.node1.name != '__unnamed_node__' else edge.node2.name

                    if nxt_node_name[0] == 'p':
                        left_edges.append(edge)
                    elif nxt_node_name == f'l_{j+1}':
                        left_edges.append(edge)
                    else:
                        right_edges.append(edge)

                node1, node2 = tn.split_node_qr(left_nodes[j],
                                                left_edges=left_edges,
                                                right_edges=right_edges)

                left_nodes[j] = node1
                left_nodes[j].name = f'l_{j}'
                left_nodes[j - 1] = node2 @ left_nodes[j - 1]
                # left_nodes[j-1].tensor = left_nodes[j-1].tensor / \
                #     left_nodes[j-1].tensor.norm()
                left_nodes[j - 1].name = f'l_{j-1}'

        for i in range(self.xnodes - 2, cx, -1):
            for j in range(self.ynodes):
                right_nodes[j] = right_nodes[j] @ input_nodes[i][j]
                right_nodes[j].name = f'r_{j}'

            # Row Normalization
            row_norm = torch.mean(
                torch.stack([t.tensor.norm() for t in right_nodes]))
            for t in right_nodes:
                t.tensor = t.tensor / row_norm

            # RQ Decomposition
            for j in range(self.ynodes - 1):
                left_edges = []
                right_edges = []

                for edge in right_nodes[j].edges:
                    if not edge.node2 and not edge.node1:
                        continue
                    nxt_node_name = edge.node1.name if edge.node1.name != f'r_{j}' and edge.node1.name != '__unnamed_node__' else edge.node2.name

                    if nxt_node_name[0] == 'p':
                        right_edges.append(edge)
                    elif nxt_node_name == f'r_{j-1}':
                        right_edges.append(edge)
                    else:
                        left_edges.append(edge)

                node1, node2 = tn.split_node_rq(right_nodes[j],
                                                left_edges=left_edges,
                                                right_edges=right_edges)
                right_nodes[j] = node2
                right_nodes[j + 1] = right_nodes[j + 1] @ node1
                # right_nodes[j+1].tensor = right_nodes[j+1].tensor / \
                #     right_nodes[j+1].tensor.norm()
                right_nodes[j].name = f'r_{j}'
                right_nodes[j + 1].name = f'r_{j+1}'

            # SVD
            for j in range(self.ynodes - 1, 0, -1):
                tmp_node = right_nodes[j] @ right_nodes[j - 1]
                left_edges = []
                right_edges = []

                for edge in tmp_node.edges:
                    if not edge.node2 and not edge.node1:
                        continue
                    nxt_node_name = edge.node1.name if edge.node1.name != f'r_{j}' and edge.node1.name != '__unnamed_node__' else edge.node2.name

                    if nxt_node_name == f'p_{i-1}_{j}':
                        left_edges.append(edge)
                    elif nxt_node_name == f'p_{i-1}_{j-1}':
                        right_edges.append(edge)
                    elif nxt_node_name == f'r_{j+1}':
                        left_edges.append(edge)
                    else:
                        right_edges.append(edge)

                node1, node2, _ = tn.split_node(
                    tmp_node,
                    left_edges=left_edges,
                    right_edges=right_edges,
                    max_singular_values=self.max_singular_values)

                right_nodes[j] = node1
                right_nodes[j - 1] = node2
                right_nodes[j].name = f'r_{j}'
                right_nodes[j - 1].name = f'r_{j-1}'

                # QR Decomposition
                left_edges = []
                right_edges = []

                for edge in right_nodes[j].edges:
                    if not edge.node2 and not edge.node1:
                        continue
                    nxt_node_name = edge.node1.name if edge.node1.name != f'r_{j}' and edge.node1.name != '__unnamed_node__' else edge.node2.name

                    if nxt_node_name[0] == 'p':
                        left_edges.append(edge)
                    elif nxt_node_name == f'r_{j+1}':
                        left_edges.append(edge)
                    else:
                        right_edges.append(edge)

                node1, node2 = tn.split_node_qr(right_nodes[j],
                                                left_edges=left_edges,
                                                right_edges=right_edges)

                right_nodes[j] = node1
                right_nodes[j].name = f'r_{j}'
                right_nodes[j - 1] = node2 @ right_nodes[j - 1]
                # right_nodes[j-1].tensor = right_nodes[j-1].tensor / \
                #     right_nodes[j-1].tensor.norm()
                right_nodes[j - 1].name = f'r_{j-1}'

        for j in range(self.ynodes):
            middle_nodes[j] = left_nodes[j] @ middle_nodes[j]
            # middle_nodes[j].tensor = middle_nodes[j].tensor / \
            #     middle_nodes[j].tensor.norm()

        for j in range(self.ynodes):
            middle_nodes[j] = right_nodes[j] @ middle_nodes[j]
            # middle_nodes[j].tensor = middle_nodes[j].tensor / \
            #     middle_nodes[j].tensor.norm()

        down_node = middle_nodes[0]
        up_node = middle_nodes[self.ynodes - 1]

        for j in range(1, cy + 1):
            down_node = down_node @ middle_nodes[j]
            down_node.tensor = down_node.tensor / down_node.tensor.norm()

        for j in range(self.ynodes - 2, cy, -1):
            up_node = up_node @ middle_nodes[j]
            up_node.tensor = up_node.tensor / up_node.tensor.norm()

        result = (down_node @ up_node).tensor

        # Contract the remaining peps (With Auto Mode)
        # result = contractors.auto(contracted_nodes).tensor
        # print(result[0].item())

        result = result.view([10]) / result.norm()
        return result