Пример #1
0
    def testCategorical(self, p, axis, dtype, sample_shape):
        key = random.PRNGKey(0)
        p = np.array(p, dtype=dtype)
        logits = np.log(p) - 42  # test unnormalized
        out_shape = tuple(np.delete(logits.shape, axis))
        shape = sample_shape + out_shape
        rand = lambda key, p: random.categorical(
            key, logits, shape=shape, axis=axis)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, p)
        compiled_samples = crand(key, p)

        if axis < 0:
            axis += len(logits.shape)

        for samples in [uncompiled_samples, compiled_samples]:
            assert samples.shape == shape
            samples = jnp.reshape(samples, (10000, ) + out_shape)
            if len(p.shape[:-1]) > 0:
                ps = np.transpose(p, (1, 0)) if axis == 0 else p
                for cat_samples, cat_p in zip(samples.transpose(), ps):
                    self._CheckChiSquared(cat_samples, pmf=lambda x: cat_p[x])
            else:
                self._CheckChiSquared(samples, pmf=lambda x: p[x])
Пример #2
0
        def sample_scan(params, tup, x):
            """ Perform single step update of the network """
            _, (update_W, update_U,
                update_b), (reset_W, reset_U,
                            reset_b), (out_W, out_U, out_b), (sm_W,
                                                              sm_b) = params
            hidden = tup[3]
            logP = tup[2]
            key = tup[0]
            inp = tup[1]

            update_gate = sigmoid(
                np.dot(inp, update_W) + np.dot(hidden, update_U) + update_b)
            reset_gate = sigmoid(
                np.dot(inp, reset_W) + np.dot(hidden, reset_U) + reset_b)
            output_gate = np.tanh(
                np.dot(inp, out_W) +
                np.dot(np.multiply(reset_gate, hidden), out_U) + out_b)
            output = np.multiply(update_gate, hidden) + np.multiply(
                1 - update_gate, output_gate)
            hidden = output
            logits = np.dot(hidden, sm_W) + sm_b

            key, subkey = random.split(key)

            samples = random.categorical(
                subkey, logits, axis=1, shape=None)  # sampling the conditional
            samples = one_hot(
                samples, sm_b.shape[0])  # convert to one hot encoded vector
            log_P_new = np.sum(samples * log_softmax(logits), axis=1)
            log_P_new = log_P_new + logP  # update the value of the logP of the sample

            return (key, samples, log_P_new, output), samples
def draw_state(val, key, params):
    """
    Simulate one step of a system that evolves as
                A z_{t-1} + Bk + eps,
    where eps ~ N(0, Q).
    
    Parameters
    ----------
    val: tuple (int, jnp.array)
        (latent value of system, state value of system).
    params: PRBPFParamsDiscrete
    key: PRNGKey
    """
    latent_old, state_old = val
    probabilities = params.transition_matrix[latent_old, :]
    logits = logit(probabilities)
    latent_new = random.categorical(key, logits)

    key_latent, key_obs = random.split(key)
    state_new = params.A @ state_old + params.B[latent_new, :]
    state_new = random.multivariate_normal(key_latent, state_new, params.Q)
    obs_new = random.multivariate_normal(key_obs, params.C @ state_new,
                                         params.R)

    return (latent_new, state_new), (latent_new, state_new, obs_new)
Пример #4
0
  def testCategorical(self, p, axis, dtype, sample_shape):
    key = random.PRNGKey(0)
    p = onp.array(p, dtype=dtype)
    logits = onp.log(p) - 42 # test unnormalized
    shape = sample_shape + tuple(onp.delete(logits.shape, axis))
    rand = lambda key, p: random.categorical(key, logits, shape=shape, axis=axis)
    crand = api.jit(rand)

    uncompiled_samples = rand(key, p)
    compiled_samples = crand(key, p)

    if p.ndim > 1:
      self.skipTest("multi-dimensional categorical tests are currently broken!")

    for samples in [uncompiled_samples, compiled_samples]:
      if axis < 0:
       axis += len(logits.shape)

      assert samples.shape == shape

      if len(p.shape[:-1]) > 0:
        for cat_index, p_ in enumerate(p):
          self._CheckChiSquared(samples[:, cat_index], pmf=lambda x: p_[x])
      else:
        self._CheckChiSquared(samples, pmf=lambda x: p[x])
def rbpf_step(key, weight_t, st, mu_t, Sigma_t, xt, params):
    log_p_next = logit(params.transition_matrix[st])
    k = random.categorical(key, log_p_next)
    mu_t, Sigma_t, Ltk = kf_update(mu_t, Sigma_t, k, xt, params)
    weight_t = weight_t * Ltk

    return mu_t, Sigma_t, weight_t, Ltk
def rbpf(current_config, xt, params, nparticles=100):
    """
    Rao-Blackwell Particle Filter using prior as proposal
    """
    key, mu_t, Sigma_t, weights_t, st = current_config

    key_sample, key_state, key_next, key_reindex = random.split(key, 4)
    keys = random.split(key_sample, nparticles)

    st = random.categorical(key_state, logit(params.transition_matrix[st, :]))
    mu_t, Sigma_t, weights_t, Ltk = rbpf_step_vec(keys, weights_t, st, mu_t,
                                                  Sigma_t, xt, params)
    weights_t = weights_t / weights_t.sum()

    indices = jnp.arange(nparticles)
    pi = random.choice(key_reindex,
                       indices,
                       shape=(nparticles, ),
                       p=weights_t,
                       replace=True)
    st = st[pi]
    mu_t = mu_t[pi, ...]
    Sigma_t = Sigma_t[pi, ...]
    weights_t = jnp.ones(nparticles) / nparticles

    return (key_next, mu_t, Sigma_t, weights_t, st), (mu_t, Sigma_t, weights_t,
                                                      st, Ltk)
Пример #7
0
 def resample_final(self, sample: cdict) -> cdict:
     unweighted_vals = sample.value[
         -1,
         random.categorical(random.PRNGKey(1),
                            logits=sample.log_weight[-1],
                            shape=(self.n, ))]
     unweighted_sample = cdict(value=unweighted_vals)
     return unweighted_sample
 def sample(self, n):
     comp = random.categorical(
             seed(),
             np.log(self._pi),
             shape=(n,))
     ran = random.normal(seed(), (n, 1, self.dim))
     samples = self._mu[comp].reshape(n, 1, -1) + ran @ self._sig[comp]
     return samples
Пример #9
0
 def resample(self, ensemble_state: cdict,
              random_key: jnp.ndarray) -> cdict:
     n = ensemble_state.value.shape[0]
     resampled_indices = random.categorical(random_key,
                                            ensemble_state.log_weight,
                                            shape=(n, ))
     resampled_ensemble_state = ensemble_state[resampled_indices]
     resampled_ensemble_state.log_weight = jnp.zeros(n)
     resampled_ensemble_state.ess = jnp.zeros(n) + n
     return resampled_ensemble_state
Пример #10
0
def full_stitch_single(ssm_scenario: StateSpaceModel,
                       x0_single: jnp.ndarray,
                       t: float,
                       x1_all: jnp.ndarray,
                       tplus1: float,
                       x1_log_weight: jnp.ndarray,
                       random_key: jnp.ndarray) -> jnp.ndarray:
    log_weight = x1_log_weight - vmap(ssm_scenario.transition_potential, (None, None, 0, None))(x0_single, t,
                                                                                                x1_all, tplus1)
    return random.categorical(random_key, log_weight)
Пример #11
0
 def sample(rng, params, num_samples=1):
     cluster_samples = []
     for mean, cov in zip(means, covariances):
         rng, temp_rng = random.split(rng)
         cluster_sample = random.multivariate_normal(
             temp_rng, mean, cov, (num_samples, ))
         cluster_samples.append(cluster_sample)
     samples = np.dstack(cluster_samples)
     idx = random.categorical(rng, weights, shape=(num_samples, 1, 1))
     return np.squeeze(np.take_along_axis(samples, idx, -1))
Пример #12
0
 def init_kernel(init_params,
                 num_warmup,
                 adapt_state_size=None,
                 inverse_mass_matrix=None,
                 dense_mass=False,
                 model_args=(),
                 model_kwargs=None,
                 rng_key=random.PRNGKey(0)):
     nonlocal wa_steps
     wa_steps = num_warmup
     pe_fn = potential_fn
     if potential_fn_gen:
         if pe_fn is not None:
             raise ValueError(
                 'Only one of `potential_fn` or `potential_fn_gen` must be provided.'
             )
         else:
             kwargs = {} if model_kwargs is None else model_kwargs
             pe_fn = potential_fn_gen(*model_args, **kwargs)
     rng_key_sa, rng_key_zs, rng_key_z = random.split(rng_key, 3)
     z = init_params
     z_flat, unravel_fn = ravel_pytree(z)
     if inverse_mass_matrix is None:
         inverse_mass_matrix = jnp.identity(
             z_flat.shape[-1]) if dense_mass else jnp.ones(z_flat.shape[-1])
     inv_mass_matrix_sqrt = jnp.linalg.cholesky(inverse_mass_matrix) if dense_mass \
         else jnp.sqrt(inverse_mass_matrix)
     if adapt_state_size is None:
         # XXX: heuristic choice
         adapt_state_size = 2 * z_flat.shape[-1]
     else:
         assert adapt_state_size > 1, 'adapt_state_size should be greater than 1.'
     # NB: mean is init_params
     zs = z_flat + _sample_proposal(inv_mass_matrix_sqrt, rng_key_zs,
                                    (adapt_state_size, ))
     # compute potential energies
     pes = lax.map(lambda z: pe_fn(unravel_fn(z)), zs)
     if dense_mass:
         cov = jnp.cov(zs, rowvar=False, bias=True)
         if cov.shape == ():  # JAX returns scalar for 1D input
             cov = cov.reshape((1, 1))
         cholesky = jnp.linalg.cholesky(cov)
         # if cholesky is NaN, we use the scale from `sample_proposal` here
         inv_mass_matrix_sqrt = jnp.where(jnp.any(jnp.isnan(cholesky)),
                                          inv_mass_matrix_sqrt, cholesky)
     else:
         inv_mass_matrix_sqrt = jnp.std(zs, 0)
     adapt_state = SAAdaptState(zs, pes, jnp.mean(zs, 0),
                                inv_mass_matrix_sqrt)
     k = random.categorical(rng_key_z, jnp.zeros(zs.shape[0]))
     z = unravel_fn(zs[k])
     pe = pes[k]
     sa_state = SAState(jnp.array(0), z, pe, jnp.zeros(()), jnp.zeros(()),
                        jnp.array(False), adapt_state, rng_key_sa)
     return device_put(sa_state)
Пример #13
0
def rejection_stitch_proposal_single(ssm_scenario: StateSpaceModel,
                                     x0_single: jnp.ndarray,
                                     t: float,
                                     x1_all: jnp.ndarray,
                                     tplus1: float,
                                     x1_log_weight: jnp.ndarray,
                                     bound: float,
                                     random_key: jnp.ndarray) \
        -> Tuple[jnp.ndarray, float, bool, jnp.ndarray]:
    random_key, choice_key, uniform_key = random.split(random_key, 3)
    x1_single_ind = random.categorical(choice_key, x1_log_weight)
    conditional_dens = jnp.exp(-ssm_scenario.transition_potential(x0_single, t, x1_all[x1_single_ind], tplus1))
    return x1_single_ind, conditional_dens, random.uniform(uniform_key) > conditional_dens / bound, random_key
Пример #14
0
def metric(samp_arr, log_weight=None):
    if isinstance(samp_arr, mocat.cdict):
        samp_arr = samp_arr.value
    if samp_arr.ndim == 3:
        samp_arr = samp_arr[..., 0].T
    if log_weight is not None:
        samp_arr = samp_arr[random.categorical(random.PRNGKey(0),
                                               log_weight,
                                               shape=(len(samp_arr), ))]
    mean = samp_arr.mean(0)
    cov = jnp.cov(samp_arr.T, ddof=1)
    return 0.5 * (jnp.trace(prec_post @ cov) +
                  (mean_post - mean).T @ prec_post @ (mean_post - mean) - len_t
                  + jnp.log(jnp.linalg.det(cov_post) / jnp.linalg.det(cov)))
Пример #15
0
def get_k(rng):
    if args.k_type == 'const':
        return int(args.k_param)

    idx = jnp.arange(1, args.L**2 + 1)

    if args.k_type == 'exp':
        logits = -idx * jnp.log(args.k_param)
    elif args.k_type == 'power':
        logits = -args.k_param * jnp.log(idx)
    else:
        raise ValueError('Unknown k_type: {}'.format(args.k_type))

    k = jrand.categorical(rng, logits) + 1
    return k
Пример #16
0
 def sampling_loop_body_fn(state):
     """Sampling loop state update."""
     i, sequences, cache, cur_token, ended, rng = state
     # Split RNG for sampling.
     rng1, rng2 = random.split(rng)
     # Call fast-decoder model on current tokens to get next-position logits.
     logits, new_cache = tokens_to_logits(cur_token, cache)
     # Sample next token from logits.
     # TODO(levskaya): add top-p "nucleus" sampling option.
     if topk:
         # Get top-k logits and their indices, sample within these top-k tokens.
         topk_logits, topk_idxs = lax.top_k(logits, topk)
         topk_token = jnp.expand_dims(random.categorical(
             rng1, topk_logits / temperature).astype(jnp.int32),
                                      axis=-1)
         # Return the original indices corresponding to the sampled top-k tokens.
         next_token = jnp.squeeze(jnp.take_along_axis(topk_idxs,
                                                      topk_token,
                                                      axis=-1),
                                  axis=-1)
     else:
         next_token = random.categorical(rng1, logits / temperature).astype(
             jnp.int32)
     # Only use sampled tokens if we're past provided prefix tokens.
     out_of_prompt = (sequences[:, i + 1] == 0)
     next_token = (next_token * out_of_prompt +
                   sequences[:, i + 1] * ~out_of_prompt)
     # If end-marker reached for batch item, only emit padding tokens.
     next_token_or_endpad = next_token * ~ended
     ended |= (next_token_or_endpad == end_marker)
     # Add current sampled tokens to recorded sequences.
     new_sequences = lax.dynamic_update_slice(sequences,
                                              next_token_or_endpad,
                                              (0, i + 1))
     return (i + 1, new_sequences, new_cache, next_token_or_endpad, ended,
             rng2)
Пример #17
0
 def body(state):
     (i, _, key, done, _) = state
     key, accept_key, sample_key, select_key = random.split(key, 4)
     k = random.categorical(select_key, log_p)
     mu_k = mu[k, :]
     radii_k = radii[k, :]
     rotation_k = rotation[k, :, :]
     u_test = sample_ellipsoid(sample_key,
                               mu_k,
                               radii_k,
                               rotation_k,
                               unit_cube_constraint=unit_cube_constraint)
     inside = vmap(lambda mu, radii, rotation: point_in_ellipsoid(
         u_test, mu, radii, rotation))(mu, radii, rotation)
     n_intersect = jnp.sum(inside)
     done = (random.uniform(accept_key) < jnp.reciprocal(n_intersect))
     return (i + 1, k, key, done, u_test)
Пример #18
0
def logistic_mix_sample(nn_out, rng):
    m, t, inv_scales, logit_weights = logistic_preprocess(nn_out)
    rng_mix, rng_logistic = random.split(rng)
    mix_idx = random.categorical(rng_mix, logit_weights, -3)

    def select_mix(arr):
        return jnp.squeeze(
            jnp.take_along_axis(arr, jnp.expand_dims(mix_idx, (-4, -1)), -4),
            -4)

    m, t, inv_scales = map(lambda x: jnp.moveaxis(select_mix(x), -1, 0),
                           (m, t, inv_scales))
    l = random.logistic(rng_logistic, m.shape) / inv_scales
    img_red = m[0] + l[0]
    img_green = m[1] + t[0] * img_red + l[1]
    img_blue = m[2] + t[1] * img_red + t[2] * img_green + l[2]
    return jnp.stack([img_red, img_green, img_blue], -1)
    def sample(self, n_samples, key=None):
        """mutates self.rkey"""
        if key is None:
            self.threadkey, key = random.split(self.threadkey)

        def sample_from_component(rkey, component):
            return random.multivariate_normal(rkey, self.means[component],
                                              self.covs[component])

        key, subkey = random.split(key)
        keys = random.split(key, n_samples)
        components = random.categorical(subkey,
                                        np.log(self.weights),
                                        shape=(n_samples, ))
        out = vmap(sample_from_component)(keys, components)
        shape = (n_samples, self.d)
        return out.reshape(shape)
Пример #20
0
def rejection_stitching(ssm_scenario: StateSpaceModel,
                        x0_all: jnp.ndarray,
                        t: float,
                        x1_all: jnp.ndarray,
                        tplus1: float,
                        x1_log_weight: jnp.ndarray,
                        random_key: jnp.ndarray,
                        maximum_rejections: int,
                        init_bound_param: float,
                        bound_inflation: float) -> Tuple[jnp.ndarray, int]:
    rejection_initial_keys = random.split(random_key, 3)
    n = len(x1_all)

    # Prerun to initiate bound
    x1_initial_inds = random.categorical(rejection_initial_keys[0], x1_log_weight, shape=(n,))
    initial_cond_dens = jnp.exp(-vmap(ssm_scenario.transition_potential,
                                      (0, None, 0, None))(x0_all, t, x1_all[x1_initial_inds], tplus1))
    max_cond_dens = jnp.max(initial_cond_dens)
    initial_bound = jnp.where(max_cond_dens > init_bound_param, max_cond_dens * bound_inflation, init_bound_param)
    initial_not_yet_accepted_arr = random.uniform(rejection_initial_keys[1], (n,)) > initial_cond_dens / initial_bound

    out_tup = while_loop(lambda tup: jnp.logical_and(tup[0].sum() > 0, tup[-2] < maximum_rejections),
                         lambda tup: rejection_stitch_proposal_all(ssm_scenario, x0_all, t, x1_all, tplus1,
                                                                   x1_log_weight,
                                                                   bound_inflation, *tup),
                         (initial_not_yet_accepted_arr,
                          x1_initial_inds,
                          initial_bound,
                          random.split(rejection_initial_keys[2], n),
                          1,
                          n))
    not_yet_accepted_arr, x1_final_inds, final_bound, random_keys, rej_attempted, num_transition_evals = out_tup

    x1_final_inds = map(lambda i: full_stitch_single_cond(not_yet_accepted_arr[i],
                                                          x1_final_inds[i],
                                                          ssm_scenario,
                                                          x0_all[i],
                                                          t,
                                                          x1_all,
                                                          tplus1,
                                                          x1_log_weight,
                                                          random_keys[i]), jnp.arange(n))

    num_transition_evals = num_transition_evals + len(x1_all) * not_yet_accepted_arr.sum()

    return x1_final_inds, num_transition_evals
Пример #21
0
    def action_selection(self, rng_key, beliefs, gamma=1e3):
        # sample choices

        p_cfm, params = beliefs

        p_c = einsum('...cfm->...c', p_cfm)

        beliefs = (p_c, params)

        if self.dyn_pref:
            U = jnp.expand_dims(jnp.log(self.lam), -2)
        else:
            U = self.U

        self.logits = logits(beliefs, gamma, U)

        return random.categorical(rng_key, self.logits)
    def conditional_sample(self, x, idx, n):
        reps = x.shape[0]
        _idx_help = np.arange(n*reps)
        pi, mu, var = self.condition(x, idx)
        var = np.maximum(var, 1e-5)
        sig = vmap(np.linalg.cholesky)(var)

        pi = np.repeat(pi, n, 0)
        mu = np.repeat(mu, n, 0)
        sig = np.repeat(sig, n, 0)
        sig = np.maximum(sig, 0.)
        comp = random.categorical(
                seed(),
                np.log(pi),
                axis=1).reshape(-1)
        offset = mu[_idx_help, comp]
        ran = random.normal(seed(), (n*reps, 1, self.dim-idx))
        ran = (ran @ sig[_idx_help, comp]).reshape(-1, mu.shape[-1])
        samples = offset + ran
        return samples
Пример #23
0
 def inner_body(inner_state):
     (key, accept, _) = inner_state
     key, sample_key, accept_key, select_key = random.split(key, 4)
     i = random.categorical(select_key, logits=log_Vp_i)
     mean_trials = jnp.exp(jnp.log(N) + D * jnp.log(cube_width))
     u_test = random.uniform(sample_key,
                             shape=(
                                 10,
                                 D,
                             ),
                             minval=points_lower[i, None, :],
                             maxval=points_upper[i, None, :])
     n_intersect = vmap(lambda u_test: jnp.sum(
         vmap(lambda y_lower, y_upper: points_in_box(
             u_test, y_lower, y_upper))(points_lower, points_upper)))(
                 u_test)
     accept = n_intersect * random.uniform(accept_key,
                                           shape=(10, )) < 1.
     u_test = u_test[jnp.argmax(accept), :]
     accept = jnp.any(accept)
     return (key, accept, u_test)
Пример #24
0
def backward_simulation_rejection(ssm_scenario: StateSpaceModel,
                                  marginal_particles: cdict,
                                  n_samps: int,
                                  random_key: jnp.ndarray,
                                  maximum_rejections: int,
                                  init_bound_param: float,
                                  bound_inflation: float) -> cdict:
    marg_particles_vals = marginal_particles.value
    times = marginal_particles.t
    marginal_log_weight = marginal_particles.log_weight

    T, n_pf, d = marg_particles_vals.shape

    t_keys = random.split(random_key, T)
    final_particle_vals = marg_particles_vals[-1, random.categorical(t_keys[-1],
                                                                     marginal_log_weight[-1],
                                                                     shape=(n_samps,))]

    def back_sim_body(x_tplus1_all: jnp.ndarray, ind: int):
        x_t_all, num_transition_evals = rejection_resampling(ssm_scenario,
                                                             marg_particles_vals[ind], times[ind],
                                                             x_tplus1_all, times[ind + 1],
                                                             marginal_log_weight[ind], t_keys[ind],
                                                             maximum_rejections, init_bound_param, bound_inflation)
        return x_t_all, (x_t_all, num_transition_evals)

    _, back_sim_out = scan(back_sim_body,
                           final_particle_vals,
                           jnp.arange(T - 2, -1, -1), unroll=1)

    back_sim_particles, num_transition_evals = back_sim_out

    out_samps = marginal_particles.copy()
    out_samps.value = jnp.vstack([back_sim_particles[::-1], final_particle_vals[jnp.newaxis]])
    out_samps.num_transition_evals = jnp.append(0, num_transition_evals[::-1])
    del out_samps.log_weight
    return out_samps
Пример #25
0
    def body(state):
        (key, _done, log_t_cum, log_M) = state
        key, choose_key, sample_key = random.split(key, 3)
        i = random.categorical(choose_key, logits=log_Vp_i - log_Vp)
        # sample query
        x = random.uniform(sample_key,
                           shape=(D, ),
                           minval=points_lower[i, :],
                           maxval=points_upper[i, :])

        def inner_body(inner_state):
            (
                key,
                _done,
                _completed,
                log_t_M,
            ) = inner_state
            completed = jnp.logaddexp(log_t_cum, log_t_M) >= log_T
            # if not completed then increment t_M and test if point in cube
            log_t_M = jnp.where(completed, log_t_M, jnp.logaddexp(log_t_M, 0.))
            key, inner_choose_key = random.split(key, 2)
            j = random.randint(inner_choose_key,
                               shape=(),
                               minval=0,
                               maxval=N + 1)
            # point query
            in_j = points_in_box(x, points_lower[j, :], points_upper[j, :])
            done = in_j | completed
            return key, done, completed, log_t_M

        (key, _done, completed, log_t_M) = while_loop(
            lambda inner_state: ~inner_state[1], inner_body,
            (key, log_t_cum >= log_T, log_t_cum >= log_T, -jnp.inf))
        done = completed
        log_t_cum = jnp.logaddexp(log_t_cum, log_t_M)
        return (key, done, log_t_cum, jnp.logaddexp(log_M, 0.))
Пример #26
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     return random.categorical(key,
                               self.logits,
                               shape=sample_shape + self.batch_shape)
Пример #27
0
    def sample_kernel(sa_state, model_args=(), model_kwargs=None):
        pe_fn = potential_fn
        if potential_fn_gen:
            pe_fn = potential_fn_gen(*model_args, **model_kwargs)
        zs, pes, loc, scale = sa_state.adapt_state
        # we recompute loc/scale after each iteration to avoid precision loss
        # XXX: consider to expose a setting to do this job periodically
        # to save some computations
        loc = jnp.mean(zs, 0)
        if scale.ndim == 2:
            cov = jnp.cov(zs, rowvar=False, bias=True)
            if cov.shape == ():  # JAX returns scalar for 1D input
                cov = cov.reshape((1, 1))
            cholesky = jnp.linalg.cholesky(cov)
            scale = jnp.where(jnp.any(jnp.isnan(cholesky)), scale, cholesky)
        else:
            scale = jnp.std(zs, 0)

        rng_key, rng_key_z, rng_key_reject, rng_key_accept = random.split(
            sa_state.rng_key, 4)
        _, unravel_fn = ravel_pytree(sa_state.z)

        z = loc + _sample_proposal(scale, rng_key_z)
        pe = pe_fn(unravel_fn(z))
        pe = jnp.where(jnp.isnan(pe), jnp.inf, pe)
        diverging = (pe - sa_state.potential_energy) > max_delta_energy

        # NB: all terms having the pattern *s will have shape N x ...
        # and all terms having the pattern *s_ will have shape (N + 1) x ...
        locs, scales = _get_proposal_loc_and_scale(zs, loc, scale, z)
        zs_ = jnp.concatenate([zs, z[None, :]])
        pes_ = jnp.concatenate([pes, pe[None]])
        locs_ = jnp.concatenate([locs, loc[None, :]])
        scales_ = jnp.concatenate([scales, scale[None, ...]])
        if scale.ndim == 2:  # dense_mass
            log_weights_ = dist.MultivariateNormal(
                locs_, scale_tril=scales_).log_prob(zs_) + pes_
        else:
            log_weights_ = dist.Normal(locs_,
                                       scales_).log_prob(zs_).sum(-1) + pes_
        # mask invalid values (nan, +inf) by -inf
        log_weights_ = jnp.where(jnp.isfinite(log_weights_), log_weights_,
                                 -jnp.inf)
        # get rejecting index
        j = random.categorical(rng_key_reject, log_weights_)
        zs = _numpy_delete(zs_, j)
        pes = _numpy_delete(pes_, j)
        loc = locs_[j]
        scale = scales_[j]
        adapt_state = SAAdaptState(zs, pes, loc, scale)

        # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
        accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
        itr = sa_state.i + 1
        n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
        mean_accept_prob = sa_state.mean_accept_prob + (
            accept_prob - sa_state.mean_accept_prob) / n

        # XXX: we make a modification of SA sampler in [1]
        # in [1], each MCMC state contains N points `zs`
        # here we do resampling to pick randomly a point from those N points
        k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[0]))
        z = unravel_fn(zs[k])
        pe = pes[k]
        return SAState(itr, z, pe, accept_prob, mean_accept_prob, diverging,
                       adapt_state, rng_key)
Пример #28
0
 def sample(self, key, sample_shape=()):
     return random.categorical(key,
                               self.logits,
                               shape=sample_shape + self.batch_shape)
Пример #29
0
 def sample(self, rng_key, sample_shape):
     shape = sample_shape + self.batch_shape + self.event_shape
     return random.categorical(rng_key, self.probs, axis=-1, shape=shape)
Пример #30
0
 def mix(key):
     key1, key2 = random.split(key)
     component = random.categorical(key1, np.log(weights))
     return components[component](key2)