def train(hparams): """Trains a StarGAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ # Create the log_dir if not exist. if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) # Shard the model to different parameter servers. with tf.device(tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)): # Create the input dataset. with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'): images, labels = data_provider.provide_data('train', hparams.batch_size, hparams.patch_size) # Define the model. with tf.compat.v1.name_scope('model'): model = _define_model(images, labels) # Add image summary. tfgan.eval.add_stargan_image_summaries( model, num_images=3 * hparams.batch_size, display_diffs=True) # Define the model loss. loss = tfgan.stargan_loss(model) # Define the train ops. with tf.compat.v1.name_scope('train_ops'): train_ops = _define_train_ops(model, loss, hparams.generator_lr, hparams.discriminator_lr, hparams.adam_beta1, hparams.adam_beta2, hparams.max_number_of_steps) # Define the train steps. train_steps = _define_train_step(hparams.gen_disc_step_ratio) # Define a status message. status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.compat.v1.train.get_or_create_global_step()) ], name='status_message') # Train the model. tfgan.gan_train( train_ops, hparams.train_log_dir, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook([status_message], every_n_iter=10) ], master=hparams.tf_master, is_chief=hparams.task == 0)
def test_define_train_ops(self): hparams = self.hparams._replace( batch_size=2, generator_lr=0.1, discriminator_lr=0.01) images_shape = [hparams.batch_size, 4, 4, 3] images = tf.zeros(images_shape, dtype=tf.float32) labels = tf.one_hot([0] * hparams.batch_size, 2) model = train_lib._define_model(images, labels) loss = tfgan.stargan_loss(model) train_ops = train_lib._define_train_ops(model, loss, hparams.generator_lr, hparams.discriminator_lr, hparams.adam_beta1, hparams.adam_beta2, hparams.max_number_of_steps) self.assertIsInstance(train_ops, tfgan.GANTrainOps)
def test_stargan(self, create_gan_model_fn): if tf.executing_eagerly(): # None of the usual utilities work in eager. return model = create_gan_model_fn() model_loss = tfgan.stargan_loss(model) self.assertIsInstance(model_loss, tfgan.GANLoss) with self.cached_session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) gen_loss, disc_loss = sess.run( [model_loss.generator_loss, model_loss.discriminator_loss]) self.assertTrue(np.isscalar(gen_loss)) self.assertTrue(np.isscalar(disc_loss))