def __init__(self, sheet): wetlab = wet_lab.WetLab(num_patients=sheet.num_patients) sampler = smc.SmcSampler(num_particles=sheet.num_patients * 150, resample_at_each_iteration=True) policy_name = sheet.params['policy'] if policy_name == 'G-MIMAX': policy = group_policy.MimaxPolicy(forward_iterations=2, backward_iterations=1) elif policy_name == 'InformativeDorfman': policy = group_policy.InformativeDorfmanPolicy(cut_off_high=0.95, cut_off_low=0.001) elif policy_name == 'Random': policy = group_policy.MezardPolicy() else: raise ValueError(f'Unsupported policy {policy_name}') super().__init__( wetlab=wetlab, sampler=sampler, policy=policy, num_simulations=1, num_tests_per_cycle=sheet.params['tests per cycle'], max_test_cycles=sheet.params['cycles'], max_group_size=sheet.params['max group size'], prior_specificity=sheet.params['specificity'], prior_sensitivity=sheet.params['sensitivity'], prior_infection_rate=sheet.priors) self.sheet = sheet
def __init__(self, workdir = None, wetlab=wet_lab.WetLab(), policy=group_policy.Dorfman(), sampler=smc.SmcSampler(), cheap_sampler=lbp.LbpSampler(), num_simulations = 1, num_tests_per_cycle = 10, max_test_cycles = 5, max_group_size = 8, prior_specificity = 0.97, prior_sensitivity = 0.85, prior_infection_rate = 0.05, metrics_cls=metrics.Metrics, export_metrics_every = 5): self._wetlab = wetlab self._policy = policy self._samplers = [cheap_sampler, sampler] self._max_test_cycles = max_test_cycles self._num_simulations = num_simulations self.state = state.State(self._wetlab.num_patients, num_tests_per_cycle, max_group_size, prior_infection_rate, prior_specificity, prior_sensitivity) self.metrics = metrics_cls( workdir, self._num_simulations, self._max_test_cycles, self._wetlab.num_patients, self.state.num_tests_per_cycle) self._export_every = export_metrics_every
def test_run(self): """A setup were we find a solution before the last cycle.""" sim = simulator.Simulator( None, wet_lab.WetLab(num_patients=100), num_simulations=self.num_simulations, max_test_cycles=self.max_test_cycles, num_tests_per_cycle=4, max_group_size=5) sim.run(0) last_groups = sim.metrics.groups[0, -1] self.assertFalse(np.all(np.isnan(last_groups)))
def test_reset_unfrozen(self): num_patients = 10 wetlab = wet_lab.WetLab(num_patients=num_patients, freeze_diseased=False) self.assertIsNone(wetlab.diseased) wetlab.reset(self.rng) diseased = np.array(wetlab.diseased) rng = jax.random.split(self.rng)[0] wetlab.reset(rng) self.assertFalse(np.all(wetlab.diseased == diseased))
def test_reset_frozen(self): num_patients = 10 wetlab = wet_lab.WetLab(num_patients=num_patients, freeze_diseased=True) self.assertIsNone(wetlab.diseased) wetlab.reset(self.rng) self.assertIsNotNone(wetlab.diseased) diseased = np.array(wetlab.diseased) self.assertEqual(wetlab.diseased.shape[0], num_patients) rng = jax.random.split(self.rng)[0] wetlab.reset(rng) self.assertTrue(np.all(wetlab.diseased == diseased))
def setUp(self): super().setUp() self.num_patients = 10 self.wetlab = wet_lab.WetLab(self.num_patients, freeze_diseased=True) self.num_simulations = 1 self.max_test_cycles = 4 self.policy = policy.Policy([random.RandomSelector()]) self.simulator = simulator.Simulator( workdir=None, wetlab=self.wetlab, policy=self.policy, num_simulations=self.num_simulations, max_test_cycles=self.max_test_cycles, num_tests_per_cycle=4, max_group_size=5) self.rng = jax.random.PRNGKey(0)
def test_group_tests_outputs(self): num_patients = 10 wetlab = wet_lab.WetLab(num_patients=num_patients, freeze_diseased=True) rngs = jax.random.split(self.rng, 3) wetlab.reset(rngs[0]) num_groups = 5 groups = jax.random.uniform(rngs[1], shape=(num_groups, num_patients)) < 0.3 output = wetlab.group_tests_outputs(rngs[2], groups) self.assertDtypesMatch(output, wetlab.diseased) self.assertEqual(output.shape, (num_groups, )) self.assertEqual(np.any(wetlab.diseased), np.any(output))
def __init__(self, sheet, num_particles=10000, policy=None): wetlab = wet_lab.WetLab(num_patients=sheet.num_patients) sampler = smc.SmcSampler(num_particles=num_particles, resample_at_each_iteration=True) if policy is None: policy = group_policy.MimaxPolicy() super().__init__(wetlab=wetlab, sampler=sampler, policy=policy, num_simulations=1, num_tests_per_cycle=sheet.params['tests per cycle'], max_test_cycles=sheet.params['cycles'], max_group_size=sheet.params['max group size'], prior_specificity=sheet.params['specificity'], prior_sensitivity=sheet.params['sensitivity'], prior_infection_rate=sheet.priors) self.sheet = sheet
def __init__(self, workdir=None, wetlab=wet_lab.WetLab(), policy=group_policy.Dorfman(), sampler=smc.SmcSampler(), cheap_sampler=lbp.LbpSampler(), num_simulations=1, num_tests_per_cycle=10, max_test_cycles=5, max_group_size=8, prior_specificity=0.97, prior_sensitivity=0.85, prior_infection_rate=0.05, metrics_cls=metrics.Metrics, export_metrics_every=5): """Initializes simulation. Args: workdir: where results will be stored. wetlab: WetLab objet tasked with producing test results given groups. policy: group testing policy, a sequence of algorithms tasked with choosing groups to test. Can be adaptive to test environment (spec/sens) of tests. Can be adaptive to previously tested groups. Can leverage samplers to build information on what are the most likely disease status among patients. sampler: sampler that produces a posterior approximation. Instantiated by default to SmcSampler that resamples at each iteration to fix cases where the LBP sampler does not quite work the way it should. cheap_sampler: LBP object to yield cheap approximation of marginal. num_simulations: number of simulations run consecutively. Here randomness can come from the group testing policy (if it uses randomness), as well as diseased label, if WetLab.freeze_diseased is False. num_tests_per_cycle: number of tests a testing machine can carry out in the next testing cycle. max_test_cycles: number of cycles in total that should be considered. max_group_size: maximal size of groups, how many individuals can be pooled prior_specificity: best guess one has prior to simulation of the testing device's specificity per group size (or for all sizes, if singleton). prior_sensitivity: best guess one has prior to simulation of the testing device's sensitivity per group size (or for all sizes, if singleton). prior_infection_rate: best guess of prior probability for patient to be infected (same for all patients, if singleton) metrics_cls: class of metrics object used to store results. export_metrics_every: frequency of exports to file when carrying our num_simulations results. """ self._wetlab = wetlab self._policy = policy self._samplers = [cheap_sampler, sampler] self._max_test_cycles = max_test_cycles self._num_simulations = num_simulations self.state = state.State(self._wetlab.num_patients, num_tests_per_cycle, max_group_size, prior_infection_rate, prior_specificity, prior_sensitivity) self.metrics = metrics_cls(workdir, self._num_simulations, self._max_test_cycles, self._wetlab.num_patients, self.state.num_tests_per_cycle) self._export_every = export_metrics_every