def _test_logits_helper(self, mode):
     """Tests that the expected logits are passed to mock head."""
     with ops.Graph().as_default():
         training_util.get_or_create_global_step()
         generator_inputs = {'x': array_ops.zeros([5, 4])}
         real_data = (None if mode == model_fn_lib.ModeKeys.PREDICT else
                      array_ops.zeros([5, 4]))
         generator_scope_name = 'generator'
         head = mock_head(self,
                          expected_generator_inputs=generator_inputs,
                          expected_real_data=real_data,
                          generator_scope_name=generator_scope_name)
         estimator_spec = estimator._gan_model_fn(
             features=generator_inputs,
             labels=real_data,
             mode=mode,
             generator_fn=generator_fn,
             discriminator_fn=discriminator_fn,
             generator_scope_name=generator_scope_name,
             head=head)
         with monitored_session.MonitoredTrainingSession(
                 checkpoint_dir=self._model_dir) as sess:
             if mode == model_fn_lib.ModeKeys.TRAIN:
                 sess.run(estimator_spec.train_op)
             elif mode == model_fn_lib.ModeKeys.EVAL:
                 sess.run(estimator_spec.loss)
             elif mode == model_fn_lib.ModeKeys.PREDICT:
                 sess.run(estimator_spec.predictions)
             else:
                 self.fail('Invalid mode: {}'.format(mode))
 def _test_logits_helper(self, mode):
   """Tests that the expected logits are passed to mock head."""
   with ops.Graph().as_default():
     training_util.get_or_create_global_step()
     generator_inputs = {'x': array_ops.zeros([5, 4])}
     real_data = (None if mode == model_fn_lib.ModeKeys.PREDICT else
                  array_ops.zeros([5, 4]))
     generator_scope_name = 'generator'
     head = mock_head(self,
                      expected_generator_inputs=generator_inputs,
                      expected_real_data=real_data,
                      generator_scope_name=generator_scope_name)
     estimator_spec = estimator._gan_model_fn(
         features=generator_inputs,
         labels=real_data,
         mode=mode,
         generator_fn=generator_fn,
         discriminator_fn=discriminator_fn,
         generator_scope_name=generator_scope_name,
         head=head)
     with monitored_session.MonitoredTrainingSession(
         checkpoint_dir=self._model_dir) as sess:
       if mode == model_fn_lib.ModeKeys.TRAIN:
         sess.run(estimator_spec.train_op)
       elif mode == model_fn_lib.ModeKeys.EVAL:
         sess.run(estimator_spec.loss)
       elif mode == model_fn_lib.ModeKeys.PREDICT:
         sess.run(estimator_spec.predictions)
       else:
         self.fail('Invalid mode: {}'.format(mode))