コード例 #1
0
def test_sir_logistic_policy(penn_chime_setup, sir_data_w_policy):
    """Compares local SIR against penn_chime SIR for implemented social policies
    where policies are implemented as a logistic function
    """
    p, sir = penn_chime_setup
    x, pars = sir_data_w_policy

    policies = sir.gen_policy(p)

    # Set up logistic function to match policies (Sharp decay)
    pars["beta"] = policies[0][0] * p.population
    ## This are new parameters needed by one_minus_logistic_fcn
    pars["L"] = 1 - policies[1][0] / policies[0][0]
    pars["x0"] = policies[0][1] - 0.5
    pars["k"] = 1.0e7

    def update_parameters(ddate, **kwargs):
        xx = (ddate - x["dates"][0]).days
        ppars = kwargs.copy()
        ppars["beta"] = kwargs["beta"] * one_minus_logistic_fcn(
            xx,
            L=kwargs["L"],
            k=kwargs["k"],
            x0=kwargs["x0"],
        )
        return ppars

    sir_model = SIRModel(update_parameters=update_parameters)
    predictions = sir_model.propagate_uncertainties(x, pars)

    assert_frame_equal(
        sir.raw_df.set_index("date").rename(
            columns=COLUMN_MAP)[COLS_TO_COMPARE].fillna(0),
        predictions[COLS_TO_COMPARE],
    )
コード例 #2
0
ファイル: seir_test.py プロジェクト: BrianThomasRoss/CHIME-2
def test_compare_sir_vs_seir(sir_data_wo_policy, seir_data, monkeypatch):
    """Checks if SEIR and SIR return same results if the code enforces

    * alpha = gamma
    * E = 0
    * dI = dE
    """
    x_sir, pars_sir = sir_data_wo_policy
    x_seir, pars_seir = seir_data

    pars_seir["alpha"] = pars_sir["gamma"]  # will be done by hand

    def mocked_seir_step(data, **pars):
        data["exposed"] = 0
        new_data = SEIRModel.simulation_step(data, **pars)
        new_data["infected"] += new_data["exposed_new"]
        return new_data

    seir_model = SEIRModel()
    monkeypatch.setattr(seir_model, "simulation_step", mocked_seir_step)

    sir_model = SIRModel()
    predictions_sir = sir_model.propagate_uncertainties(x_sir, pars_sir)
    predictions_seir = seir_model.propagate_uncertainties(x_seir, pars_seir)

    assert_frame_equal(
        predictions_sir[COLS_TO_COMPARE], predictions_seir[COLS_TO_COMPARE],
    )
コード例 #3
0
def test_sir_vs_penn_chime_no_policies(penn_chime_raw_df_no_policy,
                                       sir_data_wo_policy):
    """Compares local SIR against penn_chime SIR for no social policies
    """
    x, pars = sir_data_wo_policy

    sir_model = SIRModel()
    predictions = sir_model.propagate_uncertainties(x, pars)

    assert_frame_equal(
        penn_chime_raw_df_no_policy.rename(
            columns=COLUMN_MAP)[COLS_TO_COMPARE],
        predictions[COLS_TO_COMPARE],
    )
コード例 #4
0
def test_sir_type_conversion(sir_data_w_policy):
    """Compares local SIR run with set gamma vs set with recovery_days
    """
    x, pars = sir_data_w_policy

    sir_model = SIRModel()
    predictions = sir_model.propagate_uncertainties(x, pars)

    pars["recovery_days"] = 1 / pars.pop("gamma")
    new_predictions = sir_model.propagate_uncertainties(x, pars)

    assert_frame_equal(
        predictions,
        new_predictions,
    )
コード例 #5
0
def test_conserved_n(sir_data_wo_policy):
    """Checks if S + I + R is conserved for local SIR
    """
    x, pars = sir_data_wo_policy
    sir_model = SIRModel()

    n_total = 0
    for key in sir_model.compartments:
        n_total += pars[f"initial_{key}"]

    predictions = sir_model.propagate_uncertainties(x, pars)

    n_computed = predictions[sir_model.compartments].sum(axis=1)
    n_expected = Series(data=[n_total] * len(n_computed),
                        index=n_computed.index)

    assert_series_equal(n_expected, n_computed)
コード例 #6
0
def test_sir_vs_penn_chime_w_policies(penn_chime_setup, sir_data_w_policy):
    """Compares local SIR against penn_chime SIR for with social policies
    """
    p, sir = penn_chime_setup
    x, pars = sir_data_w_policy

    policies = sir.gen_policy(p)
    new_policy_date = x["dates"][0] + timedelta(days=policies[0][1])
    beta0, beta1 = policies[0][0], policies[1][0]

    def update_parameters(ddate, **pars):  # pylint: disable=W0613
        pars["beta"] = (beta0
                        if ddate < new_policy_date else beta1) * p.population
        return pars

    sir_model = SIRModel(update_parameters=update_parameters)
    predictions = sir_model.propagate_uncertainties(x, pars)

    assert_frame_equal(
        sir.raw_df.set_index("date").fillna(0).rename(
            columns=COLUMN_MAP)[COLS_TO_COMPARE],
        predictions[COLS_TO_COMPARE],
    )