def test_partial_sum_scalar_tree_aggregation(self, scalar_value, tree_node_value): total_steps = 8 query = tree_aggregation_query.TreeCumulativeSumQuery( clip_fn=_get_l2_clip_fn(), clip_value=scalar_value + 1., # no clip noise_generator=lambda: tree_node_value, record_specs=tf.TensorSpec([]), use_efficient=False, ) global_state = query.initial_global_state() params = query.derive_sample_params(global_state) for i in range(total_steps): sample_state = query.initial_sample_state(scalar_value) sample_state = query.accumulate_record(params, sample_state, scalar_value) query_result, global_state, _ = query.get_noised_result( sample_state, global_state) # For each streaming step i , the expected value is roughly # `scalar_value*(i+1) + tree_aggregation(tree_node_value, i)`. # The tree aggregation value can be inferred from the binary # representation of the current step. self.assertEqual( query_result, scalar_value * (i + 1) + tree_node_value * bin(i + 1)[2:].count('1'))
def test_build_l2_gaussian_query(self, records_num, record_dim, noise_multiplier, seed, total_steps, clip, use_efficient): record_specs = tf.TensorSpec(shape=[record_dim]) query = tree_aggregation_query.TreeCumulativeSumQuery.build_l2_gaussian_query( clip_norm=clip, noise_multiplier=noise_multiplier, record_specs=record_specs, noise_seed=seed, use_efficient=use_efficient) reference_query = tree_aggregation_query.TreeCumulativeSumQuery( clip_fn=_get_l2_clip_fn(), clip_value=clip, noise_generator=_get_noise_generator(record_specs, clip * noise_multiplier, seed), record_specs=record_specs, use_efficient=use_efficient) global_state = query.initial_global_state() reference_global_state = reference_query.initial_global_state() for _ in range(total_steps): records = [ tf.random.uniform(shape=[record_dim], maxval=records_num) for _ in range(records_num) ] query_result, global_state = test_utils.run_query( query, records, global_state) reference_query_result, reference_global_state = test_utils.run_query( reference_query, records, reference_global_state) self.assertAllClose(query_result, reference_query_result, rtol=1e-6)
def test_sum_scalar_tree_aggregation_reset(self, scalar_value, tree_node_value, frequency): total_steps = 20 query = tree_aggregation_query.TreeCumulativeSumQuery( clip_fn=_get_l2_clip_fn(), clip_value=scalar_value + 1., # no clip noise_generator=lambda: tree_node_value, record_specs=tf.TensorSpec([]), use_efficient=False) global_state = query.initial_global_state() params = query.derive_sample_params(global_state) for i in range(total_steps): sample_state = query.initial_sample_state(scalar_value) sample_state = query.accumulate_record(params, sample_state, scalar_value) query_result, global_state, _ = query.get_noised_result( sample_state, global_state) if i % frequency == frequency - 1: global_state = query.reset_state(query_result, global_state) # Expected value is the combination of cumsum of signal; sum of trees # that have been reset; current tree sum. The tree aggregation value can # be inferred from the binary representation of the current step. expected = ( scalar_value * (i + 1) + i // frequency * tree_node_value * bin(frequency)[2:].count('1') + tree_node_value * bin(i % frequency + 1)[2:].count('1')) self.assertEqual(query_result, expected)
def test_sum_tree_aggregator_instance(self, use_efficient, tree_class): specs = tf.TensorSpec([]) query = tree_aggregation_query.TreeCumulativeSumQuery( clip_fn=_get_l2_clip_fn(), clip_value=1., noise_generator=_get_noise_fn(specs, 1.), record_specs=specs, use_efficient=use_efficient, ) self.assertIsInstance(query._tree_aggregator, tree_class)
def test_noiseless_query_structure_type_record(self): query = tree_aggregation_query.TreeCumulativeSumQuery( clip_fn=_get_l2_clip_fn(), clip_value=10., noise_generator=_get_no_noise_fn(STRUCTURE_SPECS), record_specs=STRUCTURE_SPECS) query_result, _ = test_utils.run_query(query, [STRUCT_RECORD, STRUCT_RECORD]) expected = tf.nest.map_structure(lambda a, b: a + b, STRUCT_RECORD, STRUCT_RECORD) self.assertAllClose(query_result, expected)
def test_noiseless_query_single_value_type_record(self): record_specs = tf.nest.map_structure( lambda t: tf.TensorSpec(tf.shape(t)), SINGLE_VALUE_RECORDS[0]) query = tree_aggregation_query.TreeCumulativeSumQuery( clip_fn=_get_l2_clip_fn(), clip_value=10., noise_generator=_get_no_noise_fn(record_specs), record_specs=record_specs) query_result, _ = test_utils.run_query(query, SINGLE_VALUE_RECORDS) expected = tf.constant(9.) self.assertAllClose(query_result, expected)
def test_linfty_clips_structure_type_record(self, record, norm_clip): query = tree_aggregation_query.TreeCumulativeSumQuery( clip_fn=_get_l_infty_clip_fn(), clip_value=norm_clip, noise_generator=_get_no_noise_fn(STRUCTURE_SPECS), record_specs=tf.nest.map_structure( lambda t: tf.TensorSpec(tf.shape(t)), record)) global_state = query.initial_global_state() expected_clipped_record = tf.nest.map_structure( lambda t: tf.clip_by_value(t, -norm_clip, norm_clip), record) clipped_record = query.preprocess_record(global_state.clip_value, record) self.assertAllClose(expected_clipped_record, clipped_record)
def test_correct_initial_global_state_struct_type(self): query = tree_aggregation_query.TreeCumulativeSumQuery( clip_fn=_get_l2_clip_fn(), clip_value=10., noise_generator=_get_no_noise_fn(STRUCTURE_SPECS), record_specs=STRUCTURE_SPECS) global_state = query.initial_global_state() self.assertIsInstance(global_state.tree_state, tree_aggregation.TreeState) expected_cum_sum = tf.nest.map_structure( lambda spec: tf.zeros(spec.shape), STRUCTURE_SPECS) self.assertAllClose(expected_cum_sum, global_state.samples_cumulative_sum)
def test_correct_initial_global_state_single_value_type(self): record_specs = tf.nest.map_structure( lambda t: tf.TensorSpec(tf.shape(t)), SINGLE_VALUE_RECORDS[0]) query = tree_aggregation_query.TreeCumulativeSumQuery( clip_fn=_get_l2_clip_fn(), clip_value=10., noise_generator=_get_no_noise_fn(record_specs), record_specs=record_specs) global_state = query.initial_global_state() self.assertIsInstance(global_state.tree_state, tree_aggregation.TreeState) expected_cum_sum = tf.nest.map_structure( lambda spec: tf.zeros(spec.shape), record_specs) self.assertAllClose(expected_cum_sum, global_state.samples_cumulative_sum)
def test_noisy_cumsum_and_state_update(self, records, value_generator): num_trials, vector_size = 10, 100 record_specs = tf.TensorSpec([vector_size]) records = [tf.constant(r, shape=[vector_size]) for r in records] noised_sums = [] for i in range(num_trials): query = tree_aggregation_query.TreeCumulativeSumQuery( clip_fn=_get_l2_clip_fn(), clip_value=10., noise_generator=value_generator(record_specs, seed=i), record_specs=record_specs) query_result, _ = test_utils.run_query(query, records) noised_sums.append(query_result.numpy()) result_stddev = np.std(noised_sums) self.assertNear(result_stddev, NOISE_STD, 0.7) # value for chi-squared test
def test_partial_sum_scalar_no_noise(self, streaming_scalars, clip_norm, partial_sum): query = tree_aggregation_query.TreeCumulativeSumQuery( clip_fn=_get_l2_clip_fn(), clip_value=clip_norm, noise_generator=lambda: 0., record_specs=tf.TensorSpec([]), ) global_state = query.initial_global_state() params = query.derive_sample_params(global_state) for scalar, expected_sum in zip(streaming_scalars, partial_sum): sample_state = query.initial_sample_state(scalar) sample_state = query.accumulate_record(params, sample_state, scalar) query_result, global_state, _ = query.get_noised_result( sample_state, global_state) self.assertEqual(query_result, expected_sum)
def test_l2_clips_structure_type_record(self, record, l2_norm_clip): query = tree_aggregation_query.TreeCumulativeSumQuery( clip_fn=_get_l2_clip_fn(), clip_value=l2_norm_clip, noise_generator=_get_no_noise_fn(STRUCTURE_SPECS), record_specs=tf.nest.map_structure( lambda t: tf.TensorSpec(tf.shape(t)), record)) global_state = query.initial_global_state() record_norm = tf.linalg.global_norm(record) if record_norm > l2_norm_clip: expected_clipped_record = tf.nest.map_structure( lambda t: t * l2_norm_clip / record_norm, record) else: expected_clipped_record = record clipped_record = query.preprocess_record(global_state.clip_value, record) self.assertAllClose(expected_clipped_record, clipped_record)
def test_l2_clips_single_record(self, record, l2_norm_clip): record_specs = tf.nest.map_structure( lambda t: tf.TensorSpec(tf.shape(t)), SINGLE_VALUE_RECORDS[0]) query = tree_aggregation_query.TreeCumulativeSumQuery( clip_fn=_get_l2_clip_fn(), clip_value=l2_norm_clip, noise_generator=_get_no_noise_fn(record_specs), record_specs=record_specs) global_state = query.initial_global_state() record_norm = tf.norm(record) if record_norm > l2_norm_clip: expected_clipped_record = tf.nest.map_structure( lambda t: t * l2_norm_clip / record_norm, record) else: expected_clipped_record = record clipped_record = query.preprocess_record(global_state.clip_value, record) self.assertAllClose(expected_clipped_record, clipped_record)