Beispiel #1
0
def test_stochatreat_input_empty_data(correct_params):
    """
    Tests that the function raises an error when an empty dataframe is passed
    """
    empty_data = pd.DataFrame()
    with pytest.raises(ValueError):
        stochatreat(
            data=empty_data,
            stratum_cols="stratum",
            treats=correct_params["treat"],
            idx_col=correct_params["idx_col"],
            probs=correct_params["probs"],
        )
Beispiel #2
0
def test_stochatreat_input_invalid_probs(correct_params):
    """
    Tests that the function rejects probabilities that don't add up to one
    """
    probs_not_sum_to_one = [0.1, 0.2]
    with pytest.raises(Exception):
        stochatreat(
            data=correct_params["data"],
            stratum_cols=["stratum"],
            treats=correct_params["treat"],
            idx_col=correct_params["idx_col"],
            probs=probs_not_sum_to_one,
        )
Beispiel #3
0
def test_stochatreat_input_more_treats_than_probs(correct_params):
    """
    Tests that the function raises an error for treatments and probs of
    different sizes
    """
    treat_too_large = 3
    with pytest.raises(Exception):
        stochatreat(
            data=correct_params["data"],
            stratum_cols=["stratum"],
            treats=treat_too_large,
            idx_col=correct_params["idx_col"],
            probs=correct_params["probs"],
        )
Beispiel #4
0
def test_stochatreat_input_idx_col_str(correct_params):
    """
    Tests that the function rejects an idx_col parameter that is not a
    string or None
    """
    idx_col_not_str = 0
    with pytest.raises(TypeError):
        stochatreat(
            data=correct_params["data"],
            stratum_cols=["stratum"],
            treats=correct_params["treat"],
            idx_col=idx_col_not_str,
            probs=correct_params["probs"],
        )
Beispiel #5
0
def test_stochatreat_input_invalid_size(correct_params):
    """
    Tests that the function rejects a sampling size larger than the data count
    """
    size_bigger_than_sampling_universe_size = 101
    with pytest.raises(ValueError):
        stochatreat(
            data=correct_params["data"],
            stratum_cols=["stratum"],
            treats=correct_params["treat"],
            idx_col=correct_params["idx_col"],
            probs=correct_params["probs"],
            size=size_bigger_than_sampling_universe_size,
        )
Beispiel #6
0
def test_stochatreat_input_invalid_strategy(correct_params):
    """
    Tests that the function raises an error if an invalid strategy string is
    passed
    """
    unknown_strategy = "unknown"
    with pytest.raises(ValueError):
        stochatreat(
            data=correct_params["data"],
            stratum_cols=["stratum"],
            treats=correct_params["treat"],
            idx_col=correct_params["idx_col"],
            probs=correct_params["probs"],
            misfit_strategy=unknown_strategy,
        )
Beispiel #7
0
def treatments_dict_rand_index():
    """fixture of stochatreat() output to test output format"""
    treats = 2
    data = pd.DataFrame(
        data={
            "id": np.random.permutation(100),
            "stratum": [0] * 40 + [1] * 30 + [2] * 30
        })
    data = data.set_index(pd.Index(np.random.choice(300, 100, replace=False)))
    idx_col = "id"

    treatments = stochatreat(
        data=data,
        stratum_cols=["stratum"],
        treats=treats,
        idx_col=idx_col,
        random_state=42,
    )

    treatments_dict = {
        "data": data,
        "stratum_cols": ["stratum"],
        "idx_col": idx_col,
        "treatments": treatments,
        "n_treatments": treats,
    }

    return treatments_dict
Beispiel #8
0
def test_stochatreat_only_misfits(probs):
    """
    Tests that overall treatment assignment proportions across all strata are
    as intended when strata are such that there are only misfits and the number
    of units is sufficiently large -- relies on the Law of Large Numbers, not
    deterministic
    """
    N = 10_000
    df = pd.DataFrame(data={
        "id": np.arange(N),
        "stratum": np.arange(N),
    })
    treats = stochatreat(
        data=df,
        stratum_cols=["stratum"],
        treats=len(probs),
        idx_col="id",
        probs=probs,
        random_state=42,
    )
    treatment_shares = treats.groupby('treat')['id'].size() / treats.shape[0]

    np.testing.assert_almost_equal(treatment_shares,
                                   np.array(probs),
                                   decimal=2)
Beispiel #9
0
def treatments_dict():
    """fixture of stochatreat() output to test output format"""
    treats = 2
    data = pd.DataFrame(data={
        "id": np.arange(100),
        "stratum": [0] * 40 + [1] * 30 + [2] * 30
    })
    idx_col = "id"
    size = 90

    treatments = stochatreat(
        data=data,
        stratum_cols=["stratum"],
        treats=treats,
        idx_col=idx_col,
        size=size,
        random_state=42,
    )

    treatments_dict = {
        "data": data,
        "idx_col": idx_col,
        "size": size,
        "treatments": treatments,
    }

    return treatments_dict
Beispiel #10
0
def test_stochatreat_input_idx_col_unique(correct_params):
    """
    Tests that the function raises an error if the idx_col is not a primary key
    of the data
    """
    data_with_idx_col_with_duplicates = pd.DataFrame(data={
        "id": 1,
        "stratum": np.arange(100)
    })
    with pytest.raises(ValueError):
        stochatreat(
            data=data_with_idx_col_with_duplicates,
            stratum_cols=["stratum"],
            treats=correct_params["treat"],
            idx_col=correct_params["idx_col"],
            probs=correct_params["probs"],
        )
Beispiel #11
0
def test_stochatreat_output_sample(correct_params):
    """
    Tests that the function samples to the correct size
    """
    size = 100
    assignments = stochatreat(data=correct_params["data"],
                              stratum_cols=["stratum"],
                              treats=correct_params["treat"],
                              idx_col=correct_params["idx_col"],
                              probs=correct_params["probs"],
                              size=size)

    assert len(assignments) == size
Beispiel #12
0
def test_stochatreat_global_strategy(probs, stratum_cols, df):
    treats = stochatreat(data=df,
                         stratum_cols=stratum_cols,
                         treats=len(probs),
                         idx_col="id",
                         probs=probs,
                         random_state=42,
                         misfit_strategy="global")
    comp = compute_count_diff(treats, probs)

    stratum_count_diff = comp.groupby(["stratum_id"])["count_diff"].sum()

    assert_msg = "There is more than one stratum with misfits"
    assert (stratum_count_diff != 0).sum() <= 1, assert_msg
Beispiel #13
0
def test_stochatreat_within_strata_no_misfits(probs, df_no_misfits):
    """
    Tests that within strata treatment assignment counts are exactly equal to
    the required counts when strata are such that there are no misfits
    """
    treats = stochatreat(
        data=df_no_misfits,
        stratum_cols=["stratum"],
        treats=len(probs),
        idx_col="id",
        probs=probs,
        random_state=42,
    )
    comp = compute_count_diff(treats, probs)

    assert_msg = "The required proportions are not reached without misfits"
    assert (comp["count_diff"] == 0).all(), assert_msg
Beispiel #14
0
def test_stochatreat_no_probs(n_treats, stratum_cols, df):
    """
    Tests that overall treatment assignment proportions across all strata are
    as intended with equal treatment assignment probabilities -- relies on the
    Law of Large Numbers, not deterministic
    """
    treats = stochatreat(data=df,
                         stratum_cols=stratum_cols,
                         treats=n_treats,
                         idx_col="id",
                         random_state=42)

    treatment_shares = treats.groupby('treat')['id'].size() / treats.shape[0]

    np.testing.assert_almost_equal(treatment_shares,
                                   np.array([1 / n_treats] * n_treats),
                                   decimal=2)
Beispiel #15
0
def test_stochatreat_no_misfits(probs, df_no_misfits):
    """
    Tests that overall treatment assignment proportions across all strata are
    as intended when strata are such that there are no misfits
    """
    treats = stochatreat(
        data=df_no_misfits,
        stratum_cols=["stratum"],
        treats=len(probs),
        idx_col="id",
        probs=probs,
        random_state=42,
    )
    treatment_shares = treats.groupby('treat')['id'].size() / treats.shape[0]

    np.testing.assert_almost_equal(treatment_shares,
                                   np.array(probs),
                                   decimal=2)
Beispiel #16
0
def test_stochatreat_within_strata_no_probs(n_treats, stratum_cols, df):
    """
    Tests that within strata treatment assignment counts are only as far from
    the required counts as misfit assignment randomization allows with equal
    treatment assignment probabilities but a differing number of treatments
    """
    probs = n_treats * [1 / n_treats]
    lcm_prob_denominators = n_treats
    treats = stochatreat(data=df,
                         stratum_cols=stratum_cols,
                         treats=n_treats,
                         idx_col="id",
                         random_state=42)
    comp = compute_count_diff(treats, probs)

    assert_msg = """The counts differences exceed the bound that misfit
    allocation should not exceed"""
    assert (comp["count_diff"] < lcm_prob_denominators).all(), assert_msg
Beispiel #17
0
def test_stochatreat_output_index_content_unchanged(treatments_dict_rand_index,
                                                    probs, misfit_strategy):
    """
    Tests that the functions's output's index column matches the input index
    column
    """
    data_with_rand_index = treatments_dict_rand_index["data"]

    treatments = stochatreat(
        data=data_with_rand_index,
        stratum_cols=["stratum"],
        probs=probs,
        treats=2,
        idx_col=treatments_dict_rand_index["idx_col"],
        misfit_strategy=misfit_strategy,
    )

    assert_msg = "The output and input indices do not have the same content"
    assert set(treatments.index) == set(data_with_rand_index.index), assert_msg
Beispiel #18
0
def test_stochatreat_random_state(df, stratum_cols, misfit_strategy):
    """
    Tests that the results are the same on two consecutive calls with the same
    random state
    """
    random_state = 42
    treats = []
    for _ in range(2):
        treatments_i = stochatreat(
            data=df,
            stratum_cols=stratum_cols,
            treats=2,
            idx_col="id",
            random_state=random_state,
            misfit_strategy=misfit_strategy,
        )
        treats.append(treatments_i)

    pd.testing.assert_series_equal(treats[0]["treat"], treats[1]["treat"])
Beispiel #19
0
def test_stochatreat_within_strata_probs(probs, stratum_cols, df):
    """
    Tests that within strata treatment assignment counts are only as far from
    the required counts as misfit assignment randomization allows with two
    treatments but unequal treatment assignment probabilities
    """
    lcm_prob_denominators = get_lcm_prob_denominators(probs)
    treats = stochatreat(
        data=df,
        stratum_cols=stratum_cols,
        treats=len(probs),
        idx_col="id",
        probs=probs,
        random_state=42,
    )
    comp = compute_count_diff(treats, probs)

    assert_msg = """The counts differences exceed the bound that misfit
    allocation should not exceed"""
    assert (comp["count_diff"] < lcm_prob_denominators).all(), assert_msg
Beispiel #20
0
def test_stochatreat_stratum_ids(df, misfit_strategy, stratum_cols):
    """Tests that the function returns the right number of stratum ids"""
    treats = stochatreat(
        data=df,
        stratum_cols=stratum_cols,
        treats=2,
        idx_col="id",
        random_state=42,
        misfit_strategy=misfit_strategy,
    )

    n_unique_strata = len(df[stratum_cols].drop_duplicates())

    n_unique_stratum_ids = len(treats["stratum_id"].drop_duplicates())

    if misfit_strategy == "global":
        # depending on whether there are misfits
        assert ((n_unique_stratum_ids == n_unique_strata)
                or (n_unique_stratum_ids - 1 == n_unique_strata))
    else:
        assert n_unique_stratum_ids == n_unique_strata
Beispiel #21
0
def test_stochatreat_shuffle_data(df, stratum_cols, misfit_strategy):
    """
    Tests that the mapping between idx_col and the assignments is the same on
    two consecutive calls with the same random state and shuffled data points
    """
    random_state = 42
    treats = []
    for _ in range(2):
        treatments_i = stochatreat(
            data=df,
            stratum_cols=stratum_cols,
            treats=2,
            idx_col="id",
            random_state=random_state,
            misfit_strategy=misfit_strategy,
        )
        treatments_i = treatments_i.sort_values("id")
        treats.append(treatments_i)

        df = df.sample(len(df), random_state=random_state)

    pd.testing.assert_series_equal(treats[0]["treat"], treats[1]["treat"])
Beispiel #22
0
def test_stochatreat_output_index_and_idx_col_correspondence(
        treatments_dict_rand_index, probs, misfit_strategy):
    """
    Tests that the functions's output's index column matches the input index
    column
    """
    data_with_rand_index = treatments_dict_rand_index["data"]
    idx_col = treatments_dict_rand_index["idx_col"]

    treatments = stochatreat(
        data=data_with_rand_index,
        stratum_cols="stratum",
        probs=probs,
        treats=2,
        idx_col=idx_col,
        misfit_strategy=misfit_strategy,
    )

    data_with_rand_index = data_with_rand_index.sort_index()
    treatments = treatments.sort_index()

    pd.testing.assert_series_equal(data_with_rand_index[idx_col],
                                   treatments[idx_col])