Beispiel #1
0
def motion_blur(inp, k_radius, k_stdev, angle_offset=0):
    n, c, h, w = inp.shape
    angle = angle_offset + Uniform(-45., 45.).sample((inp.shape[0], ))
    clip_h, clip_v = [(fn(ch.deg2rad(angle)) * k_radius).round().long()
                      for fn in (ch.cos, ch.sin)]
    # Move everything onto the right device
    clip_h, clip_v = clip_h.to(inp.device), clip_v.to(inp.device)

    kernel = gaussian_motionfilter2d(k_radius * 2 + 1, k_stdev, angle)
    inp = F.pad(inp, [k_radius] * 4, 'reflect')
    new_inp = filters.filter2D(inp, kernel, 'reflect', normalized=True)

    if ch.cuda.is_available():
        y_grid, x_grid = ch.meshgrid(ch.arange(h), ch.arange(w))
        grid = ch.stack((x_grid, y_grid), -1).repeat(n, 1, 1,
                                                     1).to(inp.device).float()
        grid[..., 0] = (grid[..., 0] + k_radius +
                        clip_h.view(-1, 1, 1)) * 2 / (new_inp.shape[2] - 1) - 1
        grid[..., 1] = (grid[..., 1] + k_radius +
                        clip_v.view(-1, 1, 1)) * 2 / (new_inp.shape[3] - 1) - 1
        res = F.grid_sample(new_inp, grid, 'nearest')
    else:
        res = ch.stack([im[:, k_radius+x:k_radius+x+h, k_radius+y:k_radius+y+w] \
                                    for im, y, x in zip(new_inp, clip_h, clip_v)])
    return ch.clamp(res, 0., 1.)
Beispiel #2
0
def pyrup(input: torch.Tensor, border_type: str = 'reflect', align_corners: bool = False) -> torch.Tensor:
    r"""Upsamples a tensor and then blurs it.

    Args:
        input (tensor): the tensor to be downsampled.
        border_type (str): the padding mode to be applied before convolving.
          The expected modes are: ``'constant'``, ``'reflect'``,
          ``'replicate'`` or ``'circular'``. Default: ``'reflect'``.
        align_corners(bool): interpolation flag. Default: False. See
        https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for detail.

    Return:
        torch.Tensor: the downsampled tensor.

    Examples:
        >>> input = torch.arange(4, dtype=torch.float32).reshape(1, 1, 2, 2)
        >>> pyrup(input, align_corners=True)
        tensor([[[[0.7500, 0.8750, 1.1250, 1.2500],
                  [1.0000, 1.1250, 1.3750, 1.5000],
                  [1.5000, 1.6250, 1.8750, 2.0000],
                  [1.7500, 1.8750, 2.1250, 2.2500]]]])
    """
    if not len(input.shape) == 4:
        raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
    kernel: torch.Tensor = _get_pyramid_gaussian_kernel()
    # upsample tensor
    b, c, height, width = input.shape
    x_up: torch.Tensor = F.interpolate(input, size=(height * 2, width * 2),
                                       mode='bilinear', align_corners=align_corners)

    # blurs upsampled tensor
    x_blur: torch.Tensor = filter2D(x_up, kernel, border_type)
    return x_blur
Beispiel #3
0
def pyrdown(
        input: torch.Tensor,
        border_type: str = 'reflect', align_corners: bool = False) -> torch.Tensor:
    r"""Blurs a tensor and downsamples it.

    Args:
        input (tensor): the tensor to be downsampled.
        border_type (str): the padding mode to be applied before convolving.
          The expected modes are: ``'constant'``, ``'reflect'``,
          ``'replicate'`` or ``'circular'``. Default: ``'reflect'``.
        align_corners(bool): interpolation flag. Default: False. See
        https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for detail.

    Return:
        torch.Tensor: the downsampled tensor.

    Examples:
        >>> input = torch.arange(16, dtype=torch.float32).reshape(1, 1, 4, 4)
        >>> pyrdown(input, align_corners=True)
        tensor([[[[ 3.7500,  5.2500],
                  [ 9.7500, 11.2500]]]])
    """
    if not len(input.shape) == 4:
        raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
    kernel: torch.Tensor = _get_pyramid_gaussian_kernel()
    b, c, height, width = input.shape
    # blur image
    x_blur: torch.Tensor = filter2D(input, kernel, border_type)

    # downsample.
    out: torch.Tensor = F.interpolate(x_blur, size=(height // 2, width // 2), mode='bilinear',
                                      align_corners=align_corners)
    return out
def get_normalized_responses(exp_preds,
                             filter_dict,
                             section='horizontal',
                             eps=1e-6):

    sum_boundary_responses = filter2D(exp_preds, filter_dict[section]) + eps
    norm = torch.div(exp_preds, sum_boundary_responses)

    return torch.clamp(norm, eps, 1)
Beispiel #5
0
def defocus_blur(inp, disk_radius, alias_blur):
    kernel_size = (3, 3) if disk_radius <= 8 else (5, 5)
    mesh_range = ch.arange(-max(8, disk_radius), max(8, disk_radius) + 1)
    X, Y = ch.meshgrid(mesh_range, mesh_range)

    aliased_disk = ((X.pow(2) + Y.pow(2)) <= disk_radius**2).float()
    aliased_disk /= aliased_disk.sum()
    kernel = filters.gaussian_blur2d(aliased_disk[None, None, ...],
                                     kernel_size, (alias_blur, alias_blur))[0]
    return ch.clamp(filters.filter2D(inp, kernel), 0, 1)
Beispiel #6
0
    def forward(self, input: torch.Tensor) -> torch.Tensor:  # type: ignore
        if not torch.is_tensor(input):
            raise TypeError("Input type is not a torch.Tensor. Got {}"
                            .format(type(input)))
        if not len(input.shape) == 4:
            raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}"
                             .format(input.shape))
        # blur image
        x_blur: torch.Tensor = filter2D(input, self.kernel, self.border_type)

        # reject even rows and columns.
        out: torch.Tensor = F.avg_pool2d(x_blur, 2, 2)
        return out
Beispiel #7
0
    def forward(self, input: torch.Tensor) -> torch.Tensor:  # type: ignore
        if not torch.is_tensor(input):
            raise TypeError("Input type is not a torch.Tensor. Got {}"
                            .format(type(input)))
        if not len(input.shape) == 4:
            raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}"
                             .format(input.shape))
        # blur image
        x_blur: torch.Tensor = filter2D(input, self.kernel, self.border_type)

        # downsample.
        out: torch.Tensor = F.interpolate(x_blur, scale_factor=0.5, mode='bilinear',
                                          align_corners=False)
        return out
Beispiel #8
0
    def forward(self, input: torch.Tensor) -> torch.Tensor:  # type: ignore
        if not torch.is_tensor(input):
            raise TypeError("Input type is not a torch.Tensor. Got {}"
                            .format(type(input)))
        if not len(input.shape) == 4:
            raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}"
                             .format(input.shape))
        # upsample tensor
        b, c, height, width = input.shape
        x_up: torch.Tensor = F.interpolate(input, size=(height * 2, width * 2),
                                           mode='bilinear', align_corners=True)

        # blurs upsampled tensor
        x_blur: torch.Tensor = filter2D(x_up, self.kernel, self.border_type)
        return x_blur
Beispiel #9
0
    def forward(self, inputs):
        _device = inputs.device

        batch_size, num_channels, height, width = inputs.size()

        kernel_size = height // 10
        radius = int(kernel_size / 2)
        kernel_size = radius * 2 + 1

        sigma = np.random.uniform(*self.sigma_range)
        kernel = torch.unsqueeze(get_gaussian_kernel2d(
            (kernel_size, kernel_size), (sigma, sigma)),
                                 dim=0)
        blurred = filter2D(inputs, kernel, "reflect")

        return blurred
Beispiel #10
0
 def forward(self, x):
     f = self.f
     f = f[None, None, :] * f[None, :, None]
     return filter2D(x, f, normalized=True)
Beispiel #11
0
def ssim(img1: torch.Tensor,
         img2: torch.Tensor,
         window_size: int,
         max_val: float = 1.0,
         eps: float = 1e-12) -> torch.Tensor:
    r"""Function that computes the Structural Similarity (SSIM) index map between two images.

    Measures the (SSIM) index between each element in the input `x` and target `y`.

    The index can be described as:

    .. math::

      \text{SSIM}(x, y) = \frac{(2\mu_x\mu_y+c_1)(2\sigma_{xy}+c_2)}
      {(\mu_x^2+\mu_y^2+c_1)(\sigma_x^2+\sigma_y^2+c_2)}

    where:
      - :math:`c_1=(k_1 L)^2` and :math:`c_2=(k_2 L)^2` are two variables to
        stabilize the division with weak denominator.
      - :math:`L` is the dynamic range of the pixel-values (typically this is
        :math:`2^{\#\text{bits per pixel}}-1`).

    Args:
        img1 (torch.Tensor): the first input image with shape :math:`(B, C, H, W)`.
        img2 (torch.Tensor): the second input image with shape :math:`(B, C, H, W)`.
        window_size (int): the size of the gaussian kernel to smooth the images.
        max_val (float): the dynamic range of the images. Default: 1.
        eps (float): Small value for numerically stability when dividing. Default: 1e-12.

    Returns:
        torch.Tensor: The ssim index map with shape :math:`(B, C, H, W)`.

    Examples:
        >>> input1 = torch.rand(1, 4, 5, 5)
        >>> input2 = torch.rand(1, 4, 5, 5)
        >>> ssim_map = ssim(input1, input2, 5)  # 1x4x5x5
    """
    if not isinstance(img1, torch.Tensor):
        raise TypeError("Input img1 type is not a torch.Tensor. Got {}".format(
            type(img1)))

    if not isinstance(img2, torch.Tensor):
        raise TypeError("Input img2 type is not a torch.Tensor. Got {}".format(
            type(img2)))

    if not isinstance(max_val, float):
        raise TypeError(
            f"Input max_val type is not a float. Got {type(max_val)}")

    if not len(img1.shape) == 4:
        raise ValueError(
            "Invalid img1 shape, we expect BxCxHxW. Got: {}".format(
                img1.shape))

    if not len(img2.shape) == 4:
        raise ValueError(
            "Invalid img2 shape, we expect BxCxHxW. Got: {}".format(
                img2.shape))

    if not img1.shape == img2.shape:
        raise ValueError(
            "img1 and img2 shapes must be the same. Got: {} and {}".format(
                img1.shape, img2.shape))

    # prepare kernel
    kernel: torch.Tensor = (get_gaussian_kernel2d((window_size, window_size),
                                                  (1.5, 1.5)).unsqueeze(0))

    # compute coefficients
    C1: float = (0.01 * max_val)**2
    C2: float = (0.03 * max_val)**2

    # compute local mean per channel
    mu1: torch.Tensor = filter2D(img1, kernel)
    mu2: torch.Tensor = filter2D(img2, kernel)

    mu1_sq = mu1**2
    mu2_sq = mu2**2
    mu1_mu2 = mu1 * mu2

    # compute local sigma per channel
    sigma1_sq = filter2D(img1**2, kernel) - mu1_sq
    sigma2_sq = filter2D(img2**2, kernel) - mu2_sq
    sigma12 = filter2D(img1 * img2, kernel) - mu1_mu2

    # compute the similarity index map
    num: torch.Tensor = (2. * mu1_mu2 + C1) * (2. * sigma12 + C2)
    den: torch.Tensor = ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    return num / (den + eps)
Beispiel #12
0
 def forward(self, x):
     return filter2D(x, self.blur_kernel, normalized=True)
Beispiel #13
0
    def forward(  # type: ignore
            self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:

        if not torch.is_tensor(img1):
            raise TypeError(
                "Input img1 type is not a torch.Tensor. Got {}".format(
                    type(img1)))

        if not torch.is_tensor(img2):
            raise TypeError(
                "Input img2 type is not a torch.Tensor. Got {}".format(
                    type(img2)))

        if not len(img1.shape) == 4:
            raise ValueError(
                "Invalid img1 shape, we expect BxCxHxW. Got: {}".format(
                    img1.shape))

        if not len(img2.shape) == 4:
            raise ValueError(
                "Invalid img2 shape, we expect BxCxHxW. Got: {}".format(
                    img2.shape))

        if not img1.shape == img2.shape:
            raise ValueError(
                "img1 and img2 shapes must be the same. Got: {} and {}".format(
                    img1.shape, img2.shape))

        if not img1.device == img2.device:
            raise ValueError(
                "img1 and img2 must be in the same device. Got: {} and {}".
                format(img1.device, img2.device))

        if not img1.dtype == img2.dtype:
            raise ValueError(
                "img1 and img2 must be in the same dtype. Got: {} and {}".
                format(img1.dtype, img2.dtype))

        # prepare kernel
        b, c, h, w = img1.shape
        tmp_kernel: torch.Tensor = self.window.to(img1.device).to(img1.dtype)
        tmp_kernel = torch.unsqueeze(tmp_kernel, dim=0)

        # compute local mean per channel
        mu1: torch.Tensor = filter2D(img1, tmp_kernel)
        mu2: torch.Tensor = filter2D(img2, tmp_kernel)

        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1 * mu2

        # compute local sigma per channel
        sigma1_sq = filter2D(img1 * img1, tmp_kernel) - mu1_sq
        sigma2_sq = filter2D(img2 * img2, tmp_kernel) - mu2_sq
        sigma12 = filter2D(img1 * img2, tmp_kernel) - mu1_mu2

        ssim_map = ((2. * mu1_mu2 + self.C1) * (2. * sigma12 + self.C2)) / \
            ((mu1_sq + mu2_sq + self.C1) * (sigma1_sq + sigma2_sq + self.C2))

        loss = torch.clamp(-ssim_map + 1., min=0, max=1) / 2.

        if self.reduction == "mean":
            loss = torch.mean(loss)
        elif self.reduction == "sum":
            loss = torch.sum(loss)
        elif self.reduction == "none":
            pass
        return loss