def test_stacked(es): trans_feat = TransformFeature(es['customers']['cancel_date'], Year) stacked = AggregationFeature(trans_feat, es['cohorts'], Mode) graph = graph_feature(stacked).source feat_name = stacked.get_name() intermediate_name = trans_feat.get_name() agg_primitive = '0_{}_mode'.format(feat_name) trans_primitive = '1_{}_year'.format(intermediate_name) groupby_node = '{}_groupby_customers--cohort'.format(feat_name) trans_prim_edge = 'customers:cancel_date -> "{}"'.format(trans_primitive) intermediate_edge = '"{}" -> customers:"{}"'.format( trans_primitive, intermediate_name) groupby_edge = 'customers:cohort -> "{}"'.format(groupby_node) groupby_input = 'customers:"{}" -> "{}"'.format(intermediate_name, groupby_node) agg_input = '"{}" -> "{}"'.format(groupby_node, agg_primitive) feat_edge = '"{}" -> cohorts:"{}"'.format(agg_primitive, feat_name) graph_components = [ feat_name, intermediate_name, agg_primitive, trans_primitive, groupby_node, trans_prim_edge, intermediate_edge, groupby_edge, groupby_input, agg_input, feat_edge ] for component in graph_components: assert component in graph agg_primitive = agg_primitive.replace('(', '\\(').replace(')', '\\)') agg_node = re.findall('"{}" \\[label.*'.format(agg_primitive), graph) assert len(agg_node) == 1 assert 'Step 2' in agg_node[0] trans_primitive = trans_primitive.replace('(', '\\(').replace(')', '\\)') trans_node = re.findall('"{}" \\[label.*'.format(trans_primitive), graph) assert len(trans_node) == 1 assert 'Step 1' in trans_node[0]
def test_transform(es): feat = TransformFeature(es['customers']['cancel_date'], Year) graph = graph_feature(feat).source feat_name = feat.get_name() prim_node = '0_{}_year'.format(feat_name) entity_table = '\u2605 customers (target)' prim_edge = 'customers:cancel_date -> "{}"'.format(prim_node) feat_edge = '"{}" -> customers:"{}"'.format(prim_node, feat_name) graph_components = [ feat_name, entity_table, prim_node, prim_edge, feat_edge ] for component in graph_components: assert component in graph matches = re.findall(r"customers \[label=<\n<TABLE.*?</TABLE>>", graph, re.DOTALL) assert len(matches) == 1 rows = re.findall(r"<TR.*?</TR>", matches[0], re.DOTALL) assert len(rows) == 3 to_match = ['customers', 'cancel_date', feat_name] for match, row in zip(to_match, rows): assert match in row