def test_federated_init_state_not_assignable(self):
     initialize_fn = federated_computation.federated_computation()(
         lambda: intrinsics.federated_value(0, placements.SERVER))
     next_fn = federated_computation.federated_computation(
         FederatedType(tf.int32, placements.CLIENTS))(lambda state: state)
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         iterative_process.IterativeProcess(initialize_fn, next_fn)
예제 #2
0
 def test_federated_init_state_not_assignable(self):
   zero = lambda: intrinsics.federated_value(0, placements.SERVER)
   initialize_fn = federated_computation.federated_computation()(zero)
   next_fn = federated_computation.federated_computation(
       computation_types.FederatedType(tf.int32, placements.CLIENTS))(
           lambda state: MeasuredProcessOutput(state, zero(), zero()))
   with self.assertRaises(errors.TemplateStateNotAssignableError):
     measured_process.MeasuredProcess(initialize_fn, next_fn)
 def test_federated_next_state_not_assignable(self):
     initialize_fn = federated_computation.federated_computation()(
         lambda: intrinsics.federated_value(0, placements.SERVER))
     next_fn = federated_computation.federated_computation(
         initialize_fn.type_signature.result)(
             intrinsics.federated_broadcast)
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         iterative_process.IterativeProcess(initialize_fn, next_fn)
예제 #4
0
def _constant_process(value):
    """Creates an `EstimationProcess` that reports a constant value."""
    init_fn = federated_computation.federated_computation(
        lambda: intrinsics.federated_value((), placements.SERVER))
    next_fn = federated_computation.federated_computation(
        lambda state, value: state, init_fn.type_signature.result,
        computation_types.at_clients(NORM_TF_TYPE))
    report_fn = federated_computation.federated_computation(
        lambda state: intrinsics.federated_value(value, placements.SERVER),
        init_fn.type_signature.result)
    return estimation_process.EstimationProcess(init_fn, next_fn, report_fn)
예제 #5
0
 def test_federated_report_state_not_assignable(self):
     initialize_fn = federated_computation.federated_computation()(
         lambda: intrinsics.federated_value(0, placements.SERVER))
     next_fn = federated_computation.federated_computation(
         initialize_fn.type_signature.result)(lambda state: state)
     report_fn = federated_computation.federated_computation(
         computation_types.FederatedType(
             tf.int32, placements.CLIENTS))(lambda state: state)
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         estimation_process.EstimationProcess(initialize_fn, next_fn,
                                              report_fn)
예제 #6
0
def federated_output_computation_from_metrics(
    metrics: List[tf.keras.metrics.Metric]
) -> federated_computation.federated_computation:
  """Produces a federated computation for aggregating Keras metrics.

  This can be used to evaluate both Keras and non-Keras models using Keras
  metrics. Aggregates metrics across clients by summing their internal
  variables, producing new metrics with summed internal variables, and calling
  metric.result() on each. See `tff.learning.federated_aggregate_keras_metric`
  for details.

  Args:
    metrics: A List of `tf.keras.metrics.Metric` to aggregate.

  Returns:
    A `tff.federated_computation` aggregating metrics across clients by summing
    their internal variables, producing new metrics with summed internal
    variables, and calling metric.result() on each.
  """
  # Get a sample of metric variables to use to determine its type.
  sample_metric_variables = read_metric_variables(metrics)

  metric_variable_type_dict = tf.nest.map_structure(tf.TensorSpec.from_tensor,
                                                    sample_metric_variables)
  federated_local_outputs_type = computation_types.at_clients(
      metric_variable_type_dict)

  def federated_output(local_outputs):
    return base_utils.federated_aggregate_keras_metric(metrics, local_outputs)

  federated_output_computation = federated_computation.federated_computation(
      federated_output, federated_local_outputs_type)
  return federated_output_computation
예제 #7
0
    def test_cardinality_free_data_descriptor_places_data(self):
        ds = data_descriptor.CardinalityFreeDataDescriptor(
            federated_computation.federated_computation(
                lambda x: intrinsics.federated_value(x, placements.CLIENTS),
                tf.int32), 1000, computation_types.TensorType(tf.int32))
        self.assertEqual(str(ds.type_signature), 'int32@CLIENTS')

        @federated_computation.federated_computation(
            computation_types.FederatedType(tf.int32,
                                            placements.CLIENTS,
                                            all_equal=True))
        def foo(x):
            return intrinsics.federated_sum(x)

        # Since this DataDescriptor does not specify its cardinality, the number of
        # values placed is inferred from the decault setting for the executor.
        with executor_test_utils.install_executor(
                executor_test_utils.LocalTestExecutorFactory(
                    default_num_clients=1)):
            result = foo(ds)
        self.assertEqual(result, 1000)

        with executor_test_utils.install_executor(
                executor_test_utils.LocalTestExecutorFactory(
                    default_num_clients=3)):
            result = foo(ds)
        self.assertEqual(result, 3000)
예제 #8
0
 def test_raises_computation_no_dataset_parameter(self):
   no_dataset_comp = federated_computation.federated_computation(
       lambda x: x, [tf.int32])
   with self.assertRaises(
       iterative_process_compositions.SequenceTypeNotFoundError):
     iterative_process_compositions.compose_dataset_computation_with_computation(
         int_dataset_computation, no_dataset_comp)
예제 #9
0
    def test_raises_on_bad_process_next_single_param(self, make_factory):
        next_fn = federated_computation.federated_computation(
            lambda state: state, _float_at_server)
        norm = _test_norm_process(next_fn=next_fn)

        with self.assertRaisesRegex(TypeError, '.* must take two arguments.'):
            make_factory(norm)
예제 #10
0
 def _bind_federated_value(unused_input, input_type,
                           federated_output_value):
     federated_input_type = computation_types.FederatedType(
         input_type, placements.CLIENTS)
     wrapper = federated_computation.federated_computation(
         lambda _: federated_output_value, federated_input_type)
     return wrapper(unused_input)
예제 #11
0
    def test_raises_on_bad_process_next_two_outputs(self, make_factory):
        next_fn = federated_computation.federated_computation(
            lambda state, val: (state, state), _float_at_server,
            _float_at_clients)
        norm = _test_norm_process(next_fn=next_fn)

        with self.assertRaisesRegex(TypeError, 'Result type .* state only.'):
            make_factory(norm)
예제 #12
0
def _test_float_next_fn(factor):
    @tensorflow_computation.tf_computation
    def shift_one(x):
        return x + (factor * 1.0)

    return federated_computation.federated_computation(
        lambda state, value: intrinsics.federated_map(shift_one, state),
        _float_at_server, _float_at_clients)
예제 #13
0
    def test_raises_on_bad_process_next_not_float(self, make_factory):
        complex_at_clients = computation_types.at_clients(tf.complex64)
        next_fn = federated_computation.federated_computation(
            lambda state, value: state, _float_at_server, complex_at_clients)
        norm = _test_norm_process(next_fn=next_fn)

        with self.assertRaisesRegex(TypeError,
                                    'Second argument .* assignable from'):
            make_factory(norm)
예제 #14
0
    def test_raises_on_bad_norm_process_result(self, value, placement,
                                               make_factory):
        report_fn = federated_computation.federated_computation(
            lambda s: intrinsics.federated_value(value, placement),
            _float_at_server)
        norm = _test_norm_process(report_fn=report_fn)

        with self.assertRaisesRegex(TypeError,
                                    r'Result type .* assignable to'):
            make_factory(norm)
  def test_non_server_placed_init_state_raises(self):
    initialize_fn = federated_computation.federated_computation(
        lambda: intrinsics.federated_value(0, placements.CLIENTS))

    @federated_computation.federated_computation(CLIENTS_INT, CLIENTS_FLOAT)
    def next_fn(state, val):
      return MeasuredProcessOutput(state, intrinsics.federated_sum(val),
                                   server_zero())

    with self.assertRaises(aggregation_process.AggregationPlacementError):
      aggregation_process.AggregationProcess(initialize_fn, next_fn)
  def test_init_tuple_of_federated_types_raises(self):
    initialize_fn = federated_computation.federated_computation()(
        lambda: (server_zero(), server_zero()))

    @federated_computation.federated_computation(
        initialize_fn.type_signature.result, CLIENTS_FLOAT)
    def next_fn(state, val):
      return MeasuredProcessOutput(state, intrinsics.federated_sum(val), ())

    with self.assertRaises(aggregation_process.AggregationNotFederatedError):
      aggregation_process.AggregationProcess(initialize_fn, next_fn)
예제 #17
0
  def test_federated_next_state_not_assignable(self):
    initialize_fn = federated_computation.federated_computation()(
        lambda: intrinsics.federated_value(0, placements.SERVER))

    @federated_computation.federated_computation(
        initialize_fn.type_signature.result)
    def next_fn(state):
      return MeasuredProcessOutput(
          intrinsics.federated_broadcast(state), (), ())

    with self.assertRaises(errors.TemplateStateNotAssignableError):
      measured_process.MeasuredProcess(initialize_fn, next_fn)
예제 #18
0
  def test_init_tuple_of_federated_types_raises(self):
    initialize_fn = federated_computation.federated_computation()(
        lambda: (server_zero(), server_zero()))

    @federated_computation.federated_computation(
        initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE, SERVER_FLOAT)
    def next_fn(state, weights, update):
      return MeasuredProcessOutput(state,
                                   test_finalizer_result(weights, update),
                                   server_zero())

    with self.assertRaises(errors.TemplateNotFederatedError):
      finalizers.FinalizerProcess(initialize_fn, next_fn)
  def test_construction_with_empty_state_does_not_raise(self):
    initialize_fn = federated_computation.federated_computation()(server_zero)

    @federated_computation.federated_computation(SERVER_INT, CLIENTS_FLOAT)
    def next_fn(state, val):
      return MeasuredProcessOutput(
          state, intrinsics.federated_sum(val),
          intrinsics.federated_value(1, placements.SERVER))

    try:
      aggregation_process.AggregationProcess(initialize_fn, next_fn)
    except:  # pylint: disable=bare-except
      self.fail('Could not construct an AggregationProcess with empty state.')
예제 #20
0
def _create_test_measured_process_state_at_clients():

  @federated_computation.federated_computation(
      computation_types.at_clients(tf.int32),
      computation_types.at_clients(tf.int32))
  def next_fn(state, values):
    return measured_process.MeasuredProcessOutput(
        state, intrinsics.federated_sum(values),
        intrinsics.federated_value(1, placements.SERVER))

  return measured_process.MeasuredProcess(
      initialize_fn=federated_computation.federated_computation(
          lambda: intrinsics.federated_value(0, placements.CLIENTS)),
      next_fn=next_fn)
예제 #21
0
  def test_non_server_placed_init_state_raises(self):
    initialize_fn = federated_computation.federated_computation(
        lambda: intrinsics.federated_value(0, placements.CLIENTS))

    @federated_computation.federated_computation(CLIENTS_INT,
                                                 MODEL_WEIGHTS_TYPE,
                                                 SERVER_FLOAT)
    def next_fn(state, weights, update):
      return MeasuredProcessOutput(state,
                                   test_finalizer_result(weights, update),
                                   server_zero())

    with self.assertRaises(errors.TemplatePlacementError):
      finalizers.FinalizerProcess(initialize_fn, next_fn)
예제 #22
0
def _create_test_aggregation_process(state_type, state_init, values_type):

  @federated_computation.federated_computation(
      computation_types.at_server(state_type),
      computation_types.at_clients(values_type))
  def next_fn(state, values):
    return measured_process.MeasuredProcessOutput(
        state, intrinsics.federated_sum(values),
        intrinsics.federated_value(1, placements.SERVER))

  return aggregation_process.AggregationProcess(
      initialize_fn=federated_computation.federated_computation(
          lambda: intrinsics.federated_value(state_init, placements.SERVER)),
      next_fn=next_fn)
예제 #23
0
    def test_non_server_placed_init_state_raises(self):
        initialize_fn = federated_computation.federated_computation(
            lambda: intrinsics.federated_value(0, placements.CLIENTS))

        @federated_computation.federated_computation(
            initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE,
            CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(state,
                                         test_client_result(weights, data),
                                         server_zero())

        with self.assertRaises(errors.TemplatePlacementError):
            client_works.ClientWorkProcess(initialize_fn, next_fn)
예제 #24
0
    def test_init_tuple_of_federated_types_raises(self):
        initialize_fn = federated_computation.federated_computation()(
            lambda: (server_zero(), server_zero()))

        @federated_computation.federated_computation(
            initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE,
            CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(state,
                                         test_client_result(weights, data),
                                         server_zero())

        with self.assertRaises(errors.TemplateNotFederatedError):
            client_works.ClientWorkProcess(initialize_fn, next_fn)
예제 #25
0
  def test_construction_with_empty_state_does_not_raise(self):
    initialize_fn = federated_computation.federated_computation()(
        lambda: intrinsics.federated_value((), placements.SERVER))

    @federated_computation.federated_computation(
        initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE, SERVER_FLOAT)
    def next_fn(state, weights, update):
      return MeasuredProcessOutput(
          state, test_finalizer_result(weights, update),
          intrinsics.federated_value(1, placements.SERVER))

    try:
      finalizers.FinalizerProcess(initialize_fn, next_fn)
    except:  # pylint: disable=bare-except
      self.fail('Could not construct an FinalizerProcess with empty state.')
예제 #26
0
    def test_federated_mapped_process_as_expected(self):
        initialize_fn = federated_computation.federated_computation()(
            lambda: intrinsics.federated_value(0, placements.SERVER))
        next_fn = federated_computation.federated_computation(
            initialize_fn.type_signature.result)(lambda state: state)
        report_fn = federated_computation.federated_computation(
            initialize_fn.type_signature.result)(
                lambda state: intrinsics.federated_map(test_report_fn, state))
        process = estimation_process.EstimationProcess(initialize_fn, next_fn,
                                                       report_fn)

        map_fn = federated_computation.federated_computation(
            report_fn.type_signature.result
        )(lambda estimate: intrinsics.federated_map(test_map_fn, estimate))
        mapped_process = process.map(map_fn)

        self.assertIsInstance(mapped_process,
                              estimation_process.EstimationProcess)
        self.assertEqual(process.initialize, mapped_process.initialize)
        self.assertEqual(process.next, mapped_process.next)
        self.assertEqual(process.report.type_signature.parameter,
                         mapped_process.report.type_signature.parameter)
        self.assertEqual(map_fn.type_signature.result,
                         mapped_process.report.type_signature.result)
예제 #27
0
def _encoded_init_fn(encoders):
  """Creates `init_fn` for the process returned by `EncodedSumFactory`.

  The state for the `EncodedSumFactory` is directly derived from the state of
  the `GatherEncoder` objects that parameterize the functionality.

  Args:
    encoders: A collection of `GatherEncoder` objects.

  Returns:
    A no-arg `tff.Computation` returning initial state for `EncodedSumFactory`.
  """
  init_fn_tf = tensorflow_computation.tf_computation(
      lambda: tf.nest.map_structure(lambda e: e.initial_state(), encoders))
  init_fn = federated_computation.federated_computation(
      lambda: intrinsics.federated_eval(init_fn_tf, placements.SERVER))
  return init_fn
예제 #28
0
    def test_construction_with_empty_state_does_not_raise(self):
        initialize_fn = federated_computation.federated_computation()(
            lambda: intrinsics.federated_value((), placements.SERVER))

        @federated_computation.federated_computation(
            initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE,
            CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(
                state, test_client_result(weights, data),
                intrinsics.federated_value(1, placements.SERVER))

        try:
            client_works.ClientWorkProcess(initialize_fn, next_fn)
        except:  # pylint: disable=bare-except
            self.fail(
                'Could not construct an ClientWorkProcess with empty state.')
예제 #29
0
  def test_federated_measured_process_output_raises(self):
    initialize_fn = federated_computation.federated_computation()(
        lambda: intrinsics.federated_value(0, placements.SERVER))
    empty = lambda: intrinsics.federated_value((), placements.SERVER)
    state_type = initialize_fn.type_signature.result

    # Using federated_zip to place FederatedType at the top of the hierarchy.
    @federated_computation.federated_computation(state_type)
    def next_fn(state):
      return intrinsics.federated_zip(
          MeasuredProcessOutput(state, empty(), empty()))

    # A MeasuredProcessOutput containing three `FederatedType`s is different
    # than a `FederatedType` containing a MeasuredProcessOutput. Corrently, only
    # the former is considered valid.
    with self.assertRaises(errors.TemplateStateNotAssignableError):
      measured_process.MeasuredProcess(initialize_fn, next_fn)
예제 #30
0
    def test_federated(self):
        ds = data_descriptor.DataDescriptor(
            federated_computation.federated_computation(
                lambda x: intrinsics.federated_value(x, placements.CLIENTS),
                tf.int32), 1000, computation_types.TensorType(tf.int32), 3)
        self.assertEqual(str(ds.type_signature), 'int32@CLIENTS')

        @federated_computation.federated_computation(
            computation_types.FederatedType(tf.int32,
                                            placements.CLIENTS,
                                            all_equal=True))
        def foo(x):
            return intrinsics.federated_sum(x)

        with executor_test_utils.install_executor(
                executor_test_utils.LocalTestExecutorFactory()):
            result = foo(ds)
        self.assertEqual(result, 3000)