def _call(self, x, training=True, rng=None): info = self.info if training: if rng is None: raise ValueError('rng is required when training is True') # Using tie_in to avoid materializing constants keep = primitive.tie_in(x, random.bernoulli(rng, info.rate, x.shape)) return np.where(keep, x / info.rate, 0) else: return x
def _get_inputs(cls, out_logits, test_shape, train_shape): key = random.PRNGKey(0) key, split = random.split(key) x_train = random.normal(split, train_shape) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) x_test = random.normal(split, test_shape) return key, x_test, x_train, y_train
def testBernoulli(self, p, dtype): key = random.PRNGKey(0) p = np.array(p, dtype=dtype) rand = lambda key, p: random.bernoulli(key, p, (10000,)) crand = api.jit(rand) uncompiled_samples = rand(key, p) compiled_samples = crand(key, p) for samples in [uncompiled_samples, compiled_samples]: self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf)
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)): eps = 1e-6 if isinstance(constraint, constraints._Boolean): return random.bernoulli(key, shape=size) elif isinstance(constraint, constraints._GreaterThan): return np.exp(random.normal(key, size)) + constraint.lower_bound + eps elif isinstance(constraint, constraints._IntegerInterval): lower_bound = np.broadcast_to(constraint.lower_bound, size) upper_bound = np.broadcast_to(constraint.upper_bound, size) return random.randint(key, size, lower_bound, upper_bound + 1) elif isinstance(constraint, constraints._IntegerGreaterThan): return constraint.lower_bound + poisson(key, 5, shape=size) elif isinstance(constraint, constraints._Interval): lower_bound = np.broadcast_to(constraint.lower_bound, size) upper_bound = np.broadcast_to(constraint.upper_bound, size) return random.uniform(key, size, minval=lower_bound, maxval=upper_bound) elif isinstance(constraint, (constraints._Real, constraints._RealVector)): return random.normal(key, size) elif isinstance(constraint, constraints._Simplex): return osp.dirichlet.rvs(alpha=np.ones((size[-1], )), size=size[:-1]) elif isinstance(constraint, constraints._Multinomial): n = size[-1] return multinomial(key, p=np.ones((n, )) / n, n=constraint.upper_bound, shape=size[:-1]) elif isinstance(constraint, constraints._CorrCholesky): return signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2, ), minval=-1, maxval=1)) elif isinstance(constraint, constraints._CorrMatrix): cholesky = signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2, ), minval=-1, maxval=1)) return np.matmul(cholesky, np.swapaxes(cholesky, -2, -1)) elif isinstance(constraint, constraints._LowerCholesky): return np.tril(random.uniform(key, size)) elif isinstance(constraint, constraints._PositiveDefinite): x = random.normal(key, size) return np.matmul(x, np.swapaxes(x, -2, -1)) elif isinstance(constraint, constraints._OrderedVector): x = np.cumsum(random.exponential(key, size), -1) return x - random.normal(key, size[:-1]) else: raise NotImplementedError('{} not implemented.'.format(constraint))
def __call__(self, x, deterministic=False, rng=None): if self.rate == 0.: return x keep_prob = 1. - self.rate if deterministic: return x else: if rng is None: rng = self.scope.make_rng('dropout') mask = random.bernoulli(rng, p=keep_prob, shape=x.shape) return lax.select(mask, x / keep_prob, jnp.zeros_like(x))
def step(self): self.update_agent_states() # Update the counts of the various states self.update_counts() if self.model.log_file: self.model.write_logs() # Write some logs, if enabled # NB: At some point, it may make sense to make agents move. For now, they're static. ############################# Transmission logic ################################### # Get all the currently contagious agents, and have them infect new agents. # TODO: include hospital transmission, vary transmissability by state. contagious = np.asarray( (self.model.epidemic_state == self.model.STATE_PRESYMPTOMATIC) | (self.model.epidemic_state == self.model.STATE_ASYMPTOMATIC) | (self.model.epidemic_state == self.model.STATE_SYMPTOMATIC) ).nonzero() # For each contagious person, infect some of its neighbors based on their hygiene and the contagious person's social radius. # Use jax.random instead of numpyro here to keep these deterministic. # TODO: figure out a way to do this in a (more) vectorized manner. Probably some sort of kernel convolution method with each radius. Should also look into numpyro's scan. for x, y in zip(*contagious): radius = self.model.social_radius[x, y] base_isolation = self.model.base_isolation[x, y] nx, ny = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking np.arange(x - radius, x + radius), np.arange(y - radius, y + radius)) neighbor_inds = np.vstack([nx.ravel(), ny.ravel()]) # Higher base_isolation leads to less infection. # TODO: modify isolation so that symptomatic agents isolate more. infection_attempts = jrand.choice( self.PRNGKey, neighbor_inds, shape=int(len(neighbor_inds) * (1 - base_isolation))) potentially_infected_hygiene = self.model.hygiene[ infection_attempts[:, 0], infection_attempts[:, 1]] susceptible = self.model.epidemic_state[ infection_attempts[:, 0], infection_attempts[:, 1]] == self.model.STATE_SUSCEPTIBLE indexer = jrand.bernoulli(self.PRNGKey, potentially_infected_hygiene.ravel(), len(infection_attempts)) got_infected = np.zeros(self.model.epidemic_state.shape, dtype=np.bool_) got_infected[potentially_infected_hygiene[indexer]] = True # Set the date to become infectious self.model.epidemic_state[got_infected & susceptible] = self.model.STATE_EXPOSED self.model.date_infected[got_infected & susceptible] = self.time self.model.date_contagious[ got_infected & susceptible] = self.time + self.params.EXPOSED_PERIOD
def binarize(rng, batch): """Binarizes a batch of observations with values in [0,1] by sampling from a Bernoulli distribution and using the original observations as means. Reason: This example assumes a Bernoulli distribution for the decoder output and thus requires inputs to be binary values as well. :param rng: rng seed key :param batch: Batch of data with continous values in interval [0, 1] :return: tuple(rng, binarized_batch). """ return random.bernoulli(rng, batch).astype(batch.dtype)
def forward(ctx, input, p=0.5, train=False): assert isinstance(input, Variable) noise = random.bernoulli(key=random.PRNGKey(rd.randint(-1000000000000000000, 1000000000000000000)),p=p, shape=input.data.shape) if not train: noise = jnp.ones(input.data.shape) if p == 1: noise = jnp.zeros(input.data.shape) def np_fn(input_np, noise): return input_np * noise np_args = (input.data, noise) id = "Dropout" return np_fn, np_args, jit(np_fn)(*np_args),id
def apply_fun(params, inputs, rng): if rng is None: msg = ( "Dropout layer requires apply_fun to be called with a PRNG key " "argument. That is, instead of `apply_fun(params, inputs)`, call " "it like `apply_fun(params, inputs, key)` where `key` is a " "jax.random.PRNGKey value.") raise ValueError(msg) if mode == 'train': keep = random.bernoulli(rng, rate, inputs.shape) return np.where(keep, inputs / rate, 0) else: return inputs
def apply_fun(params, inputs, is_training, **kwargs): rng = kwargs.get('rng', None) if rng is None: msg = ("Dropout layer requires apply_fun to be called with a PRNG key " "argument. That is, instead of `apply_fun(params, inputs)`, call " "it like `apply_fun(params, inputs, rng)` where `rng` is a " "jax.random.PRNGKey value.") raise ValueError(msg) keep = random.bernoulli(rng, rate, inputs.shape) outs = np.where(keep, inputs / rate, 0) # if not training, just return inputs and discard any computation done out = lax.cond(is_training, outs, lambda x: x, inputs, lambda x: x) return out
def testMaxLearningRate(self, train_shape, network, out_logits, fn_and_kernel, lr_factor, momentum): key = random.PRNGKey(0) key, split = random.split(key) if len(train_shape) == 2: train_shape = (train_shape[0] * 5, train_shape[1] * 10) else: train_shape = (16, 8, 8, 3) x_train = random.normal(split, train_shape) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) # Regress to an MSE loss. loss = lambda params, x: 0.5 * np.mean((f(params, x) - y_train) ** 2) grad_loss = jit(grad(loss)) def get_loss(opt_state): return loss(get_params(opt_state), x_train) steps = 30 params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) g_dd = ntk(x_train, None, 'ntk') step_size = predict.max_learning_rate( g_dd, y_train_size=y_train.size, momentum=momentum) * lr_factor opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum) opt_state = opt_init(params) init_loss = get_loss(opt_state) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, x_train), opt_state) trained_loss = get_loss(opt_state) loss_ratio = trained_loss / (init_loss + 1e-12) if lr_factor < 1.: self.assertLess(loss_ratio, 0.1) elif lr_factor == 1: # At the threshold, the loss decays slowly self.assertLess(loss_ratio, 1.) if lr_factor > 2.: if not math.isnan(loss_ratio): self.assertGreater(loss_ratio, 10.)
def sample(self, state, model_args, model_kwargs): i, x, x_pe, x_grad, _, mean_accept_prob, adapt_state, rng_key = state x_flat, unravel_fn = ravel_pytree(x) x_grad_flat, _ = ravel_pytree(x_grad) shape = jnp.shape(x_flat) rng_key, key_normal, key_bernoulli, key_accept = random.split( rng_key, 4) mass_sqrt_inv = adapt_state.mass_matrix_sqrt_inv x_grad_flat_scaled = mass_sqrt_inv @ x_grad_flat if self._dense_mass else mass_sqrt_inv * x_grad_flat # Generate proposal y. z = adapt_state.step_size * random.normal(key_normal, shape) p = expit(-z * x_grad_flat_scaled) b = jnp.where(random.uniform(key_bernoulli, shape) < p, 1., -1.) dx_flat = b * z dx_flat_scaled = mass_sqrt_inv.T @ dx_flat if self._dense_mass else mass_sqrt_inv * dx_flat y_flat = x_flat + dx_flat_scaled y = unravel_fn(y_flat) y_pe, y_grad = jax.value_and_grad(self._potential_fn)(y) y_grad_flat, _ = ravel_pytree(y_grad) y_grad_flat_scaled = mass_sqrt_inv @ y_grad_flat if self._dense_mass else mass_sqrt_inv * y_grad_flat log_accept_ratio = x_pe - y_pe + jnp.sum( softplus(dx_flat * x_grad_flat_scaled) - softplus(-dx_flat * y_grad_flat_scaled)) accept_prob = jnp.clip(jnp.exp(log_accept_ratio), a_max=1.) x, x_flat, pe, x_grad = jax.lax.cond( random.bernoulli(key_accept, accept_prob), (y, y_flat, y_pe, y_grad), identity, (x, x_flat, x_pe, x_grad), identity) # do not update adapt_state after warmup phase adapt_state = jax.lax.cond(i < self._num_warmup, (i, accept_prob, (x, ), adapt_state), lambda args: self._wa_update(*args), adapt_state, identity) itr = i + 1 n = jnp.where(i < self._num_warmup, itr, itr - self._num_warmup) mean_accept_prob = mean_accept_prob + (accept_prob - mean_accept_prob) / n return BarkerMHState(itr, x, pe, x_grad, accept_prob, mean_accept_prob, adapt_state, rng_key)
def sample_rewards(policy, discriminator, initialization=False, initial_state=0): global key get_new_key() s = bernoulli(key, p=initial_distribution[0]).astype(int) if initialization: s = initial_state get_new_key() a = bernoulli(key, p=policy[s][0]).astype(int) traj = [] traj.append((s, a)) returns = [] returns.append(jnp.log(discriminator[s][a])) for i in range(traj_len - 1): s, a = roll_out(s, a, true_transition, policy) traj.append((s, a)) returns.append(jnp.log(discriminator[s][a])) return jnp.array(copy.deepcopy(returns)), jnp.array(copy.deepcopy(traj))
def apply_fun(params, inputs, **kwargs): rng = kwargs.get("rng", None) if rng is None: msg = ( "Dropout layer requires apply_fun to be called with a PRNG key " "argument. That is, instead of `apply_fun(params, inputs)`, call " "it like `apply_fun(params, inputs, key)` where `key` is a " "jax.random.PRNGKey value.") raise ValueError(msg) if mode == "train": keep = random.bernoulli(rng, 1 - rate, inputs.shape) return np.where(keep, inputs / (1. - rate), 0.) else: return inputs
def sample_trajectory(policy, key, traj, init_state=False, my_s=0, init_action=False, my_a=0): true_transition = jnp.array([[[0.7, 0.3], [0.2, 0.8]], [[0.99, 0.01], [0.99, 0.01]]]) initial_distribution = jnp.ones(2) / 2 if init_state: s = my_s else: key = get_new_key(key) s = bernoulli(key, p=initial_distribution[0]).astype(int) if init_action: a = my_a else: key = get_new_key(key) a = bernoulli(key, p=policy[s][0]).astype(int) states = [s] actions = [a] # sample 2 times more for gae because of no absorbing state traj_len = len(traj) for _ in range(int(traj_len*2)): s, a, key = roll_out(s, a, true_transition, policy, key) states.append(s) actions.append(a) return states, actions, key
def _discrete_modified_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.): assert isinstance(stay_prob, float) and stay_prob >= 0. and stay_prob < 1 rng_key, rng_proposal, rng_stay = random.split(rng_key, 3) z_discrete_flat, unravel_fn = ravel_pytree(z_discrete) i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1) proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i) proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal) z_new_flat = ops.index_update(z_discrete_flat, idx, proposal) z_new = unravel_fn(z_new_flat) pe_new = potential_fn(z_new) log_accept_ratio = pe - pe_new return rng_key, z_new, pe_new, log_accept_ratio
def apply_fun(params, inputs, **kwargs): # pylint: disable=missing-docstring del params # Unused. rng = kwargs.get('rng', None) if rng is None: msg = ( 'Dropout layer requires apply_fun to be called with a PRNG key ' 'argument. That is, instead of `apply_fun(params, inputs)`, call ' 'it like `apply_fun(params, inputs, key)` where `key` is a ' 'jax.random.PRNGKey value.') raise ValueError(msg) if mode == 'train': keep = random.bernoulli(rng, rate, inputs.shape) return np.where(keep, inputs / rate, 0) else: return inputs
def _hmc_next(step_size, inverse_mass_matrix, vv_state, rng): num_steps = _get_num_steps(step_size, trajectory_len) vv_state_new = fori_loop(0, num_steps, lambda i, val: vv_update(step_size, inverse_mass_matrix, val), vv_state) energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r) energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r) delta_energy = energy_new - energy_old delta_energy = np.where(np.isnan(delta_energy), np.inf, delta_energy) accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0) transition = random.bernoulli(rng, accept_prob) vv_state = cond(transition, vv_state_new, lambda state: state, vv_state, lambda state: state) return vv_state, num_steps, accept_prob
def bernoulli(p, size=None): """Sample Bernoulli random values with given shape and mean. Args: p: optional, a float or array of floats for the mean of the random variables. Must be broadcast-compatible with ``shape``. Default 0.5. size: optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with ``p.shape``. The default (None) produces a result shape equal to ``p.shape``. Returns: A random array with boolean dtype and shape given by ``shape`` if ``shape`` is not None, or else ``p.shape``. """ return JaxArray( jr.bernoulli(DEFAULT.split_key(), p=p, shape=_size2shape(size)))
def apply_fun(params, inputs, **kwargs): rng = kwargs.get("rng", None) if rng is None: msg = ( "Dropout layer requires apply_fun to be called with a PRNG key " "argument. That is, instead of `apply_fun(params, inputs)`, call " "it like `apply_fun(params, inputs, rng)` where `rng` is a " "jax.random.PRNGKey value.") raise ValueError(msg) if mode == "train": hrate = kwargs.get("bparam") + rate _, key = random.split(random.PRNGKey(0)) keep = random.bernoulli(key, hrate, inputs.shape) return np.where(keep, inputs, 0) else: return inputs
def dropout(x, key, keep_rate): """Implement a dropout layer. Arguments: x: np array to be dropped out key: random.PRNGKey for random bits keep_rate: dropout rate Returns: np array of dropped out x """ # The shenanigans with np.where are to avoid having to re-jit if # keep rate changes. do_keep = random.bernoulli(key, keep_rate, x.shape) kept_rates = np.where(do_keep, x / keep_rate, 0.0) return np.where(keep_rate < 1.0, kept_rates, x)
def __call__(self, x: JaxArray, training: bool, dropout_keep: Optional[float] = None) -> JaxArray: """Performs dropout of input tensor. Args: x: input tensor. training: if True then apply dropout to the input, otherwise keep input tensor unchanged. dropout_keep: optional argument, when set overrides dropout keep probability. Returns: Tensor with applied dropout. """ keep = dropout_keep or self.keep if not training or keep >= 1: return x keep_mask = jr.bernoulli(self.keygen(), keep, x.shape) return jn.where(keep_mask, x / keep, 0)
def sample(self, curr_state, model_args, model_kwargs): """ Run Hop from the given :data:`~numpyro.infer.hop.HopState` and return the resulting :data:`~numpyro.infer.hop.Hp[State`. :param HopState hop_state: Represents the current state. :param model_args: Arguments provided to the model. :param model_kwargs: Keyword arguments provided to the model. :return: Next `hop_state` after running Hop. """ def proposal_dist(z, g): g = -self._preconditioner.flatten(g) dim = jnp.size(g) rho2 = jnp.clip(jnp.dot(g, g), a_min=1.0) covar = (self._mu2 * jnp.eye(dim) + self._lam2_minus_mu2 * jnp.outer(g, g) / jnp.dot(g, g)) / rho2 return dist.MultivariateNormal(loc=self._preconditioner.flatten(z), covariance_matrix=covar) def proposal_density(dist, z): return dist.log_prob(self._preconditioner.flatten(z)) itr, curr_z, curr_pe, curr_grad, num_steps, _, rng_key = curr_state rng_key, rng_key_hop, rng_key_ar = random.split(rng_key, 3) curr_to_prop = proposal_dist(curr_z, curr_grad) prop_z = self._preconditioner.unflatten( curr_to_prop.sample(rng_key_hop)) prop_pe, prop_grad = value_and_grad(self._potential_fn)(prop_z) prop_to_curr = proposal_dist(prop_z, prop_grad) log_accept_ratio = -prop_pe + curr_pe + \ proposal_density(prop_to_curr, curr_z) - \ proposal_density(curr_to_prop, prop_z) accept_prob = to_accept_prob(log_accept_ratio) transition = random.bernoulli(rng_key_ar, accept_prob) next_z, next_pe, next_grad = cond(transition, (prop_z, prop_pe, prop_grad), identity, (curr_z, curr_pe, curr_grad), identity) return HState(itr + 1, next_z, next_pe, next_grad, num_steps, accept_prob, rng_key)
def __call__(self, x, training: bool, rng: PRNGKey = None): """Applies a random dropout mask to the input. Args: x: the inputs that should be randomly masked. training: if false the inputs are scaled by `1 / (1 - rate)` and masked, whereas if true, no mask is applied and the inputs are returned as is. rng: an optional `jax.random.PRNGKey`. By default `nn.make_rng()` will be used. Returns: The masked inputs reweighted to preserve mean. """ if self.rate == 0. or not training: return x keep_prob = 1. - self.rate if rng is None: rng = self.make_rng('dropout') mask = random.bernoulli(rng, p=keep_prob, shape=x.shape) return lax.select(mask, x / keep_prob, jnp.zeros_like(x))
def dropout(inputs, *args, **kwargs): if len(args) == 1: rng = args[0] else: rng = kwargs.get('rng', None) if rng is None: msg = ( "dropout requires to be called with a PRNG key argument. " "That is, instead of `dropout(params, inputs)`, " "call it like `dropout(inputs, key)` " "where `key` is a jax.random.PRNGKey value.") raise ValueError(msg) if mode == 'train': keep = random.bernoulli(rng, rate, inputs.shape) return np.where(keep, inputs / rate, 0) else: return inputs
def _discrete_gibbs_proposal_body_fn(z_init_flat, unravel_fn, pe_init, potential_fn, idx, i, val): rng_key, z, pe, log_weight_sum = val rng_key, rng_transition = random.split(rng_key) proposal = jnp.where(i >= z_init_flat[idx], i + 1, i) z_new_flat = ops.index_update(z_init_flat, idx, proposal) z_new = unravel_fn(z_new_flat) pe_new = potential_fn(z_new) log_weight_new = pe_init - pe_new # Handles the NaN case... log_weight_new = jnp.where(jnp.isfinite(log_weight_new), log_weight_new, -jnp.inf) # transition_prob = e^weight_new / (e^weight_logsumexp + e^weight_new) transition_prob = expit(log_weight_new - log_weight_sum) z, pe = cond(random.bernoulli(rng_transition, transition_prob), (z_new, pe_new), identity, (z, pe), identity) log_weight_sum = jnp.logaddexp(log_weight_new, log_weight_sum) return rng_key, z, pe, log_weight_sum
def _hmc_next( step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key, trajectory_length, ): if potential_fn_gen: nonlocal vv_update, forward_mode_ad pe_fn = potential_fn_gen(*model_args, **model_kwargs) _, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad) # no need to spend too many steps if the state z has 0 size (i.e. z is empty) if len(inverse_mass_matrix) == 0: num_steps = 1 else: num_steps = _get_num_steps(step_size, trajectory_length) # makes sure trajectory length is constant, rather than step_size * num_steps step_size = trajectory_length / num_steps vv_state_new = fori_loop( 0, num_steps, lambda i, val: vv_update(step_size, inverse_mass_matrix, val), vv_state, ) energy_old = vv_state.potential_energy + kinetic_fn( inverse_mass_matrix, vv_state.r) energy_new = vv_state_new.potential_energy + kinetic_fn( inverse_mass_matrix, vv_state_new.r) delta_energy = energy_new - energy_old delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy) accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0) diverging = delta_energy > max_delta_energy transition = random.bernoulli(rng_key, accept_prob) vv_state, energy = cond( transition, (vv_state_new, energy_new), identity, (vv_state, energy_old), identity, ) return vv_state, energy, num_steps, accept_prob, diverging
def mask_uniform(inputs, rate, rng, mask_value): """Applies a random dropout mask to the input. Args: inputs: the inputs that should be randomly masked. rate: the probablity of masking out a value. rng: an optional `jax.random.PRNGKey`. By default `nn.make_rng()` will be used. mask_value: Value to mask with. Returns: The masked inputs. """ if rate == 0.: return inputs keep_prob = 1. - rate mask = jrandom.bernoulli(rng, p=keep_prob, shape=inputs.shape) return lax.select(mask, inputs, jnp.full_like(inputs, mask_value))
def drop_path(x: jnp.array, drop_rate: float = 0., rng=None) -> jnp.array: """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_rate == 0.: return x keep_prob = 1. - drop_rate if rng is None: rng = make_rng() mask = random.bernoulli(key=rng, p=keep_prob, shape=(x.shape[0], 1, 1, 1)) mask = jnp.broadcast_to(mask, x.shape) return lax.select(mask, x / keep_prob, jnp.zeros_like(x))
def _discrete_modified_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.): assert isinstance(stay_prob, float) and stay_prob >= 0. and stay_prob < 1 z_discrete_flat, unravel_fn = ravel_pytree(z_discrete) body_fn = partial(_discrete_gibbs_proposal_body_fn, z_discrete_flat, unravel_fn, pe, potential_fn, idx) # like gibbs_step but here, weight of the current value is 0 init_val = (rng_key, z_discrete, pe, jnp.array(-jnp.inf)) rng_key, z_new, pe_new, log_weight_sum = fori_loop(0, support_size - 1, body_fn, init_val) rng_key, rng_stay = random.split(rng_key) z_new, pe_new = cond(random.bernoulli(rng_stay, stay_prob), (z_discrete, pe), identity, (z_new, pe_new), identity) # here we calculate the MH correction: (1 - P(z)) / (1 - P(z_new)) # where 1 - P(z) ~ weight_sum # and 1 - P(z_new) ~ 1 + weight_sum - z_new_weight log_accept_ratio = log_weight_sum - jnp.log(jnp.exp(log_weight_sum) - jnp.expm1(pe - pe_new)) return rng_key, z_new, pe_new, log_accept_ratio