def __init__(self, workdir = None, wetlab=wet_lab.WetLab(), policy=group_policy.Dorfman(), sampler=smc.SmcSampler(), cheap_sampler=lbp.LbpSampler(), num_simulations = 1, num_tests_per_cycle = 10, max_test_cycles = 5, max_group_size = 8, prior_specificity = 0.97, prior_sensitivity = 0.85, prior_infection_rate = 0.05, metrics_cls=metrics.Metrics, export_metrics_every = 5): self._wetlab = wetlab self._policy = policy self._samplers = [cheap_sampler, sampler] self._max_test_cycles = max_test_cycles self._num_simulations = num_simulations self.state = state.State(self._wetlab.num_patients, num_tests_per_cycle, max_group_size, prior_infection_rate, prior_specificity, prior_sensitivity) self.metrics = metrics_cls( workdir, self._num_simulations, self._max_test_cycles, self._wetlab.num_patients, self.state.num_tests_per_cycle) self._export_every = export_metrics_every
def setUp(self): super().setUp() self.num_patients = 20 self.num_tests_per_cycle = 3 self.state = state.State( self.num_patients, self.num_tests_per_cycle, max_group_size=4, prior_infection_rate=0.30, prior_specificity=0.9, prior_sensitivity=0.7) self.rng = jax.random.PRNGKey(0)
def setUp(self): super().setUp() self.rng = jax.random.PRNGKey(0) self.state = state.State(num_patients=32, num_tests_per_cycle=3, max_group_size=5, prior_infection_rate=0.05, prior_specificity=0.95, prior_sensitivity=0.75)
def test_exhaustive(self): # Reduce the number of possible states. exh_state = state.State( num_patients=8, num_tests_per_cycle=3, max_group_size=5, prior_infection_rate=0.05, prior_specificity=0.95, prior_sensitivity=0.75) sampler = exhaustive.ExhaustiveSampler() self.assertIsNone(sampler.particles) self.assertIsNone(sampler.particle_weights) sampler.produce_sample(self.rng, exh_state) self.assertIsNotNone(sampler.particles) self.assertIsNotNone(sampler.particle_weights)
def test_act(self): num_patients = 40 num_tests_per_cycle = 4 s = state.State(num_patients, num_tests_per_cycle, max_group_size=5, prior_infection_rate=0.05, prior_specificity=0.95, prior_sensitivity=0.80) self.assertEqual(np.size(s.groups_to_test), 0) self.assertEqual(self.policy.index, 0) self.policy.act(self.rng, s) self.assertGreater(np.size(s.groups_to_test), 0) self.assertEqual(s.groups_to_test.shape[1], num_patients) self.assertGreater(s.groups_to_test.shape[0], 0) self.assertEqual(self.policy.index, 1)
def setUp(self): super().setUp() self.rng = jax.random.PRNGKey(0) self.state = state.State(num_patients=72, num_tests_per_cycle=3, max_group_size=5, prior_infection_rate=0.05, prior_specificity=0.95, prior_sensitivity=0.75) self.num_groups = 4 self.rng, *rngs = jax.random.split(self.rng, 3) self.groups = jax.random.uniform( rngs[0], (self.num_groups, self.state.num_patients)) > 0.3 self.results = jax.random.uniform(rngs[1], (self.num_groups, )) > 0.2
def __init__(self, workdir=None, wetlab=wet_lab.WetLab(), policy=group_policy.Dorfman(), sampler=smc.SmcSampler(), cheap_sampler=lbp.LbpSampler(), num_simulations=1, num_tests_per_cycle=10, max_test_cycles=5, max_group_size=8, prior_specificity=0.97, prior_sensitivity=0.85, prior_infection_rate=0.05, metrics_cls=metrics.Metrics, export_metrics_every=5): """Initializes simulation. Args: workdir: where results will be stored. wetlab: WetLab objet tasked with producing test results given groups. policy: group testing policy, a sequence of algorithms tasked with choosing groups to test. Can be adaptive to test environment (spec/sens) of tests. Can be adaptive to previously tested groups. Can leverage samplers to build information on what are the most likely disease status among patients. sampler: sampler that produces a posterior approximation. Instantiated by default to SmcSampler that resamples at each iteration to fix cases where the LBP sampler does not quite work the way it should. cheap_sampler: LBP object to yield cheap approximation of marginal. num_simulations: number of simulations run consecutively. Here randomness can come from the group testing policy (if it uses randomness), as well as diseased label, if WetLab.freeze_diseased is False. num_tests_per_cycle: number of tests a testing machine can carry out in the next testing cycle. max_test_cycles: number of cycles in total that should be considered. max_group_size: maximal size of groups, how many individuals can be pooled prior_specificity: best guess one has prior to simulation of the testing device's specificity per group size (or for all sizes, if singleton). prior_sensitivity: best guess one has prior to simulation of the testing device's sensitivity per group size (or for all sizes, if singleton). prior_infection_rate: best guess of prior probability for patient to be infected (same for all patients, if singleton) metrics_cls: class of metrics object used to store results. export_metrics_every: frequency of exports to file when carrying our num_simulations results. """ self._wetlab = wetlab self._policy = policy self._samplers = [cheap_sampler, sampler] self._max_test_cycles = max_test_cycles self._num_simulations = num_simulations self.state = state.State(self._wetlab.num_patients, num_tests_per_cycle, max_group_size, prior_infection_rate, prior_specificity, prior_sensitivity) self.metrics = metrics_cls(workdir, self._num_simulations, self._max_test_cycles, self._wetlab.num_patients, self.state.num_tests_per_cycle) self._export_every = export_metrics_every