def test_accumulate_server_votes_works_as_expected(self):
    discovered_prefixes = ['a', 'b']
    discovered_heavy_hitters = []
    heavy_hitters_counts = []
    initial_votes = tf.constant(
        [[1, 2, 1, 0, 0], [1, 2, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
        dtype=tf.int32)

    server_state = triehh_tf.ServerState(
        discovered_heavy_hitters=tf.constant(
            discovered_heavy_hitters, dtype=tf.string),
        heavy_hitters_counts=tf.constant(heavy_hitters_counts, dtype=tf.int32),
        discovered_prefixes=tf.constant(discovered_prefixes, dtype=tf.string),
        round_num=tf.constant(0, dtype=tf.int32),
        accumulated_votes=initial_votes)

    sub_round_votes = tf.constant(
        [[1, 2, 1, 0, 0], [1, 2, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
        dtype=tf.int32)

    server_state = triehh_tf.accumulate_server_votes(server_state,
                                                     sub_round_votes)
    expected_accumulated_votes = tf.constant(
        [[2, 4, 2, 0, 0], [2, 4, 2, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
        dtype=tf.int32)

    self.assertAllEqual(server_state.accumulated_votes,
                        expected_accumulated_votes)
  def test_server_update_works_as_expected(self):
    max_num_prefixes = tf.constant(10)
    threshold = tf.constant(1)
    num_sub_rounds = tf.constant(1, dtype=tf.int32)
    possible_prefix_extensions = [
        'a', 'b', 'c', 'd', 'e', triehh_tf.DEFAULT_TERMINATOR
    ]
    possible_prefix_extensions = tf.constant(
        possible_prefix_extensions, dtype=tf.string)
    discovered_prefixes = ['a', 'b', 'c', 'd', 'e']
    discovered_heavy_hitters = []
    heavy_hitters_counts = []

    server_state = triehh_tf.ServerState(
        discovered_heavy_hitters=tf.constant(
            discovered_heavy_hitters, dtype=tf.string),
        heavy_hitters_counts=tf.constant(heavy_hitters_counts, dtype=tf.int32),
        discovered_prefixes=tf.constant(discovered_prefixes, dtype=tf.string),
        round_num=tf.constant(1, dtype=tf.int32),
        accumulated_votes=tf.zeros(
            dtype=tf.int32,
            shape=[max_num_prefixes,
                   len(possible_prefix_extensions)]))

    sub_round_votes = tf.constant(
        [[10, 9, 8, 7, 6, 0], [5, 4, 3, 2, 1, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0]],
        dtype=tf.int32)

    server_state = triehh_tf.server_update(server_state,
                                           possible_prefix_extensions,
                                           sub_round_votes, num_sub_rounds,
                                           max_num_prefixes, threshold)
    expected_discovered_prefixes = tf.constant(
        ['aa', 'ab', 'ac', 'ad', 'ae', 'ba', 'bb', 'bc', 'bd', 'be'],
        dtype=tf.string)
    expected_discovered_heavy_hitters = tf.constant([], dtype=tf.string)
    expected_heavy_hitters_counts = tf.constant([], dtype=tf.int32)
    expected_accumulated_votes = tf.constant(
        [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0]],
        dtype=tf.int32)

    self.assertSetAllEqual(server_state.discovered_prefixes,
                           expected_discovered_prefixes)
    self.assertSetAllEqual(server_state.discovered_heavy_hitters,
                           expected_discovered_heavy_hitters)
    self.assertHistogramsEqual(server_state.discovered_heavy_hitters,
                               server_state.heavy_hitters_counts,
                               expected_discovered_heavy_hitters,
                               expected_heavy_hitters_counts)
    self.assertAllEqual(server_state.accumulated_votes,
                        expected_accumulated_votes)
  def test_accumulate_server_votes_and_decode_works_as_expected(self):
    max_num_prefixes = tf.constant(4)
    threshold = tf.constant(1)
    possible_prefix_extensions = [
        'a', 'n', 's', 't', 'u', triehh_tf.DEFAULT_TERMINATOR
    ]
    possible_prefix_extensions = tf.constant(
        possible_prefix_extensions, dtype=tf.string)
    discovered_prefixes = ['su', 'st']
    discovered_heavy_hitters = []
    heavy_hitters_counts = []
    initial_votes = tf.constant([[1, 2, 1, 0, 0, 0], [1, 2, 1, 0, 0, 0],
                                 [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]],
                                dtype=tf.int32)

    server_state = triehh_tf.ServerState(
        discovered_heavy_hitters=tf.constant(
            discovered_heavy_hitters, dtype=tf.string),
        heavy_hitters_counts=tf.constant(heavy_hitters_counts, dtype=tf.int32),
        discovered_prefixes=tf.constant(discovered_prefixes, dtype=tf.string),
        round_num=tf.constant(3, dtype=tf.int32),
        accumulated_votes=initial_votes)

    sub_round_votes = tf.constant([[3, 3, 1, 0, 0, 0], [5, 1, 1, 0, 0, 0],
                                   [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]],
                                  dtype=tf.int32)

    server_state = triehh_tf.accumulate_server_votes_and_decode(
        server_state, possible_prefix_extensions, sub_round_votes,
        max_num_prefixes, threshold)

    expected_accumulated_votes = tf.constant(
        [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0]],
        dtype=tf.int32)

    expected_discovered_prefixes = tf.constant(['sta', 'sun', 'sua', 'stn'],
                                               dtype=tf.string)
    expected_discovered_heavy_hitters = tf.constant([], dtype=tf.string)
    expected_heavy_hitters_counts = tf.constant([], dtype=tf.int32)

    self.assertAllEqual(server_state.accumulated_votes,
                        expected_accumulated_votes)
    self.assertSetAllEqual(server_state.discovered_prefixes,
                           expected_discovered_prefixes)
    self.assertSetAllEqual(server_state.discovered_heavy_hitters,
                           expected_discovered_heavy_hitters)
    self.assertHistogramsEqual(server_state.discovered_heavy_hitters,
                               server_state.heavy_hitters_counts,
                               expected_discovered_heavy_hitters,
                               expected_heavy_hitters_counts)
  def test_server_update_does_not_decode_in_a_subround(self):
    max_num_prefixes = tf.constant(10)
    threshold = tf.constant(1)
    num_sub_rounds = tf.constant(2, dtype=tf.int32)
    possible_prefix_extensions = [
        'a', 'b', 'c', 'd', 'e', triehh_tf.DEFAULT_TERMINATOR
    ]
    possible_prefix_extensions = tf.constant(
        possible_prefix_extensions, dtype=tf.string)
    discovered_prefixes = ['']
    discovered_heavy_hitters = []
    heavy_hitters_counts = []

    server_state = triehh_tf.ServerState(
        discovered_heavy_hitters=tf.constant(
            discovered_heavy_hitters, dtype=tf.string),
        heavy_hitters_counts=tf.constant(heavy_hitters_counts, dtype=tf.int32),
        discovered_prefixes=tf.constant(discovered_prefixes, dtype=tf.string),
        round_num=tf.constant(0, dtype=tf.int32),
        accumulated_votes=tf.zeros(
            dtype=tf.int32,
            shape=[max_num_prefixes,
                   len(possible_prefix_extensions)]))

    sub_round_votes = tf.constant(
        [[1, 2, 1, 2, 0, 0], [2, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0]],
        dtype=tf.int32)

    server_state = triehh_tf.server_update(server_state,
                                           possible_prefix_extensions,
                                           sub_round_votes, num_sub_rounds,
                                           max_num_prefixes, threshold)
    expected_discovered_prefixes = tf.constant([''], dtype=tf.string)
    expected_discovered_heavy_hitters = tf.constant([], dtype=tf.string)
    expected_heavy_hitters_counts = tf.constant([], dtype=tf.int32)
    expected_accumulated_votes = sub_round_votes

    self.assertSetAllEqual(server_state.discovered_prefixes,
                           expected_discovered_prefixes)
    self.assertSetAllEqual(server_state.discovered_heavy_hitters,
                           expected_discovered_heavy_hitters)
    self.assertHistogramsEqual(server_state.discovered_heavy_hitters,
                               server_state.heavy_hitters_counts,
                               expected_discovered_heavy_hitters,
                               expected_heavy_hitters_counts)
    self.assertAllEqual(server_state.accumulated_votes,
                        expected_accumulated_votes)
  def test_all_tf_functions_work_together_high_threshold(self):
    clients = 3
    num_sub_rounds = 4
    max_rounds = 6
    max_num_prefixes = 3
    threshold = 100
    max_user_contribution = 100
    roots = (
        string.ascii_lowercase + string.digits + "'@#-;*:./" +
        triehh_tf.DEFAULT_TERMINATOR)
    possible_prefix_extensions = list(roots)
    possible_prefix_extensions_num = len(possible_prefix_extensions)
    possible_prefix_extensions = tf.constant(
        possible_prefix_extensions, dtype=tf.string)

    server_state = triehh_tf.ServerState(
        discovered_heavy_hitters=tf.constant([], dtype=tf.string),
        heavy_hitters_counts=tf.constant([], dtype=tf.int32),
        discovered_prefixes=tf.constant([''], dtype=tf.string),
        round_num=tf.constant(0, dtype=tf.int32),
        accumulated_votes=tf.zeros(
            dtype=tf.int32,
            shape=[max_num_prefixes, possible_prefix_extensions_num]))

    def create_dataset_fn(client_id):
      del client_id
      return tf.data.Dataset.from_tensor_slices(['hello', 'hey', 'hi'])

    client_ids = list(range(100))

    client_data = tff.simulation.datasets.ClientData.from_clients_and_fn(
        client_ids=client_ids,
        create_tf_dataset_for_client_fn=create_dataset_fn)

    for round_num in range(max_rounds * num_sub_rounds):
      sampled_clients = list(range(clients))
      sampled_datasets = [
          client_data.create_tf_dataset_for_client(client_id)
          for client_id in sampled_clients
      ]
      accumulated_votes = tf.zeros(
          dtype=tf.int32,
          shape=[max_num_prefixes, possible_prefix_extensions_num])

      # This is a workaround to clear the graph cache in the `tf.function`; this
      # is necessary because we need to construct a new lookup table every round
      # based on new prefixes.
      client_update = tf.function(triehh_tf.client_update.python_function)

      for dataset in sampled_datasets:
        client_output = client_update(
            dataset, server_state.discovered_prefixes,
            possible_prefix_extensions, round_num, tf.constant(num_sub_rounds),
            tf.constant(max_num_prefixes, dtype=tf.int32),
            tf.constant(max_user_contribution, dtype=tf.int32),
            tf.constant(triehh_tf.DEFAULT_TERMINATOR, dtype=tf.string))
        accumulated_votes += client_output.client_votes

      server_state = triehh_tf.server_update(
          server_state, possible_prefix_extensions, accumulated_votes,
          tf.constant(num_sub_rounds, dtype=tf.int32),
          tf.constant(max_num_prefixes, dtype=tf.int32),
          tf.constant(threshold, dtype=tf.int32))

    expected_discovered_heavy_hitters = tf.constant([], dtype=tf.string)
    expected_heavy_hitters_counts = tf.constant([], dtype=tf.int32)
    expected_discovered_prefixes = tf.constant([], dtype=tf.string)

    self.assertSetAllEqual(server_state.discovered_heavy_hitters,
                           expected_discovered_heavy_hitters)
    self.assertHistogramsEqual(server_state.discovered_heavy_hitters,
                               server_state.heavy_hitters_counts,
                               expected_discovered_heavy_hitters,
                               expected_heavy_hitters_counts)
    self.assertSetAllEqual(server_state.discovered_prefixes,
                           expected_discovered_prefixes)