def test_adding_mismatched_accumulator_length_raises_exception(self): mean_acc1 = MeanAccumulator(params=[], values=[11]) sum_squares_acc1 = SumOfSquaresAccumulator(params=[], values=[1]) mean_acc2 = MeanAccumulator(params=[], values=[22]) base_compound_accumulator = accumulator.CompoundAccumulator( [mean_acc1, sum_squares_acc1]) to_add_compound_accumulator = accumulator.CompoundAccumulator( [mean_acc2]) with self.assertRaises(ValueError) as context: base_compound_accumulator.add_accumulator( to_add_compound_accumulator) self.assertEqual( "Accumulators in the input are not of the same size. " "Expected size = 2 received size = 1.", str(context.exception))
def test_privacy_id_count(self): compound_accumulator1 = accumulator.CompoundAccumulator([]) compound_accumulator2 = accumulator.CompoundAccumulator([]) # Freshly created CompoundAccumulator has data of one privacy id. self.assertEqual(1, compound_accumulator1.privacy_id_count) # The call of add_value does not change number of privacy ids. compound_accumulator1.add_value(3) self.assertEqual(1, compound_accumulator1.privacy_id_count) # The count of privacy ids after addition is the sum of privacy id # counts because the assumption is that different CompoundAccumulator # have data from on-overlapping set of privacy ids. compound_accumulator1.add_accumulator(compound_accumulator2) self.assertEqual(2, compound_accumulator1.privacy_id_count)
def test_adding_mismatched_accumulator_order_raises_exception(self): mean_acc1 = MeanAccumulator(params=[], values=[11]) sum_squares_acc1 = SumOfSquaresAccumulator(params=[], values=[1]) mean_acc2 = MeanAccumulator(params=[], values=[22]) sum_squares_acc2 = SumOfSquaresAccumulator(params=[], values=[2]) base_compound_accumulator = accumulator.CompoundAccumulator( [mean_acc1, sum_squares_acc1]) to_add_compound_accumulator = accumulator.CompoundAccumulator( [sum_squares_acc2, mean_acc2]) with self.assertRaises(TypeError) as context: base_compound_accumulator.add_accumulator( to_add_compound_accumulator) self.assertEqual( "The type of the accumulators don't match at index 0. " "MeanAccumulator != SumOfSquaresAccumulator." "", str(context.exception))
def test_adding_accumulator(self): mean_acc1 = MeanAccumulator(params=None, values=[5]) sum_squares_acc1 = SumOfSquaresAccumulator(params=None, values=[5]) compound_accumulator = accumulator.CompoundAccumulator( [mean_acc1, sum_squares_acc1]) mean_acc2 = MeanAccumulator(params=[], values=[]) sum_squares_acc2 = SumOfSquaresAccumulator(params=[], values=[]) to_be_added_compound_accumulator = accumulator.CompoundAccumulator( [mean_acc2, sum_squares_acc2]) to_be_added_compound_accumulator.add_value(4) compound_accumulator.add_accumulator(to_be_added_compound_accumulator) compound_accumulator.add_value(3) computed_metrics = compound_accumulator.compute_metrics() self.assertEqual(len(computed_metrics), 2) self.assertEqual(computed_metrics, [4, 50])
def test_with_mean_and_sum_squares(self): mean_acc = MeanAccumulator(params=[], values=[]) sum_squares_acc = SumOfSquaresAccumulator(params=[], values=[]) compound_accumulator = accumulator.CompoundAccumulator( [mean_acc, sum_squares_acc]) compound_accumulator.add_value(3) compound_accumulator.add_value(4) computed_metrics = compound_accumulator.compute_metrics() self.assertTrue( isinstance(compound_accumulator, accumulator.CompoundAccumulator)) self.assertEqual(len(computed_metrics), 2) self.assertEqual(computed_metrics, [3.5, 25])
def test_serialization_compound_accumulator(self): mean_acc = MeanAccumulator(params=[], values=[15]) sum_squares_acc = SumOfSquaresAccumulator(params=[], values=[1]) compound_accumulator = accumulator.CompoundAccumulator( [mean_acc, sum_squares_acc]) serialized_obj = compound_accumulator.serialize() deserialized_obj = accumulator.Accumulator.deserialize(serialized_obj) self.assertIsInstance(deserialized_obj, accumulator.CompoundAccumulator) self.assertEqual(len(deserialized_obj.accumulators), 2) self.assertIsInstance(deserialized_obj.accumulators[0], MeanAccumulator) self.assertIsInstance(deserialized_obj.accumulators[1], SumOfSquaresAccumulator) self.assertEqual(deserialized_obj.compute_metrics(), compound_accumulator.compute_metrics())