Exemplo n.º 1
0
    def test_construction_with_aggregation_process(self):
        with tf.Graph().as_default():
            model_update_type = tff.framework.type_from_tensors(
                model_utils.ModelWeights.from_model(
                    model_examples.LinearRegression()).trainable)
        aggregation_process = _build_test_measured_mean(model_update_type)
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            aggregation_process=aggregation_process)

        aggregation_state_type = aggregation_process.initialize.type_signature.result
        initialize_type = iterative_process.initialize.type_signature
        self.assertEqual(
            tff.FederatedType(
                initialize_type.result.member.delta_aggregate_state,
                tff.SERVER), aggregation_state_type)

        next_type = iterative_process.next.type_signature
        self.assertEqual(
            tff.FederatedType(
                next_type.parameter[0].member.delta_aggregate_state,
                tff.SERVER), aggregation_state_type)
        self.assertEqual(
            tff.FederatedType(next_type.result[0].member.delta_aggregate_state,
                              tff.SERVER), aggregation_state_type)

        aggregation_metrics_type = aggregation_process.next.type_signature.result.measurements
        self.assertEqual(
            tff.FederatedType(next_type.result[1].member.aggregation,
                              tff.SERVER), aggregation_metrics_type)
Exemplo n.º 2
0
    def test_construction_with_broadcast_process(self):
        with tf.Graph().as_default():
            model_weights_type = tff.framework.type_from_tensors(
                model_utils.ModelWeights.from_model(
                    model_examples.LinearRegression()))
        broadcast_process = _build_test_measured_broadcast(model_weights_type)
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            broadcast_process=broadcast_process)

        expected_broadcast_state_type = broadcast_process.initialize.type_signature.result
        initialize_type = iterative_process.initialize.type_signature
        self.assertEqual(
            tff.FederatedType(
                initialize_type.result.member.model_broadcast_state,
                tff.SERVER), expected_broadcast_state_type)

        next_type = iterative_process.next.type_signature
        self.assertEqual(
            tff.FederatedType(
                next_type.parameter[0].member.model_broadcast_state,
                tff.SERVER), expected_broadcast_state_type)
        self.assertEqual(
            tff.FederatedType(next_type.result[0].member.model_broadcast_state,
                              tff.SERVER), expected_broadcast_state_type)
Exemplo n.º 3
0
def build_federated_evaluation(model_fn):
    """Builds the TFF computation for federated evaluation of the given model.

  Args:
    model_fn: A no-argument function that returns a `tff.learning.Model`.

  Returns:
    A federated computation (an instance of `tff.Computation`) that accepts
    model parameters and federated data, and returns the evaluation metrics
    as aggregated by `tff.learning.Model.federated_output_computation`.
  """
    # Construct the model first just to obtain the metadata and define all the
    # types needed to define the computations that follow.
    # TODO(b/124477628): Ideally replace the need for stamping throwaway models
    # with some other mechanism.
    with tf.Graph().as_default():
        model = model_utils.enhance(model_fn())
        model_weights_type = tff.to_type(
            tf.nest.map_structure(
                lambda v: tff.TensorType(v.dtype.base_dtype, v.shape),
                model.weights))
        batch_type = tff.to_type(model.input_spec)

    @tff.tf_computation(model_weights_type, tff.SequenceType(batch_type))
    def client_eval(incoming_model_weights, dataset):
        """Returns local outputs after evaluting `model_weights` on `dataset`."""
        model = model_utils.enhance(model_fn())

        # TODO(b/124477598): Remove dummy when b/121400757 has been fixed.
        @tf.function
        def reduce_fn(dummy, batch):
            model_output = model.forward_pass(batch, training=False)
            return dummy + tf.cast(model_output.loss, tf.float64)

        # TODO(b/123898430): The control dependencies below have been inserted as a
        # temporary workaround. These control dependencies need to be removed, and
        # defuns and datasets supported together fully.
        with tf.control_dependencies(
            [tff.utils.assign(model.weights, incoming_model_weights)]):
            dummy = dataset.reduce(tf.constant(0.0, dtype=tf.float64),
                                   reduce_fn)

        with tf.control_dependencies([dummy]):
            return collections.OrderedDict([
                ('local_outputs', model.report_local_outputs()),
                ('workaround for b/121400757', dummy)
            ])

    @tff.federated_computation(
        tff.FederatedType(model_weights_type, tff.SERVER),
        tff.FederatedType(tff.SequenceType(batch_type), tff.CLIENTS))
    def server_eval(server_model_weights, federated_dataset):
        client_outputs = tff.federated_map(
            client_eval,
            [tff.federated_broadcast(server_model_weights), federated_dataset])
        return model.federated_output_computation(client_outputs.local_outputs)

    return server_eval
def build_federated_evaluation(model_fn):
  """Builds the TFF computation for federated evaluation of the given model.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.

  Returns:
    A federated computation (an instance of `tff.Computation`) that accepts
    model parameters and federated data, and returns the evaluation metrics
    as aggregated by `tff.learning.Model.federated_output_computation`.
  """
  # Construct the model first just to obtain the metadata and define all the
  # types needed to define the computations that follow.
  # TODO(b/124477628): Ideally replace the need for stamping throwaway models
  # with some other mechanism.
  with tf.Graph().as_default():
    model = model_utils.enhance(model_fn())
    model_weights_type = tff.framework.type_from_tensors(model.weights)
    batch_type = tff.to_type(model.input_spec)

  @tff.tf_computation(model_weights_type, tff.SequenceType(batch_type))
  def client_eval(incoming_model_weights, dataset):
    """Returns local outputs after evaluting `model_weights` on `dataset`."""

    model = model_utils.enhance(model_fn())

    @tf.function
    def _tf_client_eval(incoming_model_weights, dataset):
      """Evaluation TF work."""

      tff.utils.assign(model.weights, incoming_model_weights)

      def reduce_fn(prev_loss, batch):
        model_output = model.forward_pass(batch, training=False)
        return prev_loss + tf.cast(model_output.loss, tf.float64)

      dataset.reduce(tf.constant(0.0, dtype=tf.float64), reduce_fn)

      return collections.OrderedDict([('local_outputs',
                                       model.report_local_outputs())])

    return _tf_client_eval(incoming_model_weights, dataset)

  @tff.federated_computation(
      tff.FederatedType(model_weights_type, tff.SERVER),
      tff.FederatedType(tff.SequenceType(batch_type), tff.CLIENTS))
  def server_eval(server_model_weights, federated_dataset):
    client_outputs = tff.federated_map(
        client_eval,
        [tff.federated_broadcast(server_model_weights), federated_dataset])
    return model.federated_output_computation(client_outputs.local_outputs)

  return server_eval
Exemplo n.º 5
0
    def test_mutates_iterproc_accepting_dataset_in_second_index_of_next(self):
        iterproc = _create_stateless_int_dataset_reduction_iterative_process()
        expected_new_next_type_signature = tff.FunctionType([
            tff.FederatedType(tf.int64, tff.SERVER),
            tff.FederatedType(tf.string, tff.CLIENTS)
        ], tff.FederatedType(tf.int64, tff.SERVER))

        new_iterproc = iterative_process_compositions.compose_dataset_computation(
            int_dataset_computation, iterproc)

        self.assertTrue(
            expected_new_next_type_signature.is_equivalent_to(
                new_iterproc.next.type_signature))
Exemplo n.º 6
0
    def test_orchestration_type_signature(self):
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.TrainableLinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=lambda: gradient_descent.SGD(learning_rate=1.0
                                                             ))

        expected_model_weights_type = model_utils.ModelWeights(
            collections.OrderedDict([('a', tff.TensorType(tf.float32, [2, 1])),
                                     ('b', tf.float32)]),
            collections.OrderedDict([('c', tf.float32)]))

        # ServerState consists of a model and optimizer_state. The optimizer_state
        # is provided by TensorFlow, TFF doesn't care what the actual value is.
        expected_federated_server_state_type = tff.FederatedType(
            optimizer_utils.ServerState(expected_model_weights_type,
                                        test.AnyType(), test.AnyType(),
                                        test.AnyType()),
            placement=tff.SERVER,
            all_equal=True)

        expected_federated_dataset_type = tff.FederatedType(tff.SequenceType(
            model_examples.TrainableLinearRegression().input_spec),
                                                            tff.CLIENTS,
                                                            all_equal=False)

        expected_model_output_types = tff.FederatedType(
            collections.OrderedDict([
                ('loss', tff.TensorType(tf.float32, [])),
                ('num_examples', tff.TensorType(tf.int32, [])),
            ]),
            tff.SERVER,
            all_equal=True)

        # `initialize` is expected to be a funcion of no arguments to a ServerState.
        self.assertEqual(
            tff.FunctionType(parameter=None,
                             result=expected_federated_server_state_type),
            iterative_process.initialize.type_signature)

        # `next` is expected be a function of (ServerState, Datasets) to
        # ServerState.
        self.assertEqual(
            tff.FunctionType(parameter=[
                expected_federated_server_state_type,
                expected_federated_dataset_type
            ],
                             result=(expected_federated_server_state_type,
                                     expected_model_output_types)),
            iterative_process.next.type_signature)
def build_stateless_robust_aggregation(model_type,
                                       num_communication_passes=5,
                                       tolerance=1e-6):
    """Create TFF function for robust aggregation.

  The robust aggregate is an approximate geometric median
  computed via the smoothed Weiszfeld algorithm.

  Args:
    model_type: tff typespec of quantity to be aggregated.
    num_communication_passes: number of communication rounds in the smoothed
      Weiszfeld algorithm (min. 1).
    tolerance: smoothing parameter of smoothed Weiszfeld algorithm. Default
      1e-6.

  Returns:
    An instance of `tff.utils.StatefulAggregateFn` which implements a
  (stateless) robust aggregate.
  """
    py_typecheck.check_type(num_communication_passes, int)
    if num_communication_passes < 1:
        raise ValueError('Aggregation requires num_communication_passes >= 1')
    # client weights have been hardcoded as float32, this needs to be
    # parameterized.

    @tff.tf_computation(tf.float32, model_type, model_type)
    def update_weight_fn(weight, server_model, client_model):
        sqnorms = tf.nest.map_structure(lambda a, b: tf.norm(a - b)**2,
                                        server_model, client_model)
        sqnorm = tf.reduce_sum(list(six.itervalues(sqnorms)))
        return weight / tf.math.maximum(tolerance, tf.math.sqrt(sqnorm))

    client_model_type = tff.FederatedType(model_type, tff.CLIENTS)
    client_weight_type = tff.FederatedType(tf.float32, tff.CLIENTS)

    @tff.federated_computation(client_model_type, client_weight_type)
    def robust_aggregation_fn(value, weight):
        aggregate = tff.federated_mean(value, weight=weight)
        for _ in range(num_communication_passes - 1):
            aggregate_at_client = tff.federated_broadcast(aggregate)
            updated_weight = tff.federated_map(
                update_weight_fn, (weight, aggregate_at_client, value))
            aggregate = tff.federated_mean(value, weight=updated_weight)
        return aggregate

    def _stateless_next(state, value, weight):
        return state, robust_aggregation_fn(value, weight)

    return tff.utils.StatefulAggregateFn(initialize_fn=lambda: (),
                                         next_fn=_stateless_next)
Exemplo n.º 8
0
  def test_orchestration_typecheck(self):
    iterative_process = federated_sgd.build_federated_sgd_process(
        model_fn=model_examples.LinearRegression)

    expected_model_weights_type = model_utils.ModelWeights(
        collections.OrderedDict([('a', tff.TensorType(tf.float32, [2, 1])),
                                 ('b', tf.float32)]),
        collections.OrderedDict([('c', tf.float32)]))

    # ServerState consists of a model and optimizer_state. The optimizer_state
    # is provided by TensorFlow, TFF doesn't care what the actual value is.
    expected_federated_server_state_type = tff.FederatedType(
        optimizer_utils.ServerState(expected_model_weights_type,
                                    test.AnyType()),
        placement=tff.SERVER,
        all_equal=True)

    expected_federated_dataset_type = tff.FederatedType(
        tff.SequenceType(
            model_examples.LinearRegression.make_batch(
                tff.TensorType(tf.float32, [None, 2]),
                tff.TensorType(tf.float32, [None, 1]))),
        tff.CLIENTS,
        all_equal=False)

    expected_model_output_types = tff.FederatedType(
        collections.OrderedDict([
            ('loss', tff.TensorType(tf.float32, [])),
            ('num_examples', tff.TensorType(tf.int32, [])),
        ]),
        tff.SERVER,
        all_equal=True)

    # `initialize` is expected to be a funcion of no arguments to a ServerState.
    self.assertEqual(
        tff.FunctionType(
            parameter=None, result=expected_federated_server_state_type),
        iterative_process.initialize.type_signature)

    # `next` is expected be a function of (ServerState, Datasets) to
    # ServerState.
    self.assertEqual(
        tff.FunctionType(
            parameter=[
                expected_federated_server_state_type,
                expected_federated_dataset_type
            ],
            result=(expected_federated_server_state_type,
                    expected_model_output_types)),
        iterative_process.next.type_signature)
Exemplo n.º 9
0
def _wrap_in_measured_process(
        stateful_fn: Union[tff.utils.StatefulBroadcastFn,
                           tff.utils.StatefulAggregateFn],
        input_type: tff.Type) -> tff.templates.MeasuredProcess:
    """Converts a `tff.utils.StatefulFn` to a `tff.templates.MeasuredProcess`."""
    py_typecheck.check_type(
        stateful_fn,
        (tff.utils.StatefulBroadcastFn, tff.utils.StatefulAggregateFn))

    @tff.federated_computation()
    def initialize_comp():
        if not isinstance(stateful_fn.initialize, tff.Computation):
            initialize = tff.tf_computation(stateful_fn.initialize)
        else:
            initialize = stateful_fn.initialize
        return tff.federated_eval(initialize, tff.SERVER)

    state_type = initialize_comp.type_signature.result

    if isinstance(stateful_fn, tff.utils.StatefulBroadcastFn):

        @tff.federated_computation(state_type,
                                   tff.FederatedType(input_type, tff.SERVER))
        def next_comp(state, value):
            empty_metrics = tff.federated_value((), tff.SERVER)
            state, result = stateful_fn(state, value)
            return collections.OrderedDict(state=state,
                                           result=result,
                                           measurements=empty_metrics)

    elif isinstance(stateful_fn, tff.utils.StatefulAggregateFn):

        @tff.federated_computation(state_type,
                                   tff.FederatedType(input_type, tff.CLIENTS),
                                   tff.FederatedType(tf.float32, tff.CLIENTS))
        def next_comp(state, value, weight):
            empty_metrics = tff.federated_value((), tff.SERVER)
            state, result = stateful_fn(state, value, weight)
            return collections.OrderedDict(state=state,
                                           result=result,
                                           measurements=empty_metrics)

    else:
        raise TypeError(
            'Received a {t}, expected either a tff.utils.StatefulAggregateFn or a '
            'tff.utils.StatefulBroadcastFn.'.format(t=type(stateful_fn)))

    return tff.templates.MeasuredProcess(initialize_fn=initialize_comp,
                                         next_fn=next_comp)
Exemplo n.º 10
0
def build_stateless_mean(
    *, model_delta_type: Union[tff.NamedTupleType, tff.TensorType]
) -> tff.templates.MeasuredProcess:
  """Builds a `MeasuredProcess` that wraps` tff.federated_mean`."""

  @tff.federated_computation(NONE_SERVER_TYPE,
                             tff.FederatedType(model_delta_type, tff.CLIENTS),
                             tff.FederatedType(tf.float32, tff.CLIENTS))
  def stateless_mean(state, value, weight):
    empty_metrics = tff.federated_value((), tff.SERVER)
    return collections.OrderedDict(
        state=state,
        result=tff.federated_mean(value, weight=weight),
        measurements=empty_metrics)

  return tff.templates.MeasuredProcess(
      initialize_fn=_empty_server_initialization, next_fn=stateless_mean)
Exemplo n.º 11
0
  def test_construction(self):
    iterative_process = optimizer_utils.build_model_delta_optimizer_process(
        model_fn=model_examples.LinearRegression,
        model_to_client_delta_fn=DummyClientDeltaFn,
        server_optimizer_fn=tf.keras.optimizers.SGD)

    server_state_type = tff.FederatedType(
        optimizer_utils.ServerState(
            model=model_utils.ModelWeights(
                trainable=[
                    tff.TensorType(tf.float32, [2, 1]),
                    tff.TensorType(tf.float32)
                ],
                non_trainable=[tff.TensorType(tf.float32)]),
            optimizer_state=[tf.int64],
            delta_aggregate_state=(),
            model_broadcast_state=()), tff.SERVER)

    self.assertEqual(
        str(iterative_process.initialize.type_signature),
        str(tff.FunctionType(parameter=None, result=server_state_type)))

    dataset_type = tff.FederatedType(
        tff.SequenceType(
            collections.OrderedDict(
                x=tff.TensorType(tf.float32, [None, 2]),
                y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS)

    metrics_type = tff.FederatedType(
        collections.OrderedDict(
            broadcast=(),
            aggregation=(),
            train=collections.OrderedDict(
                loss=tff.TensorType(tf.float32),
                num_examples=tff.TensorType(tf.int32))), tff.SERVER)

    self.assertEqual(
        str(iterative_process.next.type_signature),
        str(
            tff.FunctionType(
                parameter=(server_state_type, dataset_type),
                result=(server_state_type, metrics_type))))
Exemplo n.º 12
0
def _build_test_measured_mean(
        model_update_type: tff.NamedTupleType
) -> tff.templates.MeasuredProcess:
    """Builds a test `MeasuredProcess` that has state and metrics."""
    @tff.federated_computation()
    def initialize_comp():
        return tff.federated_value(0, tff.SERVER)

    @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER),
                               tff.FederatedType(model_update_type,
                                                 tff.CLIENTS),
                               tff.FederatedType(tf.float32, tff.CLIENTS))
    def next_comp(state, value, weight):
        return collections.OrderedDict(
            state=tff.federated_map(_add_one, state),
            result=tff.federated_mean(value, weight),
            measurements=tff.federated_zip(
                collections.OrderedDict(num_clients=tff.federated_sum(
                    tff.federated_value(1, tff.CLIENTS)))))

    return tff.templates.MeasuredProcess(initialize_fn=initialize_comp,
                                         next_fn=next_comp)
Exemplo n.º 13
0
def _build_test_measured_broadcast(
        model_weights_type: tff.NamedTupleType
) -> tff.templates.MeasuredProcess:
    """Builds a test `MeasuredProcess` that has state and metrics."""
    @tff.federated_computation()
    def initialize_comp():
        return tff.federated_value(0, tff.SERVER)

    @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER),
                               tff.FederatedType(model_weights_type,
                                                 tff.SERVER))
    def next_comp(state, value):
        return collections.OrderedDict(
            state=tff.federated_map(_add_one, state),
            result=tff.federated_broadcast(value),
            # Arbitrary metrics for testing.
            measurements=tff.federated_map(
                tff.tf_computation(
                    lambda v: tf.linalg.global_norm(tf.nest.flatten(v)) + 3.0),
                value))

    return tff.templates.MeasuredProcess(initialize_fn=initialize_comp,
                                         next_fn=next_comp)
Exemplo n.º 14
0
    def test_returns_iterproc_accepting_dataset_in_third_index_of_next(self):
        iterproc = _create_stateless_int_dataset_reduction_iterative_process()

        old_param_type = iterproc.next.type_signature.parameter

        new_param_elements = [old_param_type[0], tf.int32, old_param_type[1]]

        @tff.federated_computation(tff.StructType(new_param_elements))
        def new_next(param):
            return iterproc.next([param[0], param[2]])

        iterproc_with_dataset_as_third_elem = tff.templates.IterativeProcess(
            iterproc.initialize, new_next)
        expected_new_next_type_signature = tff.FunctionType([
            tff.FederatedType(tf.int64, tff.SERVER), tf.int32,
            tff.FederatedType(tf.string, tff.CLIENTS)
        ], tff.FederatedType(tf.int64, tff.SERVER))

        new_iterproc = iterative_process_compositions.compose_dataset_computation(
            int_dataset_computation, iterproc_with_dataset_as_third_elem)

        self.assertTrue(
            expected_new_next_type_signature.is_equivalent_to(
                new_iterproc.next.type_signature))
Exemplo n.º 15
0
    def federated_output_computation(self):
        metric_variable_type_dict = nest.map_structure(
            tf.TensorSpec.from_tensor, self.report_local_outputs())
        federated_local_outputs_type = tff.FederatedType(
            metric_variable_type_dict, tff.CLIENTS, all_equal=False)

        @tff.federated_computation(federated_local_outputs_type)
        def federated_output(local_outputs):
            results = collections.OrderedDict()
            for metric, variables in zip(self.get_metrics(), local_outputs):
                results[metric.name] = federated_aggregate_keras_metric(
                    type(metric), metric.get_config(), variables)
            return results

        return federated_output
Exemplo n.º 16
0
def build_stateless_broadcaster(
    *, model_weights_type: Union[tff.NamedTupleType, tff.TensorType]
) -> tff.templates.MeasuredProcess:
  """Builds a `MeasuredProcess` that wraps `tff.federated_broadcast`."""

  @tff.federated_computation(NONE_SERVER_TYPE,
                             tff.FederatedType(model_weights_type, tff.SERVER))
  def stateless_broadcast(state, value):
    empty_metrics = tff.federated_value((), tff.SERVER)
    return collections.OrderedDict(
        state=state,
        result=tff.federated_broadcast(value),
        measurements=empty_metrics)

  return tff.templates.MeasuredProcess(
      initialize_fn=_empty_server_initialization, next_fn=stateless_broadcast)
Exemplo n.º 17
0
    def __init__(self, keras_model: tf.keras.Model, input_spec,
                 loss_fns: List[tf.keras.losses.Loss],
                 loss_weights: List[float],
                 metrics: List[tf.keras.metrics.Metric]):
        self._keras_model = keras_model
        self._input_spec = input_spec
        self._loss_fns = loss_fns
        self._loss_weights = loss_weights
        self._metrics = metrics

        # This is defined here so that it closes over the `loss_fn`.
        class _WeightedMeanLossMetric(tf.keras.metrics.Mean):
            """A `tf.keras.metrics.Metric` wrapper for the loss function."""
            def __init__(self, name='loss', dtype=tf.float32):
                super().__init__(name, dtype)
                self._loss_fns = loss_fns
                self._loss_weights = loss_weights

            def update_state(self, y_true, y_pred, sample_weight=None):
                if len(self._loss_fns) == 1:
                    batch_size = tf.shape(y_pred)[0]
                    batch_loss = self._loss_fns[0](y_true, y_pred)
                else:
                    batch_size = tf.shape(y_pred[0])[0]
                    batch_loss = tf.zeros(())
                    for i in range(len(self._loss_fns)):
                        batch_loss += self._loss_weights[i] * self._loss_fns[
                            i](y_true[i], y_pred[i])

                return super().update_state(batch_loss, batch_size)

        self._loss_metric = _WeightedMeanLossMetric()

        metric_variable_type_dict = tf.nest.map_structure(
            tf.TensorSpec.from_tensor, self.report_local_outputs())
        federated_local_outputs_type = tff.FederatedType(
            metric_variable_type_dict, tff.CLIENTS)

        def federated_output(local_outputs):
            return federated_aggregate_keras_metric(self.get_metrics(),
                                                    local_outputs)

        self._federated_output_computation = tff.federated_computation(
            federated_output, federated_local_outputs_type)
Exemplo n.º 18
0
def _create_stateless_int_dataset_reduction_iterative_process():
    @tff.tf_computation()
    def make_zero():
        return tf.cast(0, tf.int64)

    @tff.federated_computation()
    def init():
        return tff.federated_eval(make_zero, tff.SERVER)

    @tff.tf_computation(tff.SequenceType(tf.int64))
    def reduce_dataset(x):
        return x.reduce(tf.cast(0, tf.int64), lambda x, y: x + y)

    @tff.federated_computation((init.type_signature.result,
                                tff.FederatedType(tff.SequenceType(tf.int64),
                                                  tff.CLIENTS)))
    def next_fn(empty_tup, x):
        del empty_tup  # Unused
        return tff.federated_sum(tff.federated_map(reduce_dataset, x))

    return tff.templates.IterativeProcess(initialize_fn=init, next_fn=next_fn)
Exemplo n.º 19
0
    def __init__(self,
                 inner_model,
                 dummy_batch,
                 loss_fns,
                 loss_weights=None,
                 metrics=None):

        # NOTE: sub-classed `tf.keras.Model`s do not have fully initialized
        # variables until they are called on input. We forced that here.
        if isinstance(dummy_batch, collections.Mapping):
            inner_model(dummy_batch['x'])
        else:
            inner_model(dummy_batch[0])

        def _tensor_spec_with_undefined_batch_dim(tensor):
            # Remove the batch dimension and leave it unspecified.
            spec = tf.TensorSpec(shape=[None] + tensor.shape.dims[1:],
                                 dtype=tensor.dtype)
            return spec

        self._input_spec = tf.nest.map_structure(
            _tensor_spec_with_undefined_batch_dim, dummy_batch)

        self._keras_model = inner_model
        self._loss_fns = loss_fns

        if isinstance(loss_weights, collections.Mapping):
            self._loss_weights = []
            for name in inner_model.output_names:
                if name not in loss_weights:
                    raise KeyError(
                        'Output missing from loss_weights dictionary'
                        '\nloss_weights: {}\noutputs: {}'.format(
                            list(loss_weights.keys()),
                            inner_model.output_names))
                else:
                    self._loss_weights.append(loss_weights[name])
        else:
            if loss_weights is None:
                self._loss_weights = [1.0 for _ in range(len(loss_fns))]
            else:
                self._loss_weights = loss_weights

        loss_weights = self._loss_weights
        self._metrics = metrics if metrics is not None else []

        # This is defined here so that it closes over the `loss_fn`.
        class _WeightedMeanLossMetric(tf.keras.metrics.Mean):
            """A `tf.keras.metrics.Metric` wrapper for the loss function."""
            def __init__(self, name='loss', dtype=tf.float32):
                super(_WeightedMeanLossMetric, self).__init__(name, dtype)
                self._loss_fns = loss_fns
                self._loss_weights = loss_weights

            def update_state(self, y_true, y_pred, sample_weight=None):
                if len(self._loss_fns) == 1:
                    batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype)
                    y_true = tf.cast(y_true, self._dtype)
                    y_pred = tf.cast(y_pred, self._dtype)
                    batch_loss = self._loss_fns[0](y_true, y_pred)

                else:
                    batch_loss = tf.zeros(())
                    for i in range(len(self._loss_fns)):
                        y_t = tf.cast(y_true[i], self._dtype)
                        y_p = tf.cast(y_pred[i], self._dtype)
                        batch_loss += self._loss_weights[i] * self._loss_fns[
                            i](y_t, y_p)

                    batch_size = tf.cast(tf.shape(y_pred[0])[0], self._dtype)

                return super(_WeightedMeanLossMetric,
                             self).update_state(batch_loss, batch_size)

        class _TrainingTimeHistory(tf.keras.metrics.Sum):
            def update_state(self, y_true, y_pred, sample_weight=None):
                pass

            def log_time(self, time_value):
                return super(_TrainingTimeHistory,
                             self).update_state(values=time_value)

        self._loss_metric = _WeightedMeanLossMetric()
        self._training_timing = _TrainingTimeHistory(name='training_time_sec')

        metric_variable_type_dict = tf.nest.map_structure(
            tf.TensorSpec.from_tensor, self.report_local_outputs())
        federated_local_outputs_type = tff.FederatedType(
            metric_variable_type_dict, tff.CLIENTS)

        def federated_output(local_outputs):
            results = collections.OrderedDict()
            for metric, variables in zip(self.get_metrics(), local_outputs):
                results[metric.name] = federated_aggregate_keras_metric(
                    type(metric), metric.get_config(), variables)
            return results

        self._federated_output_computation = tff.federated_computation(
            federated_output, federated_local_outputs_type)

        # Keras creates variables that are not added to any collection, making it
        # impossible for TFF to extract them and create the appropriate initializer
        # before call a tff.Computation. Here we store them in a TFF specific
        # collection so that they can be retrieved later.
        # TODO(b/122081673): this likely goes away in TF2.0
        for variable in itertools.chain(self.trainable_variables,
                                        self.non_trainable_variables,
                                        self.local_variables):
            tf.compat.v1.add_to_collection(
                graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE, variable)
Exemplo n.º 20
0
  Args:
    process: A measured process to validate.

  Returns:
    `True` iff the process is a validate aggregation process, otherwise `False`.
  """
    next_type = process.next.type_signature
    return (isinstance(process, tff.templates.MeasuredProcess)
            and _is_valid_stateful_process(process)
            and next_type.parameter[1].placement is tff.CLIENTS
            and next_type.result.result.placement is tff.SERVER)


# ============================================================================

NONE_SERVER_TYPE = tff.FederatedType((), tff.SERVER)


def _wrap_in_measured_process(
        stateful_fn: Union[tff.utils.StatefulBroadcastFn,
                           tff.utils.StatefulAggregateFn],
        input_type: tff.Type) -> tff.templates.MeasuredProcess:
    """Converts a `tff.utils.StatefulFn` to a `tff.templates.MeasuredProcess`."""
    py_typecheck.check_type(
        stateful_fn,
        (tff.utils.StatefulBroadcastFn, tff.utils.StatefulAggregateFn))

    @tff.federated_computation()
    def initialize_comp():
        if not isinstance(stateful_fn.initialize, tff.Computation):
            initialize = tff.tf_computation(stateful_fn.initialize)
Exemplo n.º 21
0
    def benchmark_fc_api_mnist(self):
        """Code adapted from FC API tutorial ipynb."""
        n_rounds = 10

        batch_type = tff.NamedTupleType([
            ("x", tff.TensorType(tf.float32, [None, 784])),
            ("y", tff.TensorType(tf.int32, [None]))
        ])

        model_type = tff.NamedTupleType([
            ("weights", tff.TensorType(tf.float32, [784, 10])),
            ("bias", tff.TensorType(tf.float32, [10]))
        ])

        local_data_type = tff.SequenceType(batch_type)

        server_model_type = tff.FederatedType(model_type,
                                              tff.SERVER,
                                              all_equal=True)
        client_data_type = tff.FederatedType(local_data_type, tff.CLIENTS)

        server_float_type = tff.FederatedType(tf.float32,
                                              tff.SERVER,
                                              all_equal=True)

        computation_building_start = time.time()

        # pylint: disable=missing-docstring
        @tff.tf_computation(model_type, batch_type)
        def batch_loss(model, batch):
            predicted_y = tf.nn.softmax(
                tf.matmul(batch.x, model.weights) + model.bias)
            return -tf.reduce_mean(
                tf.reduce_sum(tf.one_hot(batch.y, 10) * tf.log(predicted_y),
                              reduction_indices=[1]))

        initial_model = {
            "weights": np.zeros([784, 10], dtype=np.float32),
            "bias": np.zeros([10], dtype=np.float32)
        }

        @tff.tf_computation(model_type, batch_type, tf.float32)
        def batch_train(initial_model, batch, learning_rate):
            model_vars = tff.utils.get_variables("v", model_type)
            init_model = tff.utils.assign(model_vars, initial_model)

            optimizer = tf.train.GradientDescentOptimizer(learning_rate)
            with tf.control_dependencies([init_model]):
                train_model = optimizer.minimize(batch_loss(model_vars, batch))

            with tf.control_dependencies([train_model]):
                return tff.utils.identity(model_vars)

        @tff.federated_computation(model_type, tf.float32, local_data_type)
        def local_train(initial_model, learning_rate, all_batches):
            @tff.federated_computation(model_type, batch_type)
            def batch_fn(model, batch):
                return batch_train(model, batch, learning_rate)

            return tff.sequence_reduce(all_batches, initial_model, batch_fn)

        @tff.federated_computation(server_model_type, server_float_type,
                                   client_data_type)
        def federated_train(model, learning_rate, data):
            return tff.federated_average(
                tff.federated_map(local_train, [
                    tff.federated_broadcast(model),
                    tff.federated_broadcast(learning_rate), data
                ]))

        computation_building_stop = time.time()
        building_time = computation_building_stop - computation_building_start
        self.report_benchmark(name="computation_building_time, FC API",
                              wall_time=building_time,
                              iters=1)

        model = initial_model
        learning_rate = 0.1

        federated_data = generate_fake_mnist_data()

        execution_array = []
        for _ in range(n_rounds):
            execution_start = time.time()
            model = federated_train(model, learning_rate, federated_data)
            execution_stop = time.time()
            execution_array.append(execution_stop - execution_start)

        self.report_benchmark(name="Average per round execution time, FC API",
                              wall_time=np.mean(execution_array),
                              iters=n_rounds,
                              extras={"std_dev": np.std(execution_array)})
Exemplo n.º 22
0
  def __init__(self, inner_model, dummy_batch, loss_fn, metrics):

    # NOTE: sub-classed `tf.keras.Model`s do not have fully initialized
    # variables until they are called on input. We forced that here.
    inner_model(dummy_batch['x'])

    def _tensor_spec_with_undefined_batch_dim(tensor):
      # Remove the batch dimension and leave it unspecified.
      spec = tf.TensorSpec(
          shape=[None] + tensor.shape.dims[1:], dtype=tensor.dtype)
      return spec

    self._input_spec = tf.nest.map_structure(
        _tensor_spec_with_undefined_batch_dim, dummy_batch)

    self._keras_model = inner_model
    self._loss_fn = loss_fn
    self._metrics = metrics if metrics is not None else []

    # This is defined here so that it closes over the `loss_fn`.
    class _WeightedMeanLossMetric(tf.keras.metrics.Mean):
      """A `tf.keras.metrics.Metric` wrapper for the loss function."""

      def __init__(self, name='loss', dtype=tf.float32):
        super(_WeightedMeanLossMetric, self).__init__(name, dtype)
        self._loss_fn = loss_fn

      def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.cast(y_true, self._dtype)
        y_pred = tf.cast(y_pred, self._dtype)

        batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype)
        batch_loss = self._loss_fn(y_true, y_pred)

        return super(_WeightedMeanLossMetric,
                     self).update_state(batch_loss, batch_size)

    self._loss_metric = _WeightedMeanLossMetric()

    metric_variable_type_dict = tf.nest.map_structure(
        tf.TensorSpec.from_tensor, self.report_local_outputs())
    federated_local_outputs_type = tff.FederatedType(metric_variable_type_dict,
                                                     tff.CLIENTS)

    def federated_output(local_outputs):
      results = collections.OrderedDict()
      for metric, variables in zip(self.get_metrics(), local_outputs):
        results[metric.name] = federated_aggregate_keras_metric(
            type(metric), metric.get_config(), variables)
      return results

    self._federated_output_computation = tff.federated_computation(
        federated_output, federated_local_outputs_type)

    # Keras creates variables that are not added to any collection, making it
    # impossible for TFF to extract them and create the appropriate initializer
    # before call a tff.Computation. Here we store them in a TFF specific
    # collection so that they can be retrieved later.
    # TODO(b/122081673): this likely goes away in TF2.0
    for variable in itertools.chain(self.trainable_variables,
                                    self.non_trainable_variables,
                                    self.local_variables):
      tf.add_to_collection(graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE,
                           variable)
Exemplo n.º 23
0
    def __init__(self,
                 inner_model,
                 input_spec,
                 loss_fns,
                 loss_weights=None,
                 metrics=None):
        self._input_spec = input_spec

        if not loss_fns:
            raise ValueError(
                'Must specify at least one loss_fns, got: {l}'.format(
                    l=loss_fns))
        if (bool(len(loss_fns) == 1) != tf.is_tensor(inner_model.output)
                or (isinstance(inner_model.output, list)
                    and len(loss_fns) != len(inner_model.output))):
            raise ValueError(
                'Must specify the same number of loss_fns as model '
                'outputs.\nloss_fns: {l}\nmodel outputs: {o}'.format(
                    l=loss_fns, o=inner_model.output))
        self._loss_fns = loss_fns

        if loss_weights is None:
            loss_weights = [1.0] * len(loss_fns)
        else:
            py_typecheck.check_type(loss_weights, collections.Sequence)
            if len(loss_weights) != len(loss_fns):
                raise ValueError(
                    'Must specify the same number of '
                    'loss_weights (got {llw}) as loss_fns (got {llf}).\n'
                    'loss_weights: {lw}\nloss_fns: {lf}'.format(
                        lw=loss_weights,
                        llw=len(loss_weights),
                        lf=loss_fns,
                        llf=len(loss_fns)))
        self._loss_weights = loss_weights
        self._keras_model = inner_model
        self._metrics = metrics if metrics is not None else []

        # This is defined here so that it closes over the `loss_fn`.
        class _WeightedMeanLossMetric(tf.keras.metrics.Mean):
            """A `tf.keras.metrics.Metric` wrapper for the loss function."""
            def __init__(self, name='loss', dtype=tf.float32):
                super().__init__(name, dtype)
                self._loss_fns = loss_fns
                self._loss_weights = loss_weights

            def update_state(self, y_true, y_pred, sample_weight=None):
                if len(self._loss_fns) == 1:
                    batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype)
                    y_true = tf.cast(y_true, self._dtype)
                    y_pred = tf.cast(y_pred, self._dtype)
                    batch_loss = self._loss_fns[0](y_true, y_pred)

                else:
                    batch_loss = tf.zeros(())
                    for i in range(len(self._loss_fns)):
                        y_t = tf.cast(y_true[i], self._dtype)
                        y_p = tf.cast(y_pred[i], self._dtype)
                        batch_loss += self._loss_weights[i] * self._loss_fns[
                            i](y_t, y_p)

                    batch_size = tf.cast(tf.shape(y_pred[0])[0], self._dtype)

                return super().update_state(batch_loss, batch_size)

        self._loss_metric = _WeightedMeanLossMetric()

        metric_variable_type_dict = tf.nest.map_structure(
            tf.TensorSpec.from_tensor, self.report_local_outputs())
        federated_local_outputs_type = tff.FederatedType(
            metric_variable_type_dict, tff.CLIENTS)

        def federated_output(local_outputs):
            results = collections.OrderedDict()
            for metric, variables in zip(self.get_metrics(), local_outputs):
                results[metric.name] = federated_aggregate_keras_metric(
                    type(metric), metric.get_config(), variables)
            return results

        self._federated_output_computation = tff.federated_computation(
            federated_output, federated_local_outputs_type)

        # Keras creates variables that are not added to any collection, making it
        # impossible for TFF to extract them and create the appropriate initializer
        # before call a tff.Computation. Here we store them in a TFF specific
        # collection so that they can be retrieved later.
        # TODO(b/122081673): this likely goes away in TF2.0
        for variable in itertools.chain(self.trainable_variables,
                                        self.non_trainable_variables,
                                        self.local_variables):
            tf.compat.v1.add_to_collection(
                graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE, variable)
Exemplo n.º 24
0
def build_personalization_eval(model_fn,
                               personalize_fn_dict,
                               baseline_evaluate_fn,
                               max_num_samples=100,
                               context_tff_type=None):
  """Builds the TFF computation for evaluating personalization strategies.

  The returned TFF computation broadcasts model weights from SERVER to CLIENTS.
  Each client evaluates the personalization strategies given in
  `personalize_fn_dict`. Evaluation metrics from at most `max_num_samples`
  participating clients are collected to the SERVER.

  Args:
    model_fn: A no-argument function that returns a `tff.learning.Model`.
    personalize_fn_dict: An `OrderedDict` that maps a `string` (representing a
      strategy name) to a no-argument function that returns a `tf.function`.
      Each `tf.function` represents a personalization strategy: it accepts a
      `tff.learning.Model` (with weights already initialized to the provided
      model weights when users invoke the returned TFF computation), a training
      `tf.dataset.Dataset`, a test `tf.dataset.Dataset`, and an arbitrary
      context object (which is used to hold any extra information that a
      personalization strategy may use), trains a personalized model, and
      returns the evaluation metrics. The evaluation metrics are usually
      represented as an `OrderedDict` (or a nested `OrderedDict`) of `string`
      metric names to scalar `tf.Tensor`s.
    baseline_evaluate_fn: A `tf.function` that accepts a `tff.learning.Model`
      (with weights already initialized to the provided model weights when users
      invoke the returned TFF computation), and a `tf.dataset.Dataset`,
      evaluates the model on the dataset, and returns the evaluation metrics.
      The evaluation metrics are usually represented as an `OrderedDict` (or a
      nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s.
      This function is *only* used to compute the baseline metrics of the
      initial model.
    max_num_samples: A positive `int` specifying the maximum number of metric
      samples to collect in a round. Each sample contains the personalization
      metrics from a single client. If the number of participating clients in a
      round is smaller than this value, all clients' metrics are collected.
    context_tff_type: A `tff.Type` of the optional context object used by the
      personalization strategies defined in `personalization_fn_dict`. We use a
      context object to hold any extra information (in addition to the training
      dataset) that personalization may use. If context is used in
      `personalization_fn_dict`, its `tff.Type` must be provided here.

  Returns:
    A federated `tff.Computation` that maps
    < model_weights@SERVER, input@CLIENTS > -> personalization_metrics@SERVER,
    where:
    - model_weights is a `tff.learning.framework.ModelWeights`.
    - each client's input is an `OrderedDict` of at least two keys `train_data`
      and `test_data`, and each key is mapped to a `tf.dataset.Dataset`. If
      context is used in `personalize_fn_dict`, then client input has a third
      key `context` that is mapped to a object whose `tff.Type` is provided by
      the `context_tff_type` argument.
    - personazliation_metrics is an `OrderedDict` that maps a key
      'baseline_metrics' to the evaluation metrics of the initial model
      (computed by `baseline_evaluate_fn`), and maps keys (strategy names) in
      `personalize_fn_dict` to the evaluation metrics of the corresponding
      personalization strategies.
    - Note: only metrics from at most `max_num_samples` participating clients
      are collected to the SERVER. All collected metrics are stored in a
      single `OrderedDict` (the personalization_metrics shown above), where each
      metric is mapped to a list of scalars (each scalar comes from one client).
      Metric values at the same position, e.g., metric_1[i], metric_2[i]..., all
      come from the same client.

  Raises:
    TypeError: If arguments are of the wrong types.
    ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`.
    ValueError: If `max_num_samples` is not positive.
  """
  # Obtain the types by constructing the model first.
  # TODO(b/124477628): Replace it with other ways of handling metadata.
  with tf.Graph().as_default():
    py_typecheck.check_callable(model_fn)
    model = model_utils.enhance(model_fn())
    model_weights_type = tff.framework.type_from_tensors(model.weights)
    batch_type = tff.to_type(model.input_spec)

  # Define the `tff.Type` of each client's input.
  client_input_type = collections.OrderedDict([
      ('train_data', tff.SequenceType(batch_type)),
      ('test_data', tff.SequenceType(batch_type))
  ])
  if context_tff_type is not None:
    py_typecheck.check_type(context_tff_type, tff.Type)
    client_input_type['context'] = context_tff_type
  client_input_type = tff.to_type(client_input_type)

  @tff.tf_computation(model_weights_type, client_input_type)
  def _client_computation(initial_model_weights, client_input):
    """TFF computation that runs on each client."""
    model = model_fn()
    train_data = client_input['train_data']
    test_data = client_input['test_data']
    context = client_input.get('context', None)
    return _client_fn(model, initial_model_weights, train_data, test_data,
                      personalize_fn_dict, baseline_evaluate_fn, context)

  py_typecheck.check_type(max_num_samples, int)
  if max_num_samples <= 0:
    raise ValueError('max_num_samples must be a positive integer.')

  @tff.federated_computation(
      tff.FederatedType(model_weights_type, tff.SERVER),
      tff.FederatedType(client_input_type, tff.CLIENTS))
  def personalization_eval(server_model_weights, federated_client_input):
    """TFF orchestration logic."""
    client_init_weights = tff.federated_broadcast(server_model_weights)
    client_final_metrics = tff.federated_map(
        _client_computation, (client_init_weights, federated_client_input))

    # WARNING: Collecting information from clients can be risky. Users have to
    # make sure that it is proper to collect those metrics from clients.
    # TODO(b/147889283): Add a link to the TFF doc once it exists.
    results = tff.utils.federated_sample(client_final_metrics, max_num_samples)
    return results

  return personalization_eval
Exemplo n.º 25
0
  def __init__(self, inner_model, dummy_batch, loss_fn, metrics):
    # TODO(b/124477598): the following set_session() should be removed in the
    # future. This is a workaround for Keras' caching sessions in a way that
    # isn't compatible with TFF. This is already fixed in TF master, but not as
    # of v1.13.1.
    #
    # We do not use .clear_session() because it blows away the graph stack by
    # resetting the default graph.
    tf.keras.backend.set_session(None)

    if hasattr(dummy_batch, '_asdict'):
      dummy_batch = dummy_batch._asdict()
    # Convert input to tensors, possibly from nested lists that need to be
    # converted to a single top-level tensor.
    dummy_tensors = collections.OrderedDict([
        (k, tf.convert_to_tensor_or_sparse_tensor(v))
        for k, v in six.iteritems(dummy_batch)
    ])
    # NOTE: sub-classed `tf.keras.Model`s do not have fully initialized
    # variables until they are called on input. We forced that here.
    inner_model(dummy_tensors['x'])

    def _tensor_spec_with_undefined_batch_dim(tensor):
      # Remove the batch dimension and leave it unspecified.
      spec = tf.TensorSpec(
          shape=[None] + tensor.shape.dims[1:], dtype=tensor.dtype)
      return spec

    self._input_spec = nest.map_structure(_tensor_spec_with_undefined_batch_dim,
                                          dummy_tensors)

    self._keras_model = inner_model
    self._loss_fn = loss_fn
    self._metrics = metrics if metrics is not None else []

    # This is defined here so that it closes over the `loss_fn`.
    class _WeightedMeanLossMetric(keras_metrics.Metric):
      """A `tf.keras.metrics.Metric` wrapper for the loss function."""

      def __init__(self, name='loss', dtype=tf.float32):
        super(_WeightedMeanLossMetric, self).__init__(name, dtype)
        self._total_loss = self.add_weight('total_loss', initializer='zeros')
        self._total_weight = self.add_weight(
            'total_weight', initializer='zeros')
        self._loss_fn = loss_fn

      def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.cast(y_true, self._dtype)
        y_pred = tf.cast(y_pred, self._dtype)

        # _loss_fn is expected to return the scalar mean loss, so we multiply by
        # the batch_size to get back to total loss.
        batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype)
        batch_total_loss = self._loss_fn(y_true, y_pred) * batch_size

        op = self._total_loss.assign_add(batch_total_loss)
        with tf.control_dependencies([op]):
          return self._total_weight.assign_add(batch_size)

      def result(self):
        return tf.div_no_nan(self._total_loss, self._total_weight)

    self._loss_metric = _WeightedMeanLossMetric()

    metric_variable_type_dict = nest.map_structure(tf.TensorSpec.from_tensor,
                                                   self.report_local_outputs())
    federated_local_outputs_type = tff.FederatedType(
        metric_variable_type_dict, tff.CLIENTS, all_equal=False)

    def federated_output(local_outputs):
      results = collections.OrderedDict()
      for metric, variables in zip(self.get_metrics(), local_outputs):
        results[metric.name] = federated_aggregate_keras_metric(
            type(metric), metric.get_config(), variables)
      return results

    self._federated_output_computation = tff.federated_computation(
        federated_output, federated_local_outputs_type)

    # Keras creates variables that are not added to any collection, making it
    # impossible for TFF to extract them and create the appropriate initializer
    # before call a tff.Computation. Here we store them in a TFF specific
    # collection so that they can be retrieved later.
    # TODO(b/122081673): this likely goes away in TF2.0
    for variable in itertools.chain(self.trainable_variables,
                                    self.non_trainable_variables,
                                    self.local_variables):
      tf.add_to_collection(graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE,
                           variable)
Exemplo n.º 26
0
  Args:
    process: A measured process to validate.

  Returns:
    `True` iff the process is a validate aggregation process, otherwise `False`.
  """
  next_type = process.next.type_signature
  return (isinstance(process, tff.templates.MeasuredProcess) and
          _is_valid_stateful_process(process) and
          next_type.parameter[1].placement is tff.CLIENTS and
          next_type.result.result.placement is tff.SERVER)


# ============================================================================

NONE_SERVER_TYPE = tff.FederatedType(tff.NamedTupleType([]), tff.SERVER)


def _wrap_in_measured_process(
    stateful_fn: Union[tff.utils.StatefulBroadcastFn,
                       tff.utils.StatefulAggregateFn],
    input_type: tff.Type) -> tff.templates.MeasuredProcess:
  """Converts a `tff.utils.StatefulFn` to a `tff.templates.MeasuredProcess`."""
  py_typecheck.check_type(
      stateful_fn,
      (tff.utils.StatefulBroadcastFn, tff.utils.StatefulAggregateFn))

  @tff.federated_computation()
  def initialize_comp():
    if not isinstance(stateful_fn.initialize, tff.Computation):
      initialize = tff.tf_computation(stateful_fn.initialize)
Exemplo n.º 27
0
def build_model_delta_optimizer_process(model_fn, model_to_client_delta_fn,
                                        server_optimizer_fn):
  """Constructs `tff.utils.IterativeProcess` for Federated Averaging or SGD.

  This provides the TFF orchestration logic connecting the common server logic
  which applies aggregated model deltas to the server model with a ClientDeltaFn
  that specifies how weight_deltas are computed on device.

  Note: We pass in functions rather than constructed objects so we can ensure
  any variables or ops created in constructors are placed in the correct graph.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    model_to_client_delta_fn: A function from a model_fn to a `ClientDeltaFn`.
    server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The
      `apply_gradients` method of this optimizer is used to apply client updates
      to the server model.

  Returns:
    A `tff.utils.IterativeProcess`.
  """
  py_typecheck.check_callable(model_fn)
  py_typecheck.check_callable(model_to_client_delta_fn)
  py_typecheck.check_callable(server_optimizer_fn)

  # TODO(b/122081673): would be nice not to have the construct a throwaway model
  # here just to get the types. After fully moving to TF2.0 and eager-mode, we
  # should re-evaluate what happens here and where `g` is used below.
  with tf.Graph().as_default() as g:
    dummy_model_for_metadata = model_utils.enhance(model_fn())

  @tff.federated_computation
  def server_init_tff():
    """Orchestration logic for server model initialization."""
    no_arg_server_init_fn = lambda: server_init(model_fn, server_optimizer_fn)
    server_init_tf = tff.tf_computation(no_arg_server_init_fn)
    return tff.federated_value(server_init_tf(), tff.SERVER)

  federated_server_state_type = server_init_tff.type_signature.result
  server_state_type = federated_server_state_type.member

  tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec)
  federated_dataset_type = tff.FederatedType(
      tf_dataset_type, tff.CLIENTS, all_equal=False)

  @tff.federated_computation(federated_server_state_type,
                             federated_dataset_type)
  def run_one_round_tff(server_state, federated_dataset):
    """Orchestration logic for one round of optimization.

    Args:
      server_state: a `tff.learning.framework.ServerState` named tuple.
      federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS.

    Returns:
      A tuple of updated `tff.learning.framework.ServerState` and the result of
    `tff.learning.Model.federated_output_computation`.
    """
    model_weights_type = federated_server_state_type.member.model

    @tff.tf_computation(tf_dataset_type, model_weights_type)
    def client_delta_tf(tf_dataset, initial_model_weights):
      """Performs client local model optimization.

      Args:
        tf_dataset: a `tf.data.Dataset` that provides training examples.
        initial_model_weights: a `model_utils.ModelWeights` containing the
          starting weights.

      Returns:
        A `ClientOutput` structure.
      """
      client_delta_fn = model_to_client_delta_fn(model_fn)

      # TODO(b/123092620): this can be removed once AnonymousTuple works with
      # tf.contrib.framework.nest, or the following behavior is moved to
      # anonymous_tuple module.
      if isinstance(initial_model_weights, anonymous_tuple.AnonymousTuple):
        initial_model_weights = model_utils.ModelWeights.from_tff_value(
            initial_model_weights)

      client_output = client_delta_fn(tf_dataset, initial_model_weights)
      return client_output

    client_outputs = tff.federated_map(
        client_delta_tf,
        (federated_dataset, tff.federated_broadcast(server_state.model)))

    @tff.tf_computation(server_state_type, model_weights_type.trainable)
    def server_update_model_tf(server_state, model_delta):
      """Converts args to correct python types and calls server_update_model."""
      # We need to convert TFF types to the types server_update_model expects.
      # TODO(b/123092620): Mixing AnonymousTuple with other nested types is not
      # pretty, fold this into anonymous_tuple module or get working with
      # tf.contrib.framework.nest.
      py_typecheck.check_type(model_delta, anonymous_tuple.AnonymousTuple)
      model_delta = anonymous_tuple.to_odict(model_delta)
      py_typecheck.check_type(server_state, anonymous_tuple.AnonymousTuple)
      server_state = ServerState(
          model=model_utils.ModelWeights.from_tff_value(server_state.model),
          optimizer_state=list(server_state.optimizer_state))

      return server_update_model(
          server_state,
          model_delta,
          model_fn=model_fn,
          optimizer_fn=server_optimizer_fn)

    # TODO(b/124070381): We hope to remove this explicit cast once we have a
    # full solution for type analysis in multiplications and divisions
    # inside TFF
    fed_weight_type = client_outputs.weights_delta_weight.type_signature.member
    py_typecheck.check_type(fed_weight_type, tff.TensorType)
    if fed_weight_type.dtype.is_integer:

      @tff.tf_computation(fed_weight_type)
      def _cast_to_float(x):
        return tf.cast(x, tf.float32)

      weight_denom = tff.federated_map(_cast_to_float,
                                       client_outputs.weights_delta_weight)
    else:
      weight_denom = client_outputs.weights_delta_weight
    round_model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=weight_denom)

    # TODO(b/123408447): remove tff.federated_apply and call
    # server_update_model_tf directly once T <-> T@SERVER isomorphism is
    # supported.
    server_state = tff.federated_apply(server_update_model_tf,
                                       (server_state, round_model_delta))

    # Re-use graph used to construct `model`, since it has the variables, which
    # need to be read in federated_output_computation to get the correct shapes
    # and types for the federated aggregation.
    with g.as_default():
      aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
          client_outputs.model_output)

    # Promote the FederatedType outside the NamedTupleType
    aggregated_outputs = tff.federated_zip(aggregated_outputs)

    return server_state, aggregated_outputs

  return tff.utils.IterativeProcess(
      initialize_fn=server_init_tff, next_fn=run_one_round_tff)
Exemplo n.º 28
0
def _build_one_round_computation(
    *,
    model_fn: _ModelConstructor,
    server_optimizer_fn: _OptimizerConstructor,
    model_to_client_delta_fn: Callable[[Callable[[], model_lib.Model]],
                                       ClientDeltaFn],
    broadcast_process: tff.templates.MeasuredProcess,
    aggregation_process: tff.templates.MeasuredProcess,
) -> tff.Computation:
  """Builds the `next` computation for a model delta averaging process.

  Args:
    model_fn: a no-argument callable that constructs and returns a
      `tff.learning.Model`. *Must* construct and return a new model when called.
      Returning captured models from other scopes will raise errors.
    server_optimizer_fn: a no-argument callable that constructs and returns a
      `tf.keras.optimizers.Optimizer`. *Must* construct and return a new
      optimizer when called. Returning captured optimizers from other scopes
      will raise errors.
    model_to_client_delta_fn: a callable that takes a single no-arg callable
      that returns `tff.learning.Model` as an argument and returns a
      `ClientDeltaFn` which performs the local training loop and model delta
      computation.
    broadcast_process: a `tff.templates.MeasuredProcess` to broadcast the
      global model to the clients.
    aggregation_process: a `tff.templates.MeasuredProcess` to aggregate client
      model deltas.

  Returns:
    A `tff.Computation` that initializes the process. The computation takes
    a tuple of `(ServerState@SERVER, tf.data.Dataset@CLIENTS)` argument, and
    returns a tuple of `(ServerState@SERVER, metrics@SERVER)`.
  """
  # TODO(b/124477628): would be nice not to have the construct a throwaway model
  # here just to get the types. After fully moving to TF2.0 and eager-mode, we
  # should re-evaluate what happens here.
  # TODO(b/144382142): Keras name uniquification is probably the main reason we
  # still need this.
  with tf.Graph().as_default():
    dummy_model_for_metadata = model_fn()
    model_weights_type = tff.framework.type_from_tensors(
        model_utils.ModelWeights.from_model(dummy_model_for_metadata))

    dummy_optimizer = server_optimizer_fn()
    # We must force variable creation for momentum and adaptive optimizers.
    _eagerly_create_optimizer_variables(
        model=dummy_model_for_metadata, optimizer=dummy_optimizer)
    optimizer_variable_type = tff.framework.type_from_tensors(
        dummy_optimizer.variables())

  @tff.tf_computation(model_weights_type, model_weights_type.trainable,
                      optimizer_variable_type)
  def server_update(global_model, model_delta, optimizer_state):
    """Converts args to correct python types and calls server_update_model."""
    # Construct variables first.
    model = model_fn()
    optimizer = server_optimizer_fn()
    # We must force variable creation for momentum and adaptive optimizers.
    _eagerly_create_optimizer_variables(model=model, optimizer=optimizer)

    @tf.function
    def update_model_inner(weights_delta):
      """Applies the update to the global model."""
      model_variables = model_utils.ModelWeights.from_model(model)
      optimizer_variables = optimizer.variables()
      # We might have a NaN value e.g. if all of the clients processed
      # had no data, so the denominator in the federated_mean is zero.
      # If we see any NaNs, zero out the whole update.
      no_nan_weights_delta, _ = tensor_utils.zero_all_if_any_non_finite(
          weights_delta)

      # TODO(b/124538167): We should increment a server counter to
      # track the fact a non-finite weights_delta was encountered.

      # Set the variables to the current global model (before update).
      tf.nest.map_structure(lambda a, b: a.assign(b),
                            (model_variables, optimizer_variables),
                            (global_model, optimizer_state))
      # Update the variables with the delta, and return the new global model.
      _apply_delta(optimizer=optimizer, model=model, delta=no_nan_weights_delta)
      return model_variables, optimizer_variables

    return update_model_inner(model_delta)

  dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec)

  @tff.tf_computation(dataset_type, model_weights_type)
  def _compute_local_training_and_client_delta(dataset, initial_model_weights):
    """Performs client local model optimization.

    Args:
      dataset: a `tf.data.Dataset` that provides training examples.
      initial_model_weights: a `model_utils.ModelWeights` containing the
        starting weights.

    Returns:
      A `ClientOutput` structure.
    """
    client_delta_fn = model_to_client_delta_fn(model_fn)
    client_output = client_delta_fn(dataset, initial_model_weights)
    return client_output

  broadcast_state = broadcast_process.initialize.type_signature.result.member
  aggregation_state = aggregation_process.initialize.type_signature.result.member

  server_state_type = ServerState(
      model=model_weights_type,
      optimizer_state=optimizer_variable_type,
      delta_aggregate_state=aggregation_state,
      model_broadcast_state=broadcast_state)

  @tff.federated_computation(
      tff.FederatedType(server_state_type, tff.SERVER),
      tff.FederatedType(dataset_type, tff.CLIENTS))
  def one_round_computation(server_state, federated_dataset):
    """Orchestration logic for one round of optimization.

    Args:
      server_state: a `tff.learning.framework.ServerState` named tuple.
      federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS.

    Returns:
      A tuple of updated `tff.learning.framework.ServerState` and the result of
      `tff.learning.Model.federated_output_computation`, both having
      `tff.SERVER` placement.
    """
    broadcast_output = broadcast_process.next(
        server_state.model_broadcast_state, server_state.model)
    client_outputs = tff.federated_map(
        _compute_local_training_and_client_delta,
        (federated_dataset, broadcast_output.result))
    aggregation_output = aggregation_process.next(
        server_state.delta_aggregate_state, client_outputs.weights_delta,
        client_outputs.weights_delta_weight)
    new_global_model, new_optimizer_state = tff.federated_map(
        server_update, (server_state.model, aggregation_output.result,
                        server_state.optimizer_state))
    new_server_state = tff.federated_zip(
        ServerState(new_global_model, new_optimizer_state,
                    aggregation_output.state, broadcast_output.state))
    aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
        client_outputs.model_output)
    measurements = tff.federated_zip(
        collections.OrderedDict(
            broadcast=broadcast_output.measurements,
            aggregation=aggregation_output.measurements,
            train=aggregated_outputs))
    return new_server_state, measurements

  return one_round_computation
Exemplo n.º 29
0
def build_model_delta_optimizer_process(
    model_fn,
    model_to_client_delta_fn,
    server_optimizer_fn,
    stateful_delta_aggregate_fn=build_stateless_mean(),
    stateful_model_broadcast_fn=build_stateless_broadcaster()):
    """Constructs `tff.utils.IterativeProcess` for Federated Averaging or SGD.

  This provides the TFF orchestration logic connecting the common server logic
  which applies aggregated model deltas to the server model with a
  `ClientDeltaFn` that specifies how `weight_deltas` are computed on device.

  Note: We pass in functions rather than constructed objects so we can ensure
  any variables or ops created in constructors are placed in the correct graph.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    model_to_client_delta_fn: A function from a `model_fn` to a `ClientDeltaFn`.
    server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The
      `apply_gradients` method of this optimizer is used to apply client updates
      to the server model.
    stateful_delta_aggregate_fn: A `tff.utils.StatefulAggregateFn` where the
      `next_fn` performs a federated aggregation and upates state. That is, it
      has TFF type `(state@SERVER, value@CLIENTS, weights@CLIENTS) ->
      (state@SERVER, aggregate@SERVER)`, where the `value` type is
      `tff.learning.framework.ModelWeights.trainable` corresponding to the
      object returned by `model_fn`.
    stateful_model_broadcast_fn: A `tff.utils.StatefulBroadcastFn` where the
      `next_fn` performs a federated broadcast and upates state. That is, it has
      TFF type `(state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS)`,
      where the `value` type is `tff.learning.framework.ModelWeights`
      corresponding to the object returned by `model_fn`.

  Returns:
    A `tff.utils.IterativeProcess`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_callable(model_to_client_delta_fn)
    py_typecheck.check_callable(server_optimizer_fn)
    py_typecheck.check_type(stateful_delta_aggregate_fn,
                            tff.utils.StatefulAggregateFn)
    py_typecheck.check_type(stateful_model_broadcast_fn,
                            tff.utils.StatefulBroadcastFn)

    # TODO(b/122081673): would be nice not to have the construct a throwaway model
    # here just to get the types. After fully moving to TF2.0 and eager-mode, we
    # should re-evaluate what happens here.
    with tf.Graph().as_default():
        dummy_model_for_metadata = model_utils.enhance(model_fn())

    # ===========================================================================
    # TensorFlow Computations

    @tff.tf_computation
    def tf_init_fn():
        return server_init(model_fn, server_optimizer_fn,
                           stateful_delta_aggregate_fn.initialize(),
                           stateful_model_broadcast_fn.initialize())

    tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec)
    server_state_type = tf_init_fn.type_signature.result

    @tff.tf_computation(tf_dataset_type, server_state_type.model)
    def tf_client_delta(tf_dataset, initial_model_weights):
        """Performs client local model optimization.

    Args:
      tf_dataset: a `tf.data.Dataset` that provides training examples.
      initial_model_weights: a `model_utils.ModelWeights` containing the
        starting weights.

    Returns:
      A `ClientOutput` structure.
    """
        client_delta_fn = model_to_client_delta_fn(model_fn)
        client_output = client_delta_fn(tf_dataset, initial_model_weights)
        return client_output

    @tff.tf_computation(server_state_type, server_state_type.model.trainable,
                        server_state_type.delta_aggregate_state,
                        server_state_type.model_broadcast_state)
    def tf_server_update(server_state, model_delta, new_delta_aggregate_state,
                         new_broadcaster_state):
        """Converts args to correct python types and calls server_update_model."""
        py_typecheck.check_type(server_state, ServerState)
        server_state = ServerState(
            model=server_state.model,
            optimizer_state=list(server_state.optimizer_state),
            delta_aggregate_state=new_delta_aggregate_state,
            model_broadcast_state=new_broadcaster_state)

        return server_update_model(server_state,
                                   model_delta,
                                   model_fn=model_fn,
                                   optimizer_fn=server_optimizer_fn)

    weight_type = tf_client_delta.type_signature.result.weights_delta_weight

    @tff.tf_computation(weight_type)
    def _cast_weight_to_float(x):
        return tf.cast(x, tf.float32)

    # ===========================================================================
    # Federated Computations

    @tff.federated_computation
    def server_init_tff():
        """Orchestration logic for server model initialization."""
        return tff.federated_value(tf_init_fn(), tff.SERVER)

    federated_server_state_type = server_init_tff.type_signature.result
    federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

    @tff.federated_computation(federated_server_state_type,
                               federated_dataset_type)
    def run_one_round_tff(server_state, federated_dataset):
        """Orchestration logic for one round of optimization.

    Args:
      server_state: a `tff.learning.framework.ServerState` named tuple.
      federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS.

    Returns:
      A tuple of updated `tff.learning.framework.ServerState` and the result of
    `tff.learning.Model.federated_output_computation`.
    """
        new_broadcaster_state, client_model = stateful_model_broadcast_fn(
            server_state.model_broadcast_state, server_state.model)

        client_outputs = tff.federated_map(tf_client_delta,
                                           (federated_dataset, client_model))

        # TODO(b/124070381): We hope to remove this explicit cast once we have a
        # full solution for type analysis in multiplications and divisions
        # inside TFF
        weight_denom = tff.federated_map(_cast_weight_to_float,
                                         client_outputs.weights_delta_weight)
        new_delta_aggregate_state, round_model_delta = stateful_delta_aggregate_fn(
            server_state.delta_aggregate_state,
            client_outputs.weights_delta,
            weight=weight_denom)

        # TODO(b/123408447): remove tff.federated_apply and call
        # tf_server_update directly once T <-> T@SERVER isomorphism is
        # supported.
        server_state = tff.federated_apply(
            tf_server_update,
            (server_state, round_model_delta, new_delta_aggregate_state,
             new_broadcaster_state))

        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)

        # Promote the FederatedType outside the NamedTupleType
        aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs

    return tff.utils.IterativeProcess(initialize_fn=server_init_tff,
                                      next_fn=run_one_round_tff)
Exemplo n.º 30
0
    def __init__(self,
                 inner_model,
                 input_spec,
                 loss_fns,
                 loss_weights=None,
                 metrics=None):
        self._input_spec = input_spec

        if not loss_fns:
            raise ValueError(
                'Must specify at least one loss_fns, got: {l}'.format(
                    l=loss_fns))
        if (len(tf.nest.flatten(loss_fns)) != len(
                tf.nest.flatten(inner_model.output))):
            raise ValueError(
                'Must specify the same number of loss_fns as model '
                'outputs.\nloss_fns: {l}\nmodel outputs: {o}'.format(
                    l=loss_fns, o=inner_model.output))
        self._loss_fns = loss_fns

        if loss_weights is None:
            loss_weights = [1.0] * len(loss_fns)
        else:
            py_typecheck.check_type(loss_weights, collections.Sequence)
            if len(loss_weights) != len(loss_fns):
                raise ValueError(
                    'Must specify the same number of '
                    'loss_weights (got {llw}) as loss_fns (got {llf}).\n'
                    'loss_weights: {lw}\nloss_fns: {lf}'.format(
                        lw=loss_weights,
                        llw=len(loss_weights),
                        lf=loss_fns,
                        llf=len(loss_fns)))
        self._loss_weights = loss_weights
        self._keras_model = inner_model
        self._metrics = metrics if metrics is not None else []

        # This is defined here so that it closes over the `loss_fn`.
        class _WeightedMeanLossMetric(tf.keras.metrics.Mean):
            """A `tf.keras.metrics.Metric` wrapper for the loss function."""
            def __init__(self, name='loss', dtype=tf.float32):
                super().__init__(name, dtype)
                self._loss_fns = loss_fns
                self._loss_weights = loss_weights

            def update_state(self, y_true, y_pred, sample_weight=None):
                if len(self._loss_fns) == 1:
                    batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype)
                    y_true = tf.cast(y_true, self._dtype)
                    y_pred = tf.cast(y_pred, self._dtype)
                    batch_loss = self._loss_fns[0](y_true, y_pred)

                else:
                    batch_loss = tf.zeros(())
                    for i in range(len(self._loss_fns)):
                        y_t = tf.cast(y_true[i], self._dtype)
                        y_p = tf.cast(y_pred[i], self._dtype)
                        batch_loss += self._loss_weights[i] * self._loss_fns[
                            i](y_t, y_p)

                    batch_size = tf.cast(tf.shape(y_pred[0])[0], self._dtype)

                return super().update_state(batch_loss, batch_size)

        self._loss_metric = _WeightedMeanLossMetric()

        metric_variable_type_dict = tf.nest.map_structure(
            tf.TensorSpec.from_tensor, self.report_local_outputs())
        federated_local_outputs_type = tff.FederatedType(
            metric_variable_type_dict, tff.CLIENTS)

        def federated_output(local_outputs):
            return federated_aggregate_keras_metric(self.get_metrics(),
                                                    local_outputs)

        self._federated_output_computation = tff.federated_computation(
            federated_output, federated_local_outputs_type)