def update_bond(psi, b, wf, ortho, normalize=True, max_truncation_error=None, max_bond_dim=None): """Update the MPS tensors at bond b using the two-site wave-function wf. If the MPS orthogonality center is at site b or b+1,the canonical form will be preserved with the new orthogonality center position depending on the value of ortho. Args: psi: The MPS for which the. b: The bond to update. wf: The two-site wave-function, it is assumed that wf is a tensornetwork.Node of the form s_b s_b+1 | | bond b-1 ---- wf ----- bond b+1 where the edges order of wf is [bond b-1 , s_b, s_b+1, bond b+1] ortho: 'left' or 'right', on which side of the bond should the orthogonality center be located after the update. normalize: Whether to keep the wave-function normalized after update. max_truncation_error: The maximal allowed truncation error when discarding singular values. max_bond_dim: An upper bound on the number of kept singular values values. Returns: trunc_svals: A list of discarded singular values. """ U, S, V, trunc_svals = tn.split_node_full_svd( wf, [wf[0], wf[1]], [wf[2], wf[3]], max_truncation_err=max_truncation_error, max_singular_values=max_bond_dim) S.set_tensor(S.tensor / tn.norm(S)) if ortho == 'left': U = U @ S if psi.center_position == b + 1: psi.center_position -= 1 elif ortho == 'right': V = S @ V if psi.center_position == b: psi.center_position += 1 else: raise ValueError("ortho must be 'left' or 'right'") tn.disconnect(U[-1]) psi.nodes[b] = U psi.nodes[b + 1] = V return trunc_svals
def test_disconnect_edge(backend): a = tn.Node(np.array([1.0] * 5), "a", backend=backend) b = tn.Node(np.array([1.0] * 5), "b", backend=backend) e = tn.connect(a[0], b[0]) assert not e.is_dangling() dangling_edge_1, dangling_edge_2 = tn.disconnect(e) assert dangling_edge_1.is_dangling() assert dangling_edge_2.is_dangling() assert a.get_edge(0) == dangling_edge_1 assert b.get_edge(0) == dangling_edge_2
def test_disconnect_dangling_edge_value_error(backend): a = tn.Node(np.eye(2), backend=backend) with pytest.raises(ValueError): tn.disconnect(a[0])