Example #1
0
 def _call(self, x, training=True, rng=None):
     info = self.info
     if training:
         if rng is None:
             raise ValueError('rng is required when training is True')
         # Using tie_in to avoid materializing constants
         keep = primitive.tie_in(x,
                                 random.bernoulli(rng, info.rate, x.shape))
         return np.where(keep, x / info.rate, 0)
     else:
         return x
Example #2
0
 def _get_inputs(cls, out_logits, test_shape, train_shape):
     key = random.PRNGKey(0)
     key, split = random.split(key)
     x_train = random.normal(split, train_shape)
     key, split = random.split(key)
     y_train = np.array(
         random.bernoulli(split, shape=(train_shape[0], out_logits)),
         np.float32)
     key, split = random.split(key)
     x_test = random.normal(split, test_shape)
     return key, x_test, x_train, y_train
Example #3
0
  def testBernoulli(self, p, dtype):
    key = random.PRNGKey(0)
    p = np.array(p, dtype=dtype)
    rand = lambda key, p: random.bernoulli(key, p, (10000,))
    crand = api.jit(rand)

    uncompiled_samples = rand(key, p)
    compiled_samples = crand(key, p)

    for samples in [uncompiled_samples, compiled_samples]:
      self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf)
Example #4
0
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)):
    eps = 1e-6

    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size)
    elif isinstance(constraint, constraints._GreaterThan):
        return np.exp(random.normal(key, size)) + constraint.lower_bound + eps
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = np.broadcast_to(constraint.lower_bound, size)
        upper_bound = np.broadcast_to(constraint.upper_bound, size)
        return random.randint(key, size, lower_bound, upper_bound + 1)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound + poisson(key, 5, shape=size)
    elif isinstance(constraint, constraints._Interval):
        lower_bound = np.broadcast_to(constraint.lower_bound, size)
        upper_bound = np.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key,
                              size,
                              minval=lower_bound,
                              maxval=upper_bound)
    elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
        return random.normal(key, size)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=np.ones((size[-1], )), size=size[:-1])
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key,
                           p=np.ones((n, )) / n,
                           n=constraint.upper_bound,
                           shape=size[:-1])
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key,
                           size[:-2] + (size[-1] * (size[-1] - 1) // 2, ),
                           minval=-1,
                           maxval=1))
    elif isinstance(constraint, constraints._CorrMatrix):
        cholesky = signed_stick_breaking_tril(
            random.uniform(key,
                           size[:-2] + (size[-1] * (size[-1] - 1) // 2, ),
                           minval=-1,
                           maxval=1))
        return np.matmul(cholesky, np.swapaxes(cholesky, -2, -1))
    elif isinstance(constraint, constraints._LowerCholesky):
        return np.tril(random.uniform(key, size))
    elif isinstance(constraint, constraints._PositiveDefinite):
        x = random.normal(key, size)
        return np.matmul(x, np.swapaxes(x, -2, -1))
    elif isinstance(constraint, constraints._OrderedVector):
        x = np.cumsum(random.exponential(key, size), -1)
        return x - random.normal(key, size[:-1])
    else:
        raise NotImplementedError('{} not implemented.'.format(constraint))
Example #5
0
    def __call__(self, x, deterministic=False, rng=None):
        if self.rate == 0.:
            return x
        keep_prob = 1. - self.rate

        if deterministic:
            return x
        else:
            if rng is None:
                rng = self.scope.make_rng('dropout')
            mask = random.bernoulli(rng, p=keep_prob, shape=x.shape)
            return lax.select(mask, x / keep_prob, jnp.zeros_like(x))
Example #6
0
    def step(self):
        self.update_agent_states()  # Update the counts of the various states
        self.update_counts()
        if self.model.log_file:
            self.model.write_logs()  # Write some logs, if enabled

        # NB: At some point, it may make sense to make agents move. For now, they're static.

        ############################# Transmission logic ###################################
        # Get all the currently contagious agents, and have them infect new agents.
        # TODO: include hospital transmission, vary transmissability by state.
        contagious = np.asarray(
            (self.model.epidemic_state == self.model.STATE_PRESYMPTOMATIC)
            | (self.model.epidemic_state == self.model.STATE_ASYMPTOMATIC)
            | (self.model.epidemic_state == self.model.STATE_SYMPTOMATIC)
        ).nonzero()

        # For each contagious person, infect some of its neighbors based on their hygiene and the contagious person's social radius.
        # Use jax.random instead of numpyro here to keep these deterministic.
        # TODO: figure out a way to do this in a (more) vectorized manner. Probably some sort of kernel convolution method with each radius. Should also look into numpyro's scan.
        for x, y in zip(*contagious):
            radius = self.model.social_radius[x, y]
            base_isolation = self.model.base_isolation[x, y]
            nx, ny = np.meshgrid(  # pylint: disable=unbalanced-tuple-unpacking
                np.arange(x - radius, x + radius),
                np.arange(y - radius, y + radius))
            neighbor_inds = np.vstack([nx.ravel(), ny.ravel()])
            # Higher base_isolation leads to less infection.
            # TODO: modify isolation so that symptomatic agents isolate more.
            infection_attempts = jrand.choice(
                self.PRNGKey,
                neighbor_inds,
                shape=int(len(neighbor_inds) * (1 - base_isolation)))
            potentially_infected_hygiene = self.model.hygiene[
                infection_attempts[:, 0], infection_attempts[:, 1]]
            susceptible = self.model.epidemic_state[
                infection_attempts[:, 0],
                infection_attempts[:, 1]] == self.model.STATE_SUSCEPTIBLE

            indexer = jrand.bernoulli(self.PRNGKey,
                                      potentially_infected_hygiene.ravel(),
                                      len(infection_attempts))
            got_infected = np.zeros(self.model.epidemic_state.shape,
                                    dtype=np.bool_)
            got_infected[potentially_infected_hygiene[indexer]] = True

            # Set the date to become infectious
            self.model.epidemic_state[got_infected
                                      & susceptible] = self.model.STATE_EXPOSED
            self.model.date_infected[got_infected & susceptible] = self.time
            self.model.date_contagious[
                got_infected
                & susceptible] = self.time + self.params.EXPOSED_PERIOD
Example #7
0
File: vae.py Project: byzhang/d3p
def binarize(rng, batch):
    """Binarizes a batch of observations with values in [0,1] by sampling from
        a Bernoulli distribution and using the original observations as means.

    Reason: This example assumes a Bernoulli distribution for the decoder output
    and thus requires inputs to be binary values as well.

    :param rng: rng seed key
    :param batch: Batch of data with continous values in interval [0, 1]
    :return: tuple(rng, binarized_batch).
    """
    return random.bernoulli(rng, batch).astype(batch.dtype)
Example #8
0
 def forward(ctx, input, p=0.5, train=False):
     assert isinstance(input, Variable)
     noise = random.bernoulli(key=random.PRNGKey(rd.randint(-1000000000000000000, 1000000000000000000)),p=p, shape=input.data.shape)
     if not train:
         noise = jnp.ones(input.data.shape)
     if p == 1:
         noise = jnp.zeros(input.data.shape)
     def np_fn(input_np, noise):
         return input_np * noise
     np_args = (input.data, noise)
     id = "Dropout"
     return np_fn, np_args, jit(np_fn)(*np_args),id
Example #9
0
 def apply_fun(params, inputs, rng):
     if rng is None:
         msg = (
             "Dropout layer requires apply_fun to be called with a PRNG key "
             "argument. That is, instead of `apply_fun(params, inputs)`, call "
             "it like `apply_fun(params, inputs, key)` where `key` is a "
             "jax.random.PRNGKey value.")
         raise ValueError(msg)
     if mode == 'train':
         keep = random.bernoulli(rng, rate, inputs.shape)
         return np.where(keep, inputs / rate, 0)
     else:
         return inputs
Example #10
0
 def apply_fun(params, inputs, is_training, **kwargs):
     rng = kwargs.get('rng', None)
     if rng is None:
         msg = ("Dropout layer requires apply_fun to be called with a PRNG key "
                "argument. That is, instead of `apply_fun(params, inputs)`, call "
                "it like `apply_fun(params, inputs, rng)` where `rng` is a "
                "jax.random.PRNGKey value.")
         raise ValueError(msg)
     keep = random.bernoulli(rng, rate, inputs.shape)
     outs = np.where(keep, inputs / rate, 0)
     # if not training, just return inputs and discard any computation done
     out = lax.cond(is_training, outs, lambda x: x, inputs, lambda x: x)
     return out
  def testMaxLearningRate(self, train_shape, network, out_logits,
                          fn_and_kernel, lr_factor, momentum):

    key = random.PRNGKey(0)

    key, split = random.split(key)
    if len(train_shape) == 2:
      train_shape = (train_shape[0] * 5, train_shape[1] * 10)
    else:
      train_shape = (16, 8, 8, 3)
    x_train = random.normal(split, train_shape)

    key, split = random.split(key)
    y_train = np.array(
        random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32)

    # Regress to an MSE loss.
    loss = lambda params, x: 0.5 * np.mean((f(params, x) - y_train) ** 2)
    grad_loss = jit(grad(loss))

    def get_loss(opt_state):
      return loss(get_params(opt_state), x_train)

    steps = 30

    params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits)
    g_dd = ntk(x_train, None, 'ntk')

    step_size = predict.max_learning_rate(
        g_dd, y_train_size=y_train.size, momentum=momentum) * lr_factor
    opt_init, opt_update, get_params = optimizers.momentum(step_size,
                                                           mass=momentum)

    opt_state = opt_init(params)

    init_loss = get_loss(opt_state)

    for i in range(steps):
      params = get_params(opt_state)
      opt_state = opt_update(i, grad_loss(params, x_train), opt_state)

    trained_loss = get_loss(opt_state)
    loss_ratio = trained_loss / (init_loss + 1e-12)
    if lr_factor < 1.:
      self.assertLess(loss_ratio, 0.1)
    elif lr_factor == 1:
      # At the threshold, the loss decays slowly
      self.assertLess(loss_ratio, 1.)
    if lr_factor > 2.:
      if not math.isnan(loss_ratio):
        self.assertGreater(loss_ratio, 10.)
Example #12
0
    def sample(self, state, model_args, model_kwargs):
        i, x, x_pe, x_grad, _, mean_accept_prob, adapt_state, rng_key = state
        x_flat, unravel_fn = ravel_pytree(x)
        x_grad_flat, _ = ravel_pytree(x_grad)
        shape = jnp.shape(x_flat)
        rng_key, key_normal, key_bernoulli, key_accept = random.split(
            rng_key, 4)

        mass_sqrt_inv = adapt_state.mass_matrix_sqrt_inv

        x_grad_flat_scaled = mass_sqrt_inv @ x_grad_flat if self._dense_mass else mass_sqrt_inv * x_grad_flat

        # Generate proposal y.
        z = adapt_state.step_size * random.normal(key_normal, shape)

        p = expit(-z * x_grad_flat_scaled)
        b = jnp.where(random.uniform(key_bernoulli, shape) < p, 1., -1.)

        dx_flat = b * z
        dx_flat_scaled = mass_sqrt_inv.T @ dx_flat if self._dense_mass else mass_sqrt_inv * dx_flat

        y_flat = x_flat + dx_flat_scaled

        y = unravel_fn(y_flat)
        y_pe, y_grad = jax.value_and_grad(self._potential_fn)(y)
        y_grad_flat, _ = ravel_pytree(y_grad)
        y_grad_flat_scaled = mass_sqrt_inv @ y_grad_flat if self._dense_mass else mass_sqrt_inv * y_grad_flat

        log_accept_ratio = x_pe - y_pe + jnp.sum(
            softplus(dx_flat * x_grad_flat_scaled) -
            softplus(-dx_flat * y_grad_flat_scaled))
        accept_prob = jnp.clip(jnp.exp(log_accept_ratio), a_max=1.)

        x, x_flat, pe, x_grad = jax.lax.cond(
            random.bernoulli(key_accept,
                             accept_prob), (y, y_flat, y_pe, y_grad), identity,
            (x, x_flat, x_pe, x_grad), identity)

        # do not update adapt_state after warmup phase
        adapt_state = jax.lax.cond(i < self._num_warmup,
                                   (i, accept_prob, (x, ), adapt_state),
                                   lambda args: self._wa_update(*args),
                                   adapt_state, identity)

        itr = i + 1
        n = jnp.where(i < self._num_warmup, itr, itr - self._num_warmup)
        mean_accept_prob = mean_accept_prob + (accept_prob -
                                               mean_accept_prob) / n

        return BarkerMHState(itr, x, pe, x_grad, accept_prob, mean_accept_prob,
                             adapt_state, rng_key)
Example #13
0
def sample_rewards(policy,
                   discriminator,
                   initialization=False,
                   initial_state=0):
    global key
    get_new_key()
    s = bernoulli(key, p=initial_distribution[0]).astype(int)
    if initialization:
        s = initial_state
    get_new_key()
    a = bernoulli(key, p=policy[s][0]).astype(int)

    traj = []
    traj.append((s, a))
    returns = []
    returns.append(jnp.log(discriminator[s][a]))

    for i in range(traj_len - 1):
        s, a = roll_out(s, a, true_transition, policy)
        traj.append((s, a))
        returns.append(jnp.log(discriminator[s][a]))

    return jnp.array(copy.deepcopy(returns)), jnp.array(copy.deepcopy(traj))
Example #14
0
 def apply_fun(params, inputs, **kwargs):
     rng = kwargs.get("rng", None)
     if rng is None:
         msg = (
             "Dropout layer requires apply_fun to be called with a PRNG key "
             "argument. That is, instead of `apply_fun(params, inputs)`, call "
             "it like `apply_fun(params, inputs, key)` where `key` is a "
             "jax.random.PRNGKey value.")
         raise ValueError(msg)
     if mode == "train":
         keep = random.bernoulli(rng, 1 - rate, inputs.shape)
         return np.where(keep, inputs / (1. - rate), 0.)
     else:
         return inputs
Example #15
0
def sample_trajectory(policy, key, traj, init_state=False, my_s=0, init_action=False, my_a=0):
    true_transition = jnp.array([[[0.7, 0.3], [0.2, 0.8]],
                                 [[0.99, 0.01], [0.99, 0.01]]])
    initial_distribution = jnp.ones(2) / 2
    if init_state:
        s = my_s
    else:
        key = get_new_key(key)
        s = bernoulli(key, p=initial_distribution[0]).astype(int)
    if init_action:
        a = my_a
    else:
        key = get_new_key(key)
        a = bernoulli(key, p=policy[s][0]).astype(int)
    states = [s]
    actions = [a]
    # sample 2 times more for gae because of no absorbing state
    traj_len = len(traj)
    for _ in range(int(traj_len*2)):
        s, a, key = roll_out(s, a, true_transition, policy, key)
        states.append(s)
        actions.append(a)
    return states, actions, key
Example #16
0
def _discrete_modified_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size,
                                   stay_prob=0.):
    assert isinstance(stay_prob, float) and stay_prob >= 0. and stay_prob < 1
    rng_key, rng_proposal, rng_stay = random.split(rng_key, 3)
    z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)

    i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1)
    proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i)
    proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal)
    z_new_flat = ops.index_update(z_discrete_flat, idx, proposal)
    z_new = unravel_fn(z_new_flat)
    pe_new = potential_fn(z_new)
    log_accept_ratio = pe - pe_new
    return rng_key, z_new, pe_new, log_accept_ratio
Example #17
0
 def apply_fun(params, inputs, **kwargs):  # pylint: disable=missing-docstring
     del params  # Unused.
     rng = kwargs.get('rng', None)
     if rng is None:
         msg = (
             'Dropout layer requires apply_fun to be called with a PRNG key '
             'argument. That is, instead of `apply_fun(params, inputs)`, call '
             'it like `apply_fun(params, inputs, key)` where `key` is a '
             'jax.random.PRNGKey value.')
         raise ValueError(msg)
     if mode == 'train':
         keep = random.bernoulli(rng, rate, inputs.shape)
         return np.where(keep, inputs / rate, 0)
     else:
         return inputs
Example #18
0
 def _hmc_next(step_size, inverse_mass_matrix, vv_state, rng):
     num_steps = _get_num_steps(step_size, trajectory_len)
     vv_state_new = fori_loop(0, num_steps,
                              lambda i, val: vv_update(step_size, inverse_mass_matrix, val),
                              vv_state)
     energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r)
     energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r)
     delta_energy = energy_new - energy_old
     delta_energy = np.where(np.isnan(delta_energy), np.inf, delta_energy)
     accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0)
     transition = random.bernoulli(rng, accept_prob)
     vv_state = cond(transition,
                     vv_state_new, lambda state: state,
                     vv_state, lambda state: state)
     return vv_state, num_steps, accept_prob
Example #19
0
def bernoulli(p, size=None):
    """Sample Bernoulli random values with given shape and mean.

  Args:
    p: optional, a float or array of floats for the mean of the random
      variables. Must be broadcast-compatible with ``shape``. Default 0.5.
    size: optional, a tuple of nonnegative integers representing the result
      shape. Must be broadcast-compatible with ``p.shape``. The default (None)
      produces a result shape equal to ``p.shape``.

  Returns:
    A random array with boolean dtype and shape given by ``shape`` if ``shape``
    is not None, or else ``p.shape``.
  """
    return JaxArray(
        jr.bernoulli(DEFAULT.split_key(), p=p, shape=_size2shape(size)))
Example #20
0
 def apply_fun(params, inputs, **kwargs):
     rng = kwargs.get("rng", None)
     if rng is None:
         msg = (
             "Dropout layer requires apply_fun to be called with a PRNG key "
             "argument. That is, instead of `apply_fun(params, inputs)`, call "
             "it like `apply_fun(params, inputs, rng)` where `rng` is a "
             "jax.random.PRNGKey value.")
         raise ValueError(msg)
     if mode == "train":
         hrate = kwargs.get("bparam") + rate
         _, key = random.split(random.PRNGKey(0))
         keep = random.bernoulli(key, hrate, inputs.shape)
         return np.where(keep, inputs, 0)
     else:
         return inputs
Example #21
0
def dropout(x, key, keep_rate):
    """Implement a dropout layer.

  Arguments:
    x: np array to be dropped out
    key: random.PRNGKey for random bits
    keep_rate: dropout rate

  Returns:
    np array of dropped out x
  """
    # The shenanigans with np.where are to avoid having to re-jit if
    # keep rate changes.
    do_keep = random.bernoulli(key, keep_rate, x.shape)
    kept_rates = np.where(do_keep, x / keep_rate, 0.0)
    return np.where(keep_rate < 1.0, kept_rates, x)
Example #22
0
    def __call__(self, x: JaxArray, training: bool, dropout_keep: Optional[float] = None) -> JaxArray:
        """Performs dropout of input tensor.

        Args:
            x: input tensor.
            training: if True then apply dropout to the input, otherwise keep input tensor unchanged.
            dropout_keep: optional argument, when set overrides dropout keep probability.

        Returns:
            Tensor with applied dropout.
        """
        keep = dropout_keep or self.keep
        if not training or keep >= 1:
            return x
        keep_mask = jr.bernoulli(self.keygen(), keep, x.shape)
        return jn.where(keep_mask, x / keep, 0)
Example #23
0
    def sample(self, curr_state, model_args, model_kwargs):
        """
        Run Hop from the given :data:`~numpyro.infer.hop.HopState` and return
        the resulting :data:`~numpyro.infer.hop.Hp[State`.

        :param HopState hop_state: Represents the current state.
        :param model_args: Arguments provided to the model.
        :param model_kwargs: Keyword arguments provided to the model.
        :return: Next `hop_state` after running Hop.
        """
        def proposal_dist(z, g):
            g = -self._preconditioner.flatten(g)
            dim = jnp.size(g)
            rho2 = jnp.clip(jnp.dot(g, g), a_min=1.0)
            covar = (self._mu2 * jnp.eye(dim) + self._lam2_minus_mu2 *
                     jnp.outer(g, g) / jnp.dot(g, g)) / rho2
            return dist.MultivariateNormal(loc=self._preconditioner.flatten(z),
                                           covariance_matrix=covar)

        def proposal_density(dist, z):
            return dist.log_prob(self._preconditioner.flatten(z))

        itr, curr_z, curr_pe, curr_grad, num_steps, _, rng_key = curr_state

        rng_key, rng_key_hop, rng_key_ar = random.split(rng_key, 3)

        curr_to_prop = proposal_dist(curr_z, curr_grad)

        prop_z = self._preconditioner.unflatten(
            curr_to_prop.sample(rng_key_hop))

        prop_pe, prop_grad = value_and_grad(self._potential_fn)(prop_z)
        prop_to_curr = proposal_dist(prop_z, prop_grad)

        log_accept_ratio = -prop_pe + curr_pe + \
            proposal_density(prop_to_curr, curr_z) - \
            proposal_density(curr_to_prop, prop_z)
        accept_prob = to_accept_prob(log_accept_ratio)
        transition = random.bernoulli(rng_key_ar, accept_prob)
        next_z, next_pe, next_grad = cond(transition,
                                          (prop_z, prop_pe, prop_grad),
                                          identity,
                                          (curr_z, curr_pe, curr_grad),
                                          identity)

        return HState(itr + 1, next_z, next_pe, next_grad, num_steps,
                      accept_prob, rng_key)
Example #24
0
 def __call__(self, x, training: bool, rng: PRNGKey = None):
     """Applies a random dropout mask to the input.
     Args:
         x: the inputs that should be randomly masked.
         training: if false the inputs are scaled by `1 / (1 - rate)` and
             masked, whereas if true, no mask is applied and the inputs are returned as is.
         rng: an optional `jax.random.PRNGKey`. By default `nn.make_rng()` will be used.
     Returns:
         The masked inputs reweighted to preserve mean.
     """
     if self.rate == 0. or not training:
         return x
     keep_prob = 1. - self.rate
     if rng is None:
         rng = self.make_rng('dropout')
     mask = random.bernoulli(rng, p=keep_prob, shape=x.shape)
     return lax.select(mask, x / keep_prob, jnp.zeros_like(x))
Example #25
0
 def dropout(inputs, *args, **kwargs):
     if len(args) == 1:
         rng = args[0]
     else:
         rng = kwargs.get('rng', None)
         if rng is None:
             msg = (
                 "dropout requires to be called with a PRNG key argument. "
                 "That is, instead of `dropout(params, inputs)`, "
                 "call it like `dropout(inputs, key)` "
                 "where `key` is a jax.random.PRNGKey value.")
             raise ValueError(msg)
     if mode == 'train':
         keep = random.bernoulli(rng, rate, inputs.shape)
         return np.where(keep, inputs / rate, 0)
     else:
         return inputs
Example #26
0
def _discrete_gibbs_proposal_body_fn(z_init_flat, unravel_fn, pe_init, potential_fn, idx, i, val):
    rng_key, z, pe, log_weight_sum = val
    rng_key, rng_transition = random.split(rng_key)
    proposal = jnp.where(i >= z_init_flat[idx], i + 1, i)
    z_new_flat = ops.index_update(z_init_flat, idx, proposal)
    z_new = unravel_fn(z_new_flat)
    pe_new = potential_fn(z_new)
    log_weight_new = pe_init - pe_new
    # Handles the NaN case...
    log_weight_new = jnp.where(jnp.isfinite(log_weight_new), log_weight_new, -jnp.inf)
    # transition_prob = e^weight_new / (e^weight_logsumexp + e^weight_new)
    transition_prob = expit(log_weight_new - log_weight_sum)
    z, pe = cond(random.bernoulli(rng_transition, transition_prob),
                 (z_new, pe_new), identity,
                 (z, pe), identity)
    log_weight_sum = jnp.logaddexp(log_weight_new, log_weight_sum)
    return rng_key, z, pe, log_weight_sum
Example #27
0
    def _hmc_next(
        step_size,
        inverse_mass_matrix,
        vv_state,
        model_args,
        model_kwargs,
        rng_key,
        trajectory_length,
    ):
        if potential_fn_gen:
            nonlocal vv_update, forward_mode_ad
            pe_fn = potential_fn_gen(*model_args, **model_kwargs)
            _, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)

        # no need to spend too many steps if the state z has 0 size (i.e. z is empty)
        if len(inverse_mass_matrix) == 0:
            num_steps = 1
        else:
            num_steps = _get_num_steps(step_size, trajectory_length)
        # makes sure trajectory length is constant, rather than step_size * num_steps
        step_size = trajectory_length / num_steps
        vv_state_new = fori_loop(
            0,
            num_steps,
            lambda i, val: vv_update(step_size, inverse_mass_matrix, val),
            vv_state,
        )
        energy_old = vv_state.potential_energy + kinetic_fn(
            inverse_mass_matrix, vv_state.r)
        energy_new = vv_state_new.potential_energy + kinetic_fn(
            inverse_mass_matrix, vv_state_new.r)
        delta_energy = energy_new - energy_old
        delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf,
                                 delta_energy)
        accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
        diverging = delta_energy > max_delta_energy
        transition = random.bernoulli(rng_key, accept_prob)
        vv_state, energy = cond(
            transition,
            (vv_state_new, energy_new),
            identity,
            (vv_state, energy_old),
            identity,
        )
        return vv_state, energy, num_steps, accept_prob, diverging
Example #28
0
def mask_uniform(inputs, rate, rng, mask_value):
    """Applies a random dropout mask to the input.

  Args:
    inputs: the inputs that should be randomly masked.
    rate: the probablity of masking out a value.
    rng: an optional `jax.random.PRNGKey`. By default `nn.make_rng()` will be
      used.
    mask_value: Value to mask with.

  Returns:
    The masked inputs.
  """
    if rate == 0.:
        return inputs
    keep_prob = 1. - rate
    mask = jrandom.bernoulli(rng, p=keep_prob, shape=inputs.shape)
    return lax.select(mask, inputs, jnp.full_like(inputs, mask_value))
Example #29
0
def drop_path(x: jnp.array, drop_rate: float = 0., rng=None) -> jnp.array:
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.

    """
    if drop_rate == 0.:
        return x
    keep_prob = 1. - drop_rate
    if rng is None:
        rng = make_rng()
    mask = random.bernoulli(key=rng, p=keep_prob, shape=(x.shape[0], 1, 1, 1))
    mask = jnp.broadcast_to(mask, x.shape)
    return lax.select(mask, x / keep_prob, jnp.zeros_like(x))
Example #30
0
def _discrete_modified_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size,
                                      stay_prob=0.):
    assert isinstance(stay_prob, float) and stay_prob >= 0. and stay_prob < 1
    z_discrete_flat, unravel_fn = ravel_pytree(z_discrete)
    body_fn = partial(_discrete_gibbs_proposal_body_fn,
                      z_discrete_flat, unravel_fn, pe, potential_fn, idx)
    # like gibbs_step but here, weight of the current value is 0
    init_val = (rng_key, z_discrete, pe, jnp.array(-jnp.inf))
    rng_key, z_new, pe_new, log_weight_sum = fori_loop(0, support_size - 1, body_fn, init_val)
    rng_key, rng_stay = random.split(rng_key)
    z_new, pe_new = cond(random.bernoulli(rng_stay, stay_prob),
                         (z_discrete, pe), identity,
                         (z_new, pe_new), identity)
    # here we calculate the MH correction: (1 - P(z)) / (1 - P(z_new))
    # where 1 - P(z) ~ weight_sum
    # and 1 - P(z_new) ~ 1 + weight_sum - z_new_weight
    log_accept_ratio = log_weight_sum - jnp.log(jnp.exp(log_weight_sum) - jnp.expm1(pe - pe_new))
    return rng_key, z_new, pe_new, log_accept_ratio