def testMultipleHashTables(self):
        with self.test_session() as sess:
            shared_name = ''
            default_val = -1
            table1 = tf.HashTable(tf.string, tf.int64, default_val,
                                  shared_name)
            table2 = tf.HashTable(tf.string, tf.int64, default_val,
                                  shared_name)
            table3 = tf.HashTable(tf.string, tf.int64, default_val,
                                  shared_name)

            keys = tf.constant(['brain', 'salad', 'surgery'])
            values = tf.constant([0, 1, 2], tf.int64)
            table1.initialize_from(keys, values)
            table2.initialize_from(keys, values)
            table3.initialize_from(keys, values)

            tf.initialize_all_tables().run()
            self.assertAllEqual(3, table1.size().eval())
            self.assertAllEqual(3, table2.size().eval())
            self.assertAllEqual(3, table3.size().eval())

            input_string = tf.constant(['brain', 'salad', 'tank'])
            output1 = table1.lookup(input_string)
            output2 = table2.lookup(input_string)
            output3 = table3.lookup(input_string)

            out1, out2, out3 = sess.run([output1, output2, output3])
            self.assertAllEqual([0, 1, -1], out1)
            self.assertAllEqual([0, 1, -1], out2)
            self.assertAllEqual([0, 1, -1], out3)
    def testSignatureMismatch(self):
        with self.test_session():
            shared_name = ''
            default_val = -1
            table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)

            # Initialize with keys and values tensors.
            keys = tf.constant(['brain', 'salad', 'surgery'])
            values = tf.constant([0, 1, 2], tf.int64)
            init = table.initialize_from(keys, values)
            init.run()

            input_string = tf.constant([1, 2, 3], tf.int64)
            with self.assertRaises(TypeError):
                table.lookup(input_string)

            with self.assertRaises(TypeError):
                tf.HashTable(tf.string, tf.int64, 'UNK', shared_name)
    def testInitializationWithInvalidDataTypes(self):
        with self.test_session():
            shared_name = ''
            default_val = -1
            table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)

            # Initialize with keys and values tensors.
            keys = [0, 1, 2]
            values = ['brain', 'salad', 'surgery']
            with self.assertRaises(TypeError):
                table.initialize_from(keys, values)
    def testInitializationWithInvalidDimensions(self):
        with self.test_session():
            shared_name = ''
            default_val = -1
            table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)

            # Initialize with keys and values tensors.
            keys = tf.constant(['brain', 'salad', 'surgery'])
            values = tf.constant([0, 1, 2, 3, 4], tf.int64)
            with self.assertRaises(ValueError):
                table.initialize_from(keys, values)
    def testNotInitialized(self):
        with self.test_session():
            shared_name = ''
            default_val = -1
            table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)

            input_string = tf.constant(['brain', 'salad', 'surgery'])
            output = table.lookup(input_string)

            with self.assertRaisesOpError('Table not initialized'):
                output.eval()
    def testInitializeTwice(self):
        with self.test_session():
            shared_name = ''
            default_val = -1
            table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)

            # Initialize with keys and values tensors.
            keys = tf.constant(['brain', 'salad', 'surgery'])
            values = tf.constant([0, 1, 2], tf.int64)
            init = table.initialize_from(keys, values)
            init.run()

            with self.assertRaisesOpError('Table already initialized'):
                init.run()
    def testHashTableWithTensorDefault(self):
        with self.test_session():
            shared_name = ''
            default_val = tf.constant(-1, tf.int64)
            table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)

            # Initialize with keys and values tensors.
            keys = tf.constant(['brain', 'salad', 'surgery'])
            values = tf.constant([0, 1, 2], tf.int64)
            init = table.initialize_from(keys, values)
            init.run()

            input_string = tf.constant(['brain', 'salad', 'tank'])
            output = table.lookup(input_string)

            result = output.eval()
            self.assertAllEqual([0, 1, -1], result)
    def testHashTableInitWithNumPyArrays(self):
        with self.test_session():
            shared_name = ''
            default_val = -1
            table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)

            # Initialize with keys and values tensors.
            keys = np.array(['brain', 'salad', 'surgery'], dtype=np.str)
            values = np.array([0, 1, 2], dtype=np.int64)
            init = table.initialize_from(keys, values)
            init.run()
            self.assertAllEqual(3, table.size().eval())

            input_string = tf.constant(['brain', 'salad', 'tank'])
            output = table.lookup(input_string)

            result = output.eval()
            self.assertAllEqual([0, 1, -1], result)
    def testHashTableFindHighRank(self):
        with self.test_session():
            shared_name = ''
            default_val = -1
            table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)

            # Initialize with keys and values tensors.
            keys = tf.constant(['brain', 'salad', 'surgery'])
            values = tf.constant([0, 1, 2], tf.int64)
            init = table.initialize_from(keys, values)
            init.run()
            self.assertAllEqual(3, table.size().eval())

            input_string = tf.constant([['brain', 'salad'], ['tank',
                                                             'tarkus']])
            output = table.lookup(input_string)

            result = output.eval()
            self.assertAllEqual([[0, 1], [-1, -1]], result)
 def testDTypes(self):
     with self.test_session():
         shared_name = ''
         default_val = -1
         with self.assertRaises(TypeError):
             tf.HashTable([tf.string], tf.string, default_val, shared_name)