示例#1
0
    def test_smoke(self, device, dtype):
        x = torch.tensor([1., 2., 1.], device=device, dtype=dtype)
        y_expected = torch.tensor([0., 2., 1., 0.], device=device, dtype=dtype)

        y = _torch_histc_cast(x, bins=4, min=0, max=3)

        assert_allclose(y, y_expected)
示例#2
0
def _compute_luts(tiles_x_im: torch.Tensor,
                  num_bins: int = 256,
                  clip: float = 40.0,
                  diff: bool = False) -> torch.Tensor:
    r"""Compute luts for a batched set of tiles.

    Same approach as in OpenCV (https://github.com/opencv/opencv/blob/master/modules/imgproc/src/clahe.cpp)

    Args:
        tiles_x_im (torch.Tensor): set of tiles per image to apply the lut. (B, GH, GW, C, TH, TW)
        num_bins (int, optional): number of bins. default: 256
        clip (float): threshold value for contrast limiting. If it is 0 then the clipping is disabled. Default: 40.
        diff (bool, optional): denote if the differentiable histagram will be used. Default: False

    Returns:
        torch.Tensor: Lut for each tile (B, GH, GW, C, 256)

    """
    assert tiles_x_im.dim() == 6, "Tensor must be 6D."

    b, gh, gw, c, th, tw = tiles_x_im.shape
    pixels: int = th * tw
    tiles: torch.Tensor = tiles_x_im.view(
        -1, pixels)  # test with view  # T x (THxTW)
    if not diff:
        if torch.jit.is_scripting():
            histos = torch.stack([
                _torch_histc_cast(tile, bins=num_bins, min=0, max=1)
                for tile in tiles
            ])
        else:
            histos = torch.stack(
                list(map(_my_histc, tiles, [num_bins] * len(tiles))))
    else:
        bins: torch.Tensor = torch.linspace(0,
                                            1,
                                            num_bins,
                                            device=tiles.device)
        histos = histogram(tiles, bins, torch.tensor(0.001)).squeeze()
        histos *= pixels

    if clip > 0.0:
        max_val: float = max(clip * pixels // num_bins, 1)
        histos.clamp_(max=max_val)
        clipped: torch.Tensor = pixels - histos.sum(1)
        redist: torch.Tensor = clipped // num_bins
        histos += redist[None].transpose(0, 1)
        residual: torch.Tensor = clipped - redist * num_bins
        # trick to avoid using a loop to assign the residual
        v_range: torch.Tensor = torch.arange(num_bins, device=histos.device)
        mat_range: torch.Tensor = v_range.repeat(histos.shape[0], 1)
        histos += mat_range < residual[None].transpose(0, 1)

    lut_scale: float = (num_bins - 1) / pixels
    luts: torch.Tensor = torch.cumsum(histos, 1) * lut_scale
    luts = luts.clamp(0, num_bins - 1).floor(
    )  # to get the same values as converting to int maintaining the type
    luts = luts.view((b, gh, gw, c, num_bins))
    return luts
示例#3
0
def _scale_channel(im: torch.Tensor) -> torch.Tensor:
    r"""Scale the data in the channel to implement equalize.

    Args:
        input: image tensor with shapes like :math:`(H, W)` or :math:`(D, H, W)`.

    Returns:
        image tensor with the batch in the zero position.
    """
    min_ = im.min()
    max_ = im.max()

    if min_.item() < 0.0 and not torch.isclose(
            min_, torch.tensor(0.0, dtype=min_.dtype)):
        raise ValueError(
            f"Values in the input tensor must greater or equal to 0.0. Found {min_.item()}."
        )

    if max_.item() > 1.0 and not torch.isclose(
            max_, torch.tensor(1.0, dtype=max_.dtype)):
        raise ValueError(
            f"Values in the input tensor must lower or equal to 1.0. Found {max_.item()}."
        )

    ndims = len(im.shape)
    if ndims not in (2, 3):
        raise TypeError(
            f"Input tensor must have 2 or 3 dimensions. Found {ndims}.")

    im = im * 255.
    # Compute the histogram of the image channel.
    histo = _torch_histc_cast(im, bins=256, min=0, max=255)
    # For the purposes of computing the step, filter out the nonzeros.
    nonzero_histo = torch.reshape(histo[histo != 0], [-1])
    step = torch.div(torch.sum(nonzero_histo) - nonzero_histo[-1],
                     255,
                     rounding_mode='trunc')

    # If step is zero, return the original image.  Otherwise, build
    # lut from the full histogram and step and then index from it.
    if step == 0:
        result = im
    else:
        # can't index using 2d index. Have to flatten and then reshape
        result = torch.gather(_build_lut(histo, step), 0, im.flatten().long())
        result = result.reshape_as(im)

    return result / 255.0
示例#4
0
def _my_histc(tiles: torch.Tensor, bins: int) -> torch.Tensor:
    return _torch_histc_cast(tiles, bins=bins, min=0, max=1)
示例#5
0
def _compute_luts(tiles_x_im: torch.Tensor,
                  num_bins: int = 256,
                  clip: float = 40.,
                  diff: bool = False) -> torch.Tensor:
    r"""Compute luts for a batched set of tiles.

    Same approach as in OpenCV (https://github.com/opencv/opencv/blob/master/modules/imgproc/src/clahe.cpp)

    Args:
        tiles_x_im (torch.Tensor): set of tiles per image to apply the lut. (B, GH, GW, C, TH, TW)
        num_bins (int, optional): number of bins. default: 256
        clip (float): threshold value for contrast limiting. If it is 0 then the clipping is disabled. Default: 40.
        diff (bool, optional): denote if the differentiable histagram will be used. Default: False

    Returns:
        torch.Tensor: Lut for each tile (B, GH, GW, C, 256)

    """
    assert tiles_x_im.dim() == 6, "Tensor must be 6D."

    b, gh, gw, c, th, tw = tiles_x_im.shape
    pixels: int = th * tw
    tiles: torch.Tensor = tiles_x_im.reshape(
        -1, pixels)  # test with view  # T x (THxTW)
    histos: torch.Tensor = torch.empty((tiles.shape[0], num_bins),
                                       device=tiles.device)
    if not diff:
        for i in range(tiles.shape[0]):
            histos[i] = _torch_histc_cast(tiles[i],
                                          bins=num_bins,
                                          min=0,
                                          max=1)
    else:
        bins: torch.Tensor = torch.linspace(0,
                                            1,
                                            num_bins,
                                            device=tiles.device)
        histos = histogram(tiles, bins, torch.tensor(0.001)).squeeze()
        histos *= pixels

    # clip limit (TODO: optimice the code)
    if clip > 0.:
        clip_limit: torch.Tensor = torch.tensor(max(clip * pixels // num_bins,
                                                    1),
                                                dtype=histos.dtype,
                                                device=tiles.device)

        clip_idxs: torch.Tensor = histos > clip_limit
        for i in range(histos.shape[0]):
            hist: torch.Tensor = histos[i]
            idxs = clip_idxs[i]
            if idxs.any():
                clipped: float = float((hist[idxs] - clip_limit).sum().item())
                hist = torch.where(idxs, clip_limit, hist)

                redist: float = clipped // num_bins
                hist += redist

                residual: float = clipped - redist * num_bins
                if residual:
                    hist[0:int(residual)] += 1
                histos[i] = hist

    lut_scale: float = (num_bins - 1) / pixels
    luts: torch.Tensor = torch.cumsum(histos, 1) * lut_scale
    luts = luts.clamp(0, num_bins - 1).floor(
    )  # to get the same values as converting to int maintaining the type
    luts = luts.view((b, gh, gw, c, num_bins))
    return luts