def test_iterative_process_fails_with_dp_agg_and_none_client_weighting(self):

    def loss_fn():
      return tf.keras.losses.SparseCategoricalCrossentropy()

    def metrics_fn():
      return [
          NumExamplesCounter(),
          NumBatchesCounter(),
          tf.keras.metrics.SparseCategoricalAccuracy()
      ]

    # No values should be changed, but working with inf directly zeroes out all
    # updates. Preferring very large value, but one that can be handled in
    # multiplication/division
    gaussian_sum_query = tfp.GaussianSumQuery(l2_norm_clip=1e10, stddev=0)
    dp_sum_factory = differential_privacy.DifferentiallyPrivateFactory(
        query=gaussian_sum_query,
        record_aggregation_factory=sum_factory.SumFactory())
    dp_mean_factory = _DPMean(dp_sum_factory)

    with self.assertRaisesRegex(ValueError, 'unweighted aggregator'):
      training_process.build_training_process(
          MnistModel,
          loss_fn=loss_fn,
          metrics_fn=metrics_fn,
          server_optimizer_fn=_get_keras_optimizer_fn(0.01),
          client_optimizer_fn=_get_keras_optimizer_fn(0.001),
          reconstruction_optimizer_fn=_get_keras_optimizer_fn(0.0),
          aggregation_factory=dp_mean_factory,
          client_weighting=None,
          dataset_split_fn=reconstruction_utils.simple_dataset_split_fn)
예제 #2
0
  def test_simple_sum(self):
    factory_ = differential_privacy.DifferentiallyPrivateFactory(_test_dp_query)
    value_type = computation_types.to_type(tf.float32)
    process = factory_.create(value_type)

    # The test query has clip 1.0 and no noise, so this computes clipped sum.

    state = process.initialize()

    client_data = [0.5, 1.0, 1.5]
    output = process.next(state, client_data)
    self.assertAllClose(2.5, output.result)
예제 #3
0
def create_central_hierarchical_histogram_factory(
        stddev: float = 0.0,
        arity: int = 2,
        max_records_per_user: int = 10,
        secure_sum: bool = False):
    """Creates aggregator for hierarchical histograms with differential privacy.

  Args:
    stddev: The standard deviation of noise added to each node of the central
      tree.
    arity: The branching factor of the tree.
    max_records_per_user: The maximum of records each user can upload in their
      local histogram.
    secure_sum: A boolean deciding whether to use secure aggregation. Defaults
      to `False`.

  Returns:
    `tff.aggregators.UnWeightedAggregationFactory`.

  Raises:
    `ValueError`: If 'stddev < 0', `arity < 2`, `max_records_per_user < 1` or
    `inner_agg_factory` is illegal.
  """
    if stddev < 0:
        raise ValueError(f"Standard deviation should be greater than zero."
                         f"stddev={stddev} is given.")

    if arity < 2:
        raise ValueError(f"Arity should be at least 2."
                         f"arity={arity} is given.")

    if max_records_per_user < 1:
        raise ValueError(
            f"Maximum records per user should be at least 1."
            f"max_records_per_user={max_records_per_user} is given.")

    central_tree_agg_query = tfp.privacy.dp_query.tree_aggregation_query.CentralTreeSumQuery(
        stddev=stddev, arity=arity, l1_bound=max_records_per_user)

    if secure_sum:
        inner_agg_factory = secure.SecureSumFactory(
            upper_bound_threshold=float(max_records_per_user),
            lower_bound_threshold=0.)
    else:
        inner_agg_factory = sum_factory.SumFactory()

    return differential_privacy.DifferentiallyPrivateFactory(
        central_tree_agg_query, inner_agg_factory)
예제 #4
0
  def test_inner_sum(self):
    factory_ = differential_privacy.DifferentiallyPrivateFactory(
        _test_dp_query, _test_inner_agg_factory)
    value_type = computation_types.to_type(tf.float32)
    process = factory_.create(value_type)

    # The test query has clip 1.0 and no noise, so this computes clipped sum.
    # Inner agg adds another 1.0 (post-clipping).

    state = process.initialize()
    self.assertAllEqual(0, state[1])

    client_data = [0.5, 1.0, 1.5]
    output = process.next(state, client_data)
    self.assertAllEqual(1, output.state[1])
    self.assertAllClose(3.5, output.result)
    self.assertAllEqual(test_utils.MEASUREMENT_CONSTANT,
                        output.measurements['dp'])
예제 #5
0
  def test_structure_sum(self):
    factory_ = differential_privacy.DifferentiallyPrivateFactory(_test_dp_query)
    value_type = computation_types.to_type([tf.float32, tf.float32])
    process = factory_.create(value_type)

    # The test query has clip 1.0 and no noise, so this computes clipped sum.

    state = process.initialize()

    # pyformat: disable
    client_data = [
        [0.1, 0.2],         # not clipped (norm < 1)
        [5 / 13, 12 / 13],  # not clipped (norm == 1)
        [3.0, 4.0]          # clipped to 3/5, 4/5
    ]
    output = process.next(state, client_data)

    expected_result = [0.1 +  5 / 13 + 3 / 5,
                       0.2 + 12 / 13 + 4 / 5]
    # pyformat: enable
    self.assertAllClose(expected_result, output.result)
예제 #6
0
  def test_type_properties(self, value_type, inner_agg_factory):
    factory_ = differential_privacy.DifferentiallyPrivateFactory(
        _test_dp_query, inner_agg_factory)
    self.assertIsInstance(factory_, factory.UnweightedAggregationFactory)
    value_type = computation_types.to_type(value_type)
    process = factory_.create(value_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    query_state = _test_dp_query.initial_global_state()
    query_state_type = type_conversions.type_from_tensors(query_state)
    query_metrics_type = type_conversions.type_from_tensors(
        _test_dp_query.derive_metrics(query_state))

    inner_state_type = tf.int32 if inner_agg_factory else ()

    server_state_type = computation_types.at_server(
        (query_state_type, inner_state_type))
    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=server_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    inner_measurements_type = tf.int32 if inner_agg_factory else ()
    expected_measurements_type = computation_types.at_server(
        collections.OrderedDict(
            dp_query_metrics=query_metrics_type, dp=inner_measurements_type))

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=server_state_type,
            value=computation_types.at_clients(value_type)),
        result=measured_process.MeasuredProcessOutput(
            state=server_state_type,
            result=computation_types.at_server(value_type),
            measurements=expected_measurements_type))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
예제 #7
0
    def test_adaptive_query(self):
        query = tfp.QuantileAdaptiveClipSumQuery(initial_l2_norm_clip=1.0,
                                                 noise_multiplier=0.0,
                                                 target_unclipped_quantile=1.0,
                                                 learning_rate=1.0,
                                                 clipped_count_stddev=0.0,
                                                 expected_num_records=3.0,
                                                 geometric_update=False)
        factory_ = differential_privacy.DifferentiallyPrivateFactory(query)
        value_type = computation_types.to_type(tf.float32)
        process = factory_.create(value_type)

        state = process.initialize()

        client_data = [0.5, 1.5, 2.0]  # Two clipped on first round.
        expected_result = 0.5 + 1.0 + 1.0
        output = process.next(state, client_data)
        self.assertAllClose(expected_result, output.result)

        # Clip is increased by 2/3 to 5/3.
        expected_result = 0.5 + 1.5 + 5 / 3
        output = process.next(output.state, client_data)
        self.assertAllClose(expected_result, output.result)
예제 #8
0
 def test_incorrect_value_type_raises(self, bad_value_type):
   factory_ = differential_privacy.DifferentiallyPrivateFactory(_test_dp_query)
   with self.assertRaises(TypeError):
     factory_.create(bad_value_type)
예제 #9
0
 def test_init_non_agg_factory_raises(self):
   with self.assertRaises(TypeError):
     differential_privacy.DifferentiallyPrivateFactory(_test_dp_query,
                                                       'not an agg factory')
예제 #10
0
    def test_type_properties(self, value_type, mechanism):
        ddp_factory = _make_test_factory(mechanism=mechanism)
        self.assertIsInstance(ddp_factory,
                              factory.UnweightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        process = ddp_factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        # The state is a nested object with component factory states. Construct
        # test factories directly and compare the signatures.
        modsum_f = secure.SecureModularSumFactory(2**15, True)

        if mechanism == 'distributed_dgauss':
            dp_query = tfp.DistributedDiscreteGaussianSumQuery(
                l2_norm_bound=10.0, local_stddev=10.0)
        else:
            dp_query = tfp.DistributedSkellamSumQuery(l1_norm_bound=10.0,
                                                      l2_norm_bound=10.0,
                                                      local_stddev=10.0)

        dp_f = differential_privacy.DifferentiallyPrivateFactory(
            dp_query, modsum_f)
        discrete_f = discretization.DiscretizationFactory(dp_f)
        l2clip_f = robust.clipping_factory(clipping_norm=10.0,
                                           inner_agg_factory=discrete_f)
        rot_f = rotation.HadamardTransformFactory(inner_agg_factory=l2clip_f)
        expected_process = concat.concat_factory(rot_f).create(value_type)

        # Check init_fn/state.
        expected_init_type = expected_process.initialize.type_signature
        expected_state_type = expected_init_type.result
        actual_init_type = process.initialize.type_signature
        self.assertTrue(actual_init_type.is_equivalent_to(expected_init_type))

        # Check next_fn/measurements.
        tensor2type = type_conversions.type_from_tensors
        discrete_state = discrete_f.create(
            computation_types.to_type(tf.float32)).initialize()
        dp_query_state = dp_query.initial_global_state()
        dp_query_metrics_type = tensor2type(
            dp_query.derive_metrics(dp_query_state))
        expected_measurements_type = collections.OrderedDict(
            l2_clip=robust.NORM_TF_TYPE,
            scale_factor=tensor2type(discrete_state['scale_factor']),
            scaled_inflated_l2=tensor2type(dp_query_state.l2_norm_bound),
            scaled_local_stddev=tensor2type(dp_query_state.local_stddev),
            actual_num_clients=tf.int32,
            padded_dim=tf.int32,
            dp_query_metrics=dp_query_metrics_type)
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=expected_state_type,
                result=computation_types.at_server(value_type),
                measurements=computation_types.at_server(
                    expected_measurements_type)))
        actual_next_type = process.next.type_signature
        self.assertTrue(actual_next_type.is_equivalent_to(expected_next_type))
        try:
            static_assert.assert_not_contains_unsecure_aggregation(
                process.next)
        except:  # pylint: disable=bare-except
            self.fail('Factory returned an AggregationProcess containing '
                      'non-secure aggregation.')
  def test_execution_with_custom_dp_query(self):
    client_data = create_emnist_client_data()
    train_data = [client_data(), client_data()]

    def loss_fn():
      return tf.keras.losses.SparseCategoricalCrossentropy()

    def metrics_fn():
      return [
          NumExamplesCounter(),
          NumBatchesCounter(),
          tf.keras.metrics.SparseCategoricalAccuracy()
      ]

    # No values should be changed, but working with inf directly zeroes out all
    # updates. Preferring very large value, but one that can be handled in
    # multiplication/division
    gaussian_sum_query = tfp.GaussianSumQuery(l2_norm_clip=1e10, stddev=0)
    dp_sum_factory = differential_privacy.DifferentiallyPrivateFactory(
        query=gaussian_sum_query,
        record_aggregation_factory=sum_factory.SumFactory())
    dp_mean_factory = _DPMean(dp_sum_factory)

    # Disable reconstruction via 0 learning rate to ensure post-recon loss
    # matches exact expectations round 0 and decreases by the next round.
    trainer = training_process.build_training_process(
        MnistModel,
        loss_fn=loss_fn,
        metrics_fn=metrics_fn,
        server_optimizer_fn=_get_keras_optimizer_fn(0.01),
        client_optimizer_fn=_get_keras_optimizer_fn(0.001),
        reconstruction_optimizer_fn=_get_keras_optimizer_fn(0.0),
        aggregation_factory=dp_mean_factory,
        dataset_split_fn=reconstruction_utils.simple_dataset_split_fn,
        client_weighting=client_weight_lib.ClientWeighting.UNIFORM,
    )
    state = trainer.initialize()

    outputs = []
    states = []
    for _ in range(2):
      state, output = trainer.next(state, train_data)
      outputs.append(output)
      states.append(state)

    # All weights and biases are initialized to 0, so initial logits are all 0
    # and softmax probabilities are uniform over 10 classes. So negative log
    # likelihood is -ln(1/10). This is on expectation, so increase tolerance.
    self.assertAllClose(
        outputs[0]['train']['loss'], tf.math.log(10.0), rtol=1e-4)
    self.assertLess(outputs[1]['train']['loss'], outputs[0]['train']['loss'])
    self.assertNotAllClose(states[0].model.trainable, states[1].model.trainable)

    # Expect 6 reconstruction examples, 6 training examples. Only training
    # included in metrics.
    self.assertEqual(outputs[0]['train']['num_examples_total'], 6.0)
    self.assertEqual(outputs[1]['train']['num_examples_total'], 6.0)

    # Expect 4 reconstruction batches and 4 training batches. Only training
    # included in metrics.
    self.assertEqual(outputs[0]['train']['num_batches_total'], 4.0)
    self.assertEqual(outputs[1]['train']['num_batches_total'], 4.0)
예제 #12
0
def create_hierarchical_histogram_aggregation_factory(
        num_bins: int,
        arity: int = 2,
        clip_mechanism: str = 'sub-sampling',
        max_records_per_user: int = 10,
        dp_mechanism: str = 'no-noise',
        noise_multiplier: float = 0.0,
        expected_clients_per_round: int = 10,
        bits: int = 22,
        enable_secure_sum: bool = True):
    """Creates hierarchical histogram aggregation factory.

  Hierarchical histogram factory is constructed by composing 3 aggregation
  factories.
  (1) The inner-most factory is `SumFactory`.
  (2) The middle factory is `DifferentiallyPrivateFactory` whose inner query is
      `TreeRangeSumQuery`. This factory 1) takes in a clipped histogram,
      constructs the hierarchical histogram and checks the norm bound of the
      hierarchical histogram at clients, 2) adds noise either at clients or at
      server according to `dp_mechanism`.
  (3) The outer-most factory is `HistogramClippingSumFactory` which clips the
      input histogram to bound each user's contribution.

  Args:
    num_bins: An `int` representing the input histogram size.
    arity: An `int` representing the branching factor of the tree. Defaults to
      2.
   clip_mechanism: A `str` representing the clipping mechanism. Currently
     supported mechanisms are
      - 'sub-sampling': (Default) Uniformly sample up to `max_records_per_user`
        records without replacement from the client dataset.
      - 'distinct': Uniquify client dataset and uniformly sample up to
        `max_records_per_user` records without replacement from it.
    max_records_per_user: An `int` representing the maximum of records each user
      can include in their local histogram. Defaults to 10.
    dp_mechanism: A `str` representing the differentially private mechanism to
      use. Currently supported mechanisms are
      - 'no-noise': (Default) Tree aggregation mechanism without noise.
      - 'central-gaussian': Tree aggregation with central Gaussian mechanism.
      - 'distributed-discrete-gaussian': Tree aggregation mechanism with
        distributed discrete Gaussian mechanism in "The Distributed Discrete
        Gaussian Mechanism for Federated Learning with Secure Aggregation. Peter
        Kairouz, Ziyu Liu, Thomas Steinke".
    noise_multiplier: A `float` specifying the noise multiplier (central noise
      stddev / L2 clip norm) for model updates. Only needed when `dp_mechanism`
      is not 'no-noise'. Defaults to 0.0.
    expected_clients_per_round: An `int` specifying the lower bound of the
      expected number of clients. Only needed when `dp_mechanism` is
      'distributed-discrete-gaussian. Defaults to 10.
    bits: A positive integer specifying the communication bit-width B (where
      2**B will be the field size for SecAgg operations). Only needed when
      `dp_mechanism` is 'distributed-discrete-gaussian'. Please read the below
      precautions carefully and set `bits` accordingly. Otherwise, unexpected
      overflow or accuracy degradation might happen. (1) Should be in the
      inclusive range [1, 22] to avoid overflow inside secure aggregation; (2)
      Should be at least as large as `log2(4 * sqrt(expected_clients_per_round)*
      noise_multiplier * l2_norm_bound + expected_clients_per_round *
      max_records_per_user) + 1` to avoid accuracy degradation caused by
      frequent modular clipping; (3) If the number of clients exceed
      `expected_clients_per_round`, overflow might happen.
    enable_secure_sum: Whether to aggregate client's update by secure sum or
      not. Defaults to `True`. When `dp_mechanism` is set to
      `'distributed-discrete-gaussian'`, `enable_secure_sum` must be `True`.

  Returns:
    `tff.aggregators.UnweightedAggregationFactory`.

  Raises:
    TypeError: If arguments have the wrong type(s).
    ValueError: If arguments have invalid value(s).
  """
    _check_positive(num_bins, 'num_bins')
    _check_greater_equal(arity, 2, 'arity')
    _check_membership(clip_mechanism, clipping_factory.CLIP_MECHANISMS,
                      'clip_mechanism')
    _check_positive(max_records_per_user, 'max_records_per_user')
    _check_membership(dp_mechanism, DP_MECHANISMS, 'dp_mechanism')
    _check_non_negative(noise_multiplier, 'noise_multiplier')
    _check_positive(expected_clients_per_round, 'expected_clients_per_round')
    _check_in_range(bits, 'bits', 1, 22)

    # Converts `max_records_per_user` to the corresponding norm bound according to
    # the chosen `clip_mechanism` and `dp_mechanism`.
    if dp_mechanism in ['central-gaussian', 'distributed-discrete-gaussian']:
        if clip_mechanism == 'sub-sampling':
            l2_norm_bound = max_records_per_user * math.sqrt(
                _tree_depth(num_bins, arity))
        elif clip_mechanism == 'distinct':
            # The following code block converts `max_records_per_user` to L2 norm
            # bound of the hierarchical histogram layer by layer. For the bottom
            # layer with only 0s and at most `max_records_per_user` 1s, the L2 norm
            # bound is `sqrt(max_records_per_user)`. For the second layer from bottom,
            # the worst case is only 0s and `max_records_per_user/2` 2s. And so on
            # until the root node. Another natural L2 norm bound on each layer is
            # `max_records_per_user` so we take the minimum between the two bounds.
            square_l2_norm_bound = 0.
            square_layer_l2_norm_bound = max_records_per_user
            for _ in range(_tree_depth(num_bins, arity)):
                square_l2_norm_bound += min(max_records_per_user**2,
                                            square_layer_l2_norm_bound)
                square_layer_l2_norm_bound *= arity
            l2_norm_bound = math.sqrt(square_l2_norm_bound)

    if not enable_secure_sum and dp_mechanism in DISTRIBUTED_DP_MECHANISMS:
        raise ValueError(f'When dp_mechanism is {DISTRIBUTED_DP_MECHANISMS}, '
                         'enable_secure_sum must be set to True to preserve '
                         'distributed DP.')

    # Build nested aggregtion factory from innermost to outermost.
    # 1. Sum factory. The most inner factory that sums the preprocessed records.
    # (1) If  `enable_secure_sum` is `False`, should be `SumFactory`.
    if not enable_secure_sum:
        nested_factory = sum_factory.SumFactory()
    else:
        # (2) If  `enable_secure_sum` is `True`, and `dp_mechanism` is 'no-noise' or
        # 'central-gaussian', the sum factory should be `SecureSumFactory`, with
        # a `upper_bound_threshold` of `max_records_per_user`. When `dp_mechanism`
        # is 'central-gaussian', use a float `SecureSumFactory` to be compatible
        # with `GaussianSumQuery`.
        if dp_mechanism in ['no-noise']:
            nested_factory = secure.SecureSumFactory(max_records_per_user)
        elif dp_mechanism in ['central-gaussian']:
            nested_factory = secure.SecureSumFactory(
                float(max_records_per_user))
        # (3) If `dp_mechanism` is in `DISTRIBUTED_DP_MECHANISMS`, should be
        #     `SecureSumFactory`. To preserve DP and avoid overflow, we have 4
        #    modular clips from nesting two modular clip aggregators:
        #    #1. outer-client: clips to [-2**(bits-1), 2**(bits-1))
        #        Bounds the client values.
        #    #2. inner-client: clips to [0, 2**bits)
        #        Similar to applying a two's complement to the values such that
        #        frequent values (post-rotation) are now near 0 (representing small
        #        positives) and 2**bits (small negatives). 0 also always map to 0,
        #        and we do not require another explicit value range shift from
        #        [-2**(bits-1), 2**(bits-1)] to [0, 2**bits] to make sure that
        #        values are compatible with SecAgg's mod m = 2**bits. This can be
        #        reverted at #4.
        #    #3. inner-server: clips to [0, 2**bits)
        #        Ensures the aggregated value range does not grow by
        #        `log2(expected_clients_per_round)`.
        #        NOTE: If underlying SecAgg is implemented using the new
        #        `tff.federated_secure_modular_sum()` operator with the same
        #        modular clipping range, then this would correspond to a no-op.
        #    #4. outer-server: clips to [-2**(bits-1), 2**(bits-1))
        #        Keeps aggregated values centered near 0 out of the logical SecAgg
        #        black box for outer aggregators.
        elif dp_mechanism in ['distributed-discrete-gaussian']:
            # TODO(b/196312838): Please add scaling to the distributed case once we
            # have a stable guideline for setting scaling factor to improve
            # performance and avoid overflow. The below test is to make sure that
            # modular clipping happens with small probability so the accuracy of the
            # result won't be harmed. However, if the number of clients exceeds
            # `expected_clients_per_round`, overflow still might happen. It is the
            # caller's responsibility to carefully choose `bits` according to system
            # details to avoid overflow or performance degradation.
            if bits < math.log2(4 * math.sqrt(expected_clients_per_round) *
                                noise_multiplier * l2_norm_bound +
                                expected_clients_per_round *
                                max_records_per_user) + 1:
                raise ValueError(
                    f'The selected bit-width ({bits}) is too small for the '
                    f'given parameters (expected_clients_per_round = '
                    f'{expected_clients_per_round}, max_records_per_user = '******'{max_records_per_user}, noise_multiplier = '
                    f'{noise_multiplier}) and will harm the accuracy of the '
                    f'result. Please decrease the '
                    f'`expected_clients_per_round` / `max_records_per_user` '
                    f'/ `noise_multiplier`, or increase `bits`.')
            nested_factory = secure.SecureSumFactory(
                upper_bound_threshold=2**bits - 1, lower_bound_threshold=0)
            nested_factory = modular_clipping_factory.ModularClippingSumFactory(
                clip_range_lower=0,
                clip_range_upper=2**bits,
                inner_agg_factory=nested_factory)
            nested_factory = modular_clipping_factory.ModularClippingSumFactory(
                clip_range_lower=-2**(bits - 1),
                clip_range_upper=2**(bits - 1),
                inner_agg_factory=nested_factory)

    # 2. DP operations.
    # Constructs `DifferentiallyPrivateFactory` according to the chosen
    # `dp_mechanism`.
    if dp_mechanism == 'central-gaussian':
        query = tfp.TreeRangeSumQuery.build_central_gaussian_query(
            l2_norm_bound, noise_multiplier * l2_norm_bound, arity)
        # If the inner `DifferentiallyPrivateFactory` uses `GaussianSumQuery`, then
        # the record is casted to `tf.float32` before feeding to the DP factory.
        cast_to_float = True
    elif dp_mechanism == 'distributed-discrete-gaussian':
        query = tfp.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
            l2_norm_bound, noise_multiplier * l2_norm_bound /
            math.sqrt(expected_clients_per_round), arity)
        # If the inner `DifferentiallyPrivateFactory` uses
        # `DistributedDiscreteGaussianQuery`, then the record is kept as `tf.int32`
        # before feeding to the DP factory.
        cast_to_float = False
    elif dp_mechanism == 'no-noise':
        inner_query = tfp.NoPrivacySumQuery()
        query = tfp.TreeRangeSumQuery(arity=arity, inner_query=inner_query)
        # If the inner `DifferentiallyPrivateFactory` uses `NoPrivacyQuery`, then
        # the record is kept as `tf.int32` before feeding to the DP factory.
        cast_to_float = False
    else:
        raise ValueError('Unexpected dp_mechanism.')
    nested_factory = differential_privacy.DifferentiallyPrivateFactory(
        query, nested_factory)

    # 3. Clip as specified by `clip_mechanism`.
    nested_factory = clipping_factory.HistogramClippingSumFactory(
        clip_mechanism=clip_mechanism,
        max_records_per_user=max_records_per_user,
        inner_agg_factory=nested_factory,
        cast_to_float=cast_to_float)

    return nested_factory
예제 #13
0
    def _build_aggregation_factory(self):
        central_stddev = self._value_noise_mult * self._initial_l2_clip
        local_stddev = central_stddev / math.sqrt(self._num_clients)

        # Ensure dim is at least 1 only for computing DDP parameters.
        self._client_dim = max(1, self._client_dim)
        if self._rotation_type == 'hd':
            # Hadamard transform requires dimension to be powers of 2.
            self._padded_dim = 2**math.ceil(math.log2(self._client_dim))
            rotation_factory = rotation.HadamardTransformFactory
        else:
            # DFT pads at most 1 zero.
            self._padded_dim = math.ceil(self._client_dim / 2.0) * 2
            rotation_factory = rotation.DiscreteFourierTransformFactory

        scale = _heuristic_scale_factor(local_stddev, self._initial_l2_clip,
                                        self._bits, self._num_clients,
                                        self._padded_dim,
                                        self._k_stddevs).numpy()

        # Very large scales could lead to overflows and are not as helpful for
        # utility. See comment above for more details.
        scale = min(scale, MAX_SCALE_FACTOR)

        if scale <= 1:
            warnings.warn(
                f'The selected scale_factor {scale} <= 1. This may lead to'
                f'substantial quantization errors. Consider increasing'
                f'the bit-width (currently {self._bits}) or decreasing the'
                f'expected number of clients per round (currently '
                f'{self._num_clients}).')

        # The procedure for obtaining inflated L2 bound assumes eager TF execution
        # and can be rewritten with NumPy if needed.
        inflated_l2 = discretization.inflated_l2_norm_bound(
            l2_norm_bound=self._initial_l2_clip,
            gamma=1.0 / scale,
            beta=self._beta,
            dim=self._padded_dim).numpy()

        # Add small leeway on norm bounds to gracefully allow numerical errors.
        # Specifically, the norm thresholds are computed directly from the specified
        # parameters in Python and will be checked right before noising; on the
        # other hand, the actual norm of the record (to be measured at noising time)
        # can possibly be (negligibly) higher due to the float32 arithmetic after
        # the conditional rounding (thus failing the check). While we have mitigated
        # this by sharing the computation for the inflated norm bound from
        # quantization, adding a leeway makes the execution more robust (it does not
        # need to abort should any precision issues happen) while not affecting the
        # correctness if privacy accounting is done based on the norm bounds at the
        # DPQuery/DPFactory (which incorporates the leeway).
        scaled_inflated_l2 = (inflated_l2 + 1e-5) * scale
        # Since values are scaled and rounded to integers, we have L1 <= L2^2
        # on top of the general of L1 <= sqrt(d) * L2.
        scaled_l1 = math.ceil(
            scaled_inflated_l2 *
            min(math.sqrt(self._padded_dim), scaled_inflated_l2))

        # Build nested aggregtion factory.
        # 1. Secure Aggregation. In particular, we have 4 modular clips from
        #    nesting two modular clip aggregators:
        #    #1. outer-client: clips to [-2^(b-1), 2^(b-1)]
        #        Bounds the client values (with limited effect as scaling was
        #        chosen such that `num_clients` is taken into account).
        #    #2. inner-client: clips to [0, 2^b]
        #        Similar to applying a two's complement to the values such that
        #        frequent values (post-rotation) are now near 0 (representing small
        #        positives) and 2^b (small negatives). 0 also always map to 0, and
        #        we do not require another explicit value range shift from
        #        [-2^(b-1), 2^(b-1)] to [0, 2^b] to make sure that values are
        #        compatible with SecAgg's mod m = 2^b. This can be reverted at #4.
        #    #3. inner-server: clips to [0, 2^b]
        #        Ensures the aggregated value range does not grow by log_2(n).
        #        NOTE: If underlying SecAgg is implemented using the new
        #        `tff.federated_secure_modular_sum()` operator with the same
        #        modular clipping range, then this would correspond to a no-op.
        #    #4. outer-server: clips to [-2^(b-1), 2^(b-1)]
        #        Keeps aggregated values centered near 0 out of the logical SecAgg
        #        black box for outer aggregators.
        #    Note that the scaling factor and the bit-width are chosen such that
        #    the number of clients to aggregate is taken into account.
        nested_factory = secure.SecureSumFactory(
            upper_bound_threshold=2**self._bits - 1, lower_bound_threshold=0)
        nested_factory = modular_clipping.ModularClippingSumFactory(
            clip_range_lower=0,
            clip_range_upper=2**self._bits,
            inner_agg_factory=nested_factory)
        nested_factory = modular_clipping.ModularClippingSumFactory(
            clip_range_lower=-(2**(self._bits - 1)),
            clip_range_upper=2**(self._bits - 1),
            inner_agg_factory=nested_factory)

        # 2. DP operations. DP params are in the scaled domain (post-quantization).
        if self._mechanism == 'distributed_dgauss':
            dp_query = tfp.DistributedDiscreteGaussianSumQuery(
                l2_norm_bound=scaled_inflated_l2,
                local_stddev=local_stddev * scale)
        else:
            dp_query = tfp.DistributedSkellamSumQuery(
                l1_norm_bound=scaled_l1,
                l2_norm_bound=scaled_inflated_l2,
                local_stddev=local_stddev * scale)

        nested_factory = differential_privacy.DifferentiallyPrivateFactory(
            query=dp_query, record_aggregation_factory=nested_factory)

        # 3. Discretization operations. This appropriately quantizes the inputs.
        nested_factory = discretization.DiscretizationFactory(
            inner_agg_factory=nested_factory,
            scale_factor=scale,
            stochastic=True,
            beta=self._beta,
            prior_norm_bound=self._initial_l2_clip)

        # 4. L2 clip, possibly adaptively with a `tff.templates.EstimationProcess`.
        nested_factory = robust.clipping_factory(
            clipping_norm=self._l2_clip,
            inner_agg_factory=nested_factory,
            clipped_count_sum_factory=secure.SecureSumFactory(
                upper_bound_threshold=1, lower_bound_threshold=0))

        # 5. Flattening to improve quantization and reduce modular wrapping.
        nested_factory = rotation_factory(inner_agg_factory=nested_factory)

        # 6. Concat the input structure into a single vector.
        nested_factory = concat.concat_factory(
            inner_agg_factory=nested_factory)
        return nested_factory