def main(_): hparams = train_lib.HParams(FLAGS.batch_size, FLAGS.patch_size, FLAGS.train_log_dir, FLAGS.generator_lr, FLAGS.discriminator_lr, FLAGS.max_number_of_steps, FLAGS.adam_beta1, FLAGS.adam_beta2, FLAGS.gen_disc_step_ratio, FLAGS.tf_master, FLAGS.ps_replicas, FLAGS.task) train_lib.train(hparams)
def test_main(self, mock_provide_data): hparams = self.hparams._replace(batch_size=2, max_number_of_steps=10) num_domains = 3 # Construct mock inputs. images_shape = [ hparams.batch_size, hparams.patch_size, hparams.patch_size, 3 ] img_list = [tf.zeros(images_shape)] * num_domains lbl_list = [tf.one_hot([0] * hparams.batch_size, num_domains)] * num_domains mock_provide_data.return_value = (img_list, lbl_list) train_lib.train(hparams)
def test_main(self, mock_provide_data): if tf.executing_eagerly(): # `tfgan.stargan_model` doesn't work when executing eagerly. return hparams = self.hparams._replace(batch_size=2, max_number_of_steps=10) num_domains = 3 # Construct mock inputs. images_shape = [ hparams.batch_size, hparams.patch_size, hparams.patch_size, 3 ] img_list = [tf.zeros(images_shape)] * num_domains lbl_list = [tf.one_hot([0] * hparams.batch_size, num_domains)] * num_domains mock_provide_data.return_value = (img_list, lbl_list) train_lib.train(hparams)