def test_client_update_works_as_expected(self): max_num_prefixes = tf.constant(10) max_user_contribution = tf.constant(10) possible_prefix_extensions = tf.constant( ['a', 'b', 'c', 'd', 'e', triehh_tf.DEFAULT_TERMINATOR], dtype=tf.string) discovered_prefixes = tf.constant(['a', 'b', 'c', 'd', 'e'], dtype=tf.string) round_num = tf.constant(1) num_sub_rounds = tf.constant(1) sample_data = tf.data.Dataset.from_tensor_slices( ['a', '', 'abc', 'bac', 'abb', 'aaa', 'acc', 'hi']) client_output = triehh_tf.client_update( sample_data, discovered_prefixes, possible_prefix_extensions, round_num, num_sub_rounds, max_num_prefixes, max_user_contribution, tf.constant(triehh_tf.DEFAULT_TERMINATOR, dtype=tf.string)) # Each string is attached with triehh_tf.DEFAULT_TERMINATOR before the # client votes, so 'a$' get a vote here. expected_client_votes = tf.constant( [[1, 2, 1, 0, 0, 1], [1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], dtype=tf.int32) self.assertAllEqual(client_output.client_votes, expected_client_votes)
def test_client_update_works_on_empty_local_datasets(self): max_num_prefixes = tf.constant(10) max_user_contribution = tf.constant(10) possible_prefix_extensions = tf.constant( ['a', 'b', 'c', 'd', 'e', triehh_tf.DEFAULT_TERMINATOR], dtype=tf.string) discovered_prefixes = tf.constant(['a', 'b', 'c', 'd', 'e'], dtype=tf.string) round_num = tf.constant(1) num_sub_rounds = tf.constant(1) # Force an empty dataset that yields tf.string. Using `from_tensor_slices` # defaults to yielding tf.int32 values. sample_data = tf.data.Dataset.from_generator( generator=lambda: iter(()), output_types=tf.string, output_shapes=()) client_output = triehh_tf.client_update( sample_data, discovered_prefixes, possible_prefix_extensions, round_num, num_sub_rounds, max_num_prefixes, max_user_contribution, tf.constant(triehh_tf.DEFAULT_TERMINATOR, dtype=tf.string)) expected_client_votes = tf.constant( [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], dtype=tf.int32) self.assertAllEqual(client_output.client_votes, expected_client_votes)
def test_client_update_works_on_empty_discovered_prefixes(self): max_num_prefixes = tf.constant(10) max_user_contribution = tf.constant(10) possible_prefix_extensions = tf.constant( ['a', 'b', 'c', 'd', 'e', triehh_tf.DEFAULT_TERMINATOR], dtype=tf.string) discovered_prefixes = tf.constant([], dtype=tf.string) round_num = tf.constant(1) num_sub_rounds = tf.constant(1) sample_data = tf.data.Dataset.from_tensor_slices( ['a', '', 'abc', 'bac', 'abb', 'aaa', 'acc', 'hi']) client_output = triehh_tf.client_update( sample_data, discovered_prefixes, possible_prefix_extensions, round_num, num_sub_rounds, max_num_prefixes, max_user_contribution, tf.constant(triehh_tf.DEFAULT_TERMINATOR, dtype=tf.string)) expected_client_votes = tf.constant( [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], dtype=tf.int32) self.assertAllEqual(client_output.client_votes, expected_client_votes)
def client_update_fn(tf_dataset, discovered_prefixes, round_num): return client_update(tf_dataset, discovered_prefixes, tf.constant(possible_prefix_extensions), round_num, num_sub_rounds, max_num_prefixes, max_user_contribution, tf.constant(default_terminator, dtype=tf.string))