def test_warp_grid_offset_x1_depth1(self, batch_size, device, dtype): height, width = 3, 5 # output shape pinhole_src, pinhole_dst = self._create_pinhole_pair( batch_size, device, dtype) pinhole_dst.tx += 1.0 # apply offset to tx # initialize depth to one depth_src = torch.ones(batch_size, 1, height, width, device=device, dtype=dtype) # create warper, initialize projection matrices and warp grid warper = kornia.DepthWarper(pinhole_dst, height, width) warper.compute_projection_matrix(pinhole_src) grid_warped = warper.warp_grid(depth_src) assert grid_warped.shape == (batch_size, height, width, 2) # normalize base meshgrid grid = warper.grid[..., :2].to(device=device, dtype=dtype) grid_norm = normalize_pixel_coordinates(grid, height, width) # check offset in x-axis assert_allclose(grid_warped[..., -2, 0], grid_norm[..., -1, 0].repeat(batch_size, 1), atol=1e-4, rtol=1e-4) # check that y-axis remain the same assert_allclose(grid_warped[..., -1, 1], grid_norm[..., -1, 1].repeat(batch_size, 1), rtol=1e-4, atol=1e-4)
def test_warp_grid_offset_x1y1_depth1(self, batch_size): height, width = 3, 5 # output shape pinhole_src, pinhole_dst = self._create_pinhole_pair(batch_size) pinhole_dst.tx += 1. # apply offset to tx pinhole_dst.ty += 1. # apply offset to ty # initialize depth to one depth_src = torch.ones(batch_size, 1, height, width) # create warper, initialize projection matrices and warp grid warper = kornia.DepthWarper(pinhole_dst, height, width) warper.compute_projection_matrix(pinhole_src) grid_warped = warper.warp_grid(depth_src) assert grid_warped.shape == (batch_size, height, width, 2) # normalize base meshgrid grid = warper.grid[..., :2] grid_norm = normalize_pixel_coordinates(grid, height, width) # check offset in x-axis assert utils.check_equal_torch(grid_norm[..., -1, 0], grid_warped[..., -2, 0]) # check that y-axis remain the same assert utils.check_equal_torch(grid_norm[..., -1, :, 1], grid_warped[..., -2, :, 1])
def remap(tensor: torch.Tensor, map_x: torch.Tensor, map_y: torch.Tensor, align_corners: bool = False) -> torch.Tensor: r"""Applies a generic geometrical transformation to a tensor. The function remap transforms the source tensor using the specified map: .. math:: \text{dst}(x, y) = \text{src}(map_x(x, y), map_y(x, y)) Args: tensor (torch.Tensor): the tensor to remap with shape (B, D, H, W). Where D is the number of channels. map_x (torch.Tensor): the flow in the x-direction in pixel coordinates. The tensor must be in the shape of (B, H, W). map_y (torch.Tensor): the flow in the y-direction in pixel coordinates. The tensor must be in the shape of (B, H, W). align_corners(bool): interpolation flag. Default: False. See https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for detail Returns: torch.Tensor: the warped tensor. Example: >>> from kornia.utils import create_meshgrid >>> grid = create_meshgrid(2, 2, False) # 1x2x2x2 >>> grid += 1 # apply offset in both directions >>> input = torch.ones(1, 1, 2, 2) >>> remap(input, grid[..., 0], grid[..., 1], align_corners=True) # 1x1x2x2 tensor([[[[1., 0.], [0., 0.]]]]) """ if not isinstance(tensor, torch.Tensor): raise TypeError("Input tensor type is not a torch.Tensor. Got {}" .format(type(tensor))) if not isinstance(map_x, torch.Tensor): raise TypeError("Input map_x type is not a torch.Tensor. Got {}" .format(type(map_x))) if not isinstance(map_y, torch.Tensor): raise TypeError("Input map_y type is not a torch.Tensor. Got {}" .format(type(map_y))) if not tensor.shape[-2:] == map_x.shape[-2:] == map_y.shape[-2:]: raise ValueError("Inputs last two dimensions must match.") batch_size, _, height, width = tensor.shape # grid_sample need the grid between -1/1 map_xy: torch.Tensor = torch.stack([map_x, map_y], dim=-1) map_xy_norm: torch.Tensor = normalize_pixel_coordinates( map_xy, height, width) # simulate broadcasting since grid_sample does not support it map_xy_norm = map_xy_norm.expand(batch_size, -1, -1, -1) # warp ans return tensor_warped: torch.Tensor = F.grid_sample(tensor, map_xy_norm, align_corners=align_corners) # type: ignore return tensor_warped
def remap(tensor: torch.Tensor, map_x: torch.Tensor, map_y: torch.Tensor) -> torch.Tensor: r"""Applies a generic geometrical transformation to a tensor. The function remap transforms the source tensor using the specified map: .. math:: \text{dst}(x, y) = \text{src}(map_x(x, y), map_y(x, y)) Args: tensor (torch.Tensor): the tensor to remap with shape (B, D, H, W). Where D is the number of channels. map_x (torch.Tensor): the flow in the x-direction in pixel coordinates. The tensor must be in the shape of (B, H, W). map_y (torch.Tensor): the flow in the y-direction in pixel coordinates. The tensor must be in the shape of (B, H, W). Returns: torch.Tensor: the warped tensor. Example: >>> grid = kornia.utils.create_meshgrid(2, 2, False) # 1x2x2x2 >>> grid += 1 # apply offset in both directions >>> input = torch.ones(1, 1, 2, 2) >>> kornia.remap(input, grid[..., 0], grid[..., 1]) # 1x1x2x2 tensor([[[[1., 0.], [0., 0.]]]]) """ if not torch.is_tensor(tensor): raise TypeError( "Input tensor type is not a torch.Tensor. Got {}".format( type(tensor))) if not torch.is_tensor(map_x): raise TypeError( "Input map_x type is not a torch.Tensor. Got {}".format( type(map_x))) if not torch.is_tensor(map_y): raise TypeError( "Input map_y type is not a torch.Tensor. Got {}".format( type(map_y))) if not tensor.shape[-2:] == map_x.shape[-2:] == map_y.shape[-2:]: raise ValueError("Inputs last two dimensions must match.") batch_size, _, height, width = tensor.shape # grid_sample need the grid between -1/1 map_xy: torch.Tensor = torch.stack([map_x, map_y], dim=-1) map_xy_norm: torch.Tensor = normalize_pixel_coordinates( map_xy, height, width) # simulate broadcasting since grid_sample does not support it map_xy_norm = map_xy_norm.expand(batch_size, -1, -1, -1) # warp ans return tensor_warped: torch.Tensor = F.grid_sample(tensor, map_xy_norm) return tensor_warped
def _get_window_grid_kernel2d( h: int, w: int, device: torch.device = torch.device('cpu')) -> torch.Tensor: r"""Helper function, which generates a kernel to with window coordinates, residual to window center. Args: h: kernel height. : kernel width. device: device, on which generate. Returns: conv_kernel [2x1xhxw] """ window_grid2d = create_meshgrid(h, w, False, device=device) window_grid2d = normalize_pixel_coordinates(window_grid2d, h, w) conv_kernel = window_grid2d.permute(3, 0, 1, 2) return conv_kernel
def warp_grid(self, depth_src: torch.Tensor) -> torch.Tensor: """Computes a grid for warping a given the depth from the reference pinhole camera. The function `compute_projection_matrix` has to be called beforehand in order to have precomputed the relative projection matrices encoding the relative pose and the intrinsics between the reference and a non reference camera. """ # TODO: add type and value checkings if self._dst_proj_src is None or self._pinhole_src is None: raise ValueError("Please, call compute_projection_matrix.") if len(depth_src.shape) != 4: raise ValueError("Input depth_src has to be in the shape of " "Bx1xHxW. Got {}".format(depth_src.shape)) # unpack depth attributes batch_size, _, height, width = depth_src.shape device: torch.device = depth_src.device dtype: torch.dtype = depth_src.dtype # expand the base coordinate grid according to the input batch size pixel_coords: torch.Tensor = self.grid.to(device=device, dtype=dtype).expand( batch_size, -1, -1, -1) # BxHxWx3 # reproject the pixel coordinates to the camera frame cam_coords_src: torch.Tensor = pixel2cam( depth_src, self._pinhole_src.intrinsics_inverse().to(device=device, dtype=dtype), pixel_coords) # BxHxWx3 # reproject the camera coordinates to the pixel pixel_coords_src: torch.Tensor = cam2pixel( cam_coords_src, self._dst_proj_src.to(device=device, dtype=dtype)) # (B*N)xHxWx2 # normalize between -1 and 1 the coordinates pixel_coords_src_norm: torch.Tensor = normalize_pixel_coordinates( pixel_coords_src, self.height, self.width) return pixel_coords_src_norm
def remap( tensor: torch.Tensor, map_x: torch.Tensor, map_y: torch.Tensor, mode: str = 'bilinear', padding_mode: str = 'zeros', align_corners: Optional[bool] = None, normalized_coordinates: bool = False, ) -> torch.Tensor: r"""Applies a generic geometrical transformation to a tensor. .. image:: _static/img/remap.png The function remap transforms the source tensor using the specified map: .. math:: \text{dst}(x, y) = \text{src}(map_x(x, y), map_y(x, y)) Args: tensor: the tensor to remap with shape (B, D, H, W). Where D is the number of channels. map_x: the flow in the x-direction in pixel coordinates. The tensor must be in the shape of (B, H, W). map_y: the flow in the y-direction in pixel coordinates. The tensor must be in the shape of (B, H, W). mode: interpolation mode to calculate output values ``'bilinear'`` | ``'nearest'``. padding_mode: padding mode for outside grid values ``'zeros'`` | ``'border'`` | ``'reflection'``. align_corners: mode for grid_generation. normalized_coordinates: whether the input coordinates are normalised in the range of [-1, 1]. Returns: the warped tensor with same shape as the input grid maps. Example: >>> from kornia.utils import create_meshgrid >>> grid = create_meshgrid(2, 2, False) # 1x2x2x2 >>> grid += 1 # apply offset in both directions >>> input = torch.ones(1, 1, 2, 2) >>> remap(input, grid[..., 0], grid[..., 1], align_corners=True) # 1x1x2x2 tensor([[[[1., 0.], [0., 0.]]]]) .. note:: This function is often used in conjuntion with :func:`create_meshgrid`. """ if not isinstance(tensor, torch.Tensor): raise TypeError( "Input tensor type is not a torch.Tensor. Got {}".format( type(tensor))) if not isinstance(map_x, torch.Tensor): raise TypeError( "Input map_x type is not a torch.Tensor. Got {}".format( type(map_x))) if not isinstance(map_y, torch.Tensor): raise TypeError( "Input map_y type is not a torch.Tensor. Got {}".format( type(map_y))) if not tensor.shape[-2:] == map_x.shape[-2:] == map_y.shape[-2:]: raise ValueError("Inputs last two dimensions must match.") batch_size, _, height, width = tensor.shape # grid_sample need the grid between -1/1 map_xy: torch.Tensor = torch.stack([map_x, map_y], dim=-1) # normalize coordinates if not already normalized if not normalized_coordinates: map_xy = normalize_pixel_coordinates(map_xy, height, width) # simulate broadcasting since grid_sample does not support it map_xy_norm: torch.Tensor = map_xy.expand(batch_size, -1, -1, -1) # warp ans return tensor_warped: torch.Tensor = F.grid_sample(tensor, map_xy_norm, mode=mode, padding_mode=padding_mode, align_corners=align_corners) return tensor_warped
def warp_frame_depth( image_src: torch.Tensor, depth_dst: torch.Tensor, src_trans_dst: torch.Tensor, camera_matrix: torch.Tensor, normalize_points: bool = False, ) -> torch.Tensor: """Warp a tensor from a source to destination frame by the depth in the destination. Compute 3d points from the depth, transform them using given transformation, then project the point cloud to an image plane. Args: image_src: image tensor in the source frame with shape :math:`(B,D,H,W)`. depth_dst: depth tensor in the destination frame with shape :math:`(B,1,H,W)`. src_trans_dst: transformation matrix from destination to source with shape :math:`(B,4,4)`. camera_matrix: tensor containing the camera intrinsics with shape :math:`(B,3,3)`. normalize_points: whether to normalise the pointcloud. This must be set to ``True`` when the depth is represented as the Euclidean ray length from the camera position. Return: the warped tensor in the source frame with shape :math:`(B,3,H,W)`. """ if not isinstance(image_src, torch.Tensor): raise TypeError( f"Input image_src type is not a torch.Tensor. Got {type(image_src)}." ) if not len(image_src.shape) == 4: raise ValueError( f"Input image_src musth have a shape (B, D, H, W). Got: {image_src.shape}" ) if not isinstance(depth_dst, torch.Tensor): raise TypeError( f"Input depht_dst type is not a torch.Tensor. Got {type(depth_dst)}." ) if not len(depth_dst.shape) == 4 and depth_dst.shape[-3] == 1: raise ValueError( f"Input depth_dst musth have a shape (B, 1, H, W). Got: {depth_dst.shape}" ) if not isinstance(src_trans_dst, torch.Tensor): raise TypeError(f"Input src_trans_dst type is not a torch.Tensor. " f"Got {type(src_trans_dst)}.") if not len(src_trans_dst.shape) == 3 and src_trans_dst.shape[-2:] == (3, 3): raise ValueError(f"Input src_trans_dst must have a shape (B, 3, 3). " f"Got: {src_trans_dst.shape}.") if not isinstance(camera_matrix, torch.Tensor): raise TypeError(f"Input camera_matrix type is not a torch.Tensor. " f"Got {type(camera_matrix)}.") if not len(camera_matrix.shape) == 3 and camera_matrix.shape[-2:] == (3, 3): raise ValueError(f"Input camera_matrix must have a shape (B, 3, 3). " f"Got: {camera_matrix.shape}.") # unproject source points to camera frame points_3d_dst: torch.Tensor = depth_to_3d(depth_dst, camera_matrix, normalize_points) # Bx3xHxW # transform points from source to destionation points_3d_dst = points_3d_dst.permute(0, 2, 3, 1) # BxHxWx3 # apply transformation to the 3d points points_3d_src = transform_points(src_trans_dst[:, None], points_3d_dst) # BxHxWx3 # project back to pixels camera_matrix_tmp: torch.Tensor = camera_matrix[:, None, None] # Bx1x1xHxW points_2d_src: torch.Tensor = project_points(points_3d_src, camera_matrix_tmp) # BxHxWx2 # normalize points between [-1 / 1] height, width = depth_dst.shape[-2:] points_2d_src_norm: torch.Tensor = normalize_pixel_coordinates( points_2d_src, height, width) # BxHxWx2 return F.grid_sample(image_src, points_2d_src_norm, align_corners=True) # type: ignore
def conv_soft_argmax2d( input: torch.Tensor, kernel_size: Tuple[int, int] = (3, 3), stride: Tuple[int, int] = (1, 1), padding: Tuple[int, int] = (1, 1), temperature: Union[torch.Tensor, float] = torch.tensor(1.0), normalized_coordinates: bool = True, eps: float = 1e-8, output_value: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute the convolutional spatial Soft-Argmax 2D over the windows of a given heatmap. .. math:: ij(X) = \frac{\sum{(i,j)} * exp(x / T) \in X} {\sum{exp(x / T) \in X}} .. math:: val(X) = \frac{\sum{x * exp(x / T) \in X}} {\sum{exp(x / T) \in X}} where :math:`T` is temperature. Args: input: the given heatmap with shape :math:`(N, C, H_{in}, W_{in})`. kernel_size: the size of the window. stride: the stride of the window. padding: input zero padding. temperature: factor to apply to input. normalized_coordinates: whether to return the coordinates normalized in the range of :math:`[-1, 1]`. Otherwise, it will return the coordinates in the range of the input shape. eps: small value to avoid zero division. output_value: if True, val is output, if False, only ij. Returns: Function has two outputs - argmax coordinates and the softmaxpooled heatmap values themselves. On each window, the function computed returns with shapes :math:`(N, C, 2, H_{out}, W_{out})`, :math:`(N, C, H_{out}, W_{out})`, where .. math:: H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor .. math:: W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor Examples: >>> input = torch.randn(20, 16, 50, 32) >>> nms_coords, nms_val = conv_soft_argmax2d(input, (3,3), (2,2), (1,1), output_value=True) """ if not torch.is_tensor(input): raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") if not len(input.shape) == 4: raise ValueError( f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") if temperature <= 0: raise ValueError( f"Temperature should be positive float or tensor. Got: {temperature}" ) b, c, h, w = input.shape kx, ky = kernel_size device: torch.device = input.device dtype: torch.dtype = input.dtype input = input.view(b * c, 1, h, w) center_kernel: torch.Tensor = _get_center_kernel2d(kx, ky, device).to(dtype) window_kernel: torch.Tensor = _get_window_grid_kernel2d(kx, ky, device).to(dtype) # applies exponential normalization trick # https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/ # https://github.com/pytorch/pytorch/blob/bcb0bb7e0e03b386ad837015faba6b4b16e3bfb9/aten/src/ATen/native/SoftMax.cpp#L44 x_max = F.adaptive_max_pool2d(input, (1, 1)) # max is detached to prevent undesired backprop loops in the graph x_exp = ((input - x_max.detach()) / temperature).exp() # F.avg_pool2d(.., divisor_override = 1.0) - proper way for sum pool in PyTorch 1.2. # Not available yet in version 1.0, so let's do manually pool_coef: float = float(kx * ky) # softmax denominator den = pool_coef * F.avg_pool2d( x_exp, kernel_size, stride=stride, padding=padding) + eps x_softmaxpool = pool_coef * F.avg_pool2d( x_exp * input, kernel_size, stride=stride, padding=padding) / den x_softmaxpool = x_softmaxpool.view(b, c, x_softmaxpool.size(2), x_softmaxpool.size(3)) # We need to output also coordinates # Pooled window center coordinates grid_global: torch.Tensor = create_meshgrid(h, w, False, device).to(dtype).permute( 0, 3, 1, 2) grid_global_pooled = F.conv2d(grid_global, center_kernel, stride=stride, padding=padding) # Coordinates of maxima residual to window center # prepare kernel coords_max: torch.Tensor = F.conv2d(x_exp, window_kernel, stride=stride, padding=padding) coords_max = coords_max / den.expand_as(coords_max) coords_max = coords_max + grid_global_pooled.expand_as(coords_max) # [:,:, 0, ...] is x # [:,:, 1, ...] is y if normalized_coordinates: coords_max = normalize_pixel_coordinates( coords_max.permute(0, 2, 3, 1), h, w) coords_max = coords_max.permute(0, 3, 1, 2) # Back B*C -> (b, c) coords_max = coords_max.view(b, c, 2, coords_max.size(2), coords_max.size(3)) if output_value: return coords_max, x_softmaxpool return coords_max