def testShardedMutableHashTable(self): for num_shards in [1, 3, 10]: with self._single_threaded_test_session(): default_val = -1 keys = tf.constant(["brain", "salad", "surgery"]) values = tf.constant([0, 1, 2], tf.int64) table = _ShardedMutableHashTable(tf.string, tf.int64, default_val, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) input_string = tf.constant(["brain", "salad", "tank"]) output = table.lookup(input_string) self.assertAllEqual([3], output.get_shape()) result = output.eval() self.assertAllEqual([0, 1, -1], result)
def testExportSharded(self): with self._single_threaded_test_session(): default_val = -1 num_shards = 2 keys = tf.constant(["a1", "b1", "c2"]) values = tf.constant([0, 1, 2], tf.int64) table = _ShardedMutableHashTable(tf.string, tf.int64, default_val, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) keys_list, values_list = table.export_sharded() self.assertAllEqual(num_shards, len(keys_list)) self.assertAllEqual(num_shards, len(values_list)) self.assertAllEqual(set([b"b1", b"c2"]), set(keys_list[0].eval())) self.assertAllEqual([b"a1"], keys_list[1].eval()) self.assertAllEqual(set([1, 2]), set(values_list[0].eval())) self.assertAllEqual([0], values_list[1].eval())
def testExportSharded(self): with self._single_threaded_test_session(): default_val = -1 num_shards = 2 keys = tf.constant(['a1', 'b1', 'c2']) values = tf.constant([0, 1, 2], tf.int64) table = _ShardedMutableHashTable( tf.string, tf.int64, default_val, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) keys_list, values_list = table.export_sharded() self.assertAllEqual(num_shards, len(keys_list)) self.assertAllEqual(num_shards, len(values_list)) self.assertAllEqual(set([b'b1', b'c2']), set(keys_list[0].eval())) self.assertAllEqual([b'a1'], keys_list[1].eval()) self.assertAllEqual(set([1, 2]), set(values_list[0].eval())) self.assertAllEqual([0], values_list[1].eval())
def testShardedMutableHashTable(self): for num_shards in [1, 3, 10]: with self._single_threaded_test_session(): default_val = -1 keys = tf.constant(['brain', 'salad', 'surgery']) values = tf.constant([0, 1, 2], tf.int64) table = _ShardedMutableHashTable(tf.string, tf.int64, default_val, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) input_string = tf.constant(['brain', 'salad', 'tank']) output = table.lookup(input_string) self.assertAllEqual([3], output.get_shape()) result = output.eval() self.assertAllEqual([0, 1, -1], result)
def testShardedMutableHashTable(self): for num_shards in [1, 3, 10]: with self.test_session(): default_val = -1 keys = tf.constant(['brain', 'salad', 'surgery']) values = tf.constant([0, 1, 2], tf.int64) table = _ShardedMutableHashTable(tf.string, tf.int64, default_val, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) input_string = tf.constant(['brain', 'salad', 'tank']) output = table.lookup(input_string) self.assertAllEqual([3], output.get_shape()) result = output.eval() self.assertAllEqual([0, 1, -1], result) self.assertAllEqual(3, table.values_reduce_sum().eval())