def test_concat_composite_three_party_agg_sum(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) cols_in_three = create_cols(party_data[2]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) rel_three = create("in3", cols_in_three, party_data[2]["stored_with"]) cc = concat([rel_one, rel_two, rel_three], "concat", party_data[0]["col_names"]) mult = multiply(cc, "mult", party_data[0]["col_names"][0], [party_data[0]["col_names"][1], 5]) agg = aggregate(mult, "agg", [party_data[0]["col_names"][0]], party_data[0]["col_names"][1], "sum") p = project(agg, "proj", [party_data[0]["col_names"][0]]) collect(p, {1, 2, 3}) d = Dag({rel_one, rel_two, rel_three}) pd = PushDown() pd.rewrite(d) pu = PushUp() pu.rewrite(d) ic = InsertCloseOps() ic.rewrite(d) io = InsertOpenOps() io.rewrite(d) f = d.top_sort() compare_to_expected(d, expected)
def test_join_composite(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) j = join(rel_one, rel_two, "join", party_data[0]["col_names"][:1], party_data[1]["col_names"][:1]) d = divide(j, "div", party_data[0]["col_names"][1], [party_data[1]["col_names"][1]]) collect(d, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushDown() pd.rewrite(d) pu = PushUp() pu.rewrite(d) ic = InsertCloseOps() ic.rewrite(d) io = InsertOpenOps() io.rewrite(d) compare_to_expected(d, expected)
def test_join_composite_agg(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) j = join(rel_one, rel_two, "join", party_data[0]["col_names"][:1], party_data[1]["col_names"][:1]) agg = aggregate(j, "agg", [party_data[0]["col_names"][0]], party_data[0]["col_names"][1], "mean") d = multiply(agg, "mult", party_data[0]["col_names"][1], [7]) collect(d, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushDown() pd.rewrite(d) pu = PushUp() pu.rewrite(d) ic = InsertCloseOps() ic.rewrite(d) io = InsertOpenOps() io.rewrite(d) compare_to_expected(d, expected)
def test_concat_composite(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) mult_one = multiply(rel_one, "mult_one", party_data[0]["col_names"][0], [party_data[0]["col_names"][1], 5]) div_two = divide(rel_two, "div_two", party_data[1]["col_names"][0], [party_data[1]["col_names"][1], 5]) cc = concat([mult_one, div_two], "concat", party_data[0]["col_names"]) agg = aggregate(cc, "agg", [party_data[0]["col_names"][0]], party_data[0]["col_names"][1], "sum") p = project(agg, "proj", [party_data[0]["col_names"][0]]) collect(p, {1, 2}) d = Dag({rel_one, rel_two}) ic = InsertCloseOps() ic.rewrite(d) io = InsertOpenOps() io.rewrite(d) compare_to_expected(d, expected)
def compile(protocol: callable, enable_optimizations: [bool, None] = True): dag = Dag(protocol()) if enable_optimizations: compile_dag(dag) else: compile_dag_without_optimizations(dag) return dag
def get_next_partition(self, dag: Dag, candidate_roots: set): # there will always be a next root, bc of while loop condition in parent method next_root = dag.roots.pop() roots_in_partition = {next_root} possible_roots = self._get_next_partition(next_root, next_root.requires_mpc(), set()) candidate_roots = candidate_roots.union(possible_roots) roots_in_partition, candidate_roots = \ self._iterate_over_remaining_roots(dag, next_root, candidate_roots, roots_in_partition) if not dag.roots and candidate_roots: dag.roots = candidate_roots return Dag(roots_in_partition), set() else: return Dag(roots_in_partition), candidate_roots
def test_project_reorder_cols(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) p = project(cc, "proj", party_data[0]["col_names"][::-1]) collect(p, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushDown() pd.rewrite(d) compare_to_expected(d, expected) zip_col_names = zip(d.top_sort(), [e["col_names"] for e in expected["ownership_data"]]) col_name_checks = [[c.name for c in z[0].out_rel.columns] == z[1] for z in zip_col_names] assert all(col_name_checks)
def test_agg_mean(party_data, expected): cols_in_one = create_cols(party_data[0]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) agg = aggregate(rel_one, "agg", party_data[0]["col_names"][:1], party_data[0]["col_names"][1], "mean") div = divide(agg, "div", party_data[0]["col_names"][1], [10]) collect(div, {1, 2}) d = Dag({rel_one}) compile_dag(d) compare_to_expected(d, expected)
def test_join_simple(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) j = join(rel_one, rel_two, "join", [rel_one.out_rel.columns[0].name], [rel_two.out_rel.columns[0].name]) collect(j, {1, 2, 3}) d = Dag({rel_one, rel_two}) compare_to_expected(d, expected)
def test_std_dev_alt_key_col(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) std_dev = aggregate(cc, "std_dev", [party_data[0]["col_names"][1]], party_data[0]["col_names"][0], "std_dev") collect(std_dev, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushUp() pd.rewrite(d) compare_to_expected(d, expected) zip_col_names = zip(d.top_sort(), [e["col_names"] for e in expected["ownership_data"]]) col_name_checks = [[c.name for c in z[0].out_rel.columns] == z[1] for z in zip_col_names] assert all(col_name_checks)
def test_single_dataset(party_data, expected): cols_in_one = create_cols(party_data[0]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) agg = aggregate(rel_one, "agg", party_data[0]["col_names"][:1], party_data[0]["col_names"][1], "mean") div = divide(agg, "div", party_data[0]["col_names"][1], [10]) collect(div, {1, 2}) d = Dag({rel_one}) compile_dag(d) p = HeuristicPart(d) parts = p.partition(100) compare_partition_to_expected(parts, expected)
def test_concat_simple(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) cols_in_three = create_cols(party_data[2]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) rel_three = create("in3", cols_in_three, party_data[2]["stored_with"]) cc = concat([rel_one, rel_two, rel_three], "concat", party_data[0]["col_names"]) collect(cc, {1, 2, 3}) d = Dag({rel_one, rel_two, rel_three}) compare_to_expected(d, expected)
def test_concat(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) collect(cc, {1, 2}) d = Dag({rel_one, rel_two}) ic = InsertCloseOps() ic.rewrite(d) compare_to_expected(d, expected)
def test_all_stats(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) all_stat = all_stats(cc, "all_stat", [party_data[0]["col_names"][0]], party_data[0]["col_names"][1]) collect(all_stat, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushUp() pd.rewrite(d) compare_to_expected(d, expected)
def test_join(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) j = join(rel_one, rel_two, "join", [party_data[0]["col_names"][0]], [party_data[1]["col_names"][0]]) collect(j, {1, 2}) d = Dag({rel_one, rel_two}) ic = InsertCloseOps() ic.rewrite(d) compare_to_expected(d, expected)
def test_project(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) p = project(cc, "proj", [party_data[0]["col_names"][0]]) collect(p, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushDown() pd.rewrite(d) compare_to_expected(d, expected)
def test_filter_by_scalar(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) p = filter_by(cc, "filt", party_data[0]["col_names"][0], ">", 10) collect(p, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushDown() pd.rewrite(d) compare_to_expected(d, expected)
def test_variance(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) variance = aggregate(cc, "variance", [party_data[0]["col_names"][0]], party_data[0]["col_names"][1], "variance") collect(variance, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushUp() pd.rewrite(d) compare_to_expected(d, expected)
def test_std_dev_no_key_cols(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) std_dev = aggregate(cc, "std_dev", [], party_data[0]["col_names"][0], "std_dev") collect(std_dev, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushUp() pd.rewrite(d) compare_to_expected(d, expected)
def test_multiply_target_existing(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) p = multiply(cc, "mult", party_data[0]["col_names"][0], [party_data[0]["col_names"][1], 10]) collect(p, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushDown() pd.rewrite(d) compare_to_expected(d, expected)
def test_divide(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) p = divide(cc, "div", party_data[0]["col_names"][0], [party_data[0]["col_names"][1], 10]) collect(p, {1, 2}) d = Dag({rel_one, rel_two}) pu = PushUp() pu.rewrite(d) compare_to_expected(d, expected)
def test_agg_variance(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) variance = aggregate(cc, "variance", [party_data[0]["col_names"][0]], party_data[0]["col_names"][1], "variance") mult = multiply(variance, "mult", party_data[0]["col_names"][0], [party_data[0]["col_names"][1], 7]) collect(mult, {1, 2}) d = Dag({rel_one, rel_two}) compile_dag(d) compare_to_expected(d, expected)
def test_join(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) j = join(rel_one, rel_two, "join", [party_data[0]["col_names"][0]], [party_data[1]["col_names"][0]]) agg = aggregate(j, "agg", [party_data[0]["col_names"][0]], party_data[0]["col_names"][1], "mean") div = divide(agg, "div", party_data[0]["col_names"][1], [10]) collect(div, {1, 2}) d = Dag({rel_one, rel_two}) compile_dag(d) p = HeuristicPart(d) parts = p.partition(100) compare_partition_to_expected(parts, expected)
def test_agg_std_dev(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) std_dev = aggregate(cc, "std_dev", [party_data[0]["col_names"][0]], party_data[0]["col_names"][1], "std_dev") mult = multiply(std_dev, "mult", party_data[0]["col_names"][0], [party_data[0]["col_names"][1], 7]) collect(mult, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushDown() pd.rewrite(d) pu = PushUp() pu.rewrite(d) compare_to_expected(d, expected)
def test_three_datasets_concat(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) cols_in_three = create_cols(party_data[2]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) rel_three = create("in3", cols_in_three, party_data[2]["stored_with"]) cc = concat([rel_one, rel_two, rel_three], "cc", party_data[0]["col_names"]) agg = aggregate(cc, "agg", party_data[0]["col_names"][:1], party_data[0]["col_names"][1], "mean") div = divide(agg, "div", party_data[0]["col_names"][1], [10]) collect(div, {1, 2, 3}) d = Dag({rel_one, rel_two, rel_three}) compile_dag(d) p = HeuristicPart(d) parts = p.partition(100) compare_partition_to_expected(parts, expected)
def construct_dag(roots: set): return Dag(roots)
def rewrite(self, dag: Dag): ordered = dag.top_sort() if self.reverse: ordered = ordered[::-1] for node in ordered: if isinstance(node, Create): self._rewrite_create(node) elif isinstance(node, AggregateSum): self._rewrite_aggregate_sum(node) elif isinstance(node, AggregateCount): self._rewrite_aggregate_count(node) elif isinstance(node, AggregateMean): self._rewrite_aggregate_mean(node) elif isinstance(node, AggregateStdDev): self._rewrite_aggregate_std_dev(node) elif isinstance(node, AggregateVariance): self._rewrite_aggregate_variance(node) elif isinstance(node, Project): self._rewrite_project(node) elif isinstance(node, Add): self._rewrite_add(node) elif isinstance(node, Subtract): self._rewrite_subtract(node) elif isinstance(node, Multiply): self._rewrite_multiply(node) elif isinstance(node, Divide): self._rewrite_divide(node) elif isinstance(node, Limit): self._rewrite_limit(node) elif isinstance(node, Distinct): self._rewrite_distinct(node) elif isinstance(node, FilterAgainstCol): self._rewrite_filter_against_col(node) elif isinstance(node, FilterAgainstScalar): self._rewrite_filter_against_scalar(node) elif isinstance(node, SortBy): self._rewrite_sort_by(node) elif isinstance(node, NumRows): self._rewrite_num_rows(node) elif isinstance(node, Collect): self._rewrite_collect(node) elif isinstance(node, Join): self._rewrite_join(node) elif isinstance(node, Concat): self._rewrite_concat(node) elif isinstance(node, Open): self._rewrite_open(node) elif isinstance(node, Close): self._rewrite_close(node) elif isinstance(node, Store): self._rewrite_store(node) elif isinstance(node, Read): self._rewrite_read(node) elif isinstance(node, Persist): self._rewrite_persist(node) elif isinstance(node, Send): self._rewrite_send(node) elif isinstance(node, Index): self._rewrite_index(node) elif isinstance(node, Shuffle): self._rewrite_shuffle(node) elif isinstance(node, AggregateSumCountCol): self._rewrite_aggregate_sum_count_col(node) elif isinstance(node, AggregateSumSquaresAndCount): self._rewrite_aggregate_sum_squares_and_count(node) elif isinstance(node, AggregateStdDevLocalSqrt): self._rewrite_aggregate_std_dev_local_sqrt(node) elif isinstance(node, AggregateVarianceLocalDiff): self._rewrite_aggregate_variance_local_diff(node) elif isinstance(node, ColSum): self._rewrite_col_sum(node) elif isinstance(node, MemberFilter): self._rewrite_member_filter(node) elif isinstance(node, ColumnUnion): self._rewrite_column_union(node) else: raise Exception(f"Unknown class {type(node).__name__}.")