Ejemplo n.º 1
0
def find_applicable_primitives(primitive):
    from featuretools.primitives.utils import get_aggregation_primitives
    from featuretools.primitives.utils import get_transform_primitives
    all_transform_primitives = list(get_transform_primitives().values())
    all_aggregation_primitives = list(get_aggregation_primitives().values())
    applicable_transforms = find_stackable_primitives(all_transform_primitives,
                                                      primitive)
    applicable_aggregations = find_stackable_primitives(
        all_aggregation_primitives, primitive)
    return applicable_transforms, applicable_aggregations
Ejemplo n.º 2
0
def get_valid_primitives(entityset,
                         target_dataframe_name,
                         max_depth=2,
                         selected_primitives=None):
    """
    Returns two lists of primitives (transform and aggregation) containing
    primitives that can be applied to the specific target dataframe to create
    features.  If the optional 'selected_primitives' parameter is not used,
    all discoverable primitives will be considered.

    Note:
        When using a ``max_depth`` greater than 1, some primitives returned by
        this function may not create any features if passed to DFS alone.  These
        primitives relied on features created by other primitives as input
        (primitive stacking).

    Args:
        entityset (EntitySet): An already initialized entityset
        target_dataframe_name (str): Name of dataframe to create features for.
        max_depth (int, optional): Maximum allowed depth of features.
        selected_primitives(list[str or AggregationPrimitive/TransformPrimitive], optional):
            list of primitives to consider when looking for valid primitives.
            If None, all primitives will be considered
    Returns:
       list[AggregationPrimitive], list[TransformPrimitive]:
           The list of valid aggregation primitives and the list of valid
           transform primitives.
    """
    agg_primitives = []
    trans_primitives = []
    available_aggs = get_aggregation_primitives()
    available_trans = get_transform_primitives()

    for library in Library:
        if library.value == entityset.dataframe_type:
            df_library = library
            break

    if selected_primitives:
        for prim in selected_primitives:
            if not isinstance(prim, str):
                if issubclass(prim, AggregationPrimitive):
                    prim_list = agg_primitives
                elif issubclass(prim, TransformPrimitive):
                    prim_list = trans_primitives
                else:
                    raise ValueError(
                        f"Selected primitive {prim} is not an "
                        "AggregationPrimitive, TransformPrimitive, or str")
            elif prim in available_aggs:
                prim = available_aggs[prim]
                prim_list = agg_primitives
            elif prim in available_trans:
                prim = available_trans[prim]
                prim_list = trans_primitives
            else:
                raise ValueError(
                    f"'{prim}' is not a recognized primitive name")
            if df_library in prim.compatibility:
                prim_list.append(prim)
    else:
        agg_primitives = [
            agg for agg in available_aggs.values()
            if df_library in agg.compatibility
        ]
        trans_primitives = [
            trans for trans in available_trans.values()
            if df_library in trans.compatibility
        ]

    dfs_object = DeepFeatureSynthesis(
        target_dataframe_name,
        entityset,
        agg_primitives=agg_primitives,
        trans_primitives=trans_primitives,
        max_depth=max_depth,
    )

    features = dfs_object.build_features()

    trans, agg, _, _ = _categorize_features(features)

    trans_unused = get_unused_primitives(trans_primitives, trans)
    agg_unused = get_unused_primitives(agg_primitives, agg)

    # switch from str to class
    agg_unused = [available_aggs[name] for name in agg_unused]
    trans_unused = [available_trans[name] for name in trans_unused]

    used_agg_prims = set(agg_primitives).difference(set(agg_unused))
    used_trans_prims = set(trans_primitives).difference(set(trans_unused))
    return list(used_agg_prims), list(used_trans_prims)