示例#1
0
def client_update(dataset, discovered_prefixes, possible_prefix_extensions,
                  round_num, num_sub_rounds, max_num_prefixes,
                  max_user_contribution, default_terminator):
  """Creates a ClientOutput object that holds the client's votes.

  This function takes in a 'tf.data.Dataset' containing the client's words,
  selects (up to) `max_user_contribution` words the given `dataset`, and creates
  a `ClientOutput` object that holds the client's votes on chracter extensions
  to `discovered_prefixes`. The allowed character extensions are found in
  `possible_prefix_extensions`. `round_num` and `num_sub_round` are needed to
  compute the length of the prefix to be extended. `max_num_prefixes` is
  needed to set the shape of the tensor holding the client votes.

  Args:
    dataset: A 'tf.data.Dataset' containing the client's on-device words.
    discovered_prefixes: A tf.string containing candidate prefixes.
    possible_prefix_extensions: A tf.string of shape (num_discovered_prefixes, )
      containing possible prefix extensions.
    round_num: A tf.constant dictating the algorithm's round number.
    num_sub_rounds: A tf.constant containing the number of sub rounds in a
      round.
    max_num_prefixes: A tf.constant dictating the maximum number of prefixes we
      can keep in the trie.
    max_user_contribution: A tf.constant dictating the maximum number of
      examples a client can contribute.
    default_terminator: A tf.string containing the end of sequence symbol.

  Returns:
    A ClientOutput object holding the client's votes.
  """
  # Create all zero client vote tensor.
  client_votes = tf.zeros(
      dtype=tf.int32,
      shape=[max_num_prefixes,
             tf.shape(possible_prefix_extensions)[0]])

  # If discovered_prefixes is emtpy (training is done), skip the voting.
  if tf.math.equal(tf.size(discovered_prefixes), 0):
    return ClientOutput(client_votes)
  else:
    discovered_prefixes_table = tf.lookup.StaticHashTable(
        tf.lookup.KeyValueTensorInitializer(
            discovered_prefixes, tf.range(tf.shape(discovered_prefixes)[0])),
        DEFAULT_VALUE)

    possible_prefix_extensions_table = tf.lookup.StaticHashTable(
        tf.lookup.KeyValueTensorInitializer(
            possible_prefix_extensions,
            tf.range(tf.shape(possible_prefix_extensions)[0])), DEFAULT_VALUE)

    accumulate_client_votes_fn = make_accumulate_client_votes_fn(
        round_num, num_sub_rounds, discovered_prefixes_table,
        possible_prefix_extensions_table, default_terminator)

    sampled_data_list = hh_utils.get_top_elements(dataset,
                                                  max_user_contribution)
    sampled_data = tf.data.Dataset.from_tensor_slices(sampled_data_list)

    return ClientOutput(
        sampled_data.reduce(client_votes, accumulate_client_votes_fn))
 def test_over_max_contribution(self):
   ds = tf.data.Dataset.from_tensor_slices(['a', 'b', 'a', 'c', 'b', 'c', 'c'])
   top_elements = hh_utils.get_top_elements(ds, max_user_contribution=2)
   self.assertCountEqual(top_elements.numpy(), [b'a', b'c'])
 def test_empty_dataset(self):
   ds = tf.data.Dataset.from_tensor_slices([])
   top_elements = hh_utils.get_top_elements(ds, max_user_contribution=10)
   self.assertEmpty(top_elements)