예제 #1
0
 def test_clients_placed(self):
     x = _mock_data_of_type(
         computation_types.at_clients(
             computation_types.SequenceType(tf.int32)))
     val = intrinsics.sequence_map(self.over_ten_fn(), x)
     self.assert_value(val, '{bool*}@CLIENTS')
예제 #2
0
  def create(
      self,
      value_type: factory.ValueType) -> aggregation_process.AggregationProcess:
    # Validate input args and value_type and parse out the TF dtypes.
    if value_type.is_tensor():
      tf_dtype = value_type.dtype
    elif (value_type.is_struct_with_python() and
          type_analysis.is_structure_of_tensors(value_type)):
      tf_dtype = type_conversions.structure_from_tensor_type_tree(
          lambda x: x.dtype, value_type)
    else:
      raise TypeError('Expected `value_type` to be `TensorType` or '
                      '`StructWithPythonType` containing only `TensorType`. '
                      f'Found type: {repr(value_type)}')

    # Check that all values are floats.
    if not type_analysis.is_structure_of_floats(value_type):
      raise TypeError('Component dtypes of `value_type` must all be floats. '
                      f'Found {repr(value_type)}.')

    if self._distortion_aggregation_factory is not None:
      distortion_aggregation_process = self._distortion_aggregation_factory.create(
          computation_types.to_type(tf.float32))

    @tensorflow_computation.tf_computation(value_type, tf.float32)
    def discretize_fn(value, step_size):
      return _discretize_struct(value, step_size)

    @tensorflow_computation.tf_computation(discretize_fn.type_signature.result,
                                           tf.float32)
    def undiscretize_fn(value, step_size):
      return _undiscretize_struct(value, step_size, tf_dtype)

    @tensorflow_computation.tf_computation(value_type, tf.float32)
    def distortion_measurement_fn(value, step_size):
      reconstructed_value = undiscretize_fn(
          discretize_fn(value, step_size), step_size)
      err = tf.nest.map_structure(tf.subtract, reconstructed_value, value)
      squared_err = tf.nest.map_structure(tf.square, err)
      flat_squared_errs = [
          tf.cast(tf.reshape(t, [-1]), tf.float32)
          for t in tf.nest.flatten(squared_err)
      ]
      all_squared_errs = tf.concat(flat_squared_errs, axis=0)
      mean_squared_err = tf.reduce_mean(all_squared_errs)
      return mean_squared_err

    inner_agg_process = self._inner_agg_factory.create(
        discretize_fn.type_signature.result)

    @federated_computation.federated_computation()
    def init_fn():
      state = collections.OrderedDict(
          step_size=intrinsics.federated_value(self._step_size,
                                               placements.SERVER),
          inner_agg_process=inner_agg_process.initialize())
      return intrinsics.federated_zip(state)

    @federated_computation.federated_computation(
        init_fn.type_signature.result, computation_types.at_clients(value_type))
    def next_fn(state, value):
      server_step_size = state['step_size']
      client_step_size = intrinsics.federated_broadcast(server_step_size)

      discretized_value = intrinsics.federated_map(discretize_fn,
                                                   (value, client_step_size))

      inner_state = state['inner_agg_process']
      inner_agg_output = inner_agg_process.next(inner_state, discretized_value)

      undiscretized_agg_value = intrinsics.federated_map(
          undiscretize_fn, (inner_agg_output.result, server_step_size))

      new_state = collections.OrderedDict(
          step_size=server_step_size, inner_agg_process=inner_agg_output.state)
      measurements = collections.OrderedDict(
          deterministic_discretization=inner_agg_output.measurements)

      if self._distortion_aggregation_factory is not None:
        distortions = intrinsics.federated_map(distortion_measurement_fn,
                                               (value, client_step_size))
        aggregate_distortion = distortion_aggregation_process.next(
            distortion_aggregation_process.initialize(), distortions).result
        measurements['distortion'] = aggregate_distortion

      return measured_process.MeasuredProcessOutput(
          state=intrinsics.federated_zip(new_state),
          result=undiscretized_agg_value,
          measurements=intrinsics.federated_zip(measurements))

    return aggregation_process.AggregationProcess(init_fn, next_fn)
예제 #3
0
    def create(self, value_type):
        # Validate input args and value_type and parse out the TF dtypes.
        if value_type.is_tensor():
            tf_dtype = value_type.dtype
        elif (value_type.is_struct_with_python()
              and type_analysis.is_structure_of_tensors(value_type)):
            if self._prior_norm_bound:
                raise TypeError(
                    'If `prior_norm_bound` is specified, `value_type` must '
                    f'be `TensorType`. Found type: {repr(value_type)}.')
            tf_dtype = type_conversions.structure_from_tensor_type_tree(
                lambda x: x.dtype, value_type)
        else:
            raise TypeError(
                'Expected `value_type` to be `TensorType` or '
                '`StructWithPythonType` containing only `TensorType`. '
                f'Found type: {repr(value_type)}')

        # Check that all values are floats.
        if not type_analysis.is_structure_of_floats(value_type):
            raise TypeError(
                'Component dtypes of `value_type` must all be floats. '
                f'Found {repr(value_type)}.')

        discretize_fn = _build_discretize_fn(value_type, self._stochastic,
                                             self._beta)

        @tensorflow_computation.tf_computation(
            discretize_fn.type_signature.result, tf.float32)
        def undiscretize_fn(value, scale_factor):
            return _undiscretize_struct(value, scale_factor, tf_dtype)

        inner_value_type = discretize_fn.type_signature.result
        inner_agg_process = self._inner_agg_factory.create(inner_value_type)

        @federated_computation.federated_computation()
        def init_fn():
            state = collections.OrderedDict(
                scale_factor=intrinsics.federated_value(
                    self._scale_factor, placements.SERVER),
                prior_norm_bound=intrinsics.federated_value(
                    self._prior_norm_bound, placements.SERVER),
                inner_agg_process=inner_agg_process.initialize())
            return intrinsics.federated_zip(state)

        @federated_computation.federated_computation(
            init_fn.type_signature.result,
            computation_types.at_clients(value_type))
        def next_fn(state, value):
            server_scale_factor = state['scale_factor']
            client_scale_factor = intrinsics.federated_broadcast(
                server_scale_factor)
            server_prior_norm_bound = state['prior_norm_bound']
            prior_norm_bound = intrinsics.federated_broadcast(
                server_prior_norm_bound)

            discretized_value = intrinsics.federated_map(
                discretize_fn, (value, client_scale_factor, prior_norm_bound))

            inner_state = state['inner_agg_process']
            inner_agg_output = inner_agg_process.next(inner_state,
                                                      discretized_value)

            undiscretized_agg_value = intrinsics.federated_map(
                undiscretize_fn,
                (inner_agg_output.result, server_scale_factor))

            new_state = collections.OrderedDict(
                scale_factor=server_scale_factor,
                prior_norm_bound=server_prior_norm_bound,
                inner_agg_process=inner_agg_output.state)
            measurements = collections.OrderedDict(
                discretize=inner_agg_output.measurements)

            return measured_process.MeasuredProcessOutput(
                state=intrinsics.federated_zip(new_state),
                result=undiscretized_agg_value,
                measurements=intrinsics.federated_zip(measurements))

        return aggregation_process.AggregationProcess(init_fn, next_fn)
예제 #4
0
def create_whimsy_intrinsic_def_federated_sum():
    value = intrinsic_defs.FEDERATED_SUM
    type_signature = computation_types.FunctionType(
        computation_types.at_clients(tf.float32),
        computation_types.at_server(tf.float32))
    return value, type_signature
예제 #5
0
def create_whimsy_value_at_clients(number_of_clients: int = 3):
    """Returns a Python value and federated type at clients."""
    value = [float(x) for x in range(10, number_of_clients + 10)]
    type_signature = computation_types.at_clients(tf.float32)
    return value, type_signature
 def test_roundtrip_with_nonempty_tuple_clients_argument(self):
   value = tuple(range(10))
   type_signature = computation_types.at_clients(tf.int32)
   self.assertRoundTripEqual(value, type_signature, value)
예제 #7
0
def create_whimsy_intrinsic_def_federated_broadcast():
    value = intrinsic_defs.FEDERATED_BROADCAST
    type_signature = computation_types.FunctionType(
        computation_types.at_server(tf.float32),
        computation_types.at_clients(tf.float32, all_equal=True))
    return value, type_signature
예제 #8
0
def build_federated_evaluation(
    model_fn: Callable[[], model_lib.Model],
    broadcast_process: Optional[measured_process.MeasuredProcess] = None,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    use_experimental_simulation_loop: bool = False,
) -> computation_base.Computation:
    """Builds the TFF computation for federated evaluation of the given model.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the
      model weights on the server to the clients. It must support the signature
      `(input_values@SERVER -> output_values@CLIENTS)` and have empty state. If
      set to default None, the server model is broadcast to the clients using
      the default tff.federated_broadcast.
    metrics_aggregator: An optional function that takes in the metric finalizers
      (i.e., `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a federated TFF computation of the following type signature
      `local_unfinalized_metrics@CLIENTS -> aggregated_metrics@SERVER`. If
      `None`, uses `tff.learning.metrics.sum_then_finalize`, which returns a
      federated TFF computation that sums the unfinalized metrics from
      `CLIENTS`, and then applies the corresponding metric finalizers at
      `SERVER`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation.

  Returns:
    A federated computation (an instance of `tff.Computation`) that accepts
    model parameters and federated data, and returns the evaluation metrics.
  """
    if broadcast_process is not None:
        if not isinstance(broadcast_process, measured_process.MeasuredProcess):
            raise ValueError(
                '`broadcast_process` must be a `MeasuredProcess`, got '
                f'{type(broadcast_process)}.')
        if iterative_process.is_stateful(broadcast_process):
            raise ValueError(
                'Cannot create a federated evaluation with a stateful '
                'broadcast process, must be stateless (have empty state), has state: '
                f'{broadcast_process.initialize.type_signature.result!r}')
    # Construct the model first just to obtain the metadata and define all the
    # types needed to define the computations that follow.
    # TODO(b/124477628): Ideally replace the need for stamping throwaway models
    # with some other mechanism.
    with tf.Graph().as_default():
        model = model_fn()
        model_weights_type = model_utils.weights_type_from_model(model)
        batch_type = computation_types.to_type(model.input_spec)
        unfinalized_metrics_type = type_conversions.type_from_tensors(
            model.report_local_unfinalized_metrics())
        if metrics_aggregator is not None:
            metrics_aggregation_computation = metrics_aggregator(
                model.metric_finalizers(), unfinalized_metrics_type)
        else:
            metrics_aggregation_computation = aggregator.sum_then_finalize(
                model.metric_finalizers(), unfinalized_metrics_type)

    @federated_computation.federated_computation(
        computation_types.at_server(model_weights_type),
        computation_types.at_clients(SequenceType(batch_type)))
    def server_eval(server_model_weights, federated_dataset):
        client_eval = build_local_evaluation(model_fn, model_weights_type,
                                             batch_type,
                                             use_experimental_simulation_loop)
        if broadcast_process is not None:
            # TODO(b/179091838): Zip the measurements from the broadcast_process with
            # the result of `model_metrics` below to avoid dropping these metrics.
            broadcast_output = broadcast_process.next(
                broadcast_process.initialize(), server_model_weights)
            client_outputs = intrinsics.federated_map(
                client_eval, (broadcast_output.result, federated_dataset))
        else:
            client_outputs = intrinsics.federated_map(client_eval, [
                intrinsics.federated_broadcast(server_model_weights),
                federated_dataset
            ])
        model_metrics = metrics_aggregation_computation(
            client_outputs.local_outputs)
        return intrinsics.federated_zip(
            collections.OrderedDict(eval=model_metrics))

    return server_eval
예제 #9
0
#
# @federated_computation
# def federated_aggregate(x, zero, accumulate, merge, report):
#   a = generic_partial_reduce(x, zero, accumulate, INTERMEDIATE_AGGREGATORS)
#   b = generic_reduce(a, zero, merge, SERVER)
#   c = generic_map(report, b)
#   return c
#
# Actual implementations might vary.
#
# Type signature: <{T}@CLIENTS,U,(<U,T>->U),(<U,U>->U),(U->R)> -> R@SERVER
FEDERATED_AGGREGATE = IntrinsicDef(
    'FEDERATED_AGGREGATE',
    'federated_aggregate',
    computation_types.FunctionType(parameter=[
        computation_types.at_clients(computation_types.AbstractType('T')),
        computation_types.AbstractType('U'),
        type_factory.reduction_op(computation_types.AbstractType('U'),
                                  computation_types.AbstractType('T')),
        type_factory.binary_op(computation_types.AbstractType('U')),
        computation_types.FunctionType(computation_types.AbstractType('U'),
                                       computation_types.AbstractType('R'))
    ],
                                   result=computation_types.at_server(
                                       computation_types.AbstractType('R'))),
    aggregation_kind=AggregationKind.DEFAULT)

# Applies a given function to a value on the server.
#
# Type signature: <(T->U),T@SERVER> -> U@SERVER
FEDERATED_APPLY = IntrinsicDef(
예제 #10
0
def build_federated_evaluation(
    model_fn: training_process.ModelFn,
    *,  # Callers pass below args by name.
    loss_fn: training_process.LossFn,
    metrics_fn: Optional[training_process.MetricsFn] = None,
    reconstruction_optimizer_fn: training_process.OptimizerFn = functools.
    partial(tf.keras.optimizers.SGD, 0.1),
    dataset_split_fn: Optional[reconstruction_utils.DatasetSplitFn] = None,
    broadcast_process: Optional[measured_process_lib.MeasuredProcess] = None,
) -> computation_base.Computation:
    """Builds a `tff.Computation` for evaluating a reconstruction `Model`.

  The returned computation proceeds in two stages: (1) reconstruction and (2)
  evaluation. During the reconstruction stage, local variables are reconstructed
  by freezing global variables and training using `reconstruction_optimizer_fn`.
  During the evaluation stage, the reconstructed local variables and global
  variables are evaluated using the provided `loss_fn` and `metrics_fn`.

  Usage of returned computation:
    eval_comp = build_federated_evaluation(...)
    metrics = eval_comp(tff.learning.reconstruction.get_global_variables(model),
                        federated_data)

  Args:
    model_fn: A no-arg function that returns a
      `tff.learning.reconstruction.Model`. This method must *not* capture
      Tensorflow tensors or variables and use them. Must be constructed entirely
      from scratch on each invocation, returning the same pre-constructed model
      each call will result in an error.
    loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to
      reconstruct and evaluate the model. The loss will be applied to the
      model's outputs during the evaluation stage. The final loss metric is the
      example-weighted mean loss across batches (and across clients).
    metrics_fn: A no-arg function returning a list of `tf.keras.metrics.Metric`s
      to evaluate the model. The metrics will be applied to the model's outputs
      during the evaluation stage. Final metric values are the example-weighted
      mean of metric values across batches (and across clients). If None, no
      metrics are applied.
    reconstruction_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` used to reconstruct the local variables
      with the global ones frozen.
    dataset_split_fn: A `tff.learning.reconstruction.DatasetSplitFn` taking in a
      single TF dataset and producing two TF datasets. The first is iterated
      over during reconstruction, and the second is iterated over during
      evaluation. This can be used to preprocess datasets to e.g. iterate over
      them for multiple epochs or use disjoint data for reconstruction and
      evaluation. If None, split client data in half for each user, using one
      half for reconstruction and the other for evaluation. See
      `tff.learning.reconstruction.build_dataset_split_fn` for options.
    broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the
      model weights on the server to the clients. It must support the signature
      `(input_values@SERVER -> output_values@CLIENT)` and have empty state. If
      set to default None, the server model is broadcast to the clients using
      the default `tff.federated_broadcast`.

  Raises:
    TypeError: if `broadcast_process` does not have the expected signature or
      has non-empty state.

  Returns:
    A `tff.Computation` that accepts global model parameters and federated data
    and returns example-weighted evaluation loss and metrics.
  """
    # Construct the model first just to obtain the metadata and define all the
    # types needed to define the computations that follow.
    with tf.Graph().as_default():
        model = model_fn()
        global_weights = reconstruction_utils.get_global_variables(model)
        model_weights_type = type_conversions.type_from_tensors(global_weights)
        batch_type = computation_types.to_type(model.input_spec)
        metrics = [keras_utils.MeanLossMetric(loss_fn())]
        if metrics_fn is not None:
            metrics.extend(metrics_fn())
        federated_output_computation = (
            keras_utils.federated_output_computation_from_metrics(metrics))
        # Remove unneeded variables to avoid polluting namespace.
        del model
        del global_weights
        del metrics

    if dataset_split_fn is None:
        dataset_split_fn = reconstruction_utils.build_dataset_split_fn(
            split_dataset=True)

    if broadcast_process is None:
        broadcast_process = optimizer_utils.build_stateless_broadcaster(
            model_weights_type=model_weights_type)
    if not optimizer_utils.is_valid_broadcast_process(broadcast_process):
        raise TypeError(
            'broadcast_process type signature does not conform to expected '
            'signature (<state@S, input@S> -> <state@S, result@C, measurements@S>).'
            ' Got: {t}'.format(t=broadcast_process.next.type_signature))
    if iterative_process.is_stateful(broadcast_process):
        raise TypeError(
            f'Eval broadcast_process must be stateless (have an empty '
            'state), has state '
            f'{broadcast_process.initialize.type_signature.result!r}')

    @tensorflow_computation.tf_computation(
        model_weights_type, computation_types.SequenceType(batch_type))
    def client_computation(incoming_model_weights: computation_types.Type,
                           client_dataset: computation_types.SequenceType):
        """Reconstructs and evaluates with `incoming_model_weights`."""
        client_model = model_fn()
        client_global_weights = reconstruction_utils.get_global_variables(
            client_model)
        client_local_weights = reconstruction_utils.get_local_variables(
            client_model)
        metrics = [keras_utils.MeanLossMetric(loss_fn())]
        if metrics_fn is not None:
            metrics.extend(metrics_fn())
        client_loss = loss_fn()
        reconstruction_optimizer = reconstruction_optimizer_fn()

        @tf.function
        def reconstruction_reduce_fn(num_examples_sum, batch):
            """Runs reconstruction training on local client batch."""
            with tf.GradientTape() as tape:
                output = client_model.forward_pass(batch, training=True)
                batch_loss = client_loss(y_true=output.labels,
                                         y_pred=output.predictions)

            gradients = tape.gradient(batch_loss,
                                      client_local_weights.trainable)
            reconstruction_optimizer.apply_gradients(
                zip(gradients, client_local_weights.trainable))
            return num_examples_sum + output.num_examples

        @tf.function
        def evaluation_reduce_fn(num_examples_sum, batch):
            """Runs evaluation on client batch without training."""
            output = client_model.forward_pass(batch, training=False)
            # Update each metric.
            for metric in metrics:
                metric.update_state(y_true=output.labels,
                                    y_pred=output.predictions)
            return num_examples_sum + output.num_examples

        @tf.function
        def tf_client_computation(incoming_model_weights, client_dataset):
            """Reconstructs and evaluates with `incoming_model_weights`."""
            recon_dataset, eval_dataset = dataset_split_fn(client_dataset)

            # Assign incoming global weights to `client_model` before reconstruction.
            tf.nest.map_structure(lambda v, t: v.assign(t),
                                  client_global_weights,
                                  incoming_model_weights)

            recon_dataset.reduce(tf.constant(0), reconstruction_reduce_fn)
            eval_dataset.reduce(tf.constant(0), evaluation_reduce_fn)

            eval_local_outputs = keras_utils.read_metric_variables(metrics)
            return eval_local_outputs

        return tf_client_computation(incoming_model_weights, client_dataset)

    @federated_computation.federated_computation(
        computation_types.at_server(model_weights_type),
        computation_types.at_clients(
            computation_types.SequenceType(batch_type)))
    def server_eval(server_model_weights: computation_types.FederatedType,
                    federated_dataset: computation_types.FederatedType):
        broadcast_output = broadcast_process.next(
            broadcast_process.initialize(), server_model_weights)
        client_outputs = intrinsics.federated_map(
            client_computation, [broadcast_output.result, federated_dataset])
        aggregated_client_outputs = federated_output_computation(
            client_outputs)
        measurements = intrinsics.federated_zip(
            collections.OrderedDict(broadcast=broadcast_output.measurements,
                                    eval=aggregated_client_outputs))
        return measurements

    return server_eval
예제 #11
0
def _build_mime_lite_client_work(
    model_fn: Callable[[], model_lib.Model],
    optimizer: optimizer_base.Optimizer,
    client_weighting: client_weight_lib.ClientWeighting,
    full_gradient_aggregator: Optional[
        factory.WeightedAggregationFactory] = None,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    use_experimental_simulation_loop: bool = False
) -> client_works.ClientWorkProcess:
    """Creates a `ClientWorkProcess` for Mime Lite.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    optimizer: A `tff.learning.optimizers.Optimizer` which will be used for both
      creating and updating a global optimizer state, as well as optimization at
      clients given the global state, which is fixed during the optimization.
    client_weighting: A member of `tff.learning.ClientWeighting` that specifies
      a built-in weighting method.
    full_gradient_aggregator: An optional
      `tff.aggregators.WeightedAggregationFactory` used to aggregate the full
      gradients on client datasets. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `ClientWorkProcess`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_type(optimizer, optimizer_base.Optimizer)
    py_typecheck.check_type(client_weighting,
                            client_weight_lib.ClientWeighting)
    if full_gradient_aggregator is None:
        full_gradient_aggregator = mean.MeanFactory()
    py_typecheck.check_type(full_gradient_aggregator,
                            factory.WeightedAggregationFactory)
    if metrics_aggregator is None:
        metrics_aggregator = metric_aggregator.sum_then_finalize

    with tf.Graph().as_default():
        # Wrap model construction in a graph to avoid polluting the global context
        # with variables created for this model.
        model = model_fn()
        unfinalized_metrics_type = type_conversions.type_from_tensors(
            model.report_local_unfinalized_metrics())
        metrics_aggregation_fn = metrics_aggregator(model.metric_finalizers(),
                                                    unfinalized_metrics_type)
    data_type = computation_types.SequenceType(model.input_spec)
    weights_type = model_utils.weights_type_from_model(model)
    weight_tensor_specs = type_conversions.type_to_tf_tensor_specs(
        weights_type)

    full_gradient_aggregator = full_gradient_aggregator.create(
        weights_type.trainable, computation_types.TensorType(tf.float32))

    @federated_computation.federated_computation
    def init_fn():
        specs = weight_tensor_specs.trainable
        optimizer_state = intrinsics.federated_eval(
            tensorflow_computation.tf_computation(
                lambda: optimizer.initialize(specs)), placements.SERVER)
        aggregator_state = full_gradient_aggregator.initialize()
        return intrinsics.federated_zip((optimizer_state, aggregator_state))

    client_update_fn = _build_client_update_fn_for_mime_lite(
        model_fn, optimizer, client_weighting,
        use_experimental_simulation_loop)

    @tensorflow_computation.tf_computation(
        init_fn.type_signature.result.member[0], weights_type.trainable)
    def update_optimizer_state(state, aggregate_gradient):
        whimsy_weights = tf.nest.map_structure(
            lambda g: tf.zeros(g.shape, g.dtype), aggregate_gradient)
        updated_state, _ = optimizer.next(state, whimsy_weights,
                                          aggregate_gradient)
        return updated_state

    @federated_computation.federated_computation(
        init_fn.type_signature.result,
        computation_types.at_clients(weights_type),
        computation_types.at_clients(data_type))
    def next_fn(state, weights, client_data):
        optimizer_state, aggregator_state = state
        optimizer_state_at_clients = intrinsics.federated_broadcast(
            optimizer_state)
        client_result, model_outputs, full_gradient = (
            intrinsics.federated_map(
                client_update_fn,
                (optimizer_state_at_clients, weights, client_data)))
        full_gradient_agg_output = full_gradient_aggregator.next(
            aggregator_state, full_gradient, client_result.update_weight)
        updated_optimizer_state = intrinsics.federated_map(
            update_optimizer_state,
            (optimizer_state, full_gradient_agg_output.result))

        new_state = intrinsics.federated_zip(
            (updated_optimizer_state, full_gradient_agg_output.state))
        train_metrics = metrics_aggregation_fn(model_outputs)
        measurements = intrinsics.federated_zip(
            collections.OrderedDict(train=train_metrics))
        return measured_process.MeasuredProcessOutput(new_state, client_result,
                                                      measurements)

    return client_works.ClientWorkProcess(init_fn, next_fn)
def build_scheduled_client_work(
    model_fn: Callable[[], model_lib.Model],
    learning_rate_fn: Callable[[int], float],
    optimizer_fn: Callable[[float], TFFOrKerasOptimizer],
    metrics_aggregator: Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation],
    use_experimental_simulation_loop: bool = False
) -> client_works.ClientWorkProcess:
  """Creates a `ClientWorkProcess` for federated averaging.

  This `ClientWorkProcess` creates a state containing the current round number,
  which is incremented at each call to `ClientWorkProcess.next`. This integer
  round number is used to call `optimizer_fn(round_num)`, in order to construct
  the proper optimizer.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    learning_rate_fn: A callable accepting an integer round number and returning
      a float to be used as a learning rate for the optimizer. That is, the
      client work will call `optimizer_fn(learning_rate_fn(round_num))` where
      `round_num` is the integer round number.
    optimizer_fn: A callable accepting a float learning rate, and returning a
      `tff.learning.optimizers.Optimizer` or a `tf.keras.Optimizer`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `ClientWorkProcess`.
  """
  with tf.Graph().as_default():
    # Wrap model construction in a graph to avoid polluting the global context
    # with variables created for this model.
    whimsy_model = model_fn()
    whimsy_optimizer = optimizer_fn(1.0)
    unfinalized_metrics_type = type_conversions.type_from_tensors(
        whimsy_model.report_local_unfinalized_metrics())
    metrics_aggregation_fn = metrics_aggregator(
        whimsy_model.metric_finalizers(), unfinalized_metrics_type)
  data_type = computation_types.SequenceType(whimsy_model.input_spec)
  weights_type = model_utils.weights_type_from_model(whimsy_model)

  if isinstance(whimsy_optimizer, optimizer_base.Optimizer):
    build_client_update_fn = model_delta_client_work.build_model_delta_update_with_tff_optimizer
  else:
    build_client_update_fn = model_delta_client_work.build_model_delta_update_with_keras_optimizer

  @tensorflow_computation.tf_computation(weights_type, data_type, tf.int32)
  def client_update_computation(initial_model_weights, dataset, round_num):
    learning_rate = learning_rate_fn(round_num)
    optimizer = optimizer_fn(learning_rate)
    client_update = build_client_update_fn(
        model_fn=model_fn,
        weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES,
        use_experimental_simulation_loop=use_experimental_simulation_loop)
    return client_update(optimizer, initial_model_weights, dataset)

  @federated_computation.federated_computation
  def init_fn():
    return intrinsics.federated_value(0, placements.SERVER)

  @tensorflow_computation.tf_computation(tf.int32)
  @tf.function
  def add_one(x):
    return x + 1

  @federated_computation.federated_computation(
      init_fn.type_signature.result, computation_types.at_clients(weights_type),
      computation_types.at_clients(data_type))
  def next_fn(state, weights, client_data):
    round_num_at_clients = intrinsics.federated_broadcast(state)
    client_result, model_outputs = intrinsics.federated_map(
        client_update_computation, (weights, client_data, round_num_at_clients))
    updated_state = intrinsics.federated_map(add_one, state)
    train_metrics = metrics_aggregation_fn(model_outputs)
    measurements = intrinsics.federated_zip(
        collections.OrderedDict(train=train_metrics))
    return measured_process.MeasuredProcessOutput(updated_state, client_result,
                                                  measurements)

  return client_works.ClientWorkProcess(init_fn, next_fn)
예제 #13
0
  def create(
      self,
      metric_finalizers: model_lib.MetricFinalizersType,
      local_unfinalized_metrics_type: computation_types.StructWithPythonType,
      initial_unfinalized_metrics: Optional[OrderedDict[str, Any]] = None
  ) -> aggregation_process.AggregationProcess:
    """Creates a `tff.templates.AggregationProcess` for metrics aggregation.

    Args:
      metric_finalizers: An `OrderedDict` of metric names to finalizers, should
        have same keys as the unfinalized metrics. A finalizer is a function
        (typically a `tf.function` decorated callable or a `tff.tf_computation`
        decoreated TFF Computation) that takes in a metric's unfinalized values,
        and returns the finalized metric values. This can be obtained from
        `tff.learning.Model.metric_finalizers()`.
      local_unfinalized_metrics_type: A `tff.types.StructWithPythonType` (with
        `OrderedDict` as the Python container) of a client's local unfinalized
        metrics. Let `local_unfinalized_metrics` be the output of
        `tff.learning.Model.report_local_unfinalized_metrics()`, its type can be
        obtained by
        `tff.framework.type_from_tensors(local_unfinalized_metrics)`.
      initial_unfinalized_metrics: Optional. An `OrderedDict` of metric names to
        the initial values of local unfinalized metrics, its structure should
        match that of `local_unfinalized_metrics_type`. If not specified,
        defaults to zero.

    Returns:
      An instance of `tff.templates.AggregationProcess`.

    Raises:
      TypeError: If any argument type mismatches; if the metric finalizers
        mismatch the type of local unfinalized metrics; if the initial
        unfinalized metrics mismatch the type of local unfinalized metrics.
    """
    aggregator.check_metric_finalizers(metric_finalizers)
    aggregator.check_local_unfinalzied_metrics_type(
        local_unfinalized_metrics_type)
    aggregator.check_finalizers_matches_unfinalized_metrics(
        metric_finalizers, local_unfinalized_metrics_type)

    inner_summation_process = sum_factory_lib.SumFactory().create(
        local_unfinalized_metrics_type)

    @federated_computation.federated_computation
    def init_fn():
      unfinalized_metrics_accumulators = (
          _intialize_unfinalized_metrics_accumulators(
              local_unfinalized_metrics_type, initial_unfinalized_metrics))
      return intrinsics.federated_zip((inner_summation_process.initialize(),
                                       unfinalized_metrics_accumulators))

    @federated_computation.federated_computation(
        init_fn.type_signature.result,
        computation_types.at_clients(local_unfinalized_metrics_type))
    def next_fn(state,
                unfinalized_metrics) -> measured_process.MeasuredProcessOutput:
      inner_summation_state, unfinalized_metrics_accumulators = state

      inner_summation_output = inner_summation_process.next(
          inner_summation_state, unfinalized_metrics)
      summed_unfinalized_metrics = inner_summation_output.result
      inner_summation_state = inner_summation_output.state

      @tensorflow_computation.tf_computation(local_unfinalized_metrics_type,
                                             local_unfinalized_metrics_type)
      def add_unfinalized_metrics(unfinalized_metrics,
                                  summed_unfinalized_metrics):
        return tf.nest.map_structure(tf.add, unfinalized_metrics,
                                     summed_unfinalized_metrics)

      unfinalized_metrics_accumulators = intrinsics.federated_map(
          add_unfinalized_metrics,
          (unfinalized_metrics_accumulators, summed_unfinalized_metrics))

      finalizer_computation = _build_finalizer_computation(
          metric_finalizers, local_unfinalized_metrics_type)

      current_round_metrics = intrinsics.federated_map(
          finalizer_computation, summed_unfinalized_metrics)
      total_rounds_metrics = intrinsics.federated_map(
          finalizer_computation, unfinalized_metrics_accumulators)

      return measured_process.MeasuredProcessOutput(
          state=intrinsics.federated_zip(
              (inner_summation_state, unfinalized_metrics_accumulators)),
          result=intrinsics.federated_zip(
              (current_round_metrics, total_rounds_metrics)),
          measurements=inner_summation_output.measurements)

    return aggregation_process.AggregationProcess(init_fn, next_fn)
예제 #14
0
 def test_clients_placed(self):
     x = _mock_data_of_type(
         computation_types.at_clients(
             computation_types.SequenceType(tf.int32)))
     val = intrinsics.sequence_sum(x)
     self.assert_value(val, '{int32}@CLIENTS')
    async def compute_federated_select(
        self, arg: FederatedResolvingStrategyValue
    ) -> FederatedResolvingStrategyValue:
        client_keys_type, max_key_type, server_val_type, select_fn_type = (
            arg.type_signature)
        py_typecheck.check_type(arg.internal_representation, structure.Struct)
        client_keys, max_key, server_val, select_fn = arg.internal_representation
        # We slice up the value as-needed, so `max_key` is not used.
        del max_key, max_key_type
        del server_val_type  # unused
        py_typecheck.check_type(client_keys, list)
        py_typecheck.check_type(server_val, list)
        server_val_at_server = server_val[0]
        py_typecheck.check_type(server_val_at_server,
                                executor_value_base.ExecutorValue)
        py_typecheck.check_type(select_fn, pb.Computation)
        server = self._target_executors[placements.SERVER][0]
        clients = self._target_executors[placements.CLIENTS]
        single_key_type = computation_types.TensorType(tf.int32)
        client_keys_type.member.check_tensor()
        if (client_keys_type.member.dtype != tf.int32
                or client_keys_type.member.shape.rank != 1):
            raise TypeError(
                f'Unexpected `client_keys_type`: {client_keys_type}')
        num_keys_per_client: int = client_keys_type.member.shape.dims[0].value
        unplaced_result_type = computation_types.SequenceType(
            select_fn_type.result)
        select_fn_at_server = await server.create_value(
            select_fn, select_fn_type)
        index_fn_at_server = await executor_utils.embed_indexing_operator(
            server, client_keys_type.member, single_key_type)

        async def select_single_key(keys_at_server, key_index):
            # Grab the `key_index`th key from the keys tensor.
            index_arg = await server.create_struct(
                structure.Struct([
                    (None, keys_at_server),
                    (None, await server.create_value(key_index,
                                                     single_key_type)),
                ]))
            key_at_server = await server.create_call(index_fn_at_server,
                                                     index_arg)
            select_fn_arg = await server.create_struct(
                structure.Struct([
                    (None, server_val_at_server),
                    (None, key_at_server),
                ]))
            selected = await server.create_call(select_fn_at_server,
                                                select_fn_arg)
            return await selected.compute()

        async def select_single_client(client, keys_at_client):
            keys_at_server = await server.create_value(
                await keys_at_client.compute(), client_keys_type.member)
            unplaced_values = await asyncio.gather(*[
                select_single_key(keys_at_server, i)
                for i in range(num_keys_per_client)
            ])
            return await client.create_value(unplaced_values,
                                             unplaced_result_type)

        return FederatedResolvingStrategyValue(
            list(await asyncio.gather(*[
                select_single_client(client, keys_at_client)
                for client, keys_at_client in zip(clients, client_keys)
            ])), computation_types.at_clients(unplaced_result_type))
예제 #16
0
def build_model_delta_client_work(
    model_fn: Callable[[], model_lib.Model],
    optimizer: Union[optimizer_base.Optimizer,
                     Callable[[], tf.keras.optimizers.Optimizer]],
    client_weighting: client_weight_lib.ClientWeighting,
    delta_l2_regularizer: float = 0.0,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    *,
    use_experimental_simulation_loop: bool = False
) -> client_works.ClientWorkProcess:
    """Creates a `ClientWorkProcess` for federated averaging.

  This client work is constructed in slightly different manners depending on
  whether `optimizer` is a `tff.learning.optimizers.Optimizer`, or a no-arg
  callable returning a `tf.keras.optimizers.Optimizer`.

  If it is a `tff.learning.optimizers.Optimizer`, we avoid creating
  `tf.Variable`s associated with the optimizer state within the scope of the
  client work, as they are not necessary. This also means that the client's
  model weights are updated by computing `optimizer.next` and then assigning
  the result to the model weights (while a `tf.keras.optimizers.Optimizer` will
  modify the model weight in place using `optimizer.apply_gradients`).

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    optimizer: A `tff.learning.optimizers.Optimizer`, or a no-arg callable that
      returns a `tf.keras.Optimizer`.
    client_weighting:  A `tff.learning.ClientWeighting` value.
    delta_l2_regularizer: A nonnegative float representing the parameter of the
      L2-regularization term applied to the delta from initial model weights
      during training. Values larger than 0.0 prevent clients from moving too
      far from the server model during local training.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `ClientWorkProcess`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_type(client_weighting,
                            client_weight_lib.ClientWeighting)
    py_typecheck.check_type(delta_l2_regularizer, float)
    if delta_l2_regularizer < 0.0:
        raise ValueError(f'Provided delta_l2_regularizer must be non-negative,'
                         f'but found: {delta_l2_regularizer}')
    if not (isinstance(optimizer, optimizer_base.Optimizer)
            or callable(optimizer)):
        raise TypeError(
            'Provided optimizer must a either a tff.learning.optimizers.Optimizer '
            'or a no-arg callable returning an tf.keras.optimizers.Optimizer.')

    if metrics_aggregator is None:
        metrics_aggregator = aggregator.sum_then_finalize

    with tf.Graph().as_default():
        # Wrap model construction in a graph to avoid polluting the global context
        # with variables created for this model.
        model = model_fn()
        unfinalized_metrics_type = type_conversions.type_from_tensors(
            model.report_local_unfinalized_metrics())
        metrics_aggregation_fn = metrics_aggregator(model.metric_finalizers(),
                                                    unfinalized_metrics_type)
    data_type = computation_types.SequenceType(model.input_spec)
    weights_type = model_utils.weights_type_from_model(model)

    if isinstance(optimizer, optimizer_base.Optimizer):

        @tensorflow_computation.tf_computation(weights_type, data_type)
        def client_update_computation(initial_model_weights, dataset):
            client_update = build_model_delta_update_with_tff_optimizer(
                model_fn=model_fn,
                weighting=client_weighting,
                delta_l2_regularizer=delta_l2_regularizer,
                use_experimental_simulation_loop=
                use_experimental_simulation_loop)
            return client_update(optimizer, initial_model_weights, dataset)

    else:

        @tensorflow_computation.tf_computation(weights_type, data_type)
        def client_update_computation(initial_model_weights, dataset):
            keras_optimizer = optimizer()
            client_update = build_model_delta_update_with_keras_optimizer(
                model_fn=model_fn,
                weighting=client_weighting,
                delta_l2_regularizer=delta_l2_regularizer,
                use_experimental_simulation_loop=
                use_experimental_simulation_loop)
            return client_update(keras_optimizer, initial_model_weights,
                                 dataset)

    @federated_computation.federated_computation
    def init_fn():
        return intrinsics.federated_value((), placements.SERVER)

    @federated_computation.federated_computation(
        init_fn.type_signature.result,
        computation_types.at_clients(weights_type),
        computation_types.at_clients(data_type))
    def next_fn(state, weights, client_data):
        client_result, model_outputs = intrinsics.federated_map(
            client_update_computation, (weights, client_data))
        train_metrics = metrics_aggregation_fn(model_outputs)
        measurements = intrinsics.federated_zip(
            collections.OrderedDict(train=train_metrics))
        return measured_process.MeasuredProcessOutput(state, client_result,
                                                      measurements)

    return client_works.ClientWorkProcess(init_fn, next_fn)
예제 #17
0
def build_federated_evaluation(
    model_fn: Callable[[], model_lib.Model],
    broadcast_process: Optional[measured_process.MeasuredProcess] = None,
    use_experimental_simulation_loop: bool = False,
) -> computation_base.Computation:
    """Builds the TFF computation for federated evaluation of the given model.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the
      model weights on the server to the clients. It must support the signature
      `(input_values@SERVER -> output_values@CLIENTS)` and have empty state. If
      set to default None, the server model is broadcast to the clients using
      the default tff.federated_broadcast.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation.

  Returns:
    A federated computation (an instance of `tff.Computation`) that accepts
    model parameters and federated data, and returns the evaluation metrics
    as aggregated by `tff.learning.Model.federated_output_computation`.
  """
    if broadcast_process is not None:
        if not isinstance(broadcast_process, measured_process.MeasuredProcess):
            raise ValueError(
                '`broadcast_process` must be a `MeasuredProcess`, got '
                f'{type(broadcast_process)}.')
        if optimizer_utils.is_stateful_process(broadcast_process):
            raise ValueError(
                'Cannot create a federated evaluation with a stateful '
                'broadcast process, must be stateless, has state: '
                f'{broadcast_process.initialize.type_signature.result!r}')
    # Construct the model first just to obtain the metadata and define all the
    # types needed to define the computations that follow.
    # TODO(b/124477628): Ideally replace the need for stamping throwaway models
    # with some other mechanism.
    with tf.Graph().as_default():
        model = model_fn()
        model_weights_type = model_utils.weights_type_from_model(model)
        batch_type = computation_types.to_type(model.input_spec)

    @computations.tf_computation(model_weights_type, SequenceType(batch_type))
    @tf.function
    def client_eval(incoming_model_weights, dataset):
        """Returns local outputs after evaluting `model_weights` on `dataset`."""
        with tf.init_scope():
            model = model_fn()
        model_weights = model_utils.ModelWeights.from_model(model)
        tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
                              incoming_model_weights)

        def reduce_fn(num_examples, batch):
            model_output = model.forward_pass(batch, training=False)
            if model_output.num_examples is None:
                # Compute shape from the size of the predictions if model didn't use the
                # batch size.
                return num_examples + tf.shape(model_output.predictions,
                                               out_type=tf.int64)[0]
            else:
                return num_examples + tf.cast(model_output.num_examples,
                                              tf.int64)

        dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
            use_experimental_simulation_loop)
        num_examples = dataset_reduce_fn(
            reduce_fn=reduce_fn,
            dataset=dataset,
            initial_state_fn=lambda: tf.zeros([], dtype=tf.int64))
        return collections.OrderedDict(
            local_outputs=model.report_local_outputs(),
            num_examples=num_examples)

    @computations.federated_computation(
        computation_types.at_server(model_weights_type),
        computation_types.at_clients(SequenceType(batch_type)))
    def server_eval(server_model_weights, federated_dataset):
        if broadcast_process is not None:
            # TODO(b/179091838): Zip the measurements from the broadcast_process with
            # the result of `model.federated_output_computation` below to avoid
            # dropping these metrics.
            broadcast_output = broadcast_process.next(
                broadcast_process.initialize(), server_model_weights)
            client_outputs = intrinsics.federated_map(
                client_eval, (broadcast_output.result, federated_dataset))
        else:
            client_outputs = intrinsics.federated_map(client_eval, [
                intrinsics.federated_broadcast(server_model_weights),
                federated_dataset
            ])
        model_metrics = model.federated_output_computation(
            client_outputs.local_outputs)
        statistics = collections.OrderedDict(
            num_examples=intrinsics.federated_sum(client_outputs.num_examples))
        return intrinsics.federated_zip(
            collections.OrderedDict(eval=model_metrics, stat=statistics))

    return server_eval
예제 #18
0
def build_functional_model_delta_client_work(
    *,
    model: functional.FunctionalModel,
    optimizer: optimizer_base.Optimizer,
    client_weighting: client_weight_lib.ClientWeighting,
    delta_l2_regularizer: float = 0.0,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
) -> client_works.ClientWorkProcess:
    """Creates a `ClientWorkProcess` for federated averaging.

  This differs from `tff.learning.templates.build_model_delta_client_work` in
  that it only accepts `tff.learning.models.FunctionalModel` and
  `tff.learning.optimizers.Optimizer` type arguments, resulting in TensorFlow
  graphs that do not contain `tf.Variable` operations.

  Args:
    model: A `tff.learning.models.FunctionalModel` to train.
    optimizer: A `tff.learning.optimizers.Optimizer` to use for local, on-client
      optimization.
    client_weighting:  A `tff.learning.ClientWeighting` value.
    delta_l2_regularizer: A nonnegative float representing the parameter of the
      L2-regularization term applied to the delta from initial model weights
      during training. Values larger than 0.0 prevent clients from moving too
      far from the server model during local training.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.

  Returns:
    A `ClientWorkProcess`.
  """
    py_typecheck.check_type(model, functional.FunctionalModel)
    py_typecheck.check_type(optimizer, optimizer_base.Optimizer)
    py_typecheck.check_type(client_weighting,
                            client_weight_lib.ClientWeighting)
    py_typecheck.check_type(delta_l2_regularizer, float)
    if delta_l2_regularizer < 0.0:
        raise ValueError(f'Provided delta_l2_regularizer must be non-negative,'
                         f'but found: {delta_l2_regularizer}')

    if metrics_aggregator is None:
        metrics_aggregator = aggregator.sum_then_finalize

    # TODO(b/229612282): Add metrics implementation.

    data_type = computation_types.SequenceType(model.input_spec)

    def ndarray_to_tensorspec(ndarray):
        return tf.TensorSpec(shape=ndarray.shape,
                             dtype=tf.dtypes.as_dtype(ndarray.dtype))

    # Wrap in a `ModelWeights` structure that is required by the `finalizer.`
    weights_type = model_utils.ModelWeights(
        tuple(ndarray_to_tensorspec(w) for w in model.initial_weights[0]),
        tuple(ndarray_to_tensorspec(w) for w in model.initial_weights[1]))

    @tensorflow_computation.tf_computation(weights_type, data_type)
    def client_update_computation(initial_model_weights, dataset):
        # Switch to the tuple expected by FunctionalModel.
        initial_model_weights = (initial_model_weights.trainable,
                                 initial_model_weights.non_trainable)
        client_update = build_functional_model_delta_update(
            model=model,
            weighting=client_weighting,
            delta_l2_regularizer=delta_l2_regularizer)
        return client_update(optimizer, initial_model_weights, dataset)

    @federated_computation.federated_computation
    def init_fn():
        # Empty tuple means "no state" / stateless.
        return intrinsics.federated_value((), placements.SERVER)

    @federated_computation.federated_computation(
        computation_types.at_server(()),
        computation_types.at_clients(weights_type),
        computation_types.at_clients(data_type))
    def next_fn(state, weights, client_data):
        client_result, model_outputs = intrinsics.federated_map(
            client_update_computation, (weights, client_data))
        # TODO(b/229612282): Add metrics computations
        del model_outputs
        measurements = intrinsics.federated_value((), placements.SERVER)
        return measured_process.MeasuredProcessOutput(state, client_result,
                                                      measurements)

    return client_works.ClientWorkProcess(init_fn, next_fn)
예제 #19
0
    type_signature = computation_types.FunctionType(
        computation_types.at_clients(tf.float32),
        computation_types.at_server(tf.float32))
    return value, type_signature


def create_whimsy_intrinsic_def_federated_secure_sum_bitwidth():
    value = intrinsic_defs.FEDERATED_SECURE_SUM_BITWIDTH
    type_signature = computation_types.FunctionType([
        computation_types.at_clients(tf.int32),
        tf.int32,
    ], computation_types.at_server(tf.int32))
    return value, type_signature


_WHIMSY_SELECT_CLIENT_KEYS_TYPE = computation_types.at_clients(
    computation_types.TensorType(tf.int32, [3]))
_WHIMSY_SELECT_MAX_KEY_TYPE = computation_types.at_server(tf.int32)
_WHIMSY_SELECT_SERVER_STATE_TYPE = computation_types.at_server(tf.string)
_WHIMSY_SELECTED_TYPE = computation_types.to_type((tf.string, tf.int32))
_WHIMSY_SELECT_SELECT_FN_TYPE = computation_types.FunctionType(
    (tf.string, tf.int32), _WHIMSY_SELECTED_TYPE)
_WHIMSY_SELECT_RESULT_TYPE = computation_types.at_clients(
    computation_types.SequenceType(_WHIMSY_SELECTED_TYPE))
_WHIMSY_SELECT_TYPE = computation_types.FunctionType([
    _WHIMSY_SELECT_CLIENT_KEYS_TYPE,
    _WHIMSY_SELECT_MAX_KEY_TYPE,
    _WHIMSY_SELECT_SERVER_STATE_TYPE,
    _WHIMSY_SELECT_SELECT_FN_TYPE,
], _WHIMSY_SELECT_RESULT_TYPE)
_WHIMSY_SELECT_NUM_CLIENTS = 3
예제 #20
0
    def test_type_properties(self, value_type, mechanism):
        ddp_factory = _make_test_factory(mechanism=mechanism)
        self.assertIsInstance(ddp_factory,
                              factory.UnweightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        process = ddp_factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        # The state is a nested object with component factory states. Construct
        # test factories directly and compare the signatures.
        modsum_f = secure.SecureModularSumFactory(2**15, True)

        if mechanism == 'distributed_dgauss':
            dp_query = tfp.DistributedDiscreteGaussianSumQuery(
                l2_norm_bound=10.0, local_stddev=10.0)
        else:
            dp_query = tfp.DistributedSkellamSumQuery(l1_norm_bound=10.0,
                                                      l2_norm_bound=10.0,
                                                      local_stddev=10.0)

        dp_f = differential_privacy.DifferentiallyPrivateFactory(
            dp_query, modsum_f)
        discrete_f = discretization.DiscretizationFactory(dp_f)
        l2clip_f = robust.clipping_factory(clipping_norm=10.0,
                                           inner_agg_factory=discrete_f)
        rot_f = rotation.HadamardTransformFactory(inner_agg_factory=l2clip_f)
        expected_process = concat.concat_factory(rot_f).create(value_type)

        # Check init_fn/state.
        expected_init_type = expected_process.initialize.type_signature
        expected_state_type = expected_init_type.result
        actual_init_type = process.initialize.type_signature
        self.assertTrue(actual_init_type.is_equivalent_to(expected_init_type))

        # Check next_fn/measurements.
        tensor2type = type_conversions.type_from_tensors
        discrete_state = discrete_f.create(
            computation_types.to_type(tf.float32)).initialize()
        dp_query_state = dp_query.initial_global_state()
        dp_query_metrics_type = tensor2type(
            dp_query.derive_metrics(dp_query_state))
        expected_measurements_type = collections.OrderedDict(
            l2_clip=robust.NORM_TF_TYPE,
            scale_factor=tensor2type(discrete_state['scale_factor']),
            scaled_inflated_l2=tensor2type(dp_query_state.l2_norm_bound),
            scaled_local_stddev=tensor2type(dp_query_state.local_stddev),
            actual_num_clients=tf.int32,
            padded_dim=tf.int32,
            dp_query_metrics=dp_query_metrics_type)
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=expected_state_type,
                result=computation_types.at_server(value_type),
                measurements=computation_types.at_server(
                    expected_measurements_type)))
        actual_next_type = process.next.type_signature
        self.assertTrue(actual_next_type.is_equivalent_to(expected_next_type))
        try:
            static_assert.assert_not_contains_unsecure_aggregation(
                process.next)
        except:  # pylint: disable=bare-except
            self.fail('Factory returned an AggregationProcess containing '
                      'non-secure aggregation.')
예제 #21
0
def create_whimsy_intrinsic_def_federated_eval_at_clients():
    value = intrinsic_defs.FEDERATED_EVAL_AT_CLIENTS
    type_signature = computation_types.FunctionType(
        computation_types.FunctionType(None, tf.float32),
        computation_types.at_clients(tf.float32))
    return value, type_signature
예제 #22
0

def _clipped_sum(clip=2.0):
    return robust.clipping_factory(clip, sum_factory.SumFactory())


def _zeroed_mean(clip=2.0, norm_order=2.0):
    return robust.zeroing_factory(clip, mean.MeanFactory(), norm_order)


def _zeroed_sum(clip=2.0, norm_order=2.0):
    return robust.zeroing_factory(clip, sum_factory.SumFactory(), norm_order)


_float_at_server = computation_types.at_server(tf.float32)
_float_at_clients = computation_types.at_clients(tf.float32)


@computations.federated_computation()
def _test_init_fn():
    return intrinsics.federated_value(1., placements.SERVER)


@computations.federated_computation(_float_at_server, _float_at_clients)
def _test_next_fn(state, value):
    del value
    return intrinsics.federated_map(
        computations.tf_computation(lambda x: x + 1., tf.float32), state)


@computations.federated_computation(_float_at_server)
예제 #23
0
def create_whimsy_intrinsic_def_federated_value_at_clients():
    value = intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS
    type_signature = computation_types.FunctionType(
        tf.float32, computation_types.at_clients(tf.float32, all_equal=True))
    return value, type_signature
예제 #24
0
from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.tensorflow_context import tensorflow_computation
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.templates import errors
from tensorflow_federated.python.core.templates import measured_process
from tensorflow_federated.python.learning import model_utils
from tensorflow_federated.python.learning.templates import client_works

SERVER_INT = computation_types.FederatedType(tf.int32, placements.SERVER)
SERVER_FLOAT = computation_types.FederatedType(tf.float32, placements.SERVER)
CLIENTS_FLOAT_SEQUENCE = computation_types.FederatedType(
    computation_types.SequenceType(tf.float32), placements.CLIENTS)
CLIENTS_FLOAT = computation_types.FederatedType(tf.float32, placements.CLIENTS)
CLIENTS_INT = computation_types.FederatedType(tf.int32, placements.CLIENTS)
MODEL_WEIGHTS_TYPE = computation_types.at_clients(
    computation_types.to_type(model_utils.ModelWeights(tf.float32, ())))
MeasuredProcessOutput = measured_process.MeasuredProcessOutput


def server_zero():
    return intrinsics.federated_value(0, placements.SERVER)


def client_one():
    return intrinsics.federated_value(1.0, placements.CLIENTS)


def federated_add(a, b):
    return intrinsics.federated_map(
        tensorflow_computation.tf_computation(lambda x, y: x + y), (a, b))
예제 #25
0
def create_whimsy_value_at_clients_all_equal():
    """Returns a Python value and federated type at clients and all equal."""
    value = 10.0
    type_signature = computation_types.at_clients(tf.float32, all_equal=True)
    return value, type_signature
예제 #26
0
def build_model_delta_client_work(model_fn: Callable[[], model_lib.Model],
                                  optimizer: optimizer_base.Optimizer):
  """Builds `ClientWorkProcess` returning change to the trained model weights.

  The created `ClientWorkProcess` expects model weights that can be assigned to
  the model created by `model_fn`, and will apply `optimizer` to optimize the
  model using the client data. The returned `ClientResult` will contain the
  difference between the trained and initial trainable model weights (aka
  "model delta") as update, and the update_weight will be the number of examples
  used in training. The type signature for client data is derived from the input
  spec of the model.

  This method is the recommended starting point for forking a custom
  implementation of the `ClientWorkProcess`.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    optimizer: A `tff.learning.optimizers.Optimizer`.

  Returns:
    A `ClientWorkProcess`.
  """
  py_typecheck.check_callable(model_fn)
  # TODO(b/190334722): Include support for Keras optimizers via
  # tff.learning.optimizers.KerasOptimizer when ready.
  py_typecheck.check_type(optimizer, optimizer_base.Optimizer)
  weights_type, data_type = _weights_and_data_type_from_model_fn(model_fn)
  # TODO(b/161529310): We flatten and convert the trainable specs to tuple, as
  # "for batch in data:" pattern would try to stack the tensors in a list.
  optimizer_tensor_specs = _flat_tuple(
      type_conversions.type_to_tf_tensor_specs(weights_type.trainable))

  @computations.tf_computation(weights_type, data_type)
  @tf.function
  def local_update(initial_weights, data):
    # TODO(b/190334722): Restructure so that model_fn only needs to be invoked
    # once.
    with tf.init_scope():
      model = model_fn()
    model_weights = model_utils.ModelWeights.from_model(model)

    tf.nest.map_structure(lambda weight, value: weight.assign(value),
                          model_weights, initial_weights)
    num_examples = tf.constant(0, tf.int32)
    optimizer_state = optimizer.initialize(optimizer_tensor_specs)

    # TODO(b/161529310): Different from creating an iterator using iter(data).
    for batch in data:
      with tf.GradientTape() as tape:
        outputs = model.forward_pass(batch)
      gradients = tape.gradient(outputs.loss, model_weights.trainable)
      num_examples += tf.shape(outputs.predictions)[0]

      optimizer_state, updated_weights = optimizer.next(
          optimizer_state, _flat_tuple(model_weights.trainable),
          _flat_tuple(gradients))
      updated_weights = tf.nest.pack_sequence_as(model_weights.trainable,
                                                 updated_weights)
      tf.nest.map_structure(lambda weight, value: weight.assign(value),
                            model_weights.trainable, updated_weights)

    model_delta = tf.nest.map_structure(lambda x, y: x - y,
                                        initial_weights.trainable,
                                        model_weights.trainable)
    return ClientResult(
        update=model_delta, update_weight=tf.cast(num_examples, tf.float32))

  @computations.federated_computation
  def init_fn():
    return intrinsics.federated_value((), placements.SERVER)

  @computations.federated_computation(
      init_fn.type_signature.result, computation_types.at_clients(weights_type),
      computation_types.at_clients(data_type))
  def next_fn(state, weights, client_data):
    client_result = intrinsics.federated_map(local_update,
                                             (weights, client_data))
    empty_measurements = intrinsics.federated_value((), placements.SERVER)
    return measured_process.MeasuredProcessOutput(state, client_result,
                                                  empty_measurements)

  return ClientWorkProcess(init_fn, next_fn)
예제 #27
0
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.templates import measured_process
from tensorflow_federated.python.learning import keras_utils
from tensorflow_federated.python.learning import model_examples
from tensorflow_federated.python.learning import model_utils
from tensorflow_federated.python.learning.optimizers import sgdm
from tensorflow_federated.python.learning.templates import client_works
from tensorflow_federated.python.learning.templates import composers
from tensorflow_federated.python.learning.templates import distributors
from tensorflow_federated.python.learning.templates import finalizers
from tensorflow_federated.python.learning.templates import learning_process

FLOAT_TYPE = computation_types.TensorType(tf.float32)
MODEL_WEIGHTS_TYPE = computation_types.to_type(
    model_utils.ModelWeights(FLOAT_TYPE, ()))
CLIENTS_SEQUENCE_FLOAT_TYPE = computation_types.at_clients(
    computation_types.SequenceType(FLOAT_TYPE))


def empty_at_server():
    return intrinsics.federated_value((), placements.SERVER)


@federated_computation.federated_computation()
def empty_init_fn():
    return empty_at_server()


@tensorflow_computation.tf_computation()
def test_init_model_weights_fn():
    return model_utils.ModelWeights(trainable=tf.constant(1.0),
                                    non_trainable=())
예제 #28
0
def test_init_fn():
  return intrinsics.federated_value(0, placements.SERVER)


test_state_type = test_init_fn.type_signature.result


@computations.tf_computation
def sum_sequence(s):
  spec = s.element_spec
  return s.reduce(
      tf.zeros(spec.shape, spec.dtype),
      lambda s, t: tf.nest.map_structure(tf.add, s, t))


ClientIntSequenceType = computation_types.at_clients(
    computation_types.SequenceType(tf.int32))


def build_next_fn(server_init_fn):

  @computations.federated_computation(server_init_fn.type_signature.result,
                                      ClientIntSequenceType)
  def next_fn(state, client_values):
    metrics = intrinsics.federated_map(sum_sequence, client_values)
    metrics = intrinsics.federated_sum(metrics)
    return LearningProcessOutput(state, metrics)

  return next_fn


def build_report_fn(server_init_fn):
예제 #29
0
async def compute_intrinsic_federated_weighted_mean(
    executor: executor_base.Executor,
    arg: executor_value_base.ExecutorValue,
    local_computation_factory: local_computation_factory_base.
    LocalComputationFactory = tensorflow_computation_factory.
    TensorFlowComputationFactory()
) -> executor_value_base.ExecutorValue:
    """Computes a federated weighted mean on the given `executor`.

  Args:
    executor: The executor to use.
    arg: The argument to embedded in `executor`.
    local_computation_factory: An instance of `LocalComputationFactory` to use
      to construct local computations used as parameters in certain federated
      operators (such as `tff.federated_sum`, etc.). Defaults to a TensorFlow
      computation factory that generates TensorFlow code.

  Returns:
    The result embedded in `executor`.
  """
    type_analysis.check_valid_federated_weighted_mean_argument_tuple_type(
        arg.type_signature)
    zip1_type = computation_types.FunctionType(
        computation_types.StructType([
            computation_types.at_clients(arg.type_signature[0].member),
            computation_types.at_clients(arg.type_signature[1].member)
        ]),
        computation_types.at_clients(
            computation_types.StructType(
                [arg.type_signature[0].member, arg.type_signature[1].member])))

    operand_type = zip1_type.result.member[0]
    scalar_type = zip1_type.result.member[1]
    multiply_comp_pb, multiply_comp_type = local_computation_factory.create_scalar_multiply_operator(
        operand_type, scalar_type)
    multiply_blk = building_blocks.CompiledComputation(
        multiply_comp_pb, type_signature=multiply_comp_type)
    map_type = computation_types.FunctionType(
        computation_types.StructType(
            [multiply_blk.type_signature, zip1_type.result]),
        computation_types.at_clients(multiply_blk.type_signature.result))

    sum1_type = computation_types.FunctionType(
        computation_types.at_clients(map_type.result.member),
        computation_types.at_server(map_type.result.member))

    sum2_type = computation_types.FunctionType(
        computation_types.at_clients(arg.type_signature[1].member),
        computation_types.at_server(arg.type_signature[1].member))

    zip2_type = computation_types.FunctionType(
        computation_types.StructType([sum1_type.result, sum2_type.result]),
        computation_types.at_server(
            computation_types.StructType(
                [sum1_type.result.member, sum2_type.result.member])))

    divide_blk = building_block_factory.create_tensorflow_binary_operator_with_upcast(
        zip2_type.result.member, tf.divide)

    async def _compute_multiply_fn():
        return await executor.create_value(multiply_blk.proto,
                                           multiply_blk.type_signature)

    async def _compute_multiply_arg():
        zip1_comp = create_intrinsic_comp(
            intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip1_type)
        zip_fn = await executor.create_value(zip1_comp, zip1_type)
        return await executor.create_call(zip_fn, arg)

    async def _compute_product_fn():
        map_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_MAP,
                                         map_type)
        return await executor.create_value(map_comp, map_type)

    async def _compute_product_arg():
        multiply_fn, multiply_arg = await asyncio.gather(
            _compute_multiply_fn(), _compute_multiply_arg())
        return await executor.create_struct((multiply_fn, multiply_arg))

    async def _compute_products():
        product_fn, product_arg = await asyncio.gather(_compute_product_fn(),
                                                       _compute_product_arg())
        return await executor.create_call(product_fn, product_arg)

    async def _compute_total_weight():
        sum2_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM,
                                          sum2_type)
        sum2_fn, sum2_arg = await asyncio.gather(
            executor.create_value(sum2_comp, sum2_type),
            executor.create_selection(arg, 1))
        return await executor.create_call(sum2_fn, sum2_arg)

    async def _compute_sum_of_products():
        sum1_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM,
                                          sum1_type)
        sum1_fn, products = await asyncio.gather(
            executor.create_value(sum1_comp, sum1_type), _compute_products())
        return await executor.create_call(sum1_fn, products)

    async def _compute_zip2_fn():
        zip2_comp = create_intrinsic_comp(
            intrinsic_defs.FEDERATED_ZIP_AT_SERVER, zip2_type)
        return await executor.create_value(zip2_comp, zip2_type)

    async def _compute_zip2_arg():
        sum_of_products, total_weight = await asyncio.gather(
            _compute_sum_of_products(), _compute_total_weight())
        return await executor.create_struct([sum_of_products, total_weight])

    async def _compute_divide_fn():
        return await executor.create_value(divide_blk.proto,
                                           divide_blk.type_signature)

    async def _compute_divide_arg():
        zip_fn, zip_arg = await asyncio.gather(_compute_zip2_fn(),
                                               _compute_zip2_arg())
        return await executor.create_call(zip_fn, zip_arg)

    async def _compute_apply_fn():
        apply_type = computation_types.FunctionType(
            computation_types.StructType(
                [divide_blk.type_signature, zip2_type.result]),
            computation_types.at_server(divide_blk.type_signature.result))
        apply_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_APPLY,
                                           apply_type)
        return await executor.create_value(apply_comp, apply_type)

    async def _compute_apply_arg():
        divide_fn, divide_arg = await asyncio.gather(_compute_divide_fn(),
                                                     _compute_divide_arg())
        return await executor.create_struct([divide_fn, divide_arg])

    async def _compute_divided():
        apply_fn, apply_arg = await asyncio.gather(_compute_apply_fn(),
                                                   _compute_apply_arg())
        return await executor.create_call(apply_fn, apply_arg)

    return await _compute_divided()
예제 #30
0
 def test_errors_on_client_int(self):
     with self.assertRaises(TypeError):
         x = _mock_data_of_type(
             computation_types.at_clients(tf.int32, all_equal=True))
         intrinsics.federated_broadcast(x)