def get_estimator_spec(mode,
                       gan_model,
                       loss_fn,
                       get_eval_metric_ops_fn,
                       generator_optimizer,
                       discriminator_optimizer,
                       get_hooks_fn=None):
  """Get the EstimatorSpec for the current mode."""
  if mode == tf.estimator.ModeKeys.PREDICT:
    estimator_spec = tf.estimator.EstimatorSpec(
        mode=mode, predictions=gan_model.generated_data)
  else:
    gan_loss = loss_fn(gan_model)
    if mode == tf.estimator.ModeKeys.EVAL:
      estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss,
                                                get_eval_metric_ops_fn)
    else:  # tf.estimator.ModeKeys.TRAIN:

      def _maybe_callable(x):
        return x() if callable(x) else x

      gopt = _maybe_callable(generator_optimizer)
      dopt = _maybe_callable(discriminator_optimizer)
      get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks()
      estimator_spec = _get_train_estimator_spec(gan_model, gan_loss, gopt,
                                                 dopt, get_hooks_fn)

  return estimator_spec
Example #2
0
def get_train_estimator_spec(
    gan_model, gan_loss, optimizers,
    get_hooks_fn, train_op_fn=tfgan_train.gan_train_ops, is_chief=True):
  """Return an EstimatorSpec for the train case."""
  get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks()
  optimizers = _maybe_construct_optimizers(optimizers)

  train_ops = train_op_fn(gan_model, gan_loss, optimizers.gopt,
                          optimizers.dopt, is_chief=is_chief)
  training_hooks = get_hooks_fn(train_ops)
  return tf.estimator.EstimatorSpec(
      loss=gan_loss.discriminator_loss,
      mode=tf.estimator.ModeKeys.TRAIN,
      train_op=train_ops.global_step_inc_op,
      training_hooks=training_hooks)
Example #3
0
    def estimator_model_fn(cls,
                           hparams,
                           features,
                           labels,
                           mode,
                           config=None,
                           params=None,
                           decode_hparams=None,
                           use_tpu=False):

        if mode not in [
                model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL,
                model_fn_lib.ModeKeys.PREDICT
        ]:
            raise ValueError('Mode not recognized: %s' % mode)

        if mode is model_fn_lib.ModeKeys.TRAIN:
            is_training = True
        else:
            is_training = False

        hparams = hparams_lib.copy_hparams(hparams)

        # Instantiate model
        data_parallelism = None
        if not use_tpu and config:
            data_parallelism = config.data_parallelism
        reuse = tf.get_variable_scope().reuse

        # Instantiate model
        self = cls(hparams,
                   mode,
                   data_parallelism=data_parallelism,
                   decode_hparams=decode_hparams,
                   _reuse=reuse)

        generator_inputs = self.sample_noise()
        # rename inputs for clarity
        real_data = features['inputs']
        img_shape = common_layers.shape_list(real_data)[1:4]
        real_data.set_shape([hparams.batch_size] + img_shape)

        # To satify the TFGAN API setting real data to none on predict
        if mode == tf.estimator.ModeKeys.PREDICT:
            real_data = None

        optimizers = Optimizers(
            tf.compat.v1.train.AdamOptimizer(hparams.generator_lr,
                                             hparams.beta1),
            tf.compat.v1.train.AdamOptimizer(hparams.discriminator_lr,
                                             hparams.beta1))

        # Creates tfhub modules for both generator and discriminator
        def make_discriminator_spec():
            input_layer = tf.placeholder(tf.float32, shape=[None] + img_shape)
            disc_output = self.discriminator(input_layer, None, mode)
            hub.add_signature(inputs=input_layer, outputs=disc_output)

        disc_spec = hub.create_module_spec(make_discriminator_spec)

        def make_generator_spec():
            input_layer = tf.placeholder(
                tf.float32,
                shape=[None] + common_layers.shape_list(generator_inputs)[1:])
            gen_output = self.generator(input_layer, mode)
            hub.add_signature(inputs=input_layer, outputs=gen_output)

        gen_spec = hub.create_module_spec(make_generator_spec)

        # Create the modules
        discriminator_module = hub.Module(disc_spec,
                                          name="Discriminator_Module",
                                          trainable=True)
        generator_module = hub.Module(gen_spec,
                                      name="Generator_Module",
                                      trainable=True)

        # Wraps the modules into functions expected by TF-GAN
        def generator(code, mode):
            p = hparams
            out = generator_module(code)
            shape = common_layers.shape_list(out)
            # Applying convolution by PSF convolution
            if p.apply_psf and 'psf' in features:
                out = convolve(out,
                               tf.cast(features['psf'][..., 0], tf.complex64))

            # Adds noise according to the provided power spectrum
            noise = tf.spectral.rfft2d(tf.random_normal(out.get_shape()[:3]))
            thresholded_ps = tf.where(features['ps'] >= 9,
                                      tf.zeros_like(features['ps']),
                                      tf.sqrt(tf.exp(features['ps'])))
            noise = noise * tf.cast(thresholded_ps, tf.complex64)
            out = out + tf.expand_dims(tf.spectral.irfft2d(noise), axis=-1)
            return out

        discriminator = lambda image, conditioning, mode: discriminator_module(
            image)

        # Make GANModel, which encapsulates the GAN model architectures.
        gan_model = get_gan_model(mode,
                                  generator,
                                  discriminator,
                                  real_data,
                                  generator_inputs,
                                  add_summaries=self.summaries)

        # Make GANLoss, which encapsulates the losses.
        if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]:
            gan_loss = tfgan_train.gan_loss(gan_model,
                                            self.generator_loss,
                                            self.discriminator_loss,
                                            add_summaries=True)

        # Make the EstimatorSpec, which incorporates the GANModel, losses, eval
        # metrics, and optimizers (if required).
        if mode == tf.estimator.ModeKeys.TRAIN:
            get_hooks_fn = tfgan_train.get_sequential_train_hooks(
                namedtuples.GANTrainSteps(hparams.gen_steps,
                                          hparams.disc_steps))
            estimator_spec = get_train_estimator_spec(gan_model,
                                                      gan_loss,
                                                      optimizers,
                                                      get_hooks_fn,
                                                      is_chief=True)
        elif mode == tf.estimator.ModeKeys.EVAL:
            estimator_spec = get_eval_estimator_spec(gan_model, gan_loss)
        else:  # tf.estimator.ModeKeys.PREDICT
            # Register hub modules for export
            hub.register_module_for_export(generator_module, "generator")
            hub.register_module_for_export(discriminator_module,
                                           "discriminator")
            estimator_spec = get_predict_estimator_spec(gan_model)
        return estimator_spec
Example #4
0
def get_estimator_spec(mode,
                       gan_model,
                       loss_fn,
                       get_eval_metric_ops_fn,
                       generator_optimizer,
                       discriminator_optimizer,
                       get_hooks_fn=None,
                       cls_model=None,
                       cls_checkpoint=None):
    """Get the EstimatorSpec for the current mode."""

    if PREDICT_FLAG:
        predictions = {
            'input_data':
            gan_model.input_data,
            'input_data_domain_label':
            gan_model.input_data_domain_label,
            'generated_data':
            gan_model.generated_data,
            'generated_data_domain_target':
            gan_model.generated_data_domain_target,
            'reconstructed_data':
            gan_model.reconstructed_data,
            'discriminator_input_data_source_predication':
            gan_model.discriminator_input_data_source_predication,
            'discriminator_generated_data_source_predication':
            gan_model.discriminator_generated_data_source_predication,
            'discriminator_input_data_domain_predication':
            gan_model.discriminator_input_data_domain_predication,
            'discriminator_generated_data_domain_predication':
            gan_model.discriminator_generated_data_domain_predication,
        }
    else:
        predictions = gan_model.generated_data

    if mode == tf.estimator.ModeKeys.PREDICT:
        estimator_spec = tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions
            # predictions=gan_model.generated_data
            # predictions={
            #     'input_data': gan_model.input_data,
            #     'input_data_domain_label': gan_model.input_data_domain_label,
            #     'generated_data': gan_model.generated_data,
            #     'generated_data_domain_target': gan_model.generated_data_domain_target,
            #     'reconstructed_data': gan_model.reconstructed_data,
            #     'discriminator_input_data_source_predication': gan_model.discriminator_input_data_source_predication,
            #     'discriminator_generated_data_source_predication': gan_model.discriminator_generated_data_source_predication,
            #     'discriminator_input_data_domain_predication': gan_model.discriminator_input_data_domain_predication,
            #     'discriminator_generated_data_domain_predication': gan_model.discriminator_generated_data_domain_predication,
            # }
        )
    else:
        gan_loss = loss_fn(gan_model)
        if mode == tf.estimator.ModeKeys.EVAL:
            estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss,
                                                      get_eval_metric_ops_fn)
        else:  # tf.estimator.ModeKeys.TRAIN:

            def _maybe_callable(x):
                return x() if callable(x) else x

            gopt = _maybe_callable(generator_optimizer)
            dopt = _maybe_callable(discriminator_optimizer)
            get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks(
            )
            estimator_spec = _get_train_estimator_spec(
                gan_model,
                gan_loss,
                gopt,
                dopt,
                get_hooks_fn,
                cls_model=cls_model,
                cls_checkpoint=cls_checkpoint)

    return estimator_spec
Example #5
0
def gan_train(train_ops,
              logdir,
              get_hooks_fn=get_sequential_train_hooks(),
              master='',
              is_chief=True,
              scaffold=None,
              hooks=None,
              chief_only_hooks=None,
              save_checkpoint_secs=USE_DEFAULT,
              save_summaries_steps=USE_DEFAULT,
              save_checkpoint_steps=USE_DEFAULT,
              max_wait_secs=7200,
              config=None):
    """A wrapper around `contrib.training.train` that uses GAN hooks.

    Args:
      save_checkpoint_steps: Checkpoint steps to
      train_ops: A GANTrainOps named tuple.
      logdir: The directory where the graph and checkpoints are saved.
      get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
        of hooks.
      master: The URL of the master.
      is_chief: Specifies whether or not the training is being run by the primary
        replica during replica training.
      scaffold: An tf.train.Scaffold instance.
      hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
        training loop.
      chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
        inside the training loop for the chief trainer only.
      save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
        using a default checkpoint saver. If `save_checkpoint_secs` is set to
        `None`, then the default checkpoint saver isn't used.
      save_summaries_steps: The frequency, in number of global steps, that the
        summaries are written to disk using a default summary saver. If
        `save_summaries_steps` is set to `None`, then the default summary saver
        isn't used.
      max_wait_secs: Maximum time workers should wait for the session to
        become available. This should be kept relatively short to help detect
        incorrect code, but sometimes may need to be increased if the chief takes
        a while to start up.
      config: An instance of `tf.ConfigProto`.

    Returns:
      Output of the call to `training.train`.
    """
    _validate_gan_train_inputs(logdir, is_chief, save_summaries_steps,
                               save_checkpoint_secs)
    new_hooks = get_hooks_fn(train_ops)
    if hooks is not None:
        hooks = list(hooks) + list(new_hooks)
    else:
        hooks = new_hooks

    with tf.compat.v1.train.MonitoredTrainingSession(
            master=master,
            is_chief=is_chief,
            checkpoint_dir=logdir,
            scaffold=scaffold,
            hooks=hooks,
            chief_only_hooks=chief_only_hooks,
            save_checkpoint_secs=save_checkpoint_secs,
            save_summaries_steps=save_summaries_steps,
            save_checkpoint_steps=save_checkpoint_steps,
            config=config,
            max_wait_secs=max_wait_secs) as session:
        gstep = None
        while not session.should_stop():
            gstep = session.run(train_ops.global_step_inc_op)
    return gstep