def test_smoke(self, device, dtype): x = torch.tensor([1., 2., 1.], device=device, dtype=dtype) y_expected = torch.tensor([0., 2., 1., 0.], device=device, dtype=dtype) y = _torch_histc_cast(x, bins=4, min=0, max=3) assert_allclose(y, y_expected)
def _compute_luts(tiles_x_im: torch.Tensor, num_bins: int = 256, clip: float = 40.0, diff: bool = False) -> torch.Tensor: r"""Compute luts for a batched set of tiles. Same approach as in OpenCV (https://github.com/opencv/opencv/blob/master/modules/imgproc/src/clahe.cpp) Args: tiles_x_im (torch.Tensor): set of tiles per image to apply the lut. (B, GH, GW, C, TH, TW) num_bins (int, optional): number of bins. default: 256 clip (float): threshold value for contrast limiting. If it is 0 then the clipping is disabled. Default: 40. diff (bool, optional): denote if the differentiable histagram will be used. Default: False Returns: torch.Tensor: Lut for each tile (B, GH, GW, C, 256) """ assert tiles_x_im.dim() == 6, "Tensor must be 6D." b, gh, gw, c, th, tw = tiles_x_im.shape pixels: int = th * tw tiles: torch.Tensor = tiles_x_im.view( -1, pixels) # test with view # T x (THxTW) if not diff: if torch.jit.is_scripting(): histos = torch.stack([ _torch_histc_cast(tile, bins=num_bins, min=0, max=1) for tile in tiles ]) else: histos = torch.stack( list(map(_my_histc, tiles, [num_bins] * len(tiles)))) else: bins: torch.Tensor = torch.linspace(0, 1, num_bins, device=tiles.device) histos = histogram(tiles, bins, torch.tensor(0.001)).squeeze() histos *= pixels if clip > 0.0: max_val: float = max(clip * pixels // num_bins, 1) histos.clamp_(max=max_val) clipped: torch.Tensor = pixels - histos.sum(1) redist: torch.Tensor = clipped // num_bins histos += redist[None].transpose(0, 1) residual: torch.Tensor = clipped - redist * num_bins # trick to avoid using a loop to assign the residual v_range: torch.Tensor = torch.arange(num_bins, device=histos.device) mat_range: torch.Tensor = v_range.repeat(histos.shape[0], 1) histos += mat_range < residual[None].transpose(0, 1) lut_scale: float = (num_bins - 1) / pixels luts: torch.Tensor = torch.cumsum(histos, 1) * lut_scale luts = luts.clamp(0, num_bins - 1).floor( ) # to get the same values as converting to int maintaining the type luts = luts.view((b, gh, gw, c, num_bins)) return luts
def _scale_channel(im: torch.Tensor) -> torch.Tensor: r"""Scale the data in the channel to implement equalize. Args: input: image tensor with shapes like :math:`(H, W)` or :math:`(D, H, W)`. Returns: image tensor with the batch in the zero position. """ min_ = im.min() max_ = im.max() if min_.item() < 0.0 and not torch.isclose( min_, torch.tensor(0.0, dtype=min_.dtype)): raise ValueError( f"Values in the input tensor must greater or equal to 0.0. Found {min_.item()}." ) if max_.item() > 1.0 and not torch.isclose( max_, torch.tensor(1.0, dtype=max_.dtype)): raise ValueError( f"Values in the input tensor must lower or equal to 1.0. Found {max_.item()}." ) ndims = len(im.shape) if ndims not in (2, 3): raise TypeError( f"Input tensor must have 2 or 3 dimensions. Found {ndims}.") im = im * 255. # Compute the histogram of the image channel. histo = _torch_histc_cast(im, bins=256, min=0, max=255) # For the purposes of computing the step, filter out the nonzeros. nonzero_histo = torch.reshape(histo[histo != 0], [-1]) step = torch.div(torch.sum(nonzero_histo) - nonzero_histo[-1], 255, rounding_mode='trunc') # If step is zero, return the original image. Otherwise, build # lut from the full histogram and step and then index from it. if step == 0: result = im else: # can't index using 2d index. Have to flatten and then reshape result = torch.gather(_build_lut(histo, step), 0, im.flatten().long()) result = result.reshape_as(im) return result / 255.0
def _my_histc(tiles: torch.Tensor, bins: int) -> torch.Tensor: return _torch_histc_cast(tiles, bins=bins, min=0, max=1)
def _compute_luts(tiles_x_im: torch.Tensor, num_bins: int = 256, clip: float = 40., diff: bool = False) -> torch.Tensor: r"""Compute luts for a batched set of tiles. Same approach as in OpenCV (https://github.com/opencv/opencv/blob/master/modules/imgproc/src/clahe.cpp) Args: tiles_x_im (torch.Tensor): set of tiles per image to apply the lut. (B, GH, GW, C, TH, TW) num_bins (int, optional): number of bins. default: 256 clip (float): threshold value for contrast limiting. If it is 0 then the clipping is disabled. Default: 40. diff (bool, optional): denote if the differentiable histagram will be used. Default: False Returns: torch.Tensor: Lut for each tile (B, GH, GW, C, 256) """ assert tiles_x_im.dim() == 6, "Tensor must be 6D." b, gh, gw, c, th, tw = tiles_x_im.shape pixels: int = th * tw tiles: torch.Tensor = tiles_x_im.reshape( -1, pixels) # test with view # T x (THxTW) histos: torch.Tensor = torch.empty((tiles.shape[0], num_bins), device=tiles.device) if not diff: for i in range(tiles.shape[0]): histos[i] = _torch_histc_cast(tiles[i], bins=num_bins, min=0, max=1) else: bins: torch.Tensor = torch.linspace(0, 1, num_bins, device=tiles.device) histos = histogram(tiles, bins, torch.tensor(0.001)).squeeze() histos *= pixels # clip limit (TODO: optimice the code) if clip > 0.: clip_limit: torch.Tensor = torch.tensor(max(clip * pixels // num_bins, 1), dtype=histos.dtype, device=tiles.device) clip_idxs: torch.Tensor = histos > clip_limit for i in range(histos.shape[0]): hist: torch.Tensor = histos[i] idxs = clip_idxs[i] if idxs.any(): clipped: float = float((hist[idxs] - clip_limit).sum().item()) hist = torch.where(idxs, clip_limit, hist) redist: float = clipped // num_bins hist += redist residual: float = clipped - redist * num_bins if residual: hist[0:int(residual)] += 1 histos[i] = hist lut_scale: float = (num_bins - 1) / pixels luts: torch.Tensor = torch.cumsum(histos, 1) * lut_scale luts = luts.clamp(0, num_bins - 1).floor( ) # to get the same values as converting to int maintaining the type luts = luts.view((b, gh, gw, c, num_bins)) return luts