def test_dp_sum_structure_list(self): query = privacy.GaussianSumQuery(5.0, 0.0) def _value_type_fn(value): del value return [tff.TensorType(tf.float32), tff.TensorType(tf.float32)] def _from_anon_tuple_fn(record): return list(record) dp_aggregate_fn, _ = differential_privacy.build_dp_aggregate( query, value_type_fn=_value_type_fn, from_anon_tuple_fn=_from_anon_tuple_fn) def datapoint(a, b): return [tf.Variable(a, name='a'), tf.Variable(b, name='b')] data = [ datapoint(1.0, 2.0), datapoint(2.0, 3.0), datapoint(6.0, 8.0), # Clipped to 3.0, 4.0 ] initialize, aggregate = wrap_aggregate_fn(dp_aggregate_fn, data[0]) global_state = initialize() global_state, result = aggregate(global_state, data) self.assertEqual(getattr(global_state, 'l2_norm_clip'), 5.0) self.assertEqual(getattr(global_state, 'stddev'), 0.0) result = list(result) self.assertEqual(result[0], 6.0) self.assertEqual(result[1], 9.0)
def test_dp_global_state_type(self): query = privacy.GaussianSumQuery(5.0, 0.0) _, dp_global_state_type = differential_privacy.build_dp_aggregate( query) self.assertEqual(dp_global_state_type.__class__.__name__, 'NamedTupleTypeWithPyContainerType')
def test_dp_sum(self): query = privacy.GaussianSumQuery(4.0, 0.0) dp_aggregate_fn, _ = differential_privacy.build_dp_aggregate(query) initialize, aggregate = wrap_aggregate_fn(dp_aggregate_fn, 0.0) global_state = initialize() global_state, result = aggregate(global_state, [1.0, 3.0, 5.0]) self.assertEqual(getattr(global_state, 'l2_norm_clip'), 4.0) self.assertEqual(getattr(global_state, 'stddev'), 0.0) self.assertEqual(result, 8.0)
def test_dp_sum_structure_odict(self): query = privacy.GaussianSumQuery(5.0, 0.0) dp_aggregate_fn, _ = differential_privacy.build_dp_aggregate(query) def datapoint(a, b): return collections.OrderedDict([('a', (a, )), ('b', [b])]) data = [ datapoint(1.0, 2.0), datapoint(2.0, 3.0), datapoint(6.0, 8.0), # Clipped to 3.0, 4.0 ] initialize, aggregate = wrap_aggregate_fn(dp_aggregate_fn, data[0]) global_state = initialize() global_state, result = aggregate(global_state, data) self.assertEqual(getattr(global_state, 'l2_norm_clip'), 5.0) self.assertEqual(getattr(global_state, 'stddev'), 0.0) self.assertEqual(getattr(result, 'a')[0], 6.0) self.assertEqual(getattr(result, 'b')[0], 9.0)