def test_tree_sum_noise_efficient(self, total_steps, noise_std,
                                    variable_shape, tolerance):
    # Test the variance returned by `EfficientTreeAggregator` is smaller than
    # `TreeAggregator` (within a relative `tolerance`) after `total_steps` of
    # leaf nodes are traversed. Each tree node is a `variable_shape` tensor of
    # Gaussian noise with `noise_std`. A small `tolerance` is used for numerical
    # stability, `tolerance==0` means `EfficientTreeAggregator` is strictly
    # better than `TreeAggregator` for reducing variance.
    random_generator = tree_aggregation.GaussianNoiseGenerator(
        noise_std, tf.TensorSpec(variable_shape))
    tree_aggregator = tree_aggregation.EfficientTreeAggregator(
        value_generator=random_generator)
    tree_aggregator_baseline = tree_aggregation.TreeAggregator(
        value_generator=random_generator)

    state = tree_aggregator.init_state()
    state_baseline = tree_aggregator_baseline.init_state()
    for leaf_node_idx in range(total_steps):
      self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state))
      val, state = tree_aggregator.get_cumsum_and_update(state)
      val_baseline, state_baseline = tree_aggregator_baseline.get_cumsum_and_update(
          state_baseline)
    self.assertLess(
        tf.math.reduce_variance(val),
        (1 + tolerance) * tf.math.reduce_variance(val_baseline))
 def test_tree_sum_steps_max(self, total_steps, node_value):
   tree_aggregator = tree_aggregation.TreeAggregator(
       value_generator=ConstantValueGenerator(node_value))
   max_val = node_value * math.ceil(math.log2(total_steps))
   state = tree_aggregator.init_state()
   for leaf_node_idx in range(total_steps):
     self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state))
     val, state = tree_aggregator.get_cumsum_and_update(state)
     self.assertLessEqual(val, max_val)
 def test_tree_sum_last_step_expected_value_fn(self, total_steps,
                                               expected_value, node_value):
   # Test no-arg function as stateless value generator.
   tree_aggregator = tree_aggregation.TreeAggregator(
       value_generator=lambda: node_value)
   state = tree_aggregator.init_state()
   for leaf_node_idx in range(total_steps):
     self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state))
     val, state = tree_aggregator.get_cumsum_and_update(state)
   self.assertEqual(expected_value, val)
 def test_tree_sum_last_step_expected(self, total_steps, expected_value,
                                      node_value):
   # Test whether `tree_aggregator` will output `expected_value` after
   # `total_steps` of leaf nodes are traversed. The value of each tree node
   # is a constant `node_value` for test purpose. Note that `node_value`
   # denotes the "noise" without private values in private algorithms.
   tree_aggregator = tree_aggregation.TreeAggregator(
       value_generator=ConstantValueGenerator(node_value))
   state = tree_aggregator.init_state()
   for leaf_node_idx in range(total_steps):
     self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state))
     val, state = tree_aggregator.get_cumsum_and_update(state)
   self.assertEqual(expected_value, val)
  def test_cumsum_vector(self, total_steps=15):

    tree_aggregator = tree_aggregation.EfficientTreeAggregator(
        value_generator=ConstantValueGenerator([
            tf.ones([2, 2], dtype=tf.float32),
            tf.constant([2], dtype=tf.float32)
        ]))
    tree_aggregator_truth = tree_aggregation.EfficientTreeAggregator(
        value_generator=ConstantValueGenerator(1.))
    state = tree_aggregator.init_state()
    truth_state = tree_aggregator_truth.init_state()
    for leaf_node_idx in range(total_steps):
      self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state))
      val, state = tree_aggregator.get_cumsum_and_update(state)
      expected_val, truth_state = tree_aggregator_truth.get_cumsum_and_update(
          truth_state)
      self.assertEqual(
          tree_aggregation.get_step_idx(state),
          tree_aggregation.get_step_idx(truth_state))
      expected_result = [
          expected_val * tf.ones([2, 2], dtype=tf.float32),
          expected_val * tf.constant([2], dtype=tf.float32),
      ]
      tf.nest.map_structure(self.assertAllClose, val, expected_result)
 def test_tree_sum_last_step_expected(self, total_steps, expected_value,
                                      step_value):
   # Test whether `tree_aggregator` will output `expected_value` after
   # `total_steps` of leaf nodes are traversed. The value of each tree node
   # is a constant `node_value` for test purpose. Note that `node_value`
   # denotes the "noise" without private values in private algorithms. The
   # `expected_value` is based on a weighting schema strongly depends on the
   # depth of the binary tree.
   tree_aggregator = tree_aggregation.EfficientTreeAggregator(
       value_generator=ConstantValueGenerator(step_value))
   state = tree_aggregator.init_state()
   for leaf_node_idx in range(total_steps):
     self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state))
     val, state = tree_aggregator.get_cumsum_and_update(state)
   self.assertAllClose(expected_value, val)
 def test_tree_sum_noise_expected(self, total_steps, expected_variance,
                                  noise_std, variable_shape, tolerance):
   # Test whether `tree_aggregator` will output `expected_variance` (within a
   # relative `tolerance`) after  `total_steps` of leaf nodes are traversed.
   # Each tree node is a `variable_shape` tensor of Gaussian noise with
   # `noise_std`. Note that the variance of a tree node is smaller than
   # the given vanilla node `noise_std` because of the update rule of
   # `EfficientTreeAggregator`.
   random_generator = tree_aggregation.GaussianNoiseGenerator(
       noise_std, tf.TensorSpec(variable_shape), seed=2020)
   tree_aggregator = tree_aggregation.EfficientTreeAggregator(
       value_generator=random_generator)
   state = tree_aggregator.init_state()
   for leaf_node_idx in range(total_steps):
     self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state))
     val, state = tree_aggregator.get_cumsum_and_update(state)
   self.assertAllClose(
       math.sqrt(expected_variance), tf.math.reduce_std(val), rtol=tolerance)
 def test_tree_sum_noise_expected(self, total_steps, expected_variance,
                                  noise_std, variable_shape, tolerance):
   # Test whether `tree_aggregator` will output `expected_variance` (within a
   # relative `tolerance`) in each step when `total_steps` of leaf nodes are
   # traversed. Each tree node is a `variable_shape` tensor of Gaussian noise
   # with `noise_std`.
   random_generator = tree_aggregation.GaussianNoiseGenerator(
       noise_std, tf.TensorSpec(variable_shape), seed=2020)
   tree_aggregator = tree_aggregation.TreeAggregator(
       value_generator=random_generator)
   state = tree_aggregator.init_state()
   for leaf_node_idx in range(total_steps):
     self.assertEqual(leaf_node_idx, tree_aggregation.get_step_idx(state))
     val, state = tree_aggregator.get_cumsum_and_update(state)
     self.assertAllClose(
         math.sqrt(expected_variance[leaf_node_idx]),
         tf.math.reduce_std(val),
         rtol=tolerance)