def _create_tfrecord_dataset(tmpdir): if not gfile.Exists(tmpdir): gfile.MakeDirs(tmpdir) data_sources = test_utils.create_tfrecord_files(tmpdir, num_files=1) keys_to_features = { 'image/encoded': parsing_ops.FixedLenFeature(shape=(), dtype=dtypes.string, default_value=''), 'image/format': parsing_ops.FixedLenFeature(shape=(), dtype=dtypes.string, default_value='jpeg'), 'image/class/label': parsing_ops.FixedLenFeature(shape=[1], dtype=dtypes.int64, default_value=array_ops.zeros( [1], dtype=dtypes.int64)) } items_to_handlers = { 'image': tfexample_decoder.Image(), 'label': tfexample_decoder.Tensor('image/class/label'), } decoder = tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) return dataset.Dataset(data_sources=data_sources, reader=io_ops.TFRecordReader, decoder=decoder, num_samples=100, items_to_descriptions=None)
def testDecodeExampleWithJpegEncoding(self): image_shape = (2, 3, 3) image, serialized_example = self.GenerateImage( image_format='jpeg', image_shape=image_shape) decoded_image = self.RunDecodeExample( serialized_example, tfexample_decoder.Image(), image_format='jpeg') # Need to use a tolerance of 1 because of noise in the jpeg encode/decode self.assertAllClose(image, decoded_image, atol=1.001)
def testDecodeExampleWithJpegEncodingAt16BitDoesNotCauseError(self): image_shape = (2, 3, 3) # Image has type uint8 but decoding at uint16 should not cause problems. image, serialized_example = self.GenerateImage( image_format='jpeg', image_shape=image_shape) decoded_image = self.RunDecodeExample( serialized_example, tfexample_decoder.Image(dtype=dtypes.uint16), image_format='jpeg') self.assertAllClose(image, decoded_image, atol=1.001)
def testDecodeExampleWithRawEncodingFloatDtype(self): image_shape = (2, 3, 3) image, serialized_example = self.GenerateImage( image_format='raw', image_shape=image_shape, image_dtype=np.float32) decoded_image = self.RunDecodeExample( serialized_example, tfexample_decoder.Image(shape=image_shape, dtype=dtypes.float32), image_format='raw') self.assertAllClose(image, decoded_image, atol=0)
def testDecodeExampleWithRAWEncoding(self): image_shape = (2, 3, 3) image, serialized_example = self.GenerateImage( image_format='RAW', image_shape=image_shape) decoded_image = self.RunDecodeExample( serialized_example, tfexample_decoder.Image(shape=image_shape), image_format='RAW') self.assertAllClose(image, decoded_image, atol=0)
def testDecodeExampleWithNoShapeInfo(self): test_image_channels = [1, 3] for channels in test_image_channels: image_shape = (2, 3, channels) _, serialized_example = self.GenerateImage( image_format='jpeg', image_shape=image_shape) tf_decoded_image = self.DecodeExample( serialized_example, tfexample_decoder.Image(shape=None, channels=channels), image_format='jpeg') self.assertEqual(tf_decoded_image.get_shape().ndims, 3)
def testDecodeExampleWithPNGEncoding(self): test_image_channels = [1, 3, 4] for channels in test_image_channels: image_shape = (2, 3, channels) image, serialized_example = self.GenerateImage( image_format='PNG', image_shape=image_shape) decoded_image = self.RunDecodeExample( serialized_example, tfexample_decoder.Image(channels=channels), image_format='PNG') self.assertAllClose(image, decoded_image, atol=0)
def testDecodeExampleWithJPEGEncoding(self): test_image_channels = [1, 3] for channels in test_image_channels: image_shape = (2, 3, channels) image, serialized_example = self.GenerateImage( image_format='JPEG', image_shape=image_shape) decoded_image = self.RunDecodeExample( serialized_example, tfexample_decoder.Image(channels=channels), image_format='JPEG') # Need to use a tolerance of 1 because of noise in the jpeg encode/decode self.assertAllClose(image, decoded_image, atol=1.001)
def testDecodeExampleWithRepeatedImages(self): image_shape = (2, 3, 3) image_format = 'png' image, _ = self.GenerateImage( image_format=image_format, image_shape=image_shape) tf_encoded = self._Encoder(image, image_format) with self.cached_session(): tf_string = tf_encoded.eval() example = example_pb2.Example( features=feature_pb2.Features( feature={ 'image/encoded': feature_pb2.Feature( bytes_list=feature_pb2.BytesList( value=[tf_string, tf_string])), 'image/format': self._StringFeature(image_format), })) serialized_example = example.SerializeToString() with self.cached_session(): serialized_example = array_ops.reshape(serialized_example, shape=[]) decoder = tfexample_decoder.TFExampleDecoder( keys_to_features={ 'image/encoded': parsing_ops.FixedLenFeature((2,), dtypes.string), 'image/format': parsing_ops.FixedLenFeature( (), dtypes.string, default_value=image_format), }, items_to_handlers={'image': tfexample_decoder.Image(repeated=True)}) [tf_image] = decoder.decode(serialized_example, ['image']) output_image = tf_image.eval() self.assertEqual(output_image.shape, (2, 2, 3, 3)) self.assertAllEqual(np.squeeze(output_image[0, :, :, :]), image) self.assertAllEqual(np.squeeze(output_image[1, :, :, :]), image)