def test_concat_composite_three_party_agg_sum(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])
    cols_in_three = create_cols(party_data[2])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])
    rel_three = create("in3", cols_in_three, party_data[2]["stored_with"])

    cc = concat([rel_one, rel_two, rel_three], "concat",
                party_data[0]["col_names"])
    mult = multiply(cc, "mult", party_data[0]["col_names"][0],
                    [party_data[0]["col_names"][1], 5])
    agg = aggregate(mult, "agg", [party_data[0]["col_names"][0]],
                    party_data[0]["col_names"][1], "sum")
    p = project(agg, "proj", [party_data[0]["col_names"][0]])
    collect(p, {1, 2, 3})

    d = Dag({rel_one, rel_two, rel_three})
    pd = PushDown()
    pd.rewrite(d)
    pu = PushUp()
    pu.rewrite(d)
    ic = InsertCloseOps()
    ic.rewrite(d)
    io = InsertOpenOps()
    io.rewrite(d)

    f = d.top_sort()

    compare_to_expected(d, expected)
Ejemplo n.º 2
0
def test_join_composite(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])

    j = join(rel_one, rel_two, "join", party_data[0]["col_names"][:1],
             party_data[1]["col_names"][:1])
    d = divide(j, "div", party_data[0]["col_names"][1],
               [party_data[1]["col_names"][1]])
    collect(d, {1, 2})

    d = Dag({rel_one, rel_two})
    pd = PushDown()
    pd.rewrite(d)
    pu = PushUp()
    pu.rewrite(d)
    ic = InsertCloseOps()
    ic.rewrite(d)
    io = InsertOpenOps()
    io.rewrite(d)

    compare_to_expected(d, expected)
Ejemplo n.º 3
0
def test_join_composite_agg(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])

    j = join(rel_one, rel_two, "join", party_data[0]["col_names"][:1],
             party_data[1]["col_names"][:1])
    agg = aggregate(j, "agg", [party_data[0]["col_names"][0]],
                    party_data[0]["col_names"][1], "mean")
    d = multiply(agg, "mult", party_data[0]["col_names"][1], [7])
    collect(d, {1, 2})

    d = Dag({rel_one, rel_two})
    pd = PushDown()
    pd.rewrite(d)
    pu = PushUp()
    pu.rewrite(d)
    ic = InsertCloseOps()
    ic.rewrite(d)
    io = InsertOpenOps()
    io.rewrite(d)

    compare_to_expected(d, expected)
def test_concat_composite(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])

    mult_one = multiply(rel_one, "mult_one", party_data[0]["col_names"][0],
                        [party_data[0]["col_names"][1], 5])
    div_two = divide(rel_two, "div_two", party_data[1]["col_names"][0],
                     [party_data[1]["col_names"][1], 5])
    cc = concat([mult_one, div_two], "concat", party_data[0]["col_names"])
    agg = aggregate(cc, "agg", [party_data[0]["col_names"][0]],
                    party_data[0]["col_names"][1], "sum")
    p = project(agg, "proj", [party_data[0]["col_names"][0]])
    collect(p, {1, 2})

    d = Dag({rel_one, rel_two})
    ic = InsertCloseOps()
    ic.rewrite(d)
    io = InsertOpenOps()
    io.rewrite(d)

    compare_to_expected(d, expected)
Ejemplo n.º 5
0
    def compile(protocol: callable, enable_optimizations: [bool, None] = True):

        dag = Dag(protocol())
        if enable_optimizations:
            compile_dag(dag)
        else:
            compile_dag_without_optimizations(dag)

        return dag
Ejemplo n.º 6
0
    def get_next_partition(self, dag: Dag, candidate_roots: set):

        # there will always be a next root, bc of while loop condition in parent method
        next_root = dag.roots.pop()
        roots_in_partition = {next_root}

        possible_roots = self._get_next_partition(next_root,
                                                  next_root.requires_mpc(),
                                                  set())
        candidate_roots = candidate_roots.union(possible_roots)

        roots_in_partition, candidate_roots = \
            self._iterate_over_remaining_roots(dag, next_root, candidate_roots, roots_in_partition)

        if not dag.roots and candidate_roots:
            dag.roots = candidate_roots
            return Dag(roots_in_partition), set()
        else:
            return Dag(roots_in_partition), candidate_roots
def test_project_reorder_cols(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])

    cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"])
    p = project(cc, "proj", party_data[0]["col_names"][::-1])
    collect(p, {1, 2})

    d = Dag({rel_one, rel_two})
    pd = PushDown()
    pd.rewrite(d)

    compare_to_expected(d, expected)

    zip_col_names = zip(d.top_sort(), [e["col_names"] for e in expected["ownership_data"]])
    col_name_checks = [[c.name for c in z[0].out_rel.columns] == z[1] for z in zip_col_names]
    assert all(col_name_checks)
def test_agg_mean(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    agg = aggregate(rel_one, "agg", party_data[0]["col_names"][:1],
                    party_data[0]["col_names"][1], "mean")
    div = divide(agg, "div", party_data[0]["col_names"][1], [10])
    collect(div, {1, 2})

    d = Dag({rel_one})
    compile_dag(d)

    compare_to_expected(d, expected)
def test_join_simple(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])

    j = join(rel_one, rel_two, "join", [rel_one.out_rel.columns[0].name],
             [rel_two.out_rel.columns[0].name])
    collect(j, {1, 2, 3})

    d = Dag({rel_one, rel_two})
    compare_to_expected(d, expected)
Ejemplo n.º 10
0
def test_std_dev_alt_key_col(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])

    cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"])
    std_dev = aggregate(cc, "std_dev", [party_data[0]["col_names"][1]],
                        party_data[0]["col_names"][0], "std_dev")
    collect(std_dev, {1, 2})

    d = Dag({rel_one, rel_two})
    pd = PushUp()
    pd.rewrite(d)

    compare_to_expected(d, expected)

    zip_col_names = zip(d.top_sort(),
                        [e["col_names"] for e in expected["ownership_data"]])
    col_name_checks = [[c.name for c in z[0].out_rel.columns] == z[1]
                       for z in zip_col_names]
    assert all(col_name_checks)
Ejemplo n.º 11
0
def test_single_dataset(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    agg = aggregate(rel_one, "agg", party_data[0]["col_names"][:1],
                    party_data[0]["col_names"][1], "mean")
    div = divide(agg, "div", party_data[0]["col_names"][1], [10])
    collect(div, {1, 2})

    d = Dag({rel_one})
    compile_dag(d)
    p = HeuristicPart(d)
    parts = p.partition(100)

    compare_partition_to_expected(parts, expected)
Ejemplo n.º 12
0
def test_concat_simple(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])
    cols_in_three = create_cols(party_data[2])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])
    rel_three = create("in3", cols_in_three, party_data[2]["stored_with"])

    cc = concat([rel_one, rel_two, rel_three], "concat",
                party_data[0]["col_names"])
    collect(cc, {1, 2, 3})

    d = Dag({rel_one, rel_two, rel_three})
    compare_to_expected(d, expected)
Ejemplo n.º 13
0
def test_concat(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])

    cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"])
    collect(cc, {1, 2})

    d = Dag({rel_one, rel_two})
    ic = InsertCloseOps()
    ic.rewrite(d)

    compare_to_expected(d, expected)
def test_all_stats(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])

    cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"])
    all_stat = all_stats(cc, "all_stat", [party_data[0]["col_names"][0]], party_data[0]["col_names"][1])
    collect(all_stat, {1, 2})

    d = Dag({rel_one, rel_two})
    pd = PushUp()
    pd.rewrite(d)

    compare_to_expected(d, expected)
Ejemplo n.º 15
0
def test_join(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])

    j = join(rel_one, rel_two, "join", [party_data[0]["col_names"][0]],
             [party_data[1]["col_names"][0]])
    collect(j, {1, 2})

    d = Dag({rel_one, rel_two})
    ic = InsertCloseOps()
    ic.rewrite(d)

    compare_to_expected(d, expected)
def test_project(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])

    cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"])
    p = project(cc, "proj", [party_data[0]["col_names"][0]])
    collect(p, {1, 2})

    d = Dag({rel_one, rel_two})
    pd = PushDown()
    pd.rewrite(d)

    compare_to_expected(d, expected)
Ejemplo n.º 17
0
def test_filter_by_scalar(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])

    cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"])
    p = filter_by(cc, "filt", party_data[0]["col_names"][0], ">", 10)
    collect(p, {1, 2})

    d = Dag({rel_one, rel_two})
    pd = PushDown()
    pd.rewrite(d)

    compare_to_expected(d, expected)
def test_variance(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])

    cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"])
    variance = aggregate(cc, "variance", [party_data[0]["col_names"][0]],
                         party_data[0]["col_names"][1], "variance")
    collect(variance, {1, 2})

    d = Dag({rel_one, rel_two})
    pd = PushUp()
    pd.rewrite(d)

    compare_to_expected(d, expected)
Ejemplo n.º 19
0
def test_std_dev_no_key_cols(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])

    cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"])
    std_dev = aggregate(cc, "std_dev", [], party_data[0]["col_names"][0],
                        "std_dev")
    collect(std_dev, {1, 2})

    d = Dag({rel_one, rel_two})
    pd = PushUp()
    pd.rewrite(d)

    compare_to_expected(d, expected)
Ejemplo n.º 20
0
def test_multiply_target_existing(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])

    cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"])
    p = multiply(cc, "mult", party_data[0]["col_names"][0],
                 [party_data[0]["col_names"][1], 10])
    collect(p, {1, 2})

    d = Dag({rel_one, rel_two})
    pd = PushDown()
    pd.rewrite(d)

    compare_to_expected(d, expected)
Ejemplo n.º 21
0
def test_divide(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])

    cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"])
    p = divide(cc, "div", party_data[0]["col_names"][0],
               [party_data[0]["col_names"][1], 10])
    collect(p, {1, 2})

    d = Dag({rel_one, rel_two})
    pu = PushUp()
    pu.rewrite(d)

    compare_to_expected(d, expected)
Ejemplo n.º 22
0
def test_agg_variance(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])

    cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"])
    variance = aggregate(cc, "variance", [party_data[0]["col_names"][0]],
                         party_data[0]["col_names"][1], "variance")
    mult = multiply(variance, "mult", party_data[0]["col_names"][0],
                    [party_data[0]["col_names"][1], 7])
    collect(mult, {1, 2})

    d = Dag({rel_one, rel_two})
    compile_dag(d)

    compare_to_expected(d, expected)
def test_join(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])

    j = join(rel_one, rel_two, "join", [party_data[0]["col_names"][0]], [party_data[1]["col_names"][0]])
    agg = aggregate(j, "agg", [party_data[0]["col_names"][0]], party_data[0]["col_names"][1], "mean")
    div = divide(agg, "div", party_data[0]["col_names"][1], [10])
    collect(div, {1, 2})

    d = Dag({rel_one, rel_two})
    compile_dag(d)
    p = HeuristicPart(d)
    parts = p.partition(100)

    compare_partition_to_expected(parts, expected)
Ejemplo n.º 24
0
def test_agg_std_dev(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])

    cc = concat([rel_one, rel_two], "concat", party_data[0]["col_names"])
    std_dev = aggregate(cc, "std_dev", [party_data[0]["col_names"][0]],
                        party_data[0]["col_names"][1], "std_dev")
    mult = multiply(std_dev, "mult", party_data[0]["col_names"][0],
                    [party_data[0]["col_names"][1], 7])
    collect(mult, {1, 2})

    d = Dag({rel_one, rel_two})
    pd = PushDown()
    pd.rewrite(d)
    pu = PushUp()
    pu.rewrite(d)

    compare_to_expected(d, expected)
Ejemplo n.º 25
0
def test_three_datasets_concat(party_data, expected):

    cols_in_one = create_cols(party_data[0])
    cols_in_two = create_cols(party_data[1])
    cols_in_three = create_cols(party_data[2])

    rel_one = create("in1", cols_in_one, party_data[0]["stored_with"])
    rel_two = create("in2", cols_in_two, party_data[1]["stored_with"])
    rel_three = create("in3", cols_in_three, party_data[2]["stored_with"])

    cc = concat([rel_one, rel_two, rel_three], "cc",
                party_data[0]["col_names"])
    agg = aggregate(cc, "agg", party_data[0]["col_names"][:1],
                    party_data[0]["col_names"][1], "mean")
    div = divide(agg, "div", party_data[0]["col_names"][1], [10])
    collect(div, {1, 2, 3})

    d = Dag({rel_one, rel_two, rel_three})
    compile_dag(d)
    p = HeuristicPart(d)
    parts = p.partition(100)

    compare_partition_to_expected(parts, expected)
Ejemplo n.º 26
0
 def construct_dag(roots: set):
     return Dag(roots)
Ejemplo n.º 27
0
    def rewrite(self, dag: Dag):

        ordered = dag.top_sort()
        if self.reverse:
            ordered = ordered[::-1]

        for node in ordered:
            if isinstance(node, Create):
                self._rewrite_create(node)
            elif isinstance(node, AggregateSum):
                self._rewrite_aggregate_sum(node)
            elif isinstance(node, AggregateCount):
                self._rewrite_aggregate_count(node)
            elif isinstance(node, AggregateMean):
                self._rewrite_aggregate_mean(node)
            elif isinstance(node, AggregateStdDev):
                self._rewrite_aggregate_std_dev(node)
            elif isinstance(node, AggregateVariance):
                self._rewrite_aggregate_variance(node)
            elif isinstance(node, Project):
                self._rewrite_project(node)
            elif isinstance(node, Add):
                self._rewrite_add(node)
            elif isinstance(node, Subtract):
                self._rewrite_subtract(node)
            elif isinstance(node, Multiply):
                self._rewrite_multiply(node)
            elif isinstance(node, Divide):
                self._rewrite_divide(node)
            elif isinstance(node, Limit):
                self._rewrite_limit(node)
            elif isinstance(node, Distinct):
                self._rewrite_distinct(node)
            elif isinstance(node, FilterAgainstCol):
                self._rewrite_filter_against_col(node)
            elif isinstance(node, FilterAgainstScalar):
                self._rewrite_filter_against_scalar(node)
            elif isinstance(node, SortBy):
                self._rewrite_sort_by(node)
            elif isinstance(node, NumRows):
                self._rewrite_num_rows(node)
            elif isinstance(node, Collect):
                self._rewrite_collect(node)
            elif isinstance(node, Join):
                self._rewrite_join(node)
            elif isinstance(node, Concat):
                self._rewrite_concat(node)
            elif isinstance(node, Open):
                self._rewrite_open(node)
            elif isinstance(node, Close):
                self._rewrite_close(node)
            elif isinstance(node, Store):
                self._rewrite_store(node)
            elif isinstance(node, Read):
                self._rewrite_read(node)
            elif isinstance(node, Persist):
                self._rewrite_persist(node)
            elif isinstance(node, Send):
                self._rewrite_send(node)
            elif isinstance(node, Index):
                self._rewrite_index(node)
            elif isinstance(node, Shuffle):
                self._rewrite_shuffle(node)
            elif isinstance(node, AggregateSumCountCol):
                self._rewrite_aggregate_sum_count_col(node)
            elif isinstance(node, AggregateSumSquaresAndCount):
                self._rewrite_aggregate_sum_squares_and_count(node)
            elif isinstance(node, AggregateStdDevLocalSqrt):
                self._rewrite_aggregate_std_dev_local_sqrt(node)
            elif isinstance(node, AggregateVarianceLocalDiff):
                self._rewrite_aggregate_variance_local_diff(node)
            elif isinstance(node, ColSum):
                self._rewrite_col_sum(node)
            elif isinstance(node, MemberFilter):
                self._rewrite_member_filter(node)
            elif isinstance(node, ColumnUnion):
                self._rewrite_column_union(node)
            else:
                raise Exception(f"Unknown class {type(node).__name__}.")