def equalize(input: torch.Tensor) -> torch.Tensor: r"""Apply equalize on the input tensor. Implements Equalize function from PIL using PyTorch ops based on uint8 format: https://github.com/tensorflow/tpu/blob/5f71c12a020403f863434e96982a840578fdd127/models/official/efficientnet/autoaugment.py#L355 Args: input (torch.Tensor): image tensor to equalizr with shapes like :math:`(C, H, W)` or :math:`(B, C, H, W)`. Returns: torch.Tensor: Sharpened image or images with shape as the input. Example: >>> _ = torch.manual_seed(0) >>> x = torch.rand(1, 2, 3, 3) >>> equalize(x) tensor([[[[0.4963, 0.7682, 0.0885], [0.1320, 0.3074, 0.6341], [0.4901, 0.8964, 0.4556]], <BLANKLINE> [[0.6323, 0.3489, 0.4017], [0.0223, 0.1689, 0.2939], [0.5185, 0.6977, 0.8000]]]]) """ input = _to_bchw(input) res = [] for image in input: # Assumes RGB for now. Scales each channel independently # and then stacks the result. scaled_image = torch.stack( [_scale_channel(image[i, :, :]) for i in range(len(image))]) res.append(scaled_image) return torch.stack(res)
def equalize(input: torch.Tensor) -> torch.Tensor: r"""Apply equalize on the input tensor. .. image:: _static/img/equalize.png Implements Equalize function from PIL using PyTorch ops based on uint8 format: https://github.com/tensorflow/tpu/blob/5f71c12a020403f863434e96982a840578fdd127/models/official/efficientnet/autoaugment.py#L355 Args: input: image tensor to equalize with shapes like :math:`(C, H, W)` or :math:`(B, C, H, W)`. Returns: Equalized image tensor with shape :math:`(B, C, H, W)`. Example: >>> x = torch.rand(1, 2, 3, 3) >>> equalize(x).shape torch.Size([1, 2, 3, 3]) """ input = _to_bchw(input) res = [] for image in input: # Assumes RGB for now. Scales each channel independently # and then stacks the result. scaled_image = torch.stack( [_scale_channel(image[i, :, :]) for i in range(len(image))]) res.append(scaled_image) return torch.stack(res)
def sharpness(input: torch.Tensor, factor: Union[float, torch.Tensor]) -> torch.Tensor: r"""Apply sharpness to the input tensor. Implemented Sharpness function from PIL using torch ops. This implementation refers to: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L326 Args: input (torch.Tensor): image tensor with shapes like (C, H, W) or (B, C, H, W) to sharpen. factor (float or torch.Tensor): factor of sharpness strength. Must be above 0. If float or one element tensor, input will be sharpened by the same factor across the whole batch. If 1-d tensor, input will be sharpened element-wisely, len(factor) == len(input). Returns: torch.Tensor: Sharpened image or images. Example: >>> _ = torch.manual_seed(0) >>> sharpness(torch.randn(1, 1, 5, 5), 0.5) tensor([[[[-1.1258, -1.1524, -0.2506, -0.4339, 0.8487], [ 0.6920, -0.1580, -1.0576, 0.1765, -0.1577], [ 1.4437, 0.1998, 0.1799, 0.6588, -0.1435], [-0.1116, -0.3068, 0.8381, 1.3477, 0.0537], [ 0.6181, -0.4128, -0.8411, -2.3160, -0.1023]]]]) """ input = _to_bchw(input) if not isinstance(factor, torch.Tensor): factor = torch.tensor(factor, device=input.device, dtype=input.dtype) if len(factor.size()) != 0: assert factor.shape == torch.Size([input.size(0)]), ( "Input batch size shall match with factor size if factor is not a 0-dim tensor. " f"Got {input.size(0)} and {factor.shape}") kernel = torch.tensor([[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=input.dtype, device=input.device).view(1, 1, 3, 3).repeat( input.size(1), 1, 1, 1) / 13 # This shall be equivalent to depthwise conv2d: # Ref: https://discuss.pytorch.org/t/depthwise-and-separable-convolutions-in-pytorch/7315/2 degenerate = torch.nn.functional.conv2d(input, kernel, bias=None, stride=1, groups=input.size(1)) degenerate = torch.clamp(degenerate, 0., 1.) # For the borders of the resulting image, fill in the values of the original image. mask = torch.ones_like(degenerate) padded_mask = torch.nn.functional.pad(mask, [1, 1, 1, 1]) padded_degenerate = torch.nn.functional.pad(degenerate, [1, 1, 1, 1]) result = torch.where(padded_mask == 1, padded_degenerate, input) if len(factor.size()) == 0: return _blend_one(result, input, factor) return torch.stack([ _blend_one(result[i], input[i], factor[i]) for i in range(len(factor)) ])
def sharpness(input: torch.Tensor, factor: Union[float, torch.Tensor]) -> torch.Tensor: r"""Implements Sharpness function from PIL using torch ops. Args: input (torch.Tensor): image tensor with shapes like (C, H, W) or (B, C, H, W) to sharpen. factor (float or torch.Tensor): factor of sharpness strength. Must be above 0. If float or one element tensor, input will be sharpened by the same factor across the whole batch. If 1-d tensor, input will be sharpened element-wisely, len(factor) == len(input). Returns: torch.Tensor: Sharpened image or images. """ input = _to_bchw(input) if isinstance(factor, torch.Tensor): factor = factor.squeeze() if len(factor.size()) != 0: assert input.size(0) == factor.size(0), \ f"Input batch size shall match with factor size if 1d array. Got {input.size(0)} and {factor.size(0)}" else: factor = float(factor) kernel = torch.tensor([[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=input.dtype).view(1, 1, 3, 3).repeat(3, 1, 1, 1) # This shall be equivalent to depthwise conv2d: # Ref: https://discuss.pytorch.org/t/depthwise-and-separable-convolutions-in-pytorch/7315/2 degenerate = torch.nn.functional.conv2d(input, kernel, bias=None, stride=1, groups=input.size(1)) degenerate = torch.clamp(degenerate, 0., 1.) mask = torch.ones_like(degenerate) padded_mask = torch.nn.functional.pad(mask, [1, 1, 1, 1]) padded_degenerate = torch.nn.functional.pad(degenerate, [1, 1, 1, 1]) result = torch.where(padded_mask == 1, padded_degenerate, input) def _blend_one(input1: torch.Tensor, input2: torch.Tensor, factor: Union[float, torch.Tensor]) -> torch.Tensor: if isinstance(factor, torch.Tensor): factor = factor.squeeze() assert len( factor.size() ) == 0, f"Factor shall be a float or single element tensor. Got {factor}" if factor == 0.: return input1 if factor == 1.: return input2 diff = (input2 - input1) * factor res = input1 + diff if factor > 0. and factor < 1.: return res return torch.clamp(res, 0, 1) if isinstance(factor, (float)) or len(factor.size()) == 0: return _blend_one(input, result, factor) return torch.stack([ _blend_one(input[i], result[i], factor[i]) for i in range(len(factor)) ])
def sharpness(input: torch.Tensor, factor: Union[float, torch.Tensor]) -> torch.Tensor: r"""Apply sharpness to the input tensor. .. image:: _static/img/sharpness.png Implemented Sharpness function from PIL using torch ops. This implementation refers to: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L326 Args: input: image tensor with shapes like (C, H, W) or (B, C, H, W) to sharpen. factor: factor of sharpness strength. Must be above 0. If float or one element tensor, input will be sharpened by the same factor across the whole batch. If 1-d tensor, input will be sharpened element-wisely, len(factor) == len(input). Returns: Sharpened image or images with shape :math:`(B, C, H, W)`. Example: >>> x = torch.rand(1, 1, 5, 5) >>> sharpness(x, 0.5).shape torch.Size([1, 1, 5, 5]) """ input = _to_bchw(input) if not isinstance(factor, torch.Tensor): factor = torch.tensor(factor, device=input.device, dtype=input.dtype) if len(factor.size()) != 0: assert factor.shape == torch.Size([input.size(0)]), ( "Input batch size shall match with factor size if factor is not a 0-dim tensor. " f"Got {input.size(0)} and {factor.shape}") kernel = (torch.tensor([[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=input.dtype, device=input.device).view(1, 1, 3, 3).repeat( input.size(1), 1, 1, 1) / 13) # This shall be equivalent to depthwise conv2d: # Ref: https://discuss.pytorch.org/t/depthwise-and-separable-convolutions-in-pytorch/7315/2 degenerate = torch.nn.functional.conv2d(input, kernel, bias=None, stride=1, groups=input.size(1)) degenerate = torch.clamp(degenerate, 0.0, 1.0) # For the borders of the resulting image, fill in the values of the original image. mask = torch.ones_like(degenerate) padded_mask = torch.nn.functional.pad(mask, [1, 1, 1, 1]) padded_degenerate = torch.nn.functional.pad(degenerate, [1, 1, 1, 1]) result = torch.where(padded_mask == 1, padded_degenerate, input) if len(factor.size()) == 0: return _blend_one(result, input, factor) return torch.stack([ _blend_one(result[i], input[i], factor[i]) for i in range(len(factor)) ])
def _transform_input(input: torch.Tensor) -> torch.Tensor: r"""Reshape an input tensor to be (*, C, H, W). Accept either (H, W), (C, H, W) or (*, C, H, W). Args: input: torch.Tensor Returns: torch.Tensor """ return _to_bchw(input)
def equalize(input: torch.Tensor) -> torch.Tensor: """Implements Equalize function from PIL using PyTorch ops based on uint8 format: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L352 Args: input (torch.Tensor): image tensor with shapes like (C, H, W) or (B, C, H, W) to equalize. Returns: torch.Tensor: Sharpened image or images. """ input = _to_bchw(input) * 255 # Code taken from: https://github.com/pytorch/vision/pull/796 def scale_channel(im, c): """Scale the data in the channel to implement equalize.""" im = im[c, :, :] # Compute the histogram of the image channel. histo = torch.histc(im, bins=256, min=0, max=255) # For the purposes of computing the step, filter out the nonzeros. nonzero_histo = torch.reshape(histo[histo != 0], [-1]) step = (torch.sum(nonzero_histo) - nonzero_histo[-1]) // 255 def build_lut(histo, step): # Compute the cumulative sum, shifting by step // 2 # and then normalization by step. lut = (torch.cumsum(histo, 0) + (step // 2)) // step # Shift lut, prepending with 0. lut = torch.cat([torch.zeros(1), lut[:-1]]) # Clip the counts to be in range. This is done # in the C code for image.point. return torch.clamp(lut, 0, 255) # If step is zero, return the original image. Otherwise, build # lut from the full histogram and step and then index from it. if step == 0: result = im else: # can't index using 2d index. Have to flatten and then reshape result = torch.gather(build_lut(histo, step), 0, im.flatten().long()) result = result.reshape_as(im) return result / 255. res = [] for image in input: # Assumes RGB for now. Scales each channel independently # and then stacks the result. scaled_image = torch.stack( [scale_channel(image, i) for i in range(len(image))]) res.append(scaled_image) return torch.stack(res)
def _compute_tiles( imgs: torch.Tensor, grid_size: Tuple[int, int], even_tile_size: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: r"""Compute tiles on an image according to a grid size. Note that padding can be added to the image in order to crop properly the image. So, the grid_size (GH, GW) x tile_size (TH, TW) >= image_size (H, W) Args: imgs (torch.Tensor): batch of 2D images with shape (B, C, H, W) or (C, H, W). grid_size (Tuple[int, int]): number of tiles to be cropped in each direction (GH, GW) even_tile_size (bool, optional): Determine if the width and height of the tiles must be even. Default: False. Returns: torch.Tensor: tensor with tiles (B, GH, GW, C, TH, TW). B = 1 in case of a single image is provided. torch.Tensor: tensor with the padded batch of 2D imageswith shape (B, C, H', W') """ batch: torch.Tensor = _to_bchw(imgs) # B x C x H x W # compute stride and kernel size h, w = batch.shape[-2:] kernel_vert: int = math.ceil(h / grid_size[0]) kernel_horz: int = math.ceil(w / grid_size[1]) if even_tile_size: kernel_vert += 1 if kernel_vert % 2 else 0 kernel_horz += 1 if kernel_horz % 2 else 0 # add padding (with that kernel size we could need some extra cols and rows...) pad_vert = kernel_vert * grid_size[0] - h pad_horz = kernel_horz * grid_size[1] - w # add the padding in the last coluns and rows if pad_vert > 0 or pad_horz > 0: batch = F.pad(batch, [0, pad_horz, 0, pad_vert], mode='reflect') # B x C x H' x W' # compute tiles c: int = batch.shape[-3] tiles: torch.Tensor = ( batch.unfold(1, c, c) # unfold(dimension, size, step) .unfold(2, kernel_vert, kernel_vert).unfold( 3, kernel_horz, kernel_horz).squeeze(1)).contiguous() # GH x GW x C x TH x TW assert tiles.shape[-5] == grid_size[0] # check the grid size assert tiles.shape[-4] == grid_size[1] return tiles, batch
def equalize(input: torch.Tensor) -> torch.Tensor: r"""Apply equalize on the input tensor. Implements Equalize function from PIL using PyTorch ops based on uint8 format: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L352 Args: input (torch.Tensor): image tensor with shapes like :math:(C, H, W) or :math:(B, C, H, W) to equalize. Returns: torch.Tensor: Sharpened image or images. """ input = _to_bchw(input) res = [] for image in input: # Assumes RGB for now. Scales each channel independently # and then stacks the result. scaled_image = torch.stack([_scale_channel(image[i, :, :]) for i in range(len(image))]) res.append(scaled_image) return torch.stack(res)
def equalize_clahe( input: torch.Tensor, clip_limit: float = 40.0, grid_size: Tuple[int, int] = (8, 8)) -> torch.Tensor: r"""Apply clahe equalization on the input tensor. NOTE: Lut computation uses the same approach as in OpenCV, in next versions this can change. Args: input (torch.Tensor): images tensor to equalize with values in the range [0, 1] and shapes like :math:`(C, H, W)` or :math:`(B, C, H, W)`. clip_limit (float): threshold value for contrast limiting. If 0 clipping is disabled. Default: 40. grid_size (Tuple[int, int]): number of tiles to be cropped in each direction (GH, GW). Default: (8, 8). Returns: torch.Tensor: Equalized image or images with shape as the input. Examples: >>> img = torch.rand(1, 10, 20) >>> res = equalize_clahe(img) >>> res.shape torch.Size([1, 10, 20]) >>> img = torch.rand(2, 3, 10, 20) >>> res = equalize_clahe(img) >>> res.shape torch.Size([2, 3, 10, 20]) """ if not isinstance(input, torch.Tensor): raise TypeError( f"Input input type is not a torch.Tensor. Got {type(input)}") if input.dim() not in [3, 4]: raise ValueError( f"Invalid input shape, we expect CxHxW or BxCxHxW. Got: {input.shape}" ) if input.numel() == 0: raise ValueError("Invalid input tensor, it is empty.") if not isinstance(clip_limit, float): raise TypeError( f"Input clip_limit type is not float. Got {type(clip_limit)}") if not isinstance(grid_size, tuple): raise TypeError( f"Input grid_size type is not Tuple. Got {type(grid_size)}") if len(grid_size) != 2: raise TypeError( f"Input grid_size is not a Tuple with 2 elements. Got {len(grid_size)}" ) if isinstance(grid_size[0], float) or isinstance(grid_size[1], float): raise TypeError( "Input grid_size type is not valid, must be a Tuple[int, int].") if grid_size[0] <= 0 or grid_size[1] <= 0: raise ValueError( "Input grid_size elements must be positive. Got {grid_size}") imgs: torch.Tensor = _to_bchw(input) # B x C x H x W # hist_tiles: torch.Tensor # B x GH x GW x C x TH x TW # not supported by JIT # img_padded: torch.Tensor # B x C x H' x W' # not supported by JIT # the size of the tiles must be even in order to divide them into 4 tiles for the interpolation hist_tiles, img_padded = _compute_tiles(imgs, grid_size, True) tile_size: Tuple[int, int] = (hist_tiles.shape[-2], hist_tiles.shape[-1]) interp_tiles: torch.Tensor = _compute_interpolation_tiles( img_padded, tile_size) # B x 2GH x 2GW x C x TH/2 x TW/2 luts: torch.Tensor = _compute_luts(hist_tiles, clip=clip_limit) # B x GH x GW x C x B equalized_tiles: torch.Tensor = _compute_equalized_tiles( interp_tiles, luts) # B x 2GH x 2GW x C x TH/2 x TW/2 # reconstruct the images form the tiles # try permute + contiguous + view eq_imgs: torch.Tensor = equalized_tiles.permute(0, 3, 1, 4, 2, 5).reshape_as(img_padded) h, w = imgs.shape[-2:] eq_imgs = eq_imgs[..., :h, :w] # crop imgs if they were padded # remove batch if the input was not in batch form if input.dim() != eq_imgs.dim(): eq_imgs = eq_imgs.squeeze(0) return eq_imgs
def posterize(input: torch.Tensor, bits: Union[int, torch.Tensor]) -> torch.Tensor: r"""Reduce the number of bits for each color channel. Non-differentiable function, uint8 involved. Args: input (torch.Tensor): image tensor with shapes like (C, H, W) or (B, C, H, W) to posterize. bits (int or torch.Tensor): number of high bits. Must be in range [0, 8]. If int or one element tensor, input will be posterized by this bits. If 1-d tensor, input will be posterized element-wisely, len(bits) == input.shape[1]. If n-d tensor, input will be posterized element-channel-wisely, bits.shape == input.shape[:len(bits.shape)] Returns: torch.Tensor: Image with reduced color channels. """ if not torch.is_tensor(input): raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") if isinstance(bits, int): bits = torch.tensor(bits) if not torch.all((bits >= 0) * (bits <= 8)) and bits.dtype == torch.int: raise ValueError(f"bits must be integers within range [0, 8]. Got {bits}.") # TODO: Make a differentiable version # Current version: # Ref: https://github.com/open-mmlab/mmcv/pull/132/files#diff-309c9320c7f71bedffe89a70ccff7f3bR19 # Ref: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L222 # Potential approach: implementing kornia.LUT with floating points # https://github.com/albumentations-team/albumentations/blob/master/albumentations/augmentations/functional.py#L472 def _left_shift(input: torch.Tensor, shift: torch.Tensor): return ((input * 255).to(torch.uint8) * (2 ** shift)).to(input.dtype) / 255. def _right_shift(input: torch.Tensor, shift: torch.Tensor): return (input * 255).to(torch.uint8) / (2 ** shift).to(input.dtype) / 255. def _posterize_one(input: torch.Tensor, bits: torch.Tensor): # Single bits value condition if bits == 0: return torch.zeros_like(input) if bits == 8: return input.clone() bits = 8 - bits return _left_shift(_right_shift(input, bits), bits) if len(bits.shape) == 0 or (len(bits.shape) == 1 and len(bits) == 1): return _posterize_one(input, bits) res = [] if len(bits.shape) == 1: input = _to_bchw(input) assert bits.shape[0] == input.shape[0], \ f"Batch size must be equal between bits and input. Got {bits.shape[0]}, {input.shape[0]}." for i in range(input.shape[0]): res.append(_posterize_one(input[i], bits[i])) return torch.stack(res, dim=0) assert bits.shape == input.shape[:len(bits.shape)], \ f"Batch and channel must be equal between bits and input. Got {bits.shape}, {input.shape[:len(bits.shape)]}." _input = input.view(-1, *input.shape[len(bits.shape):]) _bits = bits.flatten() for i in range(input.shape[0]): res.append(_posterize_one(_input[i], _bits[i])) return torch.stack(res, dim=0).reshape(*input.shape)
def resize(input: torch.Tensor, size: Union[int, Tuple[int, int]], interpolation: str = 'bilinear', align_corners: Optional[bool] = None, side: str = "short", antialias: bool = False) -> torch.Tensor: r"""Resize the input torch.Tensor to the given size. Args: tensor (torch.Tensor): The image tensor to be skewed with shape of :math:`(B, C, H, W)`. size (int, tuple(int, int)): Desired output size. If size is a sequence like (h, w), output size will be matched to this. If size is an int, smaller edge of the image will be matched to this number. i.e, if height > width, then image will be rescaled to (size * height / width, size) interpolation (str): algorithm used for upsampling: 'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'. Default: 'bilinear'. align_corners(bool): interpolation flag. Default: None. See https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for detail side (str): Corresponding side if ``size`` is an integer. Can be one of ``"short"``, ``"long"``, ``"vert"``, or ``"horz"``. Defaults to ``"short"``. antialias (bool): if True, then image will be filtered with Gaussian before downscaling. No effect for upscaling. Default: False Returns: torch.Tensor: The resized tensor with the shape as the specified size. Example: >>> img = torch.rand(1, 3, 4, 4) >>> out = resize(img, (6, 8)) >>> print(out.shape) torch.Size([1, 3, 6, 8]) """ if not isinstance(input, torch.Tensor): raise TypeError("Input tensor type is not a torch.Tensor. Got {}" .format(type(input))) input_size = h, w = input.shape[-2:] if isinstance(size, int): aspect_ratio = w / h size = _side_to_image_size(size, aspect_ratio, side) if size == input_size: return input # TODO: find a proper way to handle this cases in the future input_tmp = _to_bchw(input) factors = (h / size[0], w / size[1]) # We do bluring only for downscaling antialias = antialias and (max(factors) > 1) if antialias: # First, we have to determine sigma sigmas = (max(factors[0], 1.0), max(factors[1], 1.0)) # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now ks = int(2.0 * 2 * sigmas[0] + 1), int(2.0 * 2 * sigmas[1] + 1) input_tmp = kornia.filters.gaussian_blur2d(input_tmp, ks, sigmas) output = torch.nn.functional.interpolate( input_tmp, size=size, mode=interpolation, align_corners=align_corners) if len(input.shape) != len(output.shape): output = output.squeeze() return output