示例#1
0
 def server_init_tf():
     return ServerState(
         discovered_heavy_hitters=tf.constant([], dtype=tf.string),
         discovered_prefixes=tf.constant([''], dtype=tf.string),
         round_num=tf.constant(0, dtype=tf.int32),
         accumulated_votes=tf.zeros(
             dtype=tf.int32,
             shape=[max_num_prefixes,
                    len(possible_prefix_extensions)]))
示例#2
0
def build_triehh_process(
        possible_prefix_extensions: List[str],
        num_sub_rounds: int,
        max_num_prefixes: int,
        threshold: int,
        max_user_contribution: int,
        default_terminator: str = triehh_tf.DEFAULT_TERMINATOR):
    """Builds the TFF computations for heavy hitters discovery with TrieHH.

  TrieHH works by interactively keeping track of popular prefixes. In each
  round, the server broadcasts the popular prefixes it has
  discovered so far and the list of `possible_prefix_extensions` to a small
  fraction of selected clients. The select clients sample
  `max_user_contributions` words from their local datasets, and use them to vote
  on character extensions to the broadcasted popular prefixes. Client votes are
  accumulated across `num_sub_rounds` rounds, and then the top
  `max_num_prefixes` extensions get at least 'threshold' votes are used to
  extend the already discovered
  prefixes, and the extended prefixes are used in the next round. When an
  already discovered prefix is extended by `default_terminator` it is added to
  the list of discovered heavy hitters.

  Args:
    possible_prefix_extensions: A list containing all the possible extensions to
      learned prefixes. Each extensions must be a single character strings. This
      list should not contain the default_terminator.
    num_sub_rounds: The total number of sub rounds to be executed before
      decoding aggregated votes. Must be positive.
    max_num_prefixes: The maximum number of prefixes we can keep in the trie.
      Must be positive.
    threshold: The threshold for heavy hitters and discovered prefixes. Only
      those get at least `threshold` votes are discovered. Must be positive.
    max_user_contribution: The maximum number of examples a user can contribute.
      Must be positive.
    default_terminator: The end of sequence symbol.

  Returns:
    A `tff.templates.IterativeProcess`.

  Raises:
    ValueError: If possible_prefix_extensions contains default_terminator.
  """
    if default_terminator in possible_prefix_extensions:
        raise ValueError(
            'default_terminator should not appear in possible_prefix_extensions'
        )

    # Append `default_terminator` to `possible_prefix_extensions` to make sure it
    # is the last item in the list.
    possible_prefix_extensions.append(default_terminator)

    @tff.tf_computation
    def server_init_tf():
        return ServerState(
            discovered_heavy_hitters=tf.constant([], dtype=tf.string),
            heavy_hitters_frequencies=tf.constant([], dtype=tf.float64),
            discovered_prefixes=tf.constant([''], dtype=tf.string),
            round_num=tf.constant(0, dtype=tf.int32),
            accumulated_votes=tf.zeros(
                dtype=tf.int32,
                shape=[max_num_prefixes,
                       len(possible_prefix_extensions)]),
            accumulated_weights=tf.constant(0, dtype=tf.int32))

    # We cannot use server_init_tf.type_signature.result because the
    # discovered_* fields need to have [None] shapes, since they will grow over
    # time.
    server_state_type = (tff.to_type(
        ServerState(
            discovered_heavy_hitters=tff.TensorType(dtype=tf.string,
                                                    shape=[None]),
            heavy_hitters_frequencies=tff.TensorType(dtype=tf.float64,
                                                     shape=[None]),
            discovered_prefixes=tff.TensorType(dtype=tf.string, shape=[None]),
            round_num=tff.TensorType(dtype=tf.int32, shape=[]),
            accumulated_votes=tff.TensorType(
                dtype=tf.int32, shape=[None,
                                       len(possible_prefix_extensions)]),
            accumulated_weights=tff.TensorType(dtype=tf.int32, shape=[]),
        )))

    sub_round_votes_type = tff.TensorType(
        dtype=tf.int32,
        shape=[max_num_prefixes,
               len(possible_prefix_extensions)])
    sub_round_weight_type = tff.TensorType(dtype=tf.int32, shape=[])

    @tff.tf_computation(server_state_type, sub_round_votes_type,
                        sub_round_weight_type)
    def server_update_fn(server_state, sub_round_votes, sub_round_weight):
        return server_update(server_state,
                             tf.constant(possible_prefix_extensions),
                             sub_round_votes,
                             sub_round_weight,
                             num_sub_rounds=tf.constant(num_sub_rounds),
                             max_num_prefixes=tf.constant(max_num_prefixes),
                             threshold=tf.constant(threshold))

    tf_dataset_type = tff.SequenceType(tf.string)
    discovered_prefixes_type = tff.TensorType(dtype=tf.string, shape=[None])
    round_num_type = tff.TensorType(dtype=tf.int32, shape=[])

    @tff.tf_computation(tf_dataset_type, discovered_prefixes_type,
                        round_num_type)
    def client_update_fn(tf_dataset, discovered_prefixes, round_num):
        return client_update(tf_dataset, discovered_prefixes,
                             tf.constant(possible_prefix_extensions),
                             round_num, num_sub_rounds, max_num_prefixes,
                             max_user_contribution)

    federated_server_state_type = tff.FederatedType(server_state_type,
                                                    tff.SERVER)
    federated_dataset_type = tff.FederatedType(tf_dataset_type,
                                               tff.CLIENTS,
                                               all_equal=False)

    @tff.federated_computation(federated_server_state_type,
                               federated_dataset_type)
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of TrieHH computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      An updated `ServerState`
    """
        discovered_prefixes = tff.federated_broadcast(
            server_state.discovered_prefixes)
        round_num = tff.federated_broadcast(server_state.round_num)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, discovered_prefixes, round_num))

        accumulated_votes = tff.federated_sum(client_outputs.client_votes)

        accumulated_weights = tff.federated_sum(client_outputs.client_weight)

        server_state = tff.federated_map(
            server_update_fn,
            (server_state, accumulated_votes, accumulated_weights))

        server_output = tff.federated_value([], tff.SERVER)

        return server_state, server_output

    return tff.templates.IterativeProcess(
        initialize_fn=tff.federated_computation(
            lambda: tff.federated_eval(server_init_tf, tff.SERVER)),
        next_fn=run_one_round)
示例#3
0
def build_triehh_process(possible_prefix_extensions: List[str],
                         num_sub_rounds: int,
                         max_num_heavy_hitters: int,
                         max_user_contribution: int,
                         default_terminator: str = '$'):
  """Builds the TFF computations for heavy hitters discovery with TrieHH.

  TrieHH works by interactively keeping track of popular prefixes. In each
  round, the server broadcasts the popular prefixes it has
  discovered so far and the list of `possible_prefix_extensions` to a small
  fraction of selected clients. The select clients sample
  `max_user_contributions` words from their local datasets, and use them to vote
  on character extensions to the broadcasted popular prefixes. Client votes are
  accumulated across `num_sub_rounds` rounds, and then the top
  `max_num_heavy_hitters` extensions are used to extend the already discovered
  prefixes, and the extended prefixes are used in the next round. When an
  already discovered prefix is extended by `default_terminator` it is added to
  the list of discovered heavy hitters.

  Args:
    possible_prefix_extensions: A list containing all the possible extensions to
      learned prefixes. Each extensions must be a single character strings.
    num_sub_rounds: The total number of sub rounds to be executed before
      decoding aggregated votes. Must be positive.
    max_num_heavy_hitters: The maximum number of discoverable heavy hitters.
      Must be positive.
    max_user_contribution: The maximum number of examples a user can contribute.
      Must be positive.
    default_terminator: The end of sequence symbol.

  Returns:
    A `tff.utils.IterativeProcess`.
  """

  @tff.tf_computation
  def server_init_tf():
    return ServerState(
        discovered_heavy_hitters=tf.constant([], dtype=tf.string),
        discovered_prefixes=tf.constant([''], dtype=tf.string),
        possible_prefix_extensions=tf.constant(
            possible_prefix_extensions, dtype=tf.string),
        round_num=tf.constant(0, dtype=tf.int32),
        accumulated_votes=tf.zeros(
            dtype=tf.int32,
            shape=[max_num_heavy_hitters,
                   len(possible_prefix_extensions)]))

  # We cannot use server_init_tf.type_signature.result because the
  # discovered_* fields need to have [None] shapes, since they will grow over
  # time.
  server_state_type = (
      tff.to_type(
          ServerState(
              discovered_heavy_hitters=tff.TensorType(
                  dtype=tf.string, shape=[None]),
              discovered_prefixes=tff.TensorType(dtype=tf.string, shape=[None]),
              possible_prefix_extensions=tff.TensorType(
                  dtype=tf.string, shape=[len(possible_prefix_extensions)]),
              round_num=tff.TensorType(dtype=tf.int32, shape=[]),
              accumulated_votes=tff.TensorType(
                  dtype=tf.int32, shape=[None,
                                         len(possible_prefix_extensions)]),
          )))

  sub_round_votes_type = tff.TensorType(
      dtype=tf.int32,
      shape=[max_num_heavy_hitters,
             len(possible_prefix_extensions)])

  @tff.tf_computation(server_state_type, sub_round_votes_type)
  @tf.function
  def server_update_fn(server_state, sub_round_votes):
    server_state = server_update(
        server_state,
        sub_round_votes,
        num_sub_rounds=tf.constant(num_sub_rounds),
        max_num_heavy_hitters=tf.constant(max_num_heavy_hitters),
        default_terminator=tf.constant(default_terminator, dtype=tf.string))
    return server_state

  tf_dataset_type = tff.SequenceType(tf.string)
  discovered_prefixes_type = tff.TensorType(dtype=tf.string, shape=[None])
  round_num_type = tff.TensorType(dtype=tf.int32, shape=[])

  @tff.tf_computation(tf_dataset_type, discovered_prefixes_type, round_num_type)
  @tf.function
  def client_update_fn(tf_dataset, discovered_prefixes, round_num):
    result = client_update(tf_dataset, discovered_prefixes,
                           tf.constant(possible_prefix_extensions), round_num,
                           num_sub_rounds, max_num_heavy_hitters,
                           max_user_contribution)
    return result

  federated_server_state_type = tff.FederatedType(server_state_type, tff.SERVER)
  federated_dataset_type = tff.FederatedType(
      tf_dataset_type, tff.CLIENTS, all_equal=False)

  @tff.federated_computation(federated_server_state_type,
                             federated_dataset_type)
  def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of TrieHH computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      An updated `ServerState`
    """
    discovered_prefixes = tff.federated_broadcast(
        server_state.discovered_prefixes)
    round_num = tff.federated_broadcast(server_state.round_num)

    client_outputs = tff.federated_map(
        client_update_fn,
        tff.federated_zip([federated_dataset, discovered_prefixes, round_num]))

    accumulated_votes = tff.federated_sum(client_outputs.client_votes)

    server_state = tff.federated_map(server_update_fn,
                                     (server_state, accumulated_votes))

    server_output = tff.federated_value([], tff.SERVER)

    return server_state, server_output

  return tff.utils.IterativeProcess(
      initialize_fn=tff.federated_computation(
          lambda: tff.federated_value(server_init_tf(), tff.SERVER)),
      next_fn=run_one_round)