def schmidt_decomposition_tensornetwork(bipartitepurestate_tensor): """ Calculate the Schmidt decomposition of the given discrete bipartite quantum system This is called by :func:`schmidt_decomposition`. This runs tensornetwork. :param bipartitepurestate_tensor: tensor describing the bi-partitite states, with each elements the coefficients for :math:`|ij\\rangle` :return: list of tuples containing the Schmidt coefficient, eigenmode for first subsystem, and eigenmode for second subsystem :type bipartitepurestate_tensor: numpy.ndarray :rtype: list """ state_dims = bipartitepurestate_tensor.shape mindim = np.min(state_dims) node = tn.Node(bipartitepurestate_tensor) vecs1, diags, vecs2_h, _ = tn.split_node_full_svd(node, [node[0]], [node[1]]) decomposition = [(diags.tensor[k, k], vecs1.tensor[:, k], vecs2_h.tensor[k, :]) for k in range(mindim)] decomposition = sorted(decomposition, key=lambda dec: dec[0], reverse=True) return decomposition
def update_dis(hamiltonian, state, isometry, disentangler): """Updates the disentangler with the aim of reducing the energy. Args: hamiltonian: The hamiltonian (rank-6 tensor) defined at the bottom of the MERA layer. state: The 3-site reduced state (rank-6 tensor) defined at the top of the MERA layer. isometry: The isometry tensor (rank 3) of the binary MERA. disentangler: The disentangler tensor (rank 4) of the binary MERA. Returns: The updated disentangler. """ env = env_dis(hamiltonian, state, isometry, disentangler) nenv = tensornetwork.Node( env, axis_names=["bl", "br", "tl", "tr"], backend="jax") output_edges = [nenv["bl"], nenv["br"], nenv["tl"], nenv["tr"]] nu, _, nv, _ = tensornetwork.split_node_full_svd( nenv, [nenv["bl"], nenv["br"]], [nenv["tl"], nenv["tr"]], left_edge_name="s1", right_edge_name="s2") nu["s1"].disconnect() nv["s2"].disconnect() tensornetwork.connect(nu["s1"], nv["s2"]) nres = tensornetwork.contract_between(nu, nv, output_edge_order=output_edges) return np.conj(nres.get_tensor())
def convert_mps(state, dia=False): n = len(state.tensor.shape) S = [tn.Node(np.array([0])) for i in range(n)] if (dia): D = [tn.Node(np.array([0])) for i in range(n - 1)] S[0], D[0], P, _ = tn.split_node_full_svd(state, state[:1], state[1:], max_truncation_err=10e-5) for i in range(1, n - 1): S[i], D[i], P, _ = tn.split_node_full_svd( P, P[:2], P[2:], max_truncation_err=10e-5) S[-1] = P del state, n, P, i, dia return S, D else: S[0], D, P, _ = tn.split_node_full_svd(state, state[:1], state[1:], max_truncation_err=10e-5) P = D @ P for i in range(1, n - 1): S[i], D, P, _ = tn.split_node_full_svd( P, P[:2], P[2:], max_truncation_err=10e-5) P = D @ P S[-1] = P del state, n, P, i, dia, D return S
def convert_mpo(ope, dia=False): n = int(len(ope.tensor.shape) * 0.5) O = [tn.Node(np.array([0.0])) for i in range(n)] if (dia): D = [tn.Node(np.array([0.0])) for i in range(n - 1)] O[0], D[0], P, _ = tn.split_node_full_svd(ope, ope[:2], ope[2:], max_truncation_err=10e-5) for i in range(1, n - 1): O[i], D[i], P, _ = tn.split_node_full_svd( P, P[:3], P[3:], max_truncation_err=10e-5) O[-1] = P del ope, dia, n, P, i return O, D else: O[0], D, P, _ = tn.split_node_full_svd(ope, ope[:2], ope[2:], max_truncation_err=10e-5) P = D @ P for i in range(1, n - 1): O[i], D, P, _ = tn.split_node_full_svd( P, P[:3], P[3:], max_truncation_err=10e-5) P = D @ P O[-1] = P del ope, dia, n, P, i, D return O
def test_split_node_full_svd_relative_tolerance(backend): absolute = tn.Node(np.diag([2.0, 1.0, 0.2, 0.1]), backend=backend) relative = tn.Node(np.diag([2.0, 1.0, 0.2, 0.1]), backend=backend) max_truncation_err = 0.2 _, _, _, trunc_sv_absolute, = tn.split_node_full_svd( node=absolute, left_edges=[absolute[0]], right_edges=[absolute[1]], max_truncation_err=max_truncation_err, relative=False) _, _, _, trunc_sv_relative, = tn.split_node_full_svd( node=relative, left_edges=[relative[0]], right_edges=[relative[1]], max_truncation_err=max_truncation_err, relative=True) np.testing.assert_almost_equal(trunc_sv_absolute, [0.1]) np.testing.assert_almost_equal(trunc_sv_relative, [0.2, 0.1])
def test_split_node_full_svd(backend): unitary1 = np.array([[1.0, 1.0], [1.0, -1.0]]) / np.sqrt(2.0) unitary2 = np.array([[0.0, 1.0], [1.0, 0.0]]) singular_values = np.array([9.1, 7.5], dtype=np.float32) val = np.dot(unitary1, np.dot(np.diag(singular_values), (unitary2.T))) a = tn.Node(val, backend=backend) e1 = a[0] e2 = a[1] _, s, _, _, = tn.split_node_full_svd(a, [e1], [e2]) tn.check_correct(tn.reachable(s)) np.testing.assert_allclose(s.tensor, np.diag([9.1, 7.5]), rtol=1e-5)
def update_bond(psi, b, wf, ortho, normalize=True, max_truncation_error=None, max_bond_dim=None): """Update the MPS tensors at bond b using the two-site wave-function wf. If the MPS orthogonality center is at site b or b+1,the canonical form will be preserved with the new orthogonality center position depending on the value of ortho. Args: psi: The MPS for which the. b: The bond to update. wf: The two-site wave-function, it is assumed that wf is a tensornetwork.Node of the form s_b s_b+1 | | bond b-1 ---- wf ----- bond b+1 where the edges order of wf is [bond b-1 , s_b, s_b+1, bond b+1] ortho: 'left' or 'right', on which side of the bond should the orthogonality center be located after the update. normalize: Whether to keep the wave-function normalized after update. max_truncation_error: The maximal allowed truncation error when discarding singular values. max_bond_dim: An upper bound on the number of kept singular values values. Returns: trunc_svals: A list of discarded singular values. """ U, S, V, trunc_svals = tn.split_node_full_svd( wf, [wf[0], wf[1]], [wf[2], wf[3]], max_truncation_err=max_truncation_error, max_singular_values=max_bond_dim) S.set_tensor(S.tensor / tn.norm(S)) if ortho == 'left': U = U @ S if psi.center_position == b + 1: psi.center_position -= 1 elif ortho == 'right': V = S @ V if psi.center_position == b: psi.center_position += 1 else: raise ValueError("ortho must be 'left' or 'right'") tn.disconnect(U[-1]) psi.nodes[b] = U psi.nodes[b + 1] = V return trunc_svals
def schmidt_decomposition_tensornetwork(bipartitepurestate_tensor): state_dims = bipartitepurestate_tensor.shape mindim = np.min(state_dims) node = tn.Node(bipartitepurestate_tensor) vecs1, diags, vecs2_h, _ = tn.split_node_full_svd(node, [node[0]], [node[1]]) decomposition = [(diags.tensor[k, k], vecs1.tensor[:, k], vecs2_h.tensor[k, :]) for k in range(mindim)] decomposition = sorted(decomposition, key=lambda dec: dec[0], reverse=True) return decomposition
def test_split_node_full_svd_names(backend): a = tn.Node(np.random.rand(10, 10), backend=backend) e1 = a[0] e2 = a[1] left, s, right, _, = tn.split_node_full_svd(a, [e1], [e2], left_name='left', middle_name='center', right_name='right', left_edge_name='left_edge', right_edge_name='right_edge') assert left.name == 'left' assert s.name == 'center' assert right.name == 'right' assert left.edges[-1].name == 'left_edge' assert s[0].name == 'left_edge' assert s[1].name == 'right_edge' assert right.edges[0].name == 'right_edge'
def test_split_node_full_svd_names(num_charges): np.random.seed(10) a = tn.Node(get_random((10, 10), num_charges=num_charges), backend='symmetric') e1 = a[0] e2 = a[1] left, s, right, _, = tn.split_node_full_svd(a, [e1], [e2], left_name='left', middle_name='center', right_name='right', left_edge_name='left_edge', right_edge_name='right_edge') assert left.name == 'left' assert s.name == 'center' assert right.name == 'right' assert left.edges[-1].name == 'left_edge' assert s[0].name == 'left_edge' assert s[1].name == 'right_edge' assert right.edges[0].name == 'right_edge'
def getRenyiEntropy(psi: List[tn.Node], n: int, ASize: int, maxBondDim=1024): psiCopy = copyState(psi) for k in [len(psiCopy) - 1 - i for i in range(len(psiCopy) - ASize - 1)]: psiCopy = shiftWorkingSite(psiCopy, k, '<<') M = multiContraction(psiCopy[ASize - 1], psiCopy[ASize], [2], [0]) leftEdges = M.edges[:2] rightEdges = M.edges[2:] maxBondDim = getAppropriateMaxBondDim(maxBondDim, leftEdges, rightEdges) [U, S, V, truncErr] = tn.split_node_full_svd(M, leftEdges, rightEdges, max_singular_values=maxBondDim) eigenvaluesRoots = np.diag(S.tensor) result = sum([l**(2 * n) for l in eigenvaluesRoots]) removeState(psiCopy) return result
def svdTruncation(node: tn.Node, leftEdges: List[tn.Edge], rightEdges: List[tn.Edge], \ dir: str, maxBondDim=1024, leftName='U', rightName='V', edgeName=None): maxBondDim = getAppropriateMaxBondDim(maxBondDim, leftEdges, rightEdges) if dir == '>>': leftEdgeName = edgeName rightEdgeName = None else: leftEdgeName = None rightEdgeName = edgeName [U, S, V, truncErr] = tn.split_node_full_svd(node, leftEdges, rightEdges, max_singular_values=maxBondDim, \ left_name=leftName, right_name=rightName, \ left_edge_name=leftEdgeName, right_edge_name=rightEdgeName) if dir == '>>': l = copyState([U])[0] r = copyState([tn.contract_between(S, V, name=V.name)])[0] else: l = copyState([tn.contract_between(U, S, name=U.name)])[0] r = copyState([V])[0] tn.remove_node(U) tn.remove_node(S) tn.remove_node(V) return [l, r, truncErr]
def test_split_node_full_svd_orig_shape(backend): n1 = tn.Node(np.random.rand(3, 4, 5), backend=backend) tn.split_node_full_svd(n1, [n1[0], n1[2]], [n1[1]]) np.testing.assert_allclose(n1.shape, (3, 4, 5))
def test_split_node_full_svd_of_node_without_backend_raises_error(): node = np.random.rand(3, 3, 3) with pytest.raises(AttributeError): tn.split_node_full_svd(node, left_edges=[], right_edges=[])
def svdTruncation(node: tn.Node, leftEdges: List[int], rightEdges: List[int], dir: str, maxBondDim=128, leftName='U', rightName='V', edgeName='default', normalize=False, maxTrunc=8): maxBondDim = getAppropriateMaxBondDim(maxBondDim, [node.edges[e] for e in leftEdges], [node.edges[e] for e in rightEdges]) if dir == '>>': leftEdgeName = edgeName rightEdgeName = None else: leftEdgeName = None rightEdgeName = edgeName [U, S, V, truncErr] = tn.split_node_full_svd(node, [node.edges[e] for e in leftEdges], [node.edges[e] for e in rightEdges], max_singular_values=maxBondDim, left_name=leftName, right_name=rightName, left_edge_name=leftEdgeName, right_edge_name=rightEdgeName) s = S S = tn.Node(np.diag(S.tensor)) tn.remove_node(s) norm = np.sqrt(sum(S.tensor**2)) if maxTrunc > 0: meaningful = sum(np.round(S.tensor / norm, maxTrunc) > 0) S.tensor = S.tensor[:meaningful] U.tensor = np.transpose(np.transpose(U.tensor)[:meaningful]) V.tensor = V.tensor[:meaningful] if normalize: S = multNode(S, 1 / norm) for e in S.edges: e.name = edgeName if dir == '>>': l = copyState([U])[0] r = multiContraction(S, V, '1', '0', cleanOr1=True, cleanOr2=True, isDiag1=True) elif dir == '<<': l = multiContraction(U, S, [len(U.edges) - 1], '0', cleanOr1=True, cleanOr2=True, isDiag2=True) r = copyState([V])[0] elif dir == '>*<': v = V V = copyState([V])[0] tn.remove_node(v) u = U U = copyState([U])[0] tn.remove_node(u) return [U, S, V, truncErr] tn.remove_node(U) tn.remove_node(S) tn.remove_node(V) return [l, r, truncErr]
def zip_up( self, array: Any, axes: Optional[List[Tuple]] = None, left_index: Optional[int] = None, right_index: Optional[int] = None, direction: Optional[Text] = "right", max_singular_values: Optional[int] = None, max_truncation_err: Optional[float] = None, relative: Optional[bool] = False, copy: Optional[bool] = True) -> List[Tuple]: """ ................................................................... | | | | | | | B~~B | | | ~~A~~A~~A~~ , B~~B ==> | | | = ~~C~~C~~C~~ | | ~~A~~A~~A~~ self , array ==> new self ................................................................... | | | | | | | B~~B~~ | | | ~~A~~A~~A , B~~B~~ ==> | | | = ~~C~~C~~C~~ | | ~~A~~A~~A self , array ==> new self ................................................................... """ # -- input parsing -- if axes is None: axes = [(0,0)] assert self.rank + array.rank - 2*len(axes) > 0, \ "This contraction would lead to nodes with no legs. " \ + "To fully contract node use NodeArray.contract()." _left_index, _right_index = _parse_left_right_index(self, array, left_index, right_index) b = array.copy() if copy else array # -- handle left and right edges -- if b.left: assert _left_index == 0 assert not self.left self.left_edge = b.left_edge if b.right: assert _right_index == len(self) - 1 assert not self.right self.right_edge = b.right_edge # -- set variables for directions -- # Variables reappear in comments below. if direction == "right": from_index = _left_index to_index = _right_index sign = +1 reverse = False ia_start_border = 0 outer_start = self.left outer_start_edge = self.left_edge outer_bond = -1 inner_bond = 0 elif direction == "left": reverse = False from_index = _right_index to_index = _left_index sign = -1 reverse = True ia_start_border = len(self)-1 outer_start = self.right outer_start_edge = self.right_edge outer_bond = 0 inner_bond = -1 else: raise ValueError() carry_node = None singular_values = [] # sign ( 1 / -1 ) ias = range(from_index, to_index + sign, sign) # reverse ( False / True ) ibs = reversed(range(len(b))) if reverse else range(len(b)) for ia, ib in zip(ias, ibs): ax_a_list = [] ax_b_list = [] for ax_a, ax_b in axes: self.array_edges[ia][ax_a] ^ b.array_edges[ib][ax_b] ax_a_list.append(ax_a) ax_b_list.append(ax_b) for ax_a in sorted(ax_a_list, reverse=True): del self.array_edges[ia][ax_a] for ax_b in sorted(ax_b_list, reverse=True): del b.array_edges[ib][ax_b] self.array_edges[ia].extend(b.array_edges[ib]) contraction_nodes = [self.nodes[ia], b.nodes[ib]] if carry_node is not None: contraction_nodes.append(carry_node) contracted_node = tn.contractors.greedy(contraction_nodes, ignore_edge_order=True) # to_index (_right_index / _left_index) if ia == to_index: self.nodes[ia] = contracted_node else: u_edges = [] # ia_start_border (0 / len(self)-1) if ia == ia_start_border: # outer_start (self.left / self.right) if outer_start: # outer_start_edge (self.left_edge / self.right_edge) u_edges.append(outer_start_edge) else: # outer_bond (-1 / 0) u_edges.append(self.bond_edges[ia + outer_bond]) u_edges.extend(self.array_edges[ia]) # inner_bond (0 / -1) v_edges = [self.bond_edges[ia + inner_bond], b.bond_edges[ib + inner_bond]] u, s, vh, trun_vals = tn.split_node_full_svd( node=contracted_node, left_edges=u_edges, right_edges=v_edges, max_singular_values=max_singular_values, max_truncation_err=max_truncation_err, relative=relative) singular_values.append((s.tensor.diagonal(),trun_vals)) self.nodes[ia] = u # inner_bond (0 / -1) self.bond_edges[ia + inner_bond] = s[0] carry_node = s @ vh return singular_values
def svd_sweep( self, from_index: Optional[int] = 0, to_index: Optional[int] = -1, max_singular_values: Optional[int] = None, max_truncation_err: Optional[float] = None, relative: Optional[bool] = False) -> None: """ ................................................................... | | | | | | | | | | | | --A1~~A2~~A3~~A4~~A5~~A6-- ==> --A1~~A2~~a3~~a4~~a5~~A6-- self ==> new self ................................................................... """ _from_index = len(self) + from_index if from_index<0 else from_index _to_index = len(self) + to_index if to_index<0 else to_index if _from_index < 0 or _from_index >= len(self): raise IndexError("Index out of range.") if _to_index < 0 or _to_index >= len(self): raise IndexError("Index out of range.") singular_values = [] if _from_index < _to_index: for i in range(_from_index, _to_index): u_edges = [] if i == 0: if self.left: u_edges.append(self.left_edge) else: u_edges.append(self.bond_edges[i-1]) u_edges.extend(self.array_edges[i]) v_edges = [self.bond_edges[i]] u, s, vh, trun_vals = tn.split_node_full_svd( node=self.nodes[i], left_edges=u_edges, right_edges=v_edges, max_singular_values=max_singular_values, max_truncation_err=max_truncation_err, relative=relative) singular_values.append((s.tensor.diagonal(),trun_vals)) self.nodes[i] = u self.bond_edges[i] = s[0] svh = s @ vh self.nodes[i+1] = svh @ self.nodes[i+1] elif _to_index < _from_index: for i in range(_from_index, _to_index, -1): u_edges = [] u_edges.extend(self.array_edges[i]) if i == len(self)-1: if self.right: u_edges.append(self.right_edge) else: u_edges.append(self.bond_edges[i]) v_edges = [self.bond_edges[i-1]] u, s, vh, trun_vals = tn.split_node_full_svd( node=self.nodes[i], left_edges=u_edges, right_edges=v_edges, max_singular_values=max_singular_values, max_truncation_err=max_truncation_err, relative=relative) singular_values.append((s.tensor.diagonal(),trun_vals)) self.nodes[i] = u self.bond_edges[i-1] = s[0] svh = s @ vh self.nodes[i-1] = svh @ self.nodes[i-1] return singular_values