예제 #1
0
def test_multioutput(es):
    multioutput = AggregationFeature(
        IdentityFeature(es["log"].ww["zipcode"]), "sessions", NMostCommon
    )
    feat = FeatureOutputSlice(multioutput, 0)
    graph = graph_feature(feat).source

    feat_name = feat.get_name()
    prim_node = "0_{}_n_most_common".format(multioutput.get_name())
    groupby_node = "{}_groupby_log--session_id".format(multioutput.get_name())

    sessions_table = "\u2605 sessions (target)"
    log_table = "log"
    groupby_edge = 'log:session_id -> "{}"'.format(groupby_node)
    groupby_input = 'log:zipcode -> "{}"'.format(groupby_node)
    prim_input = '"{}" -> "{}"'.format(groupby_node, prim_node)
    feat_edge = '"{}" -> sessions:"{}"'.format(prim_node, feat_name)

    graph_components = [
        feat_name,
        prim_node,
        groupby_node,
        sessions_table,
        log_table,
        groupby_edge,
        groupby_input,
        prim_input,
        feat_edge,
    ]

    for component in graph_components:
        assert component in graph

    dataframes = {
        "log": [log_table, "zipcode", "session_id"],
        "sessions": [sessions_table, feat_name],
    }
    for dataframe in dataframes:
        regex = r"{} \[label=<\n<TABLE.*?</TABLE>>".format(dataframe)
        matches = re.findall(regex, graph, re.DOTALL)
        assert len(matches) == 1

        rows = re.findall(r"<TR.*?</TR>", matches[0], re.DOTALL)
        assert len(rows) == len(dataframes[dataframe])
        for row in rows:
            matched = False
            for i in dataframes[dataframe]:
                if i in row:
                    matched = True
                    dataframes[dataframe].remove(i)
                    break
            assert matched
def test_stacked(es, trans_feat):
    stacked = AggregationFeature(trans_feat, es['cohorts'], Mode)
    graph = graph_feature(stacked).source

    feat_name = stacked.get_name()
    intermediate_name = trans_feat.get_name()
    agg_primitive = '0_{}_mode'.format(feat_name)
    trans_primitive = '1_{}_year'.format(intermediate_name)
    groupby_node = '{}_groupby_customers--cohort'.format(feat_name)

    trans_prim_edge = 'customers:cancel_date -> "{}"'.format(trans_primitive)
    intermediate_edge = '"{}" -> customers:"{}"'.format(trans_primitive, intermediate_name)
    groupby_edge = 'customers:cohort -> "{}"'.format(groupby_node)
    groupby_input = 'customers:"{}" -> "{}"'.format(intermediate_name, groupby_node)
    agg_input = '"{}" -> "{}"'.format(groupby_node, agg_primitive)
    feat_edge = '"{}" -> cohorts:"{}"'.format(agg_primitive, feat_name)

    graph_components = [feat_name, intermediate_name, agg_primitive, trans_primitive,
                        groupby_node, trans_prim_edge, intermediate_edge, groupby_edge,
                        groupby_input, agg_input, feat_edge]
    for component in graph_components:
        assert component in graph

    agg_primitive = agg_primitive.replace('(', '\\(').replace(')', '\\)')
    agg_node = re.findall('"{}" \\[label.*'.format(agg_primitive), graph)
    assert len(agg_node) == 1
    assert 'Step 2' in agg_node[0]

    trans_primitive = trans_primitive.replace('(', '\\(').replace(')', '\\)')
    trans_node = re.findall('"{}" \\[label.*'.format(trans_primitive), graph)
    assert len(trans_node) == 1
    assert 'Step 1' in trans_node[0]
예제 #3
0
def test_multioutput(es):
    multioutput = AggregationFeature(es['log']['zipcode'], es['sessions'],
                                     NMostCommon)
    feat = FeatureOutputSlice(multioutput, 0)
    graph = graph_feature(feat).source

    feat_name = feat.get_name()
    prim_node = '0_{}_n_most_common'.format(multioutput.get_name())
    groupby_node = '{}_groupby_log--session_id'.format(multioutput.get_name())

    sessions_table = '\u2605 sessions (target)'
    log_table = 'log'
    groupby_edge = 'log:session_id -> "{}"'.format(groupby_node)
    groupby_input = 'log:zipcode -> "{}"'.format(groupby_node)
    prim_input = '"{}" -> "{}"'.format(groupby_node, prim_node)
    feat_edge = '"{}" -> sessions:"{}"'.format(prim_node, feat_name)

    graph_components = [
        feat_name, prim_node, groupby_node, sessions_table, log_table,
        groupby_edge, groupby_input, prim_input, feat_edge
    ]

    for component in graph_components:
        assert component in graph

    entities = {
        'log': [log_table, 'zipcode', 'session_id'],
        'sessions': [sessions_table, feat_name]
    }
    for entity in entities:
        regex = r"{} \[label=<\n<TABLE.*?</TABLE>>".format(entity)
        matches = re.findall(regex, graph, re.DOTALL)
        assert len(matches) == 1

        rows = re.findall(r"<TR.*?</TR>", matches[0], re.DOTALL)
        assert len(rows) == len(entities[entity])
        for row in rows:
            matched = False
            for i in entities[entity]:
                if i in row:
                    matched = True
                    entities[entity].remove(i)
                    break
            assert matched
def test_aggregation(es):
    feat = AggregationFeature(IdentityFeature(es['log'].ww['id']), 'sessions',
                              Count)
    graph = graph_feature(feat).source

    feat_name = feat.get_name()
    prim_node = '0_{}_count'.format(feat_name)
    groupby_node = '{}_groupby_log--session_id'.format(feat_name)

    sessions_table = '\u2605 sessions (target)'
    log_table = 'log'
    groupby_edge = 'log:session_id -> "{}"'.format(groupby_node)
    groupby_input = 'log:id -> "{}"'.format(groupby_node)
    prim_input = '"{}" -> "{}"'.format(groupby_node, prim_node)
    feat_edge = '"{}" -> sessions:"{}"'.format(prim_node, feat_name)

    graph_components = [
        feat_name, prim_node, groupby_node, sessions_table, log_table,
        groupby_edge, groupby_input, prim_input, feat_edge
    ]

    for component in graph_components:
        assert component in graph

    dataframes = {
        'log': [log_table, 'id', 'session_id'],
        'sessions': [sessions_table, feat_name]
    }
    for dataframe in dataframes:
        regex = r"{} \[label=<\n<TABLE.*?</TABLE>>".format(dataframe)
        matches = re.findall(regex, graph, re.DOTALL)
        assert len(matches) == 1

        rows = re.findall(r"<TR.*?</TR>", matches[0], re.DOTALL)
        assert len(rows) == len(dataframes[dataframe])
        for row in rows:
            matched = False
            for i in dataframes[dataframe]:
                if i in row:
                    matched = True
                    dataframes[dataframe].remove(i)
                    break
            assert matched