def vumps_delta(mps: ThreeTensors, A_C: tn.Tensor, oldA_L: tn.Tensor, mode: Text): """ Estimate the current vuMPS error. Args: mps: The MPS. A_C: Current A_C. oldA_L: A_L from the last iteration. mode: gradient_estimate_mode in vumps_params. See that docstring for details. """ if mode == "gauge mismatch": A_L, C, A_R = mps eL = tn.norm(A_C - ct.rightmult(A_L, C)) eR = tn.norm(A_C - ct.leftmult(C, A_R)) delta = max(eL, eR) elif mode == "null space": A_Ldag = tn.pivot(oldA_L, pivot_axis=2).H N_Ldag = tn_vumps.polar.null_space(A_Ldag) N_L = N_Ldag.H A_Cmat = tn.pivot(A_C, pivot_axis=2) B = N_L @ A_Cmat delta = tn.norm(B) else: raise ValueError("Invalid mode {mode}.") return delta
def minimum_eigenpair( matvec: Callable, mv_args: Sequence, guess: tn.Tensor, tol: float, max_restarts: int = 100, n_krylov: int = 40, reorth: bool = True, n_diag: int = 10, verbose: bool = True) -> Tuple[tn.Tensor, tn.Tensor, float]: """ Finds the eigenpair of the Hermitian matvec with the most negative eigenvalue by the explicitly restarted Lanczos method. Args: matvec: The function x = matvec(x, *mv_args) representing the operator. mv_args: A list of fixed positional arguments to matvec. guess: Guess eigenvector. tol: The degree to which the returned eigenpair is allowed to deviated from solving the eigenvalue equation. max_restarts: The Krylov space will be rebuilt at most this many times even if tol is not achieved. n_krylov: Size of the Krylov space to build. reorth: If True the Krylov space will be explicitly reorthogonalized at each solver iteration. n_diag: An argument to TN Lanczos that currently has no effect. verbose: If True a warning will be printed to console if tol was not reached. Returns: ev, eV, err: The eigenvalue, vector, and error. """ eV = guess for _ in range(max_restarts): out = tn.linalg.krylov.eigsh_lanczos(matvec, backend=eV.backend, args=mv_args, x0=eV, numeig=1, num_krylov_vecs=n_krylov, ndiag=n_diag, reorthogonalize=reorth) ev, eV = out ev = ev[0] eV = eV[0] Ax = matvec(eV, *mv_args) e_eV = ev * eV rho = tn.norm(tn.abs(Ax - e_eV)) err = rho # / jnp.linalg.norm(e_eV) if err <= tol: return (ev, eV, err) if verbose: print("Warning: eigensolve exited early with error=", err) benchmark.block_until_ready(eV) return ev, eV, err
def update_bond(psi, b, wf, ortho, normalize=True, max_truncation_error=None, max_bond_dim=None): """Update the MPS tensors at bond b using the two-site wave-function wf. If the MPS orthogonality center is at site b or b+1,the canonical form will be preserved with the new orthogonality center position depending on the value of ortho. Args: psi: The MPS for which the. b: The bond to update. wf: The two-site wave-function, it is assumed that wf is a tensornetwork.Node of the form s_b s_b+1 | | bond b-1 ---- wf ----- bond b+1 where the edges order of wf is [bond b-1 , s_b, s_b+1, bond b+1] ortho: 'left' or 'right', on which side of the bond should the orthogonality center be located after the update. normalize: Whether to keep the wave-function normalized after update. max_truncation_error: The maximal allowed truncation error when discarding singular values. max_bond_dim: An upper bound on the number of kept singular values values. Returns: trunc_svals: A list of discarded singular values. """ U, S, V, trunc_svals = tn.split_node_full_svd( wf, [wf[0], wf[1]], [wf[2], wf[3]], max_truncation_err=max_truncation_error, max_singular_values=max_bond_dim) S.set_tensor(S.tensor / tn.norm(S)) if ortho == 'left': U = U @ S if psi.center_position == b + 1: psi.center_position -= 1 elif ortho == 'right': V = S @ V if psi.center_position == b: psi.center_position += 1 else: raise ValueError("ortho must be 'left' or 'right'") tn.disconnect(U[-1]) psi.nodes[b] = U psi.nodes[b + 1] = V return trunc_svals
def test_norm_of_node_without_backend_raises_error(): node = np.random.rand(3, 3, 3) with pytest.raises(AttributeError): tn.norm(node)