def set_laf_orientation(LAF: torch.Tensor, angles_degrees: torch.Tensor) -> torch.Tensor: """Changes the orientation of the LAFs. Args: LAF: tensor [BxNx2x3]. angles: tensor BxNx1, in degrees. Returns: tensor [BxNx2x3]. Shape: - Input: :math: `(B, N, 2, 3)`, `(B, N, 1)` - Output: :math: `(B, N, 2, 3)` """ raise_error_if_laf_is_not_valid(LAF) B, N = LAF.shape[:2] rotmat: torch.Tensor = angle_to_rotation_matrix(angles_degrees).view( B * N, 2, 2) laf_out: torch.Tensor = torch.cat([ torch.bmm(make_upright(LAF).view(B * N, 2, 3)[:, :2, :2], rotmat), LAF.view(B * N, 2, 3)[:, :2, 2:] ], dim=2).view(B, N, 2, 3) return laf_out
def forward(self, laf: torch.Tensor, img: torch.Tensor) -> torch.Tensor: # type: ignore """ Args: laf: (torch.Tensor), shape [BxNx2x3] img: (torch.Tensor), shape [Bx1xHxW] Returns: laf_out: (torch.Tensor), shape [BxNx2x3] """ raise_error_if_laf_is_not_valid(laf) img_message: str = "Invalid img shape, we expect BxCxHxW. Got: {}".format( img.shape) if not torch.is_tensor(img): raise TypeError("img type is not a torch.Tensor. Got {}".format( type(img))) if len(img.shape) != 4: raise ValueError(img_message) if laf.size(0) != img.size(0): raise ValueError( "Batch size of laf and img should be the same. Got {}, {}". format(img.size(0), laf.size(0))) B, N = laf.shape[:2] patches: torch.Tensor = extract_patches_from_pyramid( img, laf, self.patch_size).view(-1, 1, self.patch_size, self.patch_size) angles_radians: torch.Tensor = self.angle_detector(patches).view(B, N) rotmat: torch.Tensor = angle_to_rotation_matrix( rad2deg(angles_radians)).view(B * N, 2, 2) laf_out: torch.Tensor = torch.cat([ torch.bmm(make_upright(laf).view(B * N, 2, 3)[:, :2, :2], rotmat), laf.view(B * N, 2, 3)[:, :2, 2:] ], dim=2).view(B, N, 2, 3) return laf_out
def detect(self, img: torch.Tensor, num_feats: int) -> Tuple[torch.Tensor, torch.Tensor]: sp, sigmas, pix_dists = self.scale_pyr(img) all_responses = [] all_lafs = [] for oct_idx, octave in enumerate(sp): sigmas_oct = sigmas[oct_idx] pix_dists_oct = pix_dists[oct_idx] B, L, CH, H, W = octave.size() # Run response function oct_resp = self.resp(octave.view(B * L, CH, H, W), sigmas_oct.view(-1)).view(B, L, CH, H, W) # We want nms for scale responses, so reorder to (B, CH, L, H, W) oct_resp = oct_resp.permute(0, 2, 1, 3, 4) # Differentiable nms coord_max, response_max = self.nms(oct_resp) # Now, lets crop out some small responses responses_flatten = response_max.view(response_max.size(0), -1) # [B * N, 3] max_coords_flatten = coord_max.view(response_max.size(0), 3, -1).permute(0, 2, 1) # [B, N, 3] if responses_flatten.size(1) > num_feats: resp_flat_best, idxs = torch.topk(responses_flatten, k=num_feats, dim=1) max_coords_best = torch.gather(max_coords_flatten, 1, idxs.unsqueeze(-1).repeat(1, 1, 3)) else: resp_flat_best = responses_flatten max_coords_best = max_coords_flatten B, N = resp_flat_best.size() # Converts scale level index from ConvSoftArgmax3d to the actual scale, using the sigmas max_coords_best = _scale_index_to_scale(max_coords_best, sigmas_oct) # Create local affine frames (LAFs) rotmat = angle_to_rotation_matrix(torch.zeros(B, N).to(max_coords_best.device).to(max_coords_best.dtype)) current_lafs = torch.cat([self.mr_size * max_coords_best[:, :, 0].view(B, N, 1, 1) * rotmat, max_coords_best[:, :, 1:3].view(B, N, 2, 1)], dim=3) # Normalize LAFs current_lafs = normalize_laf(current_lafs, octave[:, 0]) # We don`t need # of scale levels, only shape all_responses.append(resp_flat_best) all_lafs.append(current_lafs) # Sort and keep best n responses: torch.Tensor = torch.cat(all_responses, dim=1) lafs: torch.Tensor = torch.cat(all_lafs, dim=1) responses, idxs = torch.topk(responses, k=num_feats, dim=1) lafs = torch.gather(lafs, 1, idxs.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 2, 3)) return responses, denormalize_laf(lafs, img)