示例#1
0
def solve_bwd(params, res, grad):
    x, z = res
    x_grad, _ = grad
    x_adj, _ = solve_impl(z, x_grad, adjoint=True, params=params)
    z_grad = tuple(np.real(np.conj(a) * b) for a, b in zip(x_adj, x))
    b_grad = tuple(np.conj(a) for a in x_adj)
    return z_grad, b_grad
示例#2
0
def jaxsvd_bwd(r, tangents):
    U, S, V = r
    du, ds, dv = tangents

    dU = jnp.conj(du)
    dS = jnp.conj(ds)
    dV = jnp.transpose(dv)

    ms = jnp.diag(S)
    ms1 = jnp.diag(_safe_reciprocal(S))
    dAs = U @ jnp.diag(dS) @ V

    F = S * S - (S * S)[:, None]
    F = _safe_reciprocal(F) - jnp.diag(jnp.diag(_safe_reciprocal(F)))

    J = F * (h(U) @ dU)
    dAu = U @ (J + h(J)) @ ms @ V

    K = F * (V @ dV)
    dAv = U @ ms @ (K + h(K)) @ V

    O = h(dU) @ U @ ms1
    dAc = -1 / 2.0 * U @ (jnp.diag(jnp.diag(O - jnp.conj(O)))) @ V

    dAv = dAv + U @ ms1 @ h(dV) @ (jnp.eye(jnp.size(V[1, :])) - h(V) @ V)
    dAu = dAu + (jnp.eye(jnp.size(U[:, 1])) - U @ h(U)) @ dU @ ms1 @ V
    grad_a = jnp.conj(dAv + dAu + dAs + dAc)
    return (grad_a, )
示例#3
0
文件: _lbfgs.py 项目: romanngg/jax
 def body_fun1(j, carry):
   i = his_size - 1 - j
   _q, _a_his = carry
   a_i = state.rho_history[i] * _dot(jnp.conj(state.s_history[i]), _q).real.astype(dtype)
   _a_his = _a_his.at[i].set(a_i)
   _q = _q - a_i * jnp.conj(state.y_history[i])
   return _q, _a_his
示例#4
0
    def matvec(vec):
        x = vec.reshape((4, 2**(N - 2)))
        out = jnp.zeros(x.shape, x.dtype)
        t1 = neye * pot[0] + eyen * pot[1] / 2
        t2 = cTc * hop[0] - ccT * jnp.conj(hop[0])
        out += jnp.einsum('ij,ki -> kj', x, t1 + t2)
        x = x.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(
            (4, 2**(N - 2)))
        out = out.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(
            (4, 2**(N - 2)))
        for site in range(1, N - 2):
            t1 = neye * pot[site] / 2 + eyen * pot[site + 1] / 2
            t2 = cTc * hop[site] - ccT * jnp.conj(hop[site])
            out += jnp.einsum('ij,ki -> kj', x, t1 + t2)
            x = x.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(
                (4, 2**(N - 2)))
            out = out.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(
                (4, 2**(N - 2)))
        t1 = neye * pot[N - 2] / 2 + eyen * pot[N - 1]
        t2 = cTc * hop[N - 2] - ccT * jnp.conj(hop[N - 2])
        out += jnp.einsum('ij,ki -> kj', x, t1 + t2)
        x = x.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(
            (4, 2**(N - 2)))
        out = out.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(
            (4, 2**(N - 2)))

        x = x.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(2**N)
        out = out.reshape((2, 2**(N - 1))).transpose((1, 0)).reshape(2**N)
        return out.ravel()
示例#5
0
    def update(b, env):
        c, t = env

        # C insertion
        c_tilde = ncon([c, t, t, b, np.conj(b)],
                       [[1, 2], [2, 3, 4, -4], [-1, 5, 6, 1],
                        [7, 3, 5, -2, -5], [7, 4, 6, -3, -6]],
                       [1, 2, 3, 5, 4, 6, 7])
        # (D,d,d',R,r,r') -> (D~,R~)
        _D, _d, _d_, _R, _r, _r_ = c_tilde.shape
        c_tilde = np.reshape(c_tilde, [_D * _d * _d_, _R * _r * _r_])
        # T insertion
        t_tilde = ncon(
            [t, b, np.conj(b)],
            [[-1, 1, 2, -6], [3, 1, -2, -4, -7], [3, 2, -3, -5, -8]])
        # (L,l,l',d,d',R,r,r') -> (L~,d,d',R~)
        _L, _l, _l_, _d, _d_, _R, _r, _r_ = t_tilde.shape
        t_tilde = np.reshape(t_tilde, [_L * _l * _l_, _d, _d_, _R * _r * _r_])

        # enforce symmetry
        c_tilde = _c4v_symmetrise_c(c_tilde)

        # find projector
        P, _, _ = svd_truncated(c_tilde, chi_max=chi_ctm, cutoff=0.)  # (D~,R)

        # renormalise
        c = np.transpose(P) @ c_tilde @ P
        t = ncon([np.conj(P), t_tilde, np.conj(P)],
                 [[1, -1], [1, -2, -3, 2], [2, -4]])

        # enforce symmetry
        env = _c4v_symmetrise(c, t)

        return env
示例#6
0
def lqpos(mps):
    """
    Reshapes the (chiL, d, chiR) MPS tensor into a (chiL, d*chiR) matrix,
    and computes its LQ decomposition, with the phase of L fixed so as to
    have a non-negative main diagonal. A new right-orthogonal
    (chiL, d, chiR) MPS tensor (reshaped from Q) is returned along with
    L.
    In addition to being phase-adjusted, L is normalized by division with
    its L2 norm.

    PARAMETERS
    ----------
    mps (array-like): The (chiL, d, chiR) MPS tensor.

    RETURNS
    -------
    L, mps_R:  A lower-triangular (chiL x chiL) matrix with a non-negative
               main-diagonal, and a right-orthogonal (chiL, d, chiR) MPS
               tensor such that mps = L @ mps_R.
    """
    chiL, d, chiR = mps.shape
    mps_mat = jnp.reshape(mps, (chiL, chiR * d))
    mps_mat = jnp.conj(mps_mat.T)
    Qdag, Ldag = jnp.linalg.qr(mps_mat)
    Q = jnp.conj(Qdag.T)
    L = jnp.conj(Ldag.T)
    phases = jnp.sign(jnp.diag(L))
    L = L * phases
    L = L / jnp.linalg.norm(L)
    Q = jnp.conj(phases)[:, None] * Q
    mps_R = Q.reshape(mps.shape)
    return (L, mps_R)
示例#7
0
    def update(i, state, u):
        w, z, r, betapow = state
        z = jnp.concatenate((r2c(mimo(w, u)[None, :]), z[:-1, :]))
        z0 = jnp.repeat(z, dims, axis=-1)
        z1 = jnp.tile(z, (1, dims))
        rt = jax.vmap(lambda a, b: a[0] * b.conj(), in_axes=-1,
                      out_axes=0)(z0, z1).reshape(r.shape)
        r = beta * r + (1 - beta) * rt  # exponential moving average
        rhat = r / (1 - betapow)  # bias correction due to small beta
        r_sqsum = jnp.sum(jnp.abs(rhat)**2, axis=-1)

        v = mimo(w, u)
        lcma = jnp.sum(jnp.abs(jnp.abs(v)**2 - R2)**2)
        lmu = 2 * (jnp.sum(r_sqsum) - jnp.sum(jnp.diag(r_sqsum)))
        gcma = 4 * (v * (jnp.abs(v)**2 - R2))[..., None,
                                              None] * jnp.conj(u).T[None, ...]
        gmu_tmp_full = (4 * rhat[..., None, None] *
                        z.T[None, ..., None, None] *
                        jnp.conj(u).T[None, None, None, ...]
                        )  # shape: [dims, dims, delta, dims, T]
        # reduce delta axis first
        gmu_tmp_dr = jnp.sum(gmu_tmp_full,
                             axis=2)  # shape: [dims, dims, dims, T]
        # cross correlation = full correlation - self correlation
        gmu = jnp.sum(gmu_tmp_dr, axis=1) - gmu_tmp_dr[jnp.arange(dims),
                                                       jnp.arange(dims), ...]
        l = lcma + lmu
        g = gcma + gmu

        out = (w, l)
        w = w - lr(i) * g
        betapow *= beta
        state = (w, z, r, betapow)
        return state, out
示例#8
0
文件: _lbfgs.py 项目: yashk2810/jax
 def body_fun1(j, carry):
     i = his_size - 1 - j
     _q, _a_his = carry
     a_i = state.rho_history[i] * jnp.real(
         _dot(jnp.conj(state.s_history[i]), _q))
     _a_his = ops.index_update(_a_his, ops.index[i], a_i)
     _q = _q - a_i * jnp.conj(state.y_history[i])
     return _q, _a_his
def jit_my_stuff():

    global _sum_up_pmapd
    global _sum_sq_pmapd
    global _sum_sq_withp_pmapd
    global mean_helper
    global cov_helper_with_p
    global cov_helper_without_p
    global pmapDevices
    global jitDevice

    if global_defs.usePmap:
        if pmap_devices_updated():
            _sum_up_pmapd = global_defs.pmap_for_my_devices(
                lambda x: jax.lax.psum(jnp.sum(x, axis=0), 'i'), axis_name='i')
            _sum_sq_pmapd = global_defs.pmap_for_my_devices(
                lambda data, mean: jax.lax.psum(
                    jnp.sum(jnp.conj(data - mean) *
                            (data - mean), axis=0), 'i'),
                axis_name='i',
                in_axes=(0, None))
            _sum_sq_withp_pmapd = global_defs.pmap_for_my_devices(
                lambda data, mean, p: jax.lax.psum(
                    jnp.conj(data - mean).dot(p * (data - mean)), 'i'),
                axis_name='i',
                in_axes=(0, None, 0))
            mean_helper = global_defs.pmap_for_my_devices(
                lambda data, p: jnp.expand_dims(jnp.dot(p, data), axis=0),
                in_axes=(0, 0))
            cov_helper_with_p = global_defs.pmap_for_my_devices(
                _cov_helper_with_p, in_axes=(0, 0))
            cov_helper_without_p = global_defs.pmap_for_my_devices(
                _cov_helper_without_p)

            pmapDevices = global_defs.myPmapDevices

    else:
        if jitDevice != global_defs.myDevice:
            _sum_up_pmapd = global_defs.jit_for_my_device(
                lambda x: jnp.expand_dims(jnp.sum(x, axis=0), axis=0))
            _sum_sq_pmapd = global_defs.jit_for_my_device(
                lambda data, mean: jnp.expand_dims(jnp.sum(
                    jnp.conj(data - mean) * (data - mean), axis=0),
                                                   axis=0))
            _sum_sq_withp_pmapd = global_defs.jit_for_my_device(
                lambda data, mean, p: jnp.expand_dims(
                    jnp.conj(data - mean).dot(p * (data - mean)), axis=0))
            mean_helper = global_defs.jit_for_my_device(
                lambda data, p: jnp.expand_dims(jnp.dot(p, data), axis=0))
            cov_helper_with_p = global_defs.jit_for_my_device(
                _cov_helper_with_p)
            cov_helper_without_p = global_defs.jit_for_my_device(
                _cov_helper_without_p)

            jitDevice = global_defs.myDevice
示例#10
0
    def transform_to_eigenbasis(self, S, F, EOdata):

        self.ev, self.V = jnp.linalg.eigh(S)
        self.VtF = jnp.dot(jnp.transpose(jnp.conj(self.V)), F)

        EOdata = self.transform_EO(EOdata, self.V)
        self.rhoVar = mpi.global_variance(EOdata)

        self.snr = jnp.sqrt(
            jnp.abs(mpi.globNumSamples /
                    (self.rhoVar / (jnp.conj(self.VtF) * self.VtF) - 1.)))
示例#11
0
def update_dis(hamiltonian, state, isometry, disentangler):
    """Updates the disentangler with the aim of reducing the energy.

  Args:
    hamiltonian: The hamiltonian (rank-6 tensor) defined at the bottom of the
      MERA layer.
    state: The 3-site reduced state (rank-6 tensor) defined at the top of the
      MERA layer.
    isometry: The isometry tensor (rank 3) of the binary MERA.
    disentangler: The disentangler tensor (rank 4) of the binary MERA.

  Returns:
    The updated disentangler.
  """
    env = env_dis(hamiltonian, state, isometry, disentangler)

    net = tensornetwork.TensorNetwork(backend="jax")
    nenv = net.add_node(env, axis_names=["bl", "br", "tl", "tr"])
    output_edges = [nenv["bl"], nenv["br"], nenv["tl"], nenv["tr"]]

    nu, ns, nv, _ = net.split_node_full_svd(nenv, [nenv["bl"], nenv["br"]],
                                            [nenv["tl"], nenv["tr"]])
    _, s_edges = net.remove_node(ns)
    net.connect(s_edges[0], s_edges[1])
    nres = net.contract_between(nu, nv, output_edge_order=output_edges)

    return np.conj(nres.get_tensor())
示例#12
0
def expectation_m(m, beta, e, v):
    n = v.shape[0]
    p = fermion_weight(beta * e)
    c = jnp.dot(p, (
        jnp.conj(v.T).reshape([n, 1, n]) @ m @ v.T.reshape([n, n, 1])).reshape(
            [n]))
    return c
示例#13
0
def stable_svd_jvp(primals, tangents):
    """Copied from the JAX source code and slightly tweaked for stability"""
    # Deformation parameter which yields regular SVD JVP rule when set to 0
    eps = 1e-10
    A, = primals
    dA, = tangents
    U, s, Vt = jnp.linalg.svd(A, full_matrices=False, compute_uv=True)

    _T = lambda x: jnp.swapaxes(x, -1, -2)
    _H = lambda x: jnp.conj(_T(x))
    k = s.shape[-1]
    Ut, V = _H(U), _H(Vt)
    s_dim = s[..., None, :]
    dS = jnp.matmul(jnp.matmul(Ut, dA), V)
    ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))

    # Deformation by eps avoids getting NaN's when SV's are degenerate
    f = jnp.square(s_dim) - jnp.square(_T(s_dim)) + jnp.eye(k)
    f = f + eps / f  # eps controls stability
    F = 1 / f - jnp.eye(k) / (1 + eps)

    dSS = s_dim * dS
    SdS = _T(s_dim) * dS
    dU = jnp.matmul(U, F * (dSS + _T(dSS)))
    dV = jnp.matmul(V, F * (SdS + _T(SdS)))

    m, n = A.shape[-2], A.shape[-1]
    if m > n:
        dU = dU + jnp.matmul(
            jnp.eye(m) - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim
    if n > m:
        dV = dV + jnp.matmul(
            jnp.eye(n) - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim
    return (U, s, Vt), (dU, ds, _T(dV))
示例#14
0
def _two_loop_recursion(state: LBFGSResults):
    his_size = len(state.rho_history)
    curr_size = jnp.where(state.k < his_size, state.k, his_size)
    q = -jnp.conj(state.g_k)
    a_his = jnp.zeros_like(state.rho_history)

    def body_fun1(j, carry):
        i = his_size - 1 - j
        _q, _a_his = carry
        a_i = state.rho_history[i] * jnp.real(
            _dot(jnp.conj(state.s_history[i]), _q))
        _a_his = _a_his.at[i].set(a_i)
        _q = _q - a_i * jnp.conj(state.y_history[i])
        return _q, _a_his

    q, a_his = lax.fori_loop(0, curr_size, body_fun1, (q, a_his))
    q = state.gamma * q

    def body_fun2(j, _q):
        i = his_size - curr_size + j
        b_i = state.rho_history[i] * jnp.real(_dot(state.y_history[i], _q))
        _q = _q + (a_his[i] - b_i) * state.s_history[i]
        return _q

    q = lax.fori_loop(0, curr_size, body_fun2, q)
    return q
示例#15
0
def update_dis(hamiltonian, state, isometry, disentangler):
  """Updates the disentangler with the aim of reducing the energy.

  Args:
    hamiltonian: The hamiltonian (rank-6 tensor) defined at the bottom of the
      MERA layer.
    state: The 3-site reduced state (rank-6 tensor) defined at the top of the
      MERA layer.
    isometry: The isometry tensor (rank 3) of the binary MERA.
    disentangler: The disentangler tensor (rank 4) of the binary MERA.

  Returns:
    The updated disentangler.
  """
  env = env_dis(hamiltonian, state, isometry, disentangler)

  nenv = tensornetwork.Node(
      env, axis_names=["bl", "br", "tl", "tr"], backend="jax")
  output_edges = [nenv["bl"], nenv["br"], nenv["tl"], nenv["tr"]]

  nu, _, nv, _ = tensornetwork.split_node_full_svd(
      nenv, [nenv["bl"], nenv["br"]], [nenv["tl"], nenv["tr"]],
      left_edge_name="s1",
      right_edge_name="s2")
  nu["s1"].disconnect()
  nv["s2"].disconnect()
  tensornetwork.connect(nu["s1"], nv["s2"])
  nres = tensornetwork.contract_between(nu, nv, output_edge_order=output_edges)

  return np.conj(nres.get_tensor())
示例#16
0
def qrpos(mps):
    """
    Reshapes the (chiL, d, chiR) MPS tensor into a (chiL*d, chiR) matrix,
    and computes its QR decomposition, with the phase of R fixed so as to
    have a non-negative main diagonal. A new left-orthogonal
    (chiL, d, chiR) MPS tensor (reshaped from Q) is returned along with
    R.

    In addition to being phase-adjusted, R is normalized by division with
    its L2 norm.

    PARAMETERS
    ----------
    mps (array-like): The (chiL, d, chiR) MPS tensor.

    RETURNS
    -------
    mps_L, R: A left-orthogonal (chiL, d, chiR) MPS tensor, and an upper
              triangular (chiR x chiR) matrix with a non-negative main
              diagonal such that mps = mps_L @ R.
    """
    chiL, d, chiR = mps.shape
    mps_mat = jnp.reshape(mps, (chiL * d, chiR))
    Q, R = jnp.linalg.qr(mps_mat)
    phases = jnp.sign(jnp.diag(R))
    Q = Q * phases
    R = jnp.conj(phases)[:, None] * R
    R = R / jnp.linalg.norm(R)
    mps_L = Q.reshape(mps.shape)
    return (mps_L, R)
def test_ascend(random_tensors):
    h, s, iso, dis = random_tensors
    h = simple_mera.ascend(h, s, iso, dis)
    assert len(h.shape) == 6
    D = h.shape[0]
    hmat = np.reshape(h, [D**3] * 2)
    assert np.isclose(np.linalg.norm(hmat - np.conj(np.transpose(hmat))), 0.0)
示例#18
0
def _matrix_transpose(a, name='matrix_transpose', conjugate=False):  # pylint: disable=unused-argument
    a = np.array(a)
    if a.ndim < 2:
        raise ValueError('Input must have rank at least `2`; found {}.'.format(
            a.ndim))
    x = np.swapaxes(a, -2, -1)
    return np.conj(x) if conjugate else x
示例#19
0
def jaxeigh_bwd(r, tangents):
    a, e, v = r
    de, dv = tangents
    eye_n = jnp.eye(a.shape[-1], dtype=a.dtype)
    f = _safe_reciprocal(e[..., jnp.newaxis, :] - e[..., jnp.newaxis] +
                         eye_n) - eye_n
    middle = jnp.diag(de) + jnp.multiply(f, (v.T @ dv))
    grad_a = jnp.conj(v) @ middle @ v.T
    return (grad_a, )
示例#20
0
def encode_point(x, phis):
    dim = len(phis) * 2 + 1
    uf = np.zeros((dim, ), dtype='complex64')
    uf = index_update(uf, index[0], 1)
    uf = index_update(uf, index[1:(dim + 1) // 2], np.exp(1.j * phis))
    uf = index_update(uf, index[-1:dim // 2:-1], np.conj(np.exp(1.j * phis)))
    # this version uses full ifft with complex to allow odd dim
    ret = np.fft.ifft(uf**x).real
    return ret
def simulate(t, state, dt, space, t_h, parameters, goal, N_step, dx):
    # for i in range(N_step):
    #     t, state = step(t, state, dt, space, t_h, parameters)
    body_fun = lambda i, val: step(*val, dt, space, t_h, parameters)
    t, state = fori_loop(0, N_step, body_fun, (t, state))
    occ = occupation(state, goal, dx)
    occ = (occ * np.conj(occ)).real
    print(occ)
    return occ
示例#22
0
    def body_fun(state: LBFGSResults):
        # find search direction
        p_k = _two_loop_recursion(state)

        # line search
        ls_results = line_search(
            f=fun,
            xk=state.x_k,
            pk=p_k,
            old_fval=state.f_k,
            gfk=state.g_k,
            maxiter=maxls,
        )

        # evaluate at next iterate
        s_k = ls_results.a_k * p_k
        x_kp1 = state.x_k + s_k
        f_kp1 = ls_results.f_k
        g_kp1 = ls_results.g_k
        y_k = g_kp1 - state.g_k
        rho_k_inv = jnp.real(_dot(y_k, s_k))
        rho_k = jnp.reciprocal(rho_k_inv)
        gamma = rho_k_inv / jnp.real(_dot(jnp.conj(y_k), y_k))

        # replacements for next iteration
        status = 0
        status = jnp.where(state.f_k - f_kp1 < ftol, 4, status)
        status = jnp.where(state.ngev >= maxgrad, 3, status)  # type: ignore
        status = jnp.where(state.nfev >= maxfun, 2, status)  # type: ignore
        status = jnp.where(state.k >= maxiter, 1, status)  # type: ignore
        status = jnp.where(ls_results.failed, 5, status)

        converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol

        # TODO(jakevdp): use a fixed-point procedure rather than type-casting?
        state = state._replace(
            converged=converged,
            failed=(status > 0) & (~converged),
            k=state.k + 1,
            nfev=state.nfev + ls_results.nfev,
            ngev=state.ngev + ls_results.ngev,
            x_k=x_kp1.astype(state.x_k.dtype),
            f_k=f_kp1.astype(state.f_k.dtype),
            g_k=g_kp1.astype(state.g_k.dtype),
            s_history=_update_history_vectors(history=state.s_history,
                                              new=s_k),
            y_history=_update_history_vectors(history=state.y_history,
                                              new=y_k),
            rho_history=_update_history_scalars(history=state.rho_history,
                                                new=rho_k),
            gamma=gamma,
            status=jnp.where(converged, 0, status),
            ls_status=ls_results.status,
        )

        return state
def random_tensors(request):
    D = request.param
    key = jax.random.PRNGKey(0)

    h = jax.random.normal(key, shape=[D**3] * 2)
    h = 0.5 * (h + np.conj(np.transpose(h)))
    h = np.reshape(h, [D] * 6)

    s = jax.random.normal(key, shape=[D**3] * 2)
    s = s @ np.conj(np.transpose(s))
    s /= np.trace(s)
    s = np.reshape(s, [D] * 6)

    a = jax.random.normal(key, shape=[D**2] * 2)
    u, _, vh = np.linalg.svd(a)
    dis = np.reshape(u, [D] * 4)
    iso = np.reshape(vh, [D] * 4)[:, :, :, 0]

    return tuple(x.astype(np.complex128) for x in (h, s, iso, dis))
def test_descend(random_tensors):
    h, s, iso, dis = random_tensors
    s = simple_mera.descend(h, s, iso, dis)
    assert len(s.shape) == 6
    D = s.shape[0]
    smat = np.reshape(s, [D**3] * 2)
    assert np.isclose(np.trace(smat), 1.0)
    assert np.isclose(np.linalg.norm(smat - np.conj(np.transpose(smat))), 0.0)
    spec, _ = np.linalg.eigh(smat)
    assert np.alltrue(spec >= 0.0)
示例#25
0
def random_tensors(request):
    D = request.param
    key = jax.random.PRNGKey(int(time.time()))

    h = jax.random.normal(key, shape_tensor=[D**3] * 2)
    h = 0.5 * (h + np.conj(np.transpose(h)))
    h = np.reshape(h, [D] * 6)

    s = jax.random.normal(key, shape_tensor=[D**3] * 2)
    s = s @ np.conj(np.transpose(s))
    s /= np.trace(s)
    s = np.reshape(s, [D] * 6)

    a = jax.random.normal(key, shape_tensor=[D**2] * 2)
    u, _, vh = np.linalg.svd(a)
    dis = np.reshape(u, [D] * 4)
    iso = np.reshape(vh, [D] * 4)[:, :, :, 0]

    return (h, s, iso, dis)
示例#26
0
 def hint(const, var, e, v):
     energy = 0
     for site in uloc:  # interaction part by wick expansion
         nsite = spin_flip_func(site)
         cross = expectation(loc[site], loc[nsite], const.beta, e, v)
         energy += (expectation(loc[site], loc[site], const.beta, e, v) *
                    expectation(loc[nsite], loc[nsite], const.beta, e, v) -
                    jnp.conj(cross) * cross)
     if u:
         energy *= getattr(const, u)
     return energy
示例#27
0
def RZ(theta):
    r"""One-qubit rotation about the z axis.

    Args:
        theta (float): rotation angle

    Returns:
        array[complex]: the diagonal part of the rotation matrix :math:`e^{-i \sigma_z \theta/2}`
    """
    p = jnp.exp(-0.5j * theta)
    return jnp.array([p, jnp.conj(p)])
示例#28
0
def CRZ(theta):
    r"""Two-qubit controlled rotation about the z axis.

    Args:
        theta (float): rotation angle
    Returns:
        array[complex]: diagonal part of the 4x4 rotation matrix
        :math:`|0\rangle\langle 0|\otimes \mathbb{I}+|1\rangle\langle 1|\otimes R_z(\theta)`
    """
    p = jnp.exp(-0.5j * theta)
    return jnp.array([1.0, 1.0, p, jnp.conj(p)])
示例#29
0
def XopR(R, mpo, mps):
    """
    ---0mps2--       --
         1    |       |
         3    2       2
    ---0mpo1-0R  ->  0R
         2    1       1
         |    |       |
    ----mps*--|      --
    """
    mps_d = jnp.conj(mps)
    R = jnp.einsum("fed, afgh, cgd, bhe", R, mpo, mps, mps_d)
    return R
示例#30
0
def XnoL(mps):
    """
    ----0mps2--      ---
    |     1          |
    |     |          1
    |     |          L
    |     |          0
    |     1          |
    ----0mps*2-      ---
    """
    mps_d = jnp.conj(mps)
    L = jnp.einsum("cdb, cda", mps, mps_d)
    return L