Example #1
0
def test_get_valid_primitives_custom_primitives(pd_es):
    class ThreeMostCommonCat(AggregationPrimitive):
        name = "n_most_common_categorical"
        input_types = [ColumnSchema(semantic_tags={"category"})]
        return_type = ColumnSchema(semantic_tags={"category"})
        number_output_features = 3

    class AddThree(TransformPrimitive):
        name = 'add_three'
        input_types = [
            ColumnSchema(semantic_tags="numeric"),
            ColumnSchema(semantic_tags="numeric"),
            ColumnSchema(semantic_tags="numeric")
        ]
        return_type = ColumnSchema(semantic_tags="numeric")
        commutative = True
        compatibility = [Library.PANDAS, Library.DASK, Library.KOALAS]

    agg_prims, trans_prims = get_valid_primitives(pd_es, "log")
    assert ThreeMostCommonCat not in agg_prims
    assert AddThree not in trans_prims

    with pytest.raises(ValueError,
                       match="'add_three' is not a recognized primitive name"):
        agg_prims, trans_prims = get_valid_primitives(
            pd_es, "log", 2, [ThreeMostCommonCat, "add_three"])
Example #2
0
def test_get_valid_primitives_selected_primitives(es):
    agg_prims, trans_prims = get_valid_primitives(
        es, "log", selected_primitives=[Hour, Count])
    assert set(agg_prims) == set([Count])
    assert set(trans_prims) == set([Hour])

    agg_prims, trans_prims = get_valid_primitives(es,
                                                  "products",
                                                  selected_primitives=[Hour],
                                                  max_depth=1)
    assert set(agg_prims) == set()
    assert set(trans_prims) == set()
Example #3
0
def test_invalid_primitive(es):
    with pytest.raises(ValueError,
                       match="'foobar' is not a recognized primitive name"):
        get_valid_primitives(es,
                             target_dataframe_name='log',
                             selected_primitives=['foobar'])

    msg = ("Selected primitive <enum 'Library'> "
           "is not an AggregationPrimitive, TransformPrimitive, or str")
    with pytest.raises(ValueError, match=msg):
        get_valid_primitives(es,
                             target_dataframe_name='log',
                             selected_primitives=[Library])
Example #4
0
def test_get_valid_primitives_single_table(transform_es):
    msg = "Only one dataframe in entityset, changing max_depth to 1 since deeper features cannot be created"
    with pytest.warns(UserWarning, match=msg):
        agg_prims, trans_prims = get_valid_primitives(transform_es, "first")

    assert set(agg_prims) == set()
    assert IsIn in trans_prims
Example #5
0
def test_primitive_compatibility(es):
    _, trans_prims = get_valid_primitives(
        es, "customers", selected_primitives=[TimeSincePrevious])

    if es.dataframe_type != Library.PANDAS.value:
        assert len(trans_prims) == 0
    else:
        assert len(trans_prims) == 1
def test_get_valid_primitives_with_dfs_kwargs(es):
    agg_prims, trans_prims = get_valid_primitives(
        es, "customers", selected_primitives=[Hour, Count, Not])
    assert set(agg_prims) == set([Count])
    assert set(trans_prims) == set([Hour, Not])

    # Can use other dfs parameters and they get applied
    agg_prims, trans_prims = get_valid_primitives(
        es,
        "customers",
        selected_primitives=[Hour, Count, Not],
        ignore_columns={"customers": ["loves_ice_cream"]},
    )
    assert set(agg_prims) == set([Count])
    assert set(trans_prims) == set([Hour])

    agg_prims, trans_prims = get_valid_primitives(
        es,
        "products",
        selected_primitives=[Hour, Count],
        ignore_dataframes=["log"])
    assert set(agg_prims) == set()
    assert set(trans_prims) == set()
Example #7
0
def test_get_valid_primitives_all_primitives(es):
    agg_prims, trans_prims = get_valid_primitives(es, "customers")
    assert Count in agg_prims
    assert Hour in trans_prims