Exemplo n.º 1
0
def train(model, **kwargs):
  """Trains progressive GAN for stage `stage_id`.

  Args:
    model: An model object having all information of progressive GAN model, e.g.
      the return of build_model().
    **kwargs: A dictionary of
        'train_log_dir': A string of root directory of training logs.
        'master': Name of the TensorFlow master to use.
        'task': The Task ID. This value is used when training with multiple
          workers to identify each worker.
        'save_summaries_num_images': Save summaries in this number of images.

  Returns:
    None.
  """
  logging.info('stage_id=%d, num_blocks=%d, num_images=%d', model.stage_id,
               model.num_blocks, model.num_images)

  scaffold = make_scaffold(model.stage_id, model.optimizer_var_list, **kwargs)

  tfgan.gan_train(
      model.gan_train_ops,
      logdir=make_train_sub_dir(model.stage_id, **kwargs),
      get_hooks_fn=tfgan.get_sequential_train_hooks(tfgan.GANTrainSteps(1, 1)),
      hooks=[
          tf.estimator.StopAtStepHook(last_step=model.num_images),
          tf.estimator.LoggingTensorHook([make_status_message(model)],
                                         every_n_iter=10)
      ],
      master=kwargs['master'],
      is_chief=(kwargs['task'] == 0),
      scaffold=scaffold,
      save_checkpoint_secs=600,
      save_summaries_steps=(kwargs['save_summaries_num_images']))
Exemplo n.º 2
0
def train(hparams):
  """Trains a StarGAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """

  # Create the log_dir if not exist.
  if not tf.io.gfile.exists(hparams.train_log_dir):
    tf.io.gfile.makedirs(hparams.train_log_dir)

  # Shard the model to different parameter servers.
  with tf.device(tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)):

    # Create the input dataset.
    with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'):
      images, labels = data_provider.provide_data('train', hparams.batch_size,
                                                  hparams.patch_size)

    # Define the model.
    with tf.compat.v1.name_scope('model'):
      model = _define_model(images, labels)

    # Add image summary.
    tfgan.eval.add_stargan_image_summaries(
        model, num_images=3 * hparams.batch_size, display_diffs=True)

    # Define the model loss.
    loss = tfgan.stargan_loss(model)

    # Define the train ops.
    with tf.compat.v1.name_scope('train_ops'):
      train_ops = _define_train_ops(model, loss, hparams.generator_lr,
                                    hparams.discriminator_lr,
                                    hparams.adam_beta1, hparams.adam_beta2,
                                    hparams.max_number_of_steps)

    # Define the train steps.
    train_steps = _define_train_step(hparams.gen_disc_step_ratio)

    # Define a status message.
    status_message = tf.strings.join([
        'Starting train step: ',
        tf.as_string(tf.compat.v1.train.get_or_create_global_step())
    ],
                                     name='status_message')

    # Train the model.
    tfgan.gan_train(
        train_ops,
        hparams.train_log_dir,
        get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
        hooks=[
            tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps),
            tf.estimator.LoggingTensorHook([status_message], every_n_iter=10)
        ],
        master=hparams.tf_master,
        is_chief=hparams.task == 0)
Exemplo n.º 3
0
def train(hparams):
    """Trains a CycleGAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
    if not tf.io.gfile.exists(hparams.train_log_dir):
        tf.io.gfile.makedirs(hparams.train_log_dir)

    with tf.device(
            tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)):
        with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'):
            images_x, images_y = _get_data(hparams.image_set_x_file_pattern,
                                           hparams.image_set_y_file_pattern,
                                           hparams.batch_size,
                                           hparams.patch_size)

        # Define CycleGAN model.
        cyclegan_model = _define_model(images_x, images_y)

        # Define CycleGAN loss.
        cyclegan_loss = tfgan.cyclegan_loss(
            cyclegan_model,
            cycle_consistency_loss_weight=hparams.
            cycle_consistency_loss_weight,
            tensor_pool_fn=tfgan.features.tensor_pool)

        # Define CycleGAN train ops.
        train_ops = _define_train_ops(cyclegan_model, cyclegan_loss, hparams)

        # Training
        train_steps = tfgan.GANTrainSteps(1, 1)
        status_message = tf.strings.join([
            'Starting train step: ',
            tf.as_string(tf.compat.v1.train.get_or_create_global_step())
        ],
                                         name='status_message')
        if not hparams.max_number_of_steps:
            return
        tfgan.gan_train(
            train_ops,
            hparams.train_log_dir,
            get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
            hooks=[
                tf.estimator.StopAtStepHook(
                    num_steps=hparams.max_number_of_steps),
                tf.estimator.LoggingTensorHook(
                    {'status_message': status_message}, every_n_iter=10)
            ],
            master=hparams.master,
            is_chief=hparams.task == 0)
Exemplo n.º 4
0
def main(_):
    if not tf.gfile.Exists(FLAGS.train_log_dir):
        tf.gfile.MakeDirs(FLAGS.train_log_dir)

    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
        with tf.name_scope('inputs'):
            initializer_hook = load_op(FLAGS.batch_size, FLAGS.max_number_of_steps)
            training_input_iter = initializer_hook.input_itr
            images_x, images_y = training_input_iter.get_next()
            # Set batch size for summaries.
            # images_x.set_shape([FLAGS.batch_size, None, None, None])
            # images_y.set_shape([FLAGS.batch_size, None, None, None])

        # Define CycleGAN model.
        cyclegan_model = _define_model(images_x, images_y)

        # Define CycleGAN loss.
        cyclegan_loss = tfgan.cyclegan_loss(
            cyclegan_model,
            cycle_consistency_loss_weight=FLAGS.cycle_consistency_loss_weight,
            tensor_pool_fn=tfgan.features.tensor_pool)

        # Define CycleGAN train ops.
        train_ops = _define_train_ops(cyclegan_model, cyclegan_loss)

        # Training
        train_steps = tfgan.GANTrainSteps(1, 1)
        status_message = tf.string_join(
            [
                'Starting train step: ',
                tf.as_string(tf.train.get_or_create_global_step())
            ],
            name='status_message')
        if not FLAGS.max_number_of_steps:
            return
        tfgan.gan_train(
            train_ops,
            FLAGS.train_log_dir,
            save_checkpoint_secs=60*10,
            get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
            hooks=[
                initializer_hook,
                tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
                tf.train.LoggingTensorHook([status_message], every_n_iter=10)
            ],
            master=FLAGS.master,
            is_chief=FLAGS.task == 0)
Exemplo n.º 5
0
def train(model, **kwargs):
    """Trains progressive GAN for stage `stage_id`.

  Args:
    model: An model object having all information of progressive GAN model,
        e.g. the return of build_model().
    **kwargs: A dictionary of
        'train_root_dir': A string of root directory of training logs.
        'master': Name of the TensorFlow master to use.
        'task': The Task ID. This value is used when training with multiple
            workers to identify each worker.
        'save_summaries_num_images': Save summaries in this number of images.
        'debug_hook': Whether to attach the debug hook to the training session.
  Returns:
    None.
  """
    logging.info('stage_id=%d, num_blocks=%d, num_images=%d', model.stage_id,
                 model.num_blocks, model.num_images)

    scaffold = make_scaffold(model.stage_id, model.optimizer_var_list,
                             **kwargs)

    logdir = make_train_sub_dir(model.stage_id, **kwargs)
    print('starting training, logdir: {}'.format(logdir))
    hooks = []
    if model.stage_train_time_limit is None:
        hooks.append(tf.train.StopAtStepHook(last_step=model.num_images))
    hooks.append(
        tf.train.LoggingTensorHook([make_status_message(model)],
                                   every_n_iter=1))
    hooks.append(TrainTimeHook(model.train_time, model.stage_train_time_limit))
    if kwargs['debug_hook']:
        hooks.append(ProganDebugHook())
    tfgan.gan_train(model.gan_train_ops,
                    logdir=logdir,
                    get_hooks_fn=tfgan.get_sequential_train_hooks(
                        tfgan.GANTrainSteps(1, 1)),
                    hooks=hooks,
                    master=kwargs['master'],
                    is_chief=(kwargs['task'] == 0),
                    scaffold=scaffold,
                    save_checkpoint_secs=600,
                    save_summaries_steps=(kwargs['save_summaries_num_images']))
Exemplo n.º 6
0
  def test_multiple_steps(self, get_hooks_fn_fn):
    """Test multiple train steps."""
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return
    train_ops = self._gan_train_ops(generator_add=10, discriminator_add=100)
    train_steps = tfgan.GANTrainSteps(
        generator_train_steps=3, discriminator_train_steps=4)
    final_step = tfgan.gan_train(
        train_ops,
        get_hooks_fn=get_hooks_fn_fn(train_steps),
        logdir='',
        hooks=[tf.estimator.StopAtStepHook(num_steps=1)])

    self.assertTrue(np.isscalar(final_step))
    self.assertEqual(1 + 3 * 10 + 4 * 100, final_step)
Exemplo n.º 7
0
  def test_run_helper(self, create_gan_model_fn):
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return
    tf.compat.v1.random.set_random_seed(1234)
    model = create_gan_model_fn()
    loss = tfgan.gan_loss(model)

    g_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0)
    d_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0)
    train_ops = tfgan.gan_train_ops(model, loss, g_opt, d_opt)

    final_step = tfgan.gan_train(
        train_ops, logdir='', hooks=[tf.estimator.StopAtStepHook(num_steps=2)])
    self.assertTrue(np.isscalar(final_step))
    self.assertEqual(2, final_step)
Exemplo n.º 8
0
def train(hparams):
    """Trains an MNIST GAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
    if not tf.io.gfile.exists(hparams.train_log_dir):
        tf.io.gfile.makedirs(hparams.train_log_dir)

    # Force all input processing onto CPU in order to reserve the GPU for
    # the forward inference and back-propagation.
    with tf.name_scope('inputs'), tf.device('/cpu:0'):
        images, one_hot_labels = data_provider.provide_data(
            'train', hparams.batch_size, num_parallel_calls=4)

    # Define the GANModel tuple. Optionally, condition the GAN on the label or
    # use an InfoGAN to learn a latent representation.
    if hparams.gan_type == 'unconditional':
        gan_model = tfgan.gan_model(
            generator_fn=networks.unconditional_generator,
            discriminator_fn=networks.unconditional_discriminator,
            real_data=images,
            generator_inputs=tf.random.normal(
                [hparams.batch_size, hparams.noise_dims]))
    elif hparams.gan_type == 'conditional':
        noise = tf.random.normal([hparams.batch_size, hparams.noise_dims])
        gan_model = tfgan.gan_model(
            generator_fn=networks.conditional_generator,
            discriminator_fn=networks.conditional_discriminator,
            real_data=images,
            generator_inputs=(noise, one_hot_labels))
    elif hparams.gan_type == 'infogan':
        cat_dim, cont_dim = 10, 2
        generator_fn = functools.partial(networks.infogan_generator,
                                         categorical_dim=cat_dim)
        discriminator_fn = functools.partial(networks.infogan_discriminator,
                                             categorical_dim=cat_dim,
                                             continuous_dim=cont_dim)
        unstructured_inputs, structured_inputs = util.get_infogan_noise(
            hparams.batch_size, cat_dim, cont_dim, hparams.noise_dims)
        gan_model = tfgan.infogan_model(
            generator_fn=generator_fn,
            discriminator_fn=discriminator_fn,
            real_data=images,
            unstructured_generator_inputs=unstructured_inputs,
            structured_generator_inputs=structured_inputs)
    tfgan.eval.add_gan_model_image_summaries(gan_model, hparams.grid_size)

    # Get the GANLoss tuple. You can pass a custom function, use one of the
    # already-implemented losses from the losses library, or use the defaults.
    with tf.name_scope('loss'):
        if hparams.gan_type == 'infogan':
            gan_loss = tfgan.gan_loss(
                gan_model,
                generator_loss_fn=tfgan.losses.modified_generator_loss,
                discriminator_loss_fn=tfgan.losses.modified_discriminator_loss,
                mutual_information_penalty_weight=1.0,
                add_summaries=True)
        else:
            gan_loss = tfgan.gan_loss(gan_model, add_summaries=True)
        tfgan.eval.add_regularization_loss_summaries(gan_model)

    # Get the GANTrain ops using custom optimizers.
    with tf.name_scope('train'):
        gen_lr, dis_lr = _learning_rate(hparams.gan_type)
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
            discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
            summarize_gradients=True,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)

    # Run the alternating training loop. Skip it if no steps should be taken
    # (used for graph construction tests).
    status_message = tf.strings.join([
        'Starting train step: ',
        tf.as_string(tf.train.get_or_create_global_step())
    ],
                                     name='status_message')
    if hparams.max_number_of_steps == 0:
        return
    tfgan.gan_train(
        train_ops,
        hooks=[
            tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps),
            tf.estimator.LoggingTensorHook([status_message], every_n_iter=10)
        ],
        logdir=hparams.train_log_dir,
        get_hooks_fn=tfgan.get_joint_train_hooks(),
        save_checkpoint_secs=60)
Exemplo n.º 9
0
def train(hparams):
  """Trains a CycleGAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
  if not tf.io.gfile.exists(hparams.train_log_dir):
    tf.io.gfile.makedirs(hparams.train_log_dir)
    
  with open(hparams.train_log_dir + 'train_result.json', 'w') as fp:
    json.dump(hparams._asdict(), fp, indent=4)

  with tf.device(tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)):
    with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'):
      images_x, images_y = _get_data(hparams.image_set_x_file_pattern,
                                     hparams.image_set_y_file_pattern,
                                     hparams.batch_size, hparams.patch_size, hparams.tfdata_source)

    # Define CycleGAN model.
    cyclegan_model = _define_model(images_x, images_y)

    # Define CycleGAN loss.
    cyclegan_loss = tfgan.cyclegan_loss(
        cyclegan_model,
        cycle_consistency_loss_weight=hparams.cycle_consistency_loss_weight,
        tensor_pool_fn=tfgan.features.tensor_pool)

    # Define CycleGAN train ops.
    train_ops = _define_train_ops(cyclegan_model, cyclegan_loss, hparams)

    # Training
    train_steps = tfgan.GANTrainSteps(1, 1)
    status_message = tf.strings.join([
        'Starting train step: ',
        tf.as_string(tf.compat.v1.train.get_or_create_global_step())
    ],
                                     name='status_message')
    if not hparams.max_number_of_steps:
      return

    additional_params = {}
    if hparams.save_checkpoint_steps:
        max_to_keep = hparams.max_number_of_steps // hparams.save_checkpoint_steps + 1
        additional_params = {
            'scaffold': tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=max_to_keep)),
            'save_checkpoint_secs': None,
            'save_checkpoint_steps': hparams.save_checkpoint_steps,
        }

    tfgan.gan_train(
        train_ops,
        hparams.train_log_dir,
        get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
        hooks=[
            tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps),
            tf.estimator.LoggingTensorHook({'status_message': status_message},
                                           every_n_iter=10)
        ],
        master=hparams.master,
        is_chief=hparams.task == 0,
        **additional_params,
    )
Exemplo n.º 10
0
def train(hparams):
    """Trains a CIFAR10 GAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
    if not tf.io.gfile.exists(hparams.train_log_dir):
        tf.io.gfile.makedirs(hparams.train_log_dir)

    with tf.device(
            tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)):
        # Force all input processing onto CPU in order to reserve the GPU for
        # the forward inference and back-propagation.
        with tf.compat.v1.name_scope('inputs'):
            with tf.device('/cpu:0'):
                images, _ = data_provider.provide_data('train',
                                                       hparams.batch_size,
                                                       num_parallel_calls=4)

        # Define the GANModel tuple.
        generator_fn = networks.generator
        discriminator_fn = networks.discriminator
        generator_inputs = tf.random.normal([hparams.batch_size, 64])
        gan_model = tfgan.gan_model(generator_fn,
                                    discriminator_fn,
                                    real_data=images,
                                    generator_inputs=generator_inputs)
        tfgan.eval.add_gan_model_image_summaries(gan_model)

        # Get the GANLoss tuple. Use the selected GAN loss functions.
        with tf.compat.v1.name_scope('loss'):
            gan_loss = tfgan.gan_loss(gan_model,
                                      gradient_penalty_weight=1.0,
                                      add_summaries=True)

        # Get the GANTrain ops using the custom optimizers and optional
        # discriminator weight clipping.
        with tf.compat.v1.name_scope('train'):
            gen_opt, dis_opt = _get_optimizers(hparams)
            train_ops = tfgan.gan_train_ops(gan_model,
                                            gan_loss,
                                            generator_optimizer=gen_opt,
                                            discriminator_optimizer=dis_opt,
                                            summarize_gradients=True)

        # Run the alternating training loop. Skip it if no steps should be taken
        # (used for graph construction tests).
        status_message = tf.strings.join([
            'Starting train step: ',
            tf.as_string(tf.compat.v1.train.get_or_create_global_step())
        ],
                                         name='status_message')
        if hparams.max_number_of_steps == 0:
            return
        tfgan.gan_train(train_ops,
                        hooks=([
                            tf.estimator.StopAtStepHook(
                                num_steps=hparams.max_number_of_steps),
                            tf.estimator.LoggingTensorHook([status_message],
                                                           every_n_iter=10)
                        ]),
                        logdir=hparams.train_log_dir,
                        master=hparams.master,
                        is_chief=hparams.task == 0)