예제 #1
0
파일: train_test.py 프로젝트: Aerochip7/gan
  def test_supervisor_run_gan_model_train_ops_multiple_steps(self):
    """Test that the train ops work with the old-style supervisor."""
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return

    step = tf.compat.v1.train.create_global_step()
    train_ops = tfgan.GANTrainOps(
        generator_train_op=tf.constant(3.0),
        discriminator_train_op=tf.constant(2.0),
        global_step_inc_op=step.assign_add(1))
    train_steps = tfgan.GANTrainSteps(
        generator_train_steps=3, discriminator_train_steps=4)
    number_of_steps = 1

    # Typical simple Supervisor use.
    train_step_kwargs = {}
    train_step_kwargs['should_stop'] = tf.greater_equal(step, number_of_steps)
    train_step_fn = tfgan.get_sequential_train_steps(train_steps)
    sv = tf.compat.v1.train.Supervisor(logdir='', global_step=step)
    with sv.managed_session(master='') as sess:
      while not sv.should_stop():
        total_loss, should_stop = train_step_fn(
            sess, train_ops, step, train_step_kwargs)
        if should_stop:
          sv.request_stop()
          break

    # Correctness checks.
    self.assertTrue(np.isscalar(total_loss))
    self.assertEqual(17.0, total_loss)
예제 #2
0
def train(model, **kwargs):
  """Trains progressive GAN for stage `stage_id`.

  Args:
    model: An model object having all information of progressive GAN model, e.g.
      the return of build_model().
    **kwargs: A dictionary of
        'train_log_dir': A string of root directory of training logs.
        'master': Name of the TensorFlow master to use.
        'task': The Task ID. This value is used when training with multiple
          workers to identify each worker.
        'save_summaries_num_images': Save summaries in this number of images.

  Returns:
    None.
  """
  logging.info('stage_id=%d, num_blocks=%d, num_images=%d', model.stage_id,
               model.num_blocks, model.num_images)

  scaffold = make_scaffold(model.stage_id, model.optimizer_var_list, **kwargs)

  tfgan.gan_train(
      model.gan_train_ops,
      logdir=make_train_sub_dir(model.stage_id, **kwargs),
      get_hooks_fn=tfgan.get_sequential_train_hooks(tfgan.GANTrainSteps(1, 1)),
      hooks=[
          tf.estimator.StopAtStepHook(last_step=model.num_images),
          tf.estimator.LoggingTensorHook([make_status_message(model)],
                                         every_n_iter=10)
      ],
      master=kwargs['master'],
      is_chief=(kwargs['task'] == 0),
      scaffold=scaffold,
      save_checkpoint_secs=600,
      save_summaries_steps=(kwargs['save_summaries_num_images']))
예제 #3
0
def _define_train_step(gen_disc_step_ratio):
    """Get the training step for generator and discriminator for each GAN step.

  Args:
    gen_disc_step_ratio: A python number. The ratio of generator to
      discriminator training steps.

  Returns:
    GANTrainSteps namedtuple representing the training step configuration.
  """

    if gen_disc_step_ratio <= 1:
        discriminator_step = int(1 / gen_disc_step_ratio)
        return tfgan.GANTrainSteps(1, discriminator_step)
    else:
        generator_step = int(gen_disc_step_ratio)
        return tfgan.GANTrainSteps(generator_step, 1)
예제 #4
0
파일: train_lib.py 프로젝트: yyht/gan
def train(hparams):
    """Trains a CycleGAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
    if not tf.io.gfile.exists(hparams.train_log_dir):
        tf.io.gfile.makedirs(hparams.train_log_dir)

    with tf.device(
            tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)):
        with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'):
            images_x, images_y = _get_data(hparams.image_set_x_file_pattern,
                                           hparams.image_set_y_file_pattern,
                                           hparams.batch_size,
                                           hparams.patch_size)

        # Define CycleGAN model.
        cyclegan_model = _define_model(images_x, images_y)

        # Define CycleGAN loss.
        cyclegan_loss = tfgan.cyclegan_loss(
            cyclegan_model,
            cycle_consistency_loss_weight=hparams.
            cycle_consistency_loss_weight,
            tensor_pool_fn=tfgan.features.tensor_pool)

        # Define CycleGAN train ops.
        train_ops = _define_train_ops(cyclegan_model, cyclegan_loss, hparams)

        # Training
        train_steps = tfgan.GANTrainSteps(1, 1)
        status_message = tf.strings.join([
            'Starting train step: ',
            tf.as_string(tf.compat.v1.train.get_or_create_global_step())
        ],
                                         name='status_message')
        if not hparams.max_number_of_steps:
            return
        tfgan.gan_train(
            train_ops,
            hparams.train_log_dir,
            get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
            hooks=[
                tf.estimator.StopAtStepHook(
                    num_steps=hparams.max_number_of_steps),
                tf.estimator.LoggingTensorHook(
                    {'status_message': status_message}, every_n_iter=10)
            ],
            master=hparams.master,
            is_chief=hparams.task == 0)
예제 #5
0
파일: train_test.py 프로젝트: Aerochip7/gan
  def test_multiple_steps(self, get_hooks_fn_fn):
    """Test multiple train steps."""
    if tf.executing_eagerly():
      # None of the usual utilities work in eager.
      return
    train_ops = self._gan_train_ops(generator_add=10, discriminator_add=100)
    train_steps = tfgan.GANTrainSteps(
        generator_train_steps=3, discriminator_train_steps=4)
    final_step = tfgan.gan_train(
        train_ops,
        get_hooks_fn=get_hooks_fn_fn(train_steps),
        logdir='',
        hooks=[tf.estimator.StopAtStepHook(num_steps=1)])

    self.assertTrue(np.isscalar(final_step))
    self.assertEqual(1 + 3 * 10 + 4 * 100, final_step)
예제 #6
0
def main(_):
    if not tf.gfile.Exists(FLAGS.train_log_dir):
        tf.gfile.MakeDirs(FLAGS.train_log_dir)

    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
        with tf.name_scope('inputs'):
            initializer_hook = load_op(FLAGS.batch_size, FLAGS.max_number_of_steps)
            training_input_iter = initializer_hook.input_itr
            images_x, images_y = training_input_iter.get_next()
            # Set batch size for summaries.
            # images_x.set_shape([FLAGS.batch_size, None, None, None])
            # images_y.set_shape([FLAGS.batch_size, None, None, None])

        # Define CycleGAN model.
        cyclegan_model = _define_model(images_x, images_y)

        # Define CycleGAN loss.
        cyclegan_loss = tfgan.cyclegan_loss(
            cyclegan_model,
            cycle_consistency_loss_weight=FLAGS.cycle_consistency_loss_weight,
            tensor_pool_fn=tfgan.features.tensor_pool)

        # Define CycleGAN train ops.
        train_ops = _define_train_ops(cyclegan_model, cyclegan_loss)

        # Training
        train_steps = tfgan.GANTrainSteps(1, 1)
        status_message = tf.string_join(
            [
                'Starting train step: ',
                tf.as_string(tf.train.get_or_create_global_step())
            ],
            name='status_message')
        if not FLAGS.max_number_of_steps:
            return
        tfgan.gan_train(
            train_ops,
            FLAGS.train_log_dir,
            save_checkpoint_secs=60*10,
            get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
            hooks=[
                initializer_hook,
                tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
                tf.train.LoggingTensorHook([status_message], every_n_iter=10)
            ],
            master=FLAGS.master,
            is_chief=FLAGS.task == 0)
예제 #7
0
    def test_train(self, g_steps, d_steps, joint_train,
                   expected_total_substeps, expected_g_substep_mask,
                   expected_d_substep_mask):
        real_opt = tf.compat.v1.train.GradientDescentOptimizer(1e-2)
        gopt = TestOptimizerWrapper(real_opt, name='g_opt')
        dopt = TestOptimizerWrapper(real_opt, name='d_opt')
        est = tfgan.estimator.TPUGANEstimator(
            generator_fn=generator_fn,
            discriminator_fn=discriminator_fn,
            generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
            discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
            generator_optimizer=gopt,
            discriminator_optimizer=dopt,
            gan_train_steps=tfgan.GANTrainSteps(g_steps, d_steps),
            joint_train=joint_train,
            get_eval_metric_ops_fn=get_metrics,
            train_batch_size=4,
            eval_batch_size=10,
            predict_batch_size=8,
            use_tpu=flags.FLAGS.use_tpu,
            config=self._config)

        def train_input_fn(params):
            data = tf.ones([params['batch_size'], 4], dtype=tf.float32)
            return data, data

        est.train(train_input_fn, steps=1)

        self.assertEqual(1, est.get_variable_value('global_step'))

        substep_counter_name = 'discriminator_train/substep_counter'
        if d_steps == 0:
            substep_counter_name = 'generator_train/substep_counter'
        substep_counter = est.get_variable_value(substep_counter_name)
        self.assertEqual(expected_total_substeps, substep_counter)

        if expected_g_substep_mask is not None:
            g_substep_mask = est.get_variable_value(
                'generator_train/substep_mask')
            self.assertIn(g_substep_mask, expected_g_substep_mask)
        if expected_d_substep_mask is not None:
            d_substep_mask = est.get_variable_value(
                'discriminator_train/substep_mask')
            self.assertIn(d_substep_mask, expected_d_substep_mask)
예제 #8
0
def train(model, **kwargs):
    """Trains progressive GAN for stage `stage_id`.

  Args:
    model: An model object having all information of progressive GAN model,
        e.g. the return of build_model().
    **kwargs: A dictionary of
        'train_root_dir': A string of root directory of training logs.
        'master': Name of the TensorFlow master to use.
        'task': The Task ID. This value is used when training with multiple
            workers to identify each worker.
        'save_summaries_num_images': Save summaries in this number of images.
        'debug_hook': Whether to attach the debug hook to the training session.
  Returns:
    None.
  """
    logging.info('stage_id=%d, num_blocks=%d, num_images=%d', model.stage_id,
                 model.num_blocks, model.num_images)

    scaffold = make_scaffold(model.stage_id, model.optimizer_var_list,
                             **kwargs)

    logdir = make_train_sub_dir(model.stage_id, **kwargs)
    print('starting training, logdir: {}'.format(logdir))
    hooks = []
    if model.stage_train_time_limit is None:
        hooks.append(tf.train.StopAtStepHook(last_step=model.num_images))
    hooks.append(
        tf.train.LoggingTensorHook([make_status_message(model)],
                                   every_n_iter=1))
    hooks.append(TrainTimeHook(model.train_time, model.stage_train_time_limit))
    if kwargs['debug_hook']:
        hooks.append(ProganDebugHook())
    tfgan.gan_train(model.gan_train_ops,
                    logdir=logdir,
                    get_hooks_fn=tfgan.get_sequential_train_hooks(
                        tfgan.GANTrainSteps(1, 1)),
                    hooks=hooks,
                    master=kwargs['master'],
                    is_chief=(kwargs['task'] == 0),
                    scaffold=scaffold,
                    save_checkpoint_secs=600,
                    save_summaries_steps=(kwargs['save_summaries_num_images']))
예제 #9
0
  def test_get_train_estimator_spec(self, joint_train):
    with tf.Graph().as_default():
      if joint_train:
        gan_model_fns = [get_dummy_gan_model]
      else:
        gan_model_fns = [get_dummy_gan_model, get_dummy_gan_model]
      spec = get_train_estimator_spec(
          gan_model_fns,
          self._loss_fns,
          {},  # gan_loss_kwargs
          self._optimizers,
          joint_train=joint_train,
          is_on_tpu=flags.FLAGS.use_tpu,
          gan_train_steps=tfgan.GANTrainSteps(1, 1),
          add_summaries=not flags.FLAGS.use_tpu)

    self.assertIsInstance(spec, TPUEstimatorSpec)
    self.assertEqual(tf.estimator.ModeKeys.TRAIN, spec.mode)

    self.assertShapeEqual(np.array(0), spec.loss)  # must be a scalar
    self.assertIsNotNone(spec.train_op)
    self.assertIsNotNone(spec.training_hooks)
예제 #10
0
def train(hparams):
  """Trains a CycleGAN.

  Args:
    hparams: An HParams instance containing the hyperparameters for training.
  """
  if not tf.io.gfile.exists(hparams.train_log_dir):
    tf.io.gfile.makedirs(hparams.train_log_dir)
    
  with open(hparams.train_log_dir + 'train_result.json', 'w') as fp:
    json.dump(hparams._asdict(), fp, indent=4)

  with tf.device(tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)):
    with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'):
      images_x, images_y = _get_data(hparams.image_set_x_file_pattern,
                                     hparams.image_set_y_file_pattern,
                                     hparams.batch_size, hparams.patch_size, hparams.tfdata_source)

    # Define CycleGAN model.
    cyclegan_model = _define_model(images_x, images_y)

    # Define CycleGAN loss.
    cyclegan_loss = tfgan.cyclegan_loss(
        cyclegan_model,
        cycle_consistency_loss_weight=hparams.cycle_consistency_loss_weight,
        tensor_pool_fn=tfgan.features.tensor_pool)

    # Define CycleGAN train ops.
    train_ops = _define_train_ops(cyclegan_model, cyclegan_loss, hparams)

    # Training
    train_steps = tfgan.GANTrainSteps(1, 1)
    status_message = tf.strings.join([
        'Starting train step: ',
        tf.as_string(tf.compat.v1.train.get_or_create_global_step())
    ],
                                     name='status_message')
    if not hparams.max_number_of_steps:
      return

    additional_params = {}
    if hparams.save_checkpoint_steps:
        max_to_keep = hparams.max_number_of_steps // hparams.save_checkpoint_steps + 1
        additional_params = {
            'scaffold': tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=max_to_keep)),
            'save_checkpoint_secs': None,
            'save_checkpoint_steps': hparams.save_checkpoint_steps,
        }

    tfgan.gan_train(
        train_ops,
        hparams.train_log_dir,
        get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
        hooks=[
            tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps),
            tf.estimator.LoggingTensorHook({'status_message': status_message},
                                           every_n_iter=10)
        ],
        master=hparams.master,
        is_chief=hparams.task == 0,
        **additional_params,
    )
예제 #11
0
def main(_):
    log_dir = FLAGS.train_log_dir
    if not tf.gfile.Exists(log_dir):
        tf.gfile.MakeDirs(log_dir)

    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
        validation_iteration_count = FLAGS.validation_itr_count
        validation_sample_count = FLAGS.validation_sample_count
        loader_name = FLAGS.loader_name
        neighborhood = 0
        loader = get_class(loader_name + '.' + loader_name)(FLAGS.path)
        data_set = loader.load_data(neighborhood, True)

        shadow_map, shadow_ratio = loader.load_shadow_map(
            neighborhood, data_set)

        with tf.name_scope('inputs'):
            initializer_hook = load_op(FLAGS.batch_size,
                                       FLAGS.max_number_of_steps, loader,
                                       data_set, shadow_map, shadow_ratio,
                                       FLAGS.regularization_support_rate)
            training_input_iter = initializer_hook.input_itr
            images_x, images_y = training_input_iter.get_next()
            # Set batch size for summaries.
            # images_x.set_shape([FLAGS.batch_size, None, None, None])
            # images_y.set_shape([FLAGS.batch_size, None, None, None])

        # Define model.
        gan_type = FLAGS.gan_type
        gan_train_wrapper_dict = {
            "cycle_gan":
            CycleGANWrapper(cycle_consistency_loss_weight=FLAGS.
                            cycle_consistency_loss_weight,
                            identity_loss_weight=FLAGS.identity_loss_weight,
                            use_identity_loss=FLAGS.use_identity_loss),
            "gan_x2y":
            GANWrapper(identity_loss_weight=FLAGS.identity_loss_weight,
                       use_identity_loss=FLAGS.use_identity_loss,
                       swap_inputs=False),
            "gan_y2x":
            GANWrapper(identity_loss_weight=FLAGS.identity_loss_weight,
                       use_identity_loss=FLAGS.use_identity_loss,
                       swap_inputs=True)
        }
        wrapper = gan_train_wrapper_dict[gan_type]

        with tf.variable_scope('Model', reuse=tf.AUTO_REUSE):
            the_gan_model = wrapper.define_model(images_x, images_y)
            peer_validation_hook = wrapper.create_validation_hook(
                data_set, loader, log_dir, neighborhood, shadow_map,
                shadow_ratio, validation_iteration_count,
                validation_sample_count)

            the_gan_loss = wrapper.define_loss(the_gan_model)

        # Define CycleGAN train ops.
        train_ops = _define_train_ops(the_gan_model, the_gan_loss)

        # Training
        train_steps = tfgan.GANTrainSteps(1, 1)
        status_message = tf.string_join([
            'Starting train step: ',
            tf.as_string(tf.train.get_or_create_global_step())
        ],
                                        name='status_message')
        if not FLAGS.max_number_of_steps:
            return

        gpu = tf.config.experimental.list_physical_devices('GPU')
        tf.config.experimental.set_memory_growth(gpu[0], True)

        training_scaffold = Scaffold(saver=tf.train.Saver(max_to_keep=20))

        gan_train(
            train_ops,
            log_dir,
            scaffold=training_scaffold,
            save_checkpoint_steps=validation_iteration_count,
            get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
            hooks=[
                initializer_hook, peer_validation_hook,
                tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
                tf.train.LoggingTensorHook([status_message], every_n_iter=1000)
            ],
            master=FLAGS.master,
            is_chief=FLAGS.task == 0)