Esempio n. 1
0
  def test_dataset_input_fn(self):
    fake_data = bytearray()
    fake_data.append(7)
    for i in range(_NUM_CHANNELS):
      for _ in range(_HEIGHT * _WIDTH):
        fake_data.append(i)

    _, filename = mkstemp(dir=self.get_temp_dir())
    data_file = open(filename, 'wb')
    data_file.write(fake_data)
    data_file.close()

    fake_dataset = tf.data.FixedLengthRecordDataset(
        filename, cifar10_main._RECORD_BYTES)
    fake_dataset = fake_dataset.map(
        lambda val: cifar10_main.parse_record(val, False))
    image, label = fake_dataset.make_one_shot_iterator().get_next()

    self.assertAllEqual(label.shape, (10,))
    self.assertAllEqual(image.shape, (_HEIGHT, _WIDTH, _NUM_CHANNELS))

    with self.test_session() as sess:
      image, label = sess.run([image, label])

      self.assertAllEqual(label, np.array([int(i == 7) for i in range(10)]))

      for row in image:
        for pixel in row:
          self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
    def test_dataset_input_fn(self):
        fake_data = bytearray()
        fake_data.append(7)
        for i in range(_NUM_CHANNELS):
            for _ in range(_HEIGHT * _WIDTH):
                fake_data.append(i)

        _, filename = mkstemp(dir=self.get_temp_dir())
        data_file = open(filename, 'wb')
        data_file.write(fake_data)
        data_file.close()

        fake_dataset = tf.data.FixedLengthRecordDataset(
            filename, cifar10_main._RECORD_BYTES)
        fake_dataset = fake_dataset.map(
            lambda val: cifar10_main.parse_record(val, False))
        image, label = fake_dataset.make_one_shot_iterator().get_next()

        self.assertAllEqual(label.shape, (10, ))
        self.assertAllEqual(image.shape, (_HEIGHT, _WIDTH, _NUM_CHANNELS))

        with self.test_session() as sess:
            image, label = sess.run([image, label])

            self.assertAllEqual(label,
                                np.array([int(i == 7) for i in range(10)]))

            for row in image:
                for pixel in row:
                    self.assertAllClose(pixel,
                                        np.array([-1.225, 0., 1.225]),
                                        rtol=1e-3)
def parse_record_keras(raw_record, is_training, dtype):
    """Parses a record containing a training example of an image.

    The input record is parsed into a label and image, and the image is passed
    through preprocessing steps (cropping, flipping, and so on).

    This method converts the label to one hot to fit the loss function.

    Args:
        raw_record: scalar Tensor tf.string containing a serialized
            Example protocol buffer.
        is_training: A boolean denoting whether the input is for training.
        dtype: Data type to use for input images.

    Returns:
        Tuple with processed image tensor and one-hot-encoded label tensor.
    """
    image, label = cifar_main.parse_record(raw_record, is_training, dtype)
    label = tf.sparse_to_dense(label, (cifar_main.NUM_CLASSES, ), 1)
    return image, label