def __init__(self,
                 model: model_lib.Model,
                 optimizer: tf.keras.optimizers.Optimizer,
                 client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
                 use_experimental_simulation_loop: bool = False):
        """Creates the client computation for Federated Averaging.

    Note: All variable creation required for the client computation (e.g. model
    variable creation) must occur in during construction, and not during
    `__call__`.

    Args:
      model: A `tff.learning.Model` instance.
      optimizer: A `tf.keras.Optimizer` instance.
      client_weight_fn: an optional callable that takes the output of
        `model.report_local_outputs` and returns a tensor that provides the
        weight in the federated average of model deltas. If not provided, the
        default is the total number of examples processed on device.
      use_experimental_simulation_loop: Controls the reduce loop function for
        input dataset. An experimental reduce loop is used for simulation.
    """
        py_typecheck.check_type(model, model_lib.Model)
        self._model = model_utils.enhance(model)
        self._optimizer = optimizer
        py_typecheck.check_type(self._model, model_utils.EnhancedModel)

        if client_weight_fn is not None:
            py_typecheck.check_callable(client_weight_fn)
            self._client_weight_fn = client_weight_fn
        else:
            self._client_weight_fn = None

        self._dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
            use_experimental_simulation_loop)
Exemple #2
0
  def __init__(self,
               model: model_lib.Model,
               batch_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
               use_experimental_simulation_loop: bool = False):
    """Constructs the client computation for Federated SGD.

    Note: All variable creation required for the client computation (e.g. model
    variable construction) must occur in during construction, and not during
    `__call__`.

    Args:
      model: A `learning.Model` for which gradients are computed.
      batch_weight_fn: A function that takes a batch (as passed to forward_pass)
        and returns a float32 weight. If not provided, the default uses the size
        of the batch (as measured by the batch dimension of the predictions
        returned by forward_pass).
      use_experimental_simulation_loop: Controls the reduce loop function for
        input dataset. An experimental reduce loop is used for simulation.
    """
    if batch_weight_fn is not None:
      py_typecheck.check_callable(batch_weight_fn)
    self._batch_weight_fn = batch_weight_fn

    self._model = model_utils.enhance(model)
    py_typecheck.check_type(self._model, model_utils.EnhancedModel)

    self._dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
        use_experimental_simulation_loop)
Exemple #3
0
  def test_enhanced_var_lists(self):

    class BadModel(model_examples.TrainableLinearRegression):

      @property
      def trainable_variables(self):
        return ['not_a_variable']

      @property
      def local_variables(self):
        return 1

      def forward_pass(self, batch, training=True):
        return 'Not BatchOutput'

      def train_on_batch(self, batch):
        return 'Not BatchOutput'

    bad_model = model_utils.enhance(BadModel())
    self.assertRaisesRegexp(TypeError,
                            'Variable', lambda: bad_model.trainable_variables)
    self.assertRaisesRegexp(TypeError,
                            'Iterable', lambda: bad_model.local_variables)
    self.assertRaisesRegexp(TypeError,
                            'BatchOutput', lambda: bad_model.forward_pass(1))
    self.assertRaisesRegexp(TypeError,
                            'BatchOutput', lambda: bad_model.train_on_batch(1))
Exemple #4
0
    def __init__(self,
                 model: model_lib.Model,
                 optimizer: tf.keras.optimizers.Optimizer,
                 client_weight_fn: Optional[Callable[[Any],
                                                     tf.Tensor]] = None):
        """Creates the client computation for Federated Averaging.

    Args:
      model: A `tff.learning.Model` instance.
      optimizer: A `tf.keras.Optimizer` instance.
      client_weight_fn: an optional callable that takes the output of
        `model.report_local_outputs` and returns a tensor that provides the
        weight in the federated average of model deltas. If not provided, the
        default is the total number of examples processed on device.
    """
        py_typecheck.check_type(model, model_lib.Model)
        self._model = model_utils.enhance(model)
        self._optimizer = optimizer
        py_typecheck.check_type(self._model, model_utils.EnhancedModel)

        if client_weight_fn is not None:
            py_typecheck.check_callable(client_weight_fn)
            self._client_weight_fn = client_weight_fn
        else:
            self._client_weight_fn = None
    def __init__(self, model, batch_weight_fn=None):
        """Constructs the client computation for Federated SGD.

    Args:
      model: A `learning.Model` for which gradients are computed.
      batch_weight_fn: A function that takes a batch (as passed to forward_pass)
        and returns a float32 weight. If not provided, the default uses the size
        of the batch (as measured by the batch dimension of the predictions
        returned by forward_pass).
    """
        if batch_weight_fn is not None:
            py_typecheck.check_callable(batch_weight_fn)
        self._batch_weight_fn = batch_weight_fn

        self._model = model_utils.enhance(model)
        del model
        py_typecheck.check_type(self._model, model_utils.EnhancedModel)
        if isinstance(self._model, model_lib.TrainableModel):
            raise ValueError(
                'Do not pass a TrainableModel to ClientSgd, as the '
                'built-in local training algorithm would be ignored. '
                'This failure could be made into a warning if this is inconvenient.'
            )

        def _get_grad_var(name, tensor):
            return tf.Variable(lambda: tf.zeros_like(tensor),
                               name='{}_grad'.format(name))

        self._grad_sum_vars = nest.map_structure_with_paths(
            _get_grad_var, self._model.weights.trainable)
        self._batch_weight_sum = tf.Variable(0.0, name='batch_weight_sum')
Exemple #6
0
    def __init__(
            self,
            model: model_lib.Model,
            client_weighting: client_weight_lib.
        ClientWeightType = client_weight_lib.ClientWeighting.NUM_EXAMPLES,
            use_experimental_simulation_loop: bool = False):
        """Constructs the client computation for Federated SGD.

    Note: All variable creation required for the client computation (e.g. model
    variable construction) must occur in during construction, and not during
    `__call__`.

    Args:
      model: A `learning.Model` for which gradients are computed.
      client_weighting: A value of `tff.learning.ClientWeighting` that
        specifies a built-in weighting method, or a callable that takes the
        output of `model.report_local_outputs` and returns a tensor that
        provides the weight in the federated average of model deltas.
      use_experimental_simulation_loop: Controls the reduce loop function for
        input dataset. An experimental reduce loop is used for simulation.
    """
        client_weight_lib.check_is_client_weighting_or_callable(
            client_weighting)
        self._client_weighting = client_weighting

        self._model = model_utils.enhance(model)
        py_typecheck.check_type(self._model, model_utils.EnhancedModel)

        self._dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
            use_experimental_simulation_loop)
Exemple #7
0
def build_encoded_sum_from_model(model_fn, encoder_fn):
    """Builds `StatefulAggregateFn` for weights of model returned by `model_fn`.

  This method creates a `GatherEncoder` for every trainable weight of model
  created by `model_fn`, as returned by `encoder_fn`.

  Args:
    model_fn: A Python callable with no arguments function that returns a
      `tff.learning.Model`.
    encoder_fn: A Python callable with a single argument, which is expected to
      be a `tf.Tensor` of shape and dtype to be encoded. The function must
      return a `tensor_encoding.core.SimpleEncoder`, which expects a `tf.Tensor`
      with compatible type as the input to its `encode` method.

  Returns:
    A `StatefulAggregateFn` for encoding and summing the weights of model
    created by `model_fn`.

  Raises:
    TypeError: If `model_fn` or `encoder_fn` are not callable objects.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_callable(encoder_fn)
    values = model_utils.enhance(model_fn()).weights.trainable
    encoders = tf.nest.map_structure(encoder_fn, values)
    return tff.utils.build_encoded_sum(values, encoders)
def _compute_p13n_metrics(model_fn, initial_model_weights, train_data,
                          test_data, personalize_fn_dict, context):
    """Train and evaluate the personalized models."""
    model = model_utils.enhance(model_fn())
    # Construct the `personalize_fn` (and the associated `tf.Variable`s) here.
    # This ensures that the new variables are created in the graphs that TFF
    # controls. This is the key reason why we need `personalize_fn_dict` to
    # contain no-argument functions that build the desired `tf.function`s, rather
    # than already built `tf.function`s. Note that this has to be done outside the
    # `tf.function` `loop_and_compute` below, because `tf.function` usually does
    # not allow creation of new variables.
    personalize_fns = collections.OrderedDict()
    for name, personalize_fn_builder in personalize_fn_dict.items():
        py_typecheck.check_type(name, str)
        py_typecheck.check_callable(personalize_fn_builder)
        personalize_fns[name] = personalize_fn_builder()

    @tf.function
    def loop_and_compute():
        p13n_metrics = collections.OrderedDict()
        for name, personalize_fn in personalize_fns.items():
            tff.utils.assign(model.weights, initial_model_weights)
            py_typecheck.check_callable(personalize_fn)
            p13n_metrics[name] = personalize_fn(model, train_data, test_data,
                                                context)
        return p13n_metrics

    return loop_and_compute()
Exemple #9
0
def from_compiled_keras_model(keras_model, dummy_batch):
    """Builds a `tff.learning.Model` for an example mini batch.

  Args:
    keras_model: A `tf.keras.Model` object that was compiled.
    dummy_batch: A nested structure of values that are convertible to *batched*
      tensors with the same shapes and types as expected by `forward_pass()`.
      The values of the tensors are not important and can be filled with any
      reasonable input value.

  Returns:
    A `tff.learning.Model`.

  Raises:
    TypeError: If `keras_model` is not an instance of `tf.keras.Model`.
    ValueError: If `keras_model` was *not* compiled.
  """
    py_typecheck.check_type(keras_model, tf.keras.Model)
    # Optimizer attribute is only set after calling tf.keras.Model.compile().
    if not keras_model.optimizer:
        raise ValueError(
            '`keras_model` must be compiled. Use from_keras_model() '
            'instead.')
    dummy_tensors = _preprocess_dummy_batch(dummy_batch)
    # NOTE: A sub-classed tf.keras.Model does not produce the compiled metrics
    # until the model has been called on input. The work-around is to call
    # Model.test_on_batch() once before asking for metrics.
    if isinstance(dummy_tensors, collections.Mapping):
        keras_model.test_on_batch(**dummy_tensors)
    else:
        keras_model.test_on_batch(*dummy_tensors)
    return model_utils.enhance(_TrainableKerasModel(keras_model,
                                                    dummy_tensors))
Exemple #10
0
def build_encoded_mean_from_model(model_fn, encoder_fn):
    """Builds `StatefulAggregateFn` for weights of model returned by `model_fn`.

  This method creates a `GatherEncoder` for every trainable weight of model
  created by `model_fn`, as returned by `encoder_fn`.

  Args:
    model_fn: A Python callable with no arguments function that returns a
      `tff.learning.Model`.
    encoder_fn: A Python callable with a single argument, which is expected to
      be a `tf.Tensor` of shape and dtype to be encoded. The function must
      return a `tensor_encoding.core.SimpleEncoder`, which expects a `tf.Tensor`
      with compatible type as the input to its `encode` method.

  Returns:
    A `StatefulAggregateFn` for encoding and averaging the weights of model
    created by `model_fn`.

  Raises:
    TypeError: If `model_fn` or `encoder_fn` are not callable objects.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_callable(encoder_fn)
    # TODO(b/144382142): Keras name uniquification is probably the main reason we
    # still need this.
    with tf.Graph().as_default():
        values = model_utils.enhance(model_fn()).weights.trainable
    encoders = tf.nest.map_structure(encoder_fn, values)
    return tff.utils.build_encoded_mean(values, encoders)
Exemple #11
0
    def client_eval(model_weights, dataset):
        """Returns local outputs after evaluting `model_weights` on `dataset`."""
        model = model_utils.enhance(model_fn())

        # TODO(b/124477598): Remove dummy when b/121400757 has been fixed.
        @tf.contrib.eager.function(autograph=False)
        def reduce_fn(dummy, batch):
            model_output = model.forward_pass(batch, training=False)
            return dummy + tf.cast(model_output.loss, tf.float64)

        # TODO(b/123092620): Avoid the need for this manual conversion.
        model_vars = anonymous_tuple.from_container(collections.OrderedDict([
            ('trainable', model.weights.trainable),
            ('non_trainable', model.weights.non_trainable)
        ]),
                                                    recursive=True)

        # TODO(b/123898430): The control dependencies below have been inserted as a
        # temporary workaround. These control dependencies need to be removed, and
        # defuns and datasets supported together fully.
        with tf.control_dependencies(
            [tff.utils.assign(model_vars, model_weights)]):
            dummy = dataset.reduce(tf.constant(0.0, dtype=tf.float64),
                                   reduce_fn)

        with tf.control_dependencies([dummy]):
            return collections.OrderedDict([
                ('local_outputs', model.report_local_outputs()),
                ('workaround for b/121400757', dummy)
            ])
    def test_federated_evaluation_dataset_reduce(self, simulation,
                                                 mock_method):
        evaluate_comp = federated_evaluation.build_federated_evaluation(
            _model_fn_from_keras, use_experimental_simulation_loop=simulation)
        initial_weights = tf.nest.map_structure(
            lambda x: x.read_value(),
            model_utils.enhance(_model_fn_from_keras()).weights)

        def _input_dict(temps):
            return collections.OrderedDict([
                ('x', np.reshape(np.array(temps, dtype=np.float32), (-1, 1))),
                ('y', np.reshape(np.array(temps, dtype=np.float32), (-1, 1))),
            ])

        evaluate_comp(
            initial_weights,
            [[_input_dict([1.0, 10.0, 2.0, 7.0]),
              _input_dict([6.0, 11.0])], [_input_dict([9.0, 12.0, 13.0])],
             [_input_dict([1.0]),
              _input_dict([22.0, 23.0])]])

        if simulation:
            mock_method.assert_not_called()
        else:
            mock_method.assert_called()
Exemple #13
0
def build_federated_evaluation(model_fn):
    """Builds the TFF computation for federated evaluation of the given model.

  Args:
    model_fn: A no-argument function that returns a `tff.learning.Model`.

  Returns:
    A federated computation (an instance of `tff.Computation`) that accepts
    model parameters and federated data, and returns the evaluation metrics
    as aggregated by `tff.learning.Model.federated_output_computation`.
  """
    # Construct the model first just to obtain the metadata and define all the
    # types needed to define the computations that follow.
    # TODO(b/124477628): Ideally replace the need for stamping throwaway models
    # with some other mechanism.
    with tf.Graph().as_default():
        model = model_utils.enhance(model_fn())
        model_weights_type = tff.to_type(
            tf.nest.map_structure(
                lambda v: tff.TensorType(v.dtype.base_dtype, v.shape),
                model.weights))
        batch_type = tff.to_type(model.input_spec)

    @tff.tf_computation(model_weights_type, tff.SequenceType(batch_type))
    def client_eval(incoming_model_weights, dataset):
        """Returns local outputs after evaluting `model_weights` on `dataset`."""
        model = model_utils.enhance(model_fn())

        # TODO(b/124477598): Remove dummy when b/121400757 has been fixed.
        @tf.function
        def reduce_fn(dummy, batch):
            model_output = model.forward_pass(batch, training=False)
            return dummy + tf.cast(model_output.loss, tf.float64)

        # TODO(b/123898430): The control dependencies below have been inserted as a
        # temporary workaround. These control dependencies need to be removed, and
        # defuns and datasets supported together fully.
        with tf.control_dependencies(
            [tff.utils.assign(model.weights, incoming_model_weights)]):
            dummy = dataset.reduce(tf.constant(0.0, dtype=tf.float64),
                                   reduce_fn)

        with tf.control_dependencies([dummy]):
            return collections.OrderedDict([
                ('local_outputs', model.report_local_outputs()),
                ('workaround for b/121400757', dummy)
            ])

    @tff.federated_computation(
        tff.FederatedType(model_weights_type, tff.SERVER),
        tff.FederatedType(tff.SequenceType(batch_type), tff.CLIENTS))
    def server_eval(server_model_weights, federated_dataset):
        client_outputs = tff.federated_map(
            client_eval,
            [tff.federated_broadcast(server_model_weights), federated_dataset])
        return model.federated_output_computation(client_outputs.local_outputs)

    return server_eval
def build_federated_evaluation(model_fn):
  """Builds the TFF computation for federated evaluation of the given model.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.

  Returns:
    A federated computation (an instance of `tff.Computation`) that accepts
    model parameters and federated data, and returns the evaluation metrics
    as aggregated by `tff.learning.Model.federated_output_computation`.
  """
  # Construct the model first just to obtain the metadata and define all the
  # types needed to define the computations that follow.
  # TODO(b/124477628): Ideally replace the need for stamping throwaway models
  # with some other mechanism.
  with tf.Graph().as_default():
    model = model_utils.enhance(model_fn())
    model_weights_type = tff.framework.type_from_tensors(model.weights)
    batch_type = tff.to_type(model.input_spec)

  @tff.tf_computation(model_weights_type, tff.SequenceType(batch_type))
  def client_eval(incoming_model_weights, dataset):
    """Returns local outputs after evaluting `model_weights` on `dataset`."""

    model = model_utils.enhance(model_fn())

    @tf.function
    def _tf_client_eval(incoming_model_weights, dataset):
      """Evaluation TF work."""

      tff.utils.assign(model.weights, incoming_model_weights)

      def reduce_fn(prev_loss, batch):
        model_output = model.forward_pass(batch, training=False)
        return prev_loss + tf.cast(model_output.loss, tf.float64)

      dataset.reduce(tf.constant(0.0, dtype=tf.float64), reduce_fn)

      return collections.OrderedDict([('local_outputs',
                                       model.report_local_outputs())])

    return _tf_client_eval(incoming_model_weights, dataset)

  @tff.federated_computation(
      tff.FederatedType(model_weights_type, tff.SERVER),
      tff.FederatedType(tff.SequenceType(batch_type), tff.CLIENTS))
  def server_eval(server_model_weights, federated_dataset):
    client_outputs = tff.federated_map(
        client_eval,
        [tff.federated_broadcast(server_model_weights), federated_dataset])
    return model.federated_output_computation(client_outputs.local_outputs)

  return server_eval
Exemple #15
0
def from_keras_model(keras_model,
                     dummy_batch,
                     loss,
                     metrics=None,
                     optimizer=None):
    """Builds a `tff.learning.Model` for an example mini batch.

  Args:
    keras_model: A `tf.keras.Model` object that is not compiled.
    dummy_batch: A nested structure of values that are convertible to *batched*
      tensors with the same shapes and types as would be input to `keras_model`.
      The values of the tensors are not important and can be filled with any
      reasonable input value.
    loss: A callable that takes two batched tensor parameters, `y_true` and
      `y_pred`, and returns the loss.
    metrics: (Optional) a list of `tf.keras.metrics.Metric` objects.
    optimizer: (Optional) a `tf.keras.optimizer.Optimizer`. If None, returned
      model cannot be used for training.

  Returns:
    A `tff.learning.Model` object.

  Raises:
    TypeError: If `keras_model` is not an instance of `tf.keras.Model`.
    ValueError: If `keras_model` was compiled.
  """
    py_typecheck.check_type(keras_model, tf.keras.Model)
    if keras_model._is_compiled:  # pylint: disable=protected-access
        raise ValueError('`keras_model` must not be compiled. Use '
                         'from_compiled_keras_model() instead.')
    dummy_tensors = _preprocess_dummy_batch(dummy_batch)
    if optimizer is None:
        return model_utils.enhance(
            _KerasModel(keras_model, dummy_tensors, loss, metrics))
    keras_model.compile(loss=loss, optimizer=optimizer, metrics=metrics)
    # NOTE: A sub-classed tf.keras.Model does not produce the compiled metrics
    # until the model has been called on input. The work-around is to call
    # Model.test_on_batch() once before asking for metrics.
    keras_model.test_on_batch(**dummy_tensors)
    return model_utils.enhance(_TrainableKerasModel(keras_model,
                                                    dummy_tensors))
def _compute_baseline_metrics(model_fn, initial_model_weights, test_data,
                              baseline_evaluate_fn):
    """Evaluate the model with weights being the `initial_model_weights`."""
    model = model_utils.enhance(model_fn())

    @tf.function
    def assign_and_compute():
        tff.utils.assign(model.weights, initial_model_weights)
        py_typecheck.check_callable(baseline_evaluate_fn)
        return baseline_evaluate_fn(model, test_data)

    return assign_and_compute()
    def __init__(self, model, client_weight_fn=None):
        """Creates the client computation for Federated Averaging.

    Args:
      model: A `tff.learning.TrainableModel`.
      client_weight_fn: Optional argument is ignored
    """
        del client_weight_fn
        self._model = model_utils.enhance(model)
        py_typecheck.check_type(self._model,
                                model_utils.EnhancedTrainableModel)
        self._client_weight_fn = None
Exemple #18
0
 def test_model_initializer(self):
   with tf.Graph().as_default() as g:
     model = model_utils.enhance(model_examples.LinearRegression(2))
     init = model_utils.model_initializer(model)
     with self.session(graph=g) as sess:
       sess.run(init)
       # Make sure we can read all the variables
       try:
         sess.run(model.local_variables)
         sess.run(model.weights)
       except tf.errors.FailedPreconditionError:
         self.fail('Expected variables to be initialized, but got '
                   'tf.errors.FailedPreconditionError')
Exemple #19
0
def server_update_model(
    server_state: ServerState,
    weights_delta: Collection[tf.Tensor],
    model_fn: _ModelConstructor,
    optimizer_fn: _OptimizerConstructor,
) -> ServerState:
    """Updates `server_state` based on `weights_delta`.

  Args:
    server_state: A `tff.learning.framework.ServerState` namedtuple, the state
      to be updated.
    weights_delta: An update to the trainable variables of the model.
    model_fn: A no-arg function that returns a `tff.learning.Model`. Passing in
      a function ensures any variables are created when server_update_model is
      called, so they can be captured in a specific graph or other context.
    optimizer_fn: A no-arg function that returns a `tf.train.Optimizer`. As with
      model_fn, we pass in a function to control when variables are created.

  Returns:
    An updated `tff.learning.framework.ServerState`.
  """
    py_typecheck.check_type(server_state, ServerState)
    py_typecheck.check_type(weights_delta, collections.Collection)
    model = model_utils.enhance(model_fn())
    optimizer = optimizer_fn()
    apply_delta_fn, optimizer_vars = _build_server_optimizer(model, optimizer)

    # We might have a NaN value e.g. if all of the clients processed
    # had no data, so the denominator in the federated_mean is zero.
    # If we see any NaNs, zero out the whole update.
    no_nan_weights_delta, _ = tensor_utils.zero_all_if_any_non_finite(
        weights_delta)
    # TODO(b/124538167): We should increment a server counter to
    # track the fact a non-finite weights_delta was encountered.

    @tf.function
    def update_model_inner():
        """Applies the update."""
        tf.nest.map_structure(
            lambda a, b: a.assign(b), (model.weights, optimizer_vars),
            (server_state.model, server_state.optimizer_state))
        apply_delta_fn(no_nan_weights_delta)
        return model.weights, optimizer_vars

    model_weights, optimizer_vars = update_model_inner()
    # TODO(b/123092620): We must do this outside of the above tf.function, because
    # there could be an AnonymousTuple hiding in server_state,
    # and tf.function's can't return AnonymousTuples.
    return tff.utils.update_state(server_state,
                                  model=model_weights,
                                  optimizer_state=optimizer_vars)
Exemple #20
0
def server_init(model_fn, optimizer_fn):
  """Returns initial `tff.learning.framework.ServerState`.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    optimizer_fn: A no-arg function that returns a `tf.train.Optimizer`.

  Returns:
    A `tff.learning.framework.ServerState` namedtuple.
  """
  model = model_utils.enhance(model_fn())
  optimizer = optimizer_fn()
  _, server_state = _create_optimizer_and_server_state(model, optimizer)
  return server_state
Exemple #21
0
    def __init__(self, model, batch_weight_fn=None):
        """Constructs the client computation for Federated SGD.

    Args:
      model: A `learning.Model` for which gradients are computed.
      batch_weight_fn: A function that takes a batch (as passed to forward_pass)
        and returns a float32 weight. If not provided, the default uses the size
        of the batch (as measured by the batch dimension of the predictions
        returned by forward_pass).
    """
        if batch_weight_fn is not None:
            py_typecheck.check_callable(batch_weight_fn)
        self._batch_weight_fn = batch_weight_fn

        self._model = model_utils.enhance(model)
        py_typecheck.check_type(self._model, model_utils.EnhancedModel)
    def test_federated_evaluation_with_keras(self):
        def model_fn():
            keras_model = tf.keras.Sequential([
                tf.keras.layers.Dense(1,
                                      kernel_initializer='ones',
                                      bias_initializer='zeros',
                                      activation=None)
            ],
                                              name='my_model')
            keras_model.compile(loss='mean_squared_error',
                                optimizer='sgd',
                                metrics=[tf.keras.metrics.Accuracy()])
            return keras_utils.from_compiled_keras_model(keras_model,
                                                         dummy_batch={
                                                             'x':
                                                             np.zeros(
                                                                 (1, 1),
                                                                 np.float32),
                                                             'y':
                                                             np.zeros(
                                                                 (1, 1),
                                                                 np.float32)
                                                         })

        evaluate_comp = federated_evaluation.build_federated_evaluation(
            model_fn)
        initial_weights = tf.nest.map_structure(
            lambda x: x.read_value(),
            model_utils.enhance(model_fn()).weights)

        def _input_dict(temps):
            return {
                'x': np.reshape(np.array(temps, dtype=np.float32), (-1, 1)),
                'y': np.reshape(np.array(temps, dtype=np.float32), (-1, 1))
            }

        result = evaluate_comp(
            initial_weights,
            [[_input_dict([1.0, 10.0, 2.0, 7.0]),
              _input_dict([6.0, 11.0])], [_input_dict([9.0, 12.0, 13.0])],
             [_input_dict([1.0]),
              _input_dict([22.0, 23.0])]])
        # Expect 100% accuracy and no loss because we've constructed the identity
        # function and have the same x's and y's for training data.
        self.assertEqual(str(result), '<accuracy=1.0,loss=0.0>')
Exemple #23
0
  def client_eval(incoming_model_weights, dataset):
    """Returns local outputs after evaluting `model_weights` on `dataset`."""
    model = model_utils.enhance(model_fn())

    @tf.function
    def _tf_client_eval(incoming_model_weights, dataset):
      """Evaluation TF work."""
      tf_computation_utils.assign(model.weights, incoming_model_weights)

      def reduce_fn(prev_loss, batch):
        model_output = model.forward_pass(batch, training=False)
        return prev_loss + tf.cast(model_output.loss, tf.float64)

      dataset.reduce(tf.constant(0.0, dtype=tf.float64), reduce_fn)

      return collections.OrderedDict([('local_outputs',
                                       model.report_local_outputs())])

    return _tf_client_eval(incoming_model_weights, dataset)
    def __init__(self, model, client_weight_fn=None):
        """Creates the client computation for Federated Averaging.

    Args:
      model: A `tff.learning.TrainableModel`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor that provides the
        weight in the federated average of model deltas. If not provided, the
        default is the total number of examples processed on device.
    """
        self._model = model_utils.enhance(model)
        py_typecheck.check_type(self._model,
                                model_utils.EnhancedTrainableModel)

        if client_weight_fn is not None:
            py_typecheck.check_callable(client_weight_fn)
            self._client_weight_fn = client_weight_fn
        else:
            self._client_weight_fn = None
    def test_federated_evaluation_with_keras(self):
        def model_fn():
            keras_model = tf.keras.Sequential([
                tf.keras.layers.Input(shape=(1, )),
                tf.keras.layers.Dense(1,
                                      kernel_initializer='ones',
                                      bias_initializer='zeros',
                                      activation=None)
            ],
                                              name='my_model')
            return keras_utils.from_keras_model(
                keras_model,
                input_spec=collections.OrderedDict(
                    x=tf.TensorSpec(shape=(None, 1), dtype=tf.float32),
                    y=tf.TensorSpec(shape=(None, 1), dtype=tf.float32),
                ),
                loss=tf.keras.losses.MeanSquaredError(),
                metrics=[tf.keras.metrics.Accuracy()])

        evaluate_comp = federated_evaluation.build_federated_evaluation(
            model_fn)
        initial_weights = tf.nest.map_structure(
            lambda x: x.read_value(),
            model_utils.enhance(model_fn()).weights)

        def _input_dict(temps):
            return collections.OrderedDict([
                ('x', np.reshape(np.array(temps, dtype=np.float32), (-1, 1))),
                ('y', np.reshape(np.array(temps, dtype=np.float32), (-1, 1))),
            ])

        result = evaluate_comp(
            initial_weights,
            [[_input_dict([1.0, 10.0, 2.0, 7.0]),
              _input_dict([6.0, 11.0])], [_input_dict([9.0, 12.0, 13.0])],
             [_input_dict([1.0]),
              _input_dict([22.0, 23.0])]])
        # Expect 100% accuracy and no loss because we've constructed the identity
        # function and have the same x's and y's for training data.
        self.assertEqual(
            str(result),
            '<accuracy=1.0,loss=0.0,keras_training_time_client_sum_sec=0.0>')
Exemple #26
0
def server_init(model_fn, optimizer_fn, delta_aggregate_state,
                model_broadcast_state):
    """Returns initial `tff.learning.framework.ServerState`.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    optimizer_fn: A no-arg function that returns a `tf.train.Optimizer`.
    delta_aggregate_state: The initial state of the delta_aggregator.
    model_broadcast_state: The initial state of the model_broadcaster.

  Returns:
    A `tff.learning.framework.ServerState` namedtuple.
  """
    model = model_utils.enhance(model_fn())
    optimizer = optimizer_fn()
    _, optimizer_vars = _build_server_optimizer(model, optimizer)
    return ServerState(model=model.weights,
                       optimizer_state=optimizer_vars,
                       delta_aggregate_state=delta_aggregate_state,
                       model_broadcast_state=model_broadcast_state)
Exemple #27
0
def broadcast_from_model_fn_encoder_fn(model_fn, encoder_fn):
    """Builds `StatefulBroadcastFn` for weights of model returned by `model_fn`.

  This

  Args:
    model_fn: A Python callable with no arguments function that returns a
      `tff.learning.Model`.
    encoder_fn: A Python callable with a single argument, which is expected to
      be a `tf.Tensor` of shape and dtype to be encoded.

  Returns:
    A `StatefulBroadcastFn` for encoding and broadcasting the weights of model
    created by `model_fn`.

  Raises:
    TypeError: If `model_fn` or `encoder_fn` are not callable objects.
  """
    py_typecheck.check_callable(encoder_fn)
    value = model_utils.enhance(model_fn()).weights
    return broadcast_from_encoder_fn(value, encoder_fn)
Exemple #28
0
  def __init__(self, model, batch_weight_fn=None):
    """Constructs the client computation for Federated SGD.

    Args:
      model: A `learning.Model` for which gradients are computed.
      batch_weight_fn: A function that takes a batch (as passed to forward_pass)
        and returns a float32 weight. If not provided, the default uses the size
        of the batch (as measured by the batch dimension of the predictions
        returned by forward_pass).
    """
    if batch_weight_fn is not None:
      py_typecheck.check_callable(batch_weight_fn)
    self._batch_weight_fn = batch_weight_fn

    self._model = model_utils.enhance(model)
    py_typecheck.check_type(self._model, model_utils.EnhancedModel)
    if isinstance(self._model, model_lib.TrainableModel):
      raise ValueError(
          'Do not pass a TrainableModel to ClientSgd, as the '
          'built-in local training algorithm would be ignored. '
          'This failure could be made into a warning if this is inconvenient.')
Exemple #29
0
def server_update_model(current_server_state, weights_delta, model_fn,
                        optimizer_fn):
  """Updates `server_state` based on `weights_delta`.

  Args:
    current_server_state: A `tff.learning.framework.ServerState` namedtuple.
    weights_delta: An update to the trainable variables of the model.
    model_fn: A no-arg function that returns a `tff.learning.Model`. Passing in
      a function ensures any variables are created when server_update_model is
      called, so they can be captured in a specific graph or other context.
    optimizer_fn: A no-arg function that returns a `tf.train.Optimizer`. As with
      model_fn, we pass in a function to control when variables are created.

  Returns:
    An updated `tff.learning.framework.ServerState`.
  """
  py_typecheck.check_type(current_server_state, ServerState)
  py_typecheck.check_type(weights_delta, collections.OrderedDict)
  model = model_utils.enhance(model_fn())
  optimizer = optimizer_fn()
  apply_delta_fn, server_state_vars = _create_optimizer_and_server_state(
      model, optimizer)

  # We might have a NaN value e.g. if all of the clients processed
  # had no data, so the denominator in the federated_mean is zero.
  # If we see any NaNs, zero out the whole update.
  no_nan_weights_delta, _ = tensor_utils.zero_all_if_any_non_finite(
      weights_delta)
  # TODO(b/124538167): We should increment a server counter to
  # track the fact a non-finite weiths_delta was encountered.

  @tf.contrib.eager.function(autograph=False)
  def update_model_inner():
    """Applies the update."""
    nest.map_structure(tf.assign, server_state_vars, current_server_state)
    apply_delta_fn(no_nan_weights_delta)
    return server_state_vars

  return update_model_inner()
  def test_federated_evaluation_with_keras(self, simulation):

    evaluate_comp = federated_evaluation.build_federated_evaluation(
        _model_fn_from_keras, use_experimental_simulation_loop=simulation)
    initial_weights = tf.nest.map_structure(
        lambda x: x.read_value(),
        model_utils.enhance(_model_fn_from_keras()).weights)

    def _input_dict(temps):
      return collections.OrderedDict([
          ('x', np.reshape(np.array(temps, dtype=np.float32), (-1, 1))),
          ('y', np.reshape(np.array(temps, dtype=np.float32), (-1, 1))),
      ])

    result = evaluate_comp(
        initial_weights,
        [[_input_dict([1.0, 10.0, 2.0, 7.0]),
          _input_dict([6.0, 11.0])], [_input_dict([9.0, 12.0, 13.0])],
         [_input_dict([1.0]), _input_dict([22.0, 23.0])]])
    # Expect 100% accuracy and no loss because we've constructed the identity
    # function and have the same x's and y's for training data.
    self.assertDictEqual(result,
                         collections.OrderedDict(accuracy=1.0, loss=0.0))