def test_get_example_dataset(self): dataset = Datasets.examples_via_schema(self.train_data, self.schema_path, batch_size=16) batch_it = dataset.make_one_shot_iterator().get_next() with tf.Session() as sess: batch = sess.run(batch_it) self.assertEqual(len(batch), self.N_FEATURES) self.assertEqual(len(batch["f1"]), 16)
def test_simple_get_example_dataset(self): data, schema_path = SquareTest._write_test_data() with self.test_session() as sess: dataset = Datasets.examples_via_schema(data, schema_path) # noqa: E501 iterator = dataset.make_one_shot_iterator() r = iterator.get_next() f1, f2 = r["f1"], r["f2"] self.assertAllEqual([[1], [2]], sess.run([f1, f2])) with self.assertRaises(tf.errors.OutOfRangeError): f1.eval()