def test_category_product_too_large(random_gen):
    dfc = {}
    for i in range(20):
        dfc[str(i)] = random_cat(10, 1000)
    cat = pd.DataFrame(dfc)
    with pytest.raises(ValueError):
        category_product(cat)
def test_interaction_cat_only(cat):
    interact = Interaction(cat=cat)
    assert interact.nobs == cat.shape[0]
    assert_frame_equal(cat, interact.cat)
    expected = category_interaction(category_product(cat), precondition=False)
    actual = interact.sparse
    assert isinstance(actual, csc_matrix)
    assert_allclose(expected.A, actual.A)
def test_interaction_cat_cont(cat, cont):
    interact = Interaction(cat=cat, cont=cont)
    assert interact.nobs == cat.shape[0]
    assert_frame_equal(cat, interact.cat)
    assert_frame_equal(cont, interact.cont)
    base = category_interaction(category_product(cat), precondition=False).A
    expected = []
    for i in range(cont.shape[1]):
        element = base.copy()
        element[np.where(element)] = cont.iloc[:, i].to_numpy()
        expected.append(element)
    expected = np.column_stack(expected)
    actual = interact.sparse
    assert isinstance(actual, csc_matrix)
    assert_allclose(expected, interact.sparse.A)
Example #4
0
def test_category_product(cat):
    prod = category_product(cat)
    if cat.shape[1] == 1:
        assert_series_equal(prod, cat.iloc[:, 0], check_names=False)
    else:
        alt = cat.iloc[:, 0].astype("int64")
        for i in range(1, cat.shape[1]):
            alt += 10 ** (4 * i) * cat.iloc[:, i].astype("int64")
        alt = pd.Categorical(alt)
        alt = pd.Series(alt)
        df = pd.DataFrame([prod.cat.codes, alt.cat.codes], index=["cat_prod", "alt"]).T
        g = df.groupby("cat_prod").alt
        assert (g.nunique() == 1).all()
        g = df.groupby("alt").cat_prod
        assert (g.nunique() == 1).all()
def test_category_product_not_cat(random_gen):
    cat = pd.DataFrame(
        {str(i): pd.Series(random_gen.randint(0, 10, 1000))
         for i in range(3)})
    with pytest.raises(TypeError):
        category_product(cat)