コード例 #1
0
def _create_tfrecord_dataset(tmpdir):
    if not gfile.Exists(tmpdir):
        gfile.MakeDirs(tmpdir)

    data_sources = test_utils.create_tfrecord_files(tmpdir, num_files=1)

    keys_to_features = {
        'image/encoded':
        parsing_ops.FixedLenFeature(shape=(),
                                    dtype=dtypes.string,
                                    default_value=''),
        'image/format':
        parsing_ops.FixedLenFeature(shape=(),
                                    dtype=dtypes.string,
                                    default_value='jpeg'),
        'image/class/label':
        parsing_ops.FixedLenFeature(shape=[1],
                                    dtype=dtypes.int64,
                                    default_value=array_ops.zeros(
                                        [1], dtype=dtypes.int64))
    }

    items_to_handlers = {
        'image': tfexample_decoder.Image(),
        'label': tfexample_decoder.Tensor('image/class/label'),
    }

    decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
                                                 items_to_handlers)

    return dataset.Dataset(data_sources=data_sources,
                           reader=io_ops.TFRecordReader,
                           decoder=decoder,
                           num_samples=100,
                           items_to_descriptions=None)
コード例 #2
0
  def testDecodeExampleWithJpegEncoding(self):
    image_shape = (2, 3, 3)
    image, serialized_example = self.GenerateImage(
        image_format='jpeg', image_shape=image_shape)

    decoded_image = self.RunDecodeExample(
        serialized_example, tfexample_decoder.Image(), image_format='jpeg')

    # Need to use a tolerance of 1 because of noise in the jpeg encode/decode
    self.assertAllClose(image, decoded_image, atol=1.001)
コード例 #3
0
 def testDecodeExampleWithJpegEncodingAt16BitDoesNotCauseError(self):
   image_shape = (2, 3, 3)
   # Image has type uint8 but decoding at uint16 should not cause problems.
   image, serialized_example = self.GenerateImage(
       image_format='jpeg', image_shape=image_shape)
   decoded_image = self.RunDecodeExample(
       serialized_example,
       tfexample_decoder.Image(dtype=dtypes.uint16),
       image_format='jpeg')
   self.assertAllClose(image, decoded_image, atol=1.001)
コード例 #4
0
  def testDecodeExampleWithRawEncodingFloatDtype(self):
    image_shape = (2, 3, 3)
    image, serialized_example = self.GenerateImage(
        image_format='raw', image_shape=image_shape, image_dtype=np.float32)

    decoded_image = self.RunDecodeExample(
        serialized_example,
        tfexample_decoder.Image(shape=image_shape, dtype=dtypes.float32),
        image_format='raw')

    self.assertAllClose(image, decoded_image, atol=0)
コード例 #5
0
  def testDecodeExampleWithRAWEncoding(self):
    image_shape = (2, 3, 3)
    image, serialized_example = self.GenerateImage(
        image_format='RAW', image_shape=image_shape)

    decoded_image = self.RunDecodeExample(
        serialized_example,
        tfexample_decoder.Image(shape=image_shape),
        image_format='RAW')

    self.assertAllClose(image, decoded_image, atol=0)
コード例 #6
0
  def testDecodeExampleWithNoShapeInfo(self):
    test_image_channels = [1, 3]
    for channels in test_image_channels:
      image_shape = (2, 3, channels)
      _, serialized_example = self.GenerateImage(
          image_format='jpeg', image_shape=image_shape)

      tf_decoded_image = self.DecodeExample(
          serialized_example,
          tfexample_decoder.Image(shape=None, channels=channels),
          image_format='jpeg')
      self.assertEqual(tf_decoded_image.get_shape().ndims, 3)
コード例 #7
0
  def testDecodeExampleWithPNGEncoding(self):
    test_image_channels = [1, 3, 4]
    for channels in test_image_channels:
      image_shape = (2, 3, channels)
      image, serialized_example = self.GenerateImage(
          image_format='PNG', image_shape=image_shape)

      decoded_image = self.RunDecodeExample(
          serialized_example,
          tfexample_decoder.Image(channels=channels),
          image_format='PNG')

      self.assertAllClose(image, decoded_image, atol=0)
コード例 #8
0
  def testDecodeExampleWithJPEGEncoding(self):
    test_image_channels = [1, 3]
    for channels in test_image_channels:
      image_shape = (2, 3, channels)
      image, serialized_example = self.GenerateImage(
          image_format='JPEG', image_shape=image_shape)

      decoded_image = self.RunDecodeExample(
          serialized_example,
          tfexample_decoder.Image(channels=channels),
          image_format='JPEG')

      # Need to use a tolerance of 1 because of noise in the jpeg encode/decode
      self.assertAllClose(image, decoded_image, atol=1.001)
コード例 #9
0
  def testDecodeExampleWithRepeatedImages(self):
    image_shape = (2, 3, 3)
    image_format = 'png'
    image, _ = self.GenerateImage(
        image_format=image_format, image_shape=image_shape)
    tf_encoded = self._Encoder(image, image_format)
    with self.cached_session():
      tf_string = tf_encoded.eval()

    example = example_pb2.Example(
        features=feature_pb2.Features(
            feature={
                'image/encoded':
                    feature_pb2.Feature(
                        bytes_list=feature_pb2.BytesList(
                            value=[tf_string, tf_string])),
                'image/format':
                    self._StringFeature(image_format),
            }))
    serialized_example = example.SerializeToString()

    with self.cached_session():
      serialized_example = array_ops.reshape(serialized_example, shape=[])

      decoder = tfexample_decoder.TFExampleDecoder(
          keys_to_features={
              'image/encoded':
                  parsing_ops.FixedLenFeature((2,), dtypes.string),
              'image/format':
                  parsing_ops.FixedLenFeature(
                      (), dtypes.string, default_value=image_format),
          },
          items_to_handlers={'image': tfexample_decoder.Image(repeated=True)})
      [tf_image] = decoder.decode(serialized_example, ['image'])

      output_image = tf_image.eval()

      self.assertEqual(output_image.shape, (2, 2, 3, 3))
      self.assertAllEqual(np.squeeze(output_image[0, :, :, :]), image)
      self.assertAllEqual(np.squeeze(output_image[1, :, :, :]), image)