コード例 #1
0
ファイル: adjust.py プロジェクト: jeffshee/kornia
def equalize(input: torch.Tensor) -> torch.Tensor:
    r"""Apply equalize on the input tensor.

    Implements Equalize function from PIL using PyTorch ops based on uint8 format:
    https://github.com/tensorflow/tpu/blob/5f71c12a020403f863434e96982a840578fdd127/models/official/efficientnet/autoaugment.py#L355

    Args:
        input (torch.Tensor): image tensor to equalizr with shapes like :math:`(C, H, W)` or :math:`(B, C, H, W)`.

    Returns:
        torch.Tensor: Sharpened image or images with shape as the input.

    Example:
        >>> _ = torch.manual_seed(0)
        >>> x = torch.rand(1, 2, 3, 3)
        >>> equalize(x)
        tensor([[[[0.4963, 0.7682, 0.0885],
                  [0.1320, 0.3074, 0.6341],
                  [0.4901, 0.8964, 0.4556]],
        <BLANKLINE>
                 [[0.6323, 0.3489, 0.4017],
                  [0.0223, 0.1689, 0.2939],
                  [0.5185, 0.6977, 0.8000]]]])
    """
    input = _to_bchw(input)

    res = []
    for image in input:
        # Assumes RGB for now.  Scales each channel independently
        # and then stacks the result.
        scaled_image = torch.stack(
            [_scale_channel(image[i, :, :]) for i in range(len(image))])
        res.append(scaled_image)
    return torch.stack(res)
コード例 #2
0
def equalize(input: torch.Tensor) -> torch.Tensor:
    r"""Apply equalize on the input tensor.

    .. image:: _static/img/equalize.png

    Implements Equalize function from PIL using PyTorch ops based on uint8 format:
    https://github.com/tensorflow/tpu/blob/5f71c12a020403f863434e96982a840578fdd127/models/official/efficientnet/autoaugment.py#L355

    Args:
        input: image tensor to equalize with shapes like :math:`(C, H, W)` or :math:`(B, C, H, W)`.

    Returns:
        Equalized image tensor with shape :math:`(B, C, H, W)`.

    Example:
        >>> x = torch.rand(1, 2, 3, 3)
        >>> equalize(x).shape
        torch.Size([1, 2, 3, 3])
    """
    input = _to_bchw(input)

    res = []
    for image in input:
        # Assumes RGB for now.  Scales each channel independently
        # and then stacks the result.
        scaled_image = torch.stack(
            [_scale_channel(image[i, :, :]) for i in range(len(image))])
        res.append(scaled_image)
    return torch.stack(res)
コード例 #3
0
ファイル: adjust.py プロジェクト: jeffshee/kornia
def sharpness(input: torch.Tensor,
              factor: Union[float, torch.Tensor]) -> torch.Tensor:
    r"""Apply sharpness to the input tensor.

    Implemented Sharpness function from PIL using torch ops. This implementation refers to:
    https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L326

    Args:
        input (torch.Tensor): image tensor with shapes like (C, H, W) or (B, C, H, W) to sharpen.
        factor (float or torch.Tensor): factor of sharpness strength. Must be above 0.
            If float or one element tensor, input will be sharpened by the same factor across the whole batch.
            If 1-d tensor, input will be sharpened element-wisely, len(factor) == len(input).

    Returns:
        torch.Tensor: Sharpened image or images.

    Example:
        >>> _ = torch.manual_seed(0)
        >>> sharpness(torch.randn(1, 1, 5, 5), 0.5)
        tensor([[[[-1.1258, -1.1524, -0.2506, -0.4339,  0.8487],
                  [ 0.6920, -0.1580, -1.0576,  0.1765, -0.1577],
                  [ 1.4437,  0.1998,  0.1799,  0.6588, -0.1435],
                  [-0.1116, -0.3068,  0.8381,  1.3477,  0.0537],
                  [ 0.6181, -0.4128, -0.8411, -2.3160, -0.1023]]]])
    """
    input = _to_bchw(input)
    if not isinstance(factor, torch.Tensor):
        factor = torch.tensor(factor, device=input.device, dtype=input.dtype)

    if len(factor.size()) != 0:
        assert factor.shape == torch.Size([input.size(0)]), (
            "Input batch size shall match with factor size if factor is not a 0-dim tensor. "
            f"Got {input.size(0)} and {factor.shape}")

    kernel = torch.tensor([[1, 1, 1], [1, 5, 1], [1, 1, 1]],
                          dtype=input.dtype,
                          device=input.device).view(1, 1, 3, 3).repeat(
                              input.size(1), 1, 1, 1) / 13

    # This shall be equivalent to depthwise conv2d:
    # Ref: https://discuss.pytorch.org/t/depthwise-and-separable-convolutions-in-pytorch/7315/2
    degenerate = torch.nn.functional.conv2d(input,
                                            kernel,
                                            bias=None,
                                            stride=1,
                                            groups=input.size(1))
    degenerate = torch.clamp(degenerate, 0., 1.)

    # For the borders of the resulting image, fill in the values of the original image.
    mask = torch.ones_like(degenerate)
    padded_mask = torch.nn.functional.pad(mask, [1, 1, 1, 1])
    padded_degenerate = torch.nn.functional.pad(degenerate, [1, 1, 1, 1])
    result = torch.where(padded_mask == 1, padded_degenerate, input)

    if len(factor.size()) == 0:
        return _blend_one(result, input, factor)
    return torch.stack([
        _blend_one(result[i], input[i], factor[i]) for i in range(len(factor))
    ])
コード例 #4
0
def sharpness(input: torch.Tensor,
              factor: Union[float, torch.Tensor]) -> torch.Tensor:
    r"""Implements Sharpness function from PIL using torch ops.
    Args:
        input (torch.Tensor): image tensor with shapes like (C, H, W) or (B, C, H, W) to sharpen.
        factor (float or torch.Tensor): factor of sharpness strength. Must be above 0.
            If float or one element tensor, input will be sharpened by the same factor across the whole batch.
            If 1-d tensor, input will be sharpened element-wisely, len(factor) == len(input).
    Returns:
        torch.Tensor: Sharpened image or images.
    """
    input = _to_bchw(input)
    if isinstance(factor, torch.Tensor):
        factor = factor.squeeze()
        if len(factor.size()) != 0:
            assert input.size(0) == factor.size(0), \
                f"Input batch size shall match with factor size if 1d array. Got {input.size(0)} and {factor.size(0)}"
    else:
        factor = float(factor)
    kernel = torch.tensor([[1, 1, 1], [1, 5, 1], [1, 1, 1]],
                          dtype=input.dtype).view(1, 1, 3,
                                                  3).repeat(3, 1, 1, 1)

    # This shall be equivalent to depthwise conv2d:
    # Ref: https://discuss.pytorch.org/t/depthwise-and-separable-convolutions-in-pytorch/7315/2
    degenerate = torch.nn.functional.conv2d(input,
                                            kernel,
                                            bias=None,
                                            stride=1,
                                            groups=input.size(1))
    degenerate = torch.clamp(degenerate, 0., 1.)

    mask = torch.ones_like(degenerate)
    padded_mask = torch.nn.functional.pad(mask, [1, 1, 1, 1])
    padded_degenerate = torch.nn.functional.pad(degenerate, [1, 1, 1, 1])
    result = torch.where(padded_mask == 1, padded_degenerate, input)

    def _blend_one(input1: torch.Tensor, input2: torch.Tensor,
                   factor: Union[float, torch.Tensor]) -> torch.Tensor:
        if isinstance(factor, torch.Tensor):
            factor = factor.squeeze()
            assert len(
                factor.size()
            ) == 0, f"Factor shall be a float or single element tensor. Got {factor}"
        if factor == 0.:
            return input1
        if factor == 1.:
            return input2
        diff = (input2 - input1) * factor
        res = input1 + diff
        if factor > 0. and factor < 1.:
            return res
        return torch.clamp(res, 0, 1)

    if isinstance(factor, (float)) or len(factor.size()) == 0:
        return _blend_one(input, result, factor)
    return torch.stack([
        _blend_one(input[i], result[i], factor[i]) for i in range(len(factor))
    ])
コード例 #5
0
def sharpness(input: torch.Tensor,
              factor: Union[float, torch.Tensor]) -> torch.Tensor:
    r"""Apply sharpness to the input tensor.

    .. image:: _static/img/sharpness.png

    Implemented Sharpness function from PIL using torch ops. This implementation refers to:
    https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L326

    Args:
        input: image tensor with shapes like (C, H, W) or (B, C, H, W) to sharpen.
        factor: factor of sharpness strength. Must be above 0.
            If float or one element tensor, input will be sharpened by the same factor across the whole batch.
            If 1-d tensor, input will be sharpened element-wisely, len(factor) == len(input).

    Returns:
        Sharpened image or images with shape :math:`(B, C, H, W)`.

    Example:
        >>> x = torch.rand(1, 1, 5, 5)
        >>> sharpness(x, 0.5).shape
        torch.Size([1, 1, 5, 5])
    """
    input = _to_bchw(input)
    if not isinstance(factor, torch.Tensor):
        factor = torch.tensor(factor, device=input.device, dtype=input.dtype)

    if len(factor.size()) != 0:
        assert factor.shape == torch.Size([input.size(0)]), (
            "Input batch size shall match with factor size if factor is not a 0-dim tensor. "
            f"Got {input.size(0)} and {factor.shape}")

    kernel = (torch.tensor([[1, 1, 1], [1, 5, 1], [1, 1, 1]],
                           dtype=input.dtype,
                           device=input.device).view(1, 1, 3, 3).repeat(
                               input.size(1), 1, 1, 1) / 13)

    # This shall be equivalent to depthwise conv2d:
    # Ref: https://discuss.pytorch.org/t/depthwise-and-separable-convolutions-in-pytorch/7315/2
    degenerate = torch.nn.functional.conv2d(input,
                                            kernel,
                                            bias=None,
                                            stride=1,
                                            groups=input.size(1))
    degenerate = torch.clamp(degenerate, 0.0, 1.0)

    # For the borders of the resulting image, fill in the values of the original image.
    mask = torch.ones_like(degenerate)
    padded_mask = torch.nn.functional.pad(mask, [1, 1, 1, 1])
    padded_degenerate = torch.nn.functional.pad(degenerate, [1, 1, 1, 1])
    result = torch.where(padded_mask == 1, padded_degenerate, input)

    if len(factor.size()) == 0:
        return _blend_one(result, input, factor)
    return torch.stack([
        _blend_one(result[i], input[i], factor[i]) for i in range(len(factor))
    ])
コード例 #6
0
def _transform_input(input: torch.Tensor) -> torch.Tensor:
    r"""Reshape an input tensor to be (*, C, H, W). Accept either (H, W), (C, H, W) or (*, C, H, W).
    Args:
        input: torch.Tensor

    Returns:
        torch.Tensor
    """

    return _to_bchw(input)
コード例 #7
0
def equalize(input: torch.Tensor) -> torch.Tensor:
    """Implements Equalize function from PIL using PyTorch ops based on uint8 format:
    https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L352
    Args:
        input (torch.Tensor): image tensor with shapes like (C, H, W) or (B, C, H, W) to equalize.
    Returns:
        torch.Tensor: Sharpened image or images.
    """
    input = _to_bchw(input) * 255

    # Code taken from: https://github.com/pytorch/vision/pull/796
    def scale_channel(im, c):
        """Scale the data in the channel to implement equalize."""
        im = im[c, :, :]
        # Compute the histogram of the image channel.
        histo = torch.histc(im, bins=256, min=0, max=255)
        # For the purposes of computing the step, filter out the nonzeros.
        nonzero_histo = torch.reshape(histo[histo != 0], [-1])
        step = (torch.sum(nonzero_histo) - nonzero_histo[-1]) // 255

        def build_lut(histo, step):
            # Compute the cumulative sum, shifting by step // 2
            # and then normalization by step.
            lut = (torch.cumsum(histo, 0) + (step // 2)) // step
            # Shift lut, prepending with 0.
            lut = torch.cat([torch.zeros(1), lut[:-1]])
            # Clip the counts to be in range.  This is done
            # in the C code for image.point.
            return torch.clamp(lut, 0, 255)

        # If step is zero, return the original image.  Otherwise, build
        # lut from the full histogram and step and then index from it.
        if step == 0:
            result = im
        else:
            # can't index using 2d index. Have to flatten and then reshape
            result = torch.gather(build_lut(histo, step), 0,
                                  im.flatten().long())
            result = result.reshape_as(im)

        return result / 255.

    res = []
    for image in input:
        # Assumes RGB for now.  Scales each channel independently
        # and then stacks the result.
        scaled_image = torch.stack(
            [scale_channel(image, i) for i in range(len(image))])
        res.append(scaled_image)
    return torch.stack(res)
コード例 #8
0
ファイル: equalization.py プロジェクト: copaah/kornia
def _compute_tiles(
        imgs: torch.Tensor,
        grid_size: Tuple[int, int],
        even_tile_size: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""Compute tiles on an image according to a grid size.

    Note that padding can be added to the image in order to crop properly the image.
    So, the grid_size (GH, GW) x tile_size (TH, TW) >= image_size (H, W)

    Args:
        imgs (torch.Tensor): batch of 2D images with shape (B, C, H, W) or (C, H, W).
        grid_size (Tuple[int, int]): number of tiles to be cropped in each direction (GH, GW)
        even_tile_size (bool, optional): Determine if the width and height of the tiles must be even. Default: False.

    Returns:
        torch.Tensor: tensor with tiles (B, GH, GW, C, TH, TW). B = 1 in case of a single image is provided.
        torch.Tensor: tensor with the padded batch of 2D imageswith shape (B, C, H', W')

    """
    batch: torch.Tensor = _to_bchw(imgs)  # B x C x H x W

    # compute stride and kernel size
    h, w = batch.shape[-2:]
    kernel_vert: int = math.ceil(h / grid_size[0])
    kernel_horz: int = math.ceil(w / grid_size[1])

    if even_tile_size:
        kernel_vert += 1 if kernel_vert % 2 else 0
        kernel_horz += 1 if kernel_horz % 2 else 0

    # add padding (with that kernel size we could need some extra cols and rows...)
    pad_vert = kernel_vert * grid_size[0] - h
    pad_horz = kernel_horz * grid_size[1] - w
    # add the padding in the last coluns and rows
    if pad_vert > 0 or pad_horz > 0:
        batch = F.pad(batch, [0, pad_horz, 0, pad_vert],
                      mode='reflect')  # B x C x H' x W'

    # compute tiles
    c: int = batch.shape[-3]
    tiles: torch.Tensor = (
        batch.unfold(1, c, c)  # unfold(dimension, size, step)
        .unfold(2, kernel_vert, kernel_vert).unfold(
            3, kernel_horz,
            kernel_horz).squeeze(1)).contiguous()  # GH x GW x C x TH x TW
    assert tiles.shape[-5] == grid_size[0]  # check the grid size
    assert tiles.shape[-4] == grid_size[1]
    return tiles, batch
コード例 #9
0
ファイル: adjust.py プロジェクト: yibit/kornia
def equalize(input: torch.Tensor) -> torch.Tensor:
    r"""Apply equalize on the input tensor.
    Implements Equalize function from PIL using PyTorch ops based on uint8 format:
    https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L352

    Args:
        input (torch.Tensor): image tensor with shapes like :math:(C, H, W) or :math:(B, C, H, W) to equalize.

    Returns:
        torch.Tensor: Sharpened image or images.
    """
    input = _to_bchw(input)

    res = []
    for image in input:
        # Assumes RGB for now.  Scales each channel independently
        # and then stacks the result.
        scaled_image = torch.stack([_scale_channel(image[i, :, :]) for i in range(len(image))])
        res.append(scaled_image)
    return torch.stack(res)
コード例 #10
0
ファイル: equalization.py プロジェクト: copaah/kornia
def equalize_clahe(
    input: torch.Tensor,
    clip_limit: float = 40.0,
    grid_size: Tuple[int, int] = (8, 8)) -> torch.Tensor:
    r"""Apply clahe equalization on the input tensor.

    NOTE: Lut computation uses the same approach as in OpenCV, in next versions this can change.

    Args:
        input (torch.Tensor): images tensor to equalize with values in the range [0, 1] and shapes like
                              :math:`(C, H, W)` or :math:`(B, C, H, W)`.
        clip_limit (float): threshold value for contrast limiting. If 0 clipping is disabled. Default: 40.
        grid_size (Tuple[int, int]): number of tiles to be cropped in each direction (GH, GW). Default: (8, 8).

    Returns:
        torch.Tensor: Equalized image or images with shape as the input.

    Examples:
        >>> img = torch.rand(1, 10, 20)
        >>> res = equalize_clahe(img)
        >>> res.shape
        torch.Size([1, 10, 20])

        >>> img = torch.rand(2, 3, 10, 20)
        >>> res = equalize_clahe(img)
        >>> res.shape
        torch.Size([2, 3, 10, 20])

    """
    if not isinstance(input, torch.Tensor):
        raise TypeError(
            f"Input input type is not a torch.Tensor. Got {type(input)}")

    if input.dim() not in [3, 4]:
        raise ValueError(
            f"Invalid input shape, we expect CxHxW or BxCxHxW. Got: {input.shape}"
        )

    if input.numel() == 0:
        raise ValueError("Invalid input tensor, it is empty.")

    if not isinstance(clip_limit, float):
        raise TypeError(
            f"Input clip_limit type is not float. Got {type(clip_limit)}")

    if not isinstance(grid_size, tuple):
        raise TypeError(
            f"Input grid_size type is not Tuple. Got {type(grid_size)}")

    if len(grid_size) != 2:
        raise TypeError(
            f"Input grid_size is not a Tuple with 2 elements. Got {len(grid_size)}"
        )

    if isinstance(grid_size[0], float) or isinstance(grid_size[1], float):
        raise TypeError(
            "Input grid_size type is not valid, must be a Tuple[int, int].")

    if grid_size[0] <= 0 or grid_size[1] <= 0:
        raise ValueError(
            "Input grid_size elements must be positive. Got {grid_size}")

    imgs: torch.Tensor = _to_bchw(input)  # B x C x H x W

    # hist_tiles: torch.Tensor  # B x GH x GW x C x TH x TW  # not supported by JIT
    # img_padded: torch.Tensor  # B x C x H' x W'  # not supported by JIT
    # the size of the tiles must be even in order to divide them into 4 tiles for the interpolation
    hist_tiles, img_padded = _compute_tiles(imgs, grid_size, True)
    tile_size: Tuple[int, int] = (hist_tiles.shape[-2], hist_tiles.shape[-1])
    interp_tiles: torch.Tensor = _compute_interpolation_tiles(
        img_padded, tile_size)  # B x 2GH x 2GW x C x TH/2 x TW/2
    luts: torch.Tensor = _compute_luts(hist_tiles,
                                       clip=clip_limit)  # B x GH x GW x C x B
    equalized_tiles: torch.Tensor = _compute_equalized_tiles(
        interp_tiles, luts)  # B x 2GH x 2GW x C x TH/2 x TW/2

    # reconstruct the images form the tiles
    #    try permute + contiguous + view
    eq_imgs: torch.Tensor = equalized_tiles.permute(0, 3, 1, 4, 2,
                                                    5).reshape_as(img_padded)
    h, w = imgs.shape[-2:]
    eq_imgs = eq_imgs[..., :h, :w]  # crop imgs if they were padded

    # remove batch if the input was not in batch form
    if input.dim() != eq_imgs.dim():
        eq_imgs = eq_imgs.squeeze(0)
    return eq_imgs
コード例 #11
0
ファイル: adjust.py プロジェクト: yibit/kornia
def posterize(input: torch.Tensor, bits: Union[int, torch.Tensor]) -> torch.Tensor:
    r"""Reduce the number of bits for each color channel. Non-differentiable function, uint8 involved.

    Args:
        input (torch.Tensor): image tensor with shapes like (C, H, W) or (B, C, H, W) to posterize.
        bits (int or torch.Tensor): number of high bits. Must be in range [0, 8].
            If int or one element tensor, input will be posterized by this bits.
            If 1-d tensor, input will be posterized element-wisely, len(bits) == input.shape[1].
            If n-d tensor, input will be posterized element-channel-wisely, bits.shape == input.shape[:len(bits.shape)]

    Returns:
        torch.Tensor: Image with reduced color channels.
    """
    if not torch.is_tensor(input):
        raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")

    if isinstance(bits, int):
        bits = torch.tensor(bits)

    if not torch.all((bits >= 0) * (bits <= 8)) and bits.dtype == torch.int:
        raise ValueError(f"bits must be integers within range [0, 8]. Got {bits}.")

    # TODO: Make a differentiable version
    # Current version:
    # Ref: https://github.com/open-mmlab/mmcv/pull/132/files#diff-309c9320c7f71bedffe89a70ccff7f3bR19
    # Ref: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L222
    # Potential approach: implementing kornia.LUT with floating points
    # https://github.com/albumentations-team/albumentations/blob/master/albumentations/augmentations/functional.py#L472
    def _left_shift(input: torch.Tensor, shift: torch.Tensor):
        return ((input * 255).to(torch.uint8) * (2 ** shift)).to(input.dtype) / 255.

    def _right_shift(input: torch.Tensor, shift: torch.Tensor):
        return (input * 255).to(torch.uint8) / (2 ** shift).to(input.dtype) / 255.

    def _posterize_one(input: torch.Tensor, bits: torch.Tensor):
        # Single bits value condition
        if bits == 0:
            return torch.zeros_like(input)
        if bits == 8:
            return input.clone()
        bits = 8 - bits
        return _left_shift(_right_shift(input, bits), bits)

    if len(bits.shape) == 0 or (len(bits.shape) == 1 and len(bits) == 1):
        return _posterize_one(input, bits)

    res = []
    if len(bits.shape) == 1:
        input = _to_bchw(input)

        assert bits.shape[0] == input.shape[0], \
            f"Batch size must be equal between bits and input. Got {bits.shape[0]}, {input.shape[0]}."

        for i in range(input.shape[0]):
            res.append(_posterize_one(input[i], bits[i]))
        return torch.stack(res, dim=0)

    assert bits.shape == input.shape[:len(bits.shape)], \
        f"Batch and channel must be equal between bits and input. Got {bits.shape}, {input.shape[:len(bits.shape)]}."
    _input = input.view(-1, *input.shape[len(bits.shape):])
    _bits = bits.flatten()
    for i in range(input.shape[0]):
        res.append(_posterize_one(_input[i], _bits[i]))
    return torch.stack(res, dim=0).reshape(*input.shape)
コード例 #12
0
ファイル: affwarp.py プロジェクト: working-girl/kornia
def resize(input: torch.Tensor, size: Union[int, Tuple[int, int]],
           interpolation: str = 'bilinear', align_corners: Optional[bool] = None,
           side: str = "short", antialias: bool = False) -> torch.Tensor:
    r"""Resize the input torch.Tensor to the given size.

    Args:
        tensor (torch.Tensor): The image tensor to be skewed with shape of :math:`(B, C, H, W)`.
        size (int, tuple(int, int)): Desired output size. If size is a sequence like (h, w),
            output size will be matched to this. If size is an int, smaller edge of the image will
            be matched to this number. i.e, if height > width, then image will be rescaled
            to (size * height / width, size)
        interpolation (str):  algorithm used for upsampling: 'nearest' | 'linear' | 'bilinear' |
            'bicubic' | 'trilinear' | 'area'. Default: 'bilinear'.
        align_corners(bool): interpolation flag. Default: None. See
            https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for detail
        side (str): Corresponding side if ``size`` is an integer. Can be one of ``"short"``, ``"long"``, ``"vert"``,
            or ``"horz"``. Defaults to ``"short"``.
        antialias (bool): if True, then image will be filtered with Gaussian before downscaling.
            No effect for upscaling. Default: False


    Returns:
        torch.Tensor: The resized tensor with the shape as the specified size.

    Example:
        >>> img = torch.rand(1, 3, 4, 4)
        >>> out = resize(img, (6, 8))
        >>> print(out.shape)
        torch.Size([1, 3, 6, 8])
    """
    if not isinstance(input, torch.Tensor):
        raise TypeError("Input tensor type is not a torch.Tensor. Got {}"
                        .format(type(input)))

    input_size = h, w = input.shape[-2:]
    if isinstance(size, int):
        aspect_ratio = w / h
        size = _side_to_image_size(size, aspect_ratio, side)

    if size == input_size:
        return input

    # TODO: find a proper way to handle this cases in the future
    input_tmp = _to_bchw(input)

    factors = (h / size[0], w / size[1])

    # We do bluring only for downscaling
    antialias = antialias and (max(factors) > 1)

    if antialias:
        # First, we have to determine sigma
        sigmas = (max(factors[0], 1.0), max(factors[1], 1.0))

        # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
        # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
        # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
        ks = int(2.0 * 2 * sigmas[0] + 1), int(2.0 * 2 * sigmas[1] + 1)
        input_tmp = kornia.filters.gaussian_blur2d(input_tmp, ks, sigmas)

    output = torch.nn.functional.interpolate(
        input_tmp, size=size, mode=interpolation, align_corners=align_corners)

    if len(input.shape) != len(output.shape):
        output = output.squeeze()

    return output