def _augment_batch_(self, batch, random_state, parents, hooks): if batch.images is None: return batch images = batch.images nb_images = len(images) samples = self.k.draw_samples((nb_images, ), random_state=random_state) for i, (image, ksize) in enumerate(zip(images, samples)): has_zero_sized_axes = (image.size == 0) if ksize > 1 and not has_zero_sized_axes: ksize = ksize + 1 if ksize % 2 == 0 else ksize if image.ndim == 2 or image.shape[-1] <= 512: image_aug = cv2.medianBlur( _normalize_cv2_input_arr_(image), ksize) # cv2.medianBlur() removes channel axis for single-channel # images if image_aug.ndim == 2: image_aug = image_aug[..., np.newaxis] else: # TODO this is quite inefficient # handling more than 512 channels in cv2.medainBlur() channels = [ cv2.medianBlur( _normalize_cv2_input_arr_(image[..., c]), ksize) for c in sm.xrange(image.shape[-1]) ] image_aug = np.stack(channels, axis=-1) batch.images[i] = image_aug return batch
def _fliplr_cv2(arr): # cv2.flip() returns None for arrays with zero height or width # and turns channels=0 to channels=512 if arr.size == 0: return np.copy(arr) # cv2.flip() fails for more than 512 channels if arr.ndim == 3 and arr.shape[-1] > 512: # TODO this is quite inefficient right now channels = [ cv2.flip(_normalize_cv2_input_arr_(arr[..., c]), 1) for c in sm.xrange(arr.shape[-1]) ] result = np.stack(channels, axis=-1) else: # Normalization from imgaug.imgaug._normalize_cv2_input_arr_(). # Moved here for performance reasons. Keep this aligned. # TODO recalculate timings, they were computed without this. flags = arr.flags if not flags["OWNDATA"]: arr = np.copy(arr) flags = arr.flags if not flags["C_CONTIGUOUS"]: arr = np.ascontiguousarray(arr) result = cv2.flip(_normalize_cv2_input_arr_(arr), 1) if result.ndim == 2 and arr.ndim == 3: return result[..., np.newaxis] return result
def _augment_batch_(self, batch, random_state, parents, hooks): # pylint: disable=invalid-name if batch.images is None: return batch images = batch.images # Make sure that all images have 3 channels assert all([ image.shape[2] == 3 for image in images ]), ("BilateralBlur can currently only be applied to images with 3 " "channels. Got channels: %s" % ([image.shape[2] for image in images], )) nb_images = len(images) rss = random_state.duplicate(3) samples_d = self.d.draw_samples((nb_images, ), random_state=rss[0]) samples_sigma_color = self.sigma_color.draw_samples( (nb_images, ), random_state=rss[1]) samples_sigma_space = self.sigma_space.draw_samples( (nb_images, ), random_state=rss[2]) gen = enumerate( zip(images, samples_d, samples_sigma_color, samples_sigma_space)) for i, (image, di, sigma_color_i, sigma_space_i) in gen: has_zero_sized_axes = (image.size == 0) if di != 1 and not has_zero_sized_axes: batch.images[i] = cv2.bilateralFilter( _normalize_cv2_input_arr_(image), di, sigma_color_i, sigma_space_i) return batch
def _find_edges_canny(image, edge_multiplier, from_colorspace): image_gray = colorlib.change_colorspace_(np.copy(image), to_colorspace=colorlib.CSPACE_GRAY, from_colorspace=from_colorspace) image_gray = image_gray[..., 0] thresh = min(int(200 * (1/edge_multiplier)), 254) edges = cv2.Canny(_normalize_cv2_input_arr_(image_gray), thresh, thresh) return edges
def _suppress_edge_blobs(edges, size, thresh, inverse): kernel = np.ones((size, size), dtype=np.float32) counts = cv2.filter2D(_normalize_cv2_input_arr_(edges / 255.0), -1, kernel) if inverse: mask = (counts < thresh) else: mask = (counts >= thresh) edges = np.copy(edges) edges[mask] = 0 return edges
def _augment_batch_(self, batch, random_state, parents, hooks): if batch.images is None: return batch images = batch.images iadt.gate_dtypes(images, allowed=["uint8"], disallowed=[ "bool", "uint16", "uint32", "uint64", "uint128", "uint256", "int8", "int16", "int32", "int64", "int128", "int256", "float32", "float64", "float96", "float128", "float256" ], augmenter=self) rss = random_state.duplicate(len(images)) samples = self._draw_samples(images, rss[-1]) alpha_samples = samples[0] hthresh_samples = samples[1] sobel_samples = samples[2] gen = enumerate( zip(images, alpha_samples, hthresh_samples, sobel_samples)) for i, (image, alpha, hthreshs, sobel) in gen: assert image.shape[-1] in [ 1, 3, 4 ], ("Canny edge detector can currently only handle images with " "channel numbers that are 1, 3 or 4. Got %d.") % ( image.shape[-1], ) has_zero_sized_axes = (0 in image.shape[0:2]) if alpha > 0 and sobel > 1 and not has_zero_sized_axes: image_canny = cv2.Canny(_normalize_cv2_input_arr_(image[:, :, 0:3]), threshold1=hthreshs[0], threshold2=hthreshs[1], apertureSize=sobel, L2gradient=True) image_canny = (image_canny > 0) # canny returns a boolean (H,W) image, so we change it to # (H,W,C) and then uint8 image_canny_color = self.colorizer.colorize( image_canny, image, nth_image=i, random_state=rss[i]) batch.images[i] = blend.blend_alpha(image_canny_color, image, alpha) return batch
def _find_edges_laplacian(image, edge_multiplier, from_colorspace): image_gray = colorlib.change_colorspace_(np.copy(image), to_colorspace=colorlib.CSPACE_GRAY, from_colorspace=from_colorspace) image_gray = image_gray[..., 0] edges_f = cv2.Laplacian(_normalize_cv2_input_arr_(image_gray / 255.0), cv2.CV_64F) edges_f = np.abs(edges_f) edges_f = edges_f ** 2 vmax = np.percentile(edges_f, min(int(90 * (1/edge_multiplier)), 99)) edges_f = np.clip(edges_f, 0.0, vmax) / vmax edges_uint8 = np.clip(np.round(edges_f * 255), 0, 255.0).astype(np.uint8) edges_uint8 = _blur_median(edges_uint8, 3) edges_uint8 = _threshold(edges_uint8, 50) return edges_uint8
def _augment_batch_(self, batch, random_state, parents, hooks): if batch.images is None: return batch images = batch.images iadt.gate_dtypes(images, allowed=[ "bool", "uint8", "uint16", "int8", "int16", "float16", "float32", "float64" ], disallowed=[ "uint32", "uint64", "uint128", "uint256", "int32", "int64", "int128", "int256", "float96", "float128", "float256" ], augmenter=self) nb_images = len(images) if self.mode == "single": samples = self.k.draw_samples((nb_images, ), random_state=random_state) samples = (samples, samples) else: rss = random_state.duplicate(2) samples = ( self.k[0].draw_samples((nb_images, ), random_state=rss[0]), self.k[1].draw_samples((nb_images, ), random_state=rss[1]), ) gen = enumerate(zip(images, samples[0], samples[1])) for i, (image, ksize_h, ksize_w) in gen: kernel_impossible = (ksize_h == 0 or ksize_w == 0) kernel_does_nothing = (ksize_h == 1 and ksize_w == 1) has_zero_sized_axes = (image.size == 0) if (not kernel_impossible and not kernel_does_nothing and not has_zero_sized_axes): input_dtype = image.dtype if image.dtype.name in ["bool", "float16"]: image = image.astype(np.float32, copy=False) elif image.dtype.name == "int8": image = image.astype(np.int16, copy=False) if image.ndim == 2 or image.shape[-1] <= 512: image_aug = cv2.blur(_normalize_cv2_input_arr_(image), (ksize_h, ksize_w)) # cv2.blur() removes channel axis for single-channel images if image_aug.ndim == 2: image_aug = image_aug[..., np.newaxis] else: # TODO this is quite inefficient # handling more than 512 channels in cv2.blur() channels = [ cv2.blur(_normalize_cv2_input_arr_(image[..., c]), (ksize_h, ksize_w)) for c in sm.xrange(image.shape[-1]) ] image_aug = np.stack(channels, axis=-1) if input_dtype.name == "bool": image_aug = image_aug > 0.5 elif input_dtype.name in ["int8", "float16"]: image_aug = iadt.restore_dtypes_(image_aug, input_dtype) batch.images[i] = image_aug return batch
def blur_gaussian_(image, sigma, ksize=None, backend="auto", eps=1e-3): """Blur an image using gaussian blurring in-place. This operation *may* change the input image in-place. dtype support:: if (backend="auto"):: * ``uint8``: yes; fully tested (1) * ``uint16``: yes; tested (1) * ``uint32``: yes; tested (2) * ``uint64``: yes; tested (2) * ``int8``: yes; tested (1) * ``int16``: yes; tested (1) * ``int32``: yes; tested (1) * ``int64``: yes; tested (2) * ``float16``: yes; tested (1) * ``float32``: yes; tested (1) * ``float64``: yes; tested (1) * ``float128``: no * ``bool``: yes; tested (1) - (1) Handled by ``cv2``. See ``backend="cv2"``. - (2) Handled by ``scipy``. See ``backend="scipy"``. if (backend="cv2"):: * ``uint8``: yes; fully tested * ``uint16``: yes; tested * ``uint32``: no (2) * ``uint64``: no (3) * ``int8``: yes; tested (4) * ``int16``: yes; tested * ``int32``: yes; tested (5) * ``int64``: no (6) * ``float16``: yes; tested (7) * ``float32``: yes; tested * ``float64``: yes; tested * ``float128``: no (8) * ``bool``: yes; tested (1) - (1) Mapped internally to ``float32``. Otherwise causes ``TypeError: src data type = 0 is not supported``. - (2) Causes ``TypeError: src data type = 6 is not supported``. - (3) Causes ``cv2.error: OpenCV(3.4.5) (...)/filter.cpp:2957: error: (-213:The function/feature is not implemented) Unsupported combination of source format (=4), and buffer format (=5) in function 'getLinearRowFilter'``. - (4) Mapped internally to ``int16``. Otherwise causes ``cv2.error: OpenCV(3.4.5) (...)/filter.cpp:2957: error: (-213:The function/feature is not implemented) Unsupported combination of source format (=1), and buffer format (=5) in function 'getLinearRowFilter'``. - (5) Mapped internally to ``float64``. Otherwise causes ``cv2.error: OpenCV(3.4.5) (...)/filter.cpp:2957: error: (-213:The function/feature is not implemented) Unsupported combination of source format (=4), and buffer format (=5) in function 'getLinearRowFilter'``. - (6) Causes ``cv2.error: OpenCV(3.4.5) (...)/filter.cpp:2957: error: (-213:The function/feature is not implemented) Unsupported combination of source format (=4), and buffer format (=5) in function 'getLinearRowFilter'``. - (7) Mapped internally to ``float32``. Otherwise causes ``TypeError: src data type = 23 is not supported``. - (8) Causes ``TypeError: src data type = 13 is not supported``. if (backend="scipy"):: * ``uint8``: yes; fully tested * ``uint16``: yes; tested * ``uint32``: yes; tested * ``uint64``: yes; tested * ``int8``: yes; tested * ``int16``: yes; tested * ``int32``: yes; tested * ``int64``: yes; tested * ``float16``: yes; tested (1) * ``float32``: yes; tested * ``float64``: yes; tested * ``float128``: no (2) * ``bool``: yes; tested (3) - (1) Mapped internally to ``float32``. Otherwise causes ``RuntimeError: array type dtype('float16') not supported``. - (2) Causes ``RuntimeError: array type dtype('float128') not supported``. - (3) Mapped internally to ``float32``. Otherwise too inaccurate. Parameters ---------- image : numpy.ndarray The image to blur. Expected to be of shape ``(H, W)`` or ``(H, W, C)``. sigma : number Standard deviation of the gaussian blur. Larger numbers result in more large-scale blurring, which is overall slower than small-scale blurring. ksize : None or int, optional Size in height/width of the gaussian kernel. This argument is only understood by the ``cv2`` backend. If it is set to ``None``, an appropriate value for `ksize` will automatically be derived from `sigma`. The value is chosen tighter for larger sigmas to avoid as much as possible very large kernel sizes and therey improve performance. backend : {'auto', 'cv2', 'scipy'}, optional Backend library to use. If ``auto``, then the likely best library will be automatically picked per image. That is usually equivalent to ``cv2`` (OpenCV) and it will fall back to ``scipy`` for datatypes not supported by OpenCV. eps : number, optional A threshold used to decide whether `sigma` can be considered zero. Returns ------- numpy.ndarray The blurred image. Same shape and dtype as the input. (Input image *might* have been altered in-place.) """ has_zero_sized_axes = (image.size == 0) if sigma > 0 + eps and not has_zero_sized_axes: dtype = image.dtype iadt.gate_dtypes(image, allowed=[ "bool", "uint8", "uint16", "uint32", "int8", "int16", "int32", "int64", "uint64", "float16", "float32", "float64" ], disallowed=[ "uint128", "uint256", "int128", "int256", "float96", "float128", "float256" ], augmenter=None) dts_not_supported_by_cv2 = ["uint32", "uint64", "int64", "float128"] backend_to_use = backend if backend == "auto": backend_to_use = ("cv2" if image.dtype.name not in dts_not_supported_by_cv2 else "scipy") elif backend == "cv2": assert image.dtype.name not in dts_not_supported_by_cv2, ( "Requested 'cv2' backend, but provided %s input image, which " "cannot be handled by that backend. Choose a different " "backend or set backend to 'auto' or use a different " "datatype." % (image.dtype.name, )) elif backend == "scipy": # can handle all dtypes that were allowed in gate_dtypes() pass if backend_to_use == "scipy": if dtype.name == "bool": # We convert bool to float32 here, because gaussian_filter() # seems to only return True when the underlying value is # approximately 1.0, not when it is above 0.5. So we do that # here manually. cv2 does not support bool for gaussian blur. image = image.astype(np.float32, copy=False) elif dtype.name == "float16": image = image.astype(np.float32, copy=False) # gaussian_filter() has no ksize argument # TODO it does have a truncate argument that truncates at x # standard deviations -- maybe can be used similarly to ksize if ksize is not None: ia.warn( "Requested 'scipy' backend or picked it automatically by " "backend='auto' n blur_gaussian_(), but also provided " "'ksize' argument, which is not understood by that " "backend and will be ignored.") # Note that while gaussian_filter can be applied to all channels # at the same time, that should not be done here, because then # the blurring would also happen across channels (e.g. red values # might be mixed with blue values in RGB) if image.ndim == 2: image[:, :] = ndimage.gaussian_filter(image[:, :], sigma, mode="mirror") else: nb_channels = image.shape[2] for channel in sm.xrange(nb_channels): image[:, :, channel] = ndimage.gaussian_filter(image[:, :, channel], sigma, mode="mirror") else: if dtype.name == "bool": image = image.astype(np.float32, copy=False) elif dtype.name == "float16": image = image.astype(np.float32, copy=False) elif dtype.name == "int8": image = image.astype(np.int16, copy=False) elif dtype.name == "int32": image = image.astype(np.float64, copy=False) # ksize here is derived from the equation to compute sigma based # on ksize, see # https://docs.opencv.org/3.1.0/d4/d86/group__imgproc__filter.html # -> cv::getGaussianKernel() # example values: # sig = 0.1 -> ksize = -1.666 # sig = 0.5 -> ksize = 0.9999 # sig = 1.0 -> ksize = 1.0 # sig = 2.0 -> ksize = 11.0 # sig = 3.0 -> ksize = 17.666 # ksize = ((sig - 0.8)/0.3 + 1)/0.5 + 1 if ksize is None: ksize = _compute_gaussian_blur_ksize(sigma) else: assert ia.is_single_integer(ksize), ( "Expected 'ksize' argument to be a number, " "got %s." % (type(ksize), )) ksize = ksize + 1 if ksize % 2 == 0 else ksize if ksize > 0: image_warped = cv2.GaussianBlur( _normalize_cv2_input_arr_(image), (ksize, ksize), sigmaX=sigma, sigmaY=sigma, borderType=cv2.BORDER_REFLECT_101) # re-add channel axis removed by cv2 if input was (H, W, 1) image = (image_warped[..., np.newaxis] if image.ndim == 3 and image_warped.ndim == 2 else image_warped) if dtype.name == "bool": image = image > 0.5 elif dtype.name != image.dtype.name: image = iadt.restore_dtypes_(image, dtype) return image
def blur_mean_shift_(image, spatial_window_radius, color_window_radius): """Apply a pyramidic mean shift filter to the input image in-place. This produces an output image that has similarity with one modified by a bilateral filter. That is different from mean shift *segmentation*, which averages the colors in segments found by mean shift clustering. This function is a thin wrapper around ``cv2.pyrMeanShiftFiltering``. .. note:: This function does *not* change the image's colorspace to ``RGB`` before applying the mean shift filter. A non-``RGB`` colorspace will hence influence the results. .. note:: This function is quite slow. dtype support:: * ``uint8``: yes; fully tested * ``uint16``: no (1) * ``uint32``: no (1) * ``uint64``: no (1) * ``int8``: no (1) * ``int16``: no (1) * ``int32``: no (1) * ``int64``: no (1) * ``float16``: no (1) * ``float32``: no (1) * ``float64``: no (1) * ``float128``: no (1) * ``bool``: no (1) - (1) Not supported by ``cv2.pyrMeanShiftFiltering``. Parameters ---------- image : ndarray ``(H,W)`` or ``(H,W,1)`` or ``(H,W,3)`` image to blur. Images with no or one channel will be temporarily tiled to have three channels. spatial_window_radius : number Spatial radius for pixels that are assumed to be similar. color_window_radius : number Color radius for pixels that are assumed to be similar. Returns ------- ndarray Blurred input image. Same shape and dtype as the input. (Input image *might* have been altered in-place.) """ if 0 in image.shape[0:2]: return image # opencv method only supports uint8 assert image.dtype.name == "uint8", ( "Expected image with dtype \"uint8\", " "got \"%s\"." % (image.dtype.name, )) shape_is_hw = (image.ndim == 2) shape_is_hw1 = (image.ndim == 3 and image.shape[-1] == 1) shape_is_hw3 = (image.ndim == 3 and image.shape[-1] == 3) assert shape_is_hw or shape_is_hw1 or shape_is_hw3, ( "Expected (H,W) or (H,W,1) or (H,W,3) image, " "got shape %s." % (image.shape, )) # opencv method only supports (H,W,3), so we have to tile here for (H,W) # and (H,W,1) if shape_is_hw: image = np.tile(image[..., np.newaxis], (1, 1, 3)) elif shape_is_hw1: image = np.tile(image, (1, 1, 3)) spatial_window_radius = max(spatial_window_radius, 0) color_window_radius = max(color_window_radius, 0) image = _normalize_cv2_input_arr_(image) image = cv2.pyrMeanShiftFiltering(image, sp=spatial_window_radius, sr=color_window_radius, dst=image) if shape_is_hw: image = image[..., 0] elif shape_is_hw1: image = image[..., 0:1] return image
def stylize_cartoon(image, blur_ksize=3, segmentation_size=1.0, saturation=2.0, edge_prevalence=1.0, suppress_edges=True, from_colorspace=colorlib.CSPACE_RGB): """Convert the style of an image to a more cartoonish one. This function was primarily designed for images with a size of ``200`` to ``800`` pixels. Smaller or larger images may cause issues. Note that the quality of the results can currently not compete with learned style transfer, let alone human-made images. A lack of detected edges or also too many detected edges are probably the most significant drawbacks. This method is loosely based on the one proposed in https://stackoverflow.com/a/11614479/3760780 Added in 0.4.0. **Supported dtypes**: * ``uint8``: yes; fully tested * ``uint16``: no * ``uint32``: no * ``uint64``: no * ``int8``: no * ``int16``: no * ``int32``: no * ``int64``: no * ``float16``: no * ``float32``: no * ``float64``: no * ``float128``: no * ``bool``: no Parameters ---------- image : ndarray A ``(H,W,3) uint8`` image array. blur_ksize : int, optional Kernel size of the median blur filter applied initially to the input image. Expected to be an odd value and ``>=0``. If an even value, thn automatically increased to an odd one. If ``<=1``, no blur will be applied. segmentation_size : float, optional Size multiplier to decrease/increase the base size of the initial mean-shift segmentation of the image. Expected to be ``>=0``. Note that the base size is increased by roughly a factor of two for images with height and/or width ``>=400``. edge_prevalence : float, optional Multiplier for the prevalance of edges. Higher values lead to more edges. Note that the default value of ``1.0`` is already fairly conservative, so there is limit effect from lowerin it further. saturation : float, optional Multiplier for the saturation. Set to ``1.0`` to not change the image's saturation. suppress_edges : bool, optional Whether to run edge suppression to remove blobs containing too many or too few edge pixels. from_colorspace : str, optional The source colorspace. Use one of ``imgaug.augmenters.color.CSPACE_*``. Defaults to ``RGB``. Returns ------- ndarray Image in cartoonish style. """ iadt.gate_dtypes(image, allowed=["uint8"], disallowed=[ "bool", "uint16", "uint32", "uint64", "uint128", "uint256", "int8", "int16", "int32", "int64", "int128", "int256", "float16", "float32", "float64", "float96", "float128", "float256" ], augmenter=None) assert image.ndim == 3 and image.shape[2] == 3, ( "Expected to get a (H,W,C) image, got shape %s." % (image.shape, )) blur_ksize = max(int(np.round(blur_ksize)), 1) segmentation_size = max(segmentation_size, 0.0) saturation = max(saturation, 0.0) is_small_image = max(image.shape[0:2]) < 400 image = _blur_median(image, blur_ksize) image_seg = np.zeros_like(image) if is_small_image: spatial_window_radius = int(10 * segmentation_size) color_window_radius = int(20 * segmentation_size) else: spatial_window_radius = int(15 * segmentation_size) color_window_radius = int(40 * segmentation_size) if segmentation_size <= 0: image_seg = image else: cv2.pyrMeanShiftFiltering(_normalize_cv2_input_arr_(image), sp=spatial_window_radius, sr=color_window_radius, dst=image_seg) if is_small_image: edges_raw = _find_edges_canny(image_seg, edge_prevalence, from_colorspace) else: edges_raw = _find_edges_laplacian(image_seg, edge_prevalence, from_colorspace) edges = edges_raw edges = ((edges > 100) * 255).astype(np.uint8) if suppress_edges: # Suppress dense 3x3 blobs full of detected edges. They are visually # ugly. edges = _suppress_edge_blobs(edges, 3, 8, inverse=False) # Suppress spurious few-pixel edges (5x5 size with <=3 edge pixels). edges = _suppress_edge_blobs(edges, 5, 3, inverse=True) return _saturate(_blend_edges(image_seg, edges), saturation, from_colorspace)
def _blur_median(image, ksize): if ksize % 2 == 0: ksize += 1 if ksize <= 1: return image return cv2.medianBlur(_normalize_cv2_input_arr_(image), ksize)
def _augment_batch_(self, batch, random_state, parents, hooks): if batch.images is None: return batch images = batch.images iadt.gate_dtypes(images, allowed=[ "bool", "uint8", "uint16", "int8", "int16", "float16", "float32", "float64" ], disallowed=[ "uint32", "uint64", "uint128", "uint256", "int32", "int64", "int128", "int256", "float96", "float128", "float256" ], augmenter=self) rss = random_state.duplicate(len(images)) for i, image in enumerate(images): _height, _width, nb_channels = image.shape # currently we don't have to worry here about alignemnt with # non-image data and therefore can just place this before any # sampling if image.size == 0: continue input_dtype = image.dtype if image.dtype.name in ["bool", "float16"]: image = image.astype(np.float32, copy=False) elif image.dtype.name == "int8": image = image.astype(np.int16, copy=False) if self.matrix_type == "None": matrices = [None] * nb_channels elif self.matrix_type == "constant": matrices = [self.matrix] * nb_channels elif self.matrix_type == "function": matrices = self.matrix(images[i], nb_channels, rss[i]) if ia.is_np_array(matrices) and matrices.ndim == 2: matrices = np.tile(matrices[..., np.newaxis], (1, 1, nb_channels)) is_valid_list = (isinstance(matrices, list) and len(matrices) == nb_channels) is_valid_array = (ia.is_np_array(matrices) and matrices.ndim == 3 and matrices.shape[2] == nb_channels) assert is_valid_list or is_valid_array, ( "Callable provided to Convole must return either a " "list of 2D matrices (one per image channel) " "or a 2D numpy array " "or a 3D numpy array where the last dimension's size " "matches the number of image channels. " "Got type %s." % (type(matrices), )) if ia.is_np_array(matrices): # Shape of matrices is currently (H, W, C), but in the # loop below we need the first axis to be the channel # index to unify handling of lists of arrays and arrays. # So we move the channel axis here to the start. matrices = matrices.transpose((2, 0, 1)) else: raise Exception("Invalid matrix type") # TODO check if sampled matrices are identical over channels # and then just apply once. (does that really help wrt speed?) image_aug = image for channel in sm.xrange(nb_channels): if matrices[channel] is not None: # ndimage.convolve caused problems here cv2.filter2D() # always returns same output dtype as input dtype image_aug[..., channel] = cv2.filter2D( _normalize_cv2_input_arr_(image_aug[..., channel]), -1, matrices[channel]) if input_dtype.name == "bool": image_aug = image_aug > 0.5 elif input_dtype.name in ["int8", "float16"]: image_aug = iadt.restore_dtypes_(image_aug, input_dtype) batch.images[i] = image_aug return batch