def test_accumulate_client_votes_works_as_expected(self): possible_prefix_extensions = tf.constant( ['a', 'b', 'c', 'd', 'e', triehh_tf.DEFAULT_TERMINATOR], dtype=tf.string) discovered_prefixes = tf.constant(['a', 'b', 'c', 'd'], dtype=tf.string) round_num = tf.constant(1) num_sub_rounds = tf.constant(1) example1 = tf.constant('ab', dtype=tf.string) discovered_prefixes_table = tf.lookup.StaticHashTable( tf.lookup.KeyValueTensorInitializer( discovered_prefixes, tf.range(tf.shape(discovered_prefixes)[0])), triehh_tf.DEFAULT_VALUE) possible_prefix_extensions_table = tf.lookup.StaticHashTable( tf.lookup.KeyValueTensorInitializer( possible_prefix_extensions, tf.range(tf.shape(possible_prefix_extensions)[0])), triehh_tf.DEFAULT_VALUE) accumulate_client_votes = triehh_tf.make_accumulate_client_votes_fn( round_num, num_sub_rounds, discovered_prefixes_table, possible_prefix_extensions_table, tf.constant(triehh_tf.DEFAULT_TERMINATOR, dtype=tf.string)) initial_votes = tf.constant( [[1, 2, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0 ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=tf.int32) accumulated_votes = accumulate_client_votes(initial_votes, example1) expected_accumulated_votes = tf.constant( [[1, 3, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0 ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=tf.int32) self.assertAllEqual(accumulated_votes, expected_accumulated_votes) # An example that the prefix is not in the discovered prefixes. # The expected result is that the vote is not counted. example2 = tf.constant('ea', dtype=tf.string) accumulated_votes = accumulate_client_votes(initial_votes, example2) self.assertAllEqual(accumulated_votes, initial_votes)
def test_accumulate_client_votes_works_as_expected(self): possible_prefix_extensions = tf.constant(['a', 'b', 'c', 'd', '$'], dtype=tf.string) discovered_prefixes = tf.constant(['a', 'b', 'c', 'd'], dtype=tf.string) round_num = tf.constant(1) num_sub_rounds = tf.constant(1) example = tf.constant('ab', dtype=tf.string) discovered_prefixes_table = tf.lookup.StaticHashTable( tf.lookup.KeyValueTensorInitializer( discovered_prefixes, tf.range(tf.shape(discovered_prefixes)[0])), triehh_tf.DEFAULT_VALUE) possible_prefix_extensions_table = tf.lookup.StaticHashTable( tf.lookup.KeyValueTensorInitializer( possible_prefix_extensions, tf.range(tf.shape(possible_prefix_extensions)[0])), triehh_tf.DEFAULT_VALUE) accumulate_client_votes = triehh_tf.make_accumulate_client_votes_fn( round_num, num_sub_rounds, discovered_prefixes_table, possible_prefix_extensions_table) initial_votes = tf.constant( [[1, 2, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0 ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=tf.int32) accumulated_votes = accumulate_client_votes(initial_votes, example) expected_accumulated_votes = tf.constant( [[1, 3, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0 ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=tf.int32) self.assertAllEqual(accumulated_votes, expected_accumulated_votes)