def _model_fn(features, labels, mode, params): """GANEstimator model function.""" if mode not in [ tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT ]: raise ValueError('Mode not recognized: %s' % mode) real_data = labels # rename inputs for clarity generator_inputs = features # rename inputs for clarity # Collect GANModel builder functions, which encapsulate the GAN model # architectures. Don't actually execute them here, since the functions # actually create the TF ops and the variable reads need to be chained # after the writes from the previous step. Instead just pass the functions # with bound arguments down so that they can easily be executed later. gan_model_fns = _get_gan_model_fns( mode, generator_fn, discriminator_fn, real_data, generator_inputs, num_train_models=required_train_models) # TODO(joelshor): Switch TF-GAN over to TPU-compatible summaries, then # remove `add_summaries` logic below. is_on_tpu = _is_on_tpu(mode, use_tpu, eval_on_tpu) summary_types = None if is_on_tpu else add_summaries # Make the TPUEstimatorSpec, which incorporates the model, losses, eval # metrics, and optimizers (if required). gan_loss_kwargs = gan_estimator.extract_gan_loss_args_from_params( params) if mode == tf.estimator.ModeKeys.TRAIN: estimator_spec = get_train_estimator_spec( gan_model_fns, loss_fns, gan_loss_kwargs, optimizers, joint_train, is_on_tpu, gan_train_steps, add_summaries=summary_types, run_config=config) elif mode == tf.estimator.ModeKeys.EVAL: estimator_spec = get_eval_estimator_spec( gan_model_fns, loss_fns, gan_loss_kwargs, prepare_arguments_for_eval_metric_fn, get_eval_metric_ops_fn, add_summaries=summary_types) else: # predict estimator_spec = get_predict_estimator_spec(gan_model_fns) assert isinstance(estimator_spec, tf.compat.v1.estimator.tpu.TPUEstimatorSpec) return estimator_spec
def test_extract_gan_loss_args_from_params(self): params = { 'tensor_pool_fn': 1, 'gradient_penalty_target': 2, 'other': 3 } gan_loss_args = extract_gan_loss_args_from_params(params) self.assertEqual(gan_loss_args, { 'tensor_pool_fn': 1, 'gradient_penalty_target': 2 })
def test_extract_gan_loss_args_from_params_forbidden(self): params = {'tensor_pool_fn': 1, 'model': 2} with self.assertRaises(ValueError): extract_gan_loss_args_from_params(params)