示例#1
0
    def mnist_model_fn_helper(self, mode, multi_gpu=False):
        features, labels = dummy_input_fn()
        image_count = features.shape[0]
        spec = mnist.model_fn(features, labels, mode, {
            'data_format': 'channels_last',
            'multi_gpu': multi_gpu
        })

        if mode == tf.estimator.ModeKeys.PREDICT:
            predictions = spec.predictions
            self.assertAllEqual(predictions['probabilities'].shape,
                                (image_count, 10))
            self.assertEqual(predictions['probabilities'].dtype, tf.float32)
            self.assertAllEqual(predictions['classes'].shape, (image_count, ))
            self.assertEqual(predictions['classes'].dtype, tf.int64)

        if mode != tf.estimator.ModeKeys.PREDICT:
            loss = spec.loss
            self.assertAllEqual(loss.shape, ())
            self.assertEqual(loss.dtype, tf.float32)

        if mode == tf.estimator.ModeKeys.EVAL:
            eval_metric_ops = spec.eval_metric_ops
            self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
            self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
            self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
            self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
示例#2
0
  def mnist_model_fn_helper(self, mode, multi_gpu=False):
    features, labels = dummy_input_fn()
    image_count = features.shape[0]
    spec = mnist.model_fn(features, labels, mode, {
        'data_format': 'channels_last',
        'multi_gpu': multi_gpu
    })

    if mode == tf.estimator.ModeKeys.PREDICT:
      predictions = spec.predictions
      self.assertAllEqual(predictions['probabilities'].shape, (image_count, 10))
      self.assertEqual(predictions['probabilities'].dtype, tf.float32)
      self.assertAllEqual(predictions['classes'].shape, (image_count,))
      self.assertEqual(predictions['classes'].dtype, tf.int64)

    if mode != tf.estimator.ModeKeys.PREDICT:
      loss = spec.loss
      self.assertAllEqual(loss.shape, ())
      self.assertEqual(loss.dtype, tf.float32)

    if mode == tf.estimator.ModeKeys.EVAL:
      eval_metric_ops = spec.eval_metric_ops
      self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
      self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
      self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
      self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)