def test_calculate_contacts_no_policy(states_all_alive, contact_models): contact_policies = {} date = pd.Timestamp("2020-09-29") params = pd.DataFrame() first_half = round(len(states_all_alive) / 2) expected = np.array( [[1, i < first_half] for i in range(len(states_all_alive))], dtype=DTYPE_N_CONTACTS, ) contacts = calculate_contacts( contact_models=contact_models, states=states_all_alive, params=params, seed=itertools.count(), ) contacts = apply_contact_policies( contact_policies=contact_policies, contacts=contacts, states=states_all_alive, date=date, seed=itertools.count(), ) recurrent_contacts, random_contacts = post_process_contacts( contacts, states_all_alive, contact_models ) assert recurrent_contacts is None assert (random_contacts.to_numpy() == expected).all()
def test_calculate_contacts_policy_model_active(states_all_alive, contact_models): contact_policies = { "noone_meets": { "affected_contact_model": "first_half_meet", "start": pd.Timestamp("2020-09-01"), "end": pd.Timestamp("2020-09-30"), "policy": shut_down_model, }, } date = pd.Timestamp("2020-09-29") params = pd.DataFrame() expected = np.tile([1, 0], (len(states_all_alive), 1)).astype(DTYPE_N_CONTACTS) contacts = calculate_contacts( contact_models=contact_models, states=states_all_alive, params=params, seed=itertools.count(), ) contacts = apply_contact_policies( contact_policies=contact_policies, contacts=contacts, states=states_all_alive, date=date, seed=itertools.count(), ) recurrent_contacts, random_contacts = post_process_contacts( contacts, states_all_alive, contact_models ) assert recurrent_contacts is None assert (random_contacts.to_numpy() == expected).all()
def test_calculate_contacts_policy_active_policy_func(states_all_alive, contact_models): def reduce_to_1st_quarter(states, contacts, seed): contacts = contacts.copy() contacts[: int(len(contacts) / 4)] = 0 return contacts contact_policies = { "noone_meets": { "affected_contact_model": "first_half_meet", "start": pd.Timestamp("2020-09-01"), "end": pd.Timestamp("2020-09-30"), "policy": reduce_to_1st_quarter, }, } date = pd.Timestamp("2020-09-29") params = pd.DataFrame() expected = np.tile([1, 0], (len(states_all_alive), 1)).astype(DTYPE_N_CONTACTS) expected[2:4, 1] = 1 contacts = calculate_contacts( contact_models=contact_models, states=states_all_alive, params=params, seed=itertools.count(), ) contacts = apply_contact_policies( contact_policies=contact_policies, contacts=contacts, states=states_all_alive, date=date, seed=itertools.count(), ) recurrent_contacts, random_contacts = post_process_contacts( contacts, states_all_alive, contact_models ) assert recurrent_contacts is None assert (random_contacts.to_numpy() == expected).all()
def test_calculate_contacts_with_dead(states_with_dead, contact_models): contact_policies = {} date = pd.Timestamp("2020-09-29") params = pd.DataFrame() expected = np.array( [ [0, 0], [1, 1], [0, 0], [1, 1], [1, 0], [0, 0], [1, 0], [0, 0], ], dtype=DTYPE_N_CONTACTS, ) contacts = calculate_contacts( contact_models=contact_models, states=states_with_dead, params=params, seed=itertools.count(), ) contacts = apply_contact_policies( contact_policies=contact_policies, contacts=contacts, states=states_with_dead, date=date, seed=itertools.count(), ) recurrent_contacts, random_contacts = post_process_contacts( contacts, states_with_dead, contact_models ) assert recurrent_contacts is None assert (random_contacts.to_numpy() == expected).all()
def test_calculate_contacts_policy_scalar_active(states_all_alive): contact_models = { "ten": { "model": lambda params, states, seed: pd.Series(10, index=states.index), "is_recurrent": False, } } contact_policies = { "noone_meets": { "affected_contact_model": "ten", "start": pd.Timestamp("2020-09-01"), "end": pd.Timestamp("2020-09-30"), "policy": 0.5, }, } date = pd.Timestamp("2020-09-29") params = pd.DataFrame() contacts = calculate_contacts( contact_models=contact_models, states=states_all_alive, params=params, seed=itertools.count(), ) contacts = apply_contact_policies( contact_policies=contact_policies, contacts=contacts, states=states_all_alive, date=date, seed=itertools.count(), ) recurrent_contacts, random_contacts = post_process_contacts( contacts, states_all_alive, contact_models ) assert recurrent_contacts is None assert (random_contacts.to_numpy() == 5).all()