def _provide_real_images(batch_size, **kwargs): """Provides real images.""" dataset_file_pattern = kwargs.get('dataset_file_pattern') colors = kwargs['colors'] final_height, final_width = train.make_resolution_schedule( **kwargs).final_resolutions if not dataset_file_pattern: return data_provider.provide_data(split='train', batch_size=batch_size, patch_height=final_height, patch_width=final_width, colors=colors) else: return data_provider.provide_data_from_image_files( file_pattern=dataset_file_pattern, batch_size=batch_size, patch_height=final_height, patch_width=final_width, colors=colors)
def test_provide_data(self, mock_tfds): batch_size = 4 patch_height = 2 patch_width = 8 colors = 1 expected_shape = [batch_size, patch_height, patch_width, colors] mock_tfds.load.return_value = self.mock_ds images = data_provider.provide_data('train', patch_height=patch_height, patch_width=patch_width, colors=colors, batch_size=batch_size) self.assertEqual(images.shape.as_list(), expected_shape) with self.cached_session() as sess: images_np = sess.run(images) self.assertTupleEqual(images_np.shape, tuple(expected_shape)) self.assertTrue(np.all(np.abs(images_np) <= 1))