def test_concat_composite_three_party_agg_sum(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) cols_in_three = create_cols(party_data[2]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) rel_three = create("in3", cols_in_three, party_data[2]["stored_with"]) cc = concat([rel_one, rel_two, rel_three], "concat", party_data[0]["col_names"]) mult = multiply(cc, "mult", party_data[0]["col_names"][0], [party_data[0]["col_names"][1], 5]) agg = aggregate(mult, "agg", [party_data[0]["col_names"][0]], party_data[0]["col_names"][1], "sum") p = project(agg, "proj", [party_data[0]["col_names"][0]]) collect(p, {1, 2, 3}) d = Dag({rel_one, rel_two, rel_three}) pd = PushDown() pd.rewrite(d) pu = PushUp() pu.rewrite(d) ic = InsertCloseOps() ic.rewrite(d) io = InsertOpenOps() io.rewrite(d) f = d.top_sort() compare_to_expected(d, expected)
def test_project_reorder_cols(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) p = project(cc, "proj", party_data[0]["col_names"][::-1]) collect(p, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushDown() pd.rewrite(d) compare_to_expected(d, expected) zip_col_names = zip(d.top_sort(), [e["col_names"] for e in expected["ownership_data"]]) col_name_checks = [[c.name for c in z[0].out_rel.columns] == z[1] for z in zip_col_names] assert all(col_name_checks)
def test_std_dev_alt_key_col(party_data, expected): cols_in_one = create_cols(party_data[0]) cols_in_two = create_cols(party_data[1]) rel_one = create("in1", cols_in_one, party_data[0]["stored_with"]) rel_two = create("in2", cols_in_two, party_data[1]["stored_with"]) cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"]) std_dev = aggregate(cc, "std_dev", [party_data[0]["col_names"][1]], party_data[0]["col_names"][0], "std_dev") collect(std_dev, {1, 2}) d = Dag({rel_one, rel_two}) pd = PushUp() pd.rewrite(d) compare_to_expected(d, expected) zip_col_names = zip(d.top_sort(), [e["col_names"] for e in expected["ownership_data"]]) col_name_checks = [[c.name for c in z[0].out_rel.columns] == z[1] for z in zip_col_names] assert all(col_name_checks)
def rewrite(self, dag: Dag): ordered = dag.top_sort() if self.reverse: ordered = ordered[::-1] for node in ordered: if isinstance(node, Create): self._rewrite_create(node) elif isinstance(node, AggregateSum): self._rewrite_aggregate_sum(node) elif isinstance(node, AggregateCount): self._rewrite_aggregate_count(node) elif isinstance(node, AggregateMean): self._rewrite_aggregate_mean(node) elif isinstance(node, AggregateStdDev): self._rewrite_aggregate_std_dev(node) elif isinstance(node, AggregateVariance): self._rewrite_aggregate_variance(node) elif isinstance(node, Project): self._rewrite_project(node) elif isinstance(node, Add): self._rewrite_add(node) elif isinstance(node, Subtract): self._rewrite_subtract(node) elif isinstance(node, Multiply): self._rewrite_multiply(node) elif isinstance(node, Divide): self._rewrite_divide(node) elif isinstance(node, Limit): self._rewrite_limit(node) elif isinstance(node, Distinct): self._rewrite_distinct(node) elif isinstance(node, FilterAgainstCol): self._rewrite_filter_against_col(node) elif isinstance(node, FilterAgainstScalar): self._rewrite_filter_against_scalar(node) elif isinstance(node, SortBy): self._rewrite_sort_by(node) elif isinstance(node, NumRows): self._rewrite_num_rows(node) elif isinstance(node, Collect): self._rewrite_collect(node) elif isinstance(node, Join): self._rewrite_join(node) elif isinstance(node, Concat): self._rewrite_concat(node) elif isinstance(node, Open): self._rewrite_open(node) elif isinstance(node, Close): self._rewrite_close(node) elif isinstance(node, Store): self._rewrite_store(node) elif isinstance(node, Read): self._rewrite_read(node) elif isinstance(node, Persist): self._rewrite_persist(node) elif isinstance(node, Send): self._rewrite_send(node) elif isinstance(node, Index): self._rewrite_index(node) elif isinstance(node, Shuffle): self._rewrite_shuffle(node) elif isinstance(node, AggregateSumCountCol): self._rewrite_aggregate_sum_count_col(node) elif isinstance(node, AggregateSumSquaresAndCount): self._rewrite_aggregate_sum_squares_and_count(node) elif isinstance(node, AggregateStdDevLocalSqrt): self._rewrite_aggregate_std_dev_local_sqrt(node) elif isinstance(node, AggregateVarianceLocalDiff): self._rewrite_aggregate_variance_local_diff(node) elif isinstance(node, ColSum): self._rewrite_col_sum(node) elif isinstance(node, MemberFilter): self._rewrite_member_filter(node) elif isinstance(node, ColumnUnion): self._rewrite_column_union(node) else: raise Exception(f"Unknown class {type(node).__name__}.")