示例#1
0
    def test_cumsum_vector(self, total_steps=15):
        def new_value_fn():
            return [
                tf.ones([2, 2], dtype=tf.float32),
                tf.constant([2], dtype=tf.float32)
            ]

        tree_aggregator = tree_aggregation.TFEfficientTreeAggregator(
            new_value_fn=new_value_fn)
        tree_aggregator_truth = tree_aggregation.TFEfficientTreeAggregator(
            new_value_fn=lambda: 1.)
        state = tree_aggregator.init_state()
        truth_state = tree_aggregator_truth.init_state()
        for leaf_node_idx in range(total_steps):
            self.assertEqual(leaf_node_idx,
                             tree_aggregation.get_step_idx(state))
            val, state = tree_aggregator.get_cumsum_and_update(state)
            expected_val, truth_state = tree_aggregator_truth.get_cumsum_and_update(
                truth_state)
            self.assertEqual(tree_aggregation.get_step_idx(state),
                             tree_aggregation.get_step_idx(truth_state))
            expected_result = [
                expected_val * tf.ones([2, 2], dtype=tf.float32),
                expected_val * tf.constant([2], dtype=tf.float32),
            ]
            tf.nest.map_structure(self.assertAllClose, val, expected_result)
示例#2
0
    def __init__(self,
                 learning_rate: float,
                 momentum: float,
                 noise_std: float,
                 model_weight_shape: Collection[tf.Tensor],
                 efficient_tree: bool = True,
                 use_nesterov: bool = False):
        """Initialize the momemtum DPFTRL Optimizer."""

        _check_momentum(momentum)
        if use_nesterov and momentum == 0:
            raise ValueError('Use a positive momentum for Nesterov')

        self.lr = learning_rate
        self.momentum = momentum
        self.model_weight_shape = model_weight_shape
        self.use_nesterov = use_nesterov

        random_generator = tf.random.Generator.from_non_deterministic_state()

        def _noise_fn():
            return tf.nest.map_structure(
                lambda x: random_generator.normal(x, stddev=noise_std),
                model_weight_shape)

        if efficient_tree:
            self.noise_generator = tree_aggregation.TFEfficientTreeAggregator(
                new_value_fn=_noise_fn)
        else:
            self.noise_generator = tree_aggregation.TFTreeAggregator(
                new_value_fn=_noise_fn)
    def test_tree_sum_noise_efficient(self, total_steps, noise_std,
                                      variable_shape, tolerance):
        # Test the variance returned by `TFEfficientTreeAggregator` is smaller than
        # `TFTreeAggregator` (within a relative `tolerance`) after `total_steps` of
        # leaf nodes are traversed. Each tree node is a `variable_shape` tensor of
        # Gaussian noise with `noise_std`. A small `tolerance` is used for numerical
        # stability, `tolerance==0` means `TFEfficientTreeAggregator` is strictly
        # better than `TFTreeAggregator` for reducing variance.
        random_generator = tree_aggregation.GaussianNoiseGenerator(
            noise_std, tf.TensorSpec(variable_shape))
        tree_aggregator = tree_aggregation.TFEfficientTreeAggregator(
            value_generator=random_generator)
        tree_aggregator_baseline = tree_aggregation.TFTreeAggregator(
            value_generator=random_generator)

        state = tree_aggregator.init_state()
        state_baseline = tree_aggregator_baseline.init_state()
        for leaf_node_idx in range(total_steps):
            self.assertEqual(leaf_node_idx,
                             tree_aggregation.get_step_idx(state))
            val, state = tree_aggregator.get_cumsum_and_update(state)
            val_baseline, state_baseline = tree_aggregator_baseline.get_cumsum_and_update(
                state_baseline)
        self.assertLess(tf.math.reduce_variance(val), (1 + tolerance) *
                        tf.math.reduce_variance(val_baseline))
示例#4
0
    def __init__(self,
                 learning_rate: float,
                 momentum: float,
                 noise_std: float,
                 model_weight_shape: Collection[tf.Tensor],
                 efficient_tree: bool = True):
        """Initialize the momemtum DPFTRL Optimizer."""

        self.lr = learning_rate
        self.momentum = momentum
        self.model_weight_shape = model_weight_shape

        random_generator = tf.random.Generator.from_non_deterministic_state()

        def _noise_fn():
            return tf.nest.map_structure(
                lambda x: random_generator.normal(x, stddev=noise_std),
                model_weight_shape)

        if efficient_tree:
            self.noise_generator = tree_aggregation.TFEfficientTreeAggregator(
                new_value_fn=_noise_fn)
        else:
            self.noise_generator = tree_aggregation.TFTreeAggregator(
                new_value_fn=_noise_fn)
示例#5
0
    def __init__(self,
                 learning_rate: float,
                 momentum: float,
                 noise_std: float,
                 model_weight_specs: Collection[tf.TensorSpec],
                 efficient_tree: bool = True,
                 use_nesterov: bool = False):
        """Initialize the momemtum DPFTRL Optimizer."""

        _check_momentum(momentum)
        if use_nesterov and momentum == 0:
            raise ValueError('Use a positive momentum for Nesterov')

        self.lr = learning_rate
        self.momentum = momentum
        self.model_weight_specs = model_weight_specs
        self.use_nesterov = use_nesterov

        random_generator = tree_aggregation.GaussianNoiseGenerator(
            noise_std, model_weight_specs)

        if efficient_tree:
            self.noise_generator = tree_aggregation.TFEfficientTreeAggregator(
                value_generator=random_generator)
        else:
            self.noise_generator = tree_aggregation.TFTreeAggregator(
                value_generator=random_generator)
 def test_tree_sum_last_step_expected(self, total_steps, expected_value,
                                      step_value):
     # Test whether `tree_aggregator` will output `expected_value` after
     # `total_steps` of leaf nodes are traversed. The value of each tree node
     # is a constant `node_value` for test purpose. Note that `node_value`
     # denotes the "noise" without private values in private algorithms. The
     # `expected_value` is based on a weighting schema strongly depends on the
     # depth of the binary tree.
     tree_aggregator = tree_aggregation.TFEfficientTreeAggregator(
         value_generator=ConstantValueGenerator(step_value))
     state = tree_aggregator.init_state()
     for leaf_node_idx in range(total_steps):
         self.assertEqual(leaf_node_idx,
                          tree_aggregation.get_step_idx(state))
         val, state = tree_aggregator.get_cumsum_and_update(state)
     self.assertAllClose(expected_value, val)
 def test_tree_sum_noise_expected(self, total_steps, expected_variance,
                                  noise_std, variable_shape, tolerance):
     # Test whether `tree_aggregator` will output `expected_variance` (within a
     # relative `tolerance`) after  `total_steps` of leaf nodes are traversed.
     # Each tree node is a `variable_shape` tensor of Gaussian noise with
     # `noise_std`. Note that the variance of a tree node is smaller than
     # the given vanilla node `noise_std` because of the update rule of
     # `TFEfficientTreeAggregator`.
     random_generator = tree_aggregation.GaussianNoiseGenerator(
         noise_std, tf.TensorSpec(variable_shape), seed=2020)
     tree_aggregator = tree_aggregation.TFEfficientTreeAggregator(
         value_generator=random_generator)
     state = tree_aggregator.init_state()
     for leaf_node_idx in range(total_steps):
         self.assertEqual(leaf_node_idx,
                          tree_aggregation.get_step_idx(state))
         val, state = tree_aggregator.get_cumsum_and_update(state)
     self.assertAllClose(math.sqrt(expected_variance),
                         tf.math.reduce_std(val),
                         rtol=tolerance)