コード例 #1
0
def test_stacked(es):
    trans_feat = TransformFeature(es['customers']['cancel_date'], Year)
    stacked = AggregationFeature(trans_feat, es['cohorts'], Mode)
    graph = graph_feature(stacked).source

    feat_name = stacked.get_name()
    intermediate_name = trans_feat.get_name()
    agg_primitive = '0_{}_mode'.format(feat_name)
    trans_primitive = '1_{}_year'.format(intermediate_name)
    groupby_node = '{}_groupby_customers--cohort'.format(feat_name)

    trans_prim_edge = 'customers:cancel_date -> "{}"'.format(trans_primitive)
    intermediate_edge = '"{}" -> customers:"{}"'.format(
        trans_primitive, intermediate_name)
    groupby_edge = 'customers:cohort -> "{}"'.format(groupby_node)
    groupby_input = 'customers:"{}" -> "{}"'.format(intermediate_name,
                                                    groupby_node)
    agg_input = '"{}" -> "{}"'.format(groupby_node, agg_primitive)
    feat_edge = '"{}" -> cohorts:"{}"'.format(agg_primitive, feat_name)

    graph_components = [
        feat_name, intermediate_name, agg_primitive, trans_primitive,
        groupby_node, trans_prim_edge, intermediate_edge, groupby_edge,
        groupby_input, agg_input, feat_edge
    ]
    for component in graph_components:
        assert component in graph

    agg_primitive = agg_primitive.replace('(', '\\(').replace(')', '\\)')
    agg_node = re.findall('"{}" \\[label.*'.format(agg_primitive), graph)
    assert len(agg_node) == 1
    assert 'Step 2' in agg_node[0]

    trans_primitive = trans_primitive.replace('(', '\\(').replace(')', '\\)')
    trans_node = re.findall('"{}" \\[label.*'.format(trans_primitive), graph)
    assert len(trans_node) == 1
    assert 'Step 1' in trans_node[0]
コード例 #2
0
def test_transform(es):
    feat = TransformFeature(es['customers']['cancel_date'], Year)
    graph = graph_feature(feat).source

    feat_name = feat.get_name()
    prim_node = '0_{}_year'.format(feat_name)
    entity_table = '\u2605 customers (target)'
    prim_edge = 'customers:cancel_date -> "{}"'.format(prim_node)
    feat_edge = '"{}" -> customers:"{}"'.format(prim_node, feat_name)

    graph_components = [
        feat_name, entity_table, prim_node, prim_edge, feat_edge
    ]
    for component in graph_components:
        assert component in graph

    matches = re.findall(r"customers \[label=<\n<TABLE.*?</TABLE>>", graph,
                         re.DOTALL)
    assert len(matches) == 1
    rows = re.findall(r"<TR.*?</TR>", matches[0], re.DOTALL)
    assert len(rows) == 3
    to_match = ['customers', 'cancel_date', feat_name]
    for match, row in zip(to_match, rows):
        assert match in row