def testWrongCenterCropImages(self): """Tests that all ValueErrors are triggered for CenterCropImages.""" with tf.Graph().as_default(): input_shape = [32, 32, 3] batch_size = 4 images = self._CreateRampTestImages(batch_size, input_shape[0], input_shape[1]) with self.assertRaises(ValueError): # The input shape is (height, width) but (height, width, channels) is # required. image_transformations.CenterCropImages([images], [32, 32], [20, 64]) with self.assertRaises(ValueError): # The input shape is (height, width, channel, random) but # (height, width, channels) is required. image_transformations.CenterCropImages([images], [32, 32, 3, 4], [20, 64]) with self.assertRaises(ValueError): # The output shape is (height, ) but (height, width) is required. image_transformations.CenterCropImages([images], [32, 32, 3], [20]) with self.assertRaises(ValueError): # The output shape is (height, width, random) but (height, width) is # required. image_transformations.CenterCropImages([images], [32, 32, 3], [20, 32, 64])
def testCenterCrop(self, input_shape, output_shape): input_shape = input_shape + [3] with tf.Graph().as_default(): batch_size = 4 images = self._CreateRampTestImages(batch_size, input_shape[0], input_shape[1]) cropped = image_transformations.CenterCropImages([images], input_shape, output_shape)[0] with tf.Session() as sess: cropped_image = sess.run(cropped) # Check cropped shape. self.assertAllEqual(cropped_image.shape, [batch_size] + output_shape + [3]) # Check top-left corner on G-channel (y-coordinates). self.assertEqual(cropped_image[0, 0, 0, 1], (input_shape[0] - output_shape[0]) // 2) # Check bottom-left corner on G-channel (y-coordinates). self.assertEqual(cropped_image[0, -1, 0, 1], (input_shape[0] - output_shape[0]) // 2 + output_shape[0] - 1) # Check top-left corner on R-channel (x-coordinates). self.assertEqual(cropped_image[0, 0, 0, 0], (input_shape[1] - output_shape[1]) // 2) # Check bottom-left corner on R-channel (x-coordinates). self.assertEqual(cropped_image[0, 0, -1, 0], (input_shape[1] - output_shape[1]) // 2 + output_shape[1] - 1)
def crop_image(img, mode, input_size=(512, 640), target_size=(472, 472)): """Takes a crop of the image, either randomly or from the center. The crop is consistent across all images given in the batch. Args: img: 4D image Tensor [batch, height, width, channels]. mode: (ModeKeys) Specifies if this is training, evaluation or prediction. input_size: (height, width) of input. target_size: (height, width) of desired crop. Returns: img cropped to the desired size, randomly if mode == TRAIN and from the center otherwise. """ if input_size == target_size: # Don't even bother adding the ops. return img input_height, input_width = input_size input_shape = (input_height, input_width, 3) target_shape = target_size if mode == tf.estimator.ModeKeys.TRAIN: crops = image_transformations.RandomCropImages( [img], input_shape=input_shape, target_shape=target_shape)[0] else: crops = image_transformations.CenterCropImages( [img], input_shape=input_shape, target_shape=target_shape)[0] return crops
def _preprocess_fn(self, features, labels, mode): """The preprocessing function which will be executed prior to the model_fn. Args: features: The input features extracted from a single example in our in_feature_specification format. labels: (Optional) The input labels extracted from a single example in our in_label_specification format. mode: (ModeKeys) Specifies if this is training, evaluation or prediction. Returns: features: The preprocessed features, potentially adding additional tensors derived from the input features. labels: (Optional) The preprocessed labels, potentially adding additional tensors derived from the input features and labels. """ if mode == TRAIN: image = image_transformations.RandomCropImages( [features.state.image], INPUT_SHAPE, TARGET_SHAPE)[0] else: image = image_transformations.CenterCropImages( [features.state.image], INPUT_SHAPE, TARGET_SHAPE)[0] image = tf.image.convert_image_dtype(image, tf.float32) if mode == TRAIN: image = (image_transformations.ApplyPhotometricImageDistortions( [image])[0]) features.state.image = image return features, labels