예제 #1
0
def train(model, **kwargs):
  """Trains progressive GAN for stage `stage_id`.

  Args:
    model: An model object having all information of progressive GAN model, e.g.
      the return of build_model().
    **kwargs: A dictionary of
        'train_log_dir': A string of root directory of training logs.
        'master': Name of the TensorFlow master to use.
        'task': The Task ID. This value is used when training with multiple
          workers to identify each worker.
        'save_summaries_num_images': Save summaries in this number of images.

  Returns:
    None.
  """
  logging.info('stage_id=%d, num_blocks=%d, num_images=%d', model.stage_id,
               model.num_blocks, model.num_images)

  scaffold = make_scaffold(model.stage_id, model.optimizer_var_list, **kwargs)

  tfgan.gan_train(
      model.gan_train_ops,
      logdir=make_train_sub_dir(model.stage_id, **kwargs),
      get_hooks_fn=tfgan.get_sequential_train_hooks(tfgan.GANTrainSteps(1, 1)),
      hooks=[
          tf.estimator.StopAtStepHook(last_step=model.num_images),
          tf.estimator.LoggingTensorHook([make_status_message(model)],
                                         every_n_iter=10)
      ],
      master=kwargs['master'],
      is_chief=(kwargs['task'] == 0),
      scaffold=scaffold,
      save_checkpoint_secs=600,
      save_summaries_steps=(kwargs['save_summaries_num_images']))
예제 #2
0
def train(hparams):
  """Trains a StarGAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """

  # Create the log_dir if not exist.
  if not tf.io.gfile.exists(hparams.train_log_dir):
    tf.io.gfile.makedirs(hparams.train_log_dir)

  # Shard the model to different parameter servers.
  with tf.device(tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)):

    # Create the input dataset.
    with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'):
      images, labels = data_provider.provide_data('train', hparams.batch_size,
                                                  hparams.patch_size)

    # Define the model.
    with tf.compat.v1.name_scope('model'):
      model = _define_model(images, labels)

    # Add image summary.
    tfgan.eval.add_stargan_image_summaries(
        model, num_images=3 * hparams.batch_size, display_diffs=True)

    # Define the model loss.
    loss = tfgan.stargan_loss(model)

    # Define the train ops.
    with tf.compat.v1.name_scope('train_ops'):
      train_ops = _define_train_ops(model, loss, hparams.generator_lr,
                                    hparams.discriminator_lr,
                                    hparams.adam_beta1, hparams.adam_beta2,
                                    hparams.max_number_of_steps)

    # Define the train steps.
    train_steps = _define_train_step(hparams.gen_disc_step_ratio)

    # Define a status message.
    status_message = tf.strings.join([
        'Starting train step: ',
        tf.as_string(tf.compat.v1.train.get_or_create_global_step())
    ],
                                     name='status_message')

    # Train the model.
    tfgan.gan_train(
        train_ops,
        hparams.train_log_dir,
        get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
        hooks=[
            tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps),
            tf.estimator.LoggingTensorHook([status_message], every_n_iter=10)
        ],
        master=hparams.tf_master,
        is_chief=hparams.task == 0)
예제 #3
0
파일: train_lib.py 프로젝트: yyht/gan
def train(hparams):
    """Trains a CycleGAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
    if not tf.io.gfile.exists(hparams.train_log_dir):
        tf.io.gfile.makedirs(hparams.train_log_dir)

    with tf.device(
            tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)):
        with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'):
            images_x, images_y = _get_data(hparams.image_set_x_file_pattern,
                                           hparams.image_set_y_file_pattern,
                                           hparams.batch_size,
                                           hparams.patch_size)

        # Define CycleGAN model.
        cyclegan_model = _define_model(images_x, images_y)

        # Define CycleGAN loss.
        cyclegan_loss = tfgan.cyclegan_loss(
            cyclegan_model,
            cycle_consistency_loss_weight=hparams.
            cycle_consistency_loss_weight,
            tensor_pool_fn=tfgan.features.tensor_pool)

        # Define CycleGAN train ops.
        train_ops = _define_train_ops(cyclegan_model, cyclegan_loss, hparams)

        # Training
        train_steps = tfgan.GANTrainSteps(1, 1)
        status_message = tf.strings.join([
            'Starting train step: ',
            tf.as_string(tf.compat.v1.train.get_or_create_global_step())
        ],
                                         name='status_message')
        if not hparams.max_number_of_steps:
            return
        tfgan.gan_train(
            train_ops,
            hparams.train_log_dir,
            get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
            hooks=[
                tf.estimator.StopAtStepHook(
                    num_steps=hparams.max_number_of_steps),
                tf.estimator.LoggingTensorHook(
                    {'status_message': status_message}, every_n_iter=10)
            ],
            master=hparams.master,
            is_chief=hparams.task == 0)
예제 #4
0
def train(hparams, override_generator_fn=None, override_discriminator_fn=None):
    """Trains a StarGAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
    override_generator_fn: A generator function that overrides the default one.
    override_discriminator_fn: A discriminator function that overrides the
      default one.
  """
    # Create directories if not exist.
    if not tf.io.gfile.exists(hparams.output_dir):
        tf.io.gfile.makedirs(hparams.output_dir)

    # Make sure steps integers are consistent.
    if hparams.max_number_of_steps % hparams.steps_per_eval != 0:
        raise ValueError('`max_number_of_steps` must be divisible by '
                         '`steps_per_eval`.')

    # Create optimizers.
    gen_opt, dis_opt = _get_optimizer(hparams.generator_lr,
                                      hparams.discriminator_lr,
                                      hparams.adam_beta1, hparams.adam_beta2)

    # Create estimator.
    stargan_estimator = tfgan.estimator.StarGANEstimator(
        generator_fn=override_generator_fn or network.generator,
        discriminator_fn=override_discriminator_fn or network.discriminator,
        loss_fn=tfgan.stargan_loss,
        generator_optimizer=gen_opt,
        discriminator_optimizer=dis_opt,
        get_hooks_fn=tfgan.get_sequential_train_hooks(
            _define_train_step(hparams.gen_disc_step_ratio)),
        add_summaries=tfgan.estimator.SummaryType.IMAGES)

    # Get input function for training and test images.
    train_input_fn = lambda: data_provider.provide_data(  # pylint:disable=g-long-lambda
        'train', hparams.batch_size, hparams.patch_size)
    test_images_np = data_provider.provide_celeba_test_set(hparams.patch_size)
    filename_str = os.path.join(hparams.output_dir, 'summary_image_%i.png')

    # Periodically train and write prediction output to disk.
    cur_step = 0
    while cur_step < hparams.max_number_of_steps:
        cur_step += hparams.steps_per_eval
        stargan_estimator.train(train_input_fn, steps=cur_step)
        summary_img = _get_summary_image(stargan_estimator, test_images_np)
        with tf.io.gfile.GFile(filename_str % cur_step, 'w') as f:
            PIL.Image.fromarray(
                (255 * summary_img).astype(np.uint8)).save(f, 'PNG')
예제 #5
0
def main(_):
    if not tf.gfile.Exists(FLAGS.train_log_dir):
        tf.gfile.MakeDirs(FLAGS.train_log_dir)

    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
        with tf.name_scope('inputs'):
            initializer_hook = load_op(FLAGS.batch_size, FLAGS.max_number_of_steps)
            training_input_iter = initializer_hook.input_itr
            images_x, images_y = training_input_iter.get_next()
            # Set batch size for summaries.
            # images_x.set_shape([FLAGS.batch_size, None, None, None])
            # images_y.set_shape([FLAGS.batch_size, None, None, None])

        # Define CycleGAN model.
        cyclegan_model = _define_model(images_x, images_y)

        # Define CycleGAN loss.
        cyclegan_loss = tfgan.cyclegan_loss(
            cyclegan_model,
            cycle_consistency_loss_weight=FLAGS.cycle_consistency_loss_weight,
            tensor_pool_fn=tfgan.features.tensor_pool)

        # Define CycleGAN train ops.
        train_ops = _define_train_ops(cyclegan_model, cyclegan_loss)

        # Training
        train_steps = tfgan.GANTrainSteps(1, 1)
        status_message = tf.string_join(
            [
                'Starting train step: ',
                tf.as_string(tf.train.get_or_create_global_step())
            ],
            name='status_message')
        if not FLAGS.max_number_of_steps:
            return
        tfgan.gan_train(
            train_ops,
            FLAGS.train_log_dir,
            save_checkpoint_secs=60*10,
            get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
            hooks=[
                initializer_hook,
                tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
                tf.train.LoggingTensorHook([status_message], every_n_iter=10)
            ],
            master=FLAGS.master,
            is_chief=FLAGS.task == 0)
예제 #6
0
def train(model, **kwargs):
    """Trains progressive GAN for stage `stage_id`.

  Args:
    model: An model object having all information of progressive GAN model,
        e.g. the return of build_model().
    **kwargs: A dictionary of
        'train_root_dir': A string of root directory of training logs.
        'master': Name of the TensorFlow master to use.
        'task': The Task ID. This value is used when training with multiple
            workers to identify each worker.
        'save_summaries_num_images': Save summaries in this number of images.
        'debug_hook': Whether to attach the debug hook to the training session.
  Returns:
    None.
  """
    logging.info('stage_id=%d, num_blocks=%d, num_images=%d', model.stage_id,
                 model.num_blocks, model.num_images)

    scaffold = make_scaffold(model.stage_id, model.optimizer_var_list,
                             **kwargs)

    logdir = make_train_sub_dir(model.stage_id, **kwargs)
    print('starting training, logdir: {}'.format(logdir))
    hooks = []
    if model.stage_train_time_limit is None:
        hooks.append(tf.train.StopAtStepHook(last_step=model.num_images))
    hooks.append(
        tf.train.LoggingTensorHook([make_status_message(model)],
                                   every_n_iter=1))
    hooks.append(TrainTimeHook(model.train_time, model.stage_train_time_limit))
    if kwargs['debug_hook']:
        hooks.append(ProganDebugHook())
    tfgan.gan_train(model.gan_train_ops,
                    logdir=logdir,
                    get_hooks_fn=tfgan.get_sequential_train_hooks(
                        tfgan.GANTrainSteps(1, 1)),
                    hooks=hooks,
                    master=kwargs['master'],
                    is_chief=(kwargs['task'] == 0),
                    scaffold=scaffold,
                    save_checkpoint_secs=600,
                    save_summaries_steps=(kwargs['save_summaries_num_images']))
예제 #7
0
파일: train_test.py 프로젝트: Aerochip7/gan
  def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn):
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return
    model = create_gan_model_fn()
    loss = tfgan.gan_loss(model)

    g_opt = get_sync_optimizer()
    d_opt = get_sync_optimizer()
    train_ops = tfgan.gan_train_ops(
        model,
        loss,
        g_opt,
        d_opt,
        summarize_gradients=True,
        colocate_gradients_with_ops=True)

    sequential_train_hooks = tfgan.get_sequential_train_hooks()(train_ops)
    self.assertLen(sequential_train_hooks, 4)
    sync_opts = [
        hook._sync_optimizer
        for hook in sequential_train_hooks
        if isinstance(hook, get_sync_optimizer_hook_type())
    ]
    self.assertLen(sync_opts, 2)
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))

    joint_train_hooks = tfgan.get_joint_train_hooks()(train_ops)
    self.assertLen(joint_train_hooks, 5)
    sync_opts = [
        hook._sync_optimizer
        for hook in joint_train_hooks
        if isinstance(hook, get_sync_optimizer_hook_type())
    ]
    self.assertLen(sync_opts, 2)
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
예제 #8
0
def train(hparams):
  """Trains a CycleGAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
  if not tf.io.gfile.exists(hparams.train_log_dir):
    tf.io.gfile.makedirs(hparams.train_log_dir)
    
  with open(hparams.train_log_dir + 'train_result.json', 'w') as fp:
    json.dump(hparams._asdict(), fp, indent=4)

  with tf.device(tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)):
    with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'):
      images_x, images_y = _get_data(hparams.image_set_x_file_pattern,
                                     hparams.image_set_y_file_pattern,
                                     hparams.batch_size, hparams.patch_size, hparams.tfdata_source)

    # Define CycleGAN model.
    cyclegan_model = _define_model(images_x, images_y)

    # Define CycleGAN loss.
    cyclegan_loss = tfgan.cyclegan_loss(
        cyclegan_model,
        cycle_consistency_loss_weight=hparams.cycle_consistency_loss_weight,
        tensor_pool_fn=tfgan.features.tensor_pool)

    # Define CycleGAN train ops.
    train_ops = _define_train_ops(cyclegan_model, cyclegan_loss, hparams)

    # Training
    train_steps = tfgan.GANTrainSteps(1, 1)
    status_message = tf.strings.join([
        'Starting train step: ',
        tf.as_string(tf.compat.v1.train.get_or_create_global_step())
    ],
                                     name='status_message')
    if not hparams.max_number_of_steps:
      return

    additional_params = {}
    if hparams.save_checkpoint_steps:
        max_to_keep = hparams.max_number_of_steps // hparams.save_checkpoint_steps + 1
        additional_params = {
            'scaffold': tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=max_to_keep)),
            'save_checkpoint_secs': None,
            'save_checkpoint_steps': hparams.save_checkpoint_steps,
        }

    tfgan.gan_train(
        train_ops,
        hparams.train_log_dir,
        get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
        hooks=[
            tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps),
            tf.estimator.LoggingTensorHook({'status_message': status_message},
                                           every_n_iter=10)
        ],
        master=hparams.master,
        is_chief=hparams.task == 0,
        **additional_params,
    )
예제 #9
0
def main(_):
    log_dir = FLAGS.train_log_dir
    if not tf.gfile.Exists(log_dir):
        tf.gfile.MakeDirs(log_dir)

    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
        validation_iteration_count = FLAGS.validation_itr_count
        validation_sample_count = FLAGS.validation_sample_count
        loader_name = FLAGS.loader_name
        neighborhood = 0
        loader = get_class(loader_name + '.' + loader_name)(FLAGS.path)
        data_set = loader.load_data(neighborhood, True)

        shadow_map, shadow_ratio = loader.load_shadow_map(
            neighborhood, data_set)

        with tf.name_scope('inputs'):
            initializer_hook = load_op(FLAGS.batch_size,
                                       FLAGS.max_number_of_steps, loader,
                                       data_set, shadow_map, shadow_ratio,
                                       FLAGS.regularization_support_rate)
            training_input_iter = initializer_hook.input_itr
            images_x, images_y = training_input_iter.get_next()
            # Set batch size for summaries.
            # images_x.set_shape([FLAGS.batch_size, None, None, None])
            # images_y.set_shape([FLAGS.batch_size, None, None, None])

        # Define model.
        gan_type = FLAGS.gan_type
        gan_train_wrapper_dict = {
            "cycle_gan":
            CycleGANWrapper(cycle_consistency_loss_weight=FLAGS.
                            cycle_consistency_loss_weight,
                            identity_loss_weight=FLAGS.identity_loss_weight,
                            use_identity_loss=FLAGS.use_identity_loss),
            "gan_x2y":
            GANWrapper(identity_loss_weight=FLAGS.identity_loss_weight,
                       use_identity_loss=FLAGS.use_identity_loss,
                       swap_inputs=False),
            "gan_y2x":
            GANWrapper(identity_loss_weight=FLAGS.identity_loss_weight,
                       use_identity_loss=FLAGS.use_identity_loss,
                       swap_inputs=True)
        }
        wrapper = gan_train_wrapper_dict[gan_type]

        with tf.variable_scope('Model', reuse=tf.AUTO_REUSE):
            the_gan_model = wrapper.define_model(images_x, images_y)
            peer_validation_hook = wrapper.create_validation_hook(
                data_set, loader, log_dir, neighborhood, shadow_map,
                shadow_ratio, validation_iteration_count,
                validation_sample_count)

            the_gan_loss = wrapper.define_loss(the_gan_model)

        # Define CycleGAN train ops.
        train_ops = _define_train_ops(the_gan_model, the_gan_loss)

        # Training
        train_steps = tfgan.GANTrainSteps(1, 1)
        status_message = tf.string_join([
            'Starting train step: ',
            tf.as_string(tf.train.get_or_create_global_step())
        ],
                                        name='status_message')
        if not FLAGS.max_number_of_steps:
            return

        gpu = tf.config.experimental.list_physical_devices('GPU')
        tf.config.experimental.set_memory_growth(gpu[0], True)

        training_scaffold = Scaffold(saver=tf.train.Saver(max_to_keep=20))

        gan_train(
            train_ops,
            log_dir,
            scaffold=training_scaffold,
            save_checkpoint_steps=validation_iteration_count,
            get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
            hooks=[
                initializer_hook, peer_validation_hook,
                tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
                tf.train.LoggingTensorHook([status_message], every_n_iter=1000)
            ],
            master=FLAGS.master,
            is_chief=FLAGS.task == 0)
예제 #10
0
def train(hparams, override_generator_fn=None, override_discriminator_fn=None):
  """Trains a StarGAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
    override_generator_fn: A generator function that overrides the default one.
    override_discriminator_fn: A discriminator function that overrides the
      default one.
  """
  # Create directories if not exist.
  if not tf.io.gfile.exists(hparams.output_dir):
    tf.io.gfile.makedirs(hparams.output_dir)

  with open(hparams.output_dir + 'train_result.json', 'w') as fp:
    json.dump(hparams._asdict(), fp, indent=4)
    
  # Make sure steps integers are consistent.
  if hparams.max_number_of_steps % hparams.steps_per_eval != 0:
    raise ValueError('`max_number_of_steps` must be divisible by '
                     '`steps_per_eval`.')

  # Create optimizers.
  gen_opt, dis_opt = _get_optimizer(hparams.generator_lr,
                                    hparams.discriminator_lr,
                                    hparams.adam_beta1, hparams.adam_beta2)

  # Create estimator.
  if hparams.cls_model and hparams.cls_checkpoint:
    raise Exception('Can only assign one parameter between hparams.cls_model and hparams.cls_checkpoint')

  if hparams.cls_model:
    print("[!!!!] LOAD custom classification model in discriminator.")

    network_discriminator = network.CustomKerasDiscriminator(hparams.cls_model + '/base_model.h5')
    # network_discriminator = network.custom_keras_discriminator(hparams.cls_model)
    # tf.keras.estimator.model_to_estimator(keras_model_path=hparams.cls_model, model_dir='/tmp/temp_checkpoint/')
  elif hparams.cls_checkpoint:
    network_discriminator = network.custom_tf_discriminator()
  else:
    network_discriminator = network.discriminator

  stargan_estimator = tfgan.estimator.StarGANEstimator(
      model_dir= hparams.output_dir + "checkpoints/",
      generator_fn=override_generator_fn or network.generator,
      discriminator_fn=override_discriminator_fn or network_discriminator,
      # loss_fn=tfgan.stargan_loss,
      loss_fn=_get_stargan_loss(reconstruction_loss_weight=hparams.reconstruction_loss_weight,
                                self_consistency_loss_weight=hparams.self_consistency_loss_weight,
                                classification_loss_weight=hparams.classification_loss_weight),
      generator_optimizer=gen_opt,
      discriminator_optimizer=dis_opt,
      get_hooks_fn=tfgan.get_sequential_train_hooks(
          _define_train_step(hparams.gen_disc_step_ratio)),
      add_summaries=tfgan.estimator.SummaryType.IMAGES,
      config=tf.estimator.RunConfig(save_checkpoints_steps=hparams.save_checkpoints_steps,
                                    keep_checkpoint_max=hparams.keep_checkpoint_max),
      cls_model=hparams.cls_model,
      cls_checkpoint=hparams.cls_checkpoint
  )

  # Get input function for training and test images.
  if (hparams.tfdata_source):
    print("[**] load train dataset: tensorflow dataset: {x}".format(x=hparams.tfdata_source))
    train_input_fn = lambda: data_provider.provide_data(  # pylint:disable=g-long-lambda
        hparams.tfdata_source,
        hparams.batch_size,
        hparams.patch_size,
        split='train',
        color_labeled=hparams.use_color_labels,
        num_parallel_calls=None,
        shuffle=True,
        domains=tuple(hparams.tfdata_source_domains.split(",")),
        download=eval(hparams.download),
        data_dir=hparams.data_dir)

    if hparams.tfdata_source.startswith('cycle_gan'):
        test_images_np = data_provider.provide_cyclegan_test_set(hparams.tfdata_source, hparams.patch_size)
        num_domains = 2
    elif hparams.tfdata_source == 'celeb_a':
        test_images_np = data_provider.provide_celeba_test_set(hparams.patch_size,
                                                               download=eval(hparams.download),
                                                               data_dir=hparams.data_dir)
        num_domains = len(test_images_np)
    else:
        test_images_np, num_domains = data_provider.provide_categorized_test_set(hparams.tfdata_source,
                                                                                 hparams.patch_size,
                                                                                 color_labeled=hparams.use_color_labels,
                                                                                 download=eval(hparams.download),
                                                                                 data_dir=hparams.data_dir)


  else:
    train_input_fn = None
    test_images_np = None
    num_domains = None
    raise Exception("TODO: support external data souce.")
    
  filename_str = os.path.join(hparams.output_dir, 'summary_image_%i.png')

  # Periodically train and write prediction output to disk.
  cur_step = 0
  while cur_step < hparams.max_number_of_steps:
    cur_step += hparams.steps_per_eval
    print("current step: {cur_step} /{max_step}".format(cur_step=cur_step, max_step=hparams.max_number_of_steps))
    stargan_estimator.train(train_input_fn, steps=cur_step)
    summary_img = _get_summary_image(stargan_estimator, test_images_np, num_domains)
    with tf.io.gfile.GFile(filename_str % cur_step, 'w') as f:
        # Handle single-channel images
        if summary_img.shape[2] == 1:
            summary_img = np.repeat(summary_img, 3, axis=2)
        PIL.Image.fromarray((255 * summary_img).astype(np.uint8)).save(f, 'PNG')