Example #1
0
def get_tpu_estimator(generator, discriminator, hparams, config):
    return tfgan.estimator.TPUGANEstimator(
        generator_fn=generator,
        discriminator_fn=discriminator,
        generator_loss_fn=tfgan.losses.wasserstein_hinge_generator_loss,
        discriminator_loss_fn=tfgan.losses.
        wasserstein_hinge_discriminator_loss,
        generator_optimizer=tf.compat.v1.train.AdamOptimizer(
            hparams.generator_lr, hparams.beta1),
        discriminator_optimizer=tf.compat.v1.train.AdamOptimizer(
            hparams.discriminator_lr, hparams.beta1),
        prepare_arguments_for_eval_metric_fn=prepare_metric_arguments,
        get_eval_metric_ops_fn=functools.partial(get_metrics, hparams=hparams),
        eval_on_tpu=hparams.debug_params.eval_on_tpu,
        # Define update schedule
        gan_train_steps=tfgan_tuples.GANTrainSteps(flags.FLAGS.G_steps,
                                                   flags.FLAGS.D_steps),
        train_batch_size=hparams.train_batch_size,
        eval_batch_size=hparams.eval_batch_size,
        predict_batch_size=hparams.predict_batch_size,
        use_tpu=hparams.debug_params.use_tpu,
        config=config,
        params=hparams._asdict(),
        warm_start_from=os.path.join(hparams.model_dir, hparams.load_dir)
        if hparams.load_dir else None)
Example #2
0
    def estimator_model_fn(cls,
                           hparams,
                           features,
                           labels,
                           mode,
                           config=None,
                           params=None,
                           decode_hparams=None,
                           use_tpu=False):

        if mode not in [
                model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL,
                model_fn_lib.ModeKeys.PREDICT
        ]:
            raise ValueError('Mode not recognized: %s' % mode)

        if mode is model_fn_lib.ModeKeys.TRAIN:
            is_training = True
        else:
            is_training = False

        hparams = hparams_lib.copy_hparams(hparams)

        # Instantiate model
        data_parallelism = None
        if not use_tpu and config:
            data_parallelism = config.data_parallelism
        reuse = tf.get_variable_scope().reuse

        # Instantiate model
        self = cls(hparams,
                   mode,
                   data_parallelism=data_parallelism,
                   decode_hparams=decode_hparams,
                   _reuse=reuse)

        generator_inputs = self.sample_noise()
        # rename inputs for clarity
        real_data = features['inputs']
        img_shape = common_layers.shape_list(real_data)[1:4]
        real_data.set_shape([hparams.batch_size] + img_shape)

        # To satify the TFGAN API setting real data to none on predict
        if mode == tf.estimator.ModeKeys.PREDICT:
            real_data = None

        optimizers = Optimizers(
            tf.compat.v1.train.AdamOptimizer(hparams.generator_lr,
                                             hparams.beta1),
            tf.compat.v1.train.AdamOptimizer(hparams.discriminator_lr,
                                             hparams.beta1))

        # Creates tfhub modules for both generator and discriminator
        def make_discriminator_spec():
            input_layer = tf.placeholder(tf.float32, shape=[None] + img_shape)
            disc_output = self.discriminator(input_layer, None, mode)
            hub.add_signature(inputs=input_layer, outputs=disc_output)

        disc_spec = hub.create_module_spec(make_discriminator_spec)

        def make_generator_spec():
            input_layer = tf.placeholder(
                tf.float32,
                shape=[None] + common_layers.shape_list(generator_inputs)[1:])
            gen_output = self.generator(input_layer, mode)
            hub.add_signature(inputs=input_layer, outputs=gen_output)

        gen_spec = hub.create_module_spec(make_generator_spec)

        # Create the modules
        discriminator_module = hub.Module(disc_spec,
                                          name="Discriminator_Module",
                                          trainable=True)
        generator_module = hub.Module(gen_spec,
                                      name="Generator_Module",
                                      trainable=True)

        # Wraps the modules into functions expected by TF-GAN
        def generator(code, mode):
            p = hparams
            out = generator_module(code)
            shape = common_layers.shape_list(out)
            # Applying convolution by PSF convolution
            if p.apply_psf and 'psf' in features:
                out = convolve(out,
                               tf.cast(features['psf'][..., 0], tf.complex64))

            # Adds noise according to the provided power spectrum
            noise = tf.spectral.rfft2d(tf.random_normal(out.get_shape()[:3]))
            thresholded_ps = tf.where(features['ps'] >= 9,
                                      tf.zeros_like(features['ps']),
                                      tf.sqrt(tf.exp(features['ps'])))
            noise = noise * tf.cast(thresholded_ps, tf.complex64)
            out = out + tf.expand_dims(tf.spectral.irfft2d(noise), axis=-1)
            return out

        discriminator = lambda image, conditioning, mode: discriminator_module(
            image)

        # Make GANModel, which encapsulates the GAN model architectures.
        gan_model = get_gan_model(mode,
                                  generator,
                                  discriminator,
                                  real_data,
                                  generator_inputs,
                                  add_summaries=self.summaries)

        # Make GANLoss, which encapsulates the losses.
        if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]:
            gan_loss = tfgan_train.gan_loss(gan_model,
                                            self.generator_loss,
                                            self.discriminator_loss,
                                            add_summaries=True)

        # Make the EstimatorSpec, which incorporates the GANModel, losses, eval
        # metrics, and optimizers (if required).
        if mode == tf.estimator.ModeKeys.TRAIN:
            get_hooks_fn = tfgan_train.get_sequential_train_hooks(
                namedtuples.GANTrainSteps(hparams.gen_steps,
                                          hparams.disc_steps))
            estimator_spec = get_train_estimator_spec(gan_model,
                                                      gan_loss,
                                                      optimizers,
                                                      get_hooks_fn,
                                                      is_chief=True)
        elif mode == tf.estimator.ModeKeys.EVAL:
            estimator_spec = get_eval_estimator_spec(gan_model, gan_loss)
        else:  # tf.estimator.ModeKeys.PREDICT
            # Register hub modules for export
            hub.register_module_for_export(generator_module, "generator")
            hub.register_module_for_export(discriminator_module,
                                           "discriminator")
            estimator_spec = get_predict_estimator_spec(gan_model)
        return estimator_spec
Example #3
0
    def __init__(
            self,
            # Arguments to construct the `model_fn`.
            generator_fn=None,
            discriminator_fn=None,
            generator_loss_fn=None,
            discriminator_loss_fn=None,
            generator_optimizer=None,
            discriminator_optimizer=None,
            prepare_arguments_for_eval_metric_fn=None,
            get_eval_metric_ops_fn=None,
            add_summaries=None,
            joint_train=False,
            gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1),
            # TPUEstimator options.
            model_dir=None,
            config=None,
            params=None,
            use_tpu=True,
            train_batch_size=None,
            eval_batch_size=None,
            predict_batch_size=None,
            batch_axis=None,
            eval_on_tpu=True,
            export_to_tpu=True,
            warm_start_from=None):
        """Initializes a TPUGANEstimator instance.

    Args:
      generator_fn: A python function that takes a Tensor, Tensor list, or
        Tensor dictionary as inputs and returns the outputs of the GAN
        generator. See `TFGAN` for more details and examples. Additionally, if
        it has an argument called `mode`, the Estimator's `mode` will be passed
        in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch
        normalization.
      discriminator_fn: A python function that takes the output of
        `generator_fn` or real data in the GAN setup, and `generator_inputs`.
        Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details
        and examples.
      generator_loss_fn: The loss function on the generator. Takes a `GANModel`
        tuple.
      discriminator_loss_fn: The loss function on the discriminator. Takes a
        `GANModel` tuple.
      generator_optimizer: The optimizer for generator updates, or a function
        that takes no arguments and returns an optimizer. This function will
        be called when the default graph is the `GANEstimator`'s graph, so
        utilities like `tf.train.get_or_create_global_step` will
        work.
      discriminator_optimizer: Same as `generator_optimizer`, but for the
        discriminator updates.
      prepare_arguments_for_eval_metric_fn: A function that takes a list of
        arguments and returns a nested structure of tensors keyed by name. The
        returned tensors must be compatible with TPUEstimatorSpec.eval_metrics
        (i.e., in batch-major format, where the batch size is the first
        dimension) and will be passed to the provided get_eval_metric_ops_fn.
        The arguments must be:
            * generator_inputs
            * generated_data
            * real_data
            * discriminator_real_outputs
            * discriminator_gen_outputs
        The default impelementation simply returns the arguments as-is. This
        function is executed on the TPU, allowing for compute-heavy eval-only
        operations to be performed.
      get_eval_metric_ops_fn: A function that takes a list of arguments and
        returns a dict of metric results keyed by name, exectuted on CPU. The
        arguments of the function should be the keys of the dict returned
        by prepare_arguments_for_eval_metric_fn (see the
        prepare_arguments_for_eval_metric_fn for the defaults), and should
        return a dict from metric string name to the result of calling a metric
        function, namely a (metric_tensor, update_op) tuple.
      add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
        This is ignored for jobs that run on TPU, such as the train job if
        `use_tpu` is `True` or the eval job if `eval_on_tpu` is `True`.
      joint_train: A Python boolean. If `True`, jointly train the generator and
        the discriminator. If `False`, sequentially train them. See `train.py`
        in TFGAN for more details on the differences between the two GAN
        training methods.
      gan_train_steps: A `tfgan.GANTrainSteps` named tuple describing the ratio
        of generator to discriminator steps.
      model_dir: Same as `TPUEstimator`: Directory to save model parameters,
        graph and etc. This can also be used to load checkpoints from the
        directory into a estimator to continue training a previously saved
        model. If `None`, the model_dir in `config` will be used if set. If both
        are set, they must be same. If both are `None`, a temporary directory
        will be used.
      config: Same as `TPUEstimator`: An `tpu_config.RunConfig` configuration
        object. Cannot be `None`.
      params: Same as `TPUEstimator`: An optional `dict` of hyper parameters
        that will be passed into `input_fn` and `model_fn`.  Keys are names of
        parameters, values are basic python types. There are reserved keys for
        `TPUEstimator`, including 'batch_size'. If any `params` are args to
        TF-GAN's `gan_loss`, they will be passed to `gan_loss` during training
        and evaluation.
      use_tpu: Same as `TPUEstimator`: A bool indicating whether TPU support is
        enabled. Currently, TPU training and evaluation respect this bit, but
        eval_on_tpu can override execution of eval. See below. Predict still
        happens on CPU.
      train_batch_size: Same as `TPUEstimator`: An int representing the global
        training batch size. TPUEstimator transforms this global batch size to a
        per-shard batch size, as params['batch_size'], when calling `input_fn`
        and `model_fn`. Cannot be `None` if `use_tpu` is `True`. Must be
        divisible by total number of replicas.
      eval_batch_size: Same as `TPUEstimator`: An int representing evaluation
        batch size. Must be divisible by total number of replicas.
      predict_batch_size: Same as `TPUEstimator`: An int representing the
        prediction batch size. Must be divisible by total number of replicas.
      batch_axis: Same as `TPUEstimator`: A python tuple of int values
        describing how each tensor produced by the Estimator `input_fn` should
        be split across the TPU compute shards. For example, if your input_fn
        produced (images, labels) where the images tensor is in `HWCN` format,
        your shard dimensions would be [3, 0], where 3 corresponds to the `N`
        dimension of your images Tensor, and 0 corresponds to the dimension
        along which to split the labels to match up with the corresponding
        images. If None is supplied, and per_host_input_for_training is True,
        batches will be sharded based on the major dimension. If
        tpu_config.per_host_input_for_training is False or `PER_HOST_V2`,
        batch_axis is ignored.
      eval_on_tpu: Same as `TPUEstimator`: If False, evaluation runs on CPU or
        GPU. In this case, the model_fn must return `EstimatorSpec` when called
        with `mode` as `EVAL`.
      export_to_tpu: Same as `TPUEstimator`: If True, `export_savedmodel()`
        exports a metagraph for serving on TPU besides the one on CPU.
      warm_start_from: Same as `TPUEstimator`: Optional string filepath to a
        checkpoint or SavedModel to warm-start from, or a
        `tf.estimator.WarmStartSettings` object to fully configure
        warm-starting.  If the string filepath is provided instead of a
        `WarmStartSettings`, then all variables are warm-started, and it is
        assumed that vocabularies and Tensor names are unchanged.

    Raises:
      ValueError: If loss functions aren't callable.
      ValueError: If `gan_train_steps` isn't a `tfgan_tuples.GANTrainSteps`
        tuple.
      ValueError: If `gan_train_steps` isn't 1:1 training.
    """
        _validate_input_args(generator_loss_fn, discriminator_loss_fn,
                             gan_train_steps)
        loss_fns = LossFns(generator_loss_fn, discriminator_loss_fn)
        optimizers = Optimizers(generator_optimizer, discriminator_optimizer)

        # Determine the number of GAN models required to create in order to train
        # in different D:G ratios on TPU.
        required_train_models = _required_train_models(gan_train_steps,
                                                       joint_train)
        effective_train_batch_size = required_train_models * train_batch_size

        def _model_fn(features, labels, mode, params):
            """GANEstimator model function."""
            if mode not in [
                    tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL,
                    tf.estimator.ModeKeys.PREDICT
            ]:
                raise ValueError('Mode not recognized: %s' % mode)
            real_data = labels  # rename inputs for clarity
            generator_inputs = features  # rename inputs for clarity

            # Collect GANModel builder functions, which encapsulate the GAN model
            # architectures. Don't actually execute them here, since the functions
            # actually create the TF ops and the variable reads need to be chained
            # after the writes from the previous step. Instead just pass the functions
            # with bound arguments down so that they can easily be executed later.
            gan_model_fns = _get_gan_model_fns(
                mode,
                generator_fn,
                discriminator_fn,
                real_data,
                generator_inputs,
                num_train_models=required_train_models)

            # TODO(joelshor): Switch TF-GAN over to TPU-compatible summaries, then
            # remove `add_summaries` logic below.
            is_on_tpu = _is_on_tpu(mode, use_tpu, eval_on_tpu)
            summary_types = None if is_on_tpu else add_summaries

            # Make the TPUEstimatorSpec, which incorporates the model, losses, eval
            # metrics, and optimizers (if required).
            gan_loss_kwargs = gan_estimator.extract_gan_loss_args_from_params(
                params)
            if mode == tf.estimator.ModeKeys.TRAIN:
                estimator_spec = get_train_estimator_spec(
                    gan_model_fns,
                    loss_fns,
                    gan_loss_kwargs,
                    optimizers,
                    joint_train,
                    is_on_tpu,
                    gan_train_steps,
                    add_summaries=summary_types)
            elif mode == tf.estimator.ModeKeys.EVAL:
                estimator_spec = get_eval_estimator_spec(
                    gan_model_fns,
                    loss_fns,
                    gan_loss_kwargs,
                    prepare_arguments_for_eval_metric_fn,
                    get_eval_metric_ops_fn,
                    add_summaries=summary_types)
            else:  # predict
                estimator_spec = get_predict_estimator_spec(gan_model_fns)
            assert isinstance(estimator_spec,
                              tf.compat.v1.estimator.tpu.TPUEstimatorSpec)

            return estimator_spec

        super(TPUGANEstimator,
              self).__init__(model_fn=_model_fn,
                             model_dir=model_dir,
                             config=config,
                             params=params,
                             use_tpu=use_tpu,
                             train_batch_size=effective_train_batch_size,
                             eval_batch_size=eval_batch_size,
                             predict_batch_size=predict_batch_size,
                             batch_axis=batch_axis,
                             eval_on_tpu=eval_on_tpu,
                             export_to_tpu=export_to_tpu,
                             warm_start_from=warm_start_from)