def input_fn(mode, params): """Input function for GANEstimator.""" if 'batch_size' not in params: raise ValueError('batch_size must be in params') if 'noise_dims' not in params: raise ValueError('noise_dims must be in params') bs = params['batch_size'] nd = params['noise_dims'] split = 'train' if mode == tf.estimator.ModeKeys.TRAIN else 'test' shuffle = (mode == tf.estimator.ModeKeys.TRAIN) just_noise = (mode == tf.estimator.ModeKeys.PREDICT) noise_ds = (tf.data.Dataset.from_tensors(0).repeat().map( lambda _: tf.random.normal([bs, nd]))) if just_noise: return noise_ds if params['use_dummy_data']: img = np.zeros((bs, 28, 28, 1), dtype=np.float32) images_ds = tf.data.Dataset.from_tensors(img).repeat() else: images_ds = (data_provider.provide_dataset( split, bs, params['num_reader_parallel_calls'], shuffle).map(lambda x: x['images'])) # Just take the images. return tf.data.Dataset.zip((noise_ds, images_ds))
def test_provide_dataset(self, mock_tfds): batch_size = 5 mock_tfds.load.return_value = self.mock_ds ds = data_provider.provide_dataset('test', batch_size) self.assertIsInstance(ds, tf.data.Dataset) output = compat_utils.ds_output_classes(ds) self.assertIsInstance(output, dict) self.assertSetEqual(set(output.keys()), set(['images', 'labels'])) self.assertEqual(output['images'], tf.Tensor) self.assertEqual(output['labels'], tf.Tensor) shapes = compat_utils.ds_output_shapes(ds) self.assertIsInstance(shapes, dict) self.assertSetEqual(set(shapes.keys()), set(['images', 'labels'])) self.assertIsInstance(shapes['images'], tf.TensorShape) self.assertIsInstance(shapes['labels'], tf.TensorShape) self.assertListEqual(shapes['images'].as_list(), [batch_size, 28, 28, 1]) self.assertListEqual(shapes['labels'].as_list(), [batch_size, 10]) types = compat_utils.ds_output_types(ds) self.assertIsInstance(types, dict) self.assertSetEqual(set(types.keys()), set(['images', 'labels'])) self.assertEqual(types['images'], tf.float32) self.assertEqual(types['labels'], tf.float32) next_batch = tf.compat.v1.data.make_one_shot_iterator(ds).get_next() images = next_batch['images'] labels = next_batch['labels'] with self.cached_session() as sess: images, labels = sess.run([images, labels]) self.assertEqual(images.shape, (batch_size, 28, 28, 1)) self.assertTrue(np.all(np.abs(images) <= 1)) self.assertEqual(labels.shape, (batch_size, 10))