def test_custom_data_provider(self): if tf.executing_eagerly(): # dataset.make_initializable_iterator is not supported when eager # execution is enabled. return file_pattern = os.path.join(self.testdata_dir, '*.jpg') batch_size = 3 patch_size = 8 images_list = data_provider.provide_custom_data( batch_size=batch_size, image_file_patterns=[file_pattern, file_pattern], patch_size=patch_size) for images in images_list: self.assertListEqual([batch_size, patch_size, patch_size, 3], images.shape.as_list()) self.assertEqual(tf.float32, images.dtype) with self.cached_session() as sess: sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) images_out_list = sess.run(images_list) for images_out in images_out_list: self.assertTupleEqual((batch_size, patch_size, patch_size, 3), images_out.shape) self.assertTrue(np.all(np.abs(images_out) <= 1.0))
def _get_data(image_set_x_file_pattern, image_set_y_file_pattern, batch_size, patch_size): """Returns image Tensors from a custom provider or TFDS.""" if image_set_x_file_pattern and image_set_y_file_pattern: image_file_patterns = [image_set_x_file_pattern, image_set_y_file_pattern] else: if image_set_x_file_pattern or image_set_y_file_pattern: raise ValueError('Both image patterns or neither must be provided.') image_file_patterns = None images_x, images_y = data_provider.provide_custom_data( batch_size=batch_size, image_file_patterns=image_file_patterns, patch_size=patch_size) return images_x, images_y
def test_custom_data_provider(self): file_pattern = os.path.join(self.testdata_dir, '*.jpg') batch_size = 3 patch_size = 8 images_list = data_provider.provide_custom_data( batch_size=batch_size, image_file_patterns=[file_pattern, file_pattern], patch_size=patch_size) for images in images_list: self.assertListEqual([batch_size, patch_size, patch_size, 3], images.shape.as_list()) self.assertEqual(tf.float32, images.dtype) with self.cached_session() as sess: sess.run(tf.compat.v1.local_variables_initializer()) sess.run(tf.compat.v1.tables_initializer()) images_out_list = sess.run(images_list) for images_out in images_out_list: self.assertTupleEqual((batch_size, patch_size, patch_size, 3), images_out.shape) self.assertTrue(np.all(np.abs(images_out) <= 1.0))