示例#1
0
 def server_init():
     initial_model, server_optimizer_state = tff.federated_eval(
         server_init_tf, tff.SERVER)
     return tff.federated_zip(
         ServerState(model=initial_model,
                     optimizer_state=server_optimizer_state,
                     delta_aggregate_state=aggregation_process_init()))
示例#2
0
 def fed_server_initial_state():
   state = tff.federated_eval(build_server_initial_state_comp(gan), tff.SERVER)
   server_initial_state = tff.federated_zip(
       gan_training_tf_fns.ServerState(
           state.generator_weights,
           state.discriminator_weights,
           state.counters,
           aggregation_state=gan.aggregation_process.initialize()))
   return server_initial_state
示例#3
0
 def server_init_tff():
     """Returns a `reconstruction_utils.ServerState` placed at `tff.SERVER`."""
     tf_init_tuple = tff.federated_eval(server_init_tf, tff.SERVER)
     aggregation_process_init = aggregation_process.initialize()
     return tff.federated_zip(
         reconstruction_utils.ServerState(
             model=tf_init_tuple[0],
             optimizer_state=tf_init_tuple[1],
             round_num=tf_init_tuple[2],
             aggregator_state=aggregation_process_init))
示例#4
0
 def fed_server_initial_state():
   state = tff.federated_eval(build_server_initial_state_comp(gan), tff.SERVER)
   dp_averaging_state = (
       state.dp_averaging_state
       if gan.dp_averaging_fn is None else gan.dp_averaging_fn.initialize())
   server_initial_state = tff.federated_zip(
       gan_training_tf_fns.ServerState(
           state.generator_weights,
           state.discriminator_weights,
           state.counters,
           dp_averaging_state=dp_averaging_state))
   return server_initial_state
示例#5
0
def build_federated_averaging_process(
    model_fn,
    client_optimizer_fn,
    server_optimizer_fn=lambda: flars_optimizer.FLARSOptimizer(learning_rate=
                                                               1.0)):
    """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for the local client training.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for applying updates on the server.

  Returns:
    A `tff.templates.IterativeProcess`.
  """
    with tf.Graph().as_default():
        dummy_model_for_metadata = model_fn()
    type_signature_grads_norm = tuple(
        weight.dtype for weight in tf.nest.flatten(
            dummy_model_for_metadata.trainable_variables))

    server_init_tf = build_server_init_fn(model_fn, server_optimizer_fn)

    server_state_type = server_init_tf.type_signature.result
    server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn,
                                              server_state_type,
                                              server_state_type.model,
                                              type_signature_grads_norm)

    tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec)
    client_update_fn = build_client_update_fn(model_fn, client_optimizer_fn,
                                              tf_dataset_type,
                                              server_state_type.model)

    federated_server_state_type = tff.FederatedType(server_state_type,
                                                    tff.SERVER)
    federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
    run_one_round_tff = build_run_one_round_fn(server_update_fn,
                                               client_update_fn,
                                               dummy_model_for_metadata,
                                               federated_server_state_type,
                                               federated_dataset_type)

    return tff.templates.IterativeProcess(
        initialize_fn=tff.federated_computation(
            lambda: tff.federated_eval(server_init_tf, tff.SERVER)),
        next_fn=run_one_round_tff)
示例#6
0
def build_triehh_process(
        possible_prefix_extensions: List[str],
        num_sub_rounds: int,
        max_num_prefixes: int,
        threshold: int,
        max_user_contribution: int,
        default_terminator: str = triehh_tf.DEFAULT_TERMINATOR):
    """Builds the TFF computations for heavy hitters discovery with TrieHH.

  TrieHH works by interactively keeping track of popular prefixes. In each
  round, the server broadcasts the popular prefixes it has
  discovered so far and the list of `possible_prefix_extensions` to a small
  fraction of selected clients. The select clients sample
  `max_user_contributions` words from their local datasets, and use them to vote
  on character extensions to the broadcasted popular prefixes. Client votes are
  accumulated across `num_sub_rounds` rounds, and then the top
  `max_num_prefixes` extensions get at least 'threshold' votes are used to
  extend the already discovered
  prefixes, and the extended prefixes are used in the next round. When an
  already discovered prefix is extended by `default_terminator` it is added to
  the list of discovered heavy hitters.

  Args:
    possible_prefix_extensions: A list containing all the possible extensions to
      learned prefixes. Each extensions must be a single character strings. This
      list should not contain the default_terminator.
    num_sub_rounds: The total number of sub rounds to be executed before
      decoding aggregated votes. Must be positive.
    max_num_prefixes: The maximum number of prefixes we can keep in the trie.
      Must be positive.
    threshold: The threshold for heavy hitters and discovered prefixes. Only
      those get at least `threshold` votes are discovered. Must be positive.
    max_user_contribution: The maximum number of examples a user can contribute.
      Must be positive.
    default_terminator: The end of sequence symbol.

  Returns:
    A `tff.templates.IterativeProcess`.

  Raises:
    ValueError: If possible_prefix_extensions contains default_terminator.
  """
    if default_terminator in possible_prefix_extensions:
        raise ValueError(
            'default_terminator should not appear in possible_prefix_extensions'
        )

    # Append `default_terminator` to `possible_prefix_extensions` to make sure it
    # is the last item in the list.
    possible_prefix_extensions.append(default_terminator)

    @tff.tf_computation
    def server_init_tf():
        return ServerState(
            discovered_heavy_hitters=tf.constant([], dtype=tf.string),
            heavy_hitters_frequencies=tf.constant([], dtype=tf.float64),
            discovered_prefixes=tf.constant([''], dtype=tf.string),
            round_num=tf.constant(0, dtype=tf.int32),
            accumulated_votes=tf.zeros(
                dtype=tf.int32,
                shape=[max_num_prefixes,
                       len(possible_prefix_extensions)]),
            accumulated_weights=tf.constant(0, dtype=tf.int32))

    # We cannot use server_init_tf.type_signature.result because the
    # discovered_* fields need to have [None] shapes, since they will grow over
    # time.
    server_state_type = (tff.to_type(
        ServerState(
            discovered_heavy_hitters=tff.TensorType(dtype=tf.string,
                                                    shape=[None]),
            heavy_hitters_frequencies=tff.TensorType(dtype=tf.float64,
                                                     shape=[None]),
            discovered_prefixes=tff.TensorType(dtype=tf.string, shape=[None]),
            round_num=tff.TensorType(dtype=tf.int32, shape=[]),
            accumulated_votes=tff.TensorType(
                dtype=tf.int32, shape=[None,
                                       len(possible_prefix_extensions)]),
            accumulated_weights=tff.TensorType(dtype=tf.int32, shape=[]),
        )))

    sub_round_votes_type = tff.TensorType(
        dtype=tf.int32,
        shape=[max_num_prefixes,
               len(possible_prefix_extensions)])
    sub_round_weight_type = tff.TensorType(dtype=tf.int32, shape=[])

    @tff.tf_computation(server_state_type, sub_round_votes_type,
                        sub_round_weight_type)
    def server_update_fn(server_state, sub_round_votes, sub_round_weight):
        return server_update(server_state,
                             tf.constant(possible_prefix_extensions),
                             sub_round_votes,
                             sub_round_weight,
                             num_sub_rounds=tf.constant(num_sub_rounds),
                             max_num_prefixes=tf.constant(max_num_prefixes),
                             threshold=tf.constant(threshold))

    tf_dataset_type = tff.SequenceType(tf.string)
    discovered_prefixes_type = tff.TensorType(dtype=tf.string, shape=[None])
    round_num_type = tff.TensorType(dtype=tf.int32, shape=[])

    @tff.tf_computation(tf_dataset_type, discovered_prefixes_type,
                        round_num_type)
    def client_update_fn(tf_dataset, discovered_prefixes, round_num):
        return client_update(tf_dataset, discovered_prefixes,
                             tf.constant(possible_prefix_extensions),
                             round_num, num_sub_rounds, max_num_prefixes,
                             max_user_contribution)

    federated_server_state_type = tff.FederatedType(server_state_type,
                                                    tff.SERVER)
    federated_dataset_type = tff.FederatedType(tf_dataset_type,
                                               tff.CLIENTS,
                                               all_equal=False)

    @tff.federated_computation(federated_server_state_type,
                               federated_dataset_type)
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of TrieHH computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      An updated `ServerState`
    """
        discovered_prefixes = tff.federated_broadcast(
            server_state.discovered_prefixes)
        round_num = tff.federated_broadcast(server_state.round_num)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, discovered_prefixes, round_num))

        accumulated_votes = tff.federated_sum(client_outputs.client_votes)

        accumulated_weights = tff.federated_sum(client_outputs.client_weight)

        server_state = tff.federated_map(
            server_update_fn,
            (server_state, accumulated_votes, accumulated_weights))

        server_output = tff.federated_value([], tff.SERVER)

        return server_state, server_output

    return tff.templates.IterativeProcess(
        initialize_fn=tff.federated_computation(
            lambda: tff.federated_eval(server_init_tf, tff.SERVER)),
        next_fn=run_one_round)
示例#7
0
 def server_init_tff():
     """Orchestration logic for server model initialization."""
     return tff.federated_eval(server_init_tf, tff.SERVER)
示例#8
0
 def create_zero_model_on_server():
     return tff.federated_eval(create_zero_model, tff.SERVER)