def detect(self, img: torch.Tensor, num_feats: int) -> Tuple[torch.Tensor, torch.Tensor]: sp, sigmas, pix_dists = self.scale_pyr(img) all_responses = [] all_lafs = [] for oct_idx, octave in enumerate(sp): sigmas_oct = sigmas[oct_idx] pix_dists_oct = pix_dists[oct_idx] B, L, CH, H, W = octave.size() # Run response function oct_resp = self.resp(octave.view(B * L, CH, H, W), sigmas_oct.view(-1)).view(B, L, CH, H, W) # We want nms for scale responses, so reorder to (B, CH, L, H, W) oct_resp = oct_resp.permute(0, 2, 1, 3, 4) # Differentiable nms coord_max, response_max = self.nms(oct_resp) # Now, lets crop out some small responses responses_flatten = response_max.view(response_max.size(0), -1) # [B * N, 3] max_coords_flatten = coord_max.view(response_max.size(0), 3, -1).permute(0, 2, 1) # [B, N, 3] if responses_flatten.size(1) > num_feats: resp_flat_best, idxs = torch.topk(responses_flatten, k=num_feats, dim=1) max_coords_best = torch.gather(max_coords_flatten, 1, idxs.unsqueeze(-1).repeat(1, 1, 3)) else: resp_flat_best = responses_flatten max_coords_best = max_coords_flatten B, N = resp_flat_best.size() # Converts scale level index from ConvSoftArgmax3d to the actual scale, using the sigmas max_coords_best = _scale_index_to_scale(max_coords_best, sigmas_oct) # Create local affine frames (LAFs) rotmat = angle_to_rotation_matrix(torch.zeros(B, N).to(max_coords_best.device).to(max_coords_best.dtype)) current_lafs = torch.cat([self.mr_size * max_coords_best[:, :, 0].view(B, N, 1, 1) * rotmat, max_coords_best[:, :, 1:3].view(B, N, 2, 1)], dim=3) # Normalize LAFs current_lafs = normalize_laf(current_lafs, octave[:, 0]) # We don`t need # of scale levels, only shape all_responses.append(resp_flat_best) all_lafs.append(current_lafs) # Sort and keep best n responses: torch.Tensor = torch.cat(all_responses, dim=1) lafs: torch.Tensor = torch.cat(all_lafs, dim=1) responses, idxs = torch.topk(responses, k=num_feats, dim=1) lafs = torch.gather(lafs, 1, idxs.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 2, 3)) return responses, denormalize_laf(lafs, img)
def extract_patches_from_pyramid( img: torch.Tensor, laf: torch.Tensor, PS: int = 32, normalize_lafs_before_extraction: bool = True ) -> torch.Tensor: """Extract patches defined by LAFs from image tensor. Copied from kornia.feature.laf.extract_patches_from_pyramid with one minor difference - highlighted below. """ raise_error_if_laf_is_not_valid(laf) if normalize_lafs_before_extraction: nlaf: torch.Tensor = normalize_laf(laf, img) else: nlaf = laf B, N, _, _ = laf.size() _, ch, h, w = img.size() scale = 2.0 * get_laf_scale(denormalize_laf(nlaf, img)) / float(PS) pyr_idx = scale.log2().relu().long() # diff: floor instead of round cur_img = img cur_pyr_level = 0 out = torch.zeros(B, N, ch, PS, PS).to(nlaf.dtype).to(nlaf.device) while min(cur_img.size(2), cur_img.size(3)) >= PS: _, ch, h, w = cur_img.size() # for loop temporarily, to be refactored for i in range(B): scale_mask = (pyr_idx[i] == cur_pyr_level).squeeze() if (scale_mask.float().sum()) == 0: continue scale_mask = (scale_mask > 0).view(-1) grid = generate_patch_grid_from_normalized_LAF( cur_img[i: i + 1], nlaf[i: i + 1, scale_mask, :, :], PS) patches = F.grid_sample( cur_img[i: i + 1].expand(grid.size(0), ch, h, w), grid, # type: ignore padding_mode="border", align_corners=False, ) out[i].masked_scatter_(scale_mask.view(-1, 1, 1, 1), patches) cur_img = pyrdown(cur_img) cur_pyr_level += 1 return out
def detect( self, img: torch.Tensor, num_feats: int, mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: dev: torch.device = img.device dtype: torch.dtype = img.dtype sp, sigmas, pix_dists = self.scale_pyr(img) all_responses = [] all_lafs = [] for oct_idx, octave in enumerate(sp): sigmas_oct = sigmas[oct_idx] pix_dists_oct = pix_dists[oct_idx] B, L, CH, H, W = octave.size() # Run response function oct_resp = self.resp(octave.view(B * L, CH, H, W), sigmas_oct.view(-1)) # We want nms for scale responses, so reorder to (B, CH, L, H, W) oct_resp = oct_resp.view_as(octave).permute(0, 2, 1, 3, 4) if mask is not None: oct_mask: torch.Tensor = _create_octave_mask( mask, oct_resp.shape) oct_resp = oct_mask * oct_resp # Differentiable nms coord_max, response_max = self.nms(oct_resp) if self.minima_are_also_good: coord_min, response_min = self.nms(-oct_resp) take_min_mask = (response_min > response_max).to( response_max.dtype) response_max = response_min * take_min_mask + ( 1 - take_min_mask) * response_max coord_max = coord_min * take_min_mask.unsqueeze(1) + ( 1 - take_min_mask.unsqueeze(1)) * coord_max # Now, lets crop out some small responses responses_flatten = response_max.view(response_max.size(0), -1) # [B, N] max_coords_flatten = coord_max.view(response_max.size(0), 3, -1).permute(0, 2, 1) # [B, N, 3] if responses_flatten.size(1) > num_feats: resp_flat_best, idxs = torch.topk(responses_flatten, k=num_feats, dim=1) max_coords_best = torch.gather( max_coords_flatten, 1, idxs.unsqueeze(-1).repeat(1, 1, 3)) else: resp_flat_best = responses_flatten max_coords_best = max_coords_flatten B, N = resp_flat_best.size() # Converts scale level index from ConvSoftArgmax3d to the actual scale, using the sigmas max_coords_best = _scale_index_to_scale(max_coords_best, sigmas_oct) # Create local affine frames (LAFs) rotmat = torch.eye(2, dtype=dtype, device=dev).view(1, 1, 2, 2) current_lafs = torch.cat([ self.mr_size * max_coords_best[:, :, 0].view(B, N, 1, 1) * rotmat, max_coords_best[:, :, 1:3].view(B, N, 2, 1) ], dim=3) # Zero response lafs, which touch the boundary good_mask = laf_is_inside_image(current_lafs, octave[:, 0]) resp_flat_best = resp_flat_best * good_mask.to(dev, dtype) # Normalize LAFs current_lafs = normalize_laf( current_lafs, octave[:, 0]) # We don`t need # of scale levels, only shape all_responses.append(resp_flat_best) all_lafs.append(current_lafs) # Sort and keep best n responses: torch.Tensor = torch.cat(all_responses, dim=1) lafs: torch.Tensor = torch.cat(all_lafs, dim=1) responses, idxs = torch.topk(responses, k=num_feats, dim=1) lafs = torch.gather( lafs, 1, idxs.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 2, 3)) return responses, denormalize_laf(lafs, img)