def testIdentityIndexedDatasetIterator(self): ds = indexed_dataset_ops.IdentityIndexedDataset(16) itr = ds.make_initializable_iterator() n = itr.get_next() with self.cached_session() as sess: sess.run(itr.initializer) for i in range(16): output = sess.run(n) self.assertEqual(i, output) with self.assertRaises(errors.OutOfRangeError): sess.run(n)
def testIdentityIndexedDataset(self): ds = indexed_dataset_ops.IdentityIndexedDataset(16) materialized = ds.materialize() with self.cached_session() as sess: sess.run(materialized.initializer) placeholder = array_ops.placeholder(dtypes.uint64, shape=[]) for i in range(16): output = sess.run( materialized.get(placeholder), feed_dict={placeholder: i}) self.assertEqual([i], output) with self.assertRaises(errors.InvalidArgumentError): sess.run(materialized.get(placeholder), feed_dict={placeholder: 16})