def test_calculate_contacts_no_policy(states_all_alive, contact_models):
    contact_policies = {}
    date = pd.Timestamp("2020-09-29")
    params = pd.DataFrame()
    first_half = round(len(states_all_alive) / 2)
    expected = np.array(
        [[1, i < first_half] for i in range(len(states_all_alive))],
        dtype=DTYPE_N_CONTACTS,
    )
    contacts = calculate_contacts(
        contact_models=contact_models,
        states=states_all_alive,
        params=params,
        seed=itertools.count(),
    )

    contacts = apply_contact_policies(
        contact_policies=contact_policies,
        contacts=contacts,
        states=states_all_alive,
        date=date,
        seed=itertools.count(),
    )

    recurrent_contacts, random_contacts = post_process_contacts(
        contacts, states_all_alive, contact_models
    )

    assert recurrent_contacts is None
    assert (random_contacts.to_numpy() == expected).all()
def test_calculate_contacts_policy_model_active(states_all_alive, contact_models):
    contact_policies = {
        "noone_meets": {
            "affected_contact_model": "first_half_meet",
            "start": pd.Timestamp("2020-09-01"),
            "end": pd.Timestamp("2020-09-30"),
            "policy": shut_down_model,
        },
    }
    date = pd.Timestamp("2020-09-29")
    params = pd.DataFrame()
    expected = np.tile([1, 0], (len(states_all_alive), 1)).astype(DTYPE_N_CONTACTS)
    contacts = calculate_contacts(
        contact_models=contact_models,
        states=states_all_alive,
        params=params,
        seed=itertools.count(),
    )

    contacts = apply_contact_policies(
        contact_policies=contact_policies,
        contacts=contacts,
        states=states_all_alive,
        date=date,
        seed=itertools.count(),
    )

    recurrent_contacts, random_contacts = post_process_contacts(
        contacts, states_all_alive, contact_models
    )

    assert recurrent_contacts is None
    assert (random_contacts.to_numpy() == expected).all()
def test_calculate_contacts_policy_active_policy_func(states_all_alive, contact_models):
    def reduce_to_1st_quarter(states, contacts, seed):
        contacts = contacts.copy()
        contacts[: int(len(contacts) / 4)] = 0
        return contacts

    contact_policies = {
        "noone_meets": {
            "affected_contact_model": "first_half_meet",
            "start": pd.Timestamp("2020-09-01"),
            "end": pd.Timestamp("2020-09-30"),
            "policy": reduce_to_1st_quarter,
        },
    }
    date = pd.Timestamp("2020-09-29")
    params = pd.DataFrame()
    expected = np.tile([1, 0], (len(states_all_alive), 1)).astype(DTYPE_N_CONTACTS)
    expected[2:4, 1] = 1
    contacts = calculate_contacts(
        contact_models=contact_models,
        states=states_all_alive,
        params=params,
        seed=itertools.count(),
    )

    contacts = apply_contact_policies(
        contact_policies=contact_policies,
        contacts=contacts,
        states=states_all_alive,
        date=date,
        seed=itertools.count(),
    )

    recurrent_contacts, random_contacts = post_process_contacts(
        contacts, states_all_alive, contact_models
    )

    assert recurrent_contacts is None
    assert (random_contacts.to_numpy() == expected).all()
def test_calculate_contacts_with_dead(states_with_dead, contact_models):
    contact_policies = {}
    date = pd.Timestamp("2020-09-29")
    params = pd.DataFrame()
    expected = np.array(
        [
            [0, 0],
            [1, 1],
            [0, 0],
            [1, 1],
            [1, 0],
            [0, 0],
            [1, 0],
            [0, 0],
        ],
        dtype=DTYPE_N_CONTACTS,
    )
    contacts = calculate_contacts(
        contact_models=contact_models,
        states=states_with_dead,
        params=params,
        seed=itertools.count(),
    )

    contacts = apply_contact_policies(
        contact_policies=contact_policies,
        contacts=contacts,
        states=states_with_dead,
        date=date,
        seed=itertools.count(),
    )

    recurrent_contacts, random_contacts = post_process_contacts(
        contacts, states_with_dead, contact_models
    )

    assert recurrent_contacts is None
    assert (random_contacts.to_numpy() == expected).all()
def test_calculate_contacts_policy_scalar_active(states_all_alive):
    contact_models = {
        "ten": {
            "model": lambda params, states, seed: pd.Series(10, index=states.index),
            "is_recurrent": False,
        }
    }
    contact_policies = {
        "noone_meets": {
            "affected_contact_model": "ten",
            "start": pd.Timestamp("2020-09-01"),
            "end": pd.Timestamp("2020-09-30"),
            "policy": 0.5,
        },
    }
    date = pd.Timestamp("2020-09-29")
    params = pd.DataFrame()
    contacts = calculate_contacts(
        contact_models=contact_models,
        states=states_all_alive,
        params=params,
        seed=itertools.count(),
    )

    contacts = apply_contact_policies(
        contact_policies=contact_policies,
        contacts=contacts,
        states=states_all_alive,
        date=date,
        seed=itertools.count(),
    )

    recurrent_contacts, random_contacts = post_process_contacts(
        contacts, states_all_alive, contact_models
    )

    assert recurrent_contacts is None
    assert (random_contacts.to_numpy() == 5).all()