Exemple #1
0
def _get_loss_for_train(gan_model, loss_fns, gan_loss_kwargs, add_summaries):
    kwargs = gan_loss_kwargs or {}
    return tfgan_train.gan_loss(gan_model,
                                loss_fns.g_loss_fn,
                                loss_fns.d_loss_fn,
                                add_summaries=add_summaries,
                                **kwargs)
Exemple #2
0
        def _model_fn(features, labels, mode, params):
            """GANEstimator model function."""
            if mode not in [
                    tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL,
                    tf.estimator.ModeKeys.PREDICT
            ]:
                raise ValueError('Mode not recognized: %s' % mode)
            real_data = labels  # rename inputs for clarity
            generator_inputs = features  # rename inputs for clarity

            # Make GANModel, which encapsulates the GAN model architectures.
            gan_model = get_gan_model(mode, generator_fn, discriminator_fn,
                                      real_data, generator_inputs,
                                      add_summaries)

            # Make GANLoss, which encapsulates the losses.
            if mode in [
                    tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL
            ]:
                gan_loss_kwargs = extract_gan_loss_args_from_params(
                    params) or {}
                gan_loss = tfgan_train.gan_loss(
                    gan_model,
                    generator_loss_fn,
                    discriminator_loss_fn,
                    add_summaries=use_loss_summaries,
                    **gan_loss_kwargs)

            # Make the EstimatorSpec, which incorporates the GANModel, losses, eval
            # metrics, and optimizers (if required).
            if mode == tf.estimator.ModeKeys.TRAIN:
                estimator_spec = get_train_estimator_spec(gan_model,
                                                          gan_loss,
                                                          optimizers,
                                                          get_hooks_fn,
                                                          is_chief=is_chief)
            elif mode == tf.estimator.ModeKeys.EVAL:
                estimator_spec = get_eval_estimator_spec(
                    gan_model, gan_loss, get_eval_metric_ops_fn)
            else:  # tf.estimator.ModeKeys.PREDICT
                estimator_spec = get_predict_estimator_spec(gan_model)

            return estimator_spec
Exemple #3
0
def get_eval_estimator_spec(gan_model_fns, loss_fns, gan_loss_kwargs,
                            get_eval_metric_ops_fn, add_summaries):
  """Estimator spec for eval case."""
  assert len(gan_model_fns) == 1, (
      '`gan_models` must be length 1 in eval mode. Got length %d' %
      len(gan_model_fns))

  gan_model = gan_model_fns[0]()

  _maybe_add_summaries(gan_model, add_summaries)

  # Eval losses for metrics must preserve batch dimension.
  kwargs = gan_loss_kwargs or {}
  gan_loss_no_reduction = tfgan_train.gan_loss(
      gan_model,
      loss_fns.g_loss_fn,
      loss_fns.d_loss_fn,
      add_summaries=add_summaries,
      reduction=tf.compat.v1.losses.Reduction.NONE,
      **kwargs)

  # Make the metric function and tensor names.
  if get_eval_metric_ops_fn is not None:
    metric_fn = _make_custom_metric_fn(get_eval_metric_ops_fn)
    tensors_for_metric_fn = _make_custom_metric_tensors(
        gan_model, gan_loss_no_reduction)
  else:
    metric_fn = _make_default_metric_fn()
    tensors_for_metric_fn = _make_default_metric_tensors(gan_loss_no_reduction)

  scalar_loss = tf.compat.v1.losses.compute_weighted_loss(
      gan_loss_no_reduction.discriminator_loss,
      loss_collection=None,
      reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)

  return contrib.TPUEstimatorSpec(
      mode=tf.estimator.ModeKeys.EVAL,
      predictions=_predictions_from_generator_output(gan_model.generated_data),
      loss=scalar_loss,
      eval_metrics=(metric_fn, tensors_for_metric_fn))
Exemple #4
0
    def estimator_model_fn(cls,
                           hparams,
                           features,
                           labels,
                           mode,
                           config=None,
                           params=None,
                           decode_hparams=None,
                           use_tpu=False):

        if mode not in [
                model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL,
                model_fn_lib.ModeKeys.PREDICT
        ]:
            raise ValueError('Mode not recognized: %s' % mode)

        if mode is model_fn_lib.ModeKeys.TRAIN:
            is_training = True
        else:
            is_training = False

        hparams = hparams_lib.copy_hparams(hparams)

        # Instantiate model
        data_parallelism = None
        if not use_tpu and config:
            data_parallelism = config.data_parallelism
        reuse = tf.get_variable_scope().reuse

        # Instantiate model
        self = cls(hparams,
                   mode,
                   data_parallelism=data_parallelism,
                   decode_hparams=decode_hparams,
                   _reuse=reuse)

        generator_inputs = self.sample_noise()
        # rename inputs for clarity
        real_data = features['inputs']
        img_shape = common_layers.shape_list(real_data)[1:4]
        real_data.set_shape([hparams.batch_size] + img_shape)

        # To satify the TFGAN API setting real data to none on predict
        if mode == tf.estimator.ModeKeys.PREDICT:
            real_data = None

        optimizers = Optimizers(
            tf.compat.v1.train.AdamOptimizer(hparams.generator_lr,
                                             hparams.beta1),
            tf.compat.v1.train.AdamOptimizer(hparams.discriminator_lr,
                                             hparams.beta1))

        # Creates tfhub modules for both generator and discriminator
        def make_discriminator_spec():
            input_layer = tf.placeholder(tf.float32, shape=[None] + img_shape)
            disc_output = self.discriminator(input_layer, None, mode)
            hub.add_signature(inputs=input_layer, outputs=disc_output)

        disc_spec = hub.create_module_spec(make_discriminator_spec)

        def make_generator_spec():
            input_layer = tf.placeholder(
                tf.float32,
                shape=[None] + common_layers.shape_list(generator_inputs)[1:])
            gen_output = self.generator(input_layer, mode)
            hub.add_signature(inputs=input_layer, outputs=gen_output)

        gen_spec = hub.create_module_spec(make_generator_spec)

        # Create the modules
        discriminator_module = hub.Module(disc_spec,
                                          name="Discriminator_Module",
                                          trainable=True)
        generator_module = hub.Module(gen_spec,
                                      name="Generator_Module",
                                      trainable=True)

        # Wraps the modules into functions expected by TF-GAN
        def generator(code, mode):
            p = hparams
            out = generator_module(code)
            shape = common_layers.shape_list(out)
            # Applying convolution by PSF convolution
            if p.apply_psf and 'psf' in features:
                out = convolve(out,
                               tf.cast(features['psf'][..., 0], tf.complex64))

            # Adds noise according to the provided power spectrum
            noise = tf.spectral.rfft2d(tf.random_normal(out.get_shape()[:3]))
            thresholded_ps = tf.where(features['ps'] >= 9,
                                      tf.zeros_like(features['ps']),
                                      tf.sqrt(tf.exp(features['ps'])))
            noise = noise * tf.cast(thresholded_ps, tf.complex64)
            out = out + tf.expand_dims(tf.spectral.irfft2d(noise), axis=-1)
            return out

        discriminator = lambda image, conditioning, mode: discriminator_module(
            image)

        # Make GANModel, which encapsulates the GAN model architectures.
        gan_model = get_gan_model(mode,
                                  generator,
                                  discriminator,
                                  real_data,
                                  generator_inputs,
                                  add_summaries=self.summaries)

        # Make GANLoss, which encapsulates the losses.
        if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]:
            gan_loss = tfgan_train.gan_loss(gan_model,
                                            self.generator_loss,
                                            self.discriminator_loss,
                                            add_summaries=True)

        # Make the EstimatorSpec, which incorporates the GANModel, losses, eval
        # metrics, and optimizers (if required).
        if mode == tf.estimator.ModeKeys.TRAIN:
            get_hooks_fn = tfgan_train.get_sequential_train_hooks(
                namedtuples.GANTrainSteps(hparams.gen_steps,
                                          hparams.disc_steps))
            estimator_spec = get_train_estimator_spec(gan_model,
                                                      gan_loss,
                                                      optimizers,
                                                      get_hooks_fn,
                                                      is_chief=True)
        elif mode == tf.estimator.ModeKeys.EVAL:
            estimator_spec = get_eval_estimator_spec(gan_model, gan_loss)
        else:  # tf.estimator.ModeKeys.PREDICT
            # Register hub modules for export
            hub.register_module_for_export(generator_module, "generator")
            hub.register_module_for_export(discriminator_module,
                                           "discriminator")
            estimator_spec = get_predict_estimator_spec(gan_model)
        return estimator_spec
Exemple #5
0
def get_eval_estimator_spec(gan_model_fns, loss_fns, gan_loss_kwargs,
                            prepare_arguments_for_eval_metric_fn,
                            get_eval_metric_ops_fn, add_summaries):
    """Estimator spec for eval case."""
    assert len(gan_model_fns) == 1, (
        '`gan_models` must be length 1 in eval mode. Got length %d' %
        len(gan_model_fns))

    gan_model = gan_model_fns[0]()

    _maybe_add_summaries(gan_model, add_summaries)

    # Eval losses for metrics must preserve batch dimension.
    kwargs = gan_loss_kwargs or {}
    gan_loss_no_reduction = tfgan_train.gan_loss(
        gan_model,
        loss_fns.g_loss_fn,
        loss_fns.d_loss_fn,
        add_summaries=add_summaries,
        reduction=tf.compat.v1.losses.Reduction.NONE,
        **kwargs)

    if prepare_arguments_for_eval_metric_fn is None:
        # Set the default prepare_arguments_for_eval_metric_fn value: a function
        # that returns its arguments in a dict.
        prepare_arguments_for_eval_metric_fn = lambda **kwargs: kwargs

    default_metric_fn = _make_default_metric_fn()
    # Prepare tensors needed for calculating the metrics: the first element in
    # `tensors_for_metric_fn` holds a dict containing the arguments for
    # `default_metric_fn`, and the second element holds a dict for arguments for
    # `get_eval_metric_ops_fn` (if it is not None).
    tensors_for_metric_fn = [
        _make_default_metric_tensors(gan_loss_no_reduction)
    ]
    if get_eval_metric_ops_fn is not None:
        tensors_for_metric_fn.append(
            prepare_arguments_for_eval_metric_fn(
                **_make_custom_metric_tensors(gan_model)))

    scalar_loss = tf.compat.v1.losses.compute_weighted_loss(
        gan_loss_no_reduction.discriminator_loss,
        loss_collection=None,
        reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)

    # TPUEstimatorSpec.eval_metrics expects a function and a list of tensors,
    # however, some sturctures in tensors_for_metric_fn might be dictionaries
    # (e.g., generator_inputs and real_data). We therefore need to flatten
    # tensors_for_metric_fn before passing them to the function and then restoring
    # the original structure inside the function.
    def _metric_fn_wrapper(*args):
        """Unflattens the arguments and pass them to the metric functions."""
        unpacked_arguments = tf.nest.pack_sequence_as(tensors_for_metric_fn,
                                                      args)
        # Calculate default metrics.
        metrics = default_metric_fn(**unpacked_arguments[0])
        if get_eval_metric_ops_fn is not None:
            # Append custom metrics.
            custom_eval_metric_ops = get_eval_metric_ops_fn(
                **unpacked_arguments[1])
            if not isinstance(custom_eval_metric_ops, dict):
                raise TypeError('`get_eval_metric_ops_fn` must return a dict, '
                                'received: {}'.format(custom_eval_metric_ops))
            metrics.update(custom_eval_metric_ops)

        return metrics

    flat_tensors = tf.nest.flatten(tensors_for_metric_fn)
    if not all(isinstance(t, tf.Tensor) for t in flat_tensors):
        raise ValueError('All objects nested within the TF-GAN model must be '
                         'tensors. Instead, types are: %s.' %
                         str([type(v) for v in flat_tensors]))
    return tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL,
        predictions=_predictions_from_generator_output(
            gan_model.generated_data),
        loss=scalar_loss,
        eval_metrics=(_metric_fn_wrapper, flat_tensors))