def testShardedMutableHashTable(self): for num_shards in [1, 3, 10]: with self._single_threaded_test_session(): default_val = -1 empty_key = 0 keys = tf.constant([11, 12, 13], tf.int64) values = tf.constant([0, 1, 2], tf.int64) table = _ShardedMutableDenseHashTable(tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) input_string = tf.constant([11, 12, 14], tf.int64) output = table.lookup(input_string) self.assertAllEqual([3], output.get_shape()) self.assertAllEqual([0, 1, -1], output.eval())
def testShardedMutableHashTable(self): for num_shards in [1, 3, 10]: with self._single_threaded_test_session(): default_val = -1 empty_key = 0 keys = tf.constant([11, 12, 13], tf.int64) values = tf.constant([0, 1, 2], tf.int64) table = _ShardedMutableDenseHashTable( tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) input_string = tf.constant([11, 12, 14], tf.int64) output = table.lookup(input_string) self.assertAllEqual([3], output.get_shape()) self.assertAllEqual([0, 1, -1], output.eval())
def testShardedMutableHashTableVectors(self): for num_shards in [1, 3, 10]: with self._single_threaded_test_session(): default_val = [-0.1, 0.2] empty_key = [0, 1] keys = tf.constant([[11, 12], [13, 14], [15, 16]], tf.int64) values = tf.constant([[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]], tf.float32) table = _ShardedMutableDenseHashTable( tf.int64, tf.float32, default_val, empty_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) input_string = tf.constant([[11, 12], [13, 14], [11, 14]], tf.int64) output = table.lookup(input_string) self.assertAllEqual([3, 2], output.get_shape()) self.assertAllClose([[0.5, 0.6], [1.5, 1.6], [-0.1, 0.2]], output.eval())
def testExportSharded(self): with self._single_threaded_test_session(): empty_key = -2 default_val = -1 num_shards = 2 keys = tf.constant([10, 11, 12], tf.int64) values = tf.constant([2, 3, 4], tf.int64) table = _ShardedMutableDenseHashTable(tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) keys_list, values_list = table.export_sharded() self.assertAllEqual(num_shards, len(keys_list)) self.assertAllEqual(num_shards, len(values_list)) # Exported keys include empty key buckets set to the empty_key self.assertAllEqual(set([-2, 10, 12]), set(keys_list[0].eval().flatten())) self.assertAllEqual(set([-2, 11]), set(keys_list[1].eval().flatten())) # Exported values include empty value buckets set to 0 self.assertAllEqual(set([0, 2, 4]), set(values_list[0].eval().flatten())) self.assertAllEqual(set([0, 3]), set(values_list[1].eval().flatten()))
def testExportSharded(self): with self._single_threaded_test_session(): empty_key = -2 default_val = -1 num_shards = 2 keys = tf.constant([10, 11, 12], tf.int64) values = tf.constant([2, 3, 4], tf.int64) table = _ShardedMutableDenseHashTable( tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) keys_list, values_list = table.export_sharded() self.assertAllEqual(num_shards, len(keys_list)) self.assertAllEqual(num_shards, len(values_list)) # Exported keys include empty key buckets set to the empty_key self.assertAllEqual(set([-2, 10, 12]), set(keys_list[0].eval().flatten())) self.assertAllEqual(set([-2, 11]), set(keys_list[1].eval().flatten())) # Exported values include empty value buckets set to 0 self.assertAllEqual(set([0, 2, 4]), set(values_list[0].eval().flatten())) self.assertAllEqual(set([0, 3]), set(values_list[1].eval().flatten()))