Пример #1
0
  def testPmapMapVmapCombinations(self):
    # https://github.com/google/jax/issues/2822
    def vv(x, y):
      """Vector-vector multiply"""
      return np.dot(x, y)

    def matrix_vector(x, y, parallel=True):
      """Matrix vector multiply. First batch it and then row by row"""
      fv = lambda z: lax.map(lambda j: vv(j, y), z)
      if parallel:
        # split leading axis in two
        new_x = x.reshape((jax.device_count(), -1, *x.shape[1:]))
        # apply map
        new_res = pmap(fv)(new_x)
        # reshape back out
        res = new_res.reshape(x.shape[0], *new_res.shape[2:])
      else:
        res = fv(x)
      return res

    x = random.normal(random.PRNGKey(1), (80, 5))
    y = random.normal(random.PRNGKey(1), (10, 5))

    result1 = vmap(lambda b: matrix_vector(x, b, True))(y)       # vmap + pmap
    result2 = lax.map(lambda b: matrix_vector(x, b, False), y)   # map + map
    result3 = lax.map(lambda b: matrix_vector(x, b, True), y)    # map + pmap
    result4 = np.stack([matrix_vector(x, b, False) for b in y])  # none + map

    self.assertAllClose(result1, result2, check_dtypes=False, atol=1e-3, rtol=1e-3)
    self.assertAllClose(result1, result3, check_dtypes=False, atol=1e-3, rtol=1e-3)
    self.assertAllClose(result1, result4, check_dtypes=False, atol=1e-3, rtol=1e-3)
  def calc_bpd_loop(self, model_fn, *, x_start, rng):
    """Calculate variational bound (loop over all timesteps and sum)."""
    batch_size = x_start.shape[0]

    noise_shape = x_start.shape + (self.num_pixel_vals,)
    def map_fn(map_val):
      t, cur_rng = map_val
      # Calculate VB term at the current timestep
      t = jnp.full((batch_size,), t)
      vb, _ = self.vb_terms_bpd(
          model_fn=model_fn, x_start=x_start, t=t,
          x_t=self.q_sample(
              x_start=x_start, t=t,
              noise=jax.random.uniform(cur_rng, noise_shape)))
      del cur_rng
      assert vb.shape == (batch_size,)
      return vb

    vbterms_tb = lax.map(
        map_fn, (jnp.arange(self.num_timesteps),
                 jax.random.split(rng, self.num_timesteps)))
    vbterms_bt = vbterms_tb.T
    assert vbterms_bt.shape == (batch_size, self.num_timesteps)

    prior_b = self.prior_bpd(x_start=x_start)
    total_b = vbterms_tb.sum(axis=0) + prior_b
    assert prior_b.shape == total_b.shape == (batch_size,)

    return {
        'total': total_b,
        'vbterms': vbterms_bt,
        'prior': prior_b,
    }
Пример #3
0
def rvcoref(t, T0, P, e, omegaA, K, i):
    """Unit-free radial velocity curve w/o systemic velocity, in addition, i
    and K are separated.

    Args:
       t: Time in your time unit
       T0: Time of periastron passage in your time unit
       P: orbital period in your time unit
       e: eccentricity
       omegaA: argument of periastron
       K: RV semi-amplitude/sin i in your velocity unit
       i: inclination

    Returns:
       radial velocity curve in your velocity unit
    """
    n = 2*jnp.pi/P
    M = n*(t-T0)

    Ea = map(lambda x: getE.getE(x, e), M)
    cosE = jnp.cos(Ea)
    cosf = (-cosE + e)/(-1 + cosE*e)
    sinf = jnp.sqrt((-1 + cosE*cosE)*(-1 + e*e))/(-1 + cosE*e)
    sinf = jnp.where(Ea < jnp.pi, -sinf, sinf)

    cosfpo = cosf*jnp.cos(omegaA)-sinf*jnp.sin(omegaA)
    face = 1.0/jnp.sqrt(1.0-e*e)
    Ksini = K*jnp.sin(i)
    model = Ksini*face*(cosfpo+e*jnp.cos(omegaA))

    return model
Пример #4
0
def test_multiple_twists(dim=1, num_twists=3, integration_range=(-20, 20), num_point_ops=20000):
    """
    for several random twists of a base gaussian, assert an interconsistent logpdf
    """
    from jax.lax import map

    for i in tqdm.trange(10):
        #define the params
        x = np.random.randn(dim)
        mean = np.random.randn(dim)
        cov = np.abs(np.random.randn(dim)) + np.ones(dim)*10
        As, bs = np.abs(np.random.randn(num_twists, dim)), np.random.randn(num_twists, dim)

        #get the untwisted logZ, logpdf, and the twisting value
        untwisted_logZ = Normal_logZ(mean, cov)
        untwisted_logpdf = unnormalized_Normal_logp(x, mean, cov)
        log_twist = log_psi_twist(x, As, bs) #manually compute the twisting value
        manual_unnorm_logp_twist = untwisted_logpdf + log_twist #add the twist value to the untwisted unnormalized logp

        #do the parameter twist
        twisted_mu, twisted_cov, log_normalizer = do_param_twist(mean, cov, As, bs, True)
        twisted_logZ = Normal_logZ(twisted_mu, twisted_cov)
        twisted_logpdf = unnormalized_Normal_logp(x, twisted_mu, twisted_cov) + log_normalizer - twisted_logZ + untwisted_logZ

        #make sure the manual and automatic one are equal
        assert np.isclose(manual_unnorm_logp_twist, twisted_logpdf)

        #now, we have to do the hard part and try to compute the twisted normalization constant...
        unnorm_logp = lambda x: jnp.exp(unnormalized_Normal_logp(x, mean, cov) + log_psi_twist(x, As, bs))
        vals = map(unnorm_logp, jnp.linspace(integration_range[0], integration_range[1], num_point_ops)[..., jnp.newaxis])
        dx = (integration_range[1] - integration_range[0]) / num_point_ops
        man_logZ = jnp.log(dx*vals.sum())
        assert np.isclose(man_logZ, log_normalizer + untwisted_logZ, atol=1e-2)
Пример #5
0
def gaussian_cl_covariance(ell, probes, cl_signal, cl_noise, f_sky=0.25):
    """
    Computes a Gaussian covariance for the angular cls of the provided probes

    return_cls: (returns covariance)
    """
    ell = np.atleast_1d(ell)
    n_ell = len(ell)

    # Adding noise to auto-spectra
    cl_obs = cl_signal + cl_noise
    n_cls = cl_obs.shape[0]

    # Normalization of covariance
    norm = (2 * ell + 1) * np.gradient(ell) * f_sky

    # Retrieve ordering for blocks of the covariance matrix
    cov_blocks = np.array(_get_cov_blocks_ordering(probes))

    def get_cov_block(inds):
        a, b, c, d = inds
        cov = (cl_obs[a] * cl_obs[b] + cl_obs[c] * cl_obs[d]) / norm
        return cov * np.eye(n_ell)

    cov_mat = lax.map(get_cov_block, cov_blocks)

    # Reshape covariance matrix into proper matrix
    cov_mat = cov_mat.reshape((n_cls, n_cls, n_ell, n_ell))
    cov_mat = cov_mat.transpose(axes=(0, 2, 1, 3)).reshape(
        (n_ell * n_cls, n_ell * n_cls))
    return cov_mat
Пример #6
0
 def distributed_matrix_vector(x, y):
     """Matrix vector multiply. First batch it and then row by row"""
     fv = lambda z: lax.map(lambda j: vv(j, y), z)
     res = pmap(fv)(x.reshape((jax.device_count(), -1) +
                              tuple(x.shape[1:])))
     res = res.reshape(res.shape[0] * res.shape[1], *res.shape[2:])
     return res
Пример #7
0
def rejection_stitch_proposal_all(ssm_scenario: StateSpaceModel,
                                  x0_all: jnp.ndarray,
                                  t: float,
                                  x1_all: jnp.ndarray,
                                  tplus1: float,
                                  x1_log_weight: jnp.ndarray,
                                  bound_inflation: float,
                                  not_yet_accepted_arr: jnp.ndarray,
                                  x1_all_sampled_inds: jnp.ndarray,
                                  bound: float,
                                  random_keys: jnp.ndarray,
                                  rejection_iter: int,
                                  num_transition_evals: int) \
        -> Tuple[jnp.ndarray, jnp.ndarray, float, jnp.ndarray, int, int]:
    n = len(x1_all)
    mapped_tup = map(lambda i: rejection_stitch_proposal_single_cond(not_yet_accepted_arr[i],
                                                                     x1_all_sampled_inds[i],
                                                                     ssm_scenario,
                                                                     x0_all[i],
                                                                     t,
                                                                     x1_all,
                                                                     tplus1,
                                                                     x1_log_weight,
                                                                     bound,
                                                                     random_keys[i]), jnp.arange(n))
    x1_all_sampled_inds, dens_evals, not_yet_accepted_arr_new, random_keys = mapped_tup

    # Check if we need to start again
    max_dens = jnp.max(dens_evals)
    reset_bound = max_dens > bound
    bound = jnp.where(reset_bound, max_dens * bound_inflation, bound)
    not_yet_accepted_arr_new = jnp.where(reset_bound, jnp.ones(n, dtype='bool'), not_yet_accepted_arr_new)
    return not_yet_accepted_arr_new, x1_all_sampled_inds, bound, random_keys, rejection_iter + 1, \
           num_transition_evals + not_yet_accepted_arr.sum()
Пример #8
0
        def integrand(a):
            # Step 1: retrieve the associated comoving distance
            chi = bkgrd.radial_comoving_distance(cosmo, a)

            # Step 2: get the power spectrum for this combination of chi and a
            k = (ell + 0.5) / np.clip(chi, 1.0)

            # pk should have shape [na]
            pk = power.nonlinear_matter_power(cosmo, k, a, transfer_fn,
                                              nonlinear_fn)

            # Compute the kernels for all probes
            kernels = np.vstack([p.kernel(cosmo, a2z(a), ell) for p in probes])

            # Define an ordering for the blocks of the signal vector
            cl_index = np.array(_get_cl_ordering(probes))

            # Compute all combinations of tracers
            def combine_kernels(inds):
                return kernels[inds[0]] * kernels[inds[1]]

            # Now kernels has shape [ncls, na]
            kernels = lax.map(combine_kernels, cl_index)

            result = pk * kernels * bkgrd.dchioverda(cosmo, a) / np.clip(
                chi**2, 1.0)

            # We transpose the result just to make sure that na is first
            return result.T
Пример #9
0
    def _predict_spatial_batched(self, elec_params, x, y):
        """ A batched version of _predict_spatial_jax
        Parameters:
        -------------
        elec_params : np.array with shape (batch_size, n_electrodes, 3)
            The 3 columns are freq, amp, pdur for each electrode
        x, y: np.array with shape (n_electrodes)
            x and y coordinates of electrodes
        Returns:
        ------------
        resp : np.array() representing the resulting percepts, shape (batch_size, :, 1)
        """
        bright_effects = self.bright_model(elec_params[:, :, 0],
                                           elec_params[:, :, 1],
                                           elec_params[:, :, 2])
        size_effects = self.size_model(elec_params[:, :, 0],
                                       elec_params[:, :, 1], elec_params[:, :,
                                                                         2])
        streak_effects = self.streak_model(elec_params[:, :, 0],
                                           elec_params[:, :, 1],
                                           elec_params[:, :, 2])
        eparams = jnp.stack([bright_effects, size_effects, streak_effects],
                            axis=2)

        def predict_one(e_params):
            return self.biphasic_axon_map_jax(e_params, x, y,
                                              self.axon_contrib, self.rho,
                                              self.thresh_percept)

        resps = lax.map(predict_one, eparams)

        return resps
Пример #10
0
def rvf(t, T0, P, e, omegaA, Ksini, Vsys):
    """Unit-free radial velocity curve for SB1.

    Args:
       t: Time in your time unit
       T0: Time of periastron passage in your time unit
       P: orbital period in your time unit
       e: eccentricity
       omegaA: argument of periastron
       Ksini: RV semi-amplitude in your velocity unit
       Vsys: systemic velocity in your velocity unit

    Returns:
       radial velocity curve in your velocity unit
    """

    n = 2*jnp.pi/P
    M = n*(t-T0)

    Ea = map(lambda x: getE.getE(x, e), M)
    cosE = jnp.cos(Ea)
    cosf = (-cosE + e)/(-1 + cosE*e)
    sinf = jnp.sqrt((-1 + cosE*cosE)*(-1 + e*e))/(-1 + cosE*e)
    sinf = jnp.where(Ea < jnp.pi, -sinf, sinf)

    cosfpo = cosf*jnp.cos(omegaA)-sinf*jnp.sin(omegaA)
    face = 1.0/jnp.sqrt(1.0-e*e)
    model = Ksini*face*(cosfpo+e*jnp.cos(omegaA)) + Vsys

    return model
Пример #11
0
    def _unpack_and_constrain(self, latent_sample, params):
        def unpack_single_latent(latent):
            unpacked_samples = self._unpack_latent(latent)
            # XXX: we need to add param here to be able to replay model
            unpacked_samples.update({
                k: v
                for k, v in params.items() if k in self.prototype_trace
                and self.prototype_trace[k]["type"] == "param"
            })
            samples = self._postprocess_fn(unpacked_samples)
            # filter out param sites
            return {
                k: v
                for k, v in samples.items() if k in self.prototype_trace
                and self.prototype_trace[k]["type"] != "param"
            }

        sample_shape = jnp.shape(latent_sample)[:-1]
        if sample_shape:
            latent_sample = jnp.reshape(latent_sample,
                                        (-1, jnp.shape(latent_sample)[-1]))
            unpacked_samples = lax.map(unpack_single_latent, latent_sample)
            return tree_map(
                lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]),
                unpacked_samples,
            )
        else:
            return unpack_single_latent(latent_sample)
Пример #12
0
def safe_map(f, xs):
  """equivalent to jax.lax.map, but handles empty arrays too.

  """
  if xs.shape[0] == 0:
    return xs

  return l.map(f, xs)
Пример #13
0
            def laxmap_postprocess_fn(states, args, kwargs):
                if self.postprocess_fn is None:
                    body_fn = self.sampler.postprocess_fn(args, kwargs)
                else:
                    body_fn = self.postprocess_fn
                if self.chain_method == "vectorized" and self.num_chains > 1:
                    body_fn = vmap(body_fn)

                return lax.map(body_fn, states)
Пример #14
0
def gaussians_twist(Xp, potential, dt, potential_params, A_fn, b_fn, A_params,
                    b_params, get_log_normalizer):
    """
    conduct a gaussian twist to return a twisted mean, covariance vector, normalizing constant, As, and bs

    arguments
        Xp : jnp.array(Dx)
            previous x values
        potential : function
            potential function
        dt : float
            time increment
        potential_parameters : jnp.array(R)
            arg 1 for potential function
        A_fn : function
            A twisting function
        b_fn : function
            b twisting function
        A_params : jnp.array(U, V)
            U vector of parameters to twist; each is of dim V
        b_params : jnp.array(U, W)
            U vector of parameters to twist; each is of dim W
        get_log_normalizer : bool
            whether to compute the normalization constant of the forward kernel
    return
        twisted_mu : jnp.array(Dx)
            twisted mean
        twisted_cov : jnp.array(Dx)
            twisted covariance vector
        logZ : float
            forward kernel log normalizer
        As : jnp.array(Q, Dx)
            array of A twisting vectors
        bs : jnp.array(Q, Dx)
            array of b twisting vectors
    """
    partial_A_fn, partial_b_fn = partial(A_fn, Xp), partial(b_fn, Xp)
    As = map(partial_A_fn, A_params)
    bs = map(partial_b_fn, b_params)
    start_mu, start_cov = base_fk_params(Xp, potential, dt, potential_params)
    twisted_mu, twisted_cov, logZ = do_param_twist(start_mu, start_cov, As, bs,
                                                   get_log_normalizer)
    return twisted_mu, twisted_cov, logZ, (As, bs)
Пример #15
0
 def init_kernel(init_params,
                 num_warmup,
                 adapt_state_size=None,
                 inverse_mass_matrix=None,
                 dense_mass=False,
                 model_args=(),
                 model_kwargs=None,
                 rng_key=random.PRNGKey(0)):
     nonlocal wa_steps
     wa_steps = num_warmup
     pe_fn = potential_fn
     if potential_fn_gen:
         if pe_fn is not None:
             raise ValueError(
                 'Only one of `potential_fn` or `potential_fn_gen` must be provided.'
             )
         else:
             kwargs = {} if model_kwargs is None else model_kwargs
             pe_fn = potential_fn_gen(*model_args, **kwargs)
     rng_key_sa, rng_key_zs, rng_key_z = random.split(rng_key, 3)
     z = init_params
     z_flat, unravel_fn = ravel_pytree(z)
     if inverse_mass_matrix is None:
         inverse_mass_matrix = jnp.identity(
             z_flat.shape[-1]) if dense_mass else jnp.ones(z_flat.shape[-1])
     inv_mass_matrix_sqrt = jnp.linalg.cholesky(inverse_mass_matrix) if dense_mass \
         else jnp.sqrt(inverse_mass_matrix)
     if adapt_state_size is None:
         # XXX: heuristic choice
         adapt_state_size = 2 * z_flat.shape[-1]
     else:
         assert adapt_state_size > 1, 'adapt_state_size should be greater than 1.'
     # NB: mean is init_params
     zs = z_flat + _sample_proposal(inv_mass_matrix_sqrt, rng_key_zs,
                                    (adapt_state_size, ))
     # compute potential energies
     pes = lax.map(lambda z: pe_fn(unravel_fn(z)), zs)
     if dense_mass:
         cov = jnp.cov(zs, rowvar=False, bias=True)
         if cov.shape == ():  # JAX returns scalar for 1D input
             cov = cov.reshape((1, 1))
         cholesky = jnp.linalg.cholesky(cov)
         # if cholesky is NaN, we use the scale from `sample_proposal` here
         inv_mass_matrix_sqrt = jnp.where(jnp.any(jnp.isnan(cholesky)),
                                          inv_mass_matrix_sqrt, cholesky)
     else:
         inv_mass_matrix_sqrt = jnp.std(zs, 0)
     adapt_state = SAAdaptState(zs, pes, jnp.mean(zs, 0),
                                inv_mass_matrix_sqrt)
     k = random.categorical(rng_key_z, jnp.zeros(zs.shape[0]))
     z = unravel_fn(zs[k])
     pe = pes[k]
     sa_state = SAState(jnp.array(0), z, pe, jnp.zeros(()), jnp.zeros(()),
                        jnp.array(False), adapt_state, rng_key_sa)
     return device_put(sa_state)
Пример #16
0
def _binomial(key, p, n, shape):
    shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
    # reshape to map over axis 0
    p = jnp.reshape(jnp.broadcast_to(p, shape), -1)
    n = jnp.reshape(jnp.broadcast_to(n, shape), -1)
    key = random.split(key, jnp.size(p))
    if jax.default_backend() == "cpu":
        ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n))
    else:
        ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
    return jnp.reshape(ret, shape)
Пример #17
0
def _binomial(key, p, n, shape):
    shape = shape or lax.broadcast_shapes(np.shape(p), np.shape(n))
    # reshape to map over axis 0
    p = np.reshape(np.broadcast_to(p, shape), -1)
    n = np.reshape(np.broadcast_to(n, shape), -1)
    key = random.split(key, np.size(p))
    if xla_bridge.get_backend().platform == 'cpu':
        ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n))
    else:
        ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
    return np.reshape(ret, shape)
Пример #18
0
def _poisson(key, rate, shape, dtype):
    # Ref: https://en.wikipedia.org/wiki/Poisson_distribution#Generating_Poisson-distributed_random_variables
    shape = shape or np.shape(rate)
    rate = lax.convert_element_type(rate, canonicalize_dtype(np.float64))
    rate = np.broadcast_to(rate, shape)
    rng_keys = random.split(key, np.size(rate))
    if xla_bridge.get_backend().platform == 'cpu':
        k = lax.map(_poisson_one, (rng_keys, np.reshape(rate, -1)))
    else:
        k = vmap(_poisson_one)((rng_keys, np.reshape(rate, -1)))
    k = lax.convert_element_type(k, dtype)
    return np.reshape(k, shape)
Пример #19
0
 def _constrain(self, latent_samples):
     name = list(latent_samples)[0]
     sample_shape = jnp.shape(latent_samples[name])[
         :jnp.ndim(latent_samples[name]) - jnp.ndim(self._init_locs[name])]
     if sample_shape:
         flatten_samples = tree_map(lambda x: jnp.reshape(x, (-1,) + jnp.shape(x)[len(sample_shape):]),
                                    latent_samples)
         contrained_samples = lax.map(self._postprocess_fn, flatten_samples)
         return tree_map(lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]),
                         contrained_samples)
     else:
         return self._postprocess_fn(latent_samples)
Пример #20
0
def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None):
    """
    Vectorizing map that maps a function `fn` over `batch_ndims` leading axes
    of `xs`. This uses jax.vmap over smaller chunks of the batch dimensions
    to keep memory usage constant.

    :param callable fn: The function to map over.
    :param xs: JAX pytree (e.g. an array, a list/tuple/dict of arrays,...)
    :param int batch_ndims: The number of leading dimensions of `xs`
        to apply `fn` element-wise over them.
    :param int chunk_size: Size of each chunk of `xs`.
        Defaults to the size of batch dimensions.
    :returns: output of `fn(xs)`.
    """
    flatten_xs = tree_flatten(xs)[0]
    batch_shape = np.shape(flatten_xs[0])[:batch_ndims]
    for x in flatten_xs[1:]:
        assert np.shape(x)[:batch_ndims] == batch_shape

    # we'll do map(vmap(fn), xs) and make xs.shape = (num_chunks, chunk_size, ...)
    num_chunks = batch_size = int(np.prod(batch_shape))
    prepend_shape = (batch_size, ) if batch_size > 1 else ()
    xs = tree_map(
        lambda x: jnp.reshape(x, prepend_shape + jnp.shape(x)[batch_ndims:]),
        xs)
    # XXX: probably for the default behavior with chunk_size=None,
    # it is better to catch OOM error and reduce chunk_size by half until OOM disappears.
    chunk_size = batch_size if chunk_size is None else min(
        batch_size, chunk_size)
    if chunk_size > 1:
        pad = chunk_size - (batch_size % chunk_size)
        xs = tree_map(
            lambda x: jnp.pad(x, ((0, pad), ) + ((0, 0), ) * (np.ndim(x) - 1)),
            xs)
        num_chunks = batch_size // chunk_size + int(pad > 0)
        prepend_shape = (-1, ) if num_chunks > 1 else ()
        xs = tree_map(
            lambda x: jnp.reshape(
                x, prepend_shape + (chunk_size, ) + jnp.shape(x)[1:]),
            xs,
        )
        fn = vmap(fn)

    ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
    map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
    ys = tree_map(
        lambda y: jnp.reshape(y, (int(np.prod(jnp.shape(y)[:map_ndims])), ) +
                              jnp.shape(y)[map_ndims:])[:batch_size],
        ys,
    )
    return tree_map(lambda y: jnp.reshape(y, batch_shape + jnp.shape(y)[1:]),
                    ys)
Пример #21
0
 def matrix_vector(x, y, parallel=True):
   """Matrix vector multiply. First batch it and then row by row"""
   fv = lambda z: lax.map(lambda j: vv(j, y), z)
   if parallel:
     # split leading axis in two
     new_x = x.reshape((jax.device_count(), -1, *x.shape[1:]))
     # apply map
     new_res = pmap(fv)(new_x)
     # reshape back out
     res = new_res.reshape(x.shape[0], *new_res.shape[2:])
   else:
     res = fv(x)
   return res
Пример #22
0
def gaussian_cl_covariance(ell,
                           probes,
                           cl_signal,
                           cl_noise,
                           f_sky=0.25,
                           sparse=True):
    """
    Computes a Gaussian covariance for the angular cls of the provided probes

    Set sparse True to return a sparse matrix representation that uses a factor
    of n_ell less memory and is compatible with the linear algebra operations
    in :mod:`jax_cosmo.sparse`.

    return_cls: (returns covariance)
    """
    ell = np.atleast_1d(ell)
    n_ell = len(ell)
    one = 1.0 if sparse else np.eye(n_ell)

    # Adding noise to auto-spectra
    cl_obs = cl_signal + cl_noise
    n_cls = cl_obs.shape[0]

    # Normalization of covariance
    norm = (2 * ell + 1) * np.gradient(ell) * f_sky

    # Retrieve ordering for blocks of the covariance matrix
    cov_blocks = np.array(_get_cov_blocks_ordering(probes))

    def get_cov_block(inds):
        a, b, c, d = inds
        cov = (cl_obs[a] * cl_obs[b] + cl_obs[c] * cl_obs[d]) / norm
        return cov * one

    # Return a sparse representation of the matrix containing only the diagonals
    # for each of the n_cls x n_cls blocks of size n_ell x n_ell.
    # We could compress this further using the symmetry of the blocks, but
    # it is easier to invert this matrix with this redundancy included.
    cov_mat = lax.map(get_cov_block, cov_blocks)

    # Reshape covariance matrix into proper matrix
    if sparse:
        cov_mat = cov_mat.reshape((n_cls, n_cls, n_ell))
    else:
        cov_mat = cov_mat.reshape((n_cls, n_cls, n_ell, n_ell))
        cov_mat = cov_mat.transpose(axes=(0, 2, 1, 3)).reshape(
            (n_ell * n_cls, n_ell * n_cls))
    return cov_mat
Пример #23
0
    def _single_chain_mcmc(self, init, args, kwargs, collect_fields):
        rng_key, init_state, init_params = init
        if init_state is None:
            init_state = self.sampler.init(rng_key,
                                           self.num_warmup,
                                           init_params,
                                           model_args=args,
                                           model_kwargs=kwargs)
        if self.postprocess_fn is None:
            postprocess_fn = self.sampler.postprocess_fn(args, kwargs)
        else:
            postprocess_fn = self.postprocess_fn
        diagnostics = lambda x: self.sampler.get_diagnostics_str(x[
            0]) if rng_key.ndim == 1 else ''  # noqa: E731
        init_val = (init_state, args,
                    kwargs) if self._jit_model_args else (init_state, )
        lower_idx = self._collection_params["lower"]
        upper_idx = self._collection_params["upper"]
        phase = self._collection_params["phase"]

        collect_vals = fori_collect(
            lower_idx,
            upper_idx,
            self._get_cached_fn(),
            init_val,
            transform=_collect_fn(collect_fields),
            progbar=self.progress_bar,
            return_last_val=True,
            collection_size=self._collection_params["collection_size"],
            progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase),
            diagnostics_fn=diagnostics)
        states, last_val = collect_vals
        # Get first argument of type `HMCState`
        last_state = last_val[0]
        if len(collect_fields) == 1:
            states = (states, )
        states = dict(zip(collect_fields, states))
        # Apply constraints if number of samples is non-zero
        site_values = tree_flatten(states[self._sample_field])[0]
        # XXX: lax.map still works if some arrays have 0 size
        # so we only need to filter out the case site_value.shape[0] == 0
        # (which happens when lower_idx==upper_idx)
        if len(site_values) > 0 and jnp.shape(site_values[0])[0] > 0:
            if self.chain_method == "vectorized" and self.num_chains > 1:
                postprocess_fn = vmap(postprocess_fn)
            states[self._sample_field] = lax.map(postprocess_fn,
                                                 states[self._sample_field])
        return states, last_state
Пример #24
0
    def _unpack_and_constrain(self, latent_sample, params):
        def unpack_single_latent(latent):
            unpacked_samples = self._unpack_latent(latent)
            # add param sites in model
            unpacked_samples.update({k: v for k, v in params.items() if k in self.prototype_trace
                                     and self.prototype_trace[k]['type'] == 'param'})
            return self._postprocess_fn(unpacked_samples)

        sample_shape = jnp.shape(latent_sample)[:-1]
        if sample_shape:
            latent_sample = jnp.reshape(latent_sample, (-1, jnp.shape(latent_sample)[-1]))
            unpacked_samples = lax.map(unpack_single_latent, latent_sample)
            return tree_map(lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]),
                            unpacked_samples)
        else:
            return unpack_single_latent(latent_sample)
Пример #25
0
def noise_cl(ell, probes):
    """
    Computes noise contributions to auto-spectra
    """
    n_ell = len(ell)
    # Concatenate noise power for each tracer
    noise = np.concatenate([p.noise() for p in probes])
    # Define an ordering for the blocks of the signal vector
    cl_index = np.array(_get_cl_ordering(probes))

    # Only include a noise contribution for the auto-spectra
    def get_noise_cl(inds):
        i, j = inds
        delta = 1.0 - np.clip(np.abs(i - j), 0.0, 1.0)
        return noise[i] * delta * np.ones(n_ell)

    return lax.map(get_noise_cl, cl_index)
Пример #26
0
def rejection_stitching(ssm_scenario: StateSpaceModel,
                        x0_all: jnp.ndarray,
                        t: float,
                        x1_all: jnp.ndarray,
                        tplus1: float,
                        x1_log_weight: jnp.ndarray,
                        random_key: jnp.ndarray,
                        maximum_rejections: int,
                        init_bound_param: float,
                        bound_inflation: float) -> Tuple[jnp.ndarray, int]:
    rejection_initial_keys = random.split(random_key, 3)
    n = len(x1_all)

    # Prerun to initiate bound
    x1_initial_inds = random.categorical(rejection_initial_keys[0], x1_log_weight, shape=(n,))
    initial_cond_dens = jnp.exp(-vmap(ssm_scenario.transition_potential,
                                      (0, None, 0, None))(x0_all, t, x1_all[x1_initial_inds], tplus1))
    max_cond_dens = jnp.max(initial_cond_dens)
    initial_bound = jnp.where(max_cond_dens > init_bound_param, max_cond_dens * bound_inflation, init_bound_param)
    initial_not_yet_accepted_arr = random.uniform(rejection_initial_keys[1], (n,)) > initial_cond_dens / initial_bound

    out_tup = while_loop(lambda tup: jnp.logical_and(tup[0].sum() > 0, tup[-2] < maximum_rejections),
                         lambda tup: rejection_stitch_proposal_all(ssm_scenario, x0_all, t, x1_all, tplus1,
                                                                   x1_log_weight,
                                                                   bound_inflation, *tup),
                         (initial_not_yet_accepted_arr,
                          x1_initial_inds,
                          initial_bound,
                          random.split(rejection_initial_keys[2], n),
                          1,
                          n))
    not_yet_accepted_arr, x1_final_inds, final_bound, random_keys, rej_attempted, num_transition_evals = out_tup

    x1_final_inds = map(lambda i: full_stitch_single_cond(not_yet_accepted_arr[i],
                                                          x1_final_inds[i],
                                                          ssm_scenario,
                                                          x0_all[i],
                                                          t,
                                                          x1_all,
                                                          tplus1,
                                                          x1_log_weight,
                                                          random_keys[i]), jnp.arange(n))

    num_transition_evals = num_transition_evals + len(x1_all) * not_yet_accepted_arr.sum()

    return x1_final_inds, num_transition_evals
Пример #27
0
    def _single_chain_mcmc(self,
                           rng_key,
                           init_state,
                           init_params,
                           args,
                           kwargs,
                           collect_fields=('z', )):
        if init_state is None:
            init_state = self.sampler.init(rng_key,
                                           self.num_warmup,
                                           init_params,
                                           model_args=args,
                                           model_kwargs=kwargs)
        if self.constrain_fn is None:
            self.constrain_fn = self.sampler.constrain_fn(args, kwargs)
        diagnostics = lambda x: get_diagnostics_str(x[
            0]) if rng_key.ndim == 1 else None  # noqa: E731
        init_val = (init_state, args,
                    kwargs) if self._jit_model_args else (init_state, )
        lower_idx = self._collection_params["lower"]
        upper_idx = self._collection_params["upper"]

        collect_vals = fori_collect(
            lower_idx,
            upper_idx,
            self._get_cached_fn(),
            init_val,
            transform=_collect_fn(collect_fields),
            progbar=self.progress_bar,
            return_last_val=True,
            collection_size=self._collection_params["collection_size"],
            progbar_desc=functools.partial(get_progbar_desc_str, lower_idx),
            diagnostics_fn=diagnostics)
        states, last_val = collect_vals
        # Get first argument of type `HMCState`
        last_state = last_val[0]
        if len(collect_fields) == 1:
            states = (states, )
        states = dict(zip(collect_fields, states))
        # Apply constraints if number of samples is non-zero
        site_values = tree_flatten(states['z'])[0]
        if len(site_values) > 0 and site_values[0].size > 0:
            states['z'] = lax.map(self.constrain_fn, states['z'])
        return states, last_state
Пример #28
0
def _predictive(rng_key,
                model,
                posterior_samples,
                num_samples,
                return_sites=None,
                parallel=True,
                model_args=(),
                model_kwargs={}):
    rng_keys = random.split(rng_key, num_samples)

    def single_prediction(val):
        rng_key, samples = val
        model_trace = trace(seed(substitute(model, samples),
                                 rng_key)).get_trace(*model_args,
                                                     **model_kwargs)
        if return_sites is not None:
            if return_sites == '':
                sites = {
                    k
                    for k, site in model_trace.items()
                    if site['type'] != 'plate'
                }
            else:
                sites = return_sites
        else:
            sites = {
                k
                for k, site in model_trace.items()
                if (site['type'] == 'sample' and k not in samples) or (
                    site['type'] == 'deterministic')
            }
        return {
            name: site['value']
            for name, site in model_trace.items() if name in sites
        }

    if parallel:
        return vmap(single_prediction)((rng_keys, posterior_samples))
    else:
        return lax.map(single_prediction, (rng_keys, posterior_samples))
Пример #29
0
 def testMap(self):
     f = lambda x: x**2
     xs = np.arange(10)
     expected = xs**2
     actual = lax.map(f, xs)
     self.assertAllClose(actual, expected, check_dtypes=True)
Пример #30
0
def find_valid_initial_params(
    rng_key,
    model,
    *,
    init_strategy=init_to_uniform,
    enum=False,
    model_args=(),
    model_kwargs=None,
    prototype_params=None,
    forward_mode_differentiation=False,
    validate_grad=True,
):
    """
    (EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns an initial
    valid unconstrained value for all the parameters. This function also returns
    the corresponding potential energy, the gradients, and an
    `is_valid` flag to say whether the initial parameters are valid. Parameter values
    are considered valid if the values and the gradients for the log density have
    finite values.

    :param jax.random.PRNGKey rng_key: random number generator seed to
        sample from the prior. The returned `init_params` will have the
        batch shape ``rng_key.shape[:-1]``.
    :param model: Python callable containing Pyro primitives.
    :param callable init_strategy: a per-site initialization function.
    :param bool enum: whether to enumerate over discrete latent sites.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param dict prototype_params: an optional prototype parameters, which is used
        to define the shape for initial parameters.
    :param bool forward_mode_differentiation: whether to use forward-mode differentiation
        or reverse-mode differentiation. Defaults to False.
    :param bool validate_grad: whether to validate gradient of the initial params.
        Defaults to True.
    :return: tuple of `init_params_info` and `is_valid`, where `init_params_info` is the tuple
        containing the initial params, their potential energy, and their gradients.
    """
    model_kwargs = {} if model_kwargs is None else model_kwargs
    init_strategy = (init_strategy if isinstance(init_strategy, partial) else
                     init_strategy())
    # handle those init strategies differently to save computation
    if init_strategy.func is init_to_uniform:
        radius = init_strategy.keywords.get("radius")
        init_values = {}
    elif init_strategy.func is _init_to_unconstrained_value:
        radius = 2
        init_values = init_strategy.keywords.get("values")
    else:
        radius = None

    def cond_fn(state):
        i, _, _, is_valid = state
        return (i < 100) & (~is_valid)

    def body_fn(state):
        i, key, _, _ = state
        key, subkey = random.split(key)

        if radius is None or prototype_params is None:
            # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
            seeded_model = substitute(seed(model, subkey),
                                      substitute_fn=init_strategy)
            model_trace = trace(seeded_model).get_trace(
                *model_args, **model_kwargs)
            constrained_values, inv_transforms = {}, {}
            for k, v in model_trace.items():
                if (v["type"] == "sample" and not v["is_observed"]
                        and not v["fn"].is_discrete):
                    constrained_values[k] = v["value"]
                    inv_transforms[k] = biject_to(v["fn"].support)
            params = transform_fn(
                inv_transforms,
                {k: v
                 for k, v in constrained_values.items()},
                invert=True,
            )
        else:  # this branch doesn't require tracing the model
            params = {}
            for k, v in prototype_params.items():
                if k in init_values:
                    params[k] = init_values[k]
                else:
                    params[k] = random.uniform(subkey,
                                               jnp.shape(v),
                                               minval=-radius,
                                               maxval=radius)
                    key, subkey = random.split(key)

        potential_fn = partial(potential_energy,
                               model,
                               model_args,
                               model_kwargs,
                               enum=enum)
        if validate_grad:
            if forward_mode_differentiation:
                pe = potential_fn(params)
                z_grad = jacfwd(potential_fn)(params)
            else:
                pe, z_grad = value_and_grad(potential_fn)(params)
            z_grad_flat = ravel_pytree(z_grad)[0]
            is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
        else:
            pe = potential_fn(params)
            is_valid = jnp.isfinite(pe)
            z_grad = None

        return i + 1, key, (params, pe, z_grad), is_valid

    def _find_valid_params(rng_key, exit_early=False):
        init_state = (0, rng_key, (prototype_params, 0.0, prototype_params),
                      False)
        if exit_early and not_jax_tracer(rng_key):
            # Early return if valid params found. This is only helpful for single chain,
            # where we can avoid compiling body_fn in while_loop.
            _, _, (init_params, pe,
                   z_grad), is_valid = init_state = body_fn(init_state)
            if not_jax_tracer(is_valid):
                if device_get(is_valid):
                    return (init_params, pe, z_grad), is_valid

        # XXX: this requires compiling the model, so for multi-chain, we trace the model 2-times
        # even if the init_state is a valid result
        _, _, (init_params, pe,
               z_grad), is_valid = while_loop(cond_fn, body_fn, init_state)
        return (init_params, pe, z_grad), is_valid

    # Handle possible vectorization
    if rng_key.ndim == 1:
        (init_params, pe,
         z_grad), is_valid = _find_valid_params(rng_key, exit_early=True)
    else:
        (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params,
                                                      rng_key)
    return (init_params, pe, z_grad), is_valid