def test_accumulate_server_votes_works_as_expected(self): discovered_prefixes = ['a', 'b'] discovered_heavy_hitters = [] heavy_hitters_counts = [] initial_votes = tf.constant( [[1, 2, 1, 0, 0], [1, 2, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=tf.int32) server_state = triehh_tf.ServerState( discovered_heavy_hitters=tf.constant( discovered_heavy_hitters, dtype=tf.string), heavy_hitters_counts=tf.constant(heavy_hitters_counts, dtype=tf.int32), discovered_prefixes=tf.constant(discovered_prefixes, dtype=tf.string), round_num=tf.constant(0, dtype=tf.int32), accumulated_votes=initial_votes) sub_round_votes = tf.constant( [[1, 2, 1, 0, 0], [1, 2, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=tf.int32) server_state = triehh_tf.accumulate_server_votes(server_state, sub_round_votes) expected_accumulated_votes = tf.constant( [[2, 4, 2, 0, 0], [2, 4, 2, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=tf.int32) self.assertAllEqual(server_state.accumulated_votes, expected_accumulated_votes)
def test_server_update_works_as_expected(self): max_num_prefixes = tf.constant(10) threshold = tf.constant(1) num_sub_rounds = tf.constant(1, dtype=tf.int32) possible_prefix_extensions = [ 'a', 'b', 'c', 'd', 'e', triehh_tf.DEFAULT_TERMINATOR ] possible_prefix_extensions = tf.constant( possible_prefix_extensions, dtype=tf.string) discovered_prefixes = ['a', 'b', 'c', 'd', 'e'] discovered_heavy_hitters = [] heavy_hitters_counts = [] server_state = triehh_tf.ServerState( discovered_heavy_hitters=tf.constant( discovered_heavy_hitters, dtype=tf.string), heavy_hitters_counts=tf.constant(heavy_hitters_counts, dtype=tf.int32), discovered_prefixes=tf.constant(discovered_prefixes, dtype=tf.string), round_num=tf.constant(1, dtype=tf.int32), accumulated_votes=tf.zeros( dtype=tf.int32, shape=[max_num_prefixes, len(possible_prefix_extensions)])) sub_round_votes = tf.constant( [[10, 9, 8, 7, 6, 0], [5, 4, 3, 2, 1, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], dtype=tf.int32) server_state = triehh_tf.server_update(server_state, possible_prefix_extensions, sub_round_votes, num_sub_rounds, max_num_prefixes, threshold) expected_discovered_prefixes = tf.constant( ['aa', 'ab', 'ac', 'ad', 'ae', 'ba', 'bb', 'bc', 'bd', 'be'], dtype=tf.string) expected_discovered_heavy_hitters = tf.constant([], dtype=tf.string) expected_heavy_hitters_counts = tf.constant([], dtype=tf.int32) expected_accumulated_votes = tf.constant( [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], dtype=tf.int32) self.assertSetAllEqual(server_state.discovered_prefixes, expected_discovered_prefixes) self.assertSetAllEqual(server_state.discovered_heavy_hitters, expected_discovered_heavy_hitters) self.assertHistogramsEqual(server_state.discovered_heavy_hitters, server_state.heavy_hitters_counts, expected_discovered_heavy_hitters, expected_heavy_hitters_counts) self.assertAllEqual(server_state.accumulated_votes, expected_accumulated_votes)
def test_accumulate_server_votes_and_decode_works_as_expected(self): max_num_prefixes = tf.constant(4) threshold = tf.constant(1) possible_prefix_extensions = [ 'a', 'n', 's', 't', 'u', triehh_tf.DEFAULT_TERMINATOR ] possible_prefix_extensions = tf.constant( possible_prefix_extensions, dtype=tf.string) discovered_prefixes = ['su', 'st'] discovered_heavy_hitters = [] heavy_hitters_counts = [] initial_votes = tf.constant([[1, 2, 1, 0, 0, 0], [1, 2, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], dtype=tf.int32) server_state = triehh_tf.ServerState( discovered_heavy_hitters=tf.constant( discovered_heavy_hitters, dtype=tf.string), heavy_hitters_counts=tf.constant(heavy_hitters_counts, dtype=tf.int32), discovered_prefixes=tf.constant(discovered_prefixes, dtype=tf.string), round_num=tf.constant(3, dtype=tf.int32), accumulated_votes=initial_votes) sub_round_votes = tf.constant([[3, 3, 1, 0, 0, 0], [5, 1, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], dtype=tf.int32) server_state = triehh_tf.accumulate_server_votes_and_decode( server_state, possible_prefix_extensions, sub_round_votes, max_num_prefixes, threshold) expected_accumulated_votes = tf.constant( [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], dtype=tf.int32) expected_discovered_prefixes = tf.constant(['sta', 'sun', 'sua', 'stn'], dtype=tf.string) expected_discovered_heavy_hitters = tf.constant([], dtype=tf.string) expected_heavy_hitters_counts = tf.constant([], dtype=tf.int32) self.assertAllEqual(server_state.accumulated_votes, expected_accumulated_votes) self.assertSetAllEqual(server_state.discovered_prefixes, expected_discovered_prefixes) self.assertSetAllEqual(server_state.discovered_heavy_hitters, expected_discovered_heavy_hitters) self.assertHistogramsEqual(server_state.discovered_heavy_hitters, server_state.heavy_hitters_counts, expected_discovered_heavy_hitters, expected_heavy_hitters_counts)
def test_server_update_does_not_decode_in_a_subround(self): max_num_prefixes = tf.constant(10) threshold = tf.constant(1) num_sub_rounds = tf.constant(2, dtype=tf.int32) possible_prefix_extensions = [ 'a', 'b', 'c', 'd', 'e', triehh_tf.DEFAULT_TERMINATOR ] possible_prefix_extensions = tf.constant( possible_prefix_extensions, dtype=tf.string) discovered_prefixes = [''] discovered_heavy_hitters = [] heavy_hitters_counts = [] server_state = triehh_tf.ServerState( discovered_heavy_hitters=tf.constant( discovered_heavy_hitters, dtype=tf.string), heavy_hitters_counts=tf.constant(heavy_hitters_counts, dtype=tf.int32), discovered_prefixes=tf.constant(discovered_prefixes, dtype=tf.string), round_num=tf.constant(0, dtype=tf.int32), accumulated_votes=tf.zeros( dtype=tf.int32, shape=[max_num_prefixes, len(possible_prefix_extensions)])) sub_round_votes = tf.constant( [[1, 2, 1, 2, 0, 0], [2, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], dtype=tf.int32) server_state = triehh_tf.server_update(server_state, possible_prefix_extensions, sub_round_votes, num_sub_rounds, max_num_prefixes, threshold) expected_discovered_prefixes = tf.constant([''], dtype=tf.string) expected_discovered_heavy_hitters = tf.constant([], dtype=tf.string) expected_heavy_hitters_counts = tf.constant([], dtype=tf.int32) expected_accumulated_votes = sub_round_votes self.assertSetAllEqual(server_state.discovered_prefixes, expected_discovered_prefixes) self.assertSetAllEqual(server_state.discovered_heavy_hitters, expected_discovered_heavy_hitters) self.assertHistogramsEqual(server_state.discovered_heavy_hitters, server_state.heavy_hitters_counts, expected_discovered_heavy_hitters, expected_heavy_hitters_counts) self.assertAllEqual(server_state.accumulated_votes, expected_accumulated_votes)
def test_all_tf_functions_work_together_high_threshold(self): clients = 3 num_sub_rounds = 4 max_rounds = 6 max_num_prefixes = 3 threshold = 100 max_user_contribution = 100 roots = ( string.ascii_lowercase + string.digits + "'@#-;*:./" + triehh_tf.DEFAULT_TERMINATOR) possible_prefix_extensions = list(roots) possible_prefix_extensions_num = len(possible_prefix_extensions) possible_prefix_extensions = tf.constant( possible_prefix_extensions, dtype=tf.string) server_state = triehh_tf.ServerState( discovered_heavy_hitters=tf.constant([], dtype=tf.string), heavy_hitters_counts=tf.constant([], dtype=tf.int32), discovered_prefixes=tf.constant([''], dtype=tf.string), round_num=tf.constant(0, dtype=tf.int32), accumulated_votes=tf.zeros( dtype=tf.int32, shape=[max_num_prefixes, possible_prefix_extensions_num])) def create_dataset_fn(client_id): del client_id return tf.data.Dataset.from_tensor_slices(['hello', 'hey', 'hi']) client_ids = list(range(100)) client_data = tff.simulation.datasets.ClientData.from_clients_and_fn( client_ids=client_ids, create_tf_dataset_for_client_fn=create_dataset_fn) for round_num in range(max_rounds * num_sub_rounds): sampled_clients = list(range(clients)) sampled_datasets = [ client_data.create_tf_dataset_for_client(client_id) for client_id in sampled_clients ] accumulated_votes = tf.zeros( dtype=tf.int32, shape=[max_num_prefixes, possible_prefix_extensions_num]) # This is a workaround to clear the graph cache in the `tf.function`; this # is necessary because we need to construct a new lookup table every round # based on new prefixes. client_update = tf.function(triehh_tf.client_update.python_function) for dataset in sampled_datasets: client_output = client_update( dataset, server_state.discovered_prefixes, possible_prefix_extensions, round_num, tf.constant(num_sub_rounds), tf.constant(max_num_prefixes, dtype=tf.int32), tf.constant(max_user_contribution, dtype=tf.int32), tf.constant(triehh_tf.DEFAULT_TERMINATOR, dtype=tf.string)) accumulated_votes += client_output.client_votes server_state = triehh_tf.server_update( server_state, possible_prefix_extensions, accumulated_votes, tf.constant(num_sub_rounds, dtype=tf.int32), tf.constant(max_num_prefixes, dtype=tf.int32), tf.constant(threshold, dtype=tf.int32)) expected_discovered_heavy_hitters = tf.constant([], dtype=tf.string) expected_heavy_hitters_counts = tf.constant([], dtype=tf.int32) expected_discovered_prefixes = tf.constant([], dtype=tf.string) self.assertSetAllEqual(server_state.discovered_heavy_hitters, expected_discovered_heavy_hitters) self.assertHistogramsEqual(server_state.discovered_heavy_hitters, server_state.heavy_hitters_counts, expected_discovered_heavy_hitters, expected_heavy_hitters_counts) self.assertSetAllEqual(server_state.discovered_prefixes, expected_discovered_prefixes)