Ejemplo n.º 1
0
 def resample_move(self, rng, particles, state, sampling_from_scratch):
     """Resample / Move sequence."""
     # recover log_posterior params from state. if we resample, only recover
     # latest wave of tests. if not resampling, get entire information of prior
     #
     log_posterior_params = state.log_posterior_params(
         sampling_from_scratch, self._start_from_prior, self._sampled_up_to)
     log_base_measure_params = state.log_base_measure_params(
         sampling_from_scratch, self._start_from_prior, self._sampled_up_to)
     # add log weights to dictionary of log_prior parameters
     log_posterior = bayes.log_probability(particles,
                                           **log_posterior_params)
     alpha, log_tempered_probability = temperature.find_step_length(
         0, log_posterior)
     rho = alpha
     particle_weights = temperature.importance_weights(
         log_tempered_probability)
     print(f'Sampling {rho:.0%}', end='\r')
     while rho < 1:
         print(f'Sampling {rho:.0%}', end='\r')
         rng, *rngs = jax.random.split(rng, 3)
         self._kernel.fit_model(particle_weights, particles)
         particles = particles[self.resample(rngs[0], particle_weights), :]
         particles = self.move(rngs[1], particles, rho,
                               log_posterior_params,
                               log_base_measure_params)
         log_posterior = bayes.log_probability(particles,
                                               **log_posterior_params)
         alpha, log_tempered_probability = temperature.find_step_length(
             rho, log_posterior)
         particle_weights = temperature.importance_weights(
             log_tempered_probability)
         rho = rho + alpha
     return particle_weights, particles
Ejemplo n.º 2
0
 def produce_sample(self, rng, state):
     self.particles = all_binary_vectors(state.num_patients,
                                         self.upper_bound)
     log_posteriors = bayes.log_posterior(self.particles,
                                          state.past_test_results,
                                          state.past_groups,
                                          state.log_prior_specificity,
                                          state.log_prior_1msensitivity,
                                          state.prior_infection_rate)
     self.particle_weights = temperature.importance_weights(log_posteriors)
 def resample_move(self, rng, particles, state):
     """Resample / Move sequence."""
     log_posterior_params = state.log_posterior_params
     log_posterior = bayes.log_posterior(particles, **log_posterior_params)
     alpha, logpialphaparticles = temperature.find_step_length(
         0, log_posterior)
     rho = alpha
     particle_weights = temperature.importance_weights(logpialphaparticles)
     while rho < 1:
         print(f'Sampling {rho:.0%}', end='\r')
         rng, *rngs = jax.random.split(rng, 3)
         self._kernel.fit_model(particle_weights, particles)
         particles = particles[self.resample(rngs[0], particle_weights), :]
         particles = self.move(rngs[1], particles, rho,
                               log_posterior_params)
         alpha, logpialphaparticles = temperature.find_step_length(
             rho, bayes.log_posterior(particles, **log_posterior_params))
         particle_weights = temperature.importance_weights(
             logpialphaparticles)
         rho = rho + alpha
     return particle_weights, particles