def __init__(self, sheet):
    wetlab = wet_lab.WetLab(num_patients=sheet.num_patients)
    sampler = smc.SmcSampler(num_particles=sheet.num_patients * 150,
                             resample_at_each_iteration=True)

    policy_name = sheet.params['policy']
    if policy_name == 'G-MIMAX':
      policy = group_policy.MimaxPolicy(forward_iterations=2,
                                        backward_iterations=1)
    elif policy_name == 'InformativeDorfman':
      policy = group_policy.InformativeDorfmanPolicy(cut_off_high=0.95,
                                                     cut_off_low=0.001)
    elif policy_name == 'Random':
      policy = group_policy.MezardPolicy()
    else:
      raise ValueError(f'Unsupported policy {policy_name}')

    super().__init__(
        wetlab=wetlab,
        sampler=sampler,
        policy=policy,
        num_simulations=1,
        num_tests_per_cycle=sheet.params['tests per cycle'],
        max_test_cycles=sheet.params['cycles'],
        max_group_size=sheet.params['max group size'],
        prior_specificity=sheet.params['specificity'],
        prior_sensitivity=sheet.params['sensitivity'],
        prior_infection_rate=sheet.priors)
    self.sheet = sheet
Ejemplo n.º 2
0
  def __init__(self,
               workdir = None,
               wetlab=wet_lab.WetLab(),
               policy=group_policy.Dorfman(),
               sampler=smc.SmcSampler(),
               cheap_sampler=lbp.LbpSampler(),
               num_simulations = 1,
               num_tests_per_cycle = 10,
               max_test_cycles = 5,
               max_group_size = 8,
               prior_specificity = 0.97,
               prior_sensitivity = 0.85,
               prior_infection_rate = 0.05,
               metrics_cls=metrics.Metrics,
               export_metrics_every = 5):
    self._wetlab = wetlab
    self._policy = policy
    self._samplers = [cheap_sampler, sampler]

    self._max_test_cycles = max_test_cycles
    self._num_simulations = num_simulations
    self.state = state.State(self._wetlab.num_patients,
                             num_tests_per_cycle,
                             max_group_size,
                             prior_infection_rate,
                             prior_specificity,
                             prior_sensitivity)
    self.metrics = metrics_cls(
        workdir,
        self._num_simulations,
        self._max_test_cycles,
        self._wetlab.num_patients,
        self.state.num_tests_per_cycle)
    self._export_every = export_metrics_every
Ejemplo n.º 3
0
 def test_run(self):
   """A setup were we find a solution before the last cycle."""
   sim = simulator.Simulator(
       None, wet_lab.WetLab(num_patients=100),
       num_simulations=self.num_simulations,
       max_test_cycles=self.max_test_cycles,
       num_tests_per_cycle=4, max_group_size=5)
   sim.run(0)
   last_groups = sim.metrics.groups[0, -1]
   self.assertFalse(np.all(np.isnan(last_groups)))
Ejemplo n.º 4
0
 def test_reset_unfrozen(self):
     num_patients = 10
     wetlab = wet_lab.WetLab(num_patients=num_patients,
                             freeze_diseased=False)
     self.assertIsNone(wetlab.diseased)
     wetlab.reset(self.rng)
     diseased = np.array(wetlab.diseased)
     rng = jax.random.split(self.rng)[0]
     wetlab.reset(rng)
     self.assertFalse(np.all(wetlab.diseased == diseased))
Ejemplo n.º 5
0
 def test_reset_frozen(self):
     num_patients = 10
     wetlab = wet_lab.WetLab(num_patients=num_patients,
                             freeze_diseased=True)
     self.assertIsNone(wetlab.diseased)
     wetlab.reset(self.rng)
     self.assertIsNotNone(wetlab.diseased)
     diseased = np.array(wetlab.diseased)
     self.assertEqual(wetlab.diseased.shape[0], num_patients)
     rng = jax.random.split(self.rng)[0]
     wetlab.reset(rng)
     self.assertTrue(np.all(wetlab.diseased == diseased))
Ejemplo n.º 6
0
 def setUp(self):
   super().setUp()
   self.num_patients = 10
   self.wetlab = wet_lab.WetLab(self.num_patients, freeze_diseased=True)
   self.num_simulations = 1
   self.max_test_cycles = 4
   self.policy = policy.Policy([random.RandomSelector()])
   self.simulator = simulator.Simulator(
       workdir=None, wetlab=self.wetlab, policy=self.policy,
       num_simulations=self.num_simulations,
       max_test_cycles=self.max_test_cycles,
       num_tests_per_cycle=4, max_group_size=5)
   self.rng = jax.random.PRNGKey(0)
Ejemplo n.º 7
0
    def test_group_tests_outputs(self):
        num_patients = 10
        wetlab = wet_lab.WetLab(num_patients=num_patients,
                                freeze_diseased=True)
        rngs = jax.random.split(self.rng, 3)
        wetlab.reset(rngs[0])

        num_groups = 5
        groups = jax.random.uniform(rngs[1],
                                    shape=(num_groups, num_patients)) < 0.3
        output = wetlab.group_tests_outputs(rngs[2], groups)
        self.assertDtypesMatch(output, wetlab.diseased)
        self.assertEqual(output.shape, (num_groups, ))
        self.assertEqual(np.any(wetlab.diseased), np.any(output))
Ejemplo n.º 8
0
 def __init__(self, sheet, num_particles=10000, policy=None):
     wetlab = wet_lab.WetLab(num_patients=sheet.num_patients)
     sampler = smc.SmcSampler(num_particles=num_particles,
                              resample_at_each_iteration=True)
     if policy is None:
         policy = group_policy.MimaxPolicy()
     super().__init__(wetlab=wetlab,
                      sampler=sampler,
                      policy=policy,
                      num_simulations=1,
                      num_tests_per_cycle=sheet.params['tests per cycle'],
                      max_test_cycles=sheet.params['cycles'],
                      max_group_size=sheet.params['max group size'],
                      prior_specificity=sheet.params['specificity'],
                      prior_sensitivity=sheet.params['sensitivity'],
                      prior_infection_rate=sheet.priors)
     self.sheet = sheet
Ejemplo n.º 9
0
    def __init__(self,
                 workdir=None,
                 wetlab=wet_lab.WetLab(),
                 policy=group_policy.Dorfman(),
                 sampler=smc.SmcSampler(),
                 cheap_sampler=lbp.LbpSampler(),
                 num_simulations=1,
                 num_tests_per_cycle=10,
                 max_test_cycles=5,
                 max_group_size=8,
                 prior_specificity=0.97,
                 prior_sensitivity=0.85,
                 prior_infection_rate=0.05,
                 metrics_cls=metrics.Metrics,
                 export_metrics_every=5):
        """Initializes simulation.

    Args:
      workdir: where results will be stored.
      wetlab: WetLab objet tasked with producing test results given groups.
      policy: group testing policy, a sequence of algorithms tasked with
        choosing groups to test. Can be adaptive to test environment (spec/sens)
        of tests. Can be adaptive to previously tested groups. Can leverage
        samplers to build information on what are the most likely disease status
        among patients.
      sampler: sampler that produces a posterior approximation. Instantiated by
        default to SmcSampler that resamples at each iteration to fix cases
        where the LBP sampler does not quite work the way it should.
      cheap_sampler: LBP object to yield cheap approximation of marginal.
      num_simulations: number of simulations run consecutively. Here randomness
        can come from the group testing policy (if it uses randomness), as well
        as diseased label, if WetLab.freeze_diseased is False.
      num_tests_per_cycle: number of tests a testing machine can carry out in
        the next testing cycle.
      max_test_cycles: number of cycles in total that should be considered.
      max_group_size: maximal size of groups, how many individuals can be pooled
      prior_specificity: best guess one has prior to simulation of the testing
        device's specificity per group size (or for all sizes, if singleton).
      prior_sensitivity: best guess one has prior to simulation of the testing
        device's sensitivity per group size (or for all sizes, if singleton).
      prior_infection_rate: best guess of prior probability for patient to be
        infected (same for all patients, if singleton)
      metrics_cls: class of metrics object used to store results.
      export_metrics_every: frequency of exports to file when carrying our
        num_simulations results.
    """

        self._wetlab = wetlab
        self._policy = policy
        self._samplers = [cheap_sampler, sampler]

        self._max_test_cycles = max_test_cycles
        self._num_simulations = num_simulations
        self.state = state.State(self._wetlab.num_patients,
                                 num_tests_per_cycle, max_group_size,
                                 prior_infection_rate, prior_specificity,
                                 prior_sensitivity)
        self.metrics = metrics_cls(workdir, self._num_simulations,
                                   self._max_test_cycles,
                                   self._wetlab.num_patients,
                                   self.state.num_tests_per_cycle)
        self._export_every = export_metrics_every