コード例 #1
0
  def __init__(self,
               workdir = None,
               wetlab=wet_lab.WetLab(),
               policy=group_policy.Dorfman(),
               sampler=smc.SmcSampler(),
               cheap_sampler=lbp.LbpSampler(),
               num_simulations = 1,
               num_tests_per_cycle = 10,
               max_test_cycles = 5,
               max_group_size = 8,
               prior_specificity = 0.97,
               prior_sensitivity = 0.85,
               prior_infection_rate = 0.05,
               metrics_cls=metrics.Metrics,
               export_metrics_every = 5):
    self._wetlab = wetlab
    self._policy = policy
    self._samplers = [cheap_sampler, sampler]

    self._max_test_cycles = max_test_cycles
    self._num_simulations = num_simulations
    self.state = state.State(self._wetlab.num_patients,
                             num_tests_per_cycle,
                             max_group_size,
                             prior_infection_rate,
                             prior_specificity,
                             prior_sensitivity)
    self.metrics = metrics_cls(
        workdir,
        self._num_simulations,
        self._max_test_cycles,
        self._wetlab.num_patients,
        self.state.num_tests_per_cycle)
    self._export_every = export_metrics_every
コード例 #2
0
  def setUp(self):
    super().setUp()
    self.num_patients = 20
    self.num_tests_per_cycle = 3
    self.state = state.State(
        self.num_patients, self.num_tests_per_cycle, max_group_size=4,
        prior_infection_rate=0.30, prior_specificity=0.9, prior_sensitivity=0.7)

    self.rng = jax.random.PRNGKey(0)
コード例 #3
0
 def setUp(self):
   super().setUp()
   self.rng = jax.random.PRNGKey(0)
   self.state = state.State(num_patients=32,
                            num_tests_per_cycle=3,
                            max_group_size=5,
                            prior_infection_rate=0.05,
                            prior_specificity=0.95,
                            prior_sensitivity=0.75)
コード例 #4
0
 def test_exhaustive(self):
   # Reduce the number of possible states.
   exh_state = state.State(
       num_patients=8, num_tests_per_cycle=3, max_group_size=5,
       prior_infection_rate=0.05, prior_specificity=0.95,
       prior_sensitivity=0.75)
   sampler = exhaustive.ExhaustiveSampler()
   self.assertIsNone(sampler.particles)
   self.assertIsNone(sampler.particle_weights)
   sampler.produce_sample(self.rng, exh_state)
   self.assertIsNotNone(sampler.particles)
   self.assertIsNotNone(sampler.particle_weights)
コード例 #5
0
 def test_act(self):
   num_patients = 40
   num_tests_per_cycle = 4
   s = state.State(num_patients, num_tests_per_cycle,
                   max_group_size=5, prior_infection_rate=0.05,
                   prior_specificity=0.95, prior_sensitivity=0.80)
   self.assertEqual(np.size(s.groups_to_test), 0)
   self.assertEqual(self.policy.index, 0)
   self.policy.act(self.rng, s)
   self.assertGreater(np.size(s.groups_to_test), 0)
   self.assertEqual(s.groups_to_test.shape[1], num_patients)
   self.assertGreater(s.groups_to_test.shape[0], 0)
   self.assertEqual(self.policy.index, 1)
コード例 #6
0
 def setUp(self):
     super().setUp()
     self.rng = jax.random.PRNGKey(0)
     self.state = state.State(num_patients=72,
                              num_tests_per_cycle=3,
                              max_group_size=5,
                              prior_infection_rate=0.05,
                              prior_specificity=0.95,
                              prior_sensitivity=0.75)
     self.num_groups = 4
     self.rng, *rngs = jax.random.split(self.rng, 3)
     self.groups = jax.random.uniform(
         rngs[0], (self.num_groups, self.state.num_patients)) > 0.3
     self.results = jax.random.uniform(rngs[1], (self.num_groups, )) > 0.2
コード例 #7
0
    def __init__(self,
                 workdir=None,
                 wetlab=wet_lab.WetLab(),
                 policy=group_policy.Dorfman(),
                 sampler=smc.SmcSampler(),
                 cheap_sampler=lbp.LbpSampler(),
                 num_simulations=1,
                 num_tests_per_cycle=10,
                 max_test_cycles=5,
                 max_group_size=8,
                 prior_specificity=0.97,
                 prior_sensitivity=0.85,
                 prior_infection_rate=0.05,
                 metrics_cls=metrics.Metrics,
                 export_metrics_every=5):
        """Initializes simulation.

    Args:
      workdir: where results will be stored.
      wetlab: WetLab objet tasked with producing test results given groups.
      policy: group testing policy, a sequence of algorithms tasked with
        choosing groups to test. Can be adaptive to test environment (spec/sens)
        of tests. Can be adaptive to previously tested groups. Can leverage
        samplers to build information on what are the most likely disease status
        among patients.
      sampler: sampler that produces a posterior approximation. Instantiated by
        default to SmcSampler that resamples at each iteration to fix cases
        where the LBP sampler does not quite work the way it should.
      cheap_sampler: LBP object to yield cheap approximation of marginal.
      num_simulations: number of simulations run consecutively. Here randomness
        can come from the group testing policy (if it uses randomness), as well
        as diseased label, if WetLab.freeze_diseased is False.
      num_tests_per_cycle: number of tests a testing machine can carry out in
        the next testing cycle.
      max_test_cycles: number of cycles in total that should be considered.
      max_group_size: maximal size of groups, how many individuals can be pooled
      prior_specificity: best guess one has prior to simulation of the testing
        device's specificity per group size (or for all sizes, if singleton).
      prior_sensitivity: best guess one has prior to simulation of the testing
        device's sensitivity per group size (or for all sizes, if singleton).
      prior_infection_rate: best guess of prior probability for patient to be
        infected (same for all patients, if singleton)
      metrics_cls: class of metrics object used to store results.
      export_metrics_every: frequency of exports to file when carrying our
        num_simulations results.
    """

        self._wetlab = wetlab
        self._policy = policy
        self._samplers = [cheap_sampler, sampler]

        self._max_test_cycles = max_test_cycles
        self._num_simulations = num_simulations
        self.state = state.State(self._wetlab.num_patients,
                                 num_tests_per_cycle, max_group_size,
                                 prior_infection_rate, prior_specificity,
                                 prior_sensitivity)
        self.metrics = metrics_cls(workdir, self._num_simulations,
                                   self._max_test_cycles,
                                   self._wetlab.num_patients,
                                   self.state.num_tests_per_cycle)
        self._export_every = export_metrics_every