Exemplo n.º 1
0
def homography_warp3d(patch_src: torch.Tensor,
                      src_homo_dst: torch.Tensor,
                      dsize: Tuple[int, int, int],
                      mode: str = 'bilinear',
                      padding_mode: str = 'zeros',
                      align_corners: bool = False,
                      normalized_coordinates: bool = True) -> torch.Tensor:
    r"""Warp image patchs or tensors by normalized 3D homographies.

    Args:
        patch_src (torch.Tensor): The image or tensor to warp. Should be from
                                  source of shape :math:`(N, C, D, H, W)`.
        src_homo_dst (torch.Tensor): The homography or stack of homographies
                                     from destination to source of shape
                                     :math:`(N, 4, 4)`.
        dsize (Tuple[int, int, int]): The height and width of the image to warp.
        mode (str): interpolation mode to calculate output values
          'bilinear' | 'nearest'. Default: 'bilinear'.
        padding_mode (str): padding mode for outside grid values
          'zeros' | 'border' | 'reflection'. Default: 'zeros'.
        align_corners(bool): interpolation flag. Default: False. See
            https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for details.
        normalized_coordinates (bool): Whether the homography assumes [-1, 1] normalized
            coordinates or not.

    Return:
        torch.Tensor: Patch sampled at locations from source to destination.

    Example:
        >>> input = torch.rand(1, 3, 32, 32)
        >>> homography = torch.eye(3).view(1, 3, 3)
        >>> output = homography_warp(input, homography, (32, 32))
    """
    if not src_homo_dst.device == patch_src.device:
        raise TypeError("Patch and homography must be on the same device. \
                         Got patch.device: {} src_H_dst.device: {}.".format(
            patch_src.device, src_homo_dst.device))

    depth, height, width = dsize
    grid = create_meshgrid3d(depth,
                             height,
                             width,
                             normalized_coordinates=normalized_coordinates,
                             device=patch_src.device)
    warped_grid = warp_grid3d(grid, src_homo_dst)

    return F.grid_sample(patch_src,
                         warped_grid,
                         mode=mode,
                         padding_mode=padding_mode,
                         align_corners=align_corners)
Exemplo n.º 2
0
def conv_quad_interp3d(input: torch.Tensor, strict_maxima_bonus: float = 1.0, eps: float = 1e-6):
    r"""Function that computes the single iteration of quadratic interpolation of of the extremum (max or min) location
    and value per each 3x3x3 window which contains strict extremum, similar to one done is SIFT

    Args:
        strict_maxima_bonus (float): pixels, which are strict maxima will score (1 + strict_maxima_bonus) * value.
                                     This is needed for mimic behavior of strict NMS in classic local features
        eps (float): parameter to control the hessian matrix ill-condition number.
    Shape:
        - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
        - Output: :math:`(N, C, 3, D_{out}, H_{out}, W_{out})`, :math:`(N, C, D_{out}, H_{out}, W_{out})`, where

         .. math::
             D_{out} = \left\lfloor\frac{D_{in}  + 2 \times \text{padding}[0] -
             (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor

         .. math::
             H_{out} = \left\lfloor\frac{H_{in}  + 2 \times \text{padding}[1] -
             (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor

         .. math::
             W_{out} = \left\lfloor\frac{W_{in}  + 2 \times \text{padding}[2] -
             (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor

    Examples:
        >>> input = torch.randn(20, 16, 3, 50, 32)
        >>> nms_coords, nms_val = conv_quad_interp3d(input, 1.0)
    """
    if not torch.is_tensor(input):
        raise TypeError("Input type is not a torch.Tensor. Got {}"
                        .format(type(input)))
    if not len(input.shape) == 5:
        raise ValueError("Invalid input shape, we expect BxCxDxHxW. Got: {}"
                         .format(input.shape))
    B, CH, D, H, W = input.shape
    dev: torch.device = input.device
    grid_global: torch.Tensor = create_meshgrid3d(D, H, W, False,
                                                  device=input.device).permute(0, 4, 1, 2, 3)
    grid_global = grid_global.to(input.dtype)

    # to determine the location we are solving system of linear equations Ax = b, where b is 1st order gradient
    # and A is Hessian matrix
    b: torch.Tensor = kornia.filters.spatial_gradient3d(input, order=1, mode='diff')  #
    b = b.permute(0, 1, 3, 4, 5, 2).reshape(-1, 3, 1)
    A: torch.Tensor = kornia.filters.spatial_gradient3d(input, order=2, mode='diff')
    A = A.permute(0, 1, 3, 4, 5, 2).reshape(-1, 6)
    dxx = A[..., 0]
    dyy = A[..., 1]
    dss = A[..., 2]
    dxy = A[..., 3]
    dys = A[..., 4]
    dxs = A[..., 5]
    # for the Hessian
    Hes = torch.stack([dxx, dxy, dxs, dxy, dyy, dys, dxs, dys, dss]).view(-1, 3, 3)
    Hes += torch.eye(3, device=Hes.device)[None] * eps

    nms_mask: torch.Tensor = kornia.feature.nms3d(input, (3, 3, 3), True)
    x_solved: torch.Tensor = torch.zeros_like(b)
    x_solved_masked, _ = torch.solve(b[nms_mask.view(-1)], Hes[nms_mask.view(-1)])
    x_solved.masked_scatter_(nms_mask.view(-1, 1, 1), x_solved_masked)
    dx: torch.Tensor = -x_solved

    # Ignore ones, which are far from window,
    dx[(dx.abs().max(dim=1, keepdim=True)[0] > 0.7).view(-1), :, :] = 0

    dy: torch.Tensor = 0.5 * torch.bmm(b.permute(0, 2, 1), dx)
    y_max = input + dy.view(B, CH, D, H, W)
    if strict_maxima_bonus > 0:
        y_max *= (1.0 + strict_maxima_bonus * nms_mask.to(input.dtype))

    dx_res: torch.Tensor = dx.flip(1).reshape(B, CH, D, H, W, 3).permute(0, 1, 5, 2, 3, 4)
    coords_max: torch.Tensor = grid_global.repeat(B, 1, 1, 1, 1).unsqueeze(1)
    coords_max = coords_max + dx_res
    return coords_max, y_max
Exemplo n.º 3
0
def conv_soft_argmax3d(input: torch.Tensor,
                       kernel_size: Tuple[int, int, int] = (3, 3, 3),
                       stride: Tuple[int, int, int] = (1, 1, 1),
                       padding: Tuple[int, int, int] = (1, 1, 1),
                       temperature: Union[torch.Tensor, float] = torch.tensor(1.0),
                       normalized_coordinates: bool = False,
                       eps: float = 1e-8,
                       output_value: bool = True,
                       strict_maxima_bonus: float = 0.0) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    r"""Function that computes the convolutional spatial Soft-Argmax 3D over the windows
    of a given input heatmap. Function has two outputs: argmax coordinates and the softmaxpooled heatmap values
    themselves.
    On each window, the function computed is:

    .. math::
             ijk(X) = \frac{\sum{(i,j,k)} * exp(x / T)  \in X} {\sum{exp(x / T)  \in X}}

    .. math::
             val(X) = \frac{\sum{x * exp(x / T)  \in X}} {\sum{exp(x / T)  \in X}}

    where T is temperature.

    Args:
        kernel_size (Tuple[int,int,int]):  size of the window
        stride (Tuple[int,int,int]): stride of the window.
        padding (Tuple[int,int,int]): input zero padding
        temperature (torch.Tensor): factor to apply to input. Default is 1.
        normalized_coordinates (bool): whether to return the coordinates normalized in the range of [-1, 1]. Otherwise,
                                       it will return the coordinates in the range of the input shape. Default is False.
        eps (float): small value to avoid zero division. Default is 1e-8.
        output_value (bool): if True, val is outputed, if False, only ij
        strict_maxima_bonus (float): pixels, which are strict maxima will score (1 + strict_maxima_bonus) * value.
                                     This is needed for mimic behavior of strict NMS in classic local features
    Shape:
        - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
        - Output: :math:`(N, C, 3, D_{out}, H_{out}, W_{out})`, :math:`(N, C, D_{out}, H_{out}, W_{out})`, where

         .. math::
             D_{out} = \left\lfloor\frac{D_{in}  + 2 \times \text{padding}[0] -
             (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor

         .. math::
             H_{out} = \left\lfloor\frac{H_{in}  + 2 \times \text{padding}[1] -
             (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor

         .. math::
             W_{out} = \left\lfloor\frac{W_{in}  + 2 \times \text{padding}[2] -
             (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor

    Examples:
        >>> input = torch.randn(20, 16, 3, 50, 32)
        >>> nms_coords, nms_val = conv_soft_argmax2d(input, (3, 3, 3), (1, 2, 2), (0, 1, 1))
    """
    if not torch.is_tensor(input):
        raise TypeError("Input type is not a torch.Tensor. Got {}"
                        .format(type(input)))

    if not len(input.shape) == 5:
        raise ValueError("Invalid input shape, we expect BxCxDxHxW. Got: {}"
                         .format(input.shape))

    if temperature <= 0:
        raise ValueError("Temperature should be positive float or tensor. Got: {}"
                         .format(temperature))

    b, c, d, h, w = input.shape
    kx, ky, kz = kernel_size
    device: torch.device = input.device
    dtype: torch.dtype = input.dtype
    input = input.view(b * c, 1, d, h, w)

    center_kernel: torch.Tensor = _get_center_kernel3d(kx, ky, kz, device).to(dtype)
    window_kernel: torch.Tensor = _get_window_grid_kernel3d(kx, ky, kz, device).to(dtype)

    # applies exponential normalization trick
    # https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/
    # https://github.com/pytorch/pytorch/blob/bcb0bb7e0e03b386ad837015faba6b4b16e3bfb9/aten/src/ATen/native/SoftMax.cpp#L44
    x_max = F.adaptive_max_pool3d(input, (1, 1, 1))

    # max is detached to prevent undesired backprop loops in the graph
    x_exp = ((input - x_max.detach()) / temperature).exp()

    pool_coef: float = float(kx * ky * kz)

    # softmax denominator
    den = pool_coef * F.avg_pool3d(x_exp.view_as(input),
                                   kernel_size,
                                   stride=stride,
                                   padding=padding) + eps

    # We need to output also coordinates
    # Pooled window center coordinates
    grid_global: torch.Tensor = create_meshgrid3d(
        d, h, w, False, device=device).to(dtype).permute(0, 4, 1, 2, 3)

    grid_global_pooled = F.conv3d(grid_global,
                                  center_kernel,
                                  stride=stride,
                                  padding=padding)

    # Coordinates of maxima residual to window center
    # prepare kernel
    coords_max: torch.Tensor = F.conv3d(x_exp,
                                        window_kernel,
                                        stride=stride,
                                        padding=padding)

    coords_max = coords_max / den.expand_as(coords_max)
    coords_max = coords_max + grid_global_pooled.expand_as(coords_max)
    # [:,:, 0, ...] is depth (scale)
    # [:,:, 1, ...] is x
    # [:,:, 2, ...] is y

    if normalized_coordinates:
        coords_max = normalize_pixel_coordinates3d(coords_max.permute(0, 2, 3, 4, 1), d, h, w)
        coords_max = coords_max.permute(0, 4, 1, 2, 3)

    # Back B*C -> (b, c)
    coords_max = coords_max.view(b, c, 3, coords_max.size(2), coords_max.size(3), coords_max.size(4))

    if not output_value:
        return coords_max
    x_softmaxpool = pool_coef * F.avg_pool3d(x_exp.view(input.size()) * input,
                                             kernel_size,
                                             stride=stride,
                                             padding=padding) / den
    if strict_maxima_bonus > 0:
        in_levels: int = input.size(2)
        out_levels: int = x_softmaxpool.size(2)
        skip_levels: int = (in_levels - out_levels) // 2
        strict_maxima: torch.Tensor = F.avg_pool3d(kornia.feature.nms3d(input, kernel_size), 1, stride, 0)
        strict_maxima = strict_maxima[:, :, skip_levels:out_levels - skip_levels]
        x_softmaxpool *= 1.0 + strict_maxima_bonus * strict_maxima
    x_softmaxpool = x_softmaxpool.view(b,
                                       c,
                                       x_softmaxpool.size(2),
                                       x_softmaxpool.size(3),
                                       x_softmaxpool.size(4))
    return coords_max, x_softmaxpool
Exemplo n.º 4
0
def conv_soft_argmax3d(
    input: torch.Tensor,
    kernel_size: Tuple[int, int, int] = (3, 3, 3),
    stride: Tuple[int, int, int] = (1, 1, 1),
    padding: Tuple[int, int, int] = (1, 1, 1),
    temperature: Union[torch.Tensor, float] = torch.tensor(1.0),
    normalized_coordinates: bool = False,
    eps: float = 1e-8,
    output_value: bool = True
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    r"""Function that computes the convolutional spatial Soft-Argmax 3D over the windows
    of a given input heatmap. Function has two outputs: argmax coordinates and the softmaxpooled heatmap values
    themselves.
    On each window, the function computed is:

    .. math::
             ijk(X) = \frac{\sum{(i,j,k)} * exp(x / T)  \in X} {\sum{exp(x / T)  \in X}}

    .. math::
             val(X) = \frac{\sum{x * exp(x / T)  \in X}} {\sum{exp(x / T)  \in X}}

    where T is temperature.

    Args:
        kernel_size (Tuple[int,int,int]):  size of the window
        stride (Tuple[int,int,int]): stride of the window.
        padding (Tuple[int,int,int]): input zero padding
        temperature (torch.Tensor): factor to apply to input. Default is 1.
        normalized_coordinates (bool): whether to return the coordinates normalized in the range of [-1, 1]. Otherwise,
                                       it will return the coordinates in the range of the input shape. Default is False.
        eps (float): small value to avoid zero division. Default is 1e-8.
        output_value (bool): if True, val is outputed, if False, only ij

    Shape:
        - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
        - Output: :math:`(N, C, 3, D_{out}, H_{out}, W_{out})`, :math:`(N, C, D_{out}, H_{out}, W_{out})`, where

         .. math::
             D_{out} = \left\lfloor\frac{D_{in}  + 2 \times \text{padding}[0] -
             (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor

         .. math::
             H_{out} = \left\lfloor\frac{H_{in}  + 2 \times \text{padding}[1] -
             (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor

         .. math::
             W_{out} = \left\lfloor\frac{W_{in}  + 2 \times \text{padding}[2] -
             (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor

    Examples:
        >>> input = torch.randn(20, 16, 3, 50, 32)
        >>> nms_coords, nms_val = conv_soft_argmax2d(input, (3, 3, 3), (1, 2, 2), (0, 1, 1))
    """
    if not torch.is_tensor(input):
        raise TypeError("Input type is not a torch.Tensor. Got {}".format(
            type(input)))
    if not len(input.shape) == 5:
        raise ValueError(
            "Invalid input shape, we expect BxCxDxHxW. Got: {}".format(
                input.shape))
    if temperature <= 0:
        raise ValueError(
            "Temperature should be positive float or tensor. Got: {}".format(
                temperature))

    b, c, d, h, w = input.shape
    input = input.view(b * c, 1, d, h, w)
    dev: torch.device = input.device

    center_kernel = _get_center_kernel3d(kernel_size[0], kernel_size[1],
                                         kernel_size[2], dev)
    window_kernel = _get_window_grid_kernel3d(kernel_size[0], kernel_size[1],
                                              kernel_size[2], dev)
    window_kernel = window_kernel.to(input.dtype)

    x_exp = (input / temperature).exp()

    pool_coef: float = float(kernel_size[0] * kernel_size[1] * kernel_size[2])

    # softmax denominator
    den = pool_coef * F.avg_pool3d(
        x_exp.view(
            input.size()), kernel_size, stride=stride, padding=padding) + 1e-12

    x_softmaxpool = pool_coef * F.avg_pool3d(x_exp.view(input.size()) * input,
                                             kernel_size,
                                             stride=stride,
                                             padding=padding) / den
    x_softmaxpool = x_softmaxpool.view(b, c, x_softmaxpool.size(2),
                                       x_softmaxpool.size(3),
                                       x_softmaxpool.size(4))

    # We need to output also coordinates
    # Pooled window center coordinates
    grid_global: torch.Tensor = create_meshgrid3d(d,
                                                  h,
                                                  w,
                                                  False,
                                                  device=input.device).permute(
                                                      0, 4, 1, 2, 3)
    grid_global = grid_global.to(input.dtype)
    grid_global_pooled = F.conv3d(grid_global,
                                  center_kernel.to(input.dtype),
                                  stride=stride,
                                  padding=padding)

    # Coordinates of maxima residual to window center
    # prepare kernel
    coords_max: torch.Tensor = F.conv3d(x_exp,
                                        window_kernel,
                                        stride=stride,
                                        padding=padding)

    coords_max = coords_max / den.expand_as(coords_max)
    coords_max = coords_max + grid_global_pooled.expand_as(coords_max)
    # [:,:, 0, ...] is depth (scale)
    # [:,:, 1, ...] is x
    # [:,:, 2, ...] is y

    if normalized_coordinates:
        coords_max = normalize_pixel_coordinates3d(
            coords_max.permute(0, 2, 3, 4, 1), d, h, w)
        coords_max = coords_max.permute(0, 4, 1, 2, 3)

    # Back B*C -> (b, c)
    coords_max = coords_max.view(b, c, 3, coords_max.size(2),
                                 coords_max.size(3), coords_max.size(4))

    if output_value:
        return coords_max, x_softmaxpool
    return coords_max
Exemplo n.º 5
0
def conv_quad_interp3d(input: torch.Tensor,
                       strict_maxima_bonus: float = 10.0,
                       eps: float = 1e-7) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""Compute the single iteration of quadratic interpolation of the extremum (max or min).

    Args:
        input: the given heatmap with shape :math:`(N, C, D_{in}, H_{in}, W_{in})`.
        strict_maxima_bonus: pixels, which are strict maxima will score (1 + strict_maxima_bonus) * value.
          This is needed for mimic behavior of strict NMS in classic local features
        eps: parameter to control the hessian matrix ill-condition number.

    Returns:
        the location and value per each 3x3x3 window which contains strict extremum, similar to one done is SIFT.
        :math:`(N, C, 3, D_{out}, H_{out}, W_{out})`, :math:`(N, C, D_{out}, H_{out}, W_{out})`,

        where

         .. math::
             D_{out} = \left\lfloor\frac{D_{in}  + 2 \times \text{padding}[0] -
             (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor

         .. math::
             H_{out} = \left\lfloor\frac{H_{in}  + 2 \times \text{padding}[1] -
             (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor

         .. math::
             W_{out} = \left\lfloor\frac{W_{in}  + 2 \times \text{padding}[2] -
             (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor

    Examples:
        >>> input = torch.randn(20, 16, 3, 50, 32)
        >>> nms_coords, nms_val = conv_quad_interp3d(input, 1.0)
    """
    if not torch.is_tensor(input):
        raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")

    if not len(input.shape) == 5:
        raise ValueError(
            f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}")

    B, CH, D, H, W = input.shape
    grid_global: torch.Tensor = create_meshgrid3d(D,
                                                  H,
                                                  W,
                                                  False,
                                                  device=input.device).permute(
                                                      0, 4, 1, 2, 3)
    grid_global = grid_global.to(input.dtype)

    # to determine the location we are solving system of linear equations Ax = b, where b is 1st order gradient
    # and A is Hessian matrix
    b: torch.Tensor = spatial_gradient3d(input, order=1, mode='diff')  #
    b = b.permute(0, 1, 3, 4, 5, 2).reshape(-1, 3, 1)
    A: torch.Tensor = spatial_gradient3d(input, order=2, mode='diff')
    A = A.permute(0, 1, 3, 4, 5, 2).reshape(-1, 6)
    dxx = A[..., 0]
    dyy = A[..., 1]
    dss = A[..., 2]
    dxy = 0.25 * A[..., 3]  # normalization to match OpenCV implementation
    dys = 0.25 * A[..., 4]  # normalization to match OpenCV implementation
    dxs = 0.25 * A[..., 5]  # normalization to match OpenCV implementation

    Hes = torch.stack([dxx, dxy, dxs, dxy, dyy, dys, dxs, dys, dss],
                      dim=-1).view(-1, 3, 3)
    if not torch_version_geq(1, 10):
        # The following is needed to avoid singular cases
        Hes += torch.rand(Hes[0].size(), device=Hes.device).abs()[None] * eps

    nms_mask: torch.Tensor = nms3d(input, (3, 3, 3), True)
    x_solved: torch.Tensor = torch.zeros_like(b)
    x_solved_masked, _, solved_correctly = safe_solve_with_mask(
        b[nms_mask.view(-1)], Hes[nms_mask.view(-1)])

    #  Kill those points, where we cannot solve
    new_nms_mask = nms_mask.masked_scatter(nms_mask, solved_correctly)

    x_solved.masked_scatter_(new_nms_mask.view(-1, 1, 1),
                             x_solved_masked[solved_correctly])
    dx: torch.Tensor = -x_solved

    # Ignore ones, which are far from window center
    mask1 = dx.abs().max(dim=1, keepdim=True)[0] > 0.7
    dx.masked_fill_(mask1.expand_as(dx), 0)
    dy: torch.Tensor = 0.5 * torch.bmm(b.permute(0, 2, 1), dx)
    y_max = input + dy.view(B, CH, D, H, W)
    if strict_maxima_bonus > 0:
        y_max += strict_maxima_bonus * new_nms_mask.to(input.dtype)

    dx_res: torch.Tensor = dx.flip(1).reshape(B, CH, D, H, W,
                                              3).permute(0, 1, 5, 2, 3, 4)
    coords_max: torch.Tensor = grid_global.repeat(B, 1, 1, 1, 1).unsqueeze(1)
    coords_max = coords_max + dx_res

    return coords_max, y_max