def test_tree_sum_noise_efficient(self, total_steps, noise_std, variable_shape, tolerance): # Test the variance returned by `EfficientTreeAggregator` is smaller than # `TreeAggregator` (within a relative `tolerance`) after `total_steps` of # leaf nodes are traversed. Each tree node is a `variable_shape` tensor of # Gaussian noise with `noise_std`. A small `tolerance` is used for numerical # stability, `tolerance==0` means `EfficientTreeAggregator` is strictly # better than `TreeAggregator` for reducing variance. random_generator = tree_aggregation.GaussianNoiseGenerator( noise_std, tf.TensorSpec(variable_shape)) tree_aggregator = tree_aggregation.EfficientTreeAggregator( value_generator=random_generator) tree_aggregator_baseline = tree_aggregation.TreeAggregator( value_generator=random_generator) state = tree_aggregator.init_state() state_baseline = tree_aggregator_baseline.init_state() for leaf_node_idx in range(total_steps): self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state)) val, state = tree_aggregator.get_cumsum_and_update(state) val_baseline, state_baseline = tree_aggregator_baseline.get_cumsum_and_update( state_baseline) self.assertLess( tf.math.reduce_variance(val), (1 + tolerance) * tf.math.reduce_variance(val_baseline))
def test_tree_sum_steps_max(self, total_steps, node_value): tree_aggregator = tree_aggregation.TreeAggregator( value_generator=ConstantValueGenerator(node_value)) max_val = node_value * math.ceil(math.log2(total_steps)) state = tree_aggregator.init_state() for leaf_node_idx in range(total_steps): self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state)) val, state = tree_aggregator.get_cumsum_and_update(state) self.assertLessEqual(val, max_val)
def test_tree_sum_last_step_expected_value_fn(self, total_steps, expected_value, node_value): # Test no-arg function as stateless value generator. tree_aggregator = tree_aggregation.TreeAggregator( value_generator=lambda: node_value) state = tree_aggregator.init_state() for leaf_node_idx in range(total_steps): self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state)) val, state = tree_aggregator.get_cumsum_and_update(state) self.assertEqual(expected_value, val)
def test_tree_sum_last_step_expected(self, total_steps, expected_value, node_value): # Test whether `tree_aggregator` will output `expected_value` after # `total_steps` of leaf nodes are traversed. The value of each tree node # is a constant `node_value` for test purpose. Note that `node_value` # denotes the "noise" without private values in private algorithms. tree_aggregator = tree_aggregation.TreeAggregator( value_generator=ConstantValueGenerator(node_value)) state = tree_aggregator.init_state() for leaf_node_idx in range(total_steps): self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state)) val, state = tree_aggregator.get_cumsum_and_update(state) self.assertEqual(expected_value, val)
def test_cumsum_vector(self, total_steps=15): tree_aggregator = tree_aggregation.EfficientTreeAggregator( value_generator=ConstantValueGenerator([ tf.ones([2, 2], dtype=tf.float32), tf.constant([2], dtype=tf.float32) ])) tree_aggregator_truth = tree_aggregation.EfficientTreeAggregator( value_generator=ConstantValueGenerator(1.)) state = tree_aggregator.init_state() truth_state = tree_aggregator_truth.init_state() for leaf_node_idx in range(total_steps): self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state)) val, state = tree_aggregator.get_cumsum_and_update(state) expected_val, truth_state = tree_aggregator_truth.get_cumsum_and_update( truth_state) self.assertEqual( tree_aggregation.get_step_idx(state), tree_aggregation.get_step_idx(truth_state)) expected_result = [ expected_val * tf.ones([2, 2], dtype=tf.float32), expected_val * tf.constant([2], dtype=tf.float32), ] tf.nest.map_structure(self.assertAllClose, val, expected_result)
def test_tree_sum_last_step_expected(self, total_steps, expected_value, step_value): # Test whether `tree_aggregator` will output `expected_value` after # `total_steps` of leaf nodes are traversed. The value of each tree node # is a constant `node_value` for test purpose. Note that `node_value` # denotes the "noise" without private values in private algorithms. The # `expected_value` is based on a weighting schema strongly depends on the # depth of the binary tree. tree_aggregator = tree_aggregation.EfficientTreeAggregator( value_generator=ConstantValueGenerator(step_value)) state = tree_aggregator.init_state() for leaf_node_idx in range(total_steps): self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state)) val, state = tree_aggregator.get_cumsum_and_update(state) self.assertAllClose(expected_value, val)
def test_tree_sum_noise_expected(self, total_steps, expected_variance, noise_std, variable_shape, tolerance): # Test whether `tree_aggregator` will output `expected_variance` (within a # relative `tolerance`) after `total_steps` of leaf nodes are traversed. # Each tree node is a `variable_shape` tensor of Gaussian noise with # `noise_std`. Note that the variance of a tree node is smaller than # the given vanilla node `noise_std` because of the update rule of # `EfficientTreeAggregator`. random_generator = tree_aggregation.GaussianNoiseGenerator( noise_std, tf.TensorSpec(variable_shape), seed=2020) tree_aggregator = tree_aggregation.EfficientTreeAggregator( value_generator=random_generator) state = tree_aggregator.init_state() for leaf_node_idx in range(total_steps): self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state)) val, state = tree_aggregator.get_cumsum_and_update(state) self.assertAllClose( math.sqrt(expected_variance), tf.math.reduce_std(val), rtol=tolerance)
def test_tree_sum_noise_expected(self, total_steps, expected_variance, noise_std, variable_shape, tolerance): # Test whether `tree_aggregator` will output `expected_variance` (within a # relative `tolerance`) in each step when `total_steps` of leaf nodes are # traversed. Each tree node is a `variable_shape` tensor of Gaussian noise # with `noise_std`. random_generator = tree_aggregation.GaussianNoiseGenerator( noise_std, tf.TensorSpec(variable_shape), seed=2020) tree_aggregator = tree_aggregation.TreeAggregator( value_generator=random_generator) state = tree_aggregator.init_state() for leaf_node_idx in range(total_steps): self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state)) val, state = tree_aggregator.get_cumsum_and_update(state) self.assertAllClose( math.sqrt(expected_variance[leaf_node_idx]), tf.math.reduce_std(val), rtol=tolerance)