def test_reduction_cyclegan(self): if tf.executing_eagerly(): # None of the usual utilities work in eager. return loss = tfgan.cyclegan_loss( create_cyclegan_model(), reduction=tf.compat.v1.losses.Reduction.NONE) self.assertIsInstance(loss, tfgan.CycleGANLoss) self.assertEqual(2, loss.loss_x2y.discriminator_loss.shape.ndims) self.assertEqual(2, loss.loss_x2y.generator_loss.shape.ndims) self.assertEqual(2, loss.loss_y2x.discriminator_loss.shape.ndims) self.assertEqual(2, loss.loss_y2x.generator_loss.shape.ndims)
def test_cyclegan_output_type(self, get_gan_model_fn): if tf.executing_eagerly(): # None of the usual utilities work in eager. return loss = tfgan.cyclegan_loss(get_gan_model_fn(), add_summaries=True) self.assertIsInstance(loss, tfgan.CycleGANLoss) self.assertEqual(0, loss.loss_x2y.discriminator_loss.shape.ndims) self.assertEqual(0, loss.loss_x2y.generator_loss.shape.ndims) self.assertEqual(0, loss.loss_y2x.discriminator_loss.shape.ndims) self.assertEqual(0, loss.loss_y2x.generator_loss.shape.ndims) if not tf.executing_eagerly(): # Collections don't work in eager. self.assertNotEmpty( tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))
def train(hparams): """Trains a CycleGAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) with tf.device( tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)): with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'): images_x, images_y = _get_data(hparams.image_set_x_file_pattern, hparams.image_set_y_file_pattern, hparams.batch_size, hparams.patch_size) # Define CycleGAN model. cyclegan_model = _define_model(images_x, images_y) # Define CycleGAN loss. cyclegan_loss = tfgan.cyclegan_loss( cyclegan_model, cycle_consistency_loss_weight=hparams. cycle_consistency_loss_weight, tensor_pool_fn=tfgan.features.tensor_pool) # Define CycleGAN train ops. train_ops = _define_train_ops(cyclegan_model, cyclegan_loss, hparams) # Training train_steps = tfgan.GANTrainSteps(1, 1) status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.compat.v1.train.get_or_create_global_step()) ], name='status_message') if not hparams.max_number_of_steps: return tfgan.gan_train( train_ops, hparams.train_log_dir, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ tf.estimator.StopAtStepHook( num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook( {'status_message': status_message}, every_n_iter=10) ], master=hparams.master, is_chief=hparams.task == 0)
def test_define_train_ops(self): self.hparams = self.hparams._replace(batch_size=2, generator_lr=0.1, discriminator_lr=0.01) images_shape = [self.hparams.batch_size, 4, 4, 3] images_x = tf.zeros(images_shape, dtype=tf.float32) images_y = tf.zeros(images_shape, dtype=tf.float32) cyclegan_model = train_lib._define_model(images_x, images_y) cyclegan_loss = tfgan.cyclegan_loss(cyclegan_model, cycle_consistency_loss_weight=10.0) train_ops = train_lib._define_train_ops(cyclegan_model, cyclegan_loss, self.hparams) self.assertIsInstance(train_ops, tfgan.GANTrainOps)
def main(_): if not tf.gfile.Exists(FLAGS.train_log_dir): tf.gfile.MakeDirs(FLAGS.train_log_dir) with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): with tf.name_scope('inputs'): initializer_hook = load_op(FLAGS.batch_size, FLAGS.max_number_of_steps) training_input_iter = initializer_hook.input_itr images_x, images_y = training_input_iter.get_next() # Set batch size for summaries. # images_x.set_shape([FLAGS.batch_size, None, None, None]) # images_y.set_shape([FLAGS.batch_size, None, None, None]) # Define CycleGAN model. cyclegan_model = _define_model(images_x, images_y) # Define CycleGAN loss. cyclegan_loss = tfgan.cyclegan_loss( cyclegan_model, cycle_consistency_loss_weight=FLAGS.cycle_consistency_loss_weight, tensor_pool_fn=tfgan.features.tensor_pool) # Define CycleGAN train ops. train_ops = _define_train_ops(cyclegan_model, cyclegan_loss) # Training train_steps = tfgan.GANTrainSteps(1, 1) status_message = tf.string_join( [ 'Starting train step: ', tf.as_string(tf.train.get_or_create_global_step()) ], name='status_message') if not FLAGS.max_number_of_steps: return tfgan.gan_train( train_ops, FLAGS.train_log_dir, save_checkpoint_secs=60*10, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ initializer_hook, tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps), tf.train.LoggingTensorHook([status_message], every_n_iter=10) ], master=FLAGS.master, is_chief=FLAGS.task == 0)
def define_loss(self, model): if self._use_identity_loss: cyclegan_loss = cyclegan_loss_with_identity( model, # generator_loss_fn=wasserstein_generator_loss, # discriminator_loss_fn=wasserstein_discriminator_loss, cycle_consistency_loss_weight=self. _cycle_consistency_loss_weight, identity_loss_weight=self._identity_loss_weight, tensor_pool_fn=tfgan.features.tensor_pool) else: # Define CycleGAN loss. cyclegan_loss = tfgan.cyclegan_loss( model, # generator_loss_fn=wasserstein_generator_loss, # discriminator_loss_fn=wasserstein_discriminator_loss, cycle_consistency_loss_weight=self. _cycle_consistency_loss_weight, tensor_pool_fn=tfgan.features.tensor_pool) return cyclegan_loss
def test_cyclegan(self, create_gan_model_fn): """Test that CycleGan models work.""" if tf.executing_eagerly(): # None of the usual utilities work in eager. return model = create_gan_model_fn() loss = tfgan.cyclegan_loss(model) self.assertIsInstance(loss, tfgan.CycleGANLoss) # Check values. with self.cached_session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) (loss_x2y_gen_np, loss_x2y_dis_np, loss_y2x_gen_np, loss_y2x_dis_np) = sess.run([ loss.loss_x2y.generator_loss, loss.loss_x2y.discriminator_loss, loss.loss_y2x.generator_loss, loss.loss_y2x.discriminator_loss ]) self.assertGreater(loss_x2y_gen_np, loss_x2y_dis_np) self.assertGreater(loss_y2x_gen_np, loss_y2x_dis_np) self.assertTrue(np.isscalar(loss_x2y_gen_np)) self.assertTrue(np.isscalar(loss_x2y_dis_np)) self.assertTrue(np.isscalar(loss_y2x_gen_np)) self.assertTrue(np.isscalar(loss_y2x_dis_np))
def train(hparams): """Trains a CycleGAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) with open(hparams.train_log_dir + 'train_result.json', 'w') as fp: json.dump(hparams._asdict(), fp, indent=4) with tf.device(tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)): with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'): images_x, images_y = _get_data(hparams.image_set_x_file_pattern, hparams.image_set_y_file_pattern, hparams.batch_size, hparams.patch_size, hparams.tfdata_source) # Define CycleGAN model. cyclegan_model = _define_model(images_x, images_y) # Define CycleGAN loss. cyclegan_loss = tfgan.cyclegan_loss( cyclegan_model, cycle_consistency_loss_weight=hparams.cycle_consistency_loss_weight, tensor_pool_fn=tfgan.features.tensor_pool) # Define CycleGAN train ops. train_ops = _define_train_ops(cyclegan_model, cyclegan_loss, hparams) # Training train_steps = tfgan.GANTrainSteps(1, 1) status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.compat.v1.train.get_or_create_global_step()) ], name='status_message') if not hparams.max_number_of_steps: return additional_params = {} if hparams.save_checkpoint_steps: max_to_keep = hparams.max_number_of_steps // hparams.save_checkpoint_steps + 1 additional_params = { 'scaffold': tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=max_to_keep)), 'save_checkpoint_secs': None, 'save_checkpoint_steps': hparams.save_checkpoint_steps, } tfgan.gan_train( train_ops, hparams.train_log_dir, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook({'status_message': status_message}, every_n_iter=10) ], master=hparams.master, is_chief=hparams.task == 0, **additional_params, )