def main(_): hparams = train_lib.HParams( FLAGS.image_set_x_file_pattern, FLAGS.image_set_y_file_pattern, FLAGS.batch_size, FLAGS.patch_size, FLAGS.master, FLAGS.train_log_dir, FLAGS.generator_lr, FLAGS.discriminator_lr, FLAGS.max_number_of_steps, FLAGS.ps_replicas, FLAGS.task, FLAGS.cycle_consistency_loss_weight) train_lib.train(hparams)
def testTrainingAndInferenceGraphsAreCompatible(self, mock_provide_custom_data, unused_mock_gan_train): if tf.executing_eagerly(): # `tfgan.cyclegan_model` doesn't work when executing eagerly. return # Training and inference graphs can get out of sync if changes are made # to one but not the other. This test will keep them in sync. # Save the training graph train_sess = tf.Session() hparams = train_lib.HParams(image_set_x_file_pattern='/tmp/x/*.jpg', image_set_y_file_pattern='/tmp/y/*.jpg', batch_size=3, patch_size=128, master='master', train_log_dir=self._export_dir, generator_lr=0.02, discriminator_lr=0.3, max_number_of_steps=1, ps_replicas=0, task=0, cycle_consistency_loss_weight=2.0) mock_provide_custom_data.return_value = (tf.zeros([ 3, 4, 4, 3, ]), tf.zeros([3, 4, 4, 3])) train_lib.train(hparams) init_op = tf.global_variables_initializer() train_sess.run(init_op) train_saver = tf.train.Saver() train_saver.save(train_sess, save_path=self._ckpt_path) # Create inference graph tf.reset_default_graph() FLAGS.patch_dim = hparams.patch_size logging.info('dir_path: %s', os.listdir(self._export_dir)) FLAGS.checkpoint_path = self._ckpt_path FLAGS.image_set_x_glob = self._image_glob FLAGS.image_set_y_glob = self._image_glob FLAGS.generated_x_dir = self._genx_dir FLAGS.generated_y_dir = self._geny_dir inference_demo.main(None) logging.info('gen x: %s', os.listdir(self._genx_dir)) # Check that the image names match self.assertSetEqual(set(_basenames_from_glob(FLAGS.image_set_x_glob)), set(os.listdir(FLAGS.generated_y_dir))) self.assertSetEqual(set(_basenames_from_glob(FLAGS.image_set_y_glob)), set(os.listdir(FLAGS.generated_x_dir))) # Check that each image in the directory looks as expected for directory in [FLAGS.generated_x_dir, FLAGS.generated_x_dir]: for base_name in os.listdir(directory): image_path = os.path.join(directory, base_name) self.assertRealisticImage(image_path)
def test_main(self, mock_gan_train, mock_define_train_ops, mock_cyclegan_loss, mock_define_model, mock_data_provider, mock_gfile): self.hparams = self.hparams._replace( image_set_x_file_pattern='/tmp/x/*.jpg', image_set_y_file_pattern='/tmp/y/*.jpg', batch_size=3, patch_size=8, generator_lr=0.02, discriminator_lr=0.3, train_log_dir='/tmp/foo', master='master', task=0, cycle_consistency_loss_weight=2.0, max_number_of_steps=1) mock_data_provider.provide_custom_data.return_value = (tf.zeros( [3, 2, 2, 3], dtype=tf.float32), tf.zeros([3, 2, 2, 3], dtype=tf.float32)) train_lib.train(self.hparams) mock_data_provider.provide_custom_data.assert_called_once_with( batch_size=3, image_file_patterns=['/tmp/x/*.jpg', '/tmp/y/*.jpg'], patch_size=8) mock_define_model.assert_called_once_with(mock.ANY, mock.ANY) mock_cyclegan_loss.assert_called_once_with( mock_define_model.return_value, cycle_consistency_loss_weight=2.0, tensor_pool_fn=mock.ANY) mock_define_train_ops.assert_called_once_with( mock_define_model.return_value, mock_cyclegan_loss.return_value, self.hparams) mock_gan_train.assert_called_once_with( mock_define_train_ops.return_value, '/tmp/foo', get_hooks_fn=mock.ANY, hooks=mock.ANY, master='master', is_chief=True)