def test_run_inference(self): df = pd.DataFrame({ 'Time': [1, 2, 3, 5, 6], 'Incidence Number': [10, 3, 4, 6, 9] }) ser_int1 = [[1, 2, 1, 0, 0, 0]] ser_int2 = [[1, 2], [3, 4]] inference1 = bp.BranchProPosteriorMultSI(df, ser_int1, 1, 0.2) inference1.run_inference(tau=2) inference2 = bp.BranchProPosteriorMultSI(df, ser_int2, 1, 0.2) inference2.run_inference(tau=2) self.assertEqual(len(inference1.inference_estimates), 3) self.assertEqual(len(inference1.inference_times), 3) self.assertEqual(len(inference1.inference_posterior.mean()), 3) self.assertEqual(len(inference2.inference_estimates), 3) self.assertEqual(len(inference2.inference_times), 3) self.assertEqual(len(inference2.inference_posterior.mean()), 3) def progress_fn(i): pass inference3 = bp.BranchProPosteriorMultSI(df, ser_int2, 1, 0.2) inference3.run_inference(tau=2, progress_fn=progress_fn)
def test__init__(self): df = pd.DataFrame({ 'Time': [1, 2, 3, 5, 6], 'Incidence Number': [10, 3, 4, 6, 9] }) ser_ints = [[1, 2], [0, 1]] bp.BranchProPosteriorMultSI(df, ser_ints, 1, 0.2) with self.assertRaises(TypeError) as test_excep: bp.BranchProPosteriorMultSI(df, [[0], 0], 1, 0.2) self.assertTrue('must be iterable' in str(test_excep.exception)) with self.assertRaises(TypeError) as test_excep: bp.BranchProPosteriorMultSI(df, [[1], ['zero']], 1, 0.2) self.assertTrue( 'distribution must contain' in str(test_excep.exception))
def test_get_serial_intervals(self): df = pd.DataFrame({ 'Time': [1, 2, 3, 5, 6], 'Incidence Number': [10, 3, 4, 6, 9] }) ser_ints = [[1, 2], [0, 1]] inference = bp.BranchProPosteriorMultSI(df, ser_ints, 1, 0.2) npt.assert_array_equal(inference.get_serial_intervals(), np.array([[1, 2], [0, 1]]))
def test_set_serial_intervals(self): df = pd.DataFrame({ 'Time': [1, 2, 3, 5, 6], 'Incidence Number': [10, 3, 4, 6, 9] }) ser_ints = [[1, 2], [0, 1]] new_ser_ints = [[3, 2, 0], [1, 2, 1], [4, 0, 1]] wrong_ser_ints = [(1), [2]] inference = bp.BranchProPosteriorMultSI(df, ser_ints, 1, 0.2) inference.set_serial_intervals(new_ser_ints) npt.assert_array_equal(inference.get_serial_intervals(), np.array([[3, 2, 0], [1, 2, 1], [4, 0, 1]])) with self.assertRaises(ValueError): inference.set_serial_intervals(wrong_ser_ints)
def test_get_intervals(self): df = pd.DataFrame({ 'Time': [1, 2, 3, 5, 6], 'Incidence Number': [0, 0, 0, 0, 0] }) ser_ints = [[1, 2], [0, 1]] inference = bp.BranchProPosteriorMultSI(df, ser_ints, 1, 0.2) inference.run_inference(tau=2) intervals_df = inference.get_intervals(.95) self.assertEqual(len(intervals_df['Time Points']), 3) self.assertEqual(len(intervals_df['Mean']), 3) self.assertEqual(len(intervals_df['Median']), 3) self.assertEqual(len(intervals_df['Lower bound CI']), 3) self.assertEqual(len(intervals_df['Upper bound CI']), 3) npt.assert_allclose(intervals_df['Mean'].to_numpy(), np.array([5.0] * 3), atol=0.5) self.assertEqual(intervals_df['Central Probability'].to_list(), [.95] * 3)
def update_posterior(self, mean, stdev, tau, central_prob, epsilon=None, progress_fn=None): """Update the posterior distribution based on slider values. Parameters ---------- mean (float) updated position on the slider for the mean of the prior for the Branch Pro model in the posterior. stdev (float) updated position on the slider for the standard deviation of the prior for the Branch Pro model in the posterior. tau (int) updated position on the slider for the tau window used in the running of the inference of the reproduction numbers of the Branch Pro model in the posterior. central_prob (float) updated position on the slider for the level of the computed credible interval of the estimated R number values. epsilon (float) updated position on the slider for the constant of proportionality between local and imported cases for the Branch Pro model in the posterior. progress_fn Function of integer argument to send to posterior run_inference. It can be used for dash callbacks set_progress (see update_posterior_storage in the app script) Returns ------- pandas.DataFrame The posterior distribution, summarized in a dataframe with the following columns: 'Time Points', 'Mean', 'Lower bound CI' and 'Upper bound CI' """ new_alpha = (mean / stdev)**2 new_beta = mean / (stdev**2) data = self.session_data.get('data_storage') if data is None: raise dash.exceptions.PreventUpdate() time_label, inc_label = data.columns[:2] num_cols = len(self.session_data.get('interval_storage').columns) prior_params = (new_alpha, new_beta) labels = {'time_key': time_label, 'inc_key': inc_label} if num_cols == 1: serial_interval = self.session_data.get( 'interval_storage').iloc[:, 0].values if 'Imported Cases' in data.columns: # Separate data into local and imported cases imported_data = pd.DataFrame({ time_label: data[time_label], inc_label: data['Imported Cases'] }) # Posterior follows the LocImp behaviour posterior = bp.LocImpBranchProPosterior( data, imported_data, epsilon, serial_interval, *prior_params, **labels) else: # Posterior follows the simple behaviour posterior = bp.BranchProPosterior(data, serial_interval, *prior_params, **labels) posterior.run_inference(tau) else: serial_intervals = self.session_data.get( 'interval_storage').values.T if 'Imported Cases' in data.columns: # Separate data into local and imported cases imported_data = pd.DataFrame({ time_label: data[time_label], inc_label: data['Imported Cases'] }) # Posterior follows the LocImp behaviour posterior = bp.LocImpBranchProPosteriorMultSI( data, imported_data, epsilon, serial_intervals, *prior_params, **labels) else: # Posterior follows the simple behaviour posterior = bp.BranchProPosteriorMultSI( data, serial_intervals, *prior_params, **labels) posterior.run_inference(tau, progress_fn=progress_fn) return posterior.get_intervals(central_prob)