Пример #1
0
    def detect(self, img: torch.Tensor, num_feats: int) -> Tuple[torch.Tensor, torch.Tensor]:
        sp, sigmas, pix_dists = self.scale_pyr(img)
        all_responses = []
        all_lafs = []
        for oct_idx, octave in enumerate(sp):
            sigmas_oct = sigmas[oct_idx]
            pix_dists_oct = pix_dists[oct_idx]
            B, L, CH, H, W = octave.size()
            # Run response function
            oct_resp = self.resp(octave.view(B * L, CH, H, W), sigmas_oct.view(-1)).view(B, L, CH, H, W)

            # We want nms for scale responses, so reorder to (B, CH, L, H, W)
            oct_resp = oct_resp.permute(0, 2, 1, 3, 4)

            # Differentiable nms
            coord_max, response_max = self.nms(oct_resp)

            # Now, lets crop out some small responses
            responses_flatten = response_max.view(response_max.size(0), -1)  # [B * N, 3]
            max_coords_flatten = coord_max.view(response_max.size(0), 3, -1).permute(0, 2, 1)  # [B, N, 3]

            if responses_flatten.size(1) > num_feats:
                resp_flat_best, idxs = torch.topk(responses_flatten, k=num_feats, dim=1)
                max_coords_best = torch.gather(max_coords_flatten, 1, idxs.unsqueeze(-1).repeat(1, 1, 3))
            else:
                resp_flat_best = responses_flatten
                max_coords_best = max_coords_flatten
            B, N = resp_flat_best.size()

            # Converts scale level index from ConvSoftArgmax3d to the actual scale, using the sigmas
            max_coords_best = _scale_index_to_scale(max_coords_best, sigmas_oct)

            # Create local affine frames (LAFs)
            rotmat = angle_to_rotation_matrix(torch.zeros(B, N).to(max_coords_best.device).to(max_coords_best.dtype))
            current_lafs = torch.cat([self.mr_size * max_coords_best[:, :, 0].view(B, N, 1, 1) * rotmat,
                                      max_coords_best[:, :, 1:3].view(B, N, 2, 1)], dim=3)
            # Normalize LAFs
            current_lafs = normalize_laf(current_lafs, octave[:, 0])  # We don`t need # of scale levels, only shape

            all_responses.append(resp_flat_best)
            all_lafs.append(current_lafs)

        # Sort and keep best n
        responses: torch.Tensor = torch.cat(all_responses, dim=1)
        lafs: torch.Tensor = torch.cat(all_lafs, dim=1)
        responses, idxs = torch.topk(responses, k=num_feats, dim=1)
        lafs = torch.gather(lafs, 1, idxs.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 2, 3))
        return responses, denormalize_laf(lafs, img)
Пример #2
0
def extract_patches_from_pyramid(
    img: torch.Tensor, laf: torch.Tensor, PS: int = 32,
    normalize_lafs_before_extraction: bool = True
) -> torch.Tensor:
    """Extract patches defined by LAFs from image tensor.
    Copied from kornia.feature.laf.extract_patches_from_pyramid with one minor
    difference - highlighted below.
    """
    raise_error_if_laf_is_not_valid(laf)
    if normalize_lafs_before_extraction:
        nlaf: torch.Tensor = normalize_laf(laf, img)
    else:
        nlaf = laf
    B, N, _, _ = laf.size()
    _, ch, h, w = img.size()
    scale = 2.0 * get_laf_scale(denormalize_laf(nlaf, img)) / float(PS)
    pyr_idx = scale.log2().relu().long()  # diff: floor instead of round
    cur_img = img
    cur_pyr_level = 0
    out = torch.zeros(B, N, ch, PS, PS).to(nlaf.dtype).to(nlaf.device)
    while min(cur_img.size(2), cur_img.size(3)) >= PS:
        _, ch, h, w = cur_img.size()
        # for loop temporarily, to be refactored
        for i in range(B):
            scale_mask = (pyr_idx[i] == cur_pyr_level).squeeze()
            if (scale_mask.float().sum()) == 0:
                continue
            scale_mask = (scale_mask > 0).view(-1)
            grid = generate_patch_grid_from_normalized_LAF(
                    cur_img[i: i + 1], nlaf[i: i + 1, scale_mask, :, :], PS)
            patches = F.grid_sample(
                cur_img[i: i + 1].expand(grid.size(0), ch, h, w),
                grid,  # type: ignore
                padding_mode="border",
                align_corners=False,
            )
            out[i].masked_scatter_(scale_mask.view(-1, 1, 1, 1), patches)
        cur_img = pyrdown(cur_img)
        cur_pyr_level += 1
    return out
Пример #3
0
    def detect(
        self,
        img: torch.Tensor,
        num_feats: int,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        dev: torch.device = img.device
        dtype: torch.dtype = img.dtype
        sp, sigmas, pix_dists = self.scale_pyr(img)
        all_responses = []
        all_lafs = []
        for oct_idx, octave in enumerate(sp):
            sigmas_oct = sigmas[oct_idx]
            pix_dists_oct = pix_dists[oct_idx]

            B, L, CH, H, W = octave.size()
            # Run response function
            oct_resp = self.resp(octave.view(B * L, CH, H, W),
                                 sigmas_oct.view(-1))

            # We want nms for scale responses, so reorder to (B, CH, L, H, W)
            oct_resp = oct_resp.view_as(octave).permute(0, 2, 1, 3, 4)

            if mask is not None:
                oct_mask: torch.Tensor = _create_octave_mask(
                    mask, oct_resp.shape)
                oct_resp = oct_mask * oct_resp

            # Differentiable nms
            coord_max, response_max = self.nms(oct_resp)
            if self.minima_are_also_good:
                coord_min, response_min = self.nms(-oct_resp)
                take_min_mask = (response_min > response_max).to(
                    response_max.dtype)
                response_max = response_min * take_min_mask + (
                    1 - take_min_mask) * response_max
                coord_max = coord_min * take_min_mask.unsqueeze(1) + (
                    1 - take_min_mask.unsqueeze(1)) * coord_max

            # Now, lets crop out some small responses
            responses_flatten = response_max.view(response_max.size(0),
                                                  -1)  # [B, N]
            max_coords_flatten = coord_max.view(response_max.size(0), 3,
                                                -1).permute(0, 2,
                                                            1)  # [B, N, 3]

            if responses_flatten.size(1) > num_feats:
                resp_flat_best, idxs = torch.topk(responses_flatten,
                                                  k=num_feats,
                                                  dim=1)
                max_coords_best = torch.gather(
                    max_coords_flatten, 1,
                    idxs.unsqueeze(-1).repeat(1, 1, 3))
            else:
                resp_flat_best = responses_flatten
                max_coords_best = max_coords_flatten
            B, N = resp_flat_best.size()

            # Converts scale level index from ConvSoftArgmax3d to the actual scale, using the sigmas
            max_coords_best = _scale_index_to_scale(max_coords_best,
                                                    sigmas_oct)

            # Create local affine frames (LAFs)
            rotmat = torch.eye(2, dtype=dtype, device=dev).view(1, 1, 2, 2)
            current_lafs = torch.cat([
                self.mr_size * max_coords_best[:, :, 0].view(B, N, 1, 1) *
                rotmat, max_coords_best[:, :, 1:3].view(B, N, 2, 1)
            ],
                                     dim=3)

            # Zero response lafs, which touch the boundary
            good_mask = laf_is_inside_image(current_lafs, octave[:, 0])
            resp_flat_best = resp_flat_best * good_mask.to(dev, dtype)

            # Normalize LAFs
            current_lafs = normalize_laf(
                current_lafs,
                octave[:, 0])  # We don`t need # of scale levels, only shape

            all_responses.append(resp_flat_best)
            all_lafs.append(current_lafs)

        # Sort and keep best n
        responses: torch.Tensor = torch.cat(all_responses, dim=1)
        lafs: torch.Tensor = torch.cat(all_lafs, dim=1)
        responses, idxs = torch.topk(responses, k=num_feats, dim=1)
        lafs = torch.gather(
            lafs, 1,
            idxs.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 2, 3))
        return responses, denormalize_laf(lafs, img)