def __call__(self, rng, particles, rho, log_posterior_params): """Call carries out procedures 4 in https://arxiv.org/pdf/1101.6037.pdf. One expects that fit_model has been called right before to store the model in self.model Args: rng: np.ndarray<int> random key particles: np.ndarray [n_particles,n_patients] plausible infections states rho: float, scaling for posterior. log_posterior_params: Dict of parameters to compute log-posterior. Returns: A np.ndarray representing the new particles. """ rngs = jax.random.split(rng, 2) n_samples = particles.shape[0] proposed, logprop_proposed, logprop_particles = self.sample_from_model( rngs[0], particles) llparticles = rho * bayes.log_posterior(particles, ** log_posterior_params) llproposed = rho * bayes.log_posterior(proposed, ** log_posterior_params) logratio = llproposed - llparticles + logprop_particles - logprop_proposed p_replacement = np.minimum(np.exp(logratio), 1) replacement = (jax.random.uniform(rngs[1], shape=(n_samples, )) < p_replacement) not_replacement = np.logical_not(replacement) return (replacement[:, np.newaxis] * proposed + not_replacement[:, np.newaxis] * particles)
def gibbs_loop(i, rng_particles_log_posteriors): rng, particles, log_posteriors = rng_particles_log_posteriors i = i % num_patients # flip values at index i particles_flipped = jax.ops.index_update( particles, jax.ops.index[:, i], np.logical_not(particles[:, i])) # compute log_posterior of flipped particles log_posteriors_flipped_at_i = rho * bayes.log_posterior( particles_flipped, **log_posterior_params) # compute acceptance probability, depending on whether we use Liu mod. if liu_modification: log_proposal_ratio = log_posteriors_flipped_at_i - log_posteriors else: log_proposal_ratio = log_posteriors_flipped_at_i - np.logaddexp( log_posteriors_flipped_at_i, log_posteriors) # here the MH thresholding is implicitly done. rng, rng_unif = jax.random.split(rng, 2) random_values = jax.random.uniform(rng_unif, particles.shape[:1]) flipped_at_i = np.log(random_values) < log_proposal_ratio selected_at_i = np.logical_xor(flipped_at_i, particles[:, i]) particles = jax.ops.index_update(particles, jax.ops.index[:, i], selected_at_i) log_posteriors = np.where(flipped_at_i, log_posteriors_flipped_at_i, log_posteriors) return [rng, particles, log_posteriors]
def produce_sample(self, rng, state): self.particles = all_binary_vectors(state.num_patients, self.upper_bound) log_posteriors = bayes.log_posterior(self.particles, state.past_test_results, state.past_groups, state.log_prior_specificity, state.log_prior_1msensitivity, state.prior_infection_rate) self.particle_weights = temperature.importance_weights(log_posteriors)
def resample_move(self, rng, particles, state): """Resample / Move sequence.""" log_posterior_params = state.log_posterior_params log_posterior = bayes.log_posterior(particles, **log_posterior_params) alpha, logpialphaparticles = temperature.find_step_length( 0, log_posterior) rho = alpha particle_weights = temperature.importance_weights(logpialphaparticles) while rho < 1: print(f'Sampling {rho:.0%}', end='\r') rng, *rngs = jax.random.split(rng, 3) self._kernel.fit_model(particle_weights, particles) particles = particles[self.resample(rngs[0], particle_weights), :] particles = self.move(rngs[1], particles, rho, log_posterior_params) alpha, logpialphaparticles = temperature.find_step_length( rho, bayes.log_posterior(particles, **log_posterior_params)) particle_weights = temperature.importance_weights( logpialphaparticles) rho = rho + alpha return particle_weights, particles
def gibbs_kernel(rng, particles, rho, log_posterior_params, cycles=2, liu_modification=True): """Applies a (Liu modified) Gibbs kernel (with MH) update. Implements vanilla (sequential, looping over coordinates) Gibbs sampling. When The Liu variant comes from Jun Liu's remarks in https://academic.oup.com/biomet/article-abstract/83/3/681/241540?redirectedFrom=fulltext which essentially changes the acceptance of a flip from p(flip) / [ p(no flip) + p(flip) ] to min(1, p(flip) / p(no flip) ) In other words, Liu's modification increases the probability to flip. Args: rng: np.ndarray<int> random key. particles: np.ndarray [n_particles,n_patients] plausible infections states. rho: float, scaling for posterior. log_posterior_params: Dict of parameters to compute log-posterior. cycles: the number of times we want of do Gibbs sampling. liu_modification : use or not Liu's modification. Returns: A np.array representing the new particles. """ def gibbs_loop(i, rng_particles_log_posteriors): rng, particles, log_posteriors = rng_particles_log_posteriors i = i % num_patients # flip values at index i particles_flipped = jax.ops.index_update( particles, jax.ops.index[:, i], np.logical_not(particles[:, i])) # compute log_posterior of flipped particles log_posteriors_flipped_at_i = rho * bayes.log_posterior( particles_flipped, **log_posterior_params) # compute acceptance probability, depending on whether we use Liu mod. if liu_modification: log_proposal_ratio = log_posteriors_flipped_at_i - log_posteriors else: log_proposal_ratio = log_posteriors_flipped_at_i - np.logaddexp( log_posteriors_flipped_at_i, log_posteriors) # here the MH thresholding is implicitly done. rng, rng_unif = jax.random.split(rng, 2) random_values = jax.random.uniform(rng_unif, particles.shape[:1]) flipped_at_i = np.log(random_values) < log_proposal_ratio selected_at_i = np.logical_xor(flipped_at_i, particles[:, i]) particles = jax.ops.index_update(particles, jax.ops.index[:, i], selected_at_i) log_posteriors = np.where(flipped_at_i, log_posteriors_flipped_at_i, log_posteriors) return [rng, particles, log_posteriors] num_patients = particles.shape[1] log_posteriors = bayes.log_posterior(particles, **log_posterior_params) rng_particles = jax.lax.fori_loop(0, cycles * num_patients, gibbs_loop, [rng, particles, log_posteriors]) # TODO(cuturi) : might be relevant to forward log_posterior_particles return rng_particles[1]