Beispiel #1
0
    def test_clip_type_properties_with_clipped_count_agg_factory(
            self, value_type):
        factory = robust.clipping_factory(
            clipping_norm=1.0,
            inner_agg_factory=sum_factory.SumFactory(),
            clipped_count_sum_factory=aggregator_test_utils.SumPlusOneFactory(
            ))
        value_type = computation_types.to_type(value_type)
        process = factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        server_state_type = computation_types.at_server(
            collections.OrderedDict(clipping_norm=(),
                                    inner_agg=(),
                                    clipped_count_agg=tf.int32))
        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(clipping=(),
                                    clipping_norm=robust.NORM_TF_TYPE,
                                    clipped_count=robust.COUNT_TF_TYPE))
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
def _default_clipping(
        inner_factory: factory.AggregationFactory
) -> factory.AggregationFactory:
    """The default adaptive clipping wrapper."""

    # Adapts relatively quickly to a moderately high norm.
    clipping_norm = quantile_estimation.PrivateQuantileEstimationProcess.no_noise(
        initial_estimate=1.0, target_quantile=0.8, learning_rate=0.2)
    return robust.clipping_factory(clipping_norm, inner_factory)
def _default_clipping(
        inner_factory: factory.AggregationFactory,
        secure_estimation: bool = False) -> factory.AggregationFactory:
    """The default adaptive clipping wrapper."""

    # Adapts relatively quickly to a moderately high norm.
    clipping_norm = quantile_estimation.PrivateQuantileEstimationProcess.no_noise(
        initial_estimate=1.0,
        target_quantile=0.8,
        learning_rate=0.2,
        secure_estimation=secure_estimation)
    if secure_estimation:
        secure_count_factory = secure.SecureSumFactory(upper_bound_threshold=1,
                                                       lower_bound_threshold=0)
        return robust.clipping_factory(
            clipping_norm,
            inner_factory,
            clipped_count_sum_factory=secure_count_factory)
    else:
        return robust.clipping_factory(clipping_norm, inner_factory)
Beispiel #4
0
def _clipped_sum(clip=2.0):
    return robust.clipping_factory(clip, sum_factory.SumFactory())
Beispiel #5
0
def _clipped_mean(clip=2.0):
    return robust.clipping_factory(clip, mean.MeanFactory())
Beispiel #6
0
    def test_type_properties(self, value_type, mechanism):
        ddp_factory = _make_test_factory(mechanism=mechanism)
        self.assertIsInstance(ddp_factory,
                              factory.UnweightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        process = ddp_factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        # The state is a nested object with component factory states. Construct
        # test factories directly and compare the signatures.
        modsum_f = secure.SecureModularSumFactory(2**15, True)

        if mechanism == 'distributed_dgauss':
            dp_query = tfp.DistributedDiscreteGaussianSumQuery(
                l2_norm_bound=10.0, local_stddev=10.0)
        else:
            dp_query = tfp.DistributedSkellamSumQuery(l1_norm_bound=10.0,
                                                      l2_norm_bound=10.0,
                                                      local_stddev=10.0)

        dp_f = differential_privacy.DifferentiallyPrivateFactory(
            dp_query, modsum_f)
        discrete_f = discretization.DiscretizationFactory(dp_f)
        l2clip_f = robust.clipping_factory(clipping_norm=10.0,
                                           inner_agg_factory=discrete_f)
        rot_f = rotation.HadamardTransformFactory(inner_agg_factory=l2clip_f)
        expected_process = concat.concat_factory(rot_f).create(value_type)

        # Check init_fn/state.
        expected_init_type = expected_process.initialize.type_signature
        expected_state_type = expected_init_type.result
        actual_init_type = process.initialize.type_signature
        self.assertTrue(actual_init_type.is_equivalent_to(expected_init_type))

        # Check next_fn/measurements.
        tensor2type = type_conversions.type_from_tensors
        discrete_state = discrete_f.create(
            computation_types.to_type(tf.float32)).initialize()
        dp_query_state = dp_query.initial_global_state()
        dp_query_metrics_type = tensor2type(
            dp_query.derive_metrics(dp_query_state))
        expected_measurements_type = collections.OrderedDict(
            l2_clip=robust.NORM_TF_TYPE,
            scale_factor=tensor2type(discrete_state['scale_factor']),
            scaled_inflated_l2=tensor2type(dp_query_state.l2_norm_bound),
            scaled_local_stddev=tensor2type(dp_query_state.local_stddev),
            actual_num_clients=tf.int32,
            padded_dim=tf.int32,
            dp_query_metrics=dp_query_metrics_type)
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=expected_state_type,
                result=computation_types.at_server(value_type),
                measurements=computation_types.at_server(
                    expected_measurements_type)))
        actual_next_type = process.next.type_signature
        self.assertTrue(actual_next_type.is_equivalent_to(expected_next_type))
        try:
            static_assert.assert_not_contains_unsecure_aggregation(
                process.next)
        except:  # pylint: disable=bare-except
            self.fail('Factory returned an AggregationProcess containing '
                      'non-secure aggregation.')
Beispiel #7
0
    def _build_aggregation_factory(self):
        central_stddev = self._value_noise_mult * self._initial_l2_clip
        local_stddev = central_stddev / math.sqrt(self._num_clients)

        # Ensure dim is at least 1 only for computing DDP parameters.
        self._client_dim = max(1, self._client_dim)
        if self._rotation_type == 'hd':
            # Hadamard transform requires dimension to be powers of 2.
            self._padded_dim = 2**math.ceil(math.log2(self._client_dim))
            rotation_factory = rotation.HadamardTransformFactory
        else:
            # DFT pads at most 1 zero.
            self._padded_dim = math.ceil(self._client_dim / 2.0) * 2
            rotation_factory = rotation.DiscreteFourierTransformFactory

        scale = _heuristic_scale_factor(local_stddev, self._initial_l2_clip,
                                        self._bits, self._num_clients,
                                        self._padded_dim,
                                        self._k_stddevs).numpy()

        # Very large scales could lead to overflows and are not as helpful for
        # utility. See comment above for more details.
        scale = min(scale, MAX_SCALE_FACTOR)

        if scale <= 1:
            warnings.warn(
                f'The selected scale_factor {scale} <= 1. This may lead to'
                f'substantial quantization errors. Consider increasing'
                f'the bit-width (currently {self._bits}) or decreasing the'
                f'expected number of clients per round (currently '
                f'{self._num_clients}).')

        # The procedure for obtaining inflated L2 bound assumes eager TF execution
        # and can be rewritten with NumPy if needed.
        inflated_l2 = discretization.inflated_l2_norm_bound(
            l2_norm_bound=self._initial_l2_clip,
            gamma=1.0 / scale,
            beta=self._beta,
            dim=self._padded_dim).numpy()

        # Add small leeway on norm bounds to gracefully allow numerical errors.
        # Specifically, the norm thresholds are computed directly from the specified
        # parameters in Python and will be checked right before noising; on the
        # other hand, the actual norm of the record (to be measured at noising time)
        # can possibly be (negligibly) higher due to the float32 arithmetic after
        # the conditional rounding (thus failing the check). While we have mitigated
        # this by sharing the computation for the inflated norm bound from
        # quantization, adding a leeway makes the execution more robust (it does not
        # need to abort should any precision issues happen) while not affecting the
        # correctness if privacy accounting is done based on the norm bounds at the
        # DPQuery/DPFactory (which incorporates the leeway).
        scaled_inflated_l2 = (inflated_l2 + 1e-5) * scale
        # Since values are scaled and rounded to integers, we have L1 <= L2^2
        # on top of the general of L1 <= sqrt(d) * L2.
        scaled_l1 = math.ceil(
            scaled_inflated_l2 *
            min(math.sqrt(self._padded_dim), scaled_inflated_l2))

        # Build nested aggregtion factory.
        # 1. Secure Aggregation. In particular, we have 4 modular clips from
        #    nesting two modular clip aggregators:
        #    #1. outer-client: clips to [-2^(b-1), 2^(b-1)]
        #        Bounds the client values (with limited effect as scaling was
        #        chosen such that `num_clients` is taken into account).
        #    #2. inner-client: clips to [0, 2^b]
        #        Similar to applying a two's complement to the values such that
        #        frequent values (post-rotation) are now near 0 (representing small
        #        positives) and 2^b (small negatives). 0 also always map to 0, and
        #        we do not require another explicit value range shift from
        #        [-2^(b-1), 2^(b-1)] to [0, 2^b] to make sure that values are
        #        compatible with SecAgg's mod m = 2^b. This can be reverted at #4.
        #    #3. inner-server: clips to [0, 2^b]
        #        Ensures the aggregated value range does not grow by log_2(n).
        #        NOTE: If underlying SecAgg is implemented using the new
        #        `tff.federated_secure_modular_sum()` operator with the same
        #        modular clipping range, then this would correspond to a no-op.
        #    #4. outer-server: clips to [-2^(b-1), 2^(b-1)]
        #        Keeps aggregated values centered near 0 out of the logical SecAgg
        #        black box for outer aggregators.
        #    Note that the scaling factor and the bit-width are chosen such that
        #    the number of clients to aggregate is taken into account.
        nested_factory = secure.SecureSumFactory(
            upper_bound_threshold=2**self._bits - 1, lower_bound_threshold=0)
        nested_factory = modular_clipping.ModularClippingSumFactory(
            clip_range_lower=0,
            clip_range_upper=2**self._bits,
            inner_agg_factory=nested_factory)
        nested_factory = modular_clipping.ModularClippingSumFactory(
            clip_range_lower=-(2**(self._bits - 1)),
            clip_range_upper=2**(self._bits - 1),
            inner_agg_factory=nested_factory)

        # 2. DP operations. DP params are in the scaled domain (post-quantization).
        if self._mechanism == 'distributed_dgauss':
            dp_query = tfp.DistributedDiscreteGaussianSumQuery(
                l2_norm_bound=scaled_inflated_l2,
                local_stddev=local_stddev * scale)
        else:
            dp_query = tfp.DistributedSkellamSumQuery(
                l1_norm_bound=scaled_l1,
                l2_norm_bound=scaled_inflated_l2,
                local_stddev=local_stddev * scale)

        nested_factory = differential_privacy.DifferentiallyPrivateFactory(
            query=dp_query, record_aggregation_factory=nested_factory)

        # 3. Discretization operations. This appropriately quantizes the inputs.
        nested_factory = discretization.DiscretizationFactory(
            inner_agg_factory=nested_factory,
            scale_factor=scale,
            stochastic=True,
            beta=self._beta,
            prior_norm_bound=self._initial_l2_clip)

        # 4. L2 clip, possibly adaptively with a `tff.templates.EstimationProcess`.
        nested_factory = robust.clipping_factory(
            clipping_norm=self._l2_clip,
            inner_agg_factory=nested_factory,
            clipped_count_sum_factory=secure.SecureSumFactory(
                upper_bound_threshold=1, lower_bound_threshold=0))

        # 5. Flattening to improve quantization and reduce modular wrapping.
        nested_factory = rotation_factory(inner_agg_factory=nested_factory)

        # 6. Concat the input structure into a single vector.
        nested_factory = concat.concat_factory(
            inner_agg_factory=nested_factory)
        return nested_factory