Ejemplo n.º 1
0
  def test_accumulate_server_votes_works_as_expected(self):
    possible_prefix_extensions = ['a', 'b', 'c', 'd', 'e']
    discovered_prefixes = ['a', 'b']
    discovered_heavy_hitters = []
    initial_votes = tf.constant(
        [[1, 2, 1, 0, 0], [1, 2, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
        dtype=tf.int32)

    server_state = triehh_tf.ServerState(
        discovered_heavy_hitters=tf.constant(
            discovered_heavy_hitters, dtype=tf.string),
        discovered_prefixes=tf.constant(discovered_prefixes, dtype=tf.string),
        possible_prefix_extensions=tf.constant(
            possible_prefix_extensions, dtype=tf.string),
        round_num=tf.constant(0, dtype=tf.int32),
        accumulated_votes=initial_votes)

    sub_round_votes = tf.constant(
        [[1, 2, 1, 0, 0], [1, 2, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
        dtype=tf.int32)

    server_state = triehh_tf.accumulate_server_votes(server_state,
                                                     sub_round_votes)
    expected_accumulated_votes = tf.constant(
        [[2, 4, 2, 0, 0], [2, 4, 2, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
        dtype=tf.int32)

    self.assertAllEqual(server_state.accumulated_votes,
                        expected_accumulated_votes)
Ejemplo n.º 2
0
    def test_server_update_finds_heavy_hitters_with_threshold(self):
        max_num_prefixes = tf.constant(10)
        threshold = tf.constant(5)
        num_sub_rounds = tf.constant(1, dtype=tf.int32)
        possible_prefix_extensions = [
            'a', 'b', 'c', 'd', triehh_tf.DEFAULT_TERMINATOR
        ]
        possible_prefix_extensions = tf.constant(possible_prefix_extensions,
                                                 dtype=tf.string)
        discovered_prefixes = [
            'a', 'b', 'c', 'd', triehh_tf.DEFAULT_TERMINATOR
        ]
        discovered_heavy_hitters = []
        heavy_hitters_counts = []

        server_state = triehh_tf.ServerState(
            discovered_heavy_hitters=tf.constant(discovered_heavy_hitters,
                                                 dtype=tf.string),
            heavy_hitters_counts=tf.constant(heavy_hitters_counts,
                                             dtype=tf.int32),
            discovered_prefixes=tf.constant(discovered_prefixes,
                                            dtype=tf.string),
            round_num=tf.constant(1, dtype=tf.int32),
            accumulated_votes=tf.zeros(
                dtype=tf.int32,
                shape=[max_num_prefixes,
                       len(possible_prefix_extensions)]))

        sub_round_votes = tf.constant(
            [[10, 9, 8, 7, 6], [5, 4, 3, 0, 4], [2, 1, 0, 0, 0], [
                0, 0, 0, 0, 0
            ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
            dtype=tf.int32)

        server_state = triehh_tf.server_update(server_state,
                                               possible_prefix_extensions,
                                               sub_round_votes, num_sub_rounds,
                                               max_num_prefixes, threshold)
        expected_discovered_prefixes = tf.constant(
            ['aa', 'ab', 'ac', 'ad', 'ba'], dtype=tf.string)
        expected_discovered_heavy_hitters = tf.constant(['a'], dtype=tf.string)
        expected_heavy_hitters_counts = tf.constant([6], dtype=tf.int32)
        expected_accumulated_votes = tf.constant(
            [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [
                0, 0, 0, 0, 0
            ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
            dtype=tf.int32)

        self.assertSetAllEqual(server_state.discovered_prefixes,
                               expected_discovered_prefixes)
        self.assertSetAllEqual(server_state.discovered_heavy_hitters,
                               expected_discovered_heavy_hitters)
        self.assertHistogramsEqual(server_state.discovered_heavy_hitters,
                                   server_state.heavy_hitters_counts,
                                   expected_discovered_heavy_hitters,
                                   expected_heavy_hitters_counts)
        self.assertAllEqual(server_state.accumulated_votes,
                            expected_accumulated_votes)
Ejemplo n.º 3
0
    def test_accumulate_server_votes_and_decode_threhold_works_as_expected(
            self):
        max_num_prefixes = tf.constant(4)
        threshold = tf.constant(5)
        possible_prefix_extensions = [
            'a', 'n', 's', 't', 'u', triehh_tf.DEFAULT_TERMINATOR
        ]
        possible_prefix_extensions = tf.constant(possible_prefix_extensions,
                                                 dtype=tf.string)
        discovered_prefixes = ['su', 'st']
        discovered_heavy_hitters = []
        heavy_hitter_frequencies = []
        initial_votes = tf.constant([[1, 2, 1, 0, 0, 0], [1, 2, 1, 0, 0, 0],
                                     [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]],
                                    dtype=tf.int32)
        initial_weights = tf.constant(10, dtype=tf.int32)
        server_state = triehh_tf.ServerState(
            discovered_heavy_hitters=tf.constant(discovered_heavy_hitters,
                                                 dtype=tf.string),
            heavy_hitter_frequencies=tf.constant(heavy_hitter_frequencies,
                                                 dtype=tf.float64),
            discovered_prefixes=tf.constant(discovered_prefixes,
                                            dtype=tf.string),
            round_num=tf.constant(3, dtype=tf.int32),
            accumulated_votes=initial_votes,
            accumulated_weights=initial_weights)

        sub_round_votes = tf.constant([[3, 3, 1, 0, 0, 0], [5, 1, 1, 0, 0, 0],
                                       [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]],
                                      dtype=tf.int32)
        sub_round_weight = tf.constant(20, dtype=tf.int32)

        server_state = triehh_tf.accumulate_server_votes_and_decode(
            server_state, possible_prefix_extensions, sub_round_votes,
            sub_round_weight, max_num_prefixes, threshold)

        expected_accumulated_votes = tf.constant(
            [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0]],
            dtype=tf.int32)
        expected_accumulated_weights = tf.constant(0, dtype=tf.int32)
        expected_discovered_prefixes = tf.constant(['sta', 'sun'],
                                                   dtype=tf.string)
        expected_discovered_heavy_hitters = tf.constant([], dtype=tf.string)
        expected_heavy_hitter_frequencies = tf.constant([], dtype=tf.float64)

        self.assertAllEqual(server_state.accumulated_votes,
                            expected_accumulated_votes)
        self.assertEqual(server_state.accumulated_weights,
                         expected_accumulated_weights)
        self.assertSetAllEqual(server_state.discovered_prefixes,
                               expected_discovered_prefixes)
        self.assertSetAllEqual(server_state.discovered_heavy_hitters,
                               expected_discovered_heavy_hitters)
        self.assertHistogramsEqual(server_state.discovered_heavy_hitters,
                                   server_state.heavy_hitter_frequencies,
                                   expected_discovered_heavy_hitters,
                                   expected_heavy_hitter_frequencies)
  def test_server_update_does_not_decode_in_a_subround(self):
    max_num_prefixes = tf.constant(10)
    threshold = tf.constant(1)
    num_sub_rounds = tf.constant(2, dtype=tf.int32)
    possible_prefix_extensions = [
        'a', 'b', 'c', 'd', 'e', triehh_tf.DEFAULT_TERMINATOR
    ]
    possible_prefix_extensions = tf.constant(
        possible_prefix_extensions, dtype=tf.string)
    discovered_prefixes = ['']
    discovered_heavy_hitters = []
    heavy_hitters_frequencies = []

    server_state = triehh_tf.ServerState(
        discovered_heavy_hitters=tf.constant(
            discovered_heavy_hitters, dtype=tf.string),
        heavy_hitters_frequencies=tf.constant(
            heavy_hitters_frequencies, dtype=tf.float64),
        discovered_prefixes=tf.constant(discovered_prefixes, dtype=tf.string),
        round_num=tf.constant(0, dtype=tf.int32),
        accumulated_votes=tf.zeros(
            dtype=tf.int32,
            shape=[max_num_prefixes,
                   len(possible_prefix_extensions)]),
        accumulated_weights=tf.constant(0, dtype=tf.int32))

    sub_round_votes = tf.constant(
        [[1, 2, 1, 2, 0, 0], [2, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0]],
        dtype=tf.int32)
    sub_round_weight = tf.constant(10, dtype=tf.int32)

    server_state = triehh_tf.server_update(server_state,
                                           possible_prefix_extensions,
                                           sub_round_votes, sub_round_weight,
                                           num_sub_rounds, max_num_prefixes,
                                           threshold)
    expected_discovered_prefixes = tf.constant([''], dtype=tf.string)
    expected_discovered_heavy_hitters = tf.constant([], dtype=tf.string)
    expected_heavy_hitters_frequencies = tf.constant([], dtype=tf.float64)
    expected_accumulated_votes = sub_round_votes
    expected_accumulated_weights = sub_round_weight

    self.assertSetAllEqual(server_state.discovered_prefixes,
                           expected_discovered_prefixes)
    self.assertSetAllEqual(server_state.discovered_heavy_hitters,
                           expected_discovered_heavy_hitters)
    self.assertHistogramsEqual(server_state.discovered_heavy_hitters,
                               server_state.heavy_hitters_frequencies,
                               expected_discovered_heavy_hitters,
                               expected_heavy_hitters_frequencies)
    self.assertAllEqual(server_state.accumulated_votes,
                        expected_accumulated_votes)
    self.assertEqual(server_state.accumulated_weights,
                     expected_accumulated_weights)
Ejemplo n.º 5
0
    def test_server_update_finds_heavy_hitters(self):
        max_num_heavy_hitters = tf.constant(10)
        default_terminator = tf.constant('$', tf.string)
        num_sub_rounds = tf.constant(1, dtype=tf.int32)
        possible_prefix_extensions = ['a', 'b', 'c', 'd', '$']
        discovered_prefixes = ['a', 'b', 'c', 'd', '$']
        discovered_heavy_hitters = []

        server_state = triehh_tf.ServerState(
            discovered_heavy_hitters=tf.constant(discovered_heavy_hitters,
                                                 dtype=tf.string),
            discovered_prefixes=tf.constant(discovered_prefixes,
                                            dtype=tf.string),
            possible_prefix_extensions=tf.constant(possible_prefix_extensions,
                                                   dtype=tf.string),
            round_num=tf.constant(1, dtype=tf.int32),
            accumulated_votes=tf.zeros(
                dtype=tf.int32,
                shape=[max_num_heavy_hitters,
                       len(possible_prefix_extensions)]))

        sub_round_votes = tf.constant(
            [[10, 9, 8, 7, 6], [5, 4, 3, 0, 0], [2, 1, 0, 0, 0], [
                0, 0, 0, 0, 0
            ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
            dtype=tf.int32)

        server_state = triehh_tf.server_update(server_state, sub_round_votes,
                                               num_sub_rounds,
                                               max_num_heavy_hitters,
                                               default_terminator)
        expected_discovered_prefixes = tf.constant(
            ['aa', 'ab', 'ac', 'ad', 'ba', 'bb', 'bc', 'ca', 'cb'],
            dtype=tf.string)
        expected_discovered_heavy_hitters = tf.constant(['a'], dtype=tf.string)
        expected_accumulated_votes = tf.constant(
            [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [
                0, 0, 0, 0, 0
            ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
            dtype=tf.int32)

        self.assertAllEqual(server_state.discovered_prefixes,
                            expected_discovered_prefixes)
        self.assertAllEqual(server_state.discovered_heavy_hitters,
                            expected_discovered_heavy_hitters)
        self.assertAllEqual(server_state.accumulated_votes,
                            expected_accumulated_votes)
Ejemplo n.º 6
0
    def test_server_update_works_on_empty_discovered_prefixes(self):
        max_num_prefixes = tf.constant(10)
        threshold = tf.constant(1)
        num_sub_rounds = tf.constant(1, dtype=tf.int32)
        possible_prefix_extensions = [
            'a', 'b', 'c', 'd', 'e', triehh_tf.DEFAULT_TERMINATOR
        ]
        possible_prefix_extensions = tf.constant(possible_prefix_extensions,
                                                 dtype=tf.string)
        discovered_prefixes = []
        discovered_heavy_hitters = []

        server_state = triehh_tf.ServerState(
            discovered_heavy_hitters=tf.constant(discovered_heavy_hitters,
                                                 dtype=tf.string),
            discovered_prefixes=tf.constant(discovered_prefixes,
                                            dtype=tf.string),
            round_num=tf.constant(1, dtype=tf.int32),
            accumulated_votes=tf.zeros(
                dtype=tf.int32,
                shape=[max_num_prefixes,
                       len(possible_prefix_extensions)]))

        sub_round_votes = tf.constant(
            [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0]],
            dtype=tf.int32)

        server_state = triehh_tf.server_update(server_state,
                                               possible_prefix_extensions,
                                               sub_round_votes, num_sub_rounds,
                                               max_num_prefixes, threshold)
        expected_discovered_prefixes = tf.constant([], dtype=tf.string)
        expected_discovered_heavy_hitters = tf.constant([], dtype=tf.string)
        expected_accumulated_votes = tf.constant(
            [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0]],
            dtype=tf.int32)

        self.assertSetAllEqual(server_state.discovered_prefixes,
                               expected_discovered_prefixes)
        self.assertSetAllEqual(server_state.discovered_heavy_hitters,
                               expected_discovered_heavy_hitters)
        self.assertAllEqual(server_state.accumulated_votes,
                            expected_accumulated_votes)
Ejemplo n.º 7
0
    def test_accumulate_server_votes_works_as_expected(self):
        discovered_prefixes = ['a', 'b']
        discovered_heavy_hitters = []
        heavy_hitter_frequencies = []
        initial_votes = tf.constant(
            [[1, 2, 1, 0, 0], [1, 2, 1, 0, 0], [0, 0, 0, 0, 0], [
                0, 0, 0, 0, 0
            ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
            dtype=tf.int32)
        initial_weights = tf.constant(10, dtype=tf.int32)

        server_state = triehh_tf.ServerState(
            discovered_heavy_hitters=tf.constant(discovered_heavy_hitters,
                                                 dtype=tf.string),
            heavy_hitter_frequencies=tf.constant(heavy_hitter_frequencies,
                                                 dtype=tf.float64),
            discovered_prefixes=tf.constant(discovered_prefixes,
                                            dtype=tf.string),
            round_num=tf.constant(0, dtype=tf.int32),
            accumulated_votes=initial_votes,
            accumulated_weights=initial_weights)

        sub_round_votes = tf.constant(
            [[1, 2, 1, 0, 0], [1, 2, 1, 0, 0], [0, 0, 0, 0, 0], [
                0, 0, 0, 0, 0
            ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
            dtype=tf.int32)
        sub_round_weight = tf.constant(10, dtype=tf.int32)

        server_state = triehh_tf.accumulate_server_votes(
            server_state, sub_round_votes, sub_round_weight)
        expected_accumulated_votes = tf.constant(
            [[2, 4, 2, 0, 0], [2, 4, 2, 0, 0], [0, 0, 0, 0, 0], [
                0, 0, 0, 0, 0
            ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
            dtype=tf.int32)
        expected_accumulated_weights = tf.constant(20, dtype=tf.int32)

        self.assertAllEqual(server_state.accumulated_votes,
                            expected_accumulated_votes)
        self.assertEqual(server_state.accumulated_weights,
                         expected_accumulated_weights)
Ejemplo n.º 8
0
    def test_server_update_does_not_decode_in_a_subround(self):
        max_num_heavy_hitters = tf.constant(10)
        default_terminator = tf.constant('$', tf.string)
        num_sub_rounds = tf.constant(2, dtype=tf.int32)
        possible_prefix_extensions = ['a', 'b', 'c', 'd', 'e']
        discovered_prefixes = ['']
        discovered_heavy_hitters = []

        server_state = triehh_tf.ServerState(
            discovered_heavy_hitters=tf.constant(discovered_heavy_hitters,
                                                 dtype=tf.string),
            discovered_prefixes=tf.constant(discovered_prefixes,
                                            dtype=tf.string),
            possible_prefix_extensions=tf.constant(possible_prefix_extensions,
                                                   dtype=tf.string),
            round_num=tf.constant(0, dtype=tf.int32),
            accumulated_votes=tf.zeros(
                dtype=tf.int32,
                shape=[max_num_heavy_hitters,
                       len(possible_prefix_extensions)]))

        sub_round_votes = tf.constant(
            [[1, 2, 1, 2, 0], [2, 0, 0, 0, 0], [0, 0, 0, 0, 0], [
                0, 0, 0, 0, 0
            ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
            dtype=tf.int32)

        server_state = triehh_tf.server_update(server_state, sub_round_votes,
                                               num_sub_rounds,
                                               max_num_heavy_hitters,
                                               default_terminator)
        expected_discovered_prefixes = tf.constant([''], dtype=tf.string)
        expected_discovered_heavy_hitters = tf.constant([], dtype=tf.string)
        expected_accumulated_votes = sub_round_votes

        self.assertAllEqual(server_state.discovered_prefixes,
                            expected_discovered_prefixes)
        self.assertAllEqual(server_state.discovered_heavy_hitters,
                            expected_discovered_heavy_hitters)
        self.assertAllEqual(server_state.accumulated_votes,
                            expected_accumulated_votes)
Ejemplo n.º 9
0
    def test_accumulate_server_votes_and_decode_works_as_expected(self):
        max_num_heavy_hitters = tf.constant(4)
        default_terminator = tf.constant('$', tf.string)
        possible_prefix_extensions = ['a', 'n', 's', 't', 'u']
        discovered_prefixes = ['su', 'st']
        discovered_heavy_hitters = []
        initial_votes = tf.constant([[1, 2, 1, 0, 0], [1, 2, 1, 0, 0],
                                     [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
                                    dtype=tf.int32)

        server_state = triehh_tf.ServerState(
            discovered_heavy_hitters=tf.constant(discovered_heavy_hitters,
                                                 dtype=tf.string),
            discovered_prefixes=tf.constant(discovered_prefixes,
                                            dtype=tf.string),
            possible_prefix_extensions=tf.constant(possible_prefix_extensions,
                                                   dtype=tf.string),
            round_num=tf.constant(3, dtype=tf.int32),
            accumulated_votes=initial_votes)

        sub_round_votes = tf.constant([[3, 3, 1, 0, 0], [5, 1, 1, 0, 0],
                                       [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
                                      dtype=tf.int32)

        server_state = triehh_tf.accumulate_server_votes_and_decode(
            server_state, sub_round_votes, max_num_heavy_hitters,
            default_terminator)

        expected_accumulated_votes = tf.constant(
            [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0]],
            dtype=tf.int32)
        expected_discovered_prefixes = tf.constant(
            ['sta', 'sun', 'sua', 'stn'], dtype=tf.string)
        expected_discovered_heavy_hitters = tf.constant([], dtype=tf.string)

        self.assertAllEqual(server_state.accumulated_votes,
                            expected_accumulated_votes)
        self.assertAllEqual(server_state.discovered_prefixes,
                            expected_discovered_prefixes)
        self.assertAllEqual(server_state.discovered_heavy_hitters,
                            expected_discovered_heavy_hitters)
Ejemplo n.º 10
0
    def test_all_tf_functions_work_together(self):
        clients = 3
        num_sub_rounds = 4
        max_rounds = 6
        max_num_heavy_hitters = 3
        max_user_contribution = 100
        roots = (string.ascii_lowercase + string.digits + "'@#-;*:./" +
                 triehh_tf.DEFAULT_TERMINATOR)
        possible_prefix_extensions = list(roots)

        server_state = triehh_tf.ServerState(
            discovered_heavy_hitters=tf.constant([], dtype=tf.string),
            discovered_prefixes=tf.constant([''], dtype=tf.string),
            possible_prefix_extensions=tf.constant(possible_prefix_extensions,
                                                   dtype=tf.string),
            round_num=tf.constant(0, dtype=tf.int32),
            accumulated_votes=tf.zeros(
                dtype=tf.int32,
                shape=[max_num_heavy_hitters,
                       len(possible_prefix_extensions)]))

        def create_dataset_fn(client_id):
            del client_id
            return tf.data.Dataset.from_tensor_slices(['hello', 'hey', 'hi'])

        client_ids = list(range(100))

        client_data = tff.simulation.ClientData.from_clients_and_fn(
            client_ids=client_ids,
            create_tf_dataset_for_client_fn=create_dataset_fn)

        for round_num in range(max_rounds * num_sub_rounds):
            sampled_clients = list(range(clients))
            sampled_datasets = [
                client_data.create_tf_dataset_for_client(client_id)
                for client_id in sampled_clients
            ]
            accumulated_votes = tf.zeros(
                dtype=tf.int32,
                shape=[max_num_heavy_hitters,
                       len(possible_prefix_extensions)])

            # This is a workaround to clear the graph cache in the `tf.function`; this
            # is necessary because we need to construct a new lookup table every round
            # based on new prefixes.
            client_update = tf.function(
                triehh_tf.client_update.python_function)

            for dataset in sampled_datasets:
                client_output = client_update(
                    dataset, server_state.discovered_prefixes,
                    server_state.possible_prefix_extensions, round_num,
                    tf.constant(num_sub_rounds),
                    tf.constant(max_num_heavy_hitters, dtype=tf.int32),
                    tf.constant(max_user_contribution, dtype=tf.int32))
                accumulated_votes += client_output.client_votes

            server_state = triehh_tf.server_update(
                server_state, accumulated_votes,
                tf.constant(num_sub_rounds, dtype=tf.int32),
                tf.constant(max_num_heavy_hitters, dtype=tf.int32),
                tf.constant(triehh_tf.DEFAULT_TERMINATOR, dtype=tf.string))

        expected_discovered_heavy_hitters = tf.constant(['hi', 'hey', 'hello'],
                                                        dtype=tf.string)

        self.assertAllEqual(server_state.discovered_heavy_hitters,
                            expected_discovered_heavy_hitters)