Esempio n. 1
0
  def test_secure_aggregator(self):
    aggregator = model_update_aggregator.secure_aggregator().create(
        _float_matrix_type, _float_type)
    mrf = self._check_aggregated_scalar_count(aggregator, 60000 * 1.01, 60000)

    # The MapReduceForm should be using secure aggregation.
    self.assertTrue(mrf.securely_aggregates_tensors)
    def test_secure_aggregator(self, zeroing, clipping):
        factory_ = model_update_aggregator.secure_aggregator(zeroing, clipping)

        self.assertIsInstance(factory_, factory.WeightedAggregationFactory)
        process = factory_.create(_float_type, _float_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)
        self.assertLen(process.next.type_signature.parameter, 3)
Esempio n. 3
0
  def test_secure_aggregator_unweighted(self, zeroing, clipping):
    factory_ = model_update_aggregator.secure_aggregator(
        zeroing=zeroing, clipping=clipping, weighted=False)

    self.assertIsInstance(factory_, factory.UnweightedAggregationFactory)
    process = factory_.create(_float_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)
    self.assertFalse(process.is_weighted)
 def test_weighted_secure_aggregator_only_contains_secure_aggregation(self):
     aggregator = model_update_aggregator.secure_aggregator(
         weighted=True).create(_float_matrix_type, _float_type)
     try:
         static_assert.assert_not_contains_unsecure_aggregation(
             aggregator.next)
     except:  # pylint: disable=bare-except
         self.fail('Secure aggregator contains non-secure aggregation.')
Esempio n. 5
0
 def test_no_unsecure_aggregation_with_secure_aggregator(self):
     model_fn = model_examples.LinearRegression
     learning_process = fed_sgd.build_fed_sgd(
         model_fn,
         model_aggregator=model_update_aggregator.secure_aggregator(),
         metrics_aggregator=aggregator.secure_sum_then_finalize)
     static_assert.assert_not_contains_unsecure_aggregation(
         learning_process.next)
Esempio n. 6
0
 def test_unweighted_mime_lite_with_only_secure_aggregation(self):
     aggregator = model_update_aggregator.secure_aggregator(weighted=False)
     learning_process = mime.build_unweighted_mime_lite(
         model_examples.LinearRegression,
         base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9),
         model_aggregator=aggregator,
         full_gradient_aggregator=aggregator,
         metrics_aggregator=metrics_aggregator.secure_sum_then_finalize)
     static_assert.assert_not_contains_unsecure_aggregation(
         learning_process.next)
Esempio n. 7
0
 def test_unweighted_fed_avg_with_only_secure_aggregation(self):
     model_fn = model_examples.LinearRegression
     learning_process = fed_avg.build_unweighted_fed_avg(
         model_fn,
         client_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0),
         model_aggregator=model_update_aggregator.secure_aggregator(
             weighted=False),
         metrics_aggregator=aggregator.secure_sum_then_finalize)
     static_assert.assert_not_contains_unsecure_aggregation(
         learning_process.next)
 def test_construction_with_only_secure_aggregation(self):
     model_fn = model_examples.LinearRegression
     learning_process = fed_avg_with_optimizer_schedule.build_weighted_fed_avg_with_optimizer_schedule(
         model_fn,
         client_learning_rate_fn=lambda x: 0.5,
         client_optimizer_fn=tf.keras.optimizers.SGD,
         model_aggregator=model_update_aggregator.secure_aggregator(
             weighted=True),
         metrics_aggregator=aggregator.secure_sum_then_finalize)
     static_assert.assert_not_contains_unsecure_aggregation(
         learning_process.next)
Esempio n. 9
0
 def test_weighted_fed_avg_with_only_secure_aggregation(self):
     model = test_models.build_functional_linear_regression()
     learning_process = fed_avg.build_weighted_fed_avg(
         model_fn=None,
         model=model,
         client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1),
         model_aggregator=model_update_aggregator.secure_aggregator(
             weighted=True),
         metrics_aggregator=aggregator.secure_sum_then_finalize)
     static_assert.assert_not_contains_unsecure_aggregation(
         learning_process.next)