예제 #1
0
    def test_type_properties_constant_bounds(self, value_type, upper_bound,
                                             lower_bound, measurements_dtype):
        secure_sum_f = secure_factory.SecureSumFactory(
            upper_bound_threshold=upper_bound,
            lower_bound_threshold=lower_bound)
        self.assertIsInstance(secure_sum_f,
                              factory.UnweightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        process = secure_sum_f.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        expected_state_type = computation_types.at_server(
            computation_types.to_type(()))
        expected_measurements_type = _measurements_type(measurements_dtype)

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=expected_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
예제 #2
0
 def test_value_type_incompatible_with_config_mode_raises_two_processes(
         self):
     secure_sum_f = secure_factory.SecureSumFactory(
         _test_estimation_process(1), _test_estimation_process(-1))
     with self.assertRaises(TypeError):
         secure_sum_f.create_unweighted(
             computation_types.TensorType(tf.int32))
예제 #3
0
    def test_float_two_processes_bounds(self):
        secure_sum_f = secure_factory.SecureSumFactory(
            upper_bound_threshold=_test_estimation_process(1),
            lower_bound_threshold=_test_estimation_process(-1))
        process = secure_sum_f.create(computation_types.to_type(tf.float32))
        client_data = [-2.5, -0.5, 0.0, 1.0, 1.5, 3.5]

        state = process.initialize()
        output = process.next(state, client_data)
        self._check_measurements(output.measurements,
                                 expected_secure_upper_clipped_count=2,
                                 expected_secure_lower_clipped_count=1,
                                 expected_secure_upper_threshold=1.0,
                                 expected_secure_lower_threshold=-1.0)

        output = process.next(output.state, client_data)
        self.assertAllClose(2.0, output.result)
        self._check_measurements(output.measurements,
                                 expected_secure_upper_clipped_count=1,
                                 expected_secure_lower_clipped_count=1,
                                 expected_secure_upper_threshold=2.0,
                                 expected_secure_lower_threshold=-2.0)

        output = process.next(output.state, client_data)
        self.assertAllClose(2.5, output.result)
        self._check_measurements(output.measurements,
                                 expected_secure_upper_clipped_count=1,
                                 expected_secure_lower_clipped_count=0,
                                 expected_secure_upper_threshold=3.0,
                                 expected_secure_lower_threshold=-3.0)
예제 #4
0
 def test_int_ranges_beyond_2_pow_32(self):
     secure_sum_f = secure_factory.SecureSumFactory(2**33, -2**33)
     # Bounds this large should be provided only with tf.int64 value_type.
     process = secure_sum_f.create_unweighted(
         computation_types.TensorType(tf.int64))
     self.assertEqual(
         process.next.type_signature.result.result.member.dtype, tf.int64)
def secure_aggregator(
        zeroing: bool = True,
        clipping: bool = True) -> factory.WeightedAggregationFactory:
    """Creates secure aggregator with adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients, clips to moderately high norm for robustness to outliers. After
  weighting in mean, the weighted values are summed using cryptographic protocol
  ensuring that the server cannot see individual updates until sufficient number
  of updates have been added together. For details, see Bonawitz et al. (2017)
  https://dl.acm.org/doi/abs/10.1145/3133956.3133982. In TFF, this is realized
  using the `tff.federated_secure_sum` operator.

  Args:
    zeroing: Whether to enable adaptive zeroing.
    clipping: Whether to enable adaptive clipping.

  Returns:
    A `tff.aggregators.WeightedAggregationFactory`.
  """
    secure_clip_bound = quantile_estimation.PrivateQuantileEstimationProcess.no_noise(
        initial_estimate=50.0,
        target_quantile=0.95,
        learning_rate=1.0,
        multiplier=2.0)
    factory_ = mean_factory.MeanFactory(
        secure_factory.SecureSumFactory(secure_clip_bound))

    if clipping:
        factory_ = _default_clipping(factory_)

    if zeroing:
        factory_ = _default_zeroing(factory_)

    return factory_
예제 #6
0
    def test_float_32_larger_than_2_pow_32(self):
        secure_sum_f = secure_factory.SecureSumFactory(
            upper_bound_threshold=float(2**34))
        process = secure_sum_f.create(computation_types.to_type(tf.float32))
        client_data = [float(2**33), float(2**33), float(2**34)]

        state = process.initialize()
        output = process.next(state, client_data)
        self.assertAllClose(float(2**35), output.result)
        self._check_measurements(output.measurements,
                                 expected_secure_upper_clipped_count=0,
                                 expected_secure_lower_clipped_count=0,
                                 expected_secure_upper_threshold=float(2**34),
                                 expected_secure_lower_threshold=float(-2**34))
예제 #7
0
    def test_int_constant_bounds(self):
        secure_sum_f = secure_factory.SecureSumFactory(
            upper_bound_threshold=1, lower_bound_threshold=-1)
        process = secure_sum_f.create(computation_types.to_type(tf.int32))
        client_data = [-2, -1, 0, 1, 2, 3]

        state = process.initialize()
        output = process.next(state, client_data)
        self.assertEqual(1, output.result)
        self._check_measurements(output.measurements,
                                 expected_secure_upper_clipped_count=2,
                                 expected_secure_lower_clipped_count=1,
                                 expected_secure_upper_threshold=1,
                                 expected_secure_lower_threshold=-1)
예제 #8
0
    def test_float_constant_bounds(self):
        secure_sum_f = secure_factory.SecureSumFactory(
            upper_bound_threshold=1.0, lower_bound_threshold=-1.0)
        process = secure_sum_f.create_unweighted(
            computation_types.to_type(tf.float32))
        client_data = [-2.5, -0.5, 0.0, 1.0, 1.5, 2.5]

        state = process.initialize()
        output = process.next(state, client_data)
        self.assertAllClose(1.5, output.result)
        self._check_measurements(output.measurements,
                                 expected_upper_bound_clipped_count=2,
                                 expected_lower_bound_clipped_count=1,
                                 expected_upper_bound_threshold=1.0,
                                 expected_lower_bound_threshold=-1.0)
예제 #9
0
    def test_float_64_larger_than_2_pow_64(self):
        secure_sum_f = secure_factory.SecureSumFactory(
            upper_bound_threshold=np.array(2**66, dtype=np.float64))
        process = secure_sum_f.create(computation_types.to_type(tf.float64))
        client_data = [
            np.array(2**65, np.float64),
            np.array(2**65, np.float64),
            np.array(2**66, np.float64)
        ]

        state = process.initialize()
        output = process.next(state, client_data)
        self.assertAllClose(np.array(2**67, np.float64), output.result)
        self._check_measurements(
            output.measurements,
            expected_secure_upper_clipped_count=0,
            expected_secure_lower_clipped_count=0,
            expected_secure_upper_threshold=np.array(2**66, np.float64),
            expected_secure_lower_threshold=np.array(-2**66, np.float64))
예제 #10
0
 def test_incorrect_value_type_raises(self, bad_value_type):
     secure_sum_f = secure_factory.SecureSumFactory(1.0, -1.0)
     with self.assertRaises(TypeError):
         secure_sum_f.create(bad_value_type)
예제 #11
0
 def test_value_type_incompatible_with_config_mode_raises_float(
         self, upper, lower):
     secure_sum_f = secure_factory.SecureSumFactory(upper, lower)
     with self.assertRaises(TypeError):
         secure_sum_f.create(computation_types.TensorType(tf.int32))
예제 #12
0
 def test_upper_bound_not_larger_than_lower_bound_raises(
         self, upper, lower):
     with self.assertRaises(ValueError):
         secure_factory.SecureSumFactory(upper, lower)