Esempio n. 1
0
        def _model_fn(features, labels, mode, params):
            """GANEstimator model function."""
            if mode not in [
                    tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL,
                    tf.estimator.ModeKeys.PREDICT
            ]:
                raise ValueError('Mode not recognized: %s' % mode)
            real_data = labels  # rename inputs for clarity
            generator_inputs = features  # rename inputs for clarity

            # Collect GANModel builder functions, which encapsulate the GAN model
            # architectures. Don't actually execute them here, since the functions
            # actually create the TF ops and the variable reads need to be chained
            # after the writes from the previous step. Instead just pass the functions
            # with bound arguments down so that they can easily be executed later.
            gan_model_fns = _get_gan_model_fns(
                mode,
                generator_fn,
                discriminator_fn,
                real_data,
                generator_inputs,
                num_train_models=required_train_models)

            # TODO(joelshor): Switch TF-GAN over to TPU-compatible summaries, then
            # remove `add_summaries` logic below.
            is_on_tpu = _is_on_tpu(mode, use_tpu, eval_on_tpu)
            summary_types = None if is_on_tpu else add_summaries

            # Make the TPUEstimatorSpec, which incorporates the model, losses, eval
            # metrics, and optimizers (if required).
            gan_loss_kwargs = gan_estimator.extract_gan_loss_args_from_params(
                params)
            if mode == tf.estimator.ModeKeys.TRAIN:
                estimator_spec = get_train_estimator_spec(
                    gan_model_fns,
                    loss_fns,
                    gan_loss_kwargs,
                    optimizers,
                    joint_train,
                    is_on_tpu,
                    gan_train_steps,
                    add_summaries=summary_types,
                    run_config=config)
            elif mode == tf.estimator.ModeKeys.EVAL:
                estimator_spec = get_eval_estimator_spec(
                    gan_model_fns,
                    loss_fns,
                    gan_loss_kwargs,
                    prepare_arguments_for_eval_metric_fn,
                    get_eval_metric_ops_fn,
                    add_summaries=summary_types)
            else:  # predict
                estimator_spec = get_predict_estimator_spec(gan_model_fns)
            assert isinstance(estimator_spec,
                              tf.compat.v1.estimator.tpu.TPUEstimatorSpec)

            return estimator_spec
Esempio n. 2
0
 def test_extract_gan_loss_args_from_params(self):
     params = {
         'tensor_pool_fn': 1,
         'gradient_penalty_target': 2,
         'other': 3
     }
     gan_loss_args = extract_gan_loss_args_from_params(params)
     self.assertEqual(gan_loss_args, {
         'tensor_pool_fn': 1,
         'gradient_penalty_target': 2
     })
Esempio n. 3
0
 def test_extract_gan_loss_args_from_params_forbidden(self):
     params = {'tensor_pool_fn': 1, 'model': 2}
     with self.assertRaises(ValueError):
         extract_gan_loss_args_from_params(params)