コード例 #1
0
 async def test_single_sample_with_some_depth(self):
     samples = ObliviousDataset.create(Sample([s(1)], s(1)))
     self.assertEqual(
         reveal(await train(samples, depth=2)),
         Branch(0,
                threshold=0,
                left=Branch(0, threshold=0, left=pruned(), right=pruned()),
                right=Branch(0, threshold=0, left=pruned(), right=leaf(1))))
コード例 #2
0
 def test_select_best_continuous_attribute(self):
     samples = ObliviousDataset.create(Sample([s(1)], s(0)),
                                       Sample([s(2)], s(0)),
                                       Sample([s(3)], s(0)),
                                       Sample([s(4)], s(1)),
                                       Sample([s(5)], s(1)),
                                       continuous=[True])
     (best_attribute, threshold) = select_best_attribute(samples)
     self.assertEqual(reveal(best_attribute), 0)
     self.assertEqual(reveal(threshold), 3)
コード例 #3
0
 def test_column_with_public_index(self):
     dataset = ObliviousDataset.create(sample(s(0), s(1), s(2)),
                                       sample(s(10), s(11), s(12)),
                                       sample(s(20), s(21), s(22)))
     self.assertEqual(reveal(dataset.column(0)), [0, 10, 20])
     self.assertEqual(reveal(dataset.column(1)), [1, 11, 21])
     self.assertEqual(reveal(dataset.column(2)), [2, 12, 22])
コード例 #4
0
 def test_continuous_attributes(self):
     dataset = ObliviousDataset.create(Sample([s(0), s(1), s(1)], s(0)),
                                       Sample([s(1), s(2), s(1)], s(1)),
                                       continuous=[False, True, False])
     self.assertFalse(dataset.is_continuous(0))
     self.assertTrue(dataset.is_continuous(1))
     self.assertFalse(dataset.is_continuous(2))
コード例 #5
0
def calculate_gains_for_thresholds(column, outcomes):
    gains = column.map(lambda _: None)
    is_right = column.map(lambda _: s(0))
    selection = [None for _ in range(len(column.values))]
    last_considered_value = s(-1)
    for index in reversed(range(len(column.values))):
        gains.values[index] = calculate_gain(is_right, outcomes)
        is_right.values[index] = s(1)
        is_duplicate = column.values[index] == last_considered_value
        selection[index] = ~is_duplicate
        last_considered_value = mpc.if_else(column.is_included(index),
                                            column.values[index],
                                            last_considered_value)
    return gains.select(selection)
コード例 #6
0
 def test_random_sample(self):
     dataset = ObliviousDataset.create(
         Sample([s(1), s(2), s(3)], s(4)),
         Sample([s(11), s(12), s(13)], s(14)))
     randomSamples = [reveal(dataset.choice()) for _ in range(100)]
     self.assertIn(Sample([1, 2, 3], 4), randomSamples)
     self.assertIn(Sample([11, 12, 13], 14), randomSamples)
コード例 #7
0
 async def test_continuous_attribute_with_some_depth(self):
     samples = ObliviousDataset.create(Sample([s(1)], s(0)),
                                       Sample([s(2)], s(0)),
                                       Sample([s(3)], s(1)),
                                       Sample([s(4)], s(1)),
                                       Sample([s(5)], s(0)),
                                       continuous=[True])
     tree = reveal(await train(samples, depth=2))
     self.assertEqual(tree.attribute, 0)
     self.assertEqual(tree.threshold, 2)
     self.assertTrue(isinstance(tree.left, Branch))
     self.assertTrue(isinstance(tree.right, Branch))
     self.assertEqual(tree.right.attribute, 0)
     self.assertEqual(tree.right.threshold, 4)
コード例 #8
0
 def test_sorts_column_and_outcomes_of_array(self):
     dataset = ObliviousDataset.create(
         Sample([s(2)], s(5)),
         Sample([s(1)], s(6)),
         Sample([s(3)], s(7)),
         Sample([s(4)], s(8))
     )
     column = dataset.column(s(0))
     outcomes = dataset.outcomes
     sorted_column, sorted_outcomes = sort(column, outcomes)
     self.assertEqual(reveal(sorted_column), [1, 2, 3, 4])
     self.assertEqual(reveal(sorted_outcomes), [6, 5, 7, 8])
コード例 #9
0
def maximum(quotients):
    """
    Returns both the maximum quotient and the index of the maximum in an
    oblivious sequence.

    Only works for quotients that have positive numerator and denominator.
    """
    def max(previous, current):
        (maximum, index_of_maximum, index) = previous

        is_new_maximum = ge_quotient(current, maximum)
        index_of_maximum = if_else(is_new_maximum, index, index_of_maximum)
        maximum = tuple(if_else(is_new_maximum, list(current), list(maximum)))

        return (maximum, index_of_maximum, index + 1)

    neutral = (s(0), s(0))
    initial = (neutral, s(0), s(0))
    maximum, index_of_maximum, _ = quotients.reduce(neutral, max, initial)
    return maximum, index_of_maximum
コード例 #10
0
 def test_gini_gain_mpc(self):
     numerator, denominator = gini_gain_quotient(
         s(2), s(2), s(1), s(1), s(1), s(1))
     numerator = reveal(numerator)
     denominator = reveal(denominator)
     total = 4
     gain = (1 / total) * float(numerator / denominator)
     self.assertEqual(gain, 0.5)
コード例 #11
0
 def test_reveal_branches(self):
     tree = Branch(s(0),
                   threshold=s(10),
                   left=Leaf(s(1), pruned=s(False)),
                   right=Leaf(s(2), pruned=s(True)))
     expected_output = Branch(0,
                              threshold=10,
                              left=Leaf(1, pruned=False),
                              right=Leaf(2, pruned=True))
     self.assertEqual(reveal(tree), expected_output)
コード例 #12
0
 def test_calculate_gains_for_thresholds_ignores_duplicates(self):
     samples = ObliviousDataset.create(Sample([s(0)], s(0)),
                                       Sample([s(0)], s(0)),
                                       Sample([s(0)], s(0)),
                                       continuous=[True])
     column = samples.column(0)
     outcomes = samples.outcomes
     gains = calculate_gains_for_thresholds(column, outcomes)
     plain_gains = reveal(gains)
     self.assertEqual(len(plain_gains), 1)
コード例 #13
0
def calculate_gains(samples):
    number_of_attributes = samples.number_of_attributes

    gains = []
    thresholds = []
    outcomes = samples.outcomes
    for attribute in range(number_of_attributes):
        column = samples.column(attribute)
        if samples.is_continuous(attribute):
            s_column, s_outcomes = sort(column, outcomes)
            (gain, threshold) = select_best_threshold(s_column, s_outcomes)
            gains.append(gain)
            thresholds.append(threshold)
        else:
            gain = calculate_gain_for_attribute(column, outcomes)
            gains.append(gain)
            thresholds.append(s(0))

    return gains, thresholds
コード例 #14
0
 def test_output_sec_int(self):
     self.assertEqual(reveal(s(42)), 42)
コード例 #15
0
 def test_column_of_subset_with_secret_index(self):
     dataset = ObliviousDataset.create(sample(s(0), s(1), s(2)),
                                       sample(s(10), s(11), s(12)),
                                       sample(s(20), s(21), s(22))).select(
                                           [s(1), s(0), s(1)])
     self.assertEqual(reveal(dataset.column(s(1))), [1, 21])
コード例 #16
0
def sample(ins, out):
    return Sample([s(i) for i in ins], s(out))
コード例 #17
0
 def test_random_sample_with_one_sample(self):
     dataset = ObliviousDataset.create(Sample([s(1), s(2), s(3)], s(4)))
     self.assertEqual(reveal(dataset.choice()), Sample([1, 2, 3], 4))
コード例 #18
0
def pruned():
    return Leaf(s(0), s(True))
コード例 #19
0
 def test_classify_with_continuous_attribute(self):
     tree = Branch(s(1), s(2), leaf(s(1)), leaf(s(0)))
     self.assertEqual(reveal(classify([s(1), s(2), s(1)], tree)), 1)
     self.assertEqual(reveal(classify([s(1), s(3), s(1)], tree)), 0)
コード例 #20
0
 def test_determine_class_multiple_samples(self):
     dataset = ObliviousDataset.create(Sample([s(0)], s(0)),
                                       Sample([s(0)], s(1)),
                                       Sample([s(0)], s(1)))
     self.assertEqual(reveal(dataset.determine_class()), 1)
コード例 #21
0
def sample(*inputs):
    return Sample(inputs, s(0))
コード例 #22
0
 def test_outcomes(self):
     dataset = ObliviousDataset.create(
         Sample([s(0), s(1), s(2)], outcome=s(60)),
         Sample([s(10), s(11), s(12)], outcome=s(70)),
         Sample([s(20), s(21), s(22)],
                outcome=s(80))).select([s(1), s(0), s(1)])
     self.assertEqual(reveal(dataset.outcomes), [60, 80])
コード例 #23
0
 def test_classify_with_a_branch(self):
     tree = Branch(s(1), s(0), leaf(s(1)), leaf(s(0)))
     self.assertEqual(reveal(classify([s(1), s(0), s(1)], tree)), 1)
     self.assertEqual(reveal(classify([s(1), s(1), s(1)], tree)), 0)
コード例 #24
0
 def test_tuple(self):
     chosen = if_else(s(True), (s(1), s(2)), (s(3), s(4)))
     self.assertEqual(reveal(chosen), (1, 2))
コード例 #25
0
 def test_classify_with_pruned_subtree(self):
     tree = Branch(s(1), s(0), Branch(s(0), s(0), pruned(), pruned()),
                   Branch(s(2), s(0), leaf(s(0)), leaf(s(1))))
     self.assertEqual(reveal(classify([s(0), s(0), s(1)], tree)), 1)
コード例 #26
0
 def test_number_of_attributes(self):
     dataset = ObliviousDataset.create(sample(s(1), s(2), s(3)),
                                       sample(s(4), s(5), s(6)))
     self.assertEqual(dataset.number_of_attributes, 3)
コード例 #27
0
def leaf(outcome):
    return Leaf(outcome, s(False))
コード例 #28
0
 def test_add_samples(self):
     sample1 = Sample([s(1), s(2), s(3)], s(4))
     sample2 = Sample([s(5), s(6), s(7)], s(8))
     self.assertEqual(reveal(sample1 + sample2), Sample([6, 8, 10], 12))
コード例 #29
0
 def test_classify_with_only_leaf_node(self):
     sample = [s(1), s(0), s(1)]
     tree = leaf(s(1))
     self.assertEqual(reveal(classify(sample, tree)), 1)
コード例 #30
0
 def test_continuous_attribute_check_with_secret_index(self):
     dataset = ObliviousDataset.create(Sample([s(0), s(1), s(1)], s(0)),
                                       Sample([s(1), s(2), s(1)], s(1)),
                                       continuous=[False, True, False])
     self.assertFalse(reveal(dataset.is_continuous(s(0))))
     self.assertTrue(reveal(dataset.is_continuous(s(1))))
     self.assertFalse(reveal(dataset.is_continuous(s(2))))