Пример #1
0
    def test_bad_type_coercion_raises(self):
        tensor_type = tff.TensorType(shape=[None], dtype=tf.float32)

        @tff.tf_computation(tensor_type)
        def foo(x):
            # We will pass in a tensor which passes the TFF type check, but fails the
            # reshape.
            return tf.reshape(x, [])

        @tff.federated_computation(tff.type_at_clients(tensor_type))
        def map_foo_at_clients(x):
            return tff.federated_map(foo, x)

        @tff.federated_computation(tff.type_at_server(tensor_type))
        def map_foo_at_server(x):
            return tff.federated_map(foo, x)

        bad_tensor = tf.constant([1.] * 10, dtype=tf.float32)
        good_tensor = tf.constant([1.], dtype=tf.float32)
        # Ensure running this computation at both placements, or unplaced, still
        # raises.
        with self.assertRaises(Exception):
            foo(bad_tensor)
        with self.assertRaises(Exception):
            map_foo_at_server(bad_tensor)
        with self.assertRaises(Exception):
            map_foo_at_clients([bad_tensor] * 10)
        # We give the distributed runtime a chance to clean itself up, otherwise
        # workers may be getting SIGABRT while they are handling another exception,
        # causing the test infra to crash. Making a successful call ensures that
        # cleanup happens after failures have been handled.
        map_foo_at_clients([good_tensor] * 10)
Пример #2
0
def validator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN,
              client_state_fn: CLIENT_STATE_FN):
    model = model_fn()
    client_state = client_state_fn()

    dataset_type = tff.SequenceType(model.input_spec)
    client_state_type = tff.framework.type_from_tensors(client_state)
    weights_type = tff.learning.framework.weights_type_from_model(model)

    validate_client_tf = tff.tf_computation(
        lambda dataset, state, weights: __validate_client(
            dataset, state, weights, coefficient_fn, model_fn,
            tf.function(client.validate)),
        (dataset_type, client_state_type, weights_type))

    federated_weights_type = tff.type_at_server(weights_type)
    federated_dataset_type = tff.type_at_clients(dataset_type)
    federated_client_state_type = tff.type_at_clients(client_state_type)

    def validate(weights, datasets, client_states):
        broadcast = tff.federated_broadcast(weights)
        outputs = tff.federated_map(validate_client_tf,
                                    (datasets, client_states, broadcast))
        metrics = model.federated_output_computation(outputs.metrics)

        return metrics

    return tff.federated_computation(
        validate, (federated_weights_type, federated_dataset_type,
                   federated_client_state_type))
Пример #3
0
    def __attrs_post_init__(self):
        self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input)
        self.real_data_type = tensor_spec_for_batch(self.dummy_real_data)

        # Model-weights based types
        self._generator = self.generator_model_fn()
        _ = self._generator(self.dummy_gen_input)
        if not isinstance(self._generator, tf.keras.models.Model):
            raise TypeError(
                'Expected `tf.keras.models.Model`, found {}.'.format(
                    type(self._generator)))
        self._discriminator = self.discriminator_model_fn()
        _ = self._discriminator(self.dummy_real_data)
        if not isinstance(self._discriminator, tf.keras.models.Model):
            raise TypeError(
                'Expected `tf.keras.models.Model`, found {}.'.format(
                    type(self._discriminator)))

        def vars_to_type(var_struct):
            # TODO(b/131681951): read_value() shouldn't be needed
            return tf.nest.map_structure(
                lambda v: tf.TensorSpec.from_tensor(v.read_value()),
                var_struct)

        self.discriminator_weights_type = vars_to_type(
            self._discriminator.weights)
        self.generator_weights_type = vars_to_type(self._generator.weights)

        self.from_server_type = gan_training_tf_fns.FromServer(
            generator_weights=self.generator_weights_type,
            discriminator_weights=self.discriminator_weights_type,
            meta_gen=self.generator_weights_type,
            meta_disc=self.discriminator_weights_type)

        self.client_gen_input_type = tff.type_at_clients(
            tff.SequenceType(self.gen_input_type))
        self.client_real_data_type = tff.type_at_clients(
            tff.SequenceType(self.real_data_type))
        self.server_gen_input_type = tff.type_at_server(
            tff.SequenceType(self.gen_input_type))

        if self.train_discriminator_dp_average_query is not None:
            self.aggregation_process = tff.aggregators.DifferentiallyPrivateFactory(
                query=self.train_discriminator_dp_average_query).create(
                    value_type=tff.to_type(self.discriminator_weights_type))
        else:
            self.aggregation_process = tff.aggregators.MeanFactory().create(
                value_type=tff.to_type(self.discriminator_weights_type),
                weight_type=tff.to_type(tf.float32))
Пример #4
0
    def __attrs_post_init__(self):
        self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input)
        self.real_data_type = tensor_spec_for_batch(self.dummy_real_data)

        # Model-weights based types
        self._generator = self.generator_model_fn()
        _ = self._generator(self.dummy_gen_input)
        if not isinstance(self._generator, tf.keras.models.Model):
            raise TypeError(
                'Expected `tf.keras.models.Model`, found {}.'.format(
                    type(self._generator)))
        self._discriminator = self.discriminator_model_fn()
        _ = self._discriminator(self.dummy_real_data)
        if not isinstance(self._discriminator, tf.keras.models.Model):
            raise TypeError(
                'Expected `tf.keras.models.Model`, found {}.'.format(
                    type(self._discriminator)))

        def vars_to_type(var_struct):
            # TODO(b/131681951): read_value() shouldn't be needed
            return tf.nest.map_structure(
                lambda v: tf.TensorSpec.from_tensor(v.read_value()),
                var_struct)

        self.discriminator_weights_type = vars_to_type(
            self._discriminator.weights)
        self.generator_weights_type = vars_to_type(self._generator.weights)

        self.from_server_type = gan_training_tf_fns.FromServer(
            generator_weights=self.generator_weights_type,
            discriminator_weights=self.discriminator_weights_type)

        self.client_gen_input_type = tff.type_at_clients(
            tff.SequenceType(self.gen_input_type))
        self.client_real_data_type = tff.type_at_clients(
            tff.SequenceType(self.real_data_type))
        self.server_gen_input_type = tff.type_at_server(
            tff.SequenceType(self.gen_input_type))

        # Right now, the logic in this library is effectively "if DP use stateful
        # aggregator, else don't use stateful aggregator". An alternative
        # formulation would be to always use a stateful aggregator, but when not
        # using DP default the aggregator to be a stateless mean, e.g.,
        # https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/learning/framework/optimizer_utils.py#L283.
        if self.train_discriminator_dp_average_query is not None:
            self.dp_averaging_fn = tff.utils.build_dp_aggregate_process(
                value_type=tff.to_type(self.discriminator_weights_type),
                query=self.train_discriminator_dp_average_query)
Пример #5
0
def iterator(
  model_fn: MODEL_FN,
  client_state_fn: CLIENT_STATE_FN,
  client_optimizer_fn: OPTIMIZER_FN
):
  model = model_fn()
  client_state = client_state_fn()

  init_tf = tff.tf_computation(
    lambda: ()
  )
  
  server_state_type = init_tf.type_signature.result
  client_state_type = tff.framework.type_from_tensors(client_state)
  dataset_type = tff.SequenceType(model.input_spec)
  
  update_client_tf = tff.tf_computation(
    lambda dataset, state: __update_client(
      dataset,
      state,
      model_fn,
      client_optimizer_fn,
      tf.function(client.update)
    ),
    (dataset_type, client_state_type)
  )
  
  federated_server_state_type = tff.type_at_server(server_state_type)
  federated_dataset_type = tff.type_at_clients(dataset_type)
  federated_client_state_type = tff.type_at_clients(client_state_type)

  def init_tff():
    return tff.federated_value(init_tf(), tff.SERVER)
  
  def next_tff(server_state, datasets, client_states):
    outputs = tff.federated_map(update_client_tf, (datasets, client_states))
    metrics = model.federated_output_computation(outputs.metrics)

    return server_state, metrics, outputs.client_state

  return tff.templates.IterativeProcess(
    initialize_fn=tff.federated_computation(init_tff),
    next_fn=tff.federated_computation(
      next_tff,
      (federated_server_state_type, federated_dataset_type, federated_client_state_type)
    )
  )
Пример #6
0
def build_federated_reconstruction_process(
    model_fn: ModelFn,
    *,  # Callers pass below args by name.
    loss_fn: LossFn,
    metrics_fn: Optional[MetricsFn] = None,
    server_optimizer_fn: OptimizerFn = functools.partial(
        tf.keras.optimizers.SGD, 1.0),
    client_optimizer_fn: OptimizerFn = functools.partial(
        tf.keras.optimizers.SGD, 0.1),
    reconstruction_optimizer_fn: OptimizerFn = functools.partial(
        tf.keras.optimizers.SGD, 0.1),
    dataset_split_fn: Optional[reconstruction_utils.DatasetSplitFn] = None,
    evaluate_reconstruction: bool = False,
    jointly_train_variables: bool = False,
    client_weight_fn: Optional[ClientWeightFn] = None,
    aggregation_factory: Optional[
        tff.aggregators.WeightedAggregationFactory] = None,
) -> tff.templates.IterativeProcess:
    """Builds the IterativeProcess for optimization using FedRecon.

  Returns a `tff.templates.IterativeProcess` for Federated Reconstruction. On
  the client, computation can be divided into two stages: (1) reconstruction of
  local variables and (2) training of global variables (possibly jointly with
  reconstructed local variables).

  Args:
    model_fn: A no-arg function that returns a `ReconstructionModel`. This
      method must *not* capture Tensorflow tensors or variables and use them.
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to
      compute local model updates during reconstruction and post-reconstruction
      and evaluate the model during training. The final loss metric is the
      example-weighted mean loss across batches and across clients. Depending on
      whether `evaluate_reconstruction` is True, the loss metric may or may not
      include reconstruction batches in the loss.
    metrics_fn: A no-arg function returning a list of `tf.keras.metrics.Metric`s
      to evaluate the model. Metrics results are computed locally as described
      by the metric, and are aggregated across clients as in
      `federated_aggregate_keras_metric`. If None, no metrics are applied.
      Depending on whether evaluate_reconstruction is True, metrics may or may
      not be computed on reconstruction batches as well.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for applying updates to the global model
      on the server.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for local client training after
      reconstruction.
    reconstruction_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` used to reconstruct the local variables,
      with the global ones frozen, or the first stage described above.
    dataset_split_fn: A `reconstruction_utils.DatasetSplitFn` taking in a client
      dataset and training round number (1-indexed) and producing two TF
      datasets. The first is iterated over during reconstruction, and the second
      is iterated over post-reconstruction. This can be used to preprocess
      datasets to e.g. iterate over them for multiple epochs or use disjoint
      data for reconstruction and post-reconstruction. If None,
      `reconstruction_utils.simple_dataset_split_fn` is used, which results in
      iterating over the original client data for both phases of training. See
      `reconstruction_utils.build_dataset_split_fn` for options.
    evaluate_reconstruction: If True, metrics (including loss) are computed on
      batches during reconstruction and post-reconstruction. If False, metrics
      are computed on batches only post-reconstruction, when global weights are
      being updated. Note that metrics are aggregated across batches as given by
      the metric (example-weighted mean for the loss). Setting this to True
      includes all local batches in metric calculations. Setting this to False
      brings the interpretation of these metrics closer to the interpretation of
      metrics in FedAvg. Note that this does not affect training at all: losses
        for individual batches are calculated and used to update variables
        regardless.
    jointly_train_variables: Whether to train local variables during the second
      stage described above. If True, global and local variables are trained
      jointly after reconstruction of local variables using the optimizer given
      by client_optimizer_fn. If False, only global variables are trained during
      the second stage with local variables frozen, similar to alternating
      minimization.
    client_weight_fn: Optional function that takes the local model's output,
      and returns a tensor that provides the weight in the federated average of
      model deltas. If not provided, the default is the total number of examples
      processed on device during post-reconstruction phase.
    aggregation_factory: An optional instance of
      `tff.aggregators.WeightedAggregationFactory` determining the method of
      aggregation to perform. If unspecified, uses a default
      `tff.aggregators.MeanFactory` which computes a stateless weighted mean
      across clients.

  Returns:
    A `tff.templates.IterativeProcess`.
  """
    with tf.Graph().as_default():
        throwaway_model_for_metadata = model_fn()

    model_weights_type = tff.framework.type_from_tensors(
        reconstruction_utils.get_global_variables(
            throwaway_model_for_metadata))

    aggregation_process = _instantiate_aggregation_process(
        aggregation_factory, model_weights_type, client_weight_fn)
    aggregator_state_type = (
        aggregation_process.initialize.type_signature.result.member)

    server_init_tff = build_server_init_fn(model_fn, server_optimizer_fn,
                                           aggregation_process)
    server_state_type = server_init_tff.type_signature.result.member

    server_update_fn = build_server_update_fn(
        model_fn,
        server_optimizer_fn,
        server_state_type,
        server_state_type.model,
        aggregator_state_type=aggregator_state_type)

    tf_dataset_type = tff.SequenceType(throwaway_model_for_metadata.input_spec)
    if dataset_split_fn is None:
        dataset_split_fn = reconstruction_utils.simple_dataset_split_fn
    client_update_fn = build_client_update_fn(
        model_fn,
        loss_fn=loss_fn,
        metrics_fn=metrics_fn,
        tf_dataset_type=tf_dataset_type,
        model_weights_type=server_state_type.model,
        client_optimizer_fn=client_optimizer_fn,
        reconstruction_optimizer_fn=reconstruction_optimizer_fn,
        dataset_split_fn=dataset_split_fn,
        evaluate_reconstruction=evaluate_reconstruction,
        jointly_train_variables=jointly_train_variables,
        client_weight_fn=client_weight_fn)

    federated_server_state_type = tff.type_at_server(server_state_type)
    federated_dataset_type = tff.type_at_clients(tf_dataset_type)
    # Create placeholder metrics to produce a corresponding federated output
    # computation.
    metrics = []
    if metrics_fn is not None:
        metrics.extend(metrics_fn())
    metrics.append(keras_utils.MeanLossMetric(loss_fn()))
    federated_output_computation = (
        keras_utils.federated_output_computation_from_metrics(metrics))

    run_one_round_tff = build_run_one_round_fn(
        server_update_fn,
        client_update_fn,
        federated_output_computation,
        federated_server_state_type,
        federated_dataset_type,
        aggregation_process=aggregation_process,
    )

    iterative_process = tff.templates.IterativeProcess(
        initialize_fn=server_init_tff, next_fn=run_one_round_tff)

    @tff.tf_computation(server_state_type)
    def get_model_weights(server_state):
        return server_state.model

    iterative_process.get_model_weights = get_model_weights
    return iterative_process
Пример #7
0
def build_triehh_process(
        possible_prefix_extensions: List[str],
        num_sub_rounds: int,
        max_num_prefixes: int,
        threshold: int,
        max_user_contribution: int,
        default_terminator: str = triehh_tf.DEFAULT_TERMINATOR):
    """Builds the TFF computations for heavy hitters discovery with TrieHH.

  TrieHH works by interactively keeping track of popular prefixes. In each
  round, the server broadcasts the popular prefixes it has
  discovered so far and the list of `possible_prefix_extensions` to a small
  fraction of selected clients. The select clients sample
  `max_user_contributions` words from their local datasets, and use them to vote
  on character extensions to the broadcasted popular prefixes. Client votes are
  accumulated across `num_sub_rounds` rounds, and then the top
  `max_num_prefixes` extensions get at least 'threshold' votes are used to
  extend the already discovered
  prefixes, and the extended prefixes are used in the next round. When an
  already discovered prefix is extended by `default_terminator` it is added to
  the list of discovered heavy hitters.

  Args:
    possible_prefix_extensions: A list containing all the possible extensions to
      learned prefixes. Each extensions must be a single character strings. This
      list should not contain the default_terminator.
    num_sub_rounds: The total number of sub rounds to be executed before
      decoding aggregated votes. Must be positive.
    max_num_prefixes: The maximum number of prefixes we can keep in the trie.
      Must be positive.
    threshold: The threshold for heavy hitters and discovered prefixes. Only
      those get at least `threshold` votes are discovered. Must be positive.
    max_user_contribution: The maximum number of examples a user can contribute.
      Must be positive.
    default_terminator: The end of sequence symbol.

  Returns:
    A `tff.templates.IterativeProcess`.

  Raises:
    ValueError: If possible_prefix_extensions contains default_terminator.
  """
    if default_terminator in possible_prefix_extensions:
        raise ValueError(
            'default_terminator should not appear in possible_prefix_extensions'
        )

    # Append `default_terminator` to `possible_prefix_extensions` to make sure it
    # is the last item in the list.
    possible_prefix_extensions.append(default_terminator)

    @tff.tf_computation
    def server_init_tf():
        return ServerState(
            discovered_heavy_hitters=tf.constant([], dtype=tf.string),
            heavy_hitters_counts=tf.constant([], dtype=tf.int32),
            discovered_prefixes=tf.constant([''], dtype=tf.string),
            round_num=tf.constant(0, dtype=tf.int32),
            accumulated_votes=tf.zeros(
                dtype=tf.int32,
                shape=[max_num_prefixes,
                       len(possible_prefix_extensions)]))

    # We cannot use server_init_tf.type_signature.result because the
    # discovered_* fields need to have [None] shapes, since they will grow over
    # time.
    server_state_type = (tff.to_type(
        ServerState(
            discovered_heavy_hitters=tff.TensorType(dtype=tf.string,
                                                    shape=[None]),
            heavy_hitters_counts=tff.TensorType(dtype=tf.int32, shape=[None]),
            discovered_prefixes=tff.TensorType(dtype=tf.string, shape=[None]),
            round_num=tff.TensorType(dtype=tf.int32, shape=[]),
            accumulated_votes=tff.TensorType(
                dtype=tf.int32, shape=[None,
                                       len(possible_prefix_extensions)]),
        )))

    sub_round_votes_type = tff.TensorType(
        dtype=tf.int32,
        shape=[max_num_prefixes,
               len(possible_prefix_extensions)])

    @tff.tf_computation(server_state_type, sub_round_votes_type)
    def server_update_fn(server_state, sub_round_votes):
        return server_update(server_state,
                             tf.constant(possible_prefix_extensions),
                             sub_round_votes,
                             num_sub_rounds=tf.constant(num_sub_rounds),
                             max_num_prefixes=tf.constant(max_num_prefixes),
                             threshold=tf.constant(threshold))

    tf_dataset_type = tff.SequenceType(tf.string)
    discovered_prefixes_type = tff.TensorType(dtype=tf.string, shape=[None])
    round_num_type = tff.TensorType(dtype=tf.int32, shape=[])

    @tff.tf_computation(tf_dataset_type, discovered_prefixes_type,
                        round_num_type)
    def client_update_fn(tf_dataset, discovered_prefixes, round_num):
        return client_update(tf_dataset, discovered_prefixes,
                             tf.constant(possible_prefix_extensions),
                             round_num, num_sub_rounds, max_num_prefixes,
                             max_user_contribution,
                             tf.constant(default_terminator, dtype=tf.string))

    federated_server_state_type = tff.type_at_server(server_state_type)
    federated_dataset_type = tff.type_at_clients(tf_dataset_type)

    @tff.federated_computation(federated_server_state_type,
                               federated_dataset_type)
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of TrieHH computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      An updated `ServerState`
    """
        discovered_prefixes = tff.federated_broadcast(
            server_state.discovered_prefixes)
        round_num = tff.federated_broadcast(server_state.round_num)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, discovered_prefixes, round_num))

        accumulated_votes = tff.federated_sum(client_outputs.client_votes)

        server_state = tff.federated_map(server_update_fn,
                                         (server_state, accumulated_votes))

        server_output = tff.federated_value([], tff.SERVER)

        return server_state, server_output

    return tff.templates.IterativeProcess(
        initialize_fn=tff.federated_computation(
            lambda: tff.federated_eval(server_init_tf, tff.SERVER)),
        next_fn=run_one_round)