Exemplo n.º 1
0
    def __call__(self, rng, particles, rho, log_posterior_params):
        """Call carries out procedures 4 in https://arxiv.org/pdf/1101.6037.pdf.

    One expects that fit_model has been called right before to store the model
    in self.model

    Args:

     rng: np.ndarray<int> random key
     particles: np.ndarray [n_particles,n_patients] plausible infections states
     rho: float, scaling for posterior.
     log_posterior_params: Dict of parameters to compute log-posterior.

    Returns:
     A np.ndarray representing the new particles.
    """
        rngs = jax.random.split(rng, 2)
        n_samples = particles.shape[0]

        proposed, logprop_proposed, logprop_particles = self.sample_from_model(
            rngs[0], particles)
        llparticles = rho * bayes.log_posterior(particles, **
                                                log_posterior_params)
        llproposed = rho * bayes.log_posterior(proposed, **
                                               log_posterior_params)
        logratio = llproposed - llparticles + logprop_particles - logprop_proposed
        p_replacement = np.minimum(np.exp(logratio), 1)
        replacement = (jax.random.uniform(rngs[1], shape=(n_samples, )) <
                       p_replacement)
        not_replacement = np.logical_not(replacement)
        return (replacement[:, np.newaxis] * proposed +
                not_replacement[:, np.newaxis] * particles)
Exemplo n.º 2
0
 def gibbs_loop(i, rng_particles_log_posteriors):
     rng, particles, log_posteriors = rng_particles_log_posteriors
     i = i % num_patients
     # flip values at index i
     particles_flipped = jax.ops.index_update(
         particles, jax.ops.index[:, i], np.logical_not(particles[:, i]))
     # compute log_posterior of flipped particles
     log_posteriors_flipped_at_i = rho * bayes.log_posterior(
         particles_flipped, **log_posterior_params)
     # compute acceptance probability, depending on whether we use Liu mod.
     if liu_modification:
         log_proposal_ratio = log_posteriors_flipped_at_i - log_posteriors
     else:
         log_proposal_ratio = log_posteriors_flipped_at_i - np.logaddexp(
             log_posteriors_flipped_at_i, log_posteriors)
     # here the MH thresholding is implicitly done.
     rng, rng_unif = jax.random.split(rng, 2)
     random_values = jax.random.uniform(rng_unif, particles.shape[:1])
     flipped_at_i = np.log(random_values) < log_proposal_ratio
     selected_at_i = np.logical_xor(flipped_at_i, particles[:, i])
     particles = jax.ops.index_update(particles, jax.ops.index[:, i],
                                      selected_at_i)
     log_posteriors = np.where(flipped_at_i, log_posteriors_flipped_at_i,
                               log_posteriors)
     return [rng, particles, log_posteriors]
Exemplo n.º 3
0
 def produce_sample(self, rng, state):
     self.particles = all_binary_vectors(state.num_patients,
                                         self.upper_bound)
     log_posteriors = bayes.log_posterior(self.particles,
                                          state.past_test_results,
                                          state.past_groups,
                                          state.log_prior_specificity,
                                          state.log_prior_1msensitivity,
                                          state.prior_infection_rate)
     self.particle_weights = temperature.importance_weights(log_posteriors)
 def resample_move(self, rng, particles, state):
     """Resample / Move sequence."""
     log_posterior_params = state.log_posterior_params
     log_posterior = bayes.log_posterior(particles, **log_posterior_params)
     alpha, logpialphaparticles = temperature.find_step_length(
         0, log_posterior)
     rho = alpha
     particle_weights = temperature.importance_weights(logpialphaparticles)
     while rho < 1:
         print(f'Sampling {rho:.0%}', end='\r')
         rng, *rngs = jax.random.split(rng, 3)
         self._kernel.fit_model(particle_weights, particles)
         particles = particles[self.resample(rngs[0], particle_weights), :]
         particles = self.move(rngs[1], particles, rho,
                               log_posterior_params)
         alpha, logpialphaparticles = temperature.find_step_length(
             rho, bayes.log_posterior(particles, **log_posterior_params))
         particle_weights = temperature.importance_weights(
             logpialphaparticles)
         rho = rho + alpha
     return particle_weights, particles
Exemplo n.º 5
0
def gibbs_kernel(rng,
                 particles,
                 rho,
                 log_posterior_params,
                 cycles=2,
                 liu_modification=True):
    """Applies a (Liu modified) Gibbs kernel (with MH) update.

  Implements vanilla (sequential, looping over coordinates) Gibbs sampling.
  When
  The Liu variant comes from Jun Liu's remarks in
  https://academic.oup.com/biomet/article-abstract/83/3/681/241540?redirectedFrom=fulltext

  which essentially changes the acceptance of a flip from
    p(flip) / [ p(no flip) + p(flip) ]
  to
    min(1, p(flip) / p(no flip) )

  In other words, Liu's modification increases the probability to flip.

  Args:
   rng: np.ndarray<int> random key.
   particles: np.ndarray [n_particles,n_patients] plausible infections states.
   rho: float, scaling for posterior.
   log_posterior_params: Dict of parameters to compute log-posterior.
   cycles: the number of times we want of do Gibbs sampling.
   liu_modification : use or not Liu's modification.

  Returns:
   A np.array representing the new particles.
  """
    def gibbs_loop(i, rng_particles_log_posteriors):
        rng, particles, log_posteriors = rng_particles_log_posteriors
        i = i % num_patients
        # flip values at index i
        particles_flipped = jax.ops.index_update(
            particles, jax.ops.index[:, i], np.logical_not(particles[:, i]))
        # compute log_posterior of flipped particles
        log_posteriors_flipped_at_i = rho * bayes.log_posterior(
            particles_flipped, **log_posterior_params)
        # compute acceptance probability, depending on whether we use Liu mod.
        if liu_modification:
            log_proposal_ratio = log_posteriors_flipped_at_i - log_posteriors
        else:
            log_proposal_ratio = log_posteriors_flipped_at_i - np.logaddexp(
                log_posteriors_flipped_at_i, log_posteriors)
        # here the MH thresholding is implicitly done.
        rng, rng_unif = jax.random.split(rng, 2)
        random_values = jax.random.uniform(rng_unif, particles.shape[:1])
        flipped_at_i = np.log(random_values) < log_proposal_ratio
        selected_at_i = np.logical_xor(flipped_at_i, particles[:, i])
        particles = jax.ops.index_update(particles, jax.ops.index[:, i],
                                         selected_at_i)
        log_posteriors = np.where(flipped_at_i, log_posteriors_flipped_at_i,
                                  log_posteriors)
        return [rng, particles, log_posteriors]

    num_patients = particles.shape[1]
    log_posteriors = bayes.log_posterior(particles, **log_posterior_params)
    rng_particles = jax.lax.fori_loop(0, cycles * num_patients, gibbs_loop,
                                      [rng, particles, log_posteriors])
    # TODO(cuturi) : might be relevant to forward log_posterior_particles
    return rng_particles[1]