def test_secure_aggregator(self): aggregator = model_update_aggregator.secure_aggregator().create( _float_matrix_type, _float_type) mrf = self._check_aggregated_scalar_count(aggregator, 60000 * 1.01, 60000) # The MapReduceForm should be using secure aggregation. self.assertTrue(mrf.securely_aggregates_tensors)
def test_secure_aggregator(self, zeroing, clipping): factory_ = model_update_aggregator.secure_aggregator(zeroing, clipping) self.assertIsInstance(factory_, factory.WeightedAggregationFactory) process = factory_.create(_float_type, _float_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) self.assertLen(process.next.type_signature.parameter, 3)
def test_secure_aggregator_unweighted(self, zeroing, clipping): factory_ = model_update_aggregator.secure_aggregator( zeroing=zeroing, clipping=clipping, weighted=False) self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) process = factory_.create(_float_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) self.assertFalse(process.is_weighted)
def test_weighted_secure_aggregator_only_contains_secure_aggregation(self): aggregator = model_update_aggregator.secure_aggregator( weighted=True).create(_float_matrix_type, _float_type) try: static_assert.assert_not_contains_unsecure_aggregation( aggregator.next) except: # pylint: disable=bare-except self.fail('Secure aggregator contains non-secure aggregation.')
def test_no_unsecure_aggregation_with_secure_aggregator(self): model_fn = model_examples.LinearRegression learning_process = fed_sgd.build_fed_sgd( model_fn, model_aggregator=model_update_aggregator.secure_aggregator(), metrics_aggregator=aggregator.secure_sum_then_finalize) static_assert.assert_not_contains_unsecure_aggregation( learning_process.next)
def test_unweighted_mime_lite_with_only_secure_aggregation(self): aggregator = model_update_aggregator.secure_aggregator(weighted=False) learning_process = mime.build_unweighted_mime_lite( model_examples.LinearRegression, base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9), model_aggregator=aggregator, full_gradient_aggregator=aggregator, metrics_aggregator=metrics_aggregator.secure_sum_then_finalize) static_assert.assert_not_contains_unsecure_aggregation( learning_process.next)
def test_unweighted_fed_avg_with_only_secure_aggregation(self): model_fn = model_examples.LinearRegression learning_process = fed_avg.build_unweighted_fed_avg( model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0), model_aggregator=model_update_aggregator.secure_aggregator( weighted=False), metrics_aggregator=aggregator.secure_sum_then_finalize) static_assert.assert_not_contains_unsecure_aggregation( learning_process.next)
def test_construction_with_only_secure_aggregation(self): model_fn = model_examples.LinearRegression learning_process = fed_avg_with_optimizer_schedule.build_weighted_fed_avg_with_optimizer_schedule( model_fn, client_learning_rate_fn=lambda x: 0.5, client_optimizer_fn=tf.keras.optimizers.SGD, model_aggregator=model_update_aggregator.secure_aggregator( weighted=True), metrics_aggregator=aggregator.secure_sum_then_finalize) static_assert.assert_not_contains_unsecure_aggregation( learning_process.next)
def test_weighted_fed_avg_with_only_secure_aggregation(self): model = test_models.build_functional_linear_regression() learning_process = fed_avg.build_weighted_fed_avg( model_fn=None, model=model, client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1), model_aggregator=model_update_aggregator.secure_aggregator( weighted=True), metrics_aggregator=aggregator.secure_sum_then_finalize) static_assert.assert_not_contains_unsecure_aggregation( learning_process.next)