def test_extend_prefixes_works_as_expected(self): possible_prefix_extensions = tf.constant(['a', 'b', 'c', 'd', '$'], dtype=tf.string) prefixes_to_extend = tf.constant(['a', 'b', 'c', 'd', 'e'], dtype=tf.string) extended_prefixes = triehh_tf.extend_prefixes(prefixes_to_extend, possible_prefix_extensions) expected_extended_prefixes = tf.constant([ 'aa', 'ab', 'ac', 'ad', 'a$', 'ba', 'bb', 'bc', 'bd', 'b$', 'ca', 'cb', 'cc', 'cd', 'c$', 'da', 'db', 'dc', 'dd', 'd$', 'ea', 'eb', 'ec', 'ed', 'e$' ], dtype=tf.string) self.assertAllEqual(extended_prefixes, expected_extended_prefixes)
def test_extend_prefixes_with_threshold_works_as_expected(self): extensions_wo_terminator = tf.constant(['a', 'b', 'c', 'd'], dtype=tf.string) discovered_prefixes = tf.constant(['a', 'b', 'c'], dtype=tf.string) threshold = threshold = tf.constant(3) max_num_prefixes = tf.constant(20) prefixes_votes = tf.constant([4, 2, 3, 0, 7, 1, 0, 0, 0, 0, 0, 8], dtype=tf.int32) extended_prefixes = triehh_tf.extend_prefixes( prefixes_votes, discovered_prefixes, extensions_wo_terminator, max_num_prefixes, threshold) expected_extended_prefixes = tf.constant(['aa', 'ac', 'ba', 'cd'], dtype=tf.string) self.assertSetAllEqual(extended_prefixes, expected_extended_prefixes)