def test_join_composite(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) j = join(rel_one, rel_two, "join", party_data[0]["col_names"][:1], party_data[1]["col_names"][:1]) d = divide(j, "div", party_data[0]["col_names"][1], [party_data[1]["col_names"][1]]) collect(d, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushDown() pd.rewrite(d) pu = PushUp() pu.rewrite(d) ic = InsertCloseOps() ic.rewrite(d) io = InsertOpenOps() io.rewrite(d) compare_to_expected(d, expected)
def test_join_composite_agg(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) j = join(rel_one, rel_two, "join", party_data[0]["col_names"][:1], party_data[1]["col_names"][:1]) agg = aggregate(j, "agg", [party_data[0]["col_names"][0]], party_data[0]["col_names"][1], "mean") d = multiply(agg, "mult", party_data[0]["col_names"][1], [7]) collect(d, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushDown() pd.rewrite(d) pu = PushUp() pu.rewrite(d) ic = InsertCloseOps() ic.rewrite(d) io = InsertOpenOps() io.rewrite(d) compare_to_expected(d, expected)
def test_concat_composite_three_party_agg_mean(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) cols_in_three = create_cols(party_data[2]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) rel_three = create("in3", cols_in_three, party_data[2]["stored_with"]) cc = concat([rel_one, rel_two, rel_three], "concat", party_data[0]["col_names"]) p = project(cc, "conc", party_data[0]["col_names"][:2]) mult = multiply(p, "mult", party_data[0]["col_names"][0], [5]) agg = aggregate(mult, "agg", [party_data[0]["col_names"][0]], party_data[0]["col_names"][1], "mean") collect(agg, {1, 2, 3}) d = Dag({rel_one, rel_two, rel_three}) pd = PushDown() pd.rewrite(d) pu = PushUp() pu.rewrite(d) ic = InsertCloseOps() ic.rewrite(d) io = InsertOpenOps() io.rewrite(d) compare_to_expected(d, expected)
def test_project(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) p = project(cc, "proj", [party_data[0]["col_names"][0]]) collect(p, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushDown() pd.rewrite(d) compare_to_expected(d, expected)
def test_filter_by_scalar(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) p = filter_by(cc, "filt", party_data[0]["col_names"][0], ">", 10) collect(p, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushDown() pd.rewrite(d) compare_to_expected(d, expected)
def test_multiply_target_existing(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) p = multiply(cc, "mult", party_data[0]["col_names"][0], [party_data[0]["col_names"][1], 10]) collect(p, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushDown() pd.rewrite(d) compare_to_expected(d, expected)
def test_project_reorder_cols(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) p = project(cc, "proj", party_data[0]["col_names"][::-1]) collect(p, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushDown() pd.rewrite(d) compare_to_expected(d, expected) zip_col_names = zip(d.top_sort(), [e["col_names"] for e in expected["ownership_data"]]) col_name_checks = [[c.name for c in z[0].out_rel.columns] == z[1] for z in zip_col_names] assert all(col_name_checks)
def test_concat(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) collect(cc, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushDown() pd.rewrite(d) pu = PushUp() pu.rewrite(d) ic = InsertCloseOps() ic.rewrite(d) io = InsertOpenOps() io.rewrite(d) compare_to_expected(d, expected)
def test_agg_variance(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) variance = aggregate(cc, "variance", [party_data[0]["col_names"][0]], party_data[0]["col_names"][1], "variance") mult = multiply(variance, "mult", party_data[0]["col_names"][0], [party_data[0]["col_names"][1], 7]) collect(mult, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushDown() pd.rewrite(d) pu = PushUp() pu.rewrite(d) compare_to_expected(d, expected)