def __init__(self, sheet):
    wetlab = wet_lab.WetLab(num_patients=sheet.num_patients)
    sampler = smc.SmcSampler(num_particles=sheet.num_patients * 150,
                             resample_at_each_iteration=True)

    policy_name = sheet.params['policy']
    if policy_name == 'G-MIMAX':
      policy = group_policy.MimaxPolicy(forward_iterations=2,
                                        backward_iterations=1)
    elif policy_name == 'InformativeDorfman':
      policy = group_policy.InformativeDorfmanPolicy(cut_off_high=0.95,
                                                     cut_off_low=0.001)
    elif policy_name == 'Random':
      policy = group_policy.MezardPolicy()
    else:
      raise ValueError(f'Unsupported policy {policy_name}')

    super().__init__(
        wetlab=wetlab,
        sampler=sampler,
        policy=policy,
        num_simulations=1,
        num_tests_per_cycle=sheet.params['tests per cycle'],
        max_test_cycles=sheet.params['cycles'],
        max_group_size=sheet.params['max group size'],
        prior_specificity=sheet.params['specificity'],
        prior_sensitivity=sheet.params['sensitivity'],
        prior_infection_rate=sheet.priors)
    self.sheet = sheet
Ejemplo n.º 2
0
  def __init__(self,
               workdir = None,
               wetlab=wet_lab.WetLab(),
               policy=group_policy.Dorfman(),
               sampler=smc.SmcSampler(),
               cheap_sampler=lbp.LbpSampler(),
               num_simulations = 1,
               num_tests_per_cycle = 10,
               max_test_cycles = 5,
               max_group_size = 8,
               prior_specificity = 0.97,
               prior_sensitivity = 0.85,
               prior_infection_rate = 0.05,
               metrics_cls=metrics.Metrics,
               export_metrics_every = 5):
    self._wetlab = wetlab
    self._policy = policy
    self._samplers = [cheap_sampler, sampler]

    self._max_test_cycles = max_test_cycles
    self._num_simulations = num_simulations
    self.state = state.State(self._wetlab.num_patients,
                             num_tests_per_cycle,
                             max_group_size,
                             prior_infection_rate,
                             prior_specificity,
                             prior_sensitivity)
    self.metrics = metrics_cls(
        workdir,
        self._num_simulations,
        self._max_test_cycles,
        self._wetlab.num_patients,
        self.state.num_tests_per_cycle)
    self._export_every = export_metrics_every
Ejemplo n.º 3
0
 def test_selector_with_particles(self, selector):
     sampler = sequential_monte_carlo.SmcSampler()
     rngs = jax.random.split(self.rng, 2)
     sampler.produce_sample(rngs[0], self.state)
     self.state.update_particles(sampler)
     self.assertEqual(np.size(self.state.groups_to_test), 0)
     selector(rngs[1], self.state)
     self.assertGreater(np.size(self.state.groups_to_test), 0)
Ejemplo n.º 4
0
 def test_sequential_monte_carlo(self, kernel):
   num_particles = 100
   sampler = sequential_monte_carlo.SmcSampler(
       num_particles=num_particles, kernel=kernel)
   self.assertIsNone(sampler.particles)
   self.assertIsNone(sampler.particle_weights)
   sampler.produce_sample(self.rng, self.state)
   self.assertIsNotNone(sampler.particles)
   self.assertIsNotNone(sampler.particle_weights)
   self.assertEqual(sampler.particles.shape,
                    (num_particles, self.state.num_patients))
Ejemplo n.º 5
0
 def __init__(self, sheet, num_particles=10000, policy=None):
     wetlab = wet_lab.WetLab(num_patients=sheet.num_patients)
     sampler = smc.SmcSampler(num_particles=num_particles,
                              resample_at_each_iteration=True)
     if policy is None:
         policy = group_policy.MimaxPolicy()
     super().__init__(wetlab=wetlab,
                      sampler=sampler,
                      policy=policy,
                      num_simulations=1,
                      num_tests_per_cycle=sheet.params['tests per cycle'],
                      max_test_cycles=sheet.params['cycles'],
                      max_group_size=sheet.params['max group size'],
                      prior_specificity=sheet.params['specificity'],
                      prior_sensitivity=sheet.params['sensitivity'],
                      prior_infection_rate=sheet.priors)
     self.sheet = sheet
Ejemplo n.º 6
0
    def __init__(self,
                 workdir=None,
                 wetlab=wet_lab.WetLab(),
                 policy=group_policy.Dorfman(),
                 sampler=smc.SmcSampler(),
                 cheap_sampler=lbp.LbpSampler(),
                 num_simulations=1,
                 num_tests_per_cycle=10,
                 max_test_cycles=5,
                 max_group_size=8,
                 prior_specificity=0.97,
                 prior_sensitivity=0.85,
                 prior_infection_rate=0.05,
                 metrics_cls=metrics.Metrics,
                 export_metrics_every=5):
        """Initializes simulation.

    Args:
      workdir: where results will be stored.
      wetlab: WetLab objet tasked with producing test results given groups.
      policy: group testing policy, a sequence of algorithms tasked with
        choosing groups to test. Can be adaptive to test environment (spec/sens)
        of tests. Can be adaptive to previously tested groups. Can leverage
        samplers to build information on what are the most likely disease status
        among patients.
      sampler: sampler that produces a posterior approximation. Instantiated by
        default to SmcSampler that resamples at each iteration to fix cases
        where the LBP sampler does not quite work the way it should.
      cheap_sampler: LBP object to yield cheap approximation of marginal.
      num_simulations: number of simulations run consecutively. Here randomness
        can come from the group testing policy (if it uses randomness), as well
        as diseased label, if WetLab.freeze_diseased is False.
      num_tests_per_cycle: number of tests a testing machine can carry out in
        the next testing cycle.
      max_test_cycles: number of cycles in total that should be considered.
      max_group_size: maximal size of groups, how many individuals can be pooled
      prior_specificity: best guess one has prior to simulation of the testing
        device's specificity per group size (or for all sizes, if singleton).
      prior_sensitivity: best guess one has prior to simulation of the testing
        device's sensitivity per group size (or for all sizes, if singleton).
      prior_infection_rate: best guess of prior probability for patient to be
        infected (same for all patients, if singleton)
      metrics_cls: class of metrics object used to store results.
      export_metrics_every: frequency of exports to file when carrying our
        num_simulations results.
    """

        self._wetlab = wetlab
        self._policy = policy
        self._samplers = [cheap_sampler, sampler]

        self._max_test_cycles = max_test_cycles
        self._num_simulations = num_simulations
        self.state = state.State(self._wetlab.num_patients,
                                 num_tests_per_cycle, max_group_size,
                                 prior_infection_rate, prior_specificity,
                                 prior_sensitivity)
        self.metrics = metrics_cls(workdir, self._num_simulations,
                                   self._max_test_cycles,
                                   self._wetlab.num_patients,
                                   self.state.num_tests_per_cycle)
        self._export_every = export_metrics_every