示例#1
0
def update_bond(psi,
                b,
                wf,
                ortho,
                normalize=True,
                max_truncation_error=None,
                max_bond_dim=None):
    """Update the MPS tensors at bond b using the two-site wave-function wf.
    If the MPS orthogonality center is at site b or b+1,the canonical form
    will be preserved with the new orthogonality center position depending on the value
    of ortho.

    Args:
      psi: The MPS for which the.
      b: The bond to update.
      wf: The two-site wave-function, it is assumed that wf is a tensornetwork.Node of the form
                    s_b      s_b+1
                     |       |
        bond b-1 ----   wf   ----- bond b+1

        where the edges order of wf is [bond b-1 , s_b, s_b+1, bond b+1]
      ortho: 'left' or 'right', on which side of the bond should the orthogonality center
        be located after the update.
      normalize: Whether to keep the wave-function normalized after update.
      max_truncation_error: The maximal allowed truncation error when discarding singular values.
      max_bond_dim: An upper bound on the number of kept singular values values.

    Returns:
      trunc_svals: A list of discarded singular values.
    """
    U, S, V, trunc_svals = tn.split_node_full_svd(
        wf, [wf[0], wf[1]], [wf[2], wf[3]],
        max_truncation_err=max_truncation_error,
        max_singular_values=max_bond_dim)
    S.set_tensor(S.tensor / tn.norm(S))
    if ortho == 'left':
        U = U @ S
        if psi.center_position == b + 1:
            psi.center_position -= 1
    elif ortho == 'right':
        V = S @ V
        if psi.center_position == b:
            psi.center_position += 1
    else:
        raise ValueError("ortho must be 'left' or 'right'")

    tn.disconnect(U[-1])

    psi.nodes[b] = U
    psi.nodes[b + 1] = V

    return trunc_svals
示例#2
0
def test_disconnect_edge(backend):
    a = tn.Node(np.array([1.0] * 5), "a", backend=backend)
    b = tn.Node(np.array([1.0] * 5), "b", backend=backend)
    e = tn.connect(a[0], b[0])
    assert not e.is_dangling()
    dangling_edge_1, dangling_edge_2 = tn.disconnect(e)
    assert dangling_edge_1.is_dangling()
    assert dangling_edge_2.is_dangling()
    assert a.get_edge(0) == dangling_edge_1
    assert b.get_edge(0) == dangling_edge_2
示例#3
0
def test_disconnect_dangling_edge_value_error(backend):
    a = tn.Node(np.eye(2), backend=backend)
    with pytest.raises(ValueError):
        tn.disconnect(a[0])