def test_args_to_gan_model(self): """Test `_args_to_gan_model`.""" tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg3']) def args_loss(arg1, arg2, arg3=3, arg4=4): return arg1 + arg2 + arg3 + arg4 gan_model_loss = tfgan_losses._args_to_gan_model(args_loss) # Value is correct. self.assertEqual(1 + 2 + 5 + 6, gan_model_loss(tuple_type(1, 2), arg2=5, arg4=6)) # Uses tuple argument with defaults. self.assertEqual(1 + 5 + 3 + 7, gan_model_loss(tuple_type(1, None), arg2=5, arg4=7)) # Uses non-tuple argument with defaults. self.assertEqual(1 + 5 + 2 + 4, gan_model_loss(tuple_type(1, 2), arg2=5)) # Requires non-tuple, non-default arguments. with self.assertRaisesRegexp(ValueError, '`arg2` must be supplied'): gan_model_loss(tuple_type(1, 2)) # Can't pass tuple argument outside tuple. with self.assertRaisesRegexp( ValueError, 'present in both the tuple and keyword args'): gan_model_loss(tuple_type(1, 2), arg2=1, arg3=5)
def test_args_to_gan_model_name(self): """Test that `_args_to_gan_model` produces correctly named functions.""" def loss_fn(x): return x new_loss_fn = tfgan_losses._args_to_gan_model(loss_fn) self.assertEqual('loss_fn', new_loss_fn.__name__) self.assertTrue('The gan_model version of' in new_loss_fn.__docstring__)
def test_args_to_gan_model_name(self): """Test that `_args_to_gan_model` produces correctly named functions.""" def loss_fn(x): return x new_loss_fn = tfgan_losses._args_to_gan_model(loss_fn) self.assertEqual('loss_fn', new_loss_fn.__name__) self.assertTrue( 'The gan_model version of' in new_loss_fn.__docstring__)
def test_tuple_respects_optional_args(self): """Test that optional args can be changed with tuple losses.""" tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg2']) def args_loss(arg1, arg2, arg3=3): return arg1 + 2 * arg2 + 3 * arg3 loss_fn = tfgan_losses._args_to_gan_model(args_loss) loss = loss_fn(tuple_type(arg1=-1, arg2=2), arg3=4) # If `arg3` were not set properly, this value would be different. self.assertEqual(-1 + 2 * 2 + 3 * 4, loss)
def test_works_with_child_classes(self): """`args_to_gan_model` should work with classes derived from namedtuple.""" tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg2']) class InheritedType(tuple_type): pass def args_loss(arg1, arg2, arg3=3): return arg1 + 2 * arg2 + 3 * arg3 loss_fn = tfgan_losses._args_to_gan_model(args_loss) loss = loss_fn(InheritedType(arg1=-1, arg2=2), arg3=4) # If `arg3` were not set properly, this value would be different. self.assertEqual(-1 + 2 * 2 + 3 * 4, loss)
def main(_): if FLAGS.train_log_dir is None: timestamp = datetime.now().strftime("%Y-%m-%d_%H%M%S") logdir = "{}/{}-{}".format('data_out', "CycleGanEst", timestamp) else: logdir = FLAGS.train_log_dir if not tf.gfile.Exists(logdir): tf.gfile.MakeDirs(logdir) with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): with tf.name_scope('inputs'): if FLAGS.use_dataset == 'synth': # Generated two channels. First channel image, second channel is mask. images_healthy, images_cancer = data_provider.provide_synth_dataset( batch_size=FLAGS.batch_size, img_size=(FLAGS.height, FLAGS.width)) elif FLAGS.use_dataset == 'cbis': # Generated two channels. First channel image, second channel is mask. images_healthy, images_cancer = data_provider.provide_cbis_dataset( [FLAGS.image_x_file, FLAGS.image_y_file], batch_size=FLAGS.batch_size, img_size=(FLAGS.height, FLAGS.width)) else: images_healthy, images_cancer = data_provider.provide_custom_datasets( [FLAGS.image_set_x_file_pattern, FLAGS.image_set_y_file_pattern], batch_size=FLAGS.batch_size, img_size=(FLAGS.height, FLAGS.width)) # Define CycleGAN model. print("images healthy", images_healthy.get_shape()) print("images cancer", images_cancer.get_shape()) cyclegan_model = _define_model(images_healthy, images_cancer, FLAGS.include_masks) # Define CycleGAN loss. if FLAGS.gan_type == 'lsgan': print("Using lsgan") generator_loss_fn = _args_to_gan_model(tfgan.losses.wargs.least_squares_generator_loss) discriminator_loss_fn = _args_to_gan_model(tfgan.losses.wargs.least_squares_discriminator_loss) elif FLAGS.gan_type == 'hinge': print("Using hinge") generator_loss_fn = _args_to_gan_model(mygan.hinge_generator_loss) discriminator_loss_fn = _args_to_gan_model(mygan.hinge_discriminator_loss) else: raise ValueError("Unknown gan type.") cyclegan_loss = tfgan.cyclegan_loss( cyclegan_model, cycle_consistency_loss_weight=FLAGS.cycle_consistency_loss_weight, generator_loss_fn=generator_loss_fn, discriminator_loss_fn=discriminator_loss_fn, cycle_consistency_loss_fn=functools.partial( mygan.cycle_consistency_loss, lambda_identity=FLAGS.loss_identity_lambda), tensor_pool_fn=tfgan.features.tensor_pool) # Define CycleGAN train ops. train_ops = _define_train_ops(cyclegan_model, cyclegan_loss) # Training train_steps = tfgan.GANTrainSteps(1, 1) status_message = tf.string_join( ['Starting train step: ', tf.as_string(tf.train.get_or_create_global_step())], name='status_message') if not FLAGS.max_number_of_steps: return # To avoid problems with GPU memmory. config = tf.ConfigProto() config.gpu_options.allow_growth = True # Do not assign whole gpu memory, just use it on the go # If a operation is not define it the default device, let it execute in another. config.allow_soft_placement = True hooks = [ tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps), tf.train.LoggingTensorHook([status_message], every_n_iter=10), ] if FLAGS.checkpoint_hook_steps > 0: chkpt_hook = tf.train.CheckpointSaverHook( checkpoint_dir=os.path.join(logdir, 'chook'), save_steps=FLAGS.checkpoint_hook_steps, saver=tf.train.Saver(max_to_keep=300)) hooks.append(chkpt_hook) tfgan.gan_train( train_ops, logdir, hooks=hooks, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), master=FLAGS.master, is_chief=FLAGS.task == 0, config=config)