Beispiel #1
0
    def _dp(self, log_potentials, length):
        semiring = self.semiring
        print(log_potentials.shape)
        N, N2, NT = log_potentials.shape
        assert N == N2
        reduced_scores = semiring.sum(log_potentials)
        term = np.diagonal(reduced_scores, 0, 0, 1)
        ns = np.arange(N)

        chart = np.full((2, N, N), semiring.zero, log_potentials.dtype)
        chart = jax.ops.index_update(chart, jax.ops.index[A, ns, 0], term)
        chart = jax.ops.index_update(chart, jax.ops.index[B, ns, N - 1], term)

        # Run
        for w in range(1, N):
            # def loop(w, chart):
            left = slice(None, N - w)
            right = slice(w, None)
            Y = chart[A, left, :w]
            Z = chart[B, right, N - w:]
            score = np.diagonal(reduced_scores, w, 0, 1)
            new = semiring.times(semiring.dot(Y, Z), score)
            chart = jax.ops.index_update(chart, jax.ops.index[A, left, w], new)
            chart = jax.ops.index_update(chart, jax.ops.index[B, right,
                                                              N - w - 1], new)

        # chart = jax.lax.fori_loop(1, N, loop, chart)
        return chart[A, 0, length - 1]
Beispiel #2
0
    def _inner_raw(self, Y=None, full=True):
        if not full and Y is not None:
            raise ValueError(
                "Ambiguous inputs: `diagonal` and `y` are not compatible.")
        assert (full)

        if Y is not None and Y != self:
            assert (self.k == Y.k and self.use_inner == Y.use_inner)
            gram_self = self.k(self.inspace_points).astype(float)
            gram_mix = self.k(self.inspace_points,
                              Y.inspace_points).astype(float)
            gram_other = self.k(Y.inspace_points).astype(float)
        else:
            Y = self
            gram_self = gram_mix = gram_other = self.k(
                self.inspace_points).astype(float)

        r1 = self.reduce_gram(gram_mix, axis=0)
        gram_mix_red = Y.reduce_gram(r1, axis=1)
        if self.use_inner == "linear" or self.use_inner == "poly":
            return gram_mix_red
        elif self.use_inner == "gen_gauss":
            gram_self_red = np.diagonal(
                self.reduce_gram(self.reduce_gram(gram_self, axis=0),
                                 axis=1)).reshape((-1, 1))
            gram_other_red = np.diagonal(
                Y.reduce_gram(Y.reduce_gram(gram_other, axis=0),
                              axis=1)).reshape((1, -1))
            return {
                "gram_mix_red": gram_mix_red,
                "gram_self_red": gram_self_red,
                "gram_other_red": gram_other_red
            }
Beispiel #3
0
def run(log_potentials, length, semiring="Log"):
    "Main code, vectorized inside-outside"
    if semiring == "Log":
        semiring = LogSemiring
    else:
        semiring = MaxSemiring
    N, N2, NT = log_potentials.shape
    assert N == N2
    reduced_scores = semiring.sum(log_potentials)
    term = np.diagonal(reduced_scores, 0, 0, 1)
    ns = np.arange(N)

    chart = np.full((2, N, N), semiring.zero, log_potentials.dtype)
    chart = jax.ops.index_update(chart, jax.ops.index[A, ns, 0], term)
    chart = jax.ops.index_update(chart, jax.ops.index[B, ns, N - 1], term)

    # Run
    for w in range(1, N):
        left = slice(None, N - w)
        right = slice(w, None)
        Y = chart[A, left, :w]
        Z = chart[B, right, N - w:]
        score = np.diagonal(reduced_scores, w, 0, 1)
        new = semiring.times(semiring.dot(Y, Z), score)
        chart = jax.ops.index_update(chart, jax.ops.index[A, left, w], new)
        chart = jax.ops.index_update(chart, jax.ops.index[B, right, N - w - 1],
                                     new)
    return chart[A, 0, length - 1]
Beispiel #4
0
def rkhs_gram_cdist_unchecked(G_ab: np.array,
                              G_a: np.array,
                              G_b: np.array,
                              power: float = 2.):
    sqdist = np.diagonal(G_a)[:, np.newaxis] + np.diagonal(G_b)[
        np.newaxis, :] - 2 * G_ab
    if power == 2.:
        return sqdist
    else:
        return np.power(sqdist, power / 2.)
Beispiel #5
0
 def log_abs_det_jacobian(self, x, y, intermediates=None):
     if self.domain is constraints.lower_cholesky:
         # Ref: http://web.mit.edu/18.325/www/handouts/handout2.pdf page 13
         n = jnp.shape(x)[-1]
         order = jnp.arange(n, 0, -1)
         return n * jnp.log(2) + jnp.sum(order * jnp.log(jnp.diagonal(x, axis1=-2, axis2=-1)), axis=-1)
     else:
         # NB: see derivation in LKJCholesky implementation
         n = jnp.shape(x)[-1]
         order = jnp.arange(n - 1, -1, -1)
         return jnp.sum(order * jnp.log(jnp.diagonal(x, axis1=-2, axis2=-1)), axis=-1)
Beispiel #6
0
def diagonal_between(x: np.ndarray,
                     start_axis: int = 0,
                     end_axis: int = -1) -> np.ndarray:
    """Returns the diagonal along all dimensions between start and end axes."""
    if end_axis == -1:
        end_axis = x.ndim
    half_ndim, ragged = divmod(end_axis - start_axis, 2)
    if ragged:
        raise ValueError(
            f'Need even number of axes to flatten, got {end_axis - start_axis}.'
        )
    if half_ndim == 0:
        return x

    side_shape = x.shape[start_axis:start_axis + half_ndim]
    side_size = size_at(side_shape)

    shape_2d = x.shape[:start_axis] + (side_size,
                                       side_size) + x.shape[end_axis:]
    shape_result = x.shape[:start_axis] + side_shape + x.shape[end_axis:]

    x = np.diagonal(x.reshape(shape_2d),
                    axis1=start_axis,
                    axis2=start_axis + 1)
    x = np.moveaxis(x, -1, start_axis)
    return x.reshape(shape_result)
Beispiel #7
0
def stable_svd_jvp(primals, tangents):
    """Copied from the JAX source code and slightly tweaked for stability"""
    # Deformation parameter which yields regular SVD JVP rule when set to 0
    eps = 1e-10
    A, = primals
    dA, = tangents
    U, s, Vt = jnp.linalg.svd(A, full_matrices=False, compute_uv=True)

    _T = lambda x: jnp.swapaxes(x, -1, -2)
    _H = lambda x: jnp.conj(_T(x))
    k = s.shape[-1]
    Ut, V = _H(U), _H(Vt)
    s_dim = s[..., None, :]
    dS = jnp.matmul(jnp.matmul(Ut, dA), V)
    ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))

    # Deformation by eps avoids getting NaN's when SV's are degenerate
    f = jnp.square(s_dim) - jnp.square(_T(s_dim)) + jnp.eye(k)
    f = f + eps / f  # eps controls stability
    F = 1 / f - jnp.eye(k) / (1 + eps)

    dSS = s_dim * dS
    SdS = _T(s_dim) * dS
    dU = jnp.matmul(U, F * (dSS + _T(dSS)))
    dV = jnp.matmul(V, F * (SdS + _T(SdS)))

    m, n = A.shape[-2], A.shape[-1]
    if m > n:
        dU = dU + jnp.matmul(
            jnp.eye(m) - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim
    if n > m:
        dV = dV + jnp.matmul(
            jnp.eye(n) - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim
    return (U, s, Vt), (dU, ds, _T(dV))
Beispiel #8
0
def test_welford_covariance(jitted, diagonal, regularize):
    with optional(jitted,
                  disable_jit()), optional(jitted,
                                           control_flow_prims_disabled()):
        np.random.seed(0)
        loc = np.random.randn(3)
        a = np.random.randn(3, 3)
        target_cov = np.matmul(a, a.T)
        x = np.random.multivariate_normal(loc, target_cov, size=(2000, ))
        x = device_put(x)

        @jit
        def get_cov(x):
            wc_init, wc_update, wc_final = welford_covariance(
                diagonal=diagonal)
            wc_state = wc_init(3)
            wc_state = fori_loop(0, 2000, lambda i, val: wc_update(x[i], val),
                                 wc_state)
            cov, cov_inv_sqrt = wc_final(wc_state, regularize=regularize)
            return cov, cov_inv_sqrt

        cov, cov_inv_sqrt = get_cov(x)

        if diagonal:
            diag_cov = jnp.diagonal(target_cov)
            assert_allclose(cov, diag_cov, rtol=0.06)
            assert_allclose(cov_inv_sqrt,
                            jnp.sqrt(jnp.reciprocal(diag_cov)),
                            rtol=0.06)
        else:
            assert_allclose(cov, target_cov, rtol=0.06)
            assert_allclose(cov_inv_sqrt,
                            jnp.linalg.cholesky(jnp.linalg.inv(cov)),
                            rtol=0.06)
Beispiel #9
0
 def __call__(self, x):
     tril = np.tril(x)
     lower_triangular = np.all(np.reshape(tril == x, x.shape[:-2] + (-1, )),
                               axis=-1)
     positive_diagonal = np.all(np.diagonal(x, axis1=-2, axis2=-1) > 0,
                                axis=-1)
     return lower_triangular & positive_diagonal
Beispiel #10
0
 def __call__(self, x):
     tril = jnp.tril(x)
     lower_triangular = jnp.all(jnp.reshape(tril == x, x.shape[:-2] + (-1,)), axis=-1)
     positive_diagonal = jnp.all(jnp.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1)
     x_norm = jnp.linalg.norm(x, axis=-1)
     unit_norm_row = jnp.all((x_norm <= 1) & (x_norm > 1 - 1e-6), axis=-1)
     return lower_triangular & positive_diagonal & unit_norm_row
Beispiel #11
0
def cholesky_update(L, x, coef=1):
    """
    Finds cholesky of L @ L.T + coef * x @ x.T.

    **References;**

        1. A more efficient rank-one covariance matrix update for evolution strategies,
           Oswin Krause and Christian Igel
    """
    batch_shape = lax.broadcast_shapes(L.shape[:-2], x.shape[:-1])
    L = jnp.broadcast_to(L, batch_shape + L.shape[-2:])
    x = jnp.broadcast_to(x, batch_shape + x.shape[-1:])
    diag = jnp.diagonal(L, axis1=-2, axis2=-1)
    # convert to unit diagonal triangular matrix: L @ D @ T.t
    L = L / diag[..., None, :]
    D = jnp.square(diag)

    def scan_fn(carry, val):
        b, w = carry
        j, Dj, L_j = val
        wj = w[..., j]
        gamma = b * Dj + coef * jnp.square(wj)
        Dj_new = gamma / b
        b = gamma / Dj_new

        # update vectors w and L_j
        w = w - wj[..., None] * L_j
        L_j = L_j + (coef * wj / gamma)[..., None] * w
        return (b, w), (Dj_new, L_j)

    D, L = jnp.moveaxis(D, -1, 0), jnp.moveaxis(L, -1, 0)  # move scan dim to front
    _, (D, L) = lax.scan(scan_fn, (jnp.ones(batch_shape), x), (jnp.arange(D.shape[0]), D, L))
    D, L = jnp.moveaxis(D, 0, -1), jnp.moveaxis(L, 0, -1)  # move scan dim back
    return L * jnp.sqrt(D)[..., None, :]
Beispiel #12
0
 def log_prob(self, value):
     M = _batch_mahalanobis(self.scale_tril, value - self.loc)
     half_log_det = np.log(np.diagonal(self.scale_tril, axis1=-2,
                                       axis2=-1)).sum(-1)
     normalize_term = half_log_det + 0.5 * self.scale_tril.shape[
         -1] * np.log(2 * np.pi)
     return -0.5 * M - normalize_term
Beispiel #13
0
 def log_abs_det_jacobian(self, x, y, intermediates=None):
     # Ref: http://web.mit.edu/18.325/www/handouts/handout2.pdf page 13
     n = jnp.shape(x)[-1]
     order = -jnp.arange(n, 0, -1)
     return -n * jnp.log(2) + jnp.sum(
         order * jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1)), axis=-1
     )
Beispiel #14
0
 def __call__(self, x):
     # check for symmetric
     symmetric = jnp.all(jnp.all(x == jnp.swapaxes(x, -2, -1), axis=-1), axis=-1)
     # check for the smallest eigenvalue is positive
     positive = jnp.linalg.eigh(x)[0][..., 0] > 0
     # check for diagonal equal to 1
     unit_variance = jnp.all(jnp.abs(jnp.diagonal(x, axis1=-2, axis2=-1) - 1) < 1e-6, axis=-1)
     return symmetric & positive & unit_variance
Beispiel #15
0
def rkhs_gram_cdist_ignore_const(G_ab: np.array,
                                 G_b: np.array,
                                 power: float = 2.):
    sqdist = np.diagonal(G_b)[np.newaxis, :] - 2 * G_ab
    if power == 2.:
        return sqdist
    else:
        return np.power(sqdist, power / 2.)
Beispiel #16
0
def mvn_log_prob(mean, cov, value):
    scale_tril = jsp.linalg.cholesky(cov, lower=True)
    M = mahalanobis(scale_tril, value - mean)
    half_log_det = jnp.log(jnp.diagonal(scale_tril, axis1=-2,
                                        axis2=-1)).sum(-1)
    normalize_term = half_log_det + 0.5 * scale_tril.shape[-1] * jnp.log(
        2 * jnp.pi)
    return -0.5 * M - normalize_term
def _batch_lowrank_logdet(W, D, capacitance_tril):
    r"""
    Uses "matrix determinant lemma"::
        log|W @ W.T + D| = log|C| + log|D|,
    where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute
    the log determinant.
    """
    return 2 * np.sum(np.log(np.diagonal(capacitance_tril, axis1=-2, axis2=-1)), axis=-1) + np.log(D).sum(-1)
Beispiel #18
0
 def testDiagonal(self, shape, dtype, offset, axis1, axis2, rng):
     onp_fun = lambda arg: onp.diagonal(arg, offset, axis1, axis2)
     lnp_fun = lambda arg: lnp.diagonal(arg, offset, axis1, axis2)
     args_maker = lambda: [rng(shape, dtype)]
     self._CheckAgainstNumpy(onp_fun,
                             lnp_fun,
                             args_maker,
                             check_dtypes=True)
     self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
def main():

    Gamma0, Omega, Sigma, T, Y_obs, amp, mu0, tec, freqs = generate_data()

    hmm = NonLinearDynamicsSmoother(TecLinearPhaseNestedSampling(freqs))
    hmm = jit(
        partial(hmm,
                tol=1.,
                maxiter=2,
                omega_window=None,
                sigma_window=None,
                momentum=0.,
                omega_diag_range=(0, jnp.inf),
                sigma_diag_range=(0, jnp.inf)))
    #
    # with disable_jit():
    keys = random.split(random.PRNGKey(0), T)
    # with disable_jit():
    res = hmm(Y_obs, Sigma, mu0, Gamma0, Omega, amp, keys)

    print(res.converged, res.niter)
    plt.plot(tec, label='true tec')
    plt.plot(res.post_mu[:, 0], label='infer tec')
    plt.fill_between(jnp.arange(T),
                     res.post_mu[:, 0] - jnp.sqrt(res.post_Gamma[:, 0, 0]),
                     res.post_mu[:, 0] + jnp.sqrt(res.post_Gamma[:, 0, 0]),
                     alpha=0.5)
    plt.legend()
    plt.show()

    plt.plot(jnp.sqrt(res.post_Gamma[:, 0, 0]))
    plt.title("Uncertainty tec")
    plt.show()

    plt.plot(tec - res.post_mu[:, 0], label='infer')
    plt.fill_between(
        jnp.arange(T),
        (tec - res.post_mu[:, 0]) - jnp.sqrt(res.post_Gamma[:, 0, 0]),
        (tec - res.post_mu[:, 0]) + jnp.sqrt(res.post_Gamma[:, 0, 0]),
        alpha=0.5)
    plt.title("Residual tec")
    plt.legend()
    plt.show()
    plt.plot(jnp.sqrt(res.Omega[:, 0, 0]))
    plt.title("omega")
    plt.show()
    plt.plot(
        jnp.mean(jnp.sqrt(jnp.diagonal(res.Sigma, axis2=-2, axis1=-1)),
                 axis=-1))
    plt.title("mean sigma")
    plt.show()
Beispiel #20
0
def _trace_and_diagonal(ntk: np.ndarray, trace_axes: Axes,
                        diagonal_axes: Axes) -> np.ndarray:
    """Extract traces and diagonals along respective pairs of axes from the `ntk`.

  Args:
    ntk:
      input empirical NTK of shape `(N1, X, Y, Z, ..., N2, X, Y, Z, ...)`.
    trace_axes:
      axes (among `X, Y, Z, ...`) to trace over, i.e. compute the trace along
      and remove the  respective pairs of axes from the `ntk`.
    diagonal_axes:
      axes (among `X, Y, Z, ...`) to take the diagonal along, i.e. extract the
      diagonal along the respective pairs of axes from the `ntk` (and hence
      reduce the resulting `ntk` axes count by 2).
  Returns:
    An array of shape, for example, `(N1, N2, Y, Z, Z, ...)` if
    `trace_axes=(1,)` (`X` axes removed), and `diagonal_axes=(2,)` (`Y` axes
    replaced with a single `Y` axis).
  """

    if ntk.ndim % 2 == 1:
        raise ValueError(
            'Expected an even-dimensional kernel. Please file a bug at'
            'https://github.com/google/neural-tangents/issues/new')

    output_ndim = ntk.ndim // 2

    trace_axes = utils.canonicalize_axis(trace_axes, output_ndim)
    diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim)

    n_diag, n_trace = len(diagonal_axes), len(trace_axes)
    contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes)

    for i, c in enumerate(reversed(trace_axes)):
        ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - i)

    for i, d in enumerate(diagonal_axes):
        axis1 = d - i
        axis2 = output_ndim + d - 2 * i - n_trace
        for c in trace_axes:
            if c < d:
                axis1 -= 1
                axis2 -= 1
        ntk = np.diagonal(ntk, axis1=axis1, axis2=axis2)

    ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_diag)
    res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes)
    ntk = np.moveaxis(ntk, range(-n_diag, 0), res_diagonal_axes)
    return ntk / contract_size
 def compute_expectations(self, mean, cov):
     """
     returns a list of expected values of the following expressions:
     x, x^2, cos(x), sin(x)
     """
     # characteristic function at [1, ..., 1]:
     t = np.ones(self.d)
     char = np.exp(np.vdot(t, (1j * mean - np.dot(cov, t) / 2)))
     expectations = [
         mean,
         np.diagonal(cov) + mean**2,
         np.real(char),
         np.imag(char)
     ]
     return expectations
Beispiel #22
0
    def _assert_is_diagonal(self, j, axis1, axis2, constant_diagonal: bool):
        c = j.shape[axis1]
        self.assertEqual(c, j.shape[axis2])
        mask_shape = [c if i in (axis1, axis2) else 1 for i in range(j.ndim)]
        mask = np.eye(c, dtype=np.bool_).reshape(mask_shape)

        # Check that removing the diagonal makes the array all 0.
        j_masked = np.where(mask, np.zeros((), j.dtype), j)
        self.assertAllClose(np.zeros_like(j, j.dtype), j_masked)

        if constant_diagonal:
            # Check that diagonal is constant.
            if j.size != 0:
                j_diagonals = np.diagonal(j, axis1=axis1, axis2=axis2)
                self.assertAllClose(np.min(j_diagonals, -1),
                                    np.max(j_diagonals, -1))
Beispiel #23
0
def log_prob_multivariate_normal(loc, scale_tril, value):
    def _batch_mahalanobis(bL, bx):
        # NB: The following procedure handles the case: bL.shape = (i, 1, n, n), bx.shape = (i, j, n)
        # because we don't want to broadcast bL to the shape (i, j, n, n).

        # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
        # we are going to make bx have shape (..., 1, j,  i, 1, n) to apply batched tril_solve
        sample_ndim = bx.ndim - bL.ndim + 1  # size of sample_shape
        out_shape = np.shape(bx)[:-1]  # shape of output
        # Reshape bx with the shape (..., 1, i, j, 1, n)
        bx_new_shape = out_shape[:sample_ndim]
        for (sL, sx) in zip(bL.shape[:-2], out_shape[sample_ndim:]):
            bx_new_shape += (sx // sL, sL)
        bx_new_shape += (-1, )
        bx = np.reshape(bx, bx_new_shape)
        # Permute bx to make it have shape (..., 1, j, i, 1, n)
        permute_dims = (tuple(range(sample_ndim)) +
                        tuple(range(sample_ndim, bx.ndim - 1, 2)) +
                        tuple(range(sample_ndim + 1, bx.ndim - 1, 2)) +
                        (bx.ndim - 1, ))
        bx = np.transpose(bx, permute_dims)

        # reshape to (-1, i, 1, n)
        xt = np.reshape(bx, (-1, ) + bL.shape[:-1])
        # permute to (i, 1, n, -1)
        xt = np.moveaxis(xt, 0, -1)
        solve_bL_bx = solve_triangular(bL, xt,
                                       lower=True)  # shape: (i, 1, n, -1)
        M = np.sum(solve_bL_bx**2, axis=-2)  # shape: (i, 1, -1)
        # permute back to (-1, i, 1)
        M = np.moveaxis(M, -1, 0)
        # reshape back to (..., 1, j, i, 1)
        M = np.reshape(M, bx.shape[:-1])
        # permute back to (..., 1, i, j, 1)
        permute_inv_dims = tuple(range(sample_ndim))
        for i in range(bL.ndim - 2):
            permute_inv_dims += (sample_ndim + i, len(out_shape) + i)
        M = np.transpose(M, permute_inv_dims)
        return np.reshape(M, out_shape)

    mahalanobis = _batch_mahalanobis(scale_tril, value - loc)
    half_log_det = np.log(np.diagonal(scale_tril, axis1=-2, axis2=-1)).sum(-1)
    normalize_term = half_log_det + 0.5 * scale_tril.shape[-1] * np.log(
        2 * np.pi)
    return -0.5 * mahalanobis - normalize_term
Beispiel #24
0
    def ker_fun(kernels):
        var1, nngp, var2, ntk, is_gaussian, _ = kernels

        pixel_axes = tuple(range(2, nngp.ndim))
        nngp = np.mean(nngp, axis=pixel_axes)
        ntk = np.mean(ntk, axis=pixel_axes) if _is_array(ntk) else ntk

        if var2 is None:
            var1 = np.diagonal(nngp)
        else:
            # TODO(romann)
            warnings.warn(
                'Pooling for different inputs `x1` and `x2` is not '
                'implemented and will only work if there are no '
                'nonlinearities in the network anywhere after the pooling '
                'layer. `var1` and `var2` will have wrong values. '
                'This will be fixed soon.')

        return Kernel(var1, nngp, var2, ntk, is_gaussian, True)
 def compute_expectations(self, means, covs, weights):
     """
     returns a list of expected values of the following expressions:
     x, x^2, cos(x), sin(x)
     """
     # characteristic function at [1, ..., 1]:
     t = np.ones(self.d)
     chars = np.array([
         np.exp(np.vdot(t, (1j * mean - np.dot(cov, t) / 2)))
         for mean, cov in zip(means, covs)
     ])  # shape (k,d)
     char = np.vdot(weights, chars)
     mean = np.einsum("i,id->d", weights, means)
     xsquares = [
         np.diagonal(cov) + mean**2 for mean, cov in zip(means, covs)
     ]
     expectations = [
         mean,
         np.einsum("i,id->d", weights, xsquares),
         np.real(char),
         np.imag(char)
     ]
     expectations = [np.squeeze(e) for e in expectations]
     return expectations
Beispiel #26
0
def bayesianpca_photonly(
    params,
    data_batch,
    data_aux,
    n_components,
    opt_basis,
    opt_priors,
):

    if opt_basis and opt_priors:
        pcacomponents_photonly = params[0]
        components_prior_params_photonly = params[1]
    if opt_basis and not opt_priors:
        components_prior_params_photonly = data_aux[0]
        pcacomponents_photonly = params[0]
    if not opt_basis and opt_priors:
        pcacomponents_photonly = data_aux[0]
        components_prior_params_photonly = params[0]

    (
        si,
        bs,
        phot,
        phot_invvar,
        phot_loginvvar,
        batch_redshifts,
        transferfunctions,
        batch_interprightindices_transfer,
        batch_interpweights_transfer,
    ) = data_batch

    components_phot_photonly = np.sum(
        pcacomponents_photonly[None, :, :, None] * transferfunctions[:, None, :, :],
        axis=2,
    )  # [n_z_transfer, n_components, n_phot]

    components_phot_photonly_obj = np.take(
        components_phot_photonly, batch_interprightindices_transfer, axis=0
    )
    components_phot_photonly_atz = (
        batch_interpweights_transfer[:, None, None] * components_phot_photonly_obj
        + (1 - batch_interpweights_transfer[:, None, None])
        * components_phot_photonly_obj
    )

    components_prior_mean_photonly = PriorModel.get_mean_at_z(
        components_prior_params_photonly, batch_redshifts
    )
    components_prior_loginvvar_photonly = PriorModel.get_loginvvar_at_z(
        components_prior_params_photonly, batch_redshifts
    )
    components_prior_invvar_photonly = np.exp(components_prior_loginvvar_photonly)
    (
        logfml_photonly,
        thetamap_photonly,
        theta_cov_photonly,
    ) = logmarglike_lineargaussianmodel_twotransfers_jitvmap(
        components_phot_photonly_atz,  # [n_obj, n_components, nphot]
        phot,  # [n_obj, nphot]
        phot_invvar,  # [n_obj, nphot]
        phot_loginvvar,  # [n_obj, nphot]
        components_prior_mean_photonly,
        components_prior_invvar_photonly,
        components_prior_loginvvar_photonly,
    )
    # ellfactors = np.ones_like(batch_redshifts)
    photmod_map_photonly = np.sum(
        components_phot_photonly_atz * thetamap_photonly[:, :, None],
        axis=1,
    )
    thetastd_photonly = np.diagonal(theta_cov_photonly, axis1=1, axis2=2) ** 0.5

    return (
        logfml_photonly,
        thetamap_photonly,
        thetastd_photonly,
        photmod_map_photonly,
    )
Beispiel #27
0
def bayesianpca_specandphot_explicit(
    components_spec,  # [n_obj, n_archetypes, n_components, nspec]
    components_phot,  # [n_obj, n_archetypes, n_components, nphot]
    polynomials_spec,  # [n_poly, nspec]
    ellfactors,  # [n_obj, n_archetypes]
    spec,  # [n_obj, nspec]
    spec_invvar,  # [n_obj, nspec]
    spec_loginvvar,  # [n_obj, nspec]
    phot,  # [n_obj, nphot]
    phot_invvar,  # [n_obj, nphot]
    phot_loginvvar,  # [n_obj, nphot]
    components_prior_mean,  # [n_obj, n_archetypes, n_components]
    components_prior_loginvvar,  # [n_obj, n_archetypes, n_components]
    polynomials_prior_mean,  # [n_poly]
    polynomials_prior_loginvvar,  # [n_poly]
):

    n_obj, n_archetypes, n_components, nspec = np.shape(components_spec)
    n_poly = np.shape(polynomials_spec)[0]
    nphot = np.shape(phot)[1]

    components_spec_all = np.concatenate(
        [
            components_spec,
            polynomials_spec[None, :, :] * np.ones((n_obj, n_archetypes, 1, nspec)),
        ],
        axis=-2,
    )  # [n_obj, n_archetypes, n_components+n_poly, nspec]

    components_phot_all = np.concatenate(
        [components_phot, np.zeros((n_obj, n_archetypes, n_poly, nphot))], axis=2
    )  # [n_obj,n_archetypes,  n_components+n_poly, nphot]

    # if shape is [n_poly] instead of [n_obj, n_components]
    mu = np.concatenate(
        [
            components_prior_mean,
            polynomials_prior_mean[None, None, :] * np.ones((n_obj, n_archetypes, 1)),
        ],
        axis=-1,
    )
    logmuinvvar = np.concatenate(
        [
            components_prior_loginvvar,
            polynomials_prior_loginvvar[None, None, :]
            * np.ones((n_obj, n_archetypes, 1)),
        ],
        axis=-1,
    )
    muinvvar = np.exp(logmuinvvar)
    # if shape is [n_obj, n_components] instead of [n_poly]
    # mu = polynomials_prior_mean
    # muinvvar = np.exp(polynomials_prior_loginvvar)
    # logmuinvvar = np.log(muinvvar)  # Assume no mask in last dimension
    (
        logfml,
        thetamap,
        theta_cov,
    ) = logmarglike_lineargaussianmodel_threetransfers_jitvmapvmap(
        ellfactors,
        components_spec_all,
        components_phot_all,
        spec[:, None, :] * np.ones((1, n_archetypes, 1)),
        spec_invvar[:, None, :] * np.ones((1, n_archetypes, 1)),
        spec_loginvvar[:, None, :] * np.ones((1, n_archetypes, 1)),
        phot[:, None, :] * np.ones((1, n_archetypes, 1)),
        phot_invvar[:, None, :] * np.ones((1, n_archetypes, 1)),
        phot_loginvvar[:, None, :] * np.ones((1, n_archetypes, 1)),
        mu,
        muinvvar,
        logmuinvvar,
    )

    thetastd = np.diagonal(theta_cov, axis1=2, axis2=3) ** 0.5

    # Produce best fit models
    specmod_map = np.sum(components_spec_all * thetamap[:, :, :, None], axis=-2)
    photmod_map = np.sum(components_phot_all * thetamap[:, :, :, None], axis=-2)

    return (logfml, thetamap, thetastd, specmod_map, photmod_map)
Beispiel #28
0
 def quantiles(self, params, quantiles):
     transform = self._get_transform(params)
     quantiles = np.array(quantiles)[..., None]
     latent = dist.Normal(transform.loc,
                          np.diagonal(transform.scale_tril)).icdf(quantiles)
     return self._unpack_and_constrain(latent, params)
Beispiel #29
0
 def _inverse(self, y):
     z = matrix_to_tril_vec(y, diagonal=-1)
     diag = _softplus_inv(jnp.diagonal(y, axis1=-2, axis2=-1))
     return jnp.concatenate([z, diag], axis=-1)
Beispiel #30
0
 def _inverse(self, y):
     z = matrix_to_tril_vec(y, diagonal=-1)
     return jnp.concatenate([z, jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1))], axis=-1)