def test_smart_resize(self): test_input = np.random.random((20, 40, 3)) output = image_utils.smart_resize(test_input, size=(50, 50)) self.assertIsInstance(output, np.ndarray) self.assertListEqual(list(output.shape), [50, 50, 3]) output = image_utils.smart_resize(test_input, size=(10, 10)) self.assertListEqual(list(output.shape), [10, 10, 3]) output = image_utils.smart_resize(test_input, size=(100, 50)) self.assertListEqual(list(output.shape), [100, 50, 3]) output = image_utils.smart_resize(test_input, size=(5, 15)) self.assertListEqual(list(output.shape), [5, 15, 3])
def test_smart_resize_tf_dataset(self, size): test_input_np = np.random.random((2, 20, 40, 3)) test_ds = tf.data.Dataset.from_tensor_slices(test_input_np) resize = lambda img: image_utils.smart_resize(img, size=size) test_ds = test_ds.map(resize) for sample in test_ds.as_numpy_iterator(): self.assertIsInstance(sample, np.ndarray) self.assertListEqual(list(sample.shape), [size[0], size[1], 3])
def test_smart_resize_errors(self): with self.assertRaisesRegex(ValueError, 'a tuple of 2 integers'): image_utils.smart_resize(np.random.random((20, 20, 2)), size=(10, 5, 3)) with self.assertRaisesRegex(ValueError, 'incorrect rank'): image_utils.smart_resize(np.random.random((2, 4)), size=(10, 5)) with self.assertRaisesRegex(ValueError, 'incorrect rank'): image_utils.smart_resize(np.random.random((2, 4, 4, 5, 3)), size=(10, 5))
def load_image( path, image_size, num_channels, interpolation, crop_to_aspect_ratio=False ): """Load an image from a path and resize it.""" img = tf.io.read_file(path) img = tf.image.decode_image( img, channels=num_channels, expand_animations=False ) if crop_to_aspect_ratio: img = image_utils.smart_resize( img, image_size, interpolation=interpolation ) else: img = tf.image.resize(img, image_size, method=interpolation) img.set_shape((image_size[0], image_size[1], num_channels)) return img
def test_smart_resize_batch(self): img = np.random.random((2, 20, 40, 3)) out = image_utils.smart_resize(img, size=(20, 20)) self.assertListEqual(list(out.shape), [2, 20, 20, 3]) self.assertAllClose(out, img[:, :, 10:-10, :])