Пример #1
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of TrieHH computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      An updated `ServerState`
    """
        discovered_prefixes = tff.federated_broadcast(
            server_state.discovered_prefixes)
        round_num = tff.federated_broadcast(server_state.round_num)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, discovered_prefixes, round_num))

        accumulated_votes = tff.federated_sum(client_outputs.client_votes)

        accumulated_weights = tff.federated_sum(client_outputs.client_weight)

        server_state = tff.federated_map(
            server_update_fn,
            (server_state, accumulated_votes, accumulated_weights))

        server_output = tff.federated_value([], tff.SERVER)

        return server_state, server_output
 def next_fn(state, value):
   one_at_clients = tff.federated_value(1, tff.CLIENTS)
   dp_sum = self._dp_sum_process.next(state, value)
   summed_one = tff.federated_sum(one_at_clients)
   return tff.templates.MeasuredProcessOutput(
       state=dp_sum.state,
       result=tff.federated_map(div, (dp_sum.result, summed_one)),
       measurements=dp_sum.measurements)
Пример #3
0
 def next_fn(state, value, weight):
     weighted_values = tff.federated_map(_mul, (value, weight))
     summed_value = tff.federated_sum(weighted_values)
     normalized_value = tff.federated_map(_div, (summed_value, state))
     measurements = tff.federated_value((), tff.SERVER)
     return tff.templates.MeasuredProcessOutput(
         state=state,
         result=normalized_value,
         measurements=measurements)
 def _robust_aggregation_fn(state, value, weight):
     aggregate = tff.federated_mean(value, weight=weight)
     for _ in range(num_communication_passes - 1):
         aggregate_at_client = tff.federated_broadcast(aggregate)
         updated_weight = tff.federated_map(
             update_weight_fn, (weight, aggregate_at_client, value))
         aggregate = tff.federated_mean(value, weight=updated_weight)
     no_metrics = tff.federated_value((), tff.SERVER)
     return tff.templates.MeasuredProcessOutput(state, aggregate,
                                                no_metrics)
Пример #5
0
 def initialize_computation():
     model = model_fn()
     initial_global_model, initial_global_optimizer_state = intrinsics.federated_eval(
         server_init_tf, tff.SERVER)
     return intrinsics.federated_zip(
         ServerState(
             model=initial_global_model,
             optimizer_state=initial_global_optimizer_state,
             round_num=tff.federated_value(0.0, tff.SERVER),
             aggregation_state=aggregation_process.initialize(),
         ))
Пример #6
0
def build_federated_averaging_process_attacked(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    stateful_delta_aggregate_fn=build_stateless_mean(),
    client_update_tf=ClientExplicitBoosting(boost_factor=1.0)):
    """Builds the TFF computations for optimization using federated averaging with potentially malicious clients.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`, use during local client training.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`, use to apply updates to the global model.
    stateful_delta_aggregate_fn: A 'tff.computation' that aggregates model
      deltas placed@CLIENTS to an aggregated model delta placed@SERVER.
    client_update_tf: a 'tf.function' computes the ClientOutput.

  Returns:
    A `tff.templates.IterativeProcess`.
  """

    dummy_model_for_metadata = model_fn()

    server_init_tf = build_server_init_fn(
        model_fn, server_optimizer_fn,
        stateful_delta_aggregate_fn.initialize())
    server_state_type = server_init_tf.type_signature.result
    server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn,
                                              server_state_type,
                                              server_state_type.model)
    tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec)

    client_update_fn = build_client_update_fn(model_fn, client_optimizer_fn,
                                              client_update_tf,
                                              tf_dataset_type,
                                              server_state_type.model)

    federated_server_state_type = tff.FederatedType(server_state_type,
                                                    tff.SERVER)

    federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

    run_one_round_tff = build_run_one_round_fn_attacked(
        server_update_fn, client_update_fn, stateful_delta_aggregate_fn,
        dummy_model_for_metadata, federated_server_state_type,
        federated_dataset_type)

    return tff.templates.IterativeProcess(
        initialize_fn=tff.federated_computation(
            lambda: tff.federated_value(server_init_tf(), tff.SERVER)),
        next_fn=run_one_round_tff)
 def initialize_computation():
     model = model_fn()
     initial_global_model, initial_global_optimizer_state = intrinsics.federated_eval(
         server_init_tf, placements.SERVER)
     return intrinsics.federated_zip(
         ServerState(
             model=initial_global_model,
             optimizer_state=initial_global_optimizer_state,
             round_num=tff.federated_value(0.0, tff.SERVER),
             effective_num_clients=intrinsics.federated_eval(
                 get_effective_num_clients, placements.SERVER),
             delta_aggregate_state=aggregation_process.initialize(),
         ))
Пример #8
0
    def run_one_round(server_state, server_gen_inputs, client_gen_inputs,
                      client_real_data):
        """The `tff.Computation` to be returned."""
        from_server = gan_training_tf_fns.FromServer(
            generator_weights=server_state.generator_weights,
            discriminator_weights=server_state.discriminator_weights)
        client_input = tff.federated_broadcast(from_server)
        client_outputs = tff.federated_map(
            client_computation,
            (client_gen_inputs, client_real_data, client_input))

        if gan.dp_averaging_fn is None:
            # Not using differential privacy.
            new_dp_averaging_state = server_state.dp_averaging_state
            averaged_discriminator_weights_delta = tff.federated_mean(
                client_outputs.discriminator_weights_delta,
                weight=client_outputs.update_weight)
        else:
            # Using differential privacy. Note that the weight argument is set to
            # a constant 1.0 here, however the underlying AggregationProcess ignores
            # the parameter and performs no weighting.
            ignored_weight = tff.federated_value(1.0, tff.CLIENTS)
            aggregation_output = gan.dp_averaging_fn.next(
                server_state.dp_averaging_state,
                client_outputs.discriminator_weights_delta,
                weight=ignored_weight)
            new_dp_averaging_state = aggregation_output.state
            averaged_discriminator_weights_delta = aggregation_output.result

        # TODO(b/131085687): Perhaps reconsider the choice to also use
        # ClientOutput to hold the aggregated client output.
        aggregated_client_output = gan_training_tf_fns.ClientOutput(
            discriminator_weights_delta=averaged_discriminator_weights_delta,
            # We don't actually need the aggregated update_weight, but
            # this keeps the types of the non-aggregated and aggregated
            # client_output the same, which is convenient. And I can
            # imagine wanting this.
            update_weight=tff.federated_sum(client_outputs.update_weight),
            counters=tff.federated_sum(client_outputs.counters))

        server_computation = build_server_computation(
            gan, server_state.type_signature.member, client_output_type)
        server_state = tff.federated_map(
            server_computation,
            (server_state, server_gen_inputs, aggregated_client_output,
             new_dp_averaging_state))
        return server_state
Пример #9
0
def build_federated_averaging_process(
    model_fn,
    client_optimizer_fn,
    server_optimizer_fn=lambda: flars_optimizer.FLARSOptimizer(learning_rate=
                                                               1.0)):
    """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.TrainableModel`.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for the local client training.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for applying updates on the server.

  Returns:
    A `tff.utils.IterativeProcess`.
  """
    dummy_model_for_metadata = model_fn()
    type_signature_grads_norm = tff.NamedTupleType([
        weight.dtype for weight in tf.nest.flatten(
            _get_weights(dummy_model_for_metadata).trainable)
    ])

    server_init_tf = build_server_init_fn(model_fn, server_optimizer_fn)
    server_state_type = server_init_tf.type_signature.result
    server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn,
                                              server_state_type,
                                              server_state_type.model,
                                              type_signature_grads_norm)

    tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec)
    client_update_fn = build_client_update_fn(model_fn, client_optimizer_fn,
                                              tf_dataset_type,
                                              server_state_type.model)

    federated_server_state_type = tff.FederatedType(server_state_type,
                                                    tff.SERVER)
    federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
    run_one_round_tff = build_run_one_round_fn(server_update_fn,
                                               client_update_fn,
                                               dummy_model_for_metadata,
                                               federated_server_state_type,
                                               federated_dataset_type)

    return tff.utils.IterativeProcess(initialize_fn=tff.federated_computation(
        lambda: tff.federated_value(server_init_tf(), tff.SERVER)),
                                      next_fn=run_one_round_tff)
Пример #10
0
def build_federated_averaging_process(
    model_fn,
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0)):
    """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.TrainableModel`.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`.

  Returns:
    A `tff.utils.IterativeProcess`.
  """

    dummy_model_for_metadata = model_fn()

    server_init_tf = build_server_init_fn(model_fn, server_optimizer_fn)
    server_state_type = server_init_tf.type_signature.result
    server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn,
                                              server_state_type,
                                              server_state_type.model)

    tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec)
    client_update_fn = build_client_update_fn(model_fn, tf_dataset_type,
                                              server_state_type.model)

    federated_server_state_type = tff.FederatedType(server_state_type,
                                                    tff.SERVER)
    federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
    run_one_round_tff = build_run_one_round_fn(server_update_fn,
                                               client_update_fn,
                                               dummy_model_for_metadata,
                                               federated_server_state_type,
                                               federated_dataset_type)

    return tff.utils.IterativeProcess(initialize_fn=tff.federated_computation(
        lambda: tff.federated_value(server_init_tf(), tff.SERVER)),
                                      next_fn=run_one_round_tff)
Пример #11
0
 def fed_server_initial_state():
     return tff.federated_value(server_initial_state(), tff.SERVER)
Пример #12
0
 def federated_aggregate_test(deltas, weights):
     state = tff.federated_value(aggregate_fn.initialize(), tff.SERVER)
     return aggregate_fn(state, deltas, weights)
 def initialize_fn():
     """Initialize the server state."""
     return tff.federated_value(server_init_tf(), tff.SERVER)
Пример #14
0
 def initialize_fn():
     # state = AggregationState(self._num_participants)
     return tff.federated_value(self._num_participants, tff.SERVER)
Пример #15
0
 def initialize_fn(federated_dataset):
     server_state = tff.federated_value(server_init_tf(), tff.SERVER)
     client_states = tff.federated_map(init_client_state,
                                       (federated_dataset))
     return (server_state, client_states)
Пример #16
0
 def initialize_fn():
   return tff.federated_value((), tff.SERVER)
Пример #17
0
 def server_init_tff():
     """Orchestration logic for server model initialization."""
     return tff.federated_value(server_init_tf(), tff.SERVER)
Пример #18
0
 def init_tff():
     return tff.federated_value(init_tf(), tff.SERVER)
Пример #19
0
 def server_init_tff():
   """Returns initial `tff.learning.framework.ServerState."""
   return tff.federated_value(server_init_tf(), tff.SERVER)
Пример #20
0
 def initialize_fn():
     return tff.federated_value(server_init_tf(), tff.SERVER)
 def get_clip_range():
     return (tff.federated_value(clip_range_lower, tff.SERVER),
             tff.federated_value(clip_range_upper, tff.SERVER))
 def initialize():
   return tff.federated_value(create_initial_state(), tff.SERVER)
Пример #23
0
def build_federated_averaging_process(
    model_fn,
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1)):
    """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.TrainableModel`.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for server update.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for client update.

  Returns:
    A `tff.utils.IterativeProcess`.
  """

    dummy_model = model_fn(
    )  # TODO(b/144510813): try remove dependency on dummy model

    @tff.tf_computation
    def server_init_tf():
        model = model_fn()
        server_optimizer = server_optimizer_fn()
        _initialize_optimizer_vars(model, server_optimizer)
        return ServerState(model=model.weights,
                           optimizer_state=server_optimizer.variables())

    server_state_type = server_init_tf.type_signature.result
    model_weights_type = server_state_type.model

    @tff.tf_computation(server_state_type, model_weights_type.trainable)
    def server_update_fn(server_state, model_delta):
        model = model_fn()
        server_optimizer = server_optimizer_fn()
        _initialize_optimizer_vars(model, server_optimizer)
        return server_update(model, server_optimizer, server_state,
                             model_delta)

    tf_dataset_type = tff.SequenceType(dummy_model.input_spec)

    @tff.tf_computation(tf_dataset_type, model_weights_type)
    def client_update_fn(tf_dataset, initial_model_weights):
        model = model_fn()
        client_optimizer = client_optimizer_fn()
        return client_update(model, tf_dataset, initial_model_weights,
                             client_optimizer)

    federated_server_state_type = tff.FederatedType(server_state_type,
                                                    tff.SERVER)
    federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

    @tff.federated_computation(federated_server_state_type,
                               federated_dataset_type)
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)

        client_outputs = tff.federated_map(client_update_fn,
                                           (federated_dataset, client_model))

        weight_denom = client_outputs.client_weight
        round_model_delta = tff.federated_mean(client_outputs.weights_delta,
                                               weight=weight_denom)

        server_state = tff.federated_map(server_update_fn,
                                         (server_state, round_model_delta))

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs

    return tff.utils.IterativeProcess(initialize_fn=tff.federated_computation(
        lambda: tff.federated_value(server_init_tf(), tff.SERVER)),
                                      next_fn=run_one_round)
Пример #24
0
def build_triehh_process(possible_prefix_extensions: List[str],
                         num_sub_rounds: int,
                         max_num_heavy_hitters: int,
                         max_user_contribution: int,
                         default_terminator: str = '$'):
  """Builds the TFF computations for heavy hitters discovery with TrieHH.

  TrieHH works by interactively keeping track of popular prefixes. In each
  round, the server broadcasts the popular prefixes it has
  discovered so far and the list of `possible_prefix_extensions` to a small
  fraction of selected clients. The select clients sample
  `max_user_contributions` words from their local datasets, and use them to vote
  on character extensions to the broadcasted popular prefixes. Client votes are
  accumulated across `num_sub_rounds` rounds, and then the top
  `max_num_heavy_hitters` extensions are used to extend the already discovered
  prefixes, and the extended prefixes are used in the next round. When an
  already discovered prefix is extended by `default_terminator` it is added to
  the list of discovered heavy hitters.

  Args:
    possible_prefix_extensions: A list containing all the possible extensions to
      learned prefixes. Each extensions must be a single character strings.
    num_sub_rounds: The total number of sub rounds to be executed before
      decoding aggregated votes. Must be positive.
    max_num_heavy_hitters: The maximum number of discoverable heavy hitters.
      Must be positive.
    max_user_contribution: The maximum number of examples a user can contribute.
      Must be positive.
    default_terminator: The end of sequence symbol.

  Returns:
    A `tff.utils.IterativeProcess`.
  """

  @tff.tf_computation
  def server_init_tf():
    return ServerState(
        discovered_heavy_hitters=tf.constant([], dtype=tf.string),
        discovered_prefixes=tf.constant([''], dtype=tf.string),
        possible_prefix_extensions=tf.constant(
            possible_prefix_extensions, dtype=tf.string),
        round_num=tf.constant(0, dtype=tf.int32),
        accumulated_votes=tf.zeros(
            dtype=tf.int32,
            shape=[max_num_heavy_hitters,
                   len(possible_prefix_extensions)]))

  # We cannot use server_init_tf.type_signature.result because the
  # discovered_* fields need to have [None] shapes, since they will grow over
  # time.
  server_state_type = (
      tff.to_type(
          ServerState(
              discovered_heavy_hitters=tff.TensorType(
                  dtype=tf.string, shape=[None]),
              discovered_prefixes=tff.TensorType(dtype=tf.string, shape=[None]),
              possible_prefix_extensions=tff.TensorType(
                  dtype=tf.string, shape=[len(possible_prefix_extensions)]),
              round_num=tff.TensorType(dtype=tf.int32, shape=[]),
              accumulated_votes=tff.TensorType(
                  dtype=tf.int32, shape=[None,
                                         len(possible_prefix_extensions)]),
          )))

  sub_round_votes_type = tff.TensorType(
      dtype=tf.int32,
      shape=[max_num_heavy_hitters,
             len(possible_prefix_extensions)])

  @tff.tf_computation(server_state_type, sub_round_votes_type)
  @tf.function
  def server_update_fn(server_state, sub_round_votes):
    server_state = server_update(
        server_state,
        sub_round_votes,
        num_sub_rounds=tf.constant(num_sub_rounds),
        max_num_heavy_hitters=tf.constant(max_num_heavy_hitters),
        default_terminator=tf.constant(default_terminator, dtype=tf.string))
    return server_state

  tf_dataset_type = tff.SequenceType(tf.string)
  discovered_prefixes_type = tff.TensorType(dtype=tf.string, shape=[None])
  round_num_type = tff.TensorType(dtype=tf.int32, shape=[])

  @tff.tf_computation(tf_dataset_type, discovered_prefixes_type, round_num_type)
  @tf.function
  def client_update_fn(tf_dataset, discovered_prefixes, round_num):
    result = client_update(tf_dataset, discovered_prefixes,
                           tf.constant(possible_prefix_extensions), round_num,
                           num_sub_rounds, max_num_heavy_hitters,
                           max_user_contribution)
    return result

  federated_server_state_type = tff.FederatedType(server_state_type, tff.SERVER)
  federated_dataset_type = tff.FederatedType(
      tf_dataset_type, tff.CLIENTS, all_equal=False)

  @tff.federated_computation(federated_server_state_type,
                             federated_dataset_type)
  def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of TrieHH computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      An updated `ServerState`
    """
    discovered_prefixes = tff.federated_broadcast(
        server_state.discovered_prefixes)
    round_num = tff.federated_broadcast(server_state.round_num)

    client_outputs = tff.federated_map(
        client_update_fn,
        tff.federated_zip([federated_dataset, discovered_prefixes, round_num]))

    accumulated_votes = tff.federated_sum(client_outputs.client_votes)

    server_state = tff.federated_map(server_update_fn,
                                     (server_state, accumulated_votes))

    server_output = tff.federated_value([], tff.SERVER)

    return server_state, server_output

  return tff.utils.IterativeProcess(
      initialize_fn=tff.federated_computation(
          lambda: tff.federated_value(server_init_tf(), tff.SERVER)),
      next_fn=run_one_round)
 def stateless_init():
     return tff.federated_value((), tff.SERVER)
Пример #26
0
 def foo(x):
   return tff.federated_value(x, tff.SERVER)
Пример #27
0
 def empty_agg():
     val_at_clients = tff.federated_value([], tff.CLIENTS)
     return tff.federated_aggregate(val_at_clients, [], accumulate,
                                    accumulate, report)