コード例 #1
0
def _create_tfrecord_dataset(tmpdir):
  if not gfile.Exists(tmpdir):
    gfile.MakeDirs(tmpdir)

  data_sources = test_utils.create_tfrecord_files(tmpdir, num_files=1)

  keys_to_features = {
      'image/encoded':
          parsing_ops.FixedLenFeature(
              shape=(), dtype=dtypes.string, default_value=''),
      'image/format':
          parsing_ops.FixedLenFeature(
              shape=(), dtype=dtypes.string, default_value='jpeg'),
      'image/class/label':
          parsing_ops.FixedLenFeature(
              shape=[1],
              dtype=dtypes.int64,
              default_value=array_ops.zeros(
                  [1], dtype=dtypes.int64))
  }

  items_to_handlers = {
      'image': tfexample_decoder.Image(),
      'label': tfexample_decoder.Tensor('image/class/label'),
  }

  decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
                                               items_to_handlers)

  return dataset.Dataset(
      data_sources=data_sources,
      reader=io_ops.TFRecordReader,
      decoder=decoder,
      num_samples=100,
      items_to_descriptions=None)
コード例 #2
0
  def testDecodeExampleWithJpegEncoding(self):
    image_shape = (2, 3, 3)
    image, serialized_example = self.GenerateImage(
        image_format='jpeg', image_shape=image_shape)

    decoded_image = self.RunDecodeExample(
        serialized_example, tfexample_decoder.Image(), image_format='jpeg')

    # Need to use a tolerance of 1 because of noise in the jpeg encode/decode
    self.assertAllClose(image, decoded_image, atol=1.001)
コード例 #3
0
 def testDecodeExampleWithJpegEncodingAt16BitDoesNotCauseError(self):
   image_shape = (2, 3, 3)
   # Image has type uint8 but decoding at uint16 should not cause problems.
   image, serialized_example = self.GenerateImage(
       image_format='jpeg', image_shape=image_shape)
   decoded_image = self.RunDecodeExample(
       serialized_example,
       tfexample_decoder.Image(dtype=tf.uint16),
       image_format='jpeg')
   self.assertAllClose(image, decoded_image, atol=1.001)
コード例 #4
0
  def testDecodeExampleWithRawEncodingFloatDtype(self):
    image_shape = (2, 3, 3)
    image, serialized_example = self.GenerateImage(
        image_format='raw', image_shape=image_shape, image_dtype=np.float32)

    decoded_image = self.RunDecodeExample(
        serialized_example,
        tfexample_decoder.Image(shape=image_shape, dtype=tf.float32),
        image_format='raw')

    self.assertAllClose(image, decoded_image, atol=0)
コード例 #5
0
  def testDecodeExampleWithRAWEncoding(self):
    image_shape = (2, 3, 3)
    image, serialized_example = self.GenerateImage(
        image_format='RAW', image_shape=image_shape)

    decoded_image = self.RunDecodeExample(
        serialized_example,
        tfexample_decoder.Image(shape=image_shape),
        image_format='RAW')

    self.assertAllClose(image, decoded_image, atol=0)
コード例 #6
0
  def testDecodeExampleWithNoShapeInfo(self):
    test_image_channels = [1, 3]
    for channels in test_image_channels:
      image_shape = (2, 3, channels)
      _, serialized_example = self.GenerateImage(
          image_format='jpeg', image_shape=image_shape)

      tf_decoded_image = self.DecodeExample(
          serialized_example,
          tfexample_decoder.Image(shape=None, channels=channels),
          image_format='jpeg')
      self.assertEqual(tf_decoded_image.get_shape().ndims, 3)
コード例 #7
0
  def testDecodeExampleWithPNGEncoding(self):
    test_image_channels = [1, 3, 4]
    for channels in test_image_channels:
      image_shape = (2, 3, channels)
      image, serialized_example = self.GenerateImage(
          image_format='PNG', image_shape=image_shape)

      decoded_image = self.RunDecodeExample(
          serialized_example,
          tfexample_decoder.Image(channels=channels),
          image_format='PNG')

      self.assertAllClose(image, decoded_image, atol=0)
コード例 #8
0
  def testDecodeExampleWithJPEGEncoding(self):
    test_image_channels = [1, 3]
    for channels in test_image_channels:
      image_shape = (2, 3, channels)
      image, serialized_example = self.GenerateImage(
          image_format='JPEG', image_shape=image_shape)

      decoded_image = self.RunDecodeExample(
          serialized_example,
          tfexample_decoder.Image(channels=channels),
          image_format='JPEG')

      # Need to use a tolerance of 1 because of noise in the jpeg encode/decode
      self.assertAllClose(image, decoded_image, atol=1.001)
コード例 #9
0
  def testDecodeExampleWithRepeatedImages(self):
    image_shape = (2, 3, 3)
    image_format = 'png'
    image, _ = self.GenerateImage(
        image_format=image_format, image_shape=image_shape)
    tf_encoded = self._Encoder(image, image_format)
    with self.cached_session():
      tf_string = tf_encoded.eval()

    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'image/encoded':
                    tf.train.Feature(
                        bytes_list=tf.train.BytesList(
                            value=[tf_string, tf_string])),
                'image/format':
                    self._StringFeature(image_format),
            }))
    serialized_example = example.SerializeToString()

    with self.cached_session():
      serialized_example = array_ops.reshape(serialized_example, shape=[])

      decoder = tfexample_decoder.TFExampleDecoder(
          keys_to_features={
              'image/encoded':
                  parsing_ops.FixedLenFeature((2,), tf.string),
              'image/format':
                  parsing_ops.FixedLenFeature((),
                                              tf.string,
                                              default_value=image_format),
          },
          items_to_handlers={'image': tfexample_decoder.Image(repeated=True)})
      [tf_image] = decoder.decode(serialized_example, ['image'])

      output_image = tf_image.eval()

      self.assertEqual(output_image.shape, (2, 2, 3, 3))
      self.assertAllEqual(np.squeeze(output_image[0, :, :, :]), image)
      self.assertAllEqual(np.squeeze(output_image[1, :, :, :]), image)