def spatial_soft_argmax2d(
    input: torch.Tensor,
    temperature: torch.Tensor = torch.tensor(1.0),
    normalized_coordinates: bool = True,
    eps: float = 1e-8,
) -> torch.Tensor:
    r"""Function that computes the Spatial Soft-Argmax 2D of a given input heatmap.

    Args:
        input: the given heatmap with shape :math:`(B, N, H, W)`.
        temperature: factor to apply to input.
        normalized_coordinates: whether to return the coordinates normalized in the range of :math:`[-1, 1]`.
            Otherwise, it will return the coordinates in the range of the input shape.
        eps: small value to avoid zero division.

    Returns:
        the index of the maximum 2d coordinates of the give map :math:`(B, N, 2)`.
        The output order is x-coord and y-coord.

    Examples:
        >>> input = torch.tensor([[[
        ... [0., 0., 0.],
        ... [0., 10., 0.],
        ... [0., 0., 0.]]]])
        >>> spatial_soft_argmax2d(input, normalized_coordinates=False)
        tensor([[[1.0000, 1.0000]]])
    """
    input_soft: torch.Tensor = dsnt.spatial_softmax2d(input, temperature)
    output: torch.Tensor = dsnt.spatial_expectation2d(input_soft, normalized_coordinates)
    return output
Exemple #2
0
def spatial_soft_argmax2d(
        input: torch.Tensor,
        temperature: torch.Tensor = torch.tensor(1.0),
        normalized_coordinates: bool = True,
        eps: float = 1e-8) -> torch.Tensor:
    r"""Function that computes the Spatial Soft-Argmax 2D
    of a given input heatmap.

    Returns the index of the maximum 2d coordinates of the give map.
    The output order is x-coord and y-coord.

    Arguments:
        temperature (torch.Tensor): factor to apply to input. Default is 1.
        normalized_coordinates (bool): whether to return the
          coordinates normalized in the range of [-1, 1]. Otherwise,
          it will return the coordinates in the range of the input shape.
          Default is True.
        eps (float): small value to avoid zero division. Default is 1e-8.

    Shape:
        - Input: :math:`(B, N, H, W)`
        - Output: :math:`(B, N, 2)`

    Examples:
        >>> input = torch.tensor([[[
            [0., 0., 0.],
            [0., 10., 0.],
            [0., 0., 0.]]]])
        >>> coords = kornia.spatial_soft_argmax2d(input, False)
        tensor([[[1.0000, 1.0000]]])
    """
    input_soft: torch.Tensor = dsnt.spatial_softmax2d(input, temperature)
    output: torch.Tensor = dsnt.spatial_expectation2d(input_soft, normalized_coordinates)
    return output
 def forward(self, inputs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
     """NCHW -> NC2."""
     # Convert logits to probabilities.
     # TODO(ycho): Consider if this is preferable to elementwise sigmoid.
     prob = spatial_softmax2d(inputs, temperature=self.temperature)
     kpts = spatial_expectation2d(prob, normalized_coordinates=True)
     return (prob, kpts)
Exemple #4
0
    def forward(self, feat_f0, feat_f1, data):
        """
        Args:
            feat0 (torch.Tensor): [M, WW, C]
            feat1 (torch.Tensor): [M, WW, C]
            data (dict)
        Update:
            data (dict):{
                'expec_f' (torch.Tensor): [M, 3],
                'mkpts0_f' (torch.Tensor): [M, 2],
                'mkpts1_f' (torch.Tensor): [M, 2]}
        """
        M, WW, C = feat_f0.shape
        W = int(math.sqrt(WW))
        scale = data['hw0_i'][0] / data['hw0_f'][0]
        self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale

        # corner case: if no coarse matches found
        if M == 0:
            if self.training:
                raise ValueError("M >0, when training, see coarse_matching.py")
            # logger.warning('No matches found in coarse-level.')
            data.update({
                'expec_f': torch.empty(0, 3, device=feat_f0.device),
                'mkpts0_f': data['mkpts0_c'],
                'mkpts1_f': data['mkpts1_c'],
            })
            return

        feat_f0_picked = feat_f0_picked = feat_f0[:, WW // 2, :]
        sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
        softmax_temp = 1. / C**.5
        heatmap = torch.softmax(softmax_temp * sim_matrix,
                                dim=1).view(-1, W, W)

        # compute coordinates from heatmap
        coords_normalized = dsnt.spatial_expectation2d(heatmap[None],
                                                       True)[0]  # [M, 2]
        grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(
            1, -1, 2)  # [1, WW, 2]

        # compute std over <x, y>
        var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1),
                        dim=1) - coords_normalized**2  # [M, 2]
        std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)),
                        -1)  # [M]  clamp needed for numerical stability

        # for fine-level supervision
        data.update(
            {'expec_f':
             torch.cat([coords_normalized, std.unsqueeze(1)], -1)})

        # compute absolute kpt coords
        self.get_fine_match(coords_normalized, data)