def input_fn(mode, params):
    """Input function for GANEstimator."""
    if 'batch_size' not in params:
        raise ValueError('batch_size must be in params')
    if 'noise_dims' not in params:
        raise ValueError('noise_dims must be in params')
    bs = params['batch_size']
    nd = params['noise_dims']
    split = 'train' if mode == tf.estimator.ModeKeys.TRAIN else 'test'
    shuffle = (mode == tf.estimator.ModeKeys.TRAIN)
    just_noise = (mode == tf.estimator.ModeKeys.PREDICT)

    noise_ds = (tf.data.Dataset.from_tensors(0).repeat().map(
        lambda _: tf.random.normal([bs, nd])))

    if just_noise:
        return noise_ds

    if params['use_dummy_data']:
        img = np.zeros((bs, 28, 28, 1), dtype=np.float32)
        images_ds = tf.data.Dataset.from_tensors(img).repeat()
    else:
        images_ds = (data_provider.provide_dataset(
            split, bs, params['num_reader_parallel_calls'],
            shuffle).map(lambda x: x['images']))  # Just take the images.

    return tf.data.Dataset.zip((noise_ds, images_ds))
Exemple #2
0
    def test_provide_dataset(self, mock_tfds):
        batch_size = 5
        mock_tfds.load.return_value = self.mock_ds

        ds = data_provider.provide_dataset('test', batch_size)
        self.assertIsInstance(ds, tf.data.Dataset)

        output = compat_utils.ds_output_classes(ds)
        self.assertIsInstance(output, dict)
        self.assertSetEqual(set(output.keys()), set(['images', 'labels']))
        self.assertEqual(output['images'], tf.Tensor)
        self.assertEqual(output['labels'], tf.Tensor)

        shapes = compat_utils.ds_output_shapes(ds)
        self.assertIsInstance(shapes, dict)
        self.assertSetEqual(set(shapes.keys()), set(['images', 'labels']))
        self.assertIsInstance(shapes['images'], tf.TensorShape)
        self.assertIsInstance(shapes['labels'], tf.TensorShape)
        self.assertListEqual(shapes['images'].as_list(),
                             [batch_size, 28, 28, 1])
        self.assertListEqual(shapes['labels'].as_list(), [batch_size, 10])

        types = compat_utils.ds_output_types(ds)
        self.assertIsInstance(types, dict)
        self.assertSetEqual(set(types.keys()), set(['images', 'labels']))
        self.assertEqual(types['images'], tf.float32)
        self.assertEqual(types['labels'], tf.float32)

        next_batch = tf.compat.v1.data.make_one_shot_iterator(ds).get_next()
        images = next_batch['images']
        labels = next_batch['labels']

        with self.cached_session() as sess:
            images, labels = sess.run([images, labels])

        self.assertEqual(images.shape, (batch_size, 28, 28, 1))
        self.assertTrue(np.all(np.abs(images) <= 1))
        self.assertEqual(labels.shape, (batch_size, 10))