def test_optionally_applies_central_crop(self): batch_size = 4 data = data_provider.get_data( dataset=datasets.fsns_test.get_test_split(), batch_size=batch_size, augment=True, central_crop_size=(500, 100)) with self.test_session() as sess, queues.QueueRunners(sess): images_np = sess.run(data.images) self.assertEqual(images_np.shape, (batch_size, 100, 500, 3))
def test_provided_data_has_correct_shape(self): batch_size = 4 data = data_provider.get_data( dataset=datasets.fsns_test.get_test_split(), batch_size=batch_size, augment=True, central_crop_size=None) with self.test_session() as sess, queues.QueueRunners(sess): images_np, labels_np = sess.run([data.images, data.labels_one_hot]) self.assertEqual(images_np.shape, (batch_size, 150, 600, 3)) self.assertEqual(labels_np.shape, (batch_size, 37, 134))
def test_labels_correctly_shuffled(self): batch_size = 4 data = data_provider.get_data( dataset=datasets.fsns_test.get_test_split(), batch_size=batch_size, augment=True, central_crop_size=None) with self.test_session() as sess, queues.QueueRunners(sess): images, labels, probs, texts = sess.run( [data.images, data.labels, data.probs, data.texts]) for i in range(batch_size * batch_size): plt.imshow(images[i]) print(texts[i], probs[i], labels[i])