Exemplo n.º 1
0
def plot_3d_traj_metric_trace_init_and_opt(fig, ax, metric, traj_init,
                                           traj_opt):

    metric_tensor_init, _, _ = gp_metric_tensor(
        traj_init[:, 0:2],
        metric.gp.Z,
        metric.gp.kernel,
        mean_func=metric.gp.mean_func,
        f=metric.gp.q_mu,
        full_cov=True,
        q_sqrt=metric.gp.q_sqrt,
        cov_weight=metric.cov_weight,
    )
    metric_tensor_opt, _, _ = gp_metric_tensor(
        traj_opt[:, 0:2],
        metric.gp.Z,
        metric.gp.kernel,
        mean_func=metric.gp.mean_func,
        f=metric.gp.q_mu,
        full_cov=True,
        q_sqrt=metric.gp.q_sqrt,
        cov_weight=metric.cov_weight,
    )

    metric_trace_init = np.trace(metric_tensor_init, axis1=1, axis2=2)
    metric_trace_opt = np.trace(metric_tensor_opt, axis1=1, axis2=2)

    plot_3d_traj_col(fig, ax, traj_init, zs=metric_trace_init, color="c")
    plot_3d_traj_col(fig, ax, traj_opt, zs=metric_trace_opt, color="m")
    return fig, ax
Exemplo n.º 2
0
    def test_sample(self):
        np.random.seed(0)
        n, d = 1000, 5
        batch_size = 10
        num_samples = 200
        parallel_chains = 1

        obj = create_random_least_squares(
            num_objectives=1,
            batch_size=batch_size,
            n_features=(d - 1),
            n_samples=(n, ),
            lam=1e-3,
        )[0]
        opt = obj.solve()
        q_obj = Quadratic.from_least_squares(obj)
        posterior_cov = jnp.linalg.pinv(q_obj.A)
        posterior_cov /= jnp.trace(posterior_cov)

        # Approximate sampling from the posterior.
        prng_key = random.PRNGKey(0)
        sampler = IASG(
            avg_steps=100,
            burnin_steps=100,
            learning_rate=1.0,
            discard_steps=100,
        )
        prng_key, subkey = random.split(prng_key)
        init_state = random.normal(subkey, shape=(d, ))
        samples = sampler.sample(
            objective=obj,
            prng_key=prng_key,
            init_state=init_state,
            num_samples=num_samples,
            parallel_chains=parallel_chains,
        )
        self.assertEqual(samples.shape[0], num_samples)

        sample_mean = jnp.mean(samples, axis=0)
        sample_cov = jnp.cov(samples, rowvar=False)
        sample_cov /= jnp.trace(sample_cov)

        sample_cov_fro_err = jnp.linalg.norm(sample_cov - posterior_cov, "fro")
        np.testing.assert_allclose(sample_mean, opt, rtol=1e-1, atol=1e-1)
        np.testing.assert_allclose(sample_cov_fro_err,
                                   0.0,
                                   rtol=1e-1,
                                   atol=1e-1)
Exemplo n.º 3
0
def get_idt(y_train, eigenspace, kdd=None):
    """lower bound on I(\theta;D)"""
    normalization = y_train.size
    op_fn = _make_inv_expm1_fn_double(normalization)
    trace_ntk = np.trace(kdd.ntk)

    def predict(t):
        evals, evecs = eigenspace
        fl, ufl = nt.predict._make_flatten_uflatten(kdd.ntk, y_train)

        op_evals = -op_fn(evals, t)
        yexpm1y = np.einsum('j...,ji,i,ki,k...',
                            fl(y_train),
                            evecs,
                            op_evals,
                            evecs,
                            fl(y_train),
                            optimize=True)
        trace_ntk_t = t * trace_ntk
        kdd_exp1m = np.einsum('kj, ji,i,li->kl',
                              kdd.nngp,
                              evecs,
                              op_evals,
                              evecs,
                              optimize=True)
        idt = np.trace(kdd_exp1m) + np.mean(yexpm1y) + trace_ntk_t
        return idt

    return predict
Exemplo n.º 4
0
def isdm(mat):
    """Checks whether a given matrix is a valid density matrix.

    Args:
        mat (:obj:`jnp.ndarray`): Input matrix
    
    Returns:
        bool: ``True`` if input matrix is a valid density matrix; 
            ``False`` otherwise
    """
    isdensity = True

    if (
        isket(mat) == True
        or isbra(mat) == True
        or isherm(mat) == False
        or jnp.allclose(jnp.real(jnp.trace(mat)), 1, atol=1e-09) == False
    ):
        isdensity = False
    else:
        evals, _ = jnp.linalg.eig(mat)
        for eig in evals:
            if eig < 0 and jnp.allclose(eig, 0, atol=1e-06) == False:
                isdensity = False
                break

    return isdensity
Exemplo n.º 5
0
def kl_div(mu: np.ndarray,
           A_chol: np.ndarray,
           sigma_prior: float
           ) -> float:
    """
    Computes the KL divergence between
    - the approximated posterior distribution N(mu, Sigma)
    - and the prior distribution on the parameters N(0, (sigma_prior ** 2) I)

    Instead of working directly with the covariance matrix Sigma, we will only deal with its Cholesky matrix A:
    It is the lower triangular matrix such that Sigma = A * A.T

    :param mu: mean of the posterior distribution approximated by variational inference
    :param A: Choleski matrix such that Sigma = A * A.T,
    where Sigma is the coveriance of the posterior distribution approximated by variational inference.
    :param sigma_prior: standard deviation of the prior on the parameters. We put the following prior on the parameters:
    N(mean=0, variance=(sigma_prior**2) I)
    :return: the value of the KL divergence
    """
    # TODO
    covariance_post = np.dot(A_chol,A_chol.T)
    mean_post = mu
    size = len(mu)
    mean_prior = np.zeros(shape=(A_chol.shape[0],1))
    covariance_prior = sigma_prior**2*np.identity(A_chol.shape[0])

    cov_ratio = np.linalg.det(covariance_prior)/np.linalg.det(covariance_post) # get ratio of 2 covariance matrices

    trace_matrices = np.trace(np.dot(np.linalg.inv(covariance_prior), covariance_post))

    last_term = np.dot((mean_prior - mean_post).T, np.dot(np.linalg.inv(covariance_prior), (mean_prior - mean_post)))

    kl = 0.5*(np.log(cov_ratio) - size + trace_matrices + last_term)

    return kl[0][0]
def theory_cnn(x_train, y_train, beta, kernel_fns, hidden_widths):

    N_tr = x_train.shape[0]
    n0 = x_train.shape[1] * x_train.shape[2]
    nd = y_train.shape[1]

    Gxx = jnp.moveaxis(jnp.tensordot(x_train, x_train, (3, 3)), (3),
                       (1))  ## Tensordot in channel axis
    Gyy = y_train @ y_train.T / nd

    K_nngp = []
    for i in range(len(kernel_fns)):
        print(convert_nt(kernel_fns[i](x_train, ).nngp).shape)
        K_nngp += [convert_nt(kernel_fns[i](x_train, ).nngp, i)]

    KPsi = jnp.trace(Gxx.reshape(N_tr, N_tr, D, D), axis1=2, axis2=3) / n0
    #     KPsi_2 = x_train.reshape(N_tr,-1)@x_train.reshape(N_tr,-1).T/D
    #     print((KPsi-KPsi_2).std())

    I = jnp.eye(N_tr)
    gamma = KPsi + I / beta
    gamma_inv = jnp.linalg.inv(gamma)
    Phi = gamma_inv @ (Gyy - KPsi - I / beta) @ gamma_inv

    prefactor = jnp.cumsum(nd / jnp.array(hidden_widths))

    K_theory = []
    for i in range(len(prefactor)):
        K_theory += [
            K_nngp[i] + prefactor[i] * correction_layer(K_nngp[i], Phi)
        ]

    return K_nngp, K_theory, Gxx, Gyy
    def body(carry, tW):
        dt, dW = tW
        t0, X0, Y0, Y0_tilde, Z0 = carry
        sigma = sigma_tf(t0, X0, Y0)
        dx = mu_tf(t0, X0, Y0, Z0) * (dt) + jnp.dot(sigma, dW)  ##dx step
        X1 = X0 + dx
        dydt = dt_forward(params, t0, X0)  ##EULER STEP

        ##########################################################
        #for Hessian vector product
        # f_partial = partial(forward, params, t0)
        # vhvp0 = vhvp(f_partial, X0, sigma)
        # vdot = vmap(jnp.dot, in_axes=(0,0))  #potentially 1 in one of the axis
        # vsHs = vdot(sigma, vhvp0)

        # sumvsHs = jnp.asarray([jnp.sum(vsHs)])
        ##########################################################
        # ddydx = 0.5 * sumvsHs
        ddydx = 0.5 * jnp.trace(jnp.dot(sigma_tf(t0, X0, Y0).T,jnp.dot(hess_forward(params, t0, X0),sigma_tf(t0, X0, Y0))))

        Y1_tilde = Y0 + phi_tf(t0, X0, Y0, Z0) * (dt) + jnp.dot(jnp.dot(Z0.T, sigma_tf(t0, X0, Y0)), dW)
        # dy = dydt * dt + jnp.dot(Z0.T, dx) #EULER STEP
        dy = (dydt + ddydx) * dt + jnp.dot(Z0.T, dx) #ITO STEP

        t1 = t0 + dt
        # Y1 = jnp.asarray([forward(params, t1, X1)])
        Y1 = Y0 + dy
        Z1 = grad_forward(params, t1, X1)
        carry_new = t1, X1, Y1, Y1_tilde, Z1
        return (carry_new, carry)
Exemplo n.º 8
0
def gaussian_expected_log_lik(Y, q_mu, q_covar, noise, mask=None):
    """
    :param Y: N x 1
    :param q_mu: N x 1
    :param q_covar: N x N
    :param noise: N x N
    :param mask: N x 1
    :return:
        E[log 𝓝(yₙ|fₙ,σ²)] = ∫ log 𝓝(yₙ|fₙ,σ²) 𝓝(fₙ|mₙ,vₙ) dfₙ
    """

    if mask is not None:
        # build a mask for computing the log likelihood of a partially observed multivariate Gaussian
        maskv = mask.reshape(-1, 1)
        q_mu = np.where(maskv, Y, q_mu)
        noise = np.where(maskv + maskv.T, 0.,
                         noise)  # ensure masked entries are independent
        noise = np.where(np.diag(mask), INV2PI,
                         noise)  # ensure masked entries return log like of 0
        q_covar = np.where(maskv + maskv.T, 0.,
                           q_covar)  # ensure masked entries are independent
        q_covar = np.where(
            np.diag(mask), 1e-20,
            q_covar)  # ensure masked entries return trace term of 0

    ml = mvn_logpdf(Y, q_mu, noise)
    trace_term = -0.5 * np.trace(solve(noise, q_covar))
    return ml + trace_term
Exemplo n.º 9
0
def L(C1s, C0s, ks, bs, sigma=1):
    """
        The loss function big L. 
    """
    # return jnp.linalg.det(FIM(q,ps,C1s,C0s,ks,bs,sigma))
    return lambda q, ps: jnp.trace(
        jnp.linalg.inv(FIM(C1s, C0s, ks, bs, sigma)(q, ps)))
Exemplo n.º 10
0
 def variational_expectation(self, y):
     return -.5 * jnp.squeeze(
         (jnp.sum(jnp.square(self.qu_mean - y))
          + jnp.trace(self.qu_scale @ self.qu_scale.T))
         / self.observation_noise_scale ** 2
         + y.shape[-1] * jnp.log(self.observation_noise_scale ** 2)
         + jnp.log(2 * jnp.pi))
Exemplo n.º 11
0
def loss(t1, flat_p, omega, U_T):
    '''
    define the loss function, which is a pure function
    '''
    t_set = jnp.linspace(0., t1, 5)

    D, _, = jnp.shape(U_T)
    U_0 = jnp.eye(D, dtype=jnp.complex128)

    def func(y, t, *args):
        t1, omega, flat_p, = args

        return -1.0j * (omega * sz + A(t, flat_p, t1) * sx) @ y
        # return -1.0j*( omega* sz)@y

    res = odeint(func,
                 U_0,
                 t_set,
                 t1,
                 omega,
                 flat_p,
                 rtol=1.4e-10,
                 atol=1.4e-10)

    U_F = res[-1, :, :]
    return (1 - jnp.abs(jnp.trace(U_T.conj().T @ U_F) / D)**2)
Exemplo n.º 12
0
def _add_diagonal_regularizer(A: np.ndarray,
                              diag_reg: float,
                              diag_reg_absolute_scale: bool) -> np.ndarray:
  dimension = A.shape[0]
  if not diag_reg_absolute_scale:
    diag_reg *= np.trace(A) / dimension
  return A + diag_reg * np.eye(dimension)
Exemplo n.º 13
0
        def trace_mps(tensors, edge):
            def multiply_tensors(left_tensor, right_tensor):
                return jnp.einsum("ij,jk->ik", left_tensor, right_tensor), None

            edge, _ = jax.lax.scan(multiply_tensors, edge, tensors)

            return jnp.trace(edge)
Exemplo n.º 14
0
    def norm(self, x):
        norm = jnp.sum(x[..., 0:self._dimension]**2, axis=-1)
        x_mat = jnp.reshape(x[..., self._dimension:],
                            (-1, self._dimension, self._dimension))

        norm += jnp.trace(x_mat, axis1=-2, axis2=-1)
        return norm
Exemplo n.º 15
0
 def _bl_update(H, C, R, state):
     G, (α, _), μ, τ = state
     tr_inv_H = np.trace(solve(H, I, sym_pos="sym"))
     γ = n - α * tr_inv_H
     α = np.float32(n / (2 * R + tr_inv_H))
     β = np.float32((x.shape[0] - γ) / (2 * C))
     return G, (α, β), μ, τ
Exemplo n.º 16
0
def split_spectrum(H, split_point, V0=None, precision=lax.Precision.HIGHEST):
  """ The Hermitian matrix `H` is split into two matrices `Hm`
  `Hp`, respectively sharing its eigenspaces beneath and above
  its `split_point`th eigenvalue.

  Returns, in addition, `Vm` and `Vp`, isometries such that
  `Hi = Vi.conj().T @ H @ Vi`. If `V0` is not None, `V0 @ Vi` are
  returned instead; this allows the overall isometries mapping from
  an initial input matrix to progressively smaller blocks to be formed.

  Args:
    H: The Hermitian matrix to split.
    split_point: The eigenvalue to split along.
    V0: Matrix of isometries to be updated.
    precision: TPU matmul precision.
  Returns:
    Hm: A Hermitian matrix sharing the eigenvalues of `H` beneath
      `split_point`.
    Vm: An isometry from the input space of `V0` to `Hm`.
    Hp: A Hermitian matrix sharing the eigenvalues of `H` above
      `split_point`.
    Vp: An isometry from the input space of `V0` to `Hp`.
  """
  def _fill_diagonal(X, vals):
    return jax.ops.index_update(X, jnp.diag_indices(X.shape[0]), vals)

  H_shift = _fill_diagonal(H, H.diagonal() - split_point)
  U, _ = jsp.linalg.polar_unitary(H_shift)
  P = -0.5 * _fill_diagonal(U, U.diagonal() - 1.)
  rank = jnp.round(jnp.trace(P)).astype(jnp.int32)
  rank = int(rank)
  return _split_spectrum_jittable(P, H, V0, rank, precision)
Exemplo n.º 17
0
def _expect_dm(oper, state):
    """Private function to calculate the expectation value of 
    an operator with respect to a density matrix
    """
    # convert to jax.numpy arrays in case user gives raw numpy
    oper, rho = jnp.asarray(oper), jnp.asarray(state)
    # Tr(rho*op)
    return jnp.trace(jnp.dot(rho, oper))
Exemplo n.º 18
0
def test_opt(wavelet_tensors):
    h, iso, dis = wavelet_tensors
    s = np.reshape(np.eye(2**3) / 2**3, [2] * 6)
    for _ in range(20):
        s = simple_mera.descend(h, s, iso, dis)
    s, iso, dis = simple_mera.optimize_linear(h, s, iso, dis, 100)
    en = np.trace(np.reshape(s, [2**3, -1]) @ np.reshape(h, [2**3, -1]))
    assert en < -1.25
Exemplo n.º 19
0
def test_energy(wavelet_tensors):
    h, iso, dis = wavelet_tensors
    s = np.reshape(np.eye(2**3) / 2**3, [2] * 6)
    for _ in range(20):
        s = simple_mera.descend(h, s, iso, dis)
    en = np.trace(np.reshape(s, [2**3, -1]) @ np.reshape(h, [2**3, -1]))
    assert np.isclose(en, -1.242, rtol=1e-3, atol=1e-3)
    en = simple_mera.binary_mera_energy(h, s, iso, dis)
    assert np.isclose(en, -1.242, rtol=1e-3, atol=1e-3)
Exemplo n.º 20
0
def test_to_matrix(vstate_rho, normalize):
    rho = vstate_rho.to_matrix(normalize=normalize)

    if normalize:
        np.testing.assert_allclose(jnp.trace(rho), 1.0)

    rho_norm = rho / jnp.trace(rho)

    assert rho.shape == (
        vstate_rho.hilbert.physical.n_states,
        vstate_rho.hilbert.physical.n_states,
    )

    x = vstate_rho.hilbert.all_states()
    rho_exact = jnp.exp(vstate_rho.log_value(x)).reshape(rho.shape)
    rho_exact = rho_exact / jnp.trace(rho_exact)

    np.testing.assert_allclose(rho_norm, rho_exact)
Exemplo n.º 21
0
 def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2, rng):
     onp_fun = lambda arg: onp.trace(arg, offset, axis1, axis2, out_dtype)
     lnp_fun = lambda arg: lnp.trace(arg, offset, axis1, axis2, out_dtype)
     args_maker = lambda: [rng(shape, dtype)]
     self._CheckAgainstNumpy(onp_fun,
                             lnp_fun,
                             args_maker,
                             check_dtypes=True)
     self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
Exemplo n.º 22
0
def _variance(c: np.ndarray, t: np.ndarray) -> float:
    """
    Returns the variance |<aa aa>| - <aa>**2 that would be zero for the fixed point of the RG
    This version is only checking if c and t are ANY environment, see also `_variance3`

    Parameters
    ----------
    c
        The corner matrix c(d, r) with `chi`-dimensional legs
    t
        The half-column / half-row tensor t(l, d, d', r)
    Returns
    -------
    var : float
        The variance

    """

    #          c --- c               c --- t --- c                     c --- t --- t --- c
    #  lr  =   |     |      ltr  =   |     |     |         lttr   =    |     |     |     |
    #          c --- c               c --- t --- c                     c --- t --- t --- c
    # if RG fixed point is reached then lttr/lr = <tt> = <t>**2 = (ltr/lr)**2

    lr = np.trace(c @ c @ c @ c)

    connects = [[1, 6], [6, 7, 8, 5], [5, 4], [4, 3], [2, 1], [2, 7, 8, 3]]
    cont_order = [
        4,
        1,
        3,
        2,
        6,
        5,
        7,
        8,
    ]
    ltr = ncon([c, t, c, c, c, t], connects, cont_order)

    connects = [[1, 4], [4, 5, 6, 9], [10, 3], [3, 11], [2, 1], [2, 5, 6, 12],
                [9, 7, 8, 10], [12, 7, 8, 11]]
    cont_order = [
        1,
        2,
        3,
        11,
        10,
        7,
        8,
        12,
        4,
        5,
        6,
        9,
    ]
    lttr = ncon([c, t, c, c, c, t, t, t], connects, cont_order)

    return abs(lttr / lr) - (ltr / lr)**2
Exemplo n.º 23
0
def test_descend(random_tensors):
    h, s, iso, dis = random_tensors
    s = simple_mera.descend(h, s, iso, dis)
    assert len(s.shape) == 6
    D = s.shape[0]
    smat = np.reshape(s, [D**3] * 2)
    assert np.isclose(np.trace(smat), 1.0)
    assert np.isclose(np.linalg.norm(smat - np.conj(np.transpose(smat))), 0.0)
    spec, _ = np.linalg.eigh(smat)
    assert np.alltrue(spec >= 0.0)
Exemplo n.º 24
0
 def update_params(self, params, C_post, m_post):
     
     sigma = params[0]
     theta = params[1]
     
     theta = (self.n_features - theta * np.trace(C_post)) / np.sum(m_post**2)
     
     upper = np.sum(self.YtY - 2 * self.XtY * m_post + m_post.T @ self.XtX @ m_post)
     lower = self.n_features - np.sum(1 - theta * np.diag(C_post))
     sigma = upper / lower
     
     return np.asarray([sigma, theta])
Exemplo n.º 25
0
    def trace(self, contra_index: int, cov_index: int) -> "Tensor[P]":
        """Contract the contravariant [contra_index] with the covariant [cov_index].

        Requires 0 <= contra_index < n_contra, and 0 <= cov_index < n_cov
        """
        if contra_index < 0 or contra_index > self.n_contra:
            raise ValueError(f"contra_index out of bounds: {contra_index}")
        if cov_index < 0 or cov_index > self.n_cov:
            raise ValueError(f"cov_index out of bounds: {cov_index}")
        coords = jnp.trace(self.t_coords, contra_index,
                           self.n_contra + cov_index)
        return Tensor(self.point, coords, self.n_contra - 1)
Exemplo n.º 26
0
def _rotation_q(pos: jnp.ndarray, indices: jnp.ndarray,
                refpos: jnp.ndarray) -> float:
    dx = pos[indices[:-1]] - pos[indices[:-1]].mean(0)
    R = dx.T @ refpos
    Rtr = jnp.trace(R)
    Ftop = jnp.array([R[1, 2] - R[2, 1], R[2, 0] - R[0, 2], R[0, 1] - R[1, 0]])
    F = jnp.block([
        [Rtr, Ftop[None, :]],
        [Ftop[:, None], -Rtr * jnp.eye(3) + R + R.T],
    ])
    q = eigh_rightmost(F)
    return q * jnp.sign(q[0])
Exemplo n.º 27
0
def plot_metirc_trace_over_time(metric, solver, traj_init, traj_opt):
    metric_tensor_init, _, _ = gp_metric_tensor(
        traj_init[:, 0:2],
        metric.gp.Z,
        metric.gp.kernel,
        mean_func=metric.gp.mean_func,
        f=metric.gp.q_mu,
        full_cov=True,
        q_sqrt=metric.gp.q_sqrt,
        cov_weight=metric.cov_weight,
    )
    metric_trace_init = np.trace(metric_tensor_init, axis1=1, axis2=2)
    metric_tensor_opt, _, _ = gp_metric_tensor(
        traj_opt[:, 0:2],
        metric.gp.Z,
        metric.gp.kernel,
        mean_func=metric.gp.mean_func,
        f=metric.gp.q_mu,
        full_cov=True,
        q_sqrt=metric.gp.q_sqrt,
        cov_weight=metric.cov_weight,
    )
    metric_trace_opt = np.trace(metric_tensor_opt, axis1=1, axis2=2)
    fig, ax = plt.subplots(1, 1, figsize=(6.4, 2.8))
    ax.set_xlabel("Time $t$")
    ax.set_ylabel("Tr$(G(\mathbf{x}_t))$")

    ax.plot(
        solver.times,
        metric_trace_init,
        color=color_init,
        label="Initial trajectory",
    )
    ax.plot(
        solver.times,
        metric_trace_opt,
        color=color_opt,
        label="Optimised trajectory",
    )
    ax.legend()
Exemplo n.º 28
0
def h(x, y, kernel, logp):
    k = kernel

    def h2(x_, y_):
        return np.inner(grad(logp)(y_), grad(k, argnums=0)(x_, y_))

    def d_xk(x_, y_):
        return grad(k, argnums=0)(x_, y_)

    out = np.inner(grad(logp)(x), grad(logp)(y)) * k(x, y) +\
        h2(x, y) + h2(y, x) +\
        np.trace(jacfwd(d_xk, argnums=1)(x, y))
    return out
Exemplo n.º 29
0
 def g(Z):
     """differentiable piece of objective in μ problem"""
     if mask is not None:
         loss_term = loss(mu0 * cmp.ilr_inv(Z, basis),
                          self.X[mask, :], self.L[mask, :])
     else:
         loss_term = loss(mu0 * cmp.ilr_inv(Z, basis), self.X,
                          self.L)
     spline_term = (β_spline / 2) * ((D1 @ Z)**2).sum()
     # generalized Tikhonov
     Z_delta = Z - Z_ref
     ridge_term = (β_ridge / 2) * np.trace(Z_delta.T @ Γ @ Z_delta)
     return loss_term + spline_term + ridge_term
Exemplo n.º 30
0
def eigengame_subspace_distance(Phi, optimal_subspace):
    """Compute subspace distance as per the eigengame paper."""
    try:
        d = Phi.shape[1]
        U_star = optimal_subspace @ optimal_subspace.T

        U_phi, _, _ = jnp.linalg.svd(Phi)
        U_phi = U_phi[:, :d]
        P_star = U_phi @ U_phi.T

        return 1 - 1 / d * jnp.trace(U_star @ P_star)
    except np.linalg.LinAlgError:
        return jnp.nan