def test_wrong_debug_measurements_fn_compression_aggregator(self):
        """Expect error if debug_measurements_fn is wrong."""
        with self.assertRaises(context_base.ContextError):

            def wrong_debug_measurements_fn(
                    aggregation_factory: factory.AggregationFactory) ->...:
                del aggregation_factory
                return debug_measurements._calculate_client_update_statistics_mixed_dtype(
                    [1.0], [1.0])

            model_update_aggregator.compression_aggregator(
                debug_measurements_fn=wrong_debug_measurements_fn)
Example #2
0
  def test_compression_aggregator(self, zeroing, clipping):
    factory_ = model_update_aggregator.compression_aggregator(zeroing, clipping)

    self.assertIsInstance(factory_, factory.WeightedAggregationFactory)
    process = factory_.create(_float_type, _float_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)
    self.assertLen(process.next.type_signature.parameter, 3)
Example #3
0
  def test_compression_aggregator_unweighted(self, zeroing, clipping):
    factory_ = model_update_aggregator.compression_aggregator(
        zeroing=zeroing, clipping=clipping, weighted=False)

    self.assertIsInstance(factory_, factory.UnweightedAggregationFactory)
    process = factory_.create(_float_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)
    self.assertFalse(process.is_weighted)
    def test_compression_aggregator_weighted_mixed_dtype(
            self, zeroing, clipping, debug_measurements_fn):
        factory_ = model_update_aggregator.compression_aggregator(
            zeroing=zeroing,
            clipping=clipping,
            debug_measurements_fn=debug_measurements_fn)

        self.assertIsInstance(factory_, factory.WeightedAggregationFactory)
        process = factory_.create(_float_type, _float_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)
        self.assertTrue(process.is_weighted)
Example #5
0
 def test_compression_aggregator(self):
   aggregator = model_update_aggregator.compression_aggregator().create(
       _float_matrix_type, _float_type)
   # Default compression should reduce the size aggregated by more than 60%.
   self._check_aggregated_scalar_count(aggregator, 60000 * 0.4)