def test_get_gan_model(self, mode): with ops.Graph().as_default(): generator_inputs = {'x': array_ops.ones([3, 4])} real_data = (array_ops.zeros([3, 4]) if mode != model_fn_lib.ModeKeys.PREDICT else None) gan_model = estimator._get_gan_model(mode, generator_fn, discriminator_fn, real_data, generator_inputs, add_summaries=False) self.assertEqual(generator_inputs, gan_model.generator_inputs) self.assertIsNotNone(gan_model.generated_data) self.assertEqual(2, len(gan_model.generator_variables)) # 1 FC layer self.assertIsNotNone(gan_model.generator_fn) if mode == model_fn_lib.ModeKeys.PREDICT: self.assertIsNone(gan_model.real_data) self.assertIsNone(gan_model.discriminator_real_outputs) self.assertIsNone(gan_model.discriminator_gen_outputs) self.assertIsNone(gan_model.discriminator_variables) self.assertIsNone(gan_model.discriminator_scope) self.assertIsNone(gan_model.discriminator_fn) else: self.assertIsNotNone(gan_model.real_data) self.assertIsNotNone(gan_model.discriminator_real_outputs) self.assertIsNotNone(gan_model.discriminator_gen_outputs) self.assertEqual(2, len( gan_model.discriminator_variables)) # 1 FC layer self.assertIsNotNone(gan_model.discriminator_scope) self.assertIsNotNone(gan_model.discriminator_fn)
def _model_fn(features, labels, mode, params): """GANEstimator model function.""" del params # unused if mode not in [ model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, model_fn_lib.ModeKeys.PREDICT ]: raise ValueError('Mode not recognized: %s' % mode) real_data = labels # rename inputs for clarity generator_inputs = features # rename inputs for clarity # Make GANModel, which encapsulates the GAN model architectures. # TODO(joelshor): Switch TF-GAN over to TPU-compatible summaries, then # remove `add_summaries` logic below. is_on_tpu = _is_on_tpu(mode, use_tpu, eval_on_tpu) gan_model = gan_estimator_lib._get_gan_model( # pylint:disable=protected-access mode, generator_fn, discriminator_fn, real_data, generator_inputs, add_summaries=None if is_on_tpu else add_summaries) # Make the TPUEstimatorSpec, which incorporates the GANModel, losses, eval # metrics, and optimizers (if required). estimator_spec = _get_estimator_spec( mode, gan_model, generator_loss_fn, discriminator_loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, joint_train, is_on_tpu, gan_train_steps) assert isinstance(estimator_spec, tpu_estimator.TPUEstimatorSpec) return estimator_spec
def test_get_gan_model(self, mode): with ops.Graph().as_default(): generator_inputs = {'x': array_ops.ones([3, 4])} real_data = (array_ops.zeros([3, 4]) if mode != model_fn_lib.ModeKeys.PREDICT else None) gan_model = estimator._get_gan_model( mode, generator_fn, discriminator_fn, real_data, generator_inputs, add_summaries=False) self.assertEqual(generator_inputs, gan_model.generator_inputs) self.assertIsNotNone(gan_model.generated_data) self.assertEqual(2, len(gan_model.generator_variables)) # 1 FC layer self.assertIsNotNone(gan_model.generator_fn) if mode == model_fn_lib.ModeKeys.PREDICT: self.assertIsNone(gan_model.real_data) self.assertIsNone(gan_model.discriminator_real_outputs) self.assertIsNone(gan_model.discriminator_gen_outputs) self.assertIsNone(gan_model.discriminator_variables) self.assertIsNone(gan_model.discriminator_scope) self.assertIsNone(gan_model.discriminator_fn) else: self.assertIsNotNone(gan_model.real_data) self.assertIsNotNone(gan_model.discriminator_real_outputs) self.assertIsNotNone(gan_model.discriminator_gen_outputs) self.assertEqual(2, len(gan_model.discriminator_variables)) # 1 FC layer self.assertIsNotNone(gan_model.discriminator_scope) self.assertIsNotNone(gan_model.discriminator_fn)
def _model_fn(features, labels, mode, params): """GANEstimator model function.""" del params # unused if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, model_fn_lib.ModeKeys.PREDICT]: raise ValueError('Mode not recognized: %s' % mode) real_data = labels # rename inputs for clarity generator_inputs = features # rename inputs for clarity # Make GANModel, which encapsulates the GAN model architectures. # TODO(joelshor): Switch TF-GAN over to TPU-compatible summaries, then # remove `add_summaries` logic below. is_on_tpu = _is_on_tpu(mode, use_tpu, eval_on_tpu) gan_model = gan_estimator_lib._get_gan_model( # pylint:disable=protected-access mode, generator_fn, discriminator_fn, real_data, generator_inputs, add_summaries=None if is_on_tpu else add_summaries) # Make the TPUEstimatorSpec, which incorporates the GANModel, losses, eval # metrics, and optimizers (if required). estimator_spec = _get_estimator_spec( mode, gan_model, generator_loss_fn, discriminator_loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, joint_train, is_on_tpu, gan_train_steps) assert isinstance(estimator_spec, tpu_estimator.TPUEstimatorSpec) return estimator_spec