def testCategorical(self, p, axis, dtype, sample_shape): key = random.PRNGKey(0) p = np.array(p, dtype=dtype) logits = np.log(p) - 42 # test unnormalized out_shape = tuple(np.delete(logits.shape, axis)) shape = sample_shape + out_shape rand = lambda key, p: random.categorical( key, logits, shape=shape, axis=axis) crand = api.jit(rand) uncompiled_samples = rand(key, p) compiled_samples = crand(key, p) if axis < 0: axis += len(logits.shape) for samples in [uncompiled_samples, compiled_samples]: assert samples.shape == shape samples = jnp.reshape(samples, (10000, ) + out_shape) if len(p.shape[:-1]) > 0: ps = np.transpose(p, (1, 0)) if axis == 0 else p for cat_samples, cat_p in zip(samples.transpose(), ps): self._CheckChiSquared(cat_samples, pmf=lambda x: cat_p[x]) else: self._CheckChiSquared(samples, pmf=lambda x: p[x])
def sample_scan(params, tup, x): """ Perform single step update of the network """ _, (update_W, update_U, update_b), (reset_W, reset_U, reset_b), (out_W, out_U, out_b), (sm_W, sm_b) = params hidden = tup[3] logP = tup[2] key = tup[0] inp = tup[1] update_gate = sigmoid( np.dot(inp, update_W) + np.dot(hidden, update_U) + update_b) reset_gate = sigmoid( np.dot(inp, reset_W) + np.dot(hidden, reset_U) + reset_b) output_gate = np.tanh( np.dot(inp, out_W) + np.dot(np.multiply(reset_gate, hidden), out_U) + out_b) output = np.multiply(update_gate, hidden) + np.multiply( 1 - update_gate, output_gate) hidden = output logits = np.dot(hidden, sm_W) + sm_b key, subkey = random.split(key) samples = random.categorical( subkey, logits, axis=1, shape=None) # sampling the conditional samples = one_hot( samples, sm_b.shape[0]) # convert to one hot encoded vector log_P_new = np.sum(samples * log_softmax(logits), axis=1) log_P_new = log_P_new + logP # update the value of the logP of the sample return (key, samples, log_P_new, output), samples
def draw_state(val, key, params): """ Simulate one step of a system that evolves as A z_{t-1} + Bk + eps, where eps ~ N(0, Q). Parameters ---------- val: tuple (int, jnp.array) (latent value of system, state value of system). params: PRBPFParamsDiscrete key: PRNGKey """ latent_old, state_old = val probabilities = params.transition_matrix[latent_old, :] logits = logit(probabilities) latent_new = random.categorical(key, logits) key_latent, key_obs = random.split(key) state_new = params.A @ state_old + params.B[latent_new, :] state_new = random.multivariate_normal(key_latent, state_new, params.Q) obs_new = random.multivariate_normal(key_obs, params.C @ state_new, params.R) return (latent_new, state_new), (latent_new, state_new, obs_new)
def testCategorical(self, p, axis, dtype, sample_shape): key = random.PRNGKey(0) p = onp.array(p, dtype=dtype) logits = onp.log(p) - 42 # test unnormalized shape = sample_shape + tuple(onp.delete(logits.shape, axis)) rand = lambda key, p: random.categorical(key, logits, shape=shape, axis=axis) crand = api.jit(rand) uncompiled_samples = rand(key, p) compiled_samples = crand(key, p) if p.ndim > 1: self.skipTest("multi-dimensional categorical tests are currently broken!") for samples in [uncompiled_samples, compiled_samples]: if axis < 0: axis += len(logits.shape) assert samples.shape == shape if len(p.shape[:-1]) > 0: for cat_index, p_ in enumerate(p): self._CheckChiSquared(samples[:, cat_index], pmf=lambda x: p_[x]) else: self._CheckChiSquared(samples, pmf=lambda x: p[x])
def rbpf_step(key, weight_t, st, mu_t, Sigma_t, xt, params): log_p_next = logit(params.transition_matrix[st]) k = random.categorical(key, log_p_next) mu_t, Sigma_t, Ltk = kf_update(mu_t, Sigma_t, k, xt, params) weight_t = weight_t * Ltk return mu_t, Sigma_t, weight_t, Ltk
def rbpf(current_config, xt, params, nparticles=100): """ Rao-Blackwell Particle Filter using prior as proposal """ key, mu_t, Sigma_t, weights_t, st = current_config key_sample, key_state, key_next, key_reindex = random.split(key, 4) keys = random.split(key_sample, nparticles) st = random.categorical(key_state, logit(params.transition_matrix[st, :])) mu_t, Sigma_t, weights_t, Ltk = rbpf_step_vec(keys, weights_t, st, mu_t, Sigma_t, xt, params) weights_t = weights_t / weights_t.sum() indices = jnp.arange(nparticles) pi = random.choice(key_reindex, indices, shape=(nparticles, ), p=weights_t, replace=True) st = st[pi] mu_t = mu_t[pi, ...] Sigma_t = Sigma_t[pi, ...] weights_t = jnp.ones(nparticles) / nparticles return (key_next, mu_t, Sigma_t, weights_t, st), (mu_t, Sigma_t, weights_t, st, Ltk)
def resample_final(self, sample: cdict) -> cdict: unweighted_vals = sample.value[ -1, random.categorical(random.PRNGKey(1), logits=sample.log_weight[-1], shape=(self.n, ))] unweighted_sample = cdict(value=unweighted_vals) return unweighted_sample
def sample(self, n): comp = random.categorical( seed(), np.log(self._pi), shape=(n,)) ran = random.normal(seed(), (n, 1, self.dim)) samples = self._mu[comp].reshape(n, 1, -1) + ran @ self._sig[comp] return samples
def resample(self, ensemble_state: cdict, random_key: jnp.ndarray) -> cdict: n = ensemble_state.value.shape[0] resampled_indices = random.categorical(random_key, ensemble_state.log_weight, shape=(n, )) resampled_ensemble_state = ensemble_state[resampled_indices] resampled_ensemble_state.log_weight = jnp.zeros(n) resampled_ensemble_state.ess = jnp.zeros(n) + n return resampled_ensemble_state
def full_stitch_single(ssm_scenario: StateSpaceModel, x0_single: jnp.ndarray, t: float, x1_all: jnp.ndarray, tplus1: float, x1_log_weight: jnp.ndarray, random_key: jnp.ndarray) -> jnp.ndarray: log_weight = x1_log_weight - vmap(ssm_scenario.transition_potential, (None, None, 0, None))(x0_single, t, x1_all, tplus1) return random.categorical(random_key, log_weight)
def sample(rng, params, num_samples=1): cluster_samples = [] for mean, cov in zip(means, covariances): rng, temp_rng = random.split(rng) cluster_sample = random.multivariate_normal( temp_rng, mean, cov, (num_samples, )) cluster_samples.append(cluster_sample) samples = np.dstack(cluster_samples) idx = random.categorical(rng, weights, shape=(num_samples, 1, 1)) return np.squeeze(np.take_along_axis(samples, idx, -1))
def init_kernel(init_params, num_warmup, adapt_state_size=None, inverse_mass_matrix=None, dense_mass=False, model_args=(), model_kwargs=None, rng_key=random.PRNGKey(0)): nonlocal wa_steps wa_steps = num_warmup pe_fn = potential_fn if potential_fn_gen: if pe_fn is not None: raise ValueError( 'Only one of `potential_fn` or `potential_fn_gen` must be provided.' ) else: kwargs = {} if model_kwargs is None else model_kwargs pe_fn = potential_fn_gen(*model_args, **kwargs) rng_key_sa, rng_key_zs, rng_key_z = random.split(rng_key, 3) z = init_params z_flat, unravel_fn = ravel_pytree(z) if inverse_mass_matrix is None: inverse_mass_matrix = jnp.identity( z_flat.shape[-1]) if dense_mass else jnp.ones(z_flat.shape[-1]) inv_mass_matrix_sqrt = jnp.linalg.cholesky(inverse_mass_matrix) if dense_mass \ else jnp.sqrt(inverse_mass_matrix) if adapt_state_size is None: # XXX: heuristic choice adapt_state_size = 2 * z_flat.shape[-1] else: assert adapt_state_size > 1, 'adapt_state_size should be greater than 1.' # NB: mean is init_params zs = z_flat + _sample_proposal(inv_mass_matrix_sqrt, rng_key_zs, (adapt_state_size, )) # compute potential energies pes = lax.map(lambda z: pe_fn(unravel_fn(z)), zs) if dense_mass: cov = jnp.cov(zs, rowvar=False, bias=True) if cov.shape == (): # JAX returns scalar for 1D input cov = cov.reshape((1, 1)) cholesky = jnp.linalg.cholesky(cov) # if cholesky is NaN, we use the scale from `sample_proposal` here inv_mass_matrix_sqrt = jnp.where(jnp.any(jnp.isnan(cholesky)), inv_mass_matrix_sqrt, cholesky) else: inv_mass_matrix_sqrt = jnp.std(zs, 0) adapt_state = SAAdaptState(zs, pes, jnp.mean(zs, 0), inv_mass_matrix_sqrt) k = random.categorical(rng_key_z, jnp.zeros(zs.shape[0])) z = unravel_fn(zs[k]) pe = pes[k] sa_state = SAState(jnp.array(0), z, pe, jnp.zeros(()), jnp.zeros(()), jnp.array(False), adapt_state, rng_key_sa) return device_put(sa_state)
def rejection_stitch_proposal_single(ssm_scenario: StateSpaceModel, x0_single: jnp.ndarray, t: float, x1_all: jnp.ndarray, tplus1: float, x1_log_weight: jnp.ndarray, bound: float, random_key: jnp.ndarray) \ -> Tuple[jnp.ndarray, float, bool, jnp.ndarray]: random_key, choice_key, uniform_key = random.split(random_key, 3) x1_single_ind = random.categorical(choice_key, x1_log_weight) conditional_dens = jnp.exp(-ssm_scenario.transition_potential(x0_single, t, x1_all[x1_single_ind], tplus1)) return x1_single_ind, conditional_dens, random.uniform(uniform_key) > conditional_dens / bound, random_key
def metric(samp_arr, log_weight=None): if isinstance(samp_arr, mocat.cdict): samp_arr = samp_arr.value if samp_arr.ndim == 3: samp_arr = samp_arr[..., 0].T if log_weight is not None: samp_arr = samp_arr[random.categorical(random.PRNGKey(0), log_weight, shape=(len(samp_arr), ))] mean = samp_arr.mean(0) cov = jnp.cov(samp_arr.T, ddof=1) return 0.5 * (jnp.trace(prec_post @ cov) + (mean_post - mean).T @ prec_post @ (mean_post - mean) - len_t + jnp.log(jnp.linalg.det(cov_post) / jnp.linalg.det(cov)))
def get_k(rng): if args.k_type == 'const': return int(args.k_param) idx = jnp.arange(1, args.L**2 + 1) if args.k_type == 'exp': logits = -idx * jnp.log(args.k_param) elif args.k_type == 'power': logits = -args.k_param * jnp.log(idx) else: raise ValueError('Unknown k_type: {}'.format(args.k_type)) k = jrand.categorical(rng, logits) + 1 return k
def sampling_loop_body_fn(state): """Sampling loop state update.""" i, sequences, cache, cur_token, ended, rng = state # Split RNG for sampling. rng1, rng2 = random.split(rng) # Call fast-decoder model on current tokens to get next-position logits. logits, new_cache = tokens_to_logits(cur_token, cache) # Sample next token from logits. # TODO(levskaya): add top-p "nucleus" sampling option. if topk: # Get top-k logits and their indices, sample within these top-k tokens. topk_logits, topk_idxs = lax.top_k(logits, topk) topk_token = jnp.expand_dims(random.categorical( rng1, topk_logits / temperature).astype(jnp.int32), axis=-1) # Return the original indices corresponding to the sampled top-k tokens. next_token = jnp.squeeze(jnp.take_along_axis(topk_idxs, topk_token, axis=-1), axis=-1) else: next_token = random.categorical(rng1, logits / temperature).astype( jnp.int32) # Only use sampled tokens if we're past provided prefix tokens. out_of_prompt = (sequences[:, i + 1] == 0) next_token = (next_token * out_of_prompt + sequences[:, i + 1] * ~out_of_prompt) # If end-marker reached for batch item, only emit padding tokens. next_token_or_endpad = next_token * ~ended ended |= (next_token_or_endpad == end_marker) # Add current sampled tokens to recorded sequences. new_sequences = lax.dynamic_update_slice(sequences, next_token_or_endpad, (0, i + 1)) return (i + 1, new_sequences, new_cache, next_token_or_endpad, ended, rng2)
def body(state): (i, _, key, done, _) = state key, accept_key, sample_key, select_key = random.split(key, 4) k = random.categorical(select_key, log_p) mu_k = mu[k, :] radii_k = radii[k, :] rotation_k = rotation[k, :, :] u_test = sample_ellipsoid(sample_key, mu_k, radii_k, rotation_k, unit_cube_constraint=unit_cube_constraint) inside = vmap(lambda mu, radii, rotation: point_in_ellipsoid( u_test, mu, radii, rotation))(mu, radii, rotation) n_intersect = jnp.sum(inside) done = (random.uniform(accept_key) < jnp.reciprocal(n_intersect)) return (i + 1, k, key, done, u_test)
def logistic_mix_sample(nn_out, rng): m, t, inv_scales, logit_weights = logistic_preprocess(nn_out) rng_mix, rng_logistic = random.split(rng) mix_idx = random.categorical(rng_mix, logit_weights, -3) def select_mix(arr): return jnp.squeeze( jnp.take_along_axis(arr, jnp.expand_dims(mix_idx, (-4, -1)), -4), -4) m, t, inv_scales = map(lambda x: jnp.moveaxis(select_mix(x), -1, 0), (m, t, inv_scales)) l = random.logistic(rng_logistic, m.shape) / inv_scales img_red = m[0] + l[0] img_green = m[1] + t[0] * img_red + l[1] img_blue = m[2] + t[1] * img_red + t[2] * img_green + l[2] return jnp.stack([img_red, img_green, img_blue], -1)
def sample(self, n_samples, key=None): """mutates self.rkey""" if key is None: self.threadkey, key = random.split(self.threadkey) def sample_from_component(rkey, component): return random.multivariate_normal(rkey, self.means[component], self.covs[component]) key, subkey = random.split(key) keys = random.split(key, n_samples) components = random.categorical(subkey, np.log(self.weights), shape=(n_samples, )) out = vmap(sample_from_component)(keys, components) shape = (n_samples, self.d) return out.reshape(shape)
def rejection_stitching(ssm_scenario: StateSpaceModel, x0_all: jnp.ndarray, t: float, x1_all: jnp.ndarray, tplus1: float, x1_log_weight: jnp.ndarray, random_key: jnp.ndarray, maximum_rejections: int, init_bound_param: float, bound_inflation: float) -> Tuple[jnp.ndarray, int]: rejection_initial_keys = random.split(random_key, 3) n = len(x1_all) # Prerun to initiate bound x1_initial_inds = random.categorical(rejection_initial_keys[0], x1_log_weight, shape=(n,)) initial_cond_dens = jnp.exp(-vmap(ssm_scenario.transition_potential, (0, None, 0, None))(x0_all, t, x1_all[x1_initial_inds], tplus1)) max_cond_dens = jnp.max(initial_cond_dens) initial_bound = jnp.where(max_cond_dens > init_bound_param, max_cond_dens * bound_inflation, init_bound_param) initial_not_yet_accepted_arr = random.uniform(rejection_initial_keys[1], (n,)) > initial_cond_dens / initial_bound out_tup = while_loop(lambda tup: jnp.logical_and(tup[0].sum() > 0, tup[-2] < maximum_rejections), lambda tup: rejection_stitch_proposal_all(ssm_scenario, x0_all, t, x1_all, tplus1, x1_log_weight, bound_inflation, *tup), (initial_not_yet_accepted_arr, x1_initial_inds, initial_bound, random.split(rejection_initial_keys[2], n), 1, n)) not_yet_accepted_arr, x1_final_inds, final_bound, random_keys, rej_attempted, num_transition_evals = out_tup x1_final_inds = map(lambda i: full_stitch_single_cond(not_yet_accepted_arr[i], x1_final_inds[i], ssm_scenario, x0_all[i], t, x1_all, tplus1, x1_log_weight, random_keys[i]), jnp.arange(n)) num_transition_evals = num_transition_evals + len(x1_all) * not_yet_accepted_arr.sum() return x1_final_inds, num_transition_evals
def action_selection(self, rng_key, beliefs, gamma=1e3): # sample choices p_cfm, params = beliefs p_c = einsum('...cfm->...c', p_cfm) beliefs = (p_c, params) if self.dyn_pref: U = jnp.expand_dims(jnp.log(self.lam), -2) else: U = self.U self.logits = logits(beliefs, gamma, U) return random.categorical(rng_key, self.logits)
def conditional_sample(self, x, idx, n): reps = x.shape[0] _idx_help = np.arange(n*reps) pi, mu, var = self.condition(x, idx) var = np.maximum(var, 1e-5) sig = vmap(np.linalg.cholesky)(var) pi = np.repeat(pi, n, 0) mu = np.repeat(mu, n, 0) sig = np.repeat(sig, n, 0) sig = np.maximum(sig, 0.) comp = random.categorical( seed(), np.log(pi), axis=1).reshape(-1) offset = mu[_idx_help, comp] ran = random.normal(seed(), (n*reps, 1, self.dim-idx)) ran = (ran @ sig[_idx_help, comp]).reshape(-1, mu.shape[-1]) samples = offset + ran return samples
def inner_body(inner_state): (key, accept, _) = inner_state key, sample_key, accept_key, select_key = random.split(key, 4) i = random.categorical(select_key, logits=log_Vp_i) mean_trials = jnp.exp(jnp.log(N) + D * jnp.log(cube_width)) u_test = random.uniform(sample_key, shape=( 10, D, ), minval=points_lower[i, None, :], maxval=points_upper[i, None, :]) n_intersect = vmap(lambda u_test: jnp.sum( vmap(lambda y_lower, y_upper: points_in_box( u_test, y_lower, y_upper))(points_lower, points_upper)))( u_test) accept = n_intersect * random.uniform(accept_key, shape=(10, )) < 1. u_test = u_test[jnp.argmax(accept), :] accept = jnp.any(accept) return (key, accept, u_test)
def backward_simulation_rejection(ssm_scenario: StateSpaceModel, marginal_particles: cdict, n_samps: int, random_key: jnp.ndarray, maximum_rejections: int, init_bound_param: float, bound_inflation: float) -> cdict: marg_particles_vals = marginal_particles.value times = marginal_particles.t marginal_log_weight = marginal_particles.log_weight T, n_pf, d = marg_particles_vals.shape t_keys = random.split(random_key, T) final_particle_vals = marg_particles_vals[-1, random.categorical(t_keys[-1], marginal_log_weight[-1], shape=(n_samps,))] def back_sim_body(x_tplus1_all: jnp.ndarray, ind: int): x_t_all, num_transition_evals = rejection_resampling(ssm_scenario, marg_particles_vals[ind], times[ind], x_tplus1_all, times[ind + 1], marginal_log_weight[ind], t_keys[ind], maximum_rejections, init_bound_param, bound_inflation) return x_t_all, (x_t_all, num_transition_evals) _, back_sim_out = scan(back_sim_body, final_particle_vals, jnp.arange(T - 2, -1, -1), unroll=1) back_sim_particles, num_transition_evals = back_sim_out out_samps = marginal_particles.copy() out_samps.value = jnp.vstack([back_sim_particles[::-1], final_particle_vals[jnp.newaxis]]) out_samps.num_transition_evals = jnp.append(0, num_transition_evals[::-1]) del out_samps.log_weight return out_samps
def body(state): (key, _done, log_t_cum, log_M) = state key, choose_key, sample_key = random.split(key, 3) i = random.categorical(choose_key, logits=log_Vp_i - log_Vp) # sample query x = random.uniform(sample_key, shape=(D, ), minval=points_lower[i, :], maxval=points_upper[i, :]) def inner_body(inner_state): ( key, _done, _completed, log_t_M, ) = inner_state completed = jnp.logaddexp(log_t_cum, log_t_M) >= log_T # if not completed then increment t_M and test if point in cube log_t_M = jnp.where(completed, log_t_M, jnp.logaddexp(log_t_M, 0.)) key, inner_choose_key = random.split(key, 2) j = random.randint(inner_choose_key, shape=(), minval=0, maxval=N + 1) # point query in_j = points_in_box(x, points_lower[j, :], points_upper[j, :]) done = in_j | completed return key, done, completed, log_t_M (key, _done, completed, log_t_M) = while_loop( lambda inner_state: ~inner_state[1], inner_body, (key, log_t_cum >= log_T, log_t_cum >= log_T, -jnp.inf)) done = completed log_t_cum = jnp.logaddexp(log_t_cum, log_t_M) return (key, done, log_t_cum, jnp.logaddexp(log_M, 0.))
def sample(self, key, sample_shape=()): assert is_prng_key(key) return random.categorical(key, self.logits, shape=sample_shape + self.batch_shape)
def sample_kernel(sa_state, model_args=(), model_kwargs=None): pe_fn = potential_fn if potential_fn_gen: pe_fn = potential_fn_gen(*model_args, **model_kwargs) zs, pes, loc, scale = sa_state.adapt_state # we recompute loc/scale after each iteration to avoid precision loss # XXX: consider to expose a setting to do this job periodically # to save some computations loc = jnp.mean(zs, 0) if scale.ndim == 2: cov = jnp.cov(zs, rowvar=False, bias=True) if cov.shape == (): # JAX returns scalar for 1D input cov = cov.reshape((1, 1)) cholesky = jnp.linalg.cholesky(cov) scale = jnp.where(jnp.any(jnp.isnan(cholesky)), scale, cholesky) else: scale = jnp.std(zs, 0) rng_key, rng_key_z, rng_key_reject, rng_key_accept = random.split( sa_state.rng_key, 4) _, unravel_fn = ravel_pytree(sa_state.z) z = loc + _sample_proposal(scale, rng_key_z) pe = pe_fn(unravel_fn(z)) pe = jnp.where(jnp.isnan(pe), jnp.inf, pe) diverging = (pe - sa_state.potential_energy) > max_delta_energy # NB: all terms having the pattern *s will have shape N x ... # and all terms having the pattern *s_ will have shape (N + 1) x ... locs, scales = _get_proposal_loc_and_scale(zs, loc, scale, z) zs_ = jnp.concatenate([zs, z[None, :]]) pes_ = jnp.concatenate([pes, pe[None]]) locs_ = jnp.concatenate([locs, loc[None, :]]) scales_ = jnp.concatenate([scales, scale[None, ...]]) if scale.ndim == 2: # dense_mass log_weights_ = dist.MultivariateNormal( locs_, scale_tril=scales_).log_prob(zs_) + pes_ else: log_weights_ = dist.Normal(locs_, scales_).log_prob(zs_).sum(-1) + pes_ # mask invalid values (nan, +inf) by -inf log_weights_ = jnp.where(jnp.isfinite(log_weights_), log_weights_, -jnp.inf) # get rejecting index j = random.categorical(rng_key_reject, log_weights_) zs = _numpy_delete(zs_, j) pes = _numpy_delete(pes_, j) loc = locs_[j] scale = scales_[j] adapt_state = SAAdaptState(zs, pes, loc, scale) # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`. accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_)) itr = sa_state.i + 1 n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps) mean_accept_prob = sa_state.mean_accept_prob + ( accept_prob - sa_state.mean_accept_prob) / n # XXX: we make a modification of SA sampler in [1] # in [1], each MCMC state contains N points `zs` # here we do resampling to pick randomly a point from those N points k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[0])) z = unravel_fn(zs[k]) pe = pes[k] return SAState(itr, z, pe, accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)
def sample(self, key, sample_shape=()): return random.categorical(key, self.logits, shape=sample_shape + self.batch_shape)
def sample(self, rng_key, sample_shape): shape = sample_shape + self.batch_shape + self.event_shape return random.categorical(rng_key, self.probs, axis=-1, shape=shape)
def mix(key): key1, key2 = random.split(key) component = random.categorical(key1, np.log(weights)) return components[component](key2)