Esempio n. 1
0
    def test_raises_bad_measurement_fn(self):
        unweighted_factory = sum_factory.SumFactory()
        with self.assertRaisesRegex(ValueError, 'single parameter'):
            measurements.add_measurements(unweighted_factory,
                                          _get_weighted_min)

        weighted_factory = mean.MeanFactory()
        with self.assertRaisesRegex(ValueError, 'two parameters'):
            measurements.add_measurements(weighted_factory, _get_min)
Esempio n. 2
0
def add_debug_measurements(
    aggregation_factory: factory.AggregationFactory
) -> factory.AggregationFactory:
    """Adds measurements suitable for debugging learning processes.

  This will wrap a `tff.aggregator.AggregationFactory` as a new factory that
  will produce additional measurements useful for debugging learning processes.
  The underlying aggregation of client values will remain unchanged.

  These measurements generally concern the norm of the client updates, and the
  norm of the aggregated server update. The implicit weighting will be
  determined by `aggregation_factory`: If this is weighted, then the debugging
  measurements will use this weighting when computing averages. If it is
  unweighted, the debugging measurements will use uniform weighting.

  The client measurements are:

  *   The average Euclidean norm of client updates.
  *   The standard deviation of these norms.

  The standard deviation we report is the square root of the **unbiased**
  variance. The server measurements are:

  *   The maximum entry of the aggregate client update.
  *   The Euclidean norm of the aggregate client update.
  *   The minimum entry of the aggregate client update.

  In the above, an "entry" means any coordinate across all tensors in the
  structure. For example, suppose that we have client structures before
  aggregation:

  *   Client A: `[[-1, -3, -5], [2]]`
  *   Client B: `[[-1, -3, 1], [0]]`

  If we use unweighted averaging, then the aggregate client update will be the
  structure `[[-1, -3, -2], [1]]`. The maximum entry is `1`, the minimum entry
  is `-3`, and the euclidean norm is `sqrt(15)`.

  Args:
    aggregation_factory: A `tff.aggregators.AggregationFactory`. Can be weighted
      or unweighted.

  Returns:
    A `tff.aggregators.AggregationFactory`.
  """
    is_weighted_aggregator = isinstance(aggregation_factory,
                                        factory.WeightedAggregationFactory)
    client_measurement_fn, server_measurement_fn = (
        _build_aggregator_measurement_fns(
            client_measurement_fn=_calculate_client_update_statistics,
            server_measurement_fn=_calculate_server_update_statistics,
            weighted_aggregator=is_weighted_aggregator))

    return measurements.add_measurements(
        aggregation_factory,
        client_measurement_fn=client_measurement_fn,
        server_measurement_fn=server_measurement_fn)
Esempio n. 3
0
    def test_unweighted(self):
        factory = sum_factory.SumFactory()
        factory = measurements.add_measurements(factory, _get_min)
        process = factory.create(_float_type)

        state = process.initialize()
        client_data = [1.0, 2.0, 3.0]
        output = process.next(state, client_data)
        self.assertAllClose(6.0, output.result)
        self.assertDictEqual(collections.OrderedDict(min_value=1.0),
                             output.measurements)
Esempio n. 4
0
    def test_unweighted_struct(self):
        factory = sum_factory.SumFactory()

        factory = measurements.add_measurements(factory, _get_min_norm)
        process = factory.create(_struct_type)

        state = process.initialize()
        client_data = [_make_struct(x) for x in [1.0, 2.0, 3.0]]
        output = process.next(state, client_data)
        self.assertAllClose(_make_struct(6.0), output.result)
        self.assertDictEqual(collections.OrderedDict(min_norm=2.0),
                             output.measurements)
Esempio n. 5
0
    def test_weighted(self):
        factory = mean.MeanFactory()
        factory = measurements.add_measurements(factory, _get_weighted_min)
        process = factory.create(_float_type, _float_type)

        state = process.initialize()
        client_values = [1.0, 2.0, 3.0]
        client_weights = [3.0, 1.0, 2.0]
        output = process.next(state, client_values, client_weights)
        self.assertAllClose(11 / 6, output.result)
        self.assertDictEqual(
            collections.OrderedDict(mean_value=(),
                                    mean_weight=(),
                                    min_weighted_value=2.0),
            output.measurements)
Esempio n. 6
0
    def test_weighted_client(self):
        factory = mean.MeanFactory()

        factory = measurements.add_measurements(
            factory, client_measurement_fn=_get_min_weighted_norm)
        process = factory.create(_struct_type, _float_type)

        state = process.initialize()
        client_data = [_make_struct(x) for x in [1.0, 2.0, 3.0]]
        client_weights = [3.0, 1.0, 2.0]
        output = process.next(state, client_data, client_weights)
        self.assertAllClose(_make_struct(11 / 6), output.result)
        self.assertDictEqual(
            collections.OrderedDict(mean_value=(),
                                    mean_weight=(),
                                    min_weighted_norm=4.0),
            output.measurements)
Esempio n. 7
0
                                   b=tf.constant(value,
                                                 dtype=tf.float32,
                                                 shape=(3, 3)))


def _named_test_cases_product(*args):
    """Utility for creating parameterized named test cases."""
    named_cases = []
    dict1, dict2 = args
    for k1, v1 in dict1.items():
        for k2, v2 in dict2.items():
            named_cases.append(('_'.join([k1, k2]), v1, v2))
    return named_cases


_measurement_aggregator = measurements.add_measurements(
    sum_factory.SumFactory(), client_measurement_fn=intrinsics.federated_sum)


class DeterministicDiscretizationComputationTest(tf.test.TestCase,
                                                 parameterized.TestCase):
    @parameterized.named_parameters(
        ('float', tf.float32),
        ('struct_list_float_scalars', [tf.float16, tf.float32, tf.float64]),
        ('struct_list_float_mixed', _test_struct_type_float),
        ('struct_nested', _test_nested_struct_type_float))
    def test_type_properties(self, value_type):
        factory = deterministic_discretization.DeterministicDiscretizationFactory(
            step_size=0.1,
            inner_agg_factory=_measurement_aggregator,
            distortion_aggregation_factory=mean.UnweightedMeanFactory())
        value_type = computation_types.to_type(value_type)
Esempio n. 8
0
def _measured_test_sum_factory():
  # SumFactory which also returns the sum as measurements. This is useful for
  # monitoring what values are passed through an inner aggregator.
  return measurements.add_measurements(
      sum_factory.SumFactory(),
      server_measurement_fn=lambda x: collections.OrderedDict(sum=x))
Esempio n. 9
0
def add_debug_measurements_with_mixed_dtype(
    aggregation_factory: factory.AggregationFactory
) -> factory.AggregationFactory:
    """Adds measurements suitable for debugging learning processes.

  WARNING: This method works for model updates with mixed, non-`tf.float32`
  dtypes by casting all tensors to `tf.float32`. This has important numerical
  considerations, for example if the updates are quantized to `tf.int32`
  precision can be lost for values creating than approximately 2^24, and
  `tf.float64` dtypes will be narrowed to `tf.float32`. Most users
  should prefer `tff.learning.add_debug_measurements`.

  This will wrap a `tff.aggregator.AggregationFactory` as a new factory that
  will produce additional measurements useful for debugging learning processes.
  The underlying aggregation of client values will remain unchanged.

  These measurements generally concern the norm of the client updates, and the
  norm of the aggregated server update. The implicit weighting will be
  determined by `aggregation_factory`: If this is weighted, then the debugging
  measurements will use this weighting when computing averages. If it is
  unweighted, the debugging measurements will use uniform weighting.

  The client measurements are:

  *   The average Euclidean norm of client updates.
  *   The standard deviation of these norms.

  The standard deviation we report is the square root of the **unbiased**
  variance. The server measurements are:

  *   The maximum entry of the aggregate client update.
  *   The Euclidean norm of the aggregate client update.
  *   The minimum entry of the aggregate client update.

  In the above, an "entry" means any coordinate across all tensors in the
  structure. For example, suppose that we have client structures before
  aggregation:

  *   Client A: `[[-1, -3, -5], [2]]`
  *   Client B: `[[-1, -3, 1], [0]]`

  If we use unweighted averaging, then the aggregate client update will be the
  structure `[[-1, -3, -2], [1]]`. The maximum entry is `1`, the minimum entry
  is `-3`, and the euclidean norm is `sqrt(15)`.

  Args:
    aggregation_factory: A `tff.aggregators.AggregationFactory`. Can be weighted
      or unweighted.

  Returns:
    A `tff.aggregators.AggregationFactory`.
  """
    is_weighted_aggregator = isinstance(aggregation_factory,
                                        factory.WeightedAggregationFactory)
    client_measurement_fn, server_measurement_fn = (
        _build_aggregator_measurement_fns(
            client_measurement_fn=
            _calculate_client_update_statistics_mixed_dtype,
            server_measurement_fn=
            _calculate_server_update_statistics_mixed_dtype,
            weighted_aggregator=is_weighted_aggregator))

    return measurements.add_measurements(
        aggregation_factory,
        client_measurement_fn=client_measurement_fn,
        server_measurement_fn=server_measurement_fn)