Exemple #1
0
    def testSavedModelSaveRestore(self):
        save_dir = os.path.join(self.get_temp_dir(), 'save_restore')
        save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), 'hash')

        # TODO(b/203097231) is there an alternative that is not __internal__?
        root = tf.__internal__.tracking.AutoTrackable()

        default_value = -1
        root.table = simple_hash_table.SimpleHashTable(
            tf.int64, tf.int64, default_value=default_value)

        @def_function.function(input_signature=[tf.TensorSpec((), tf.int64)])
        def lookup(key):
            return root.table.find(key)

        root.lookup = lookup

        root.table.insert(1, 100)
        root.table.insert(2, 200)
        root.table.insert(3, 300)
        self.assertEqual(root.lookup(2), 200)
        self.assertAllEqual(3, len(self.evaluate(root.table.export()[0])))
        tf.saved_model.save(root, save_path)

        del root
        loaded = tf.saved_model.load(save_path)
        self.assertEqual(loaded.lookup(2), 200)
        self.assertEqual(loaded.lookup(10), -1)
Exemple #2
0
 def test_export(self):
     table = simple_hash_table.SimpleHashTable(tf.int64,
                                               tf.int64,
                                               default_value=-1)
     table.insert(1, 100)
     table.insert(2, 200)
     table.insert(3, 300)
     keys, values = self.evaluate(table.export())
     self.assertAllEqual(sorted(keys), [1, 2, 3])
     self.assertAllEqual(sorted(values), [100, 200, 300])
Exemple #3
0
 def _use_table(self, key_dtype, value_dtype):
     hash_table = simple_hash_table.SimpleHashTable(key_dtype, value_dtype,
                                                    111)
     result1 = hash_table.find(1, -999)
     hash_table.insert(1, 100)
     result2 = hash_table.find(1, -999)
     hash_table.remove(1)
     result3 = hash_table.find(1, -999)
     results = tf.stack((result1, result2, result3))
     return results  # expect [-999, 100, -999]
Exemple #4
0
 def test_import(self):
     table = simple_hash_table.SimpleHashTable(tf.int64,
                                               tf.int64,
                                               default_value=-1)
     keys = tf.constant([1, 2, 3], dtype=tf.int64)
     values = tf.constant([100, 200, 300], dtype=tf.int64)
     table.do_import(keys, values)
     self.assertEqual(table.find(1), 100)
     self.assertEqual(table.find(2), 200)
     self.assertEqual(table.find(3), 300)
     self.assertEqual(table.find(9), -1)
Exemple #5
0
 def test_find_insert_find_strings_eager(self):
     default = 'Default'
     foo = 'Foo'
     bar = 'Bar'
     hash_table = simple_hash_table.SimpleHashTable(tf.string, tf.string,
                                                    default)
     result1 = hash_table.find(foo, default)
     self.assertEqual(result1, default)
     hash_table.insert(foo, bar)
     result2 = hash_table.find(foo, default)
     self.assertEqual(result2, bar)