def get_input(args): with tf.name_scope("input"): dataset, c = Datasets.get_featran_example_dataset(args.input, gen_spec=["label"]) iterator = dataset.make_initializable_iterator() (label,), features = iterator.get_next() label = tf.reshape(label, [-1, 1]) features = tf.reshape(features, [-1, c.num_features]) return iterator, label, features
def test_get_featran_example_dataset(self): d, _, _ = DataUtil.write_featran_test_data() with self.test_session() as sess: dataset, c = Datasets.get_featran_example_dataset(d) self.assertEquals(len(c.features), 2) iterator = dataset.make_one_shot_iterator() r = iterator.get_next() f1, f2 = r["f1"], r["f2"] self.assertAllEqual([1, 2], sess.run([f1, f2])) with self.assertRaises(tf.errors.OutOfRangeError): f1.eval()