def test_dataset_input_fn(self):
        fake_data = bytearray()
        fake_data.append(7)
        for i in range(_NUM_CHANNELS):
            for _ in range(_HEIGHT * _WIDTH):
                fake_data.append(i)

        _, filename = mkstemp(dir=self.get_temp_dir())
        data_file = open(filename, 'wb')
        data_file.write(fake_data)
        data_file.close()

        fake_dataset = tf.data.FixedLengthRecordDataset(
            filename, cifar10_main._RECORD_BYTES)  # pylint: disable=protected-access
        fake_dataset = fake_dataset.map(
            lambda val: cifar10_main.parse_record(val, False, tf.float32))
        image, label = fake_dataset.make_one_shot_iterator().get_next()

        self.assertAllEqual(label.shape, ())
        self.assertAllEqual(image.shape, (_HEIGHT, _WIDTH, _NUM_CHANNELS))

        with self.test_session() as sess:
            image, label = sess.run([image, label])

            self.assertEqual(label, 7)

            for row in image:
                for pixel in row:
                    self.assertAllClose(pixel,
                                        np.array([-1.225, 0., 1.225]),
                                        rtol=1e-3)
Beispiel #2
0
  def test_dataset_input_fn(self):
    fake_data = bytearray()
    fake_data.append(7)
    for i in range(_NUM_CHANNELS):
      for _ in range(_HEIGHT * _WIDTH):
        fake_data.append(i)

    _, filename = mkstemp(dir=self.get_temp_dir())
    data_file = open(filename, 'wb')
    data_file.write(fake_data)
    data_file.close()

    fake_dataset = tf.data.FixedLengthRecordDataset(
        filename, cifar10_main._RECORD_BYTES)  # pylint: disable=protected-access
    fake_dataset = fake_dataset.map(
        lambda val: cifar10_main.parse_record(val, False, tf.float32))
    image, label = tf.compat.v1.data.make_one_shot_iterator(
        fake_dataset).get_next()

    self.assertAllEqual(label.shape, ())
    self.assertAllEqual(image.shape, (_HEIGHT, _WIDTH, _NUM_CHANNELS))

    with self.test_session() as sess:
      image, label = sess.run([image, label])

      self.assertEqual(label, 7)

      for row in image:
        for pixel in row:
          self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
Beispiel #3
0
def parse_record_keras(raw_record, is_training, dtype):
    """Parses a record containing a training example of an image.

  The input record is parsed into a label and image, and the image is passed
  through preprocessing steps (cropping, flipping, and so on).

  This method converts the label to one hot to fit the loss function.

  Args:
    raw_record: scalar Tensor tf.string containing a serialized
      Example protocol buffer.
    is_training: A boolean denoting whether the input is for training.
    dtype: Data type to use for input images.

  Returns:
    Tuple with processed image tensor and one-hot-encoded label tensor.
  """
    image, label = cifar_main.parse_record(raw_record, is_training, dtype)
    label = tf.sparse_to_dense(label, (cifar_main.NUM_CLASSES, ), 1)
    return image, label
def parse_record_keras(raw_record, is_training, dtype):
  """Parses a record containing a training example of an image.

  The input record is parsed into a label and image, and the image is passed
  through preprocessing steps (cropping, flipping, and so on).

  This method converts the label to one hot to fit the loss function.

  Args:
    raw_record: scalar Tensor tf.string containing a serialized
      Example protocol buffer.
    is_training: A boolean denoting whether the input is for training.
    dtype: Data type to use for input images.

  Returns:
    Tuple with processed image tensor and one-hot-encoded label tensor.
  """
  image, label = cifar_main.parse_record(raw_record, is_training, dtype)
  label = tf.sparse_to_dense(label, (cifar_main.NUM_CLASSES,), 1)
  return image, label