def cifar10_model_fn_helper(self, mode):
        features, labels = self.input_fn()
        spec = cifar10_main.cifar10_model_fn(
            features, labels, mode, {
                'resnet_size': 32,
                'data_format': 'channels_last',
                'batch_size': _BATCH_SIZE,
            })

        predictions = spec.predictions
        self.assertAllEqual(predictions['probabilities'].shape,
                            (_BATCH_SIZE, 10))
        self.assertEqual(predictions['probabilities'].dtype, tf.float32)
        self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE, ))
        self.assertEqual(predictions['classes'].dtype, tf.int64)

        if mode != tf.estimator.ModeKeys.PREDICT:
            loss = spec.loss
            self.assertAllEqual(loss.shape, ())
            self.assertEqual(loss.dtype, tf.float32)

        if mode == tf.estimator.ModeKeys.EVAL:
            eval_metric_ops = spec.eval_metric_ops
            self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
            self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
            self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
            self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
Beispiel #2
0
  def cifar10_model_fn_helper(self, mode, multi_gpu=False):
    input_fn = cifar10_main.get_synth_input_fn()
    dataset = input_fn(True, '', _BATCH_SIZE)
    iterator = dataset.make_one_shot_iterator()
    features, labels = iterator.get_next()
    spec = cifar10_main.cifar10_model_fn(
        features, labels, mode, {
            'resnet_size': 32,
            'data_format': 'channels_last',
            'batch_size': _BATCH_SIZE,
            'multi_gpu': multi_gpu
        })

    predictions = spec.predictions
    self.assertAllEqual(predictions['probabilities'].shape,
                        (_BATCH_SIZE, 10))
    self.assertEqual(predictions['probabilities'].dtype, tf.float32)
    self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
    self.assertEqual(predictions['classes'].dtype, tf.int64)

    if mode != tf.estimator.ModeKeys.PREDICT:
      loss = spec.loss
      self.assertAllEqual(loss.shape, ())
      self.assertEqual(loss.dtype, tf.float32)

    if mode == tf.estimator.ModeKeys.EVAL:
      eval_metric_ops = spec.eval_metric_ops
      self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
      self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
      self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
      self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
Beispiel #3
0
  def cifar10_model_fn_helper(self, mode):
    features, labels = self.input_fn()
    spec = cifar10_main.cifar10_model_fn(
        features, labels, mode, {
            'resnet_size': 32,
            'data_format': 'channels_last',
            'batch_size': _BATCH_SIZE,
        })

    predictions = spec.predictions
    self.assertAllEqual(predictions['probabilities'].shape,
                        (_BATCH_SIZE, 10))
    self.assertEqual(predictions['probabilities'].dtype, tf.float32)
    self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
    self.assertEqual(predictions['classes'].dtype, tf.int64)

    if mode != tf.estimator.ModeKeys.PREDICT:
      loss = spec.loss
      self.assertAllEqual(loss.shape, ())
      self.assertEqual(loss.dtype, tf.float32)

    if mode == tf.estimator.ModeKeys.EVAL:
      eval_metric_ops = spec.eval_metric_ops
      self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
      self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
      self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
      self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
    def cifar10_model_fn_helper(self, mode, version, multi_gpu=False):
        input_fn = cifar10_main.get_synth_input_fn()
        dataset = input_fn(True, '', _BATCH_SIZE)
        iterator = dataset.make_one_shot_iterator()
        features, labels = iterator.get_next()
        spec = cifar10_main.cifar10_model_fn(
            features, labels, mode, {
                'resnet_size': 32,
                'data_format': 'channels_last',
                'batch_size': _BATCH_SIZE,
                'version': version,
                'multi_gpu': multi_gpu
            })

        predictions = spec.predictions
        self.assertAllEqual(predictions['probabilities'].shape,
                            (_BATCH_SIZE, 10))
        self.assertEqual(predictions['probabilities'].dtype, tf.float32)
        self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE, ))
        self.assertEqual(predictions['classes'].dtype, tf.int64)

        if mode != tf.estimator.ModeKeys.PREDICT:
            loss = spec.loss
            self.assertAllEqual(loss.shape, ())
            self.assertEqual(loss.dtype, tf.float32)

        if mode == tf.estimator.ModeKeys.EVAL:
            eval_metric_ops = spec.eval_metric_ops
            self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
            self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
            self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
            self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
Beispiel #5
0
  def cifar10_model_fn_helper(self, mode):
    features, labels = self.input_fn()
    spec = cifar10_main.cifar10_model_fn(features, labels, mode)

    predictions = spec.predictions
    self.assertAllEqual(predictions['probabilities'].shape,
                        (FLAGS.batch_size, 10))
    self.assertEqual(predictions['probabilities'].dtype, tf.float32)
    self.assertAllEqual(predictions['classes'].shape, (FLAGS.batch_size,))
    self.assertEqual(predictions['classes'].dtype, tf.int64)

    if mode != tf.estimator.ModeKeys.PREDICT:
      loss = spec.loss
      self.assertAllEqual(loss.shape, ())
      self.assertEqual(loss.dtype, tf.float32)

    if mode == tf.estimator.ModeKeys.EVAL:
      eval_metric_ops = spec.eval_metric_ops
      self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
      self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
      self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
      self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)