def homography_warp3d(patch_src: torch.Tensor, src_homo_dst: torch.Tensor, dsize: Tuple[int, int, int], mode: str = 'bilinear', padding_mode: str = 'zeros', align_corners: bool = False, normalized_coordinates: bool = True) -> torch.Tensor: r"""Warp image patchs or tensors by normalized 3D homographies. Args: patch_src (torch.Tensor): The image or tensor to warp. Should be from source of shape :math:`(N, C, D, H, W)`. src_homo_dst (torch.Tensor): The homography or stack of homographies from destination to source of shape :math:`(N, 4, 4)`. dsize (Tuple[int, int, int]): The height and width of the image to warp. mode (str): interpolation mode to calculate output values 'bilinear' | 'nearest'. Default: 'bilinear'. padding_mode (str): padding mode for outside grid values 'zeros' | 'border' | 'reflection'. Default: 'zeros'. align_corners(bool): interpolation flag. Default: False. See https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for details. normalized_coordinates (bool): Whether the homography assumes [-1, 1] normalized coordinates or not. Return: torch.Tensor: Patch sampled at locations from source to destination. Example: >>> input = torch.rand(1, 3, 32, 32) >>> homography = torch.eye(3).view(1, 3, 3) >>> output = homography_warp(input, homography, (32, 32)) """ if not src_homo_dst.device == patch_src.device: raise TypeError("Patch and homography must be on the same device. \ Got patch.device: {} src_H_dst.device: {}.".format( patch_src.device, src_homo_dst.device)) depth, height, width = dsize grid = create_meshgrid3d(depth, height, width, normalized_coordinates=normalized_coordinates, device=patch_src.device) warped_grid = warp_grid3d(grid, src_homo_dst) return F.grid_sample(patch_src, warped_grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners)
def conv_quad_interp3d(input: torch.Tensor, strict_maxima_bonus: float = 1.0, eps: float = 1e-6): r"""Function that computes the single iteration of quadratic interpolation of of the extremum (max or min) location and value per each 3x3x3 window which contains strict extremum, similar to one done is SIFT Args: strict_maxima_bonus (float): pixels, which are strict maxima will score (1 + strict_maxima_bonus) * value. This is needed for mimic behavior of strict NMS in classic local features eps (float): parameter to control the hessian matrix ill-condition number. Shape: - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` - Output: :math:`(N, C, 3, D_{out}, H_{out}, W_{out})`, :math:`(N, C, D_{out}, H_{out}, W_{out})`, where .. math:: D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor .. math:: H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor .. math:: W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor Examples: >>> input = torch.randn(20, 16, 3, 50, 32) >>> nms_coords, nms_val = conv_quad_interp3d(input, 1.0) """ if not torch.is_tensor(input): raise TypeError("Input type is not a torch.Tensor. Got {}" .format(type(input))) if not len(input.shape) == 5: raise ValueError("Invalid input shape, we expect BxCxDxHxW. Got: {}" .format(input.shape)) B, CH, D, H, W = input.shape dev: torch.device = input.device grid_global: torch.Tensor = create_meshgrid3d(D, H, W, False, device=input.device).permute(0, 4, 1, 2, 3) grid_global = grid_global.to(input.dtype) # to determine the location we are solving system of linear equations Ax = b, where b is 1st order gradient # and A is Hessian matrix b: torch.Tensor = kornia.filters.spatial_gradient3d(input, order=1, mode='diff') # b = b.permute(0, 1, 3, 4, 5, 2).reshape(-1, 3, 1) A: torch.Tensor = kornia.filters.spatial_gradient3d(input, order=2, mode='diff') A = A.permute(0, 1, 3, 4, 5, 2).reshape(-1, 6) dxx = A[..., 0] dyy = A[..., 1] dss = A[..., 2] dxy = A[..., 3] dys = A[..., 4] dxs = A[..., 5] # for the Hessian Hes = torch.stack([dxx, dxy, dxs, dxy, dyy, dys, dxs, dys, dss]).view(-1, 3, 3) Hes += torch.eye(3, device=Hes.device)[None] * eps nms_mask: torch.Tensor = kornia.feature.nms3d(input, (3, 3, 3), True) x_solved: torch.Tensor = torch.zeros_like(b) x_solved_masked, _ = torch.solve(b[nms_mask.view(-1)], Hes[nms_mask.view(-1)]) x_solved.masked_scatter_(nms_mask.view(-1, 1, 1), x_solved_masked) dx: torch.Tensor = -x_solved # Ignore ones, which are far from window, dx[(dx.abs().max(dim=1, keepdim=True)[0] > 0.7).view(-1), :, :] = 0 dy: torch.Tensor = 0.5 * torch.bmm(b.permute(0, 2, 1), dx) y_max = input + dy.view(B, CH, D, H, W) if strict_maxima_bonus > 0: y_max *= (1.0 + strict_maxima_bonus * nms_mask.to(input.dtype)) dx_res: torch.Tensor = dx.flip(1).reshape(B, CH, D, H, W, 3).permute(0, 1, 5, 2, 3, 4) coords_max: torch.Tensor = grid_global.repeat(B, 1, 1, 1, 1).unsqueeze(1) coords_max = coords_max + dx_res return coords_max, y_max
def conv_soft_argmax3d(input: torch.Tensor, kernel_size: Tuple[int, int, int] = (3, 3, 3), stride: Tuple[int, int, int] = (1, 1, 1), padding: Tuple[int, int, int] = (1, 1, 1), temperature: Union[torch.Tensor, float] = torch.tensor(1.0), normalized_coordinates: bool = False, eps: float = 1e-8, output_value: bool = True, strict_maxima_bonus: float = 0.0) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Function that computes the convolutional spatial Soft-Argmax 3D over the windows of a given input heatmap. Function has two outputs: argmax coordinates and the softmaxpooled heatmap values themselves. On each window, the function computed is: .. math:: ijk(X) = \frac{\sum{(i,j,k)} * exp(x / T) \in X} {\sum{exp(x / T) \in X}} .. math:: val(X) = \frac{\sum{x * exp(x / T) \in X}} {\sum{exp(x / T) \in X}} where T is temperature. Args: kernel_size (Tuple[int,int,int]): size of the window stride (Tuple[int,int,int]): stride of the window. padding (Tuple[int,int,int]): input zero padding temperature (torch.Tensor): factor to apply to input. Default is 1. normalized_coordinates (bool): whether to return the coordinates normalized in the range of [-1, 1]. Otherwise, it will return the coordinates in the range of the input shape. Default is False. eps (float): small value to avoid zero division. Default is 1e-8. output_value (bool): if True, val is outputed, if False, only ij strict_maxima_bonus (float): pixels, which are strict maxima will score (1 + strict_maxima_bonus) * value. This is needed for mimic behavior of strict NMS in classic local features Shape: - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` - Output: :math:`(N, C, 3, D_{out}, H_{out}, W_{out})`, :math:`(N, C, D_{out}, H_{out}, W_{out})`, where .. math:: D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor .. math:: H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor .. math:: W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor Examples: >>> input = torch.randn(20, 16, 3, 50, 32) >>> nms_coords, nms_val = conv_soft_argmax2d(input, (3, 3, 3), (1, 2, 2), (0, 1, 1)) """ if not torch.is_tensor(input): raise TypeError("Input type is not a torch.Tensor. Got {}" .format(type(input))) if not len(input.shape) == 5: raise ValueError("Invalid input shape, we expect BxCxDxHxW. Got: {}" .format(input.shape)) if temperature <= 0: raise ValueError("Temperature should be positive float or tensor. Got: {}" .format(temperature)) b, c, d, h, w = input.shape kx, ky, kz = kernel_size device: torch.device = input.device dtype: torch.dtype = input.dtype input = input.view(b * c, 1, d, h, w) center_kernel: torch.Tensor = _get_center_kernel3d(kx, ky, kz, device).to(dtype) window_kernel: torch.Tensor = _get_window_grid_kernel3d(kx, ky, kz, device).to(dtype) # applies exponential normalization trick # https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/ # https://github.com/pytorch/pytorch/blob/bcb0bb7e0e03b386ad837015faba6b4b16e3bfb9/aten/src/ATen/native/SoftMax.cpp#L44 x_max = F.adaptive_max_pool3d(input, (1, 1, 1)) # max is detached to prevent undesired backprop loops in the graph x_exp = ((input - x_max.detach()) / temperature).exp() pool_coef: float = float(kx * ky * kz) # softmax denominator den = pool_coef * F.avg_pool3d(x_exp.view_as(input), kernel_size, stride=stride, padding=padding) + eps # We need to output also coordinates # Pooled window center coordinates grid_global: torch.Tensor = create_meshgrid3d( d, h, w, False, device=device).to(dtype).permute(0, 4, 1, 2, 3) grid_global_pooled = F.conv3d(grid_global, center_kernel, stride=stride, padding=padding) # Coordinates of maxima residual to window center # prepare kernel coords_max: torch.Tensor = F.conv3d(x_exp, window_kernel, stride=stride, padding=padding) coords_max = coords_max / den.expand_as(coords_max) coords_max = coords_max + grid_global_pooled.expand_as(coords_max) # [:,:, 0, ...] is depth (scale) # [:,:, 1, ...] is x # [:,:, 2, ...] is y if normalized_coordinates: coords_max = normalize_pixel_coordinates3d(coords_max.permute(0, 2, 3, 4, 1), d, h, w) coords_max = coords_max.permute(0, 4, 1, 2, 3) # Back B*C -> (b, c) coords_max = coords_max.view(b, c, 3, coords_max.size(2), coords_max.size(3), coords_max.size(4)) if not output_value: return coords_max x_softmaxpool = pool_coef * F.avg_pool3d(x_exp.view(input.size()) * input, kernel_size, stride=stride, padding=padding) / den if strict_maxima_bonus > 0: in_levels: int = input.size(2) out_levels: int = x_softmaxpool.size(2) skip_levels: int = (in_levels - out_levels) // 2 strict_maxima: torch.Tensor = F.avg_pool3d(kornia.feature.nms3d(input, kernel_size), 1, stride, 0) strict_maxima = strict_maxima[:, :, skip_levels:out_levels - skip_levels] x_softmaxpool *= 1.0 + strict_maxima_bonus * strict_maxima x_softmaxpool = x_softmaxpool.view(b, c, x_softmaxpool.size(2), x_softmaxpool.size(3), x_softmaxpool.size(4)) return coords_max, x_softmaxpool
def conv_soft_argmax3d( input: torch.Tensor, kernel_size: Tuple[int, int, int] = (3, 3, 3), stride: Tuple[int, int, int] = (1, 1, 1), padding: Tuple[int, int, int] = (1, 1, 1), temperature: Union[torch.Tensor, float] = torch.tensor(1.0), normalized_coordinates: bool = False, eps: float = 1e-8, output_value: bool = True ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Function that computes the convolutional spatial Soft-Argmax 3D over the windows of a given input heatmap. Function has two outputs: argmax coordinates and the softmaxpooled heatmap values themselves. On each window, the function computed is: .. math:: ijk(X) = \frac{\sum{(i,j,k)} * exp(x / T) \in X} {\sum{exp(x / T) \in X}} .. math:: val(X) = \frac{\sum{x * exp(x / T) \in X}} {\sum{exp(x / T) \in X}} where T is temperature. Args: kernel_size (Tuple[int,int,int]): size of the window stride (Tuple[int,int,int]): stride of the window. padding (Tuple[int,int,int]): input zero padding temperature (torch.Tensor): factor to apply to input. Default is 1. normalized_coordinates (bool): whether to return the coordinates normalized in the range of [-1, 1]. Otherwise, it will return the coordinates in the range of the input shape. Default is False. eps (float): small value to avoid zero division. Default is 1e-8. output_value (bool): if True, val is outputed, if False, only ij Shape: - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` - Output: :math:`(N, C, 3, D_{out}, H_{out}, W_{out})`, :math:`(N, C, D_{out}, H_{out}, W_{out})`, where .. math:: D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor .. math:: H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor .. math:: W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor Examples: >>> input = torch.randn(20, 16, 3, 50, 32) >>> nms_coords, nms_val = conv_soft_argmax2d(input, (3, 3, 3), (1, 2, 2), (0, 1, 1)) """ if not torch.is_tensor(input): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(input))) if not len(input.shape) == 5: raise ValueError( "Invalid input shape, we expect BxCxDxHxW. Got: {}".format( input.shape)) if temperature <= 0: raise ValueError( "Temperature should be positive float or tensor. Got: {}".format( temperature)) b, c, d, h, w = input.shape input = input.view(b * c, 1, d, h, w) dev: torch.device = input.device center_kernel = _get_center_kernel3d(kernel_size[0], kernel_size[1], kernel_size[2], dev) window_kernel = _get_window_grid_kernel3d(kernel_size[0], kernel_size[1], kernel_size[2], dev) window_kernel = window_kernel.to(input.dtype) x_exp = (input / temperature).exp() pool_coef: float = float(kernel_size[0] * kernel_size[1] * kernel_size[2]) # softmax denominator den = pool_coef * F.avg_pool3d( x_exp.view( input.size()), kernel_size, stride=stride, padding=padding) + 1e-12 x_softmaxpool = pool_coef * F.avg_pool3d(x_exp.view(input.size()) * input, kernel_size, stride=stride, padding=padding) / den x_softmaxpool = x_softmaxpool.view(b, c, x_softmaxpool.size(2), x_softmaxpool.size(3), x_softmaxpool.size(4)) # We need to output also coordinates # Pooled window center coordinates grid_global: torch.Tensor = create_meshgrid3d(d, h, w, False, device=input.device).permute( 0, 4, 1, 2, 3) grid_global = grid_global.to(input.dtype) grid_global_pooled = F.conv3d(grid_global, center_kernel.to(input.dtype), stride=stride, padding=padding) # Coordinates of maxima residual to window center # prepare kernel coords_max: torch.Tensor = F.conv3d(x_exp, window_kernel, stride=stride, padding=padding) coords_max = coords_max / den.expand_as(coords_max) coords_max = coords_max + grid_global_pooled.expand_as(coords_max) # [:,:, 0, ...] is depth (scale) # [:,:, 1, ...] is x # [:,:, 2, ...] is y if normalized_coordinates: coords_max = normalize_pixel_coordinates3d( coords_max.permute(0, 2, 3, 4, 1), d, h, w) coords_max = coords_max.permute(0, 4, 1, 2, 3) # Back B*C -> (b, c) coords_max = coords_max.view(b, c, 3, coords_max.size(2), coords_max.size(3), coords_max.size(4)) if output_value: return coords_max, x_softmaxpool return coords_max
def conv_quad_interp3d(input: torch.Tensor, strict_maxima_bonus: float = 10.0, eps: float = 1e-7) -> Tuple[torch.Tensor, torch.Tensor]: r"""Compute the single iteration of quadratic interpolation of the extremum (max or min). Args: input: the given heatmap with shape :math:`(N, C, D_{in}, H_{in}, W_{in})`. strict_maxima_bonus: pixels, which are strict maxima will score (1 + strict_maxima_bonus) * value. This is needed for mimic behavior of strict NMS in classic local features eps: parameter to control the hessian matrix ill-condition number. Returns: the location and value per each 3x3x3 window which contains strict extremum, similar to one done is SIFT. :math:`(N, C, 3, D_{out}, H_{out}, W_{out})`, :math:`(N, C, D_{out}, H_{out}, W_{out})`, where .. math:: D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor .. math:: H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor .. math:: W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor Examples: >>> input = torch.randn(20, 16, 3, 50, 32) >>> nms_coords, nms_val = conv_quad_interp3d(input, 1.0) """ if not torch.is_tensor(input): raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") if not len(input.shape) == 5: raise ValueError( f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}") B, CH, D, H, W = input.shape grid_global: torch.Tensor = create_meshgrid3d(D, H, W, False, device=input.device).permute( 0, 4, 1, 2, 3) grid_global = grid_global.to(input.dtype) # to determine the location we are solving system of linear equations Ax = b, where b is 1st order gradient # and A is Hessian matrix b: torch.Tensor = spatial_gradient3d(input, order=1, mode='diff') # b = b.permute(0, 1, 3, 4, 5, 2).reshape(-1, 3, 1) A: torch.Tensor = spatial_gradient3d(input, order=2, mode='diff') A = A.permute(0, 1, 3, 4, 5, 2).reshape(-1, 6) dxx = A[..., 0] dyy = A[..., 1] dss = A[..., 2] dxy = 0.25 * A[..., 3] # normalization to match OpenCV implementation dys = 0.25 * A[..., 4] # normalization to match OpenCV implementation dxs = 0.25 * A[..., 5] # normalization to match OpenCV implementation Hes = torch.stack([dxx, dxy, dxs, dxy, dyy, dys, dxs, dys, dss], dim=-1).view(-1, 3, 3) if not torch_version_geq(1, 10): # The following is needed to avoid singular cases Hes += torch.rand(Hes[0].size(), device=Hes.device).abs()[None] * eps nms_mask: torch.Tensor = nms3d(input, (3, 3, 3), True) x_solved: torch.Tensor = torch.zeros_like(b) x_solved_masked, _, solved_correctly = safe_solve_with_mask( b[nms_mask.view(-1)], Hes[nms_mask.view(-1)]) # Kill those points, where we cannot solve new_nms_mask = nms_mask.masked_scatter(nms_mask, solved_correctly) x_solved.masked_scatter_(new_nms_mask.view(-1, 1, 1), x_solved_masked[solved_correctly]) dx: torch.Tensor = -x_solved # Ignore ones, which are far from window center mask1 = dx.abs().max(dim=1, keepdim=True)[0] > 0.7 dx.masked_fill_(mask1.expand_as(dx), 0) dy: torch.Tensor = 0.5 * torch.bmm(b.permute(0, 2, 1), dx) y_max = input + dy.view(B, CH, D, H, W) if strict_maxima_bonus > 0: y_max += strict_maxima_bonus * new_nms_mask.to(input.dtype) dx_res: torch.Tensor = dx.flip(1).reshape(B, CH, D, H, W, 3).permute(0, 1, 5, 2, 3, 4) coords_max: torch.Tensor = grid_global.repeat(B, 1, 1, 1, 1).unsqueeze(1) coords_max = coords_max + dx_res return coords_max, y_max