def testSavedModelSaveRestore(self): save_dir = os.path.join(self.get_temp_dir(), 'save_restore') save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), 'hash') # TODO(b/203097231) is there an alternative that is not __internal__? root = tf.__internal__.tracking.AutoTrackable() default_value = -1 root.table = simple_hash_table.SimpleHashTable( tf.int64, tf.int64, default_value=default_value) @def_function.function(input_signature=[tf.TensorSpec((), tf.int64)]) def lookup(key): return root.table.find(key) root.lookup = lookup root.table.insert(1, 100) root.table.insert(2, 200) root.table.insert(3, 300) self.assertEqual(root.lookup(2), 200) self.assertAllEqual(3, len(self.evaluate(root.table.export()[0]))) tf.saved_model.save(root, save_path) del root loaded = tf.saved_model.load(save_path) self.assertEqual(loaded.lookup(2), 200) self.assertEqual(loaded.lookup(10), -1)
def test_export(self): table = simple_hash_table.SimpleHashTable(tf.int64, tf.int64, default_value=-1) table.insert(1, 100) table.insert(2, 200) table.insert(3, 300) keys, values = self.evaluate(table.export()) self.assertAllEqual(sorted(keys), [1, 2, 3]) self.assertAllEqual(sorted(values), [100, 200, 300])
def _use_table(self, key_dtype, value_dtype): hash_table = simple_hash_table.SimpleHashTable(key_dtype, value_dtype, 111) result1 = hash_table.find(1, -999) hash_table.insert(1, 100) result2 = hash_table.find(1, -999) hash_table.remove(1) result3 = hash_table.find(1, -999) results = tf.stack((result1, result2, result3)) return results # expect [-999, 100, -999]
def test_import(self): table = simple_hash_table.SimpleHashTable(tf.int64, tf.int64, default_value=-1) keys = tf.constant([1, 2, 3], dtype=tf.int64) values = tf.constant([100, 200, 300], dtype=tf.int64) table.do_import(keys, values) self.assertEqual(table.find(1), 100) self.assertEqual(table.find(2), 200) self.assertEqual(table.find(3), 300) self.assertEqual(table.find(9), -1)
def test_find_insert_find_strings_eager(self): default = 'Default' foo = 'Foo' bar = 'Bar' hash_table = simple_hash_table.SimpleHashTable(tf.string, tf.string, default) result1 = hash_table.find(foo, default) self.assertEqual(result1, default) hash_table.insert(foo, bar) result2 = hash_table.find(foo, default) self.assertEqual(result2, bar)