示例#1
0
文件: laf.py 项目: jhacsonmeza/kornia
def set_laf_orientation(LAF: torch.Tensor,
                        angles_degrees: torch.Tensor) -> torch.Tensor:
    """Changes the orientation of the LAFs.

    Args:
        LAF: tensor [BxNx2x3].
        angles: tensor BxNx1, in degrees.

    Returns:
        tensor [BxNx2x3].

    Shape:
        - Input: :math: `(B, N, 2, 3)`, `(B, N, 1)`
        - Output: :math: `(B, N, 2, 3)`

    """
    raise_error_if_laf_is_not_valid(LAF)
    B, N = LAF.shape[:2]
    rotmat: torch.Tensor = angle_to_rotation_matrix(angles_degrees).view(
        B * N, 2, 2)
    laf_out: torch.Tensor = torch.cat([
        torch.bmm(make_upright(LAF).view(B * N, 2, 3)[:, :2, :2], rotmat),
        LAF.view(B * N, 2, 3)[:, :2, 2:]
    ],
                                      dim=2).view(B, N, 2, 3)
    return laf_out
示例#2
0
    def forward(self, laf: torch.Tensor,
                img: torch.Tensor) -> torch.Tensor:  # type: ignore
        """
        Args:
            laf: (torch.Tensor), shape [BxNx2x3]
            img: (torch.Tensor), shape [Bx1xHxW]

        Returns:
            laf_out: (torch.Tensor), shape [BxNx2x3] """
        raise_error_if_laf_is_not_valid(laf)
        img_message: str = "Invalid img shape, we expect BxCxHxW. Got: {}".format(
            img.shape)
        if not torch.is_tensor(img):
            raise TypeError("img type is not a torch.Tensor. Got {}".format(
                type(img)))
        if len(img.shape) != 4:
            raise ValueError(img_message)
        if laf.size(0) != img.size(0):
            raise ValueError(
                "Batch size of laf and img should be the same. Got {}, {}".
                format(img.size(0), laf.size(0)))
        B, N = laf.shape[:2]
        patches: torch.Tensor = extract_patches_from_pyramid(
            img, laf, self.patch_size).view(-1, 1, self.patch_size,
                                            self.patch_size)
        angles_radians: torch.Tensor = self.angle_detector(patches).view(B, N)
        rotmat: torch.Tensor = angle_to_rotation_matrix(
            rad2deg(angles_radians)).view(B * N, 2, 2)

        laf_out: torch.Tensor = torch.cat([
            torch.bmm(make_upright(laf).view(B * N, 2, 3)[:, :2, :2], rotmat),
            laf.view(B * N, 2, 3)[:, :2, 2:]
        ],
                                          dim=2).view(B, N, 2, 3)
        return laf_out
示例#3
0
    def detect(self, img: torch.Tensor, num_feats: int) -> Tuple[torch.Tensor, torch.Tensor]:
        sp, sigmas, pix_dists = self.scale_pyr(img)
        all_responses = []
        all_lafs = []
        for oct_idx, octave in enumerate(sp):
            sigmas_oct = sigmas[oct_idx]
            pix_dists_oct = pix_dists[oct_idx]
            B, L, CH, H, W = octave.size()
            # Run response function
            oct_resp = self.resp(octave.view(B * L, CH, H, W), sigmas_oct.view(-1)).view(B, L, CH, H, W)

            # We want nms for scale responses, so reorder to (B, CH, L, H, W)
            oct_resp = oct_resp.permute(0, 2, 1, 3, 4)

            # Differentiable nms
            coord_max, response_max = self.nms(oct_resp)

            # Now, lets crop out some small responses
            responses_flatten = response_max.view(response_max.size(0), -1)  # [B * N, 3]
            max_coords_flatten = coord_max.view(response_max.size(0), 3, -1).permute(0, 2, 1)  # [B, N, 3]

            if responses_flatten.size(1) > num_feats:
                resp_flat_best, idxs = torch.topk(responses_flatten, k=num_feats, dim=1)
                max_coords_best = torch.gather(max_coords_flatten, 1, idxs.unsqueeze(-1).repeat(1, 1, 3))
            else:
                resp_flat_best = responses_flatten
                max_coords_best = max_coords_flatten
            B, N = resp_flat_best.size()

            # Converts scale level index from ConvSoftArgmax3d to the actual scale, using the sigmas
            max_coords_best = _scale_index_to_scale(max_coords_best, sigmas_oct)

            # Create local affine frames (LAFs)
            rotmat = angle_to_rotation_matrix(torch.zeros(B, N).to(max_coords_best.device).to(max_coords_best.dtype))
            current_lafs = torch.cat([self.mr_size * max_coords_best[:, :, 0].view(B, N, 1, 1) * rotmat,
                                      max_coords_best[:, :, 1:3].view(B, N, 2, 1)], dim=3)
            # Normalize LAFs
            current_lafs = normalize_laf(current_lafs, octave[:, 0])  # We don`t need # of scale levels, only shape

            all_responses.append(resp_flat_best)
            all_lafs.append(current_lafs)

        # Sort and keep best n
        responses: torch.Tensor = torch.cat(all_responses, dim=1)
        lafs: torch.Tensor = torch.cat(all_lafs, dim=1)
        responses, idxs = torch.topk(responses, k=num_feats, dim=1)
        lafs = torch.gather(lafs, 1, idxs.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 2, 3))
        return responses, denormalize_laf(lafs, img)