def main(_): hparams = train_experiment_lib.HParams( FLAGS.generator_lr, FLAGS.discriminator_lr, FLAGS.joint_train, FLAGS.batch_size, FLAGS.noise_dims, FLAGS.model_dir, FLAGS.num_train_steps, FLAGS.num_eval_steps, FLAGS.num_reader_parallel_calls, FLAGS.use_dummy_data) train_experiment_lib.train(hparams)
def test_full_flow(self, mock_mnist_frechet_distance, mock_mnist_score): hparams = train_experiment_lib.HParams( generator_lr=0.000076421, discriminator_lr=0.0031938, joint_train=False, batch_size=16, noise_dims=4, model_dir=self.get_temp_dir(), num_train_steps=1, num_eval_steps=1, num_reader_parallel_calls=4, use_dummy_data=True) # Mock computationally expensive eval computations. mock_mnist_score.return_value = 0.0 mock_mnist_frechet_distance.return_value = 1.0 train_experiment_lib.train(hparams) # Check that there's a .png file in the output directory. out_dir = os.path.join(hparams.model_dir, 'outputs') self.assertTrue(tf.io.gfile.exists(out_dir)) has_png = False for f in tf.io.gfile.listdir(out_dir): if f.split('.')[-1] == 'png': has_png = True break self.assertTrue(has_png)