Beispiel #1
0
def main(_):
    hparams = train_experiment_lib.HParams(
        FLAGS.generator_lr, FLAGS.discriminator_lr, FLAGS.joint_train,
        FLAGS.batch_size, FLAGS.noise_dims, FLAGS.model_dir,
        FLAGS.num_train_steps, FLAGS.num_eval_steps,
        FLAGS.num_reader_parallel_calls, FLAGS.use_dummy_data)
    train_experiment_lib.train(hparams)
Beispiel #2
0
  def test_full_flow(self, mock_mnist_frechet_distance, mock_mnist_score):
    hparams = train_experiment_lib.HParams(
        generator_lr=0.000076421,
        discriminator_lr=0.0031938,
        joint_train=False,
        batch_size=16,
        noise_dims=4,
        model_dir=self.get_temp_dir(),
        num_train_steps=1,
        num_eval_steps=1,
        num_reader_parallel_calls=4,
        use_dummy_data=True)

    # Mock computationally expensive eval computations.
    mock_mnist_score.return_value = 0.0
    mock_mnist_frechet_distance.return_value = 1.0

    train_experiment_lib.train(hparams)

    # Check that there's a .png file in the output directory.
    out_dir = os.path.join(hparams.model_dir, 'outputs')
    self.assertTrue(tf.io.gfile.exists(out_dir))
    has_png = False
    for f in tf.io.gfile.listdir(out_dir):
      if f.split('.')[-1] == 'png':
        has_png = True
        break
    self.assertTrue(has_png)