def test_dp_sum_structure_list(self): query = privacy.GaussianSumQuery(5.0, 0.0) def _value_type_fn(value): del value return [tff.TensorType(tf.float32), tff.TensorType(tf.float32)] def _from_anon_tuple_fn(record): return list(record) dp_aggregate_fn, _ = differential_privacy.build_dp_aggregate( query, value_type_fn=_value_type_fn, from_anon_tuple_fn=_from_anon_tuple_fn) def datapoint(a, b): return [tf.Variable(a, name='a'), tf.Variable(b, name='b')] data = [ datapoint(1.0, 2.0), datapoint(2.0, 3.0), datapoint(6.0, 8.0), # Clipped to 3.0, 4.0 ] initialize, aggregate = wrap_aggregate_fn(dp_aggregate_fn, data[0]) global_state = initialize() global_state, result = aggregate(global_state, data) self.assertEqual(getattr(global_state, 'l2_norm_clip'), 5.0) self.assertEqual(getattr(global_state, 'stddev'), 0.0) result = list(result) self.assertEqual(result[0], 6.0) self.assertEqual(result[1], 9.0)
def test_dp_stateful_mean(self): class ShrinkingSumQuery(privacy.GaussianSumQuery): def get_noised_result(self, sample_state, global_state): global_state = self._GlobalState( tf.maximum(global_state.l2_norm_clip - 1, 0.0), global_state.stddev) return sample_state, global_state query = ShrinkingSumQuery(4.0, 0.0) dp_aggregate_fn, _ = differential_privacy.build_dp_aggregate(query) initialize, aggregate = wrap_aggregate_fn(dp_aggregate_fn, 0.0) global_state = initialize() records = [1.0, 3.0, 5.0] def run_and_check(global_state, expected_l2_norm_clip, expected_result): global_state, result = aggregate(global_state, records) self.assertEqual(getattr(global_state, 'l2_norm_clip'), expected_l2_norm_clip) self.assertEqual(result, expected_result) return global_state self.assertEqual(getattr(global_state, 'l2_norm_clip'), 4.0) global_state = run_and_check(global_state, 3.0, 8.0) global_state = run_and_check(global_state, 2.0, 7.0) global_state = run_and_check(global_state, 1.0, 5.0) global_state = run_and_check(global_state, 0.0, 3.0) global_state = run_and_check(global_state, 0.0, 0.0)
def test_dp_global_state_type(self): query = tensorflow_privacy.GaussianSumQuery(5.0, 0.0) _, dp_global_state_type = differential_privacy.build_dp_aggregate(query) self.assertIsInstance(dp_global_state_type, computation_types.StructWithPythonType)
def test_dp_global_state_type(self): query = tensorflow_privacy.GaussianSumQuery(5.0, 0.0) _, dp_global_state_type = differential_privacy.build_dp_aggregate(query) self.assertEqual(dp_global_state_type.__class__.__name__, 'NamedTupleTypeWithPyContainerType')
def test_dp_sum(self): query = privacy.GaussianSumQuery(4.0, 0.0) dp_aggregate_fn, _ = differential_privacy.build_dp_aggregate(query) initialize, aggregate = wrap_aggregate_fn(dp_aggregate_fn, 0.0) global_state = initialize() global_state, result = aggregate(global_state, [1.0, 3.0, 5.0]) self.assertEqual(getattr(global_state, 'l2_norm_clip'), 4.0) self.assertEqual(getattr(global_state, 'stddev'), 0.0) self.assertEqual(result, 8.0)
def test_dp_sum_structure_odict(self): query = privacy.GaussianSumQuery(5.0, 0.0) dp_aggregate_fn, _ = differential_privacy.build_dp_aggregate(query) def datapoint(a, b): return collections.OrderedDict([('a', (a, )), ('b', [b])]) data = [ datapoint(1.0, 2.0), datapoint(2.0, 3.0), datapoint(6.0, 8.0), # Clipped to 3.0, 4.0 ] initialize, aggregate = wrap_aggregate_fn(dp_aggregate_fn, data[0]) global_state = initialize() global_state, result = aggregate(global_state, data) self.assertEqual(getattr(global_state, 'l2_norm_clip'), 5.0) self.assertEqual(getattr(global_state, 'stddev'), 0.0) self.assertEqual(getattr(result, 'a')[0], 6.0) self.assertEqual(getattr(result, 'b')[0], 9.0)
def test_dp_sum_structure_odict(self): query = tensorflow_privacy.GaussianSumQuery(5.0, 0.0) dp_aggregate_fn, _ = differential_privacy.build_dp_aggregate(query) def datapoint(a, b): return collections.OrderedDict(a=(a, ), b=[b]) data = [ datapoint(1.0, 2.0), datapoint(2.0, 3.0), datapoint(6.0, 8.0), # Clipped to 3.0, 4.0 ] initialize, aggregate = wrap_aggregate_fn(dp_aggregate_fn, data[0]) global_state = initialize() global_state, result = aggregate(global_state, data) self.assertEqual(global_state.l2_norm_clip, 5.0) self.assertEqual(global_state.stddev, 0.0) self.assertEqual(result['a'][0], 6.0) self.assertEqual(result['b'][0], 9.0)