def _compute_luts(tiles_x_im: torch.Tensor, num_bins: int = 256, clip: float = 40.0, diff: bool = False) -> torch.Tensor: r"""Compute luts for a batched set of tiles. Same approach as in OpenCV (https://github.com/opencv/opencv/blob/master/modules/imgproc/src/clahe.cpp) Args: tiles_x_im (torch.Tensor): set of tiles per image to apply the lut. (B, GH, GW, C, TH, TW) num_bins (int, optional): number of bins. default: 256 clip (float): threshold value for contrast limiting. If it is 0 then the clipping is disabled. Default: 40. diff (bool, optional): denote if the differentiable histagram will be used. Default: False Returns: torch.Tensor: Lut for each tile (B, GH, GW, C, 256) """ assert tiles_x_im.dim() == 6, "Tensor must be 6D." b, gh, gw, c, th, tw = tiles_x_im.shape pixels: int = th * tw tiles: torch.Tensor = tiles_x_im.view( -1, pixels) # test with view # T x (THxTW) if not diff: if torch.jit.is_scripting(): histos = torch.stack([ _torch_histc_cast(tile, bins=num_bins, min=0, max=1) for tile in tiles ]) else: histos = torch.stack( list(map(_my_histc, tiles, [num_bins] * len(tiles)))) else: bins: torch.Tensor = torch.linspace(0, 1, num_bins, device=tiles.device) histos = histogram(tiles, bins, torch.tensor(0.001)).squeeze() histos *= pixels if clip > 0.0: max_val: float = max(clip * pixels // num_bins, 1) histos.clamp_(max=max_val) clipped: torch.Tensor = pixels - histos.sum(1) redist: torch.Tensor = clipped // num_bins histos += redist[None].transpose(0, 1) residual: torch.Tensor = clipped - redist * num_bins # trick to avoid using a loop to assign the residual v_range: torch.Tensor = torch.arange(num_bins, device=histos.device) mat_range: torch.Tensor = v_range.repeat(histos.shape[0], 1) histos += mat_range < residual[None].transpose(0, 1) lut_scale: float = (num_bins - 1) / pixels luts: torch.Tensor = torch.cumsum(histos, 1) * lut_scale luts = luts.clamp(0, num_bins - 1).floor( ) # to get the same values as converting to int maintaining the type luts = luts.view((b, gh, gw, c, num_bins)) return luts
def _compute_luts(tiles_x_im: torch.Tensor, num_bins: int = 256, clip: float = 40., diff: bool = False) -> torch.Tensor: r"""Compute luts for a batched set of tiles. Same approach as in OpenCV (https://github.com/opencv/opencv/blob/master/modules/imgproc/src/clahe.cpp) Args: tiles_x_im (torch.Tensor): set of tiles per image to apply the lut. (B, GH, GW, C, TH, TW) num_bins (int, optional): number of bins. default: 256 clip (float): threshold value for contrast limiting. If it is 0 then the clipping is disabled. Default: 40. diff (bool, optional): denote if the differentiable histagram will be used. Default: False Returns: torch.Tensor: Lut for each tile (B, GH, GW, C, 256) """ assert tiles_x_im.dim() == 6, "Tensor must be 6D." b, gh, gw, c, th, tw = tiles_x_im.shape pixels: int = th * tw tiles: torch.Tensor = tiles_x_im.reshape( -1, pixels) # test with view # T x (THxTW) histos: torch.Tensor = torch.empty((tiles.shape[0], num_bins), device=tiles.device) if not diff: for i in range(tiles.shape[0]): histos[i] = torch.histc(tiles[i], bins=num_bins, min=0, max=1) else: bins: torch.Tensor = torch.linspace(0, 1, num_bins, device=tiles.device) histos = histogram(tiles, bins, torch.tensor(0.001)).squeeze() histos *= pixels # clip limit (TODO: optimice the code) if clip > 0.: clip_limit: torch.Tensor = torch.tensor(max(clip * pixels // num_bins, 1), dtype=histos.dtype, device=tiles.device) clip_idxs: torch.Tensor = histos > clip_limit for i in range(histos.shape[0]): hist: torch.Tensor = histos[i] idxs = clip_idxs[i] if idxs.any(): clipped: float = float((hist[idxs] - clip_limit).sum().item()) hist = torch.where(idxs, clip_limit, hist) redist: float = clipped // num_bins hist += redist residual: float = clipped - redist * num_bins if residual: hist[0:int(residual)] += 1 histos[i] = hist lut_scale: float = (num_bins - 1) / pixels luts: torch.Tensor = torch.cumsum(histos, 1) * lut_scale luts = luts.clamp(0, num_bins - 1).floor( ) # to get the same values as converting to int maintaining the type luts = luts.view((b, gh, gw, c, num_bins)) return luts