def test_args_to_gan_model(self):
    """Test `_args_to_gan_model`."""
    tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg3'])

    def args_loss(arg1, arg2, arg3=3, arg4=4):
      return arg1 + arg2 + arg3 + arg4

    gan_model_loss = tfgan_losses._args_to_gan_model(args_loss)

    # Value is correct.
    self.assertEqual(1 + 2 + 5 + 6,
                     gan_model_loss(tuple_type(1, 2), arg2=5, arg4=6))

    # Uses tuple argument with defaults.
    self.assertEqual(1 + 5 + 3 + 7,
                     gan_model_loss(tuple_type(1, None), arg2=5, arg4=7))

    # Uses non-tuple argument with defaults.
    self.assertEqual(1 + 5 + 2 + 4,
                     gan_model_loss(tuple_type(1, 2), arg2=5))

    # Requires non-tuple, non-default arguments.
    with self.assertRaisesRegexp(ValueError, '`arg2` must be supplied'):
      gan_model_loss(tuple_type(1, 2))

    # Can't pass tuple argument outside tuple.
    with self.assertRaisesRegexp(
        ValueError, 'present in both the tuple and keyword args'):
      gan_model_loss(tuple_type(1, 2), arg2=1, arg3=5)
 def test_args_to_gan_model_name(self):
   """Test that `_args_to_gan_model` produces correctly named functions."""
   def loss_fn(x):
     return x
   new_loss_fn = tfgan_losses._args_to_gan_model(loss_fn)
   self.assertEqual('loss_fn', new_loss_fn.__name__)
   self.assertTrue('The gan_model version of' in new_loss_fn.__docstring__)
Exemple #3
0
    def test_args_to_gan_model(self):
        """Test `_args_to_gan_model`."""
        tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg3'])

        def args_loss(arg1, arg2, arg3=3, arg4=4):
            return arg1 + arg2 + arg3 + arg4

        gan_model_loss = tfgan_losses._args_to_gan_model(args_loss)

        # Value is correct.
        self.assertEqual(1 + 2 + 5 + 6,
                         gan_model_loss(tuple_type(1, 2), arg2=5, arg4=6))

        # Uses tuple argument with defaults.
        self.assertEqual(1 + 5 + 3 + 7,
                         gan_model_loss(tuple_type(1, None), arg2=5, arg4=7))

        # Uses non-tuple argument with defaults.
        self.assertEqual(1 + 5 + 2 + 4, gan_model_loss(tuple_type(1, 2),
                                                       arg2=5))

        # Requires non-tuple, non-default arguments.
        with self.assertRaisesRegexp(ValueError, '`arg2` must be supplied'):
            gan_model_loss(tuple_type(1, 2))

        # Can't pass tuple argument outside tuple.
        with self.assertRaisesRegexp(
                ValueError, 'present in both the tuple and keyword args'):
            gan_model_loss(tuple_type(1, 2), arg2=1, arg3=5)
Exemple #4
0
    def test_args_to_gan_model_name(self):
        """Test that `_args_to_gan_model` produces correctly named functions."""
        def loss_fn(x):
            return x

        new_loss_fn = tfgan_losses._args_to_gan_model(loss_fn)
        self.assertEqual('loss_fn', new_loss_fn.__name__)
        self.assertTrue(
            'The gan_model version of' in new_loss_fn.__docstring__)
  def test_tuple_respects_optional_args(self):
    """Test that optional args can be changed with tuple losses."""
    tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg2'])
    def args_loss(arg1, arg2, arg3=3):
      return arg1 + 2 * arg2 + 3 * arg3

    loss_fn = tfgan_losses._args_to_gan_model(args_loss)
    loss = loss_fn(tuple_type(arg1=-1, arg2=2), arg3=4)

    # If `arg3` were not set properly, this value would be different.
    self.assertEqual(-1 + 2 * 2 + 3 * 4, loss)
Exemple #6
0
    def test_tuple_respects_optional_args(self):
        """Test that optional args can be changed with tuple losses."""
        tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg2'])

        def args_loss(arg1, arg2, arg3=3):
            return arg1 + 2 * arg2 + 3 * arg3

        loss_fn = tfgan_losses._args_to_gan_model(args_loss)
        loss = loss_fn(tuple_type(arg1=-1, arg2=2), arg3=4)

        # If `arg3` were not set properly, this value would be different.
        self.assertEqual(-1 + 2 * 2 + 3 * 4, loss)
  def test_works_with_child_classes(self):
    """`args_to_gan_model` should work with classes derived from namedtuple."""
    tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg2'])

    class InheritedType(tuple_type):
      pass
    def args_loss(arg1, arg2, arg3=3):
      return arg1 + 2 * arg2 + 3 * arg3

    loss_fn = tfgan_losses._args_to_gan_model(args_loss)
    loss = loss_fn(InheritedType(arg1=-1, arg2=2), arg3=4)

    # If `arg3` were not set properly, this value would be different.
    self.assertEqual(-1 + 2 * 2 + 3 * 4, loss)
Exemple #8
0
    def test_works_with_child_classes(self):
        """`args_to_gan_model` should work with classes derived from namedtuple."""
        tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg2'])

        class InheritedType(tuple_type):
            pass

        def args_loss(arg1, arg2, arg3=3):
            return arg1 + 2 * arg2 + 3 * arg3

        loss_fn = tfgan_losses._args_to_gan_model(args_loss)
        loss = loss_fn(InheritedType(arg1=-1, arg2=2), arg3=4)

        # If `arg3` were not set properly, this value would be different.
        self.assertEqual(-1 + 2 * 2 + 3 * 4, loss)
Exemple #9
0
def main(_):
    if FLAGS.train_log_dir is None:
        timestamp = datetime.now().strftime("%Y-%m-%d_%H%M%S")
        logdir = "{}/{}-{}".format('data_out', "CycleGanEst", timestamp)
    else:
        logdir = FLAGS.train_log_dir

    if not tf.gfile.Exists(logdir):
        tf.gfile.MakeDirs(logdir)

    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
        with tf.name_scope('inputs'):
            if FLAGS.use_dataset == 'synth':
                # Generated two channels. First channel image, second channel is mask.
                images_healthy, images_cancer = data_provider.provide_synth_dataset(
                        batch_size=FLAGS.batch_size, img_size=(FLAGS.height, FLAGS.width))
            elif FLAGS.use_dataset == 'cbis':
                # Generated two channels. First channel image, second channel is mask.
                images_healthy, images_cancer = data_provider.provide_cbis_dataset(
                        [FLAGS.image_x_file, FLAGS.image_y_file],
                        batch_size=FLAGS.batch_size,
                        img_size=(FLAGS.height, FLAGS.width))
            else:
                images_healthy, images_cancer = data_provider.provide_custom_datasets(
                        [FLAGS.image_set_x_file_pattern, FLAGS.image_set_y_file_pattern],
                        batch_size=FLAGS.batch_size,
                        img_size=(FLAGS.height, FLAGS.width))

        # Define CycleGAN model.
        print("images healthy", images_healthy.get_shape())
        print("images cancer", images_cancer.get_shape())
        cyclegan_model = _define_model(images_healthy, images_cancer, FLAGS.include_masks)

        # Define CycleGAN loss.
        if FLAGS.gan_type == 'lsgan':
            print("Using lsgan")
            generator_loss_fn = _args_to_gan_model(tfgan.losses.wargs.least_squares_generator_loss)
            discriminator_loss_fn = _args_to_gan_model(tfgan.losses.wargs.least_squares_discriminator_loss)
        elif FLAGS.gan_type == 'hinge':
            print("Using hinge")
            generator_loss_fn = _args_to_gan_model(mygan.hinge_generator_loss)
            discriminator_loss_fn = _args_to_gan_model(mygan.hinge_discriminator_loss)
        else:
            raise ValueError("Unknown gan type.")

        cyclegan_loss = tfgan.cyclegan_loss(
                cyclegan_model,
                cycle_consistency_loss_weight=FLAGS.cycle_consistency_loss_weight,
                generator_loss_fn=generator_loss_fn,
                discriminator_loss_fn=discriminator_loss_fn,
                cycle_consistency_loss_fn=functools.partial(
                        mygan.cycle_consistency_loss, lambda_identity=FLAGS.loss_identity_lambda),
                tensor_pool_fn=tfgan.features.tensor_pool)

        # Define CycleGAN train ops.
        train_ops = _define_train_ops(cyclegan_model, cyclegan_loss)

        # Training
        train_steps = tfgan.GANTrainSteps(1, 1)
        status_message = tf.string_join(
                ['Starting train step: ', tf.as_string(tf.train.get_or_create_global_step())], name='status_message')
        if not FLAGS.max_number_of_steps:
            return

        # To avoid problems with GPU memmory.
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True  # Do not assign whole gpu memory, just use it on the go
        # If a operation is not define it the default device, let it execute in another.
        config.allow_soft_placement = True
        hooks = [
                tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
                tf.train.LoggingTensorHook([status_message], every_n_iter=10),
        ]
        if FLAGS.checkpoint_hook_steps > 0:
            chkpt_hook = tf.train.CheckpointSaverHook(
                    checkpoint_dir=os.path.join(logdir, 'chook'),
                    save_steps=FLAGS.checkpoint_hook_steps,
                    saver=tf.train.Saver(max_to_keep=300))
            hooks.append(chkpt_hook)

        tfgan.gan_train(
                train_ops,
                logdir,
                hooks=hooks,
                get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
                master=FLAGS.master,
                is_chief=FLAGS.task == 0,
                config=config)