コード例 #1
0
ファイル: train_main.py プロジェクト: zhouyonglong/gan
def _provide_real_images(batch_size, **kwargs):
    """Provides real images."""
    dataset_file_pattern = kwargs.get('dataset_file_pattern')
    colors = kwargs['colors']
    final_height, final_width = train.make_resolution_schedule(
        **kwargs).final_resolutions
    if not dataset_file_pattern:
        return data_provider.provide_data(split='train',
                                          batch_size=batch_size,
                                          patch_height=final_height,
                                          patch_width=final_width,
                                          colors=colors)
    else:
        return data_provider.provide_data_from_image_files(
            file_pattern=dataset_file_pattern,
            batch_size=batch_size,
            patch_height=final_height,
            patch_width=final_width,
            colors=colors)
コード例 #2
0
    def test_provide_data(self, mock_tfds):
        batch_size = 4
        patch_height = 2
        patch_width = 8
        colors = 1
        expected_shape = [batch_size, patch_height, patch_width, colors]
        mock_tfds.load.return_value = self.mock_ds

        images = data_provider.provide_data('train',
                                            patch_height=patch_height,
                                            patch_width=patch_width,
                                            colors=colors,
                                            batch_size=batch_size)
        self.assertEqual(images.shape.as_list(), expected_shape)

        with self.cached_session() as sess:
            images_np = sess.run(images)

        self.assertTupleEqual(images_np.shape, tuple(expected_shape))
        self.assertTrue(np.all(np.abs(images_np) <= 1))