Example #1
0
 def test_next_not_tff_computation_raises(self):
     with self.assertRaisesRegex(TypeError,
                                 r'Expected .*\.Computation, .*'):
         client_works.ClientWorkProcess(
             initialize_fn=test_initialize_fn,
             next_fn=lambda state, w, d: MeasuredProcessOutput(
                 state, w + d, ()))
Example #2
0
    def test_next_return_tuple_raises(self):
        @federated_computation.federated_computation(SERVER_INT,
                                                     MODEL_WEIGHTS_TYPE,
                                                     CLIENTS_FLOAT_SEQUENCE)
        def tuple_next_fn(state, weights, data):
            return (state, test_client_result(weights, data), server_zero())

        with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError):
            client_works.ClientWorkProcess(test_initialize_fn, tuple_next_fn)
Example #3
0
    def test_two_param_next_raises(self):
        @federated_computation.federated_computation(SERVER_INT,
                                                     MODEL_WEIGHTS_TYPE)
        def next_fn(state, weights):
            return MeasuredProcessOutput(state, weights.trainable,
                                         server_zero())

        with self.assertRaises(errors.TemplateNextFnNumArgsError):
            client_works.ClientWorkProcess(test_initialize_fn, next_fn)
Example #4
0
    def test_non_server_placed_next_measurements_raises(self):
        @federated_computation.federated_computation(SERVER_INT,
                                                     MODEL_WEIGHTS_TYPE,
                                                     CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(
                state, test_client_result(weights, data),
                intrinsics.federated_value(1.0, placements.CLIENTS))

        with self.assertRaises(errors.TemplatePlacementError):
            client_works.ClientWorkProcess(test_initialize_fn, next_fn)
Example #5
0
    def test_next_return_odict_raises(self):
        @federated_computation.federated_computation(SERVER_INT,
                                                     MODEL_WEIGHTS_TYPE,
                                                     CLIENTS_FLOAT_SEQUENCE)
        def odict_next_fn(state, weights, data):
            return collections.OrderedDict(state=state,
                                           result=test_client_result(
                                               weights, data),
                                           measurements=server_zero())

        with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError):
            client_works.ClientWorkProcess(test_initialize_fn, odict_next_fn)
Example #6
0
    def test_non_clients_placed_next_weights_param_raises(self):
        @federated_computation.federated_computation(
            SERVER_INT, computation_types.at_server(MODEL_WEIGHTS_TYPE.member),
            CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(
                state,
                test_client_result(intrinsics.federated_broadcast(weights),
                                   data), server_zero())

        with self.assertRaises(errors.TemplatePlacementError):
            client_works.ClientWorkProcess(test_initialize_fn, next_fn)
Example #7
0
    def test_next_state_not_assignable(self):
        @federated_computation.federated_computation(SERVER_INT,
                                                     MODEL_WEIGHTS_TYPE,
                                                     CLIENTS_FLOAT_SEQUENCE)
        def float_next_fn(state, weights, data):
            del state
            return MeasuredProcessOutput(
                intrinsics.federated_value(0.0, placements.SERVER),
                test_client_result(weights, data),
                intrinsics.federated_value(1, placements.SERVER))

        with self.assertRaises(errors.TemplateStateNotAssignableError):
            client_works.ClientWorkProcess(test_initialize_fn, float_next_fn)
Example #8
0
    def test_non_server_placed_init_state_raises(self):
        initialize_fn = federated_computation.federated_computation(
            lambda: intrinsics.federated_value(0, placements.CLIENTS))

        @federated_computation.federated_computation(
            initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE,
            CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(state,
                                         test_client_result(weights, data),
                                         server_zero())

        with self.assertRaises(errors.TemplatePlacementError):
            client_works.ClientWorkProcess(initialize_fn, next_fn)
Example #9
0
    def test_non_sequence_or_struct_next_data_param_raises(self):
        @federated_computation.federated_computation(SERVER_INT,
                                                     MODEL_WEIGHTS_TYPE,
                                                     CLIENTS_FLOAT)
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(
                state,
                intrinsics.federated_zip(
                    client_works.ClientResult(
                        federated_add(weights.trainable, data), client_one())),
                server_zero())

        with self.assertRaises(client_works.ClientDataTypeError):
            client_works.ClientWorkProcess(test_initialize_fn, next_fn)
Example #10
0
    def test_non_zipped_next_result_raises(self):
        @federated_computation.federated_computation(SERVER_INT,
                                                     MODEL_WEIGHTS_TYPE,
                                                     CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            reduced_data = intrinsics.federated_map(tf_data_sum, data)
            return MeasuredProcessOutput(
                state,
                client_works.ClientResult(
                    federated_add(weights.trainable, reduced_data),
                    client_one()), server_zero())

        with self.assertRaises(errors.TemplatePlacementError):
            client_works.ClientWorkProcess(test_initialize_fn, next_fn)
Example #11
0
    def test_init_tuple_of_federated_types_raises(self):
        initialize_fn = federated_computation.federated_computation()(
            lambda: (server_zero(), server_zero()))

        @federated_computation.federated_computation(
            initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE,
            CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(state,
                                         test_client_result(weights, data),
                                         server_zero())

        with self.assertRaises(errors.TemplateNotFederatedError):
            client_works.ClientWorkProcess(initialize_fn, next_fn)
Example #12
0
    def test_non_federated_init_next_raises(self):
        initialize_fn = tensorflow_computation.tf_computation(lambda: 0)

        @tensorflow_computation.tf_computation(
            tf.int32, MODEL_WEIGHTS_TYPE.member,
            computation_types.SequenceType(tf.float32))
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(
                state,
                client_works.ClientResult(
                    weights.trainable + tf_data_sum(data), ()), ())

        with self.assertRaises(errors.TemplateNotFederatedError):
            client_works.ClientWorkProcess(initialize_fn, next_fn)
Example #13
0
    def test_next_return_namedtuple_raises(self):
        measured_process_output = collections.namedtuple(
            'MeasuredProcessOutput', ['state', 'result', 'measurements'])

        @federated_computation.federated_computation(SERVER_INT,
                                                     MODEL_WEIGHTS_TYPE,
                                                     CLIENTS_FLOAT_SEQUENCE)
        def namedtuple_next_fn(state, weights, data):
            return measured_process_output(state,
                                           test_client_result(weights, data),
                                           server_zero())

        with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError):
            client_works.ClientWorkProcess(test_initialize_fn,
                                           namedtuple_next_fn)
Example #14
0
    def test_incorrect_client_result_container_raises(self):
        @federated_computation.federated_computation(SERVER_INT,
                                                     MODEL_WEIGHTS_TYPE,
                                                     CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            reduced_data = intrinsics.federated_map(tf_data_sum, data)
            bad_client_result = intrinsics.federated_zip(
                collections.OrderedDict(update=federated_add(
                    weights.trainable, reduced_data),
                                        update_weight=client_one()))
            return MeasuredProcessOutput(state, bad_client_result,
                                         server_zero())

        with self.assertRaises(client_works.ClientResultTypeError):
            client_works.ClientWorkProcess(test_initialize_fn, next_fn)
Example #15
0
def test_client_work():
    @tensorflow_computation.tf_computation()
    def make_result(value, data):
        return client_works.ClientResult(update=value.trainable,
                                         update_weight=data.reduce(
                                             0.0, lambda x, y: x + y))

    @federated_computation.federated_computation(
        empty_init_fn.type_signature.result,
        computation_types.at_clients(MODEL_WEIGHTS_TYPE),
        CLIENTS_SEQUENCE_FLOAT_TYPE)
    def next_fn(state, value, client_data):
        result = intrinsics.federated_map(make_result, (value, client_data))
        return measured_process.MeasuredProcessOutput(state, result,
                                                      empty_at_server())

    return client_works.ClientWorkProcess(empty_init_fn, next_fn)
Example #16
0
    def test_constructs_with_struct_of_client_data_parameter(self):
        @federated_computation.federated_computation(
            SERVER_INT, MODEL_WEIGHTS_TYPE,
            computation_types.at_clients(
                (computation_types.SequenceType(tf.float32),
                 (computation_types.SequenceType(tf.float32),
                  computation_types.SequenceType(tf.float32)))))
        def next_fn(state, unused_weights, unused_data):
            return MeasuredProcessOutput(
                state,
                intrinsics.federated_value(client_works.ClientResult((), ()),
                                           placements.CLIENTS), server_zero())

        try:
            client_works.ClientWorkProcess(test_initialize_fn, next_fn)
        except client_works.ClientDataTypeError:
            self.fail('Could not construct a valid ClientWorkProcess.')
Example #17
0
    def test_construction_with_empty_state_does_not_raise(self):
        initialize_fn = federated_computation.federated_computation()(
            lambda: intrinsics.federated_value((), placements.SERVER))

        @federated_computation.federated_computation(
            initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE,
            CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(
                state, test_client_result(weights, data),
                intrinsics.federated_value(1, placements.SERVER))

        try:
            client_works.ClientWorkProcess(initialize_fn, next_fn)
        except:  # pylint: disable=bare-except
            self.fail(
                'Could not construct an ClientWorkProcess with empty state.')
Example #18
0
    def test_non_clients_placed_next_data_param_raises(self):
        server_sequence_float_type = computation_types.at_server(
            computation_types.SequenceType(tf.float32))

        @federated_computation.federated_computation(SERVER_INT,
                                                     MODEL_WEIGHTS_TYPE,
                                                     server_sequence_float_type
                                                     )
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(
                state,
                test_client_result(weights,
                                   intrinsics.federated_broadcast(data)),
                server_zero())

        with self.assertRaises(errors.TemplatePlacementError):
            client_works.ClientWorkProcess(test_initialize_fn, next_fn)
Example #19
0
    def test_constructs_with_non_model_weights_parameter(self):
        non_model_weights_type = computation_types.at_clients(
            computation_types.to_type(
                collections.OrderedDict(trainable=tf.float32,
                                        non_trainable=())))

        @federated_computation.federated_computation(SERVER_INT,
                                                     non_model_weights_type,
                                                     CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(state,
                                         test_client_result(weights, data),
                                         server_zero())

        try:
            client_works.ClientWorkProcess(test_initialize_fn, next_fn)
        except client_works.ClientDataTypeError:
            self.fail('Could not construct a valid ClientWorkProcess.')
Example #20
0
def _build_kmeans_client_work(centroids_type: computation_types.TensorType,
                              data_type: computation_types.SequenceType):
    """Creates a `tff.learning.templates.ClientWorkProcess` for k-means."""
    @federated_computation.federated_computation
    def init_fn():
        return intrinsics.federated_value((), placements.SERVER)

    @tensorflow_computation.tf_computation(centroids_type, data_type)
    def client_update(centroids, client_data):
        return _compute_kmeans_step(centroids, client_data)

    @federated_computation.federated_computation(
        init_fn.type_signature.result,
        computation_types.at_clients(centroids_type),
        computation_types.at_clients(data_type))
    def next_fn(state, cluster_centers, client_data):
        client_result, stat_output = intrinsics.federated_map(
            client_update, (cluster_centers, client_data))
        stat_metrics = intrinsics.federated_sum(stat_output)
        return measured_process.MeasuredProcessOutput(state, client_result,
                                                      stat_metrics)

    return client_works.ClientWorkProcess(init_fn, next_fn)
Example #21
0
def _build_fed_sgd_client_work(
    model_fn: Callable[[], model_lib.Model],
    metrics_aggregator: Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation],
    use_experimental_simulation_loop: bool = False
) -> client_works.ClientWorkProcess:
    """Creates a `tff.learning.templates.ClientWorkProcess` for federated SGD.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `tff.learning.templates.ClientWorkProcess`.
  """
    with tf.Graph().as_default():
        # Wrap model construction in a graph to avoid polluting the global context
        # with variables created for this model.
        model = model_fn()
        unfinalized_metrics_type = type_conversions.type_from_tensors(
            model.report_local_unfinalized_metrics())
        metrics_aggregation_fn = metrics_aggregator(model.metric_finalizers(),
                                                    unfinalized_metrics_type)
    data_type = computation_types.SequenceType(model.input_spec)
    weights_type = model_utils.weights_type_from_model(model)

    @federated_computation.federated_computation
    def init_fn():
        return intrinsics.federated_value((), placements.SERVER)

    @tensorflow_computation.tf_computation(weights_type, data_type)
    def client_update_computation(initial_model_weights, dataset):
        client_update = _build_client_update(model_fn(),
                                             use_experimental_simulation_loop)
        return client_update(initial_model_weights, dataset)

    @federated_computation.federated_computation(
        init_fn.type_signature.result,
        computation_types.at_clients(weights_type),
        computation_types.at_clients(data_type))
    def next_fn(state, model_weights, client_data):
        client_result, model_outputs = intrinsics.federated_map(
            client_update_computation, (model_weights, client_data))
        train_metrics = metrics_aggregation_fn(model_outputs)
        measurements = intrinsics.federated_zip(
            collections.OrderedDict(train=train_metrics))
        return measured_process.MeasuredProcessOutput(state, client_result,
                                                      measurements)

    return client_works.ClientWorkProcess(init_fn, next_fn)
Example #22
0
 def test_construction_does_not_raise(self):
     try:
         client_works.ClientWorkProcess(test_initialize_fn, test_next_fn)
     except:  # pylint: disable=bare-except
         self.fail('Could not construct a valid ClientWorkProcess.')
def build_scheduled_client_work(
    model_fn: Callable[[], model_lib.Model],
    learning_rate_fn: Callable[[int], float],
    optimizer_fn: Callable[[float], TFFOrKerasOptimizer],
    metrics_aggregator: Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation],
    use_experimental_simulation_loop: bool = False
) -> client_works.ClientWorkProcess:
  """Creates a `ClientWorkProcess` for federated averaging.

  This `ClientWorkProcess` creates a state containing the current round number,
  which is incremented at each call to `ClientWorkProcess.next`. This integer
  round number is used to call `optimizer_fn(round_num)`, in order to construct
  the proper optimizer.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    learning_rate_fn: A callable accepting an integer round number and returning
      a float to be used as a learning rate for the optimizer. That is, the
      client work will call `optimizer_fn(learning_rate_fn(round_num))` where
      `round_num` is the integer round number.
    optimizer_fn: A callable accepting a float learning rate, and returning a
      `tff.learning.optimizers.Optimizer` or a `tf.keras.Optimizer`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `ClientWorkProcess`.
  """
  with tf.Graph().as_default():
    # Wrap model construction in a graph to avoid polluting the global context
    # with variables created for this model.
    whimsy_model = model_fn()
    whimsy_optimizer = optimizer_fn(1.0)
    unfinalized_metrics_type = type_conversions.type_from_tensors(
        whimsy_model.report_local_unfinalized_metrics())
    metrics_aggregation_fn = metrics_aggregator(
        whimsy_model.metric_finalizers(), unfinalized_metrics_type)
  data_type = computation_types.SequenceType(whimsy_model.input_spec)
  weights_type = model_utils.weights_type_from_model(whimsy_model)

  if isinstance(whimsy_optimizer, optimizer_base.Optimizer):
    build_client_update_fn = model_delta_client_work.build_model_delta_update_with_tff_optimizer
  else:
    build_client_update_fn = model_delta_client_work.build_model_delta_update_with_keras_optimizer

  @tensorflow_computation.tf_computation(weights_type, data_type, tf.int32)
  def client_update_computation(initial_model_weights, dataset, round_num):
    learning_rate = learning_rate_fn(round_num)
    optimizer = optimizer_fn(learning_rate)
    client_update = build_client_update_fn(
        model_fn=model_fn,
        weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES,
        use_experimental_simulation_loop=use_experimental_simulation_loop)
    return client_update(optimizer, initial_model_weights, dataset)

  @federated_computation.federated_computation
  def init_fn():
    return intrinsics.federated_value(0, placements.SERVER)

  @tensorflow_computation.tf_computation(tf.int32)
  @tf.function
  def add_one(x):
    return x + 1

  @federated_computation.federated_computation(
      init_fn.type_signature.result, computation_types.at_clients(weights_type),
      computation_types.at_clients(data_type))
  def next_fn(state, weights, client_data):
    round_num_at_clients = intrinsics.federated_broadcast(state)
    client_result, model_outputs = intrinsics.federated_map(
        client_update_computation, (weights, client_data, round_num_at_clients))
    updated_state = intrinsics.federated_map(add_one, state)
    train_metrics = metrics_aggregation_fn(model_outputs)
    measurements = intrinsics.federated_zip(
        collections.OrderedDict(train=train_metrics))
    return measured_process.MeasuredProcessOutput(updated_state, client_result,
                                                  measurements)

  return client_works.ClientWorkProcess(init_fn, next_fn)
Example #24
0
def _build_mime_lite_client_work(
    model_fn: Callable[[], model_lib.Model],
    optimizer: optimizer_base.Optimizer,
    client_weighting: client_weight_lib.ClientWeighting,
    full_gradient_aggregator: Optional[
        factory.WeightedAggregationFactory] = None,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    use_experimental_simulation_loop: bool = False
) -> client_works.ClientWorkProcess:
    """Creates a `ClientWorkProcess` for Mime Lite.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    optimizer: A `tff.learning.optimizers.Optimizer` which will be used for both
      creating and updating a global optimizer state, as well as optimization at
      clients given the global state, which is fixed during the optimization.
    client_weighting: A member of `tff.learning.ClientWeighting` that specifies
      a built-in weighting method.
    full_gradient_aggregator: An optional
      `tff.aggregators.WeightedAggregationFactory` used to aggregate the full
      gradients on client datasets. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `ClientWorkProcess`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_type(optimizer, optimizer_base.Optimizer)
    py_typecheck.check_type(client_weighting,
                            client_weight_lib.ClientWeighting)
    if full_gradient_aggregator is None:
        full_gradient_aggregator = mean.MeanFactory()
    py_typecheck.check_type(full_gradient_aggregator,
                            factory.WeightedAggregationFactory)
    if metrics_aggregator is None:
        metrics_aggregator = metric_aggregator.sum_then_finalize

    with tf.Graph().as_default():
        # Wrap model construction in a graph to avoid polluting the global context
        # with variables created for this model.
        model = model_fn()
        unfinalized_metrics_type = type_conversions.type_from_tensors(
            model.report_local_unfinalized_metrics())
        metrics_aggregation_fn = metrics_aggregator(model.metric_finalizers(),
                                                    unfinalized_metrics_type)
    data_type = computation_types.SequenceType(model.input_spec)
    weights_type = model_utils.weights_type_from_model(model)
    weight_tensor_specs = type_conversions.type_to_tf_tensor_specs(
        weights_type)

    full_gradient_aggregator = full_gradient_aggregator.create(
        weights_type.trainable, computation_types.TensorType(tf.float32))

    @federated_computation.federated_computation
    def init_fn():
        specs = weight_tensor_specs.trainable
        optimizer_state = intrinsics.federated_eval(
            tensorflow_computation.tf_computation(
                lambda: optimizer.initialize(specs)), placements.SERVER)
        aggregator_state = full_gradient_aggregator.initialize()
        return intrinsics.federated_zip((optimizer_state, aggregator_state))

    client_update_fn = _build_client_update_fn_for_mime_lite(
        model_fn, optimizer, client_weighting,
        use_experimental_simulation_loop)

    @tensorflow_computation.tf_computation(
        init_fn.type_signature.result.member[0], weights_type.trainable)
    def update_optimizer_state(state, aggregate_gradient):
        whimsy_weights = tf.nest.map_structure(
            lambda g: tf.zeros(g.shape, g.dtype), aggregate_gradient)
        updated_state, _ = optimizer.next(state, whimsy_weights,
                                          aggregate_gradient)
        return updated_state

    @federated_computation.federated_computation(
        init_fn.type_signature.result,
        computation_types.at_clients(weights_type),
        computation_types.at_clients(data_type))
    def next_fn(state, weights, client_data):
        optimizer_state, aggregator_state = state
        optimizer_state_at_clients = intrinsics.federated_broadcast(
            optimizer_state)
        client_result, model_outputs, full_gradient = (
            intrinsics.federated_map(
                client_update_fn,
                (optimizer_state_at_clients, weights, client_data)))
        full_gradient_agg_output = full_gradient_aggregator.next(
            aggregator_state, full_gradient, client_result.update_weight)
        updated_optimizer_state = intrinsics.federated_map(
            update_optimizer_state,
            (optimizer_state, full_gradient_agg_output.result))

        new_state = intrinsics.federated_zip(
            (updated_optimizer_state, full_gradient_agg_output.state))
        train_metrics = metrics_aggregation_fn(model_outputs)
        measurements = intrinsics.federated_zip(
            collections.OrderedDict(train=train_metrics))
        return measured_process.MeasuredProcessOutput(new_state, client_result,
                                                      measurements)

    return client_works.ClientWorkProcess(init_fn, next_fn)
Example #25
0
def build_model_delta_client_work(
    model_fn: Callable[[], model_lib.Model],
    optimizer: Union[optimizer_base.Optimizer,
                     Callable[[], tf.keras.optimizers.Optimizer]],
    client_weighting: client_weight_lib.ClientWeighting,
    delta_l2_regularizer: float = 0.0,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    *,
    use_experimental_simulation_loop: bool = False
) -> client_works.ClientWorkProcess:
    """Creates a `ClientWorkProcess` for federated averaging.

  This client work is constructed in slightly different manners depending on
  whether `optimizer` is a `tff.learning.optimizers.Optimizer`, or a no-arg
  callable returning a `tf.keras.optimizers.Optimizer`.

  If it is a `tff.learning.optimizers.Optimizer`, we avoid creating
  `tf.Variable`s associated with the optimizer state within the scope of the
  client work, as they are not necessary. This also means that the client's
  model weights are updated by computing `optimizer.next` and then assigning
  the result to the model weights (while a `tf.keras.optimizers.Optimizer` will
  modify the model weight in place using `optimizer.apply_gradients`).

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    optimizer: A `tff.learning.optimizers.Optimizer`, or a no-arg callable that
      returns a `tf.keras.Optimizer`.
    client_weighting:  A `tff.learning.ClientWeighting` value.
    delta_l2_regularizer: A nonnegative float representing the parameter of the
      L2-regularization term applied to the delta from initial model weights
      during training. Values larger than 0.0 prevent clients from moving too
      far from the server model during local training.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `ClientWorkProcess`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_type(client_weighting,
                            client_weight_lib.ClientWeighting)
    py_typecheck.check_type(delta_l2_regularizer, float)
    if delta_l2_regularizer < 0.0:
        raise ValueError(f'Provided delta_l2_regularizer must be non-negative,'
                         f'but found: {delta_l2_regularizer}')
    if not (isinstance(optimizer, optimizer_base.Optimizer)
            or callable(optimizer)):
        raise TypeError(
            'Provided optimizer must a either a tff.learning.optimizers.Optimizer '
            'or a no-arg callable returning an tf.keras.optimizers.Optimizer.')

    if metrics_aggregator is None:
        metrics_aggregator = aggregator.sum_then_finalize

    with tf.Graph().as_default():
        # Wrap model construction in a graph to avoid polluting the global context
        # with variables created for this model.
        model = model_fn()
        unfinalized_metrics_type = type_conversions.type_from_tensors(
            model.report_local_unfinalized_metrics())
        metrics_aggregation_fn = metrics_aggregator(model.metric_finalizers(),
                                                    unfinalized_metrics_type)
    data_type = computation_types.SequenceType(model.input_spec)
    weights_type = model_utils.weights_type_from_model(model)

    if isinstance(optimizer, optimizer_base.Optimizer):

        @tensorflow_computation.tf_computation(weights_type, data_type)
        def client_update_computation(initial_model_weights, dataset):
            client_update = build_model_delta_update_with_tff_optimizer(
                model_fn=model_fn,
                weighting=client_weighting,
                delta_l2_regularizer=delta_l2_regularizer,
                use_experimental_simulation_loop=
                use_experimental_simulation_loop)
            return client_update(optimizer, initial_model_weights, dataset)

    else:

        @tensorflow_computation.tf_computation(weights_type, data_type)
        def client_update_computation(initial_model_weights, dataset):
            keras_optimizer = optimizer()
            client_update = build_model_delta_update_with_keras_optimizer(
                model_fn=model_fn,
                weighting=client_weighting,
                delta_l2_regularizer=delta_l2_regularizer,
                use_experimental_simulation_loop=
                use_experimental_simulation_loop)
            return client_update(keras_optimizer, initial_model_weights,
                                 dataset)

    @federated_computation.federated_computation
    def init_fn():
        return intrinsics.federated_value((), placements.SERVER)

    @federated_computation.federated_computation(
        init_fn.type_signature.result,
        computation_types.at_clients(weights_type),
        computation_types.at_clients(data_type))
    def next_fn(state, weights, client_data):
        client_result, model_outputs = intrinsics.federated_map(
            client_update_computation, (weights, client_data))
        train_metrics = metrics_aggregation_fn(model_outputs)
        measurements = intrinsics.federated_zip(
            collections.OrderedDict(train=train_metrics))
        return measured_process.MeasuredProcessOutput(state, client_result,
                                                      measurements)

    return client_works.ClientWorkProcess(init_fn, next_fn)
Example #26
0
def build_functional_model_delta_client_work(
    *,
    model: functional.FunctionalModel,
    optimizer: optimizer_base.Optimizer,
    client_weighting: client_weight_lib.ClientWeighting,
    delta_l2_regularizer: float = 0.0,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
) -> client_works.ClientWorkProcess:
    """Creates a `ClientWorkProcess` for federated averaging.

  This differs from `tff.learning.templates.build_model_delta_client_work` in
  that it only accepts `tff.learning.models.FunctionalModel` and
  `tff.learning.optimizers.Optimizer` type arguments, resulting in TensorFlow
  graphs that do not contain `tf.Variable` operations.

  Args:
    model: A `tff.learning.models.FunctionalModel` to train.
    optimizer: A `tff.learning.optimizers.Optimizer` to use for local, on-client
      optimization.
    client_weighting:  A `tff.learning.ClientWeighting` value.
    delta_l2_regularizer: A nonnegative float representing the parameter of the
      L2-regularization term applied to the delta from initial model weights
      during training. Values larger than 0.0 prevent clients from moving too
      far from the server model during local training.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.

  Returns:
    A `ClientWorkProcess`.
  """
    py_typecheck.check_type(model, functional.FunctionalModel)
    py_typecheck.check_type(optimizer, optimizer_base.Optimizer)
    py_typecheck.check_type(client_weighting,
                            client_weight_lib.ClientWeighting)
    py_typecheck.check_type(delta_l2_regularizer, float)
    if delta_l2_regularizer < 0.0:
        raise ValueError(f'Provided delta_l2_regularizer must be non-negative,'
                         f'but found: {delta_l2_regularizer}')

    if metrics_aggregator is None:
        metrics_aggregator = aggregator.sum_then_finalize

    # TODO(b/229612282): Add metrics implementation.

    data_type = computation_types.SequenceType(model.input_spec)

    def ndarray_to_tensorspec(ndarray):
        return tf.TensorSpec(shape=ndarray.shape,
                             dtype=tf.dtypes.as_dtype(ndarray.dtype))

    # Wrap in a `ModelWeights` structure that is required by the `finalizer.`
    weights_type = model_utils.ModelWeights(
        tuple(ndarray_to_tensorspec(w) for w in model.initial_weights[0]),
        tuple(ndarray_to_tensorspec(w) for w in model.initial_weights[1]))

    @tensorflow_computation.tf_computation(weights_type, data_type)
    def client_update_computation(initial_model_weights, dataset):
        # Switch to the tuple expected by FunctionalModel.
        initial_model_weights = (initial_model_weights.trainable,
                                 initial_model_weights.non_trainable)
        client_update = build_functional_model_delta_update(
            model=model,
            weighting=client_weighting,
            delta_l2_regularizer=delta_l2_regularizer)
        return client_update(optimizer, initial_model_weights, dataset)

    @federated_computation.federated_computation
    def init_fn():
        # Empty tuple means "no state" / stateless.
        return intrinsics.federated_value((), placements.SERVER)

    @federated_computation.federated_computation(
        computation_types.at_server(()),
        computation_types.at_clients(weights_type),
        computation_types.at_clients(data_type))
    def next_fn(state, weights, client_data):
        client_result, model_outputs = intrinsics.federated_map(
            client_update_computation, (weights, client_data))
        # TODO(b/229612282): Add metrics computations
        del model_outputs
        measurements = intrinsics.federated_value((), placements.SERVER)
        return measured_process.MeasuredProcessOutput(state, client_result,
                                                      measurements)

    return client_works.ClientWorkProcess(init_fn, next_fn)
Example #27
0
 def test_init_not_tff_computation_raises(self):
     with self.assertRaisesRegex(TypeError,
                                 r'Expected .*\.Computation, .*'):
         client_works.ClientWorkProcess(initialize_fn=lambda: 0,
                                        next_fn=test_next_fn)
Example #28
0
 def test_init_param_not_empty_raises(self):
     one_arg_initialize_fn = federated_computation.federated_computation(
         SERVER_INT)(lambda x: x)
     with self.assertRaises(errors.TemplateInitFnParamNotEmptyError):
         client_works.ClientWorkProcess(one_arg_initialize_fn, test_next_fn)
Example #29
0
 def test_init_state_not_assignable(self):
     float_initialize_fn = federated_computation.federated_computation()(
         lambda: intrinsics.federated_value(0.0, placements.SERVER))
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         client_works.ClientWorkProcess(float_initialize_fn, test_next_fn)