Beispiel #1
0
  def mnist_model_fn_helper(self, mode, multi_gpu=False):
    print("mnist_model_fn_helper() in")
    features, labels = dummy_input_fn()
    image_count = features.shape[0]
    spec = mnist.model_fn(features, labels, mode, {
        'data_format': 'channels_last',
        'multi_gpu': multi_gpu
    })

    if mode == tf.estimator.ModeKeys.PREDICT:
      predictions = spec.predictions
      self.assertAllEqual(predictions['probabilities'].shape, (image_count, 10))
      self.assertEqual(predictions['probabilities'].dtype, tf.float32)
      self.assertAllEqual(predictions['classes'].shape, (image_count,))
      self.assertEqual(predictions['classes'].dtype, tf.int64)

    if mode != tf.estimator.ModeKeys.PREDICT:
      loss = spec.loss
      self.assertAllEqual(loss.shape, ())
      self.assertEqual(loss.dtype, tf.float32)

    if mode == tf.estimator.ModeKeys.EVAL:
      eval_metric_ops = spec.eval_metric_ops
      self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
      self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
      self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
      self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
Beispiel #2
0
  def mnist_model_fn_helper(self, mode, multi_gpu=False):
    features, labels = dummy_input_fn()
    image_count = features.shape[0]
    spec = mnist.model_fn(features, labels, mode, {
        'data_format': 'channels_last',
        'multi_gpu': multi_gpu
    })

    if mode == tf.estimator.ModeKeys.PREDICT:
      predictions = spec.predictions
      self.assertAllEqual(predictions['probabilities'].shape, (image_count, 10))
      self.assertEqual(predictions['probabilities'].dtype, tf.float32)
      self.assertAllEqual(predictions['classes'].shape, (image_count,))
      self.assertEqual(predictions['classes'].dtype, tf.int64)

    if mode != tf.estimator.ModeKeys.PREDICT:
      loss = spec.loss
      self.assertAllEqual(loss.shape, ())
      self.assertEqual(loss.dtype, tf.float32)

    if mode == tf.estimator.ModeKeys.EVAL:
      eval_metric_ops = spec.eval_metric_ops
      self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
      self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
      self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
      self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)