def client_update(dataset, discovered_prefixes, possible_prefix_extensions, round_num, num_sub_rounds, max_num_prefixes, max_user_contribution, default_terminator): """Creates a ClientOutput object that holds the client's votes. This function takes in a 'tf.data.Dataset' containing the client's words, selects (up to) `max_user_contribution` words the given `dataset`, and creates a `ClientOutput` object that holds the client's votes on chracter extensions to `discovered_prefixes`. The allowed character extensions are found in `possible_prefix_extensions`. `round_num` and `num_sub_round` are needed to compute the length of the prefix to be extended. `max_num_prefixes` is needed to set the shape of the tensor holding the client votes. Args: dataset: A 'tf.data.Dataset' containing the client's on-device words. discovered_prefixes: A tf.string containing candidate prefixes. possible_prefix_extensions: A tf.string of shape (num_discovered_prefixes, ) containing possible prefix extensions. round_num: A tf.constant dictating the algorithm's round number. num_sub_rounds: A tf.constant containing the number of sub rounds in a round. max_num_prefixes: A tf.constant dictating the maximum number of prefixes we can keep in the trie. max_user_contribution: A tf.constant dictating the maximum number of examples a client can contribute. default_terminator: A tf.string containing the end of sequence symbol. Returns: A ClientOutput object holding the client's votes. """ # Create all zero client vote tensor. client_votes = tf.zeros( dtype=tf.int32, shape=[max_num_prefixes, tf.shape(possible_prefix_extensions)[0]]) # If discovered_prefixes is emtpy (training is done), skip the voting. if tf.math.equal(tf.size(discovered_prefixes), 0): return ClientOutput(client_votes) else: discovered_prefixes_table = tf.lookup.StaticHashTable( tf.lookup.KeyValueTensorInitializer( discovered_prefixes, tf.range(tf.shape(discovered_prefixes)[0])), DEFAULT_VALUE) possible_prefix_extensions_table = tf.lookup.StaticHashTable( tf.lookup.KeyValueTensorInitializer( possible_prefix_extensions, tf.range(tf.shape(possible_prefix_extensions)[0])), DEFAULT_VALUE) accumulate_client_votes_fn = make_accumulate_client_votes_fn( round_num, num_sub_rounds, discovered_prefixes_table, possible_prefix_extensions_table, default_terminator) sampled_data_list = hh_utils.get_top_elements(dataset, max_user_contribution) sampled_data = tf.data.Dataset.from_tensor_slices(sampled_data_list) return ClientOutput( sampled_data.reduce(client_votes, accumulate_client_votes_fn))
def test_over_max_contribution(self): ds = tf.data.Dataset.from_tensor_slices(['a', 'b', 'a', 'c', 'b', 'c', 'c']) top_elements = hh_utils.get_top_elements(ds, max_user_contribution=2) self.assertCountEqual(top_elements.numpy(), [b'a', b'c'])
def test_empty_dataset(self): ds = tf.data.Dataset.from_tensor_slices([]) top_elements = hh_utils.get_top_elements(ds, max_user_contribution=10) self.assertEmpty(top_elements)