def batch_extract_features(self, desc, heatmap_nms_batch, residual): # extract pts, residuals for pts, descriptors """ return: -- type: tensorFloat pts: tensor [batch, N, 2] (no grad) (x, y) pts_offset: tensor [batch, N, 2] (grad) (x, y) pts_desc: tensor [batch, N, 256] (grad) """ batch_size = heatmap_nms_batch.shape[0] pts_int, pts_offset, pts_desc = [], [], [] pts_idx = heatmap_nms_batch[...].nonzero() # [N, 4(batch, 0, y, x)] for i in range(batch_size): mask_b = (pts_idx[:, 0] == i) # first column == batch pts_int_b = pts_idx[mask_b][:, 2:].float() # default floatTensor pts_int_b = pts_int_b[:, [1, 0]] # tensor [N, 2(x,y)] res_b = residual[mask_b] # print("res_b: ", res_b.shape) # print("pts_int_b: ", pts_int_b.shape) pts_b = pts_int_b + res_b # .no_grad() # extract desc pts_desc_b = self.sample_desc_from_points(desc[i].unsqueeze(0), pts_b).squeeze(0) # print("pts_desc_b: ", pts_desc_b.shape) # get random shuffle from utils.utils import crop_or_pad_choice choice = crop_or_pad_choice(pts_int_b.shape[0], out_num_points=self.out_num_points, shuffle=True) choice = torch.tensor(choice) pts_int.append(pts_int_b[choice]) pts_offset.append(res_b[choice]) pts_desc.append(pts_desc_b[choice]) pts_int = torch.stack((pts_int), dim=0) pts_offset = torch.stack((pts_offset), dim=0) pts_desc = torch.stack((pts_desc), dim=0) return { 'pts_int': pts_int, 'pts_offset': pts_offset, 'pts_desc': pts_desc }
def descriptor_loss_sparse(descriptors, descriptors_warped, homographies, mask_valid=None, cell_size=8, device='cpu', descriptor_dist=4, lamda_d=250, num_matching_attempts=1000, num_masked_non_matches_per_match=10, dist='cos', method='1d', **config): """ consider batches of descriptors :param descriptors: Output from descriptor head tensor [descriptors, Hc, Wc] :param descriptors_warped: Output from descriptor head of warped image tensor [descriptors, Hc, Wc] """ def uv_to_tuple(uv): return (uv[:, 0], uv[:, 1]) def tuple_to_uv(uv_tuple): return torch.stack([uv_tuple[0], uv_tuple[1]]) def tuple_to_1d(uv_tuple, W, uv=True): if uv: return uv_tuple[0] + uv_tuple[1] * W else: return uv_tuple[0] * W + uv_tuple[1] def uv_to_1d(points, W, uv=True): # assert points.dim == 2 # print("points: ", points[0]) # print("H: ", H) if uv: return points[..., 0] + points[..., 1] * W else: return points[..., 0] * W + points[..., 1] ## calculate matches loss def get_match_loss(image_a_pred, image_b_pred, matches_a, matches_b, dist='cos', method='1d'): match_loss, matches_a_descriptors, matches_b_descriptors = \ PixelwiseContrastiveLoss.match_loss(image_a_pred, image_b_pred, matches_a, matches_b, dist=dist, method=method) return match_loss def get_non_matches_corr(img_b_shape, uv_a, uv_b_matches, num_masked_non_matches_per_match=10, device='cpu'): ## sample non matches uv_b_matches = uv_b_matches.squeeze() uv_b_matches_tuple = uv_to_tuple(uv_b_matches) uv_b_non_matches_tuple = correspondence_finder.create_non_correspondences( uv_b_matches_tuple, img_b_shape, num_non_matches_per_match=num_masked_non_matches_per_match, img_b_mask=None) ## create_non_correspondences # print("img_b_shape ", img_b_shape) # print("uv_b_matches ", uv_b_matches.shape) # print("uv_a: ", uv_to_tuple(uv_a)) # print("uv_b_non_matches: ", uv_b_non_matches) # print("uv_b_non_matches: ", tensorUv2tuple(uv_b_non_matches)) uv_a_tuple, uv_b_non_matches_tuple = \ create_non_matches(uv_to_tuple(uv_a), uv_b_non_matches_tuple, num_masked_non_matches_per_match) return uv_a_tuple, uv_b_non_matches_tuple def get_non_match_loss(image_a_pred, image_b_pred, non_matches_a, non_matches_b, dist='cos'): ## non matches loss non_match_loss, num_hard_negatives, non_matches_a_descriptors, non_matches_b_descriptors = \ PixelwiseContrastiveLoss.non_match_descriptor_loss(image_a_pred, image_b_pred, non_matches_a.long().squeeze(), non_matches_b.long().squeeze(), M=0.2, invert=True, dist=dist) non_match_loss = non_match_loss.sum() / (num_hard_negatives + 1) return non_match_loss from utils.utils import filter_points from utils.utils import crop_or_pad_choice from utils.utils import normPts # ##### print configs # print("num_masked_non_matches_per_match: ", num_masked_non_matches_per_match) # print("num_matching_attempts: ", num_matching_attempts) # dist = 'cos' # print("method: ", method) Hc, Wc = descriptors.shape[1], descriptors.shape[2] img_shape = (Hc, Wc) # print("img_shape: ", img_shape) # img_shape_cpu = (Hc.to('cpu'), Wc.to('cpu')) # image_a_pred = descriptors.view(1, -1, Hc * Wc).transpose(1, 2) # torch [batch_size, H*W, D] def descriptor_reshape(descriptors): descriptors = descriptors.view(-1, Hc * Wc).transpose( 0, 1) # torch [D, H, W] --> [H*W, d] descriptors = descriptors.unsqueeze(0) # torch [1, H*W, D] return descriptors image_a_pred = descriptor_reshape(descriptors) # torch [1, H*W, D] # print("image_a_pred: ", image_a_pred.shape) image_b_pred = descriptor_reshape( descriptors_warped) # torch [batch_size, H*W, D] # matches uv_a = get_coor_cells(Hc, Wc, cell_size, uv=True, device='cpu') # print("uv_a: ", uv_a[0]) homographies_H = scale_homography_torch(homographies, img_shape, shift=(-1, -1)) # print("experiment inverse homographies") # homographies_H = torch.stack([torch.inverse(H) for H in homographies_H]) # print("homographies_H: ", homographies_H.shape) # homographies_H = torch.inverse(homographies_H) uv_b_matches = warp_coor_cells_with_homographies(uv_a, homographies_H.to('cpu'), uv=True, device='cpu') # # print("uv_b_matches before round: ", uv_b_matches[0]) uv_b_matches.round_() # print("uv_b_matches after round: ", uv_b_matches[0]) uv_b_matches = uv_b_matches.squeeze(0) # filtering out of range points # choice = crop_or_pad_choice(x_all.shape[0], self.sift_num, shuffle=True) uv_b_matches, mask = filter_points(uv_b_matches, torch.tensor([Wc, Hc]).to(device='cpu'), return_mask=True) # print ("pos mask sum: ", mask.sum()) uv_a = uv_a[mask] # crop to the same length shuffle = True if not shuffle: print("shuffle: ", shuffle) choice = crop_or_pad_choice(uv_b_matches.shape[0], num_matching_attempts, shuffle=shuffle) choice = torch.tensor(choice) uv_a = uv_a[choice] uv_b_matches = uv_b_matches[choice] if method == '2d': matches_a = normPts(uv_a, torch.tensor([Wc, Hc]).float()) # [u, v] matches_b = normPts(uv_b_matches, torch.tensor([Wc, Hc]).float()) else: matches_a = uv_to_1d(uv_a, Wc) matches_b = uv_to_1d(uv_b_matches, Wc) # print("matches_a: ", matches_a.shape) # print("matches_b: ", matches_b.shape) # print("matches_b max: ", matches_b.max()) if method == '2d': match_loss = get_match_loss(descriptors, descriptors_warped, matches_a.to(device), matches_b.to(device), dist=dist, method='2d') else: match_loss = get_match_loss(image_a_pred, image_b_pred, matches_a.long().to(device), matches_b.long().to(device), dist=dist) # non matches # get non matches correspondence uv_a_tuple, uv_b_non_matches_tuple = get_non_matches_corr( img_shape, uv_a, uv_b_matches, num_masked_non_matches_per_match=num_masked_non_matches_per_match) non_matches_a = tuple_to_1d(uv_a_tuple, Wc) non_matches_b = tuple_to_1d(uv_b_non_matches_tuple, Wc) # print("non_matches_a: ", non_matches_a) # print("non_matches_b: ", non_matches_b) non_match_loss = get_non_match_loss(image_a_pred, image_b_pred, non_matches_a.to(device), non_matches_b.to(device), dist=dist) # non_match_loss = non_match_loss.mean() loss = lamda_d * match_loss + non_match_loss return loss, lamda_d * match_loss, non_match_loss pass
def descriptor_loss_sparse_reliability(descriptors, descriptors_warped, reliability, reliability_warp, homographies, aplosser, mask_valid=None, cell_size=8, device='cpu', descriptor_dist=4, lamda_d=250, lamda_r=1, num_matching_attempts=1000, num_masked_non_matches_per_match=10, dist='cos', method='1d', sos=True, reli_base=0.5, **config): """ consider batches of descriptors :param descriptors: Output from descriptor head tensor [descriptors, Hc, Wc] :param descriptors_warped: Output from descriptor head of warped image tensor [descriptors, Hc, Wc] """ def uv_to_tuple(uv): return (uv[:, 0], uv[:, 1]) def tuple_to_uv(uv_tuple): return torch.stack([uv_tuple[0], uv_tuple[1]]) def tuple_to_1d(uv_tuple, W, uv=True): if uv: return uv_tuple[0] + uv_tuple[1] * W else: return uv_tuple[0] * W + uv_tuple[1] def uv_to_1d(points, W, uv=True): # assert points.dim == 2 # print("points: ", points[0]) # print("H: ", H) if uv: return points[..., 0] + points[..., 1] * W else: return points[..., 0] * W + points[..., 1] ## calculate matches loss def get_match_loss(image_a_pred, image_b_pred, matches_a, matches_b, dist='cos', method='1d', sos=True): match_loss, matches_a_descriptors, matches_b_descriptors = \ PixelwiseContrastiveLoss.match_loss(image_a_pred, image_b_pred, matches_a, matches_b, dist=dist, method=method, sos=sos) # Add SOS Loss return match_loss, matches_a_descriptors, matches_b_descriptors def get_non_matches_corr(img_b_shape, uv_a, uv_b_matches, num_masked_non_matches_per_match=10, device='cpu'): ## sample non matches uv_b_matches = uv_b_matches.squeeze() uv_b_matches_tuple = uv_to_tuple(uv_b_matches) uv_b_non_matches_tuple = correspondence_finder.create_non_correspondences( uv_b_matches_tuple, img_b_shape, num_non_matches_per_match=num_masked_non_matches_per_match, img_b_mask=None) ## create_non_correspondences # print("img_b_shape ", img_b_shape) # print("uv_b_matches ", uv_b_matches.shape) # print("uv_a: ", uv_to_tuple(uv_a)) # print("uv_b_non_matches: ", uv_b_non_matches) # print("uv_b_non_matches: ", tensorUv2tuple(uv_b_non_matches)) uv_a_tuple, uv_b_non_matches_tuple = \ create_non_matches(uv_to_tuple(uv_a), uv_b_non_matches_tuple, num_masked_non_matches_per_match) return uv_a_tuple, uv_b_non_matches_tuple def get_non_match_loss(image_a_pred, image_b_pred, non_matches_a, non_matches_b, dist='cos'): ## non matches loss non_match_loss, num_hard_negatives, non_matches_a_descriptors, non_matches_b_descriptors = \ PixelwiseContrastiveLoss.non_match_descriptor_loss(image_a_pred, image_b_pred, non_matches_a.long().squeeze(), non_matches_b.long().squeeze(), M=0.2, invert=True, dist=dist) non_match_loss = non_match_loss.sum() / (num_hard_negatives + 1) return non_match_loss # def reliability_loss(descriptors_a, descriptors_b, gt, uv_a, uv_b): # dist = descriptors_a.mm(descriptors_b.T).pow(2) # dist = dist / dist.max() # acc = (dist * gt).sum(-1) # to be tested # # max = dist.max() # # min = dist.min() # # dist_diagonal = dist.diagonal().pow(2) # reli_a = reliability[:, uv_a[:, 1].long(), uv_a[:, 0].long()].transpose(1, 0) # reli_b = reliability_warp[:, uv_b[:, 1].long(), uv_b[:, 0].long()] # # reli = reli_a.mm(reli_b).pow(0.5) # reli = (reli_a + reli_b) / 2 # ap_loss = 1 - acc * reli - (1 - reli) * reli_base # # ap_loss = 1 - dist_diagonal * reli # ap_loss = ap_loss.mean() # return ap_loss def reliability_loss(descriptors_a, descriptors_b, image_b_pred, reliability_a, reliability_b, uv_a, uv_b, img_b_shape, aplosser, method='1d', device='cpu', reli_base=0.5): uv_b = uv_b.squeeze() uv_b_matches_tuple = uv_to_tuple(uv_b) uv_b_non_matches_tuple = correspondence_finder.create_non_correspondences( uv_b_matches_tuple, img_b_shape, num_non_matches_per_match=1, img_b_mask=None) _, uv_b_non_matches_tuple = create_non_matches(uv_to_tuple(uv_a), uv_b_non_matches_tuple, 1) uv_b_non = tuple_to_uv(uv_b_non_matches_tuple).squeeze().transpose( 1, 0) # negative sample coordinates # generate negative descriptors if method == '2d': def sampleDescriptors(image_a_pred, matches_a, mode, norm=False): image_a_pred = image_a_pred.unsqueeze(0) # torch [1, D, H, W] matches_a.unsqueeze_(0).unsqueeze_(2) matches_a_descriptors = torch.nn.functional.grid_sample( image_a_pred, matches_a, mode=mode, align_corners=True) matches_a_descriptors = matches_a_descriptors.squeeze( ).transpose(0, 1) # print("image_a_pred: ", image_a_pred.shape) # print("matches_a: ", matches_a.shape) # print("matches_a: ", matches_a) # print("matches_a_descriptors: ", matches_a_descriptors) if norm: dn = torch.norm(matches_a_descriptors, p=2, dim=1) # Compute the norm of b_descriptors matches_a_descriptors = matches_a_descriptors.div( torch.unsqueeze(dn, 1)) # Divide by norm to normalize. return matches_a_descriptors matches_b_non = normPts( uv_b_non, torch.tensor([img_b_shape[1], img_b_shape[0]]).float()) descriptors_b_non = sampleDescriptors(image_b_pred, matches_b_non.to(device), mode='bilinear', norm=False) else: matches_b_non = uv_to_1d(uv_b_matches, img_b_shape[1]) descriptors_b_non = torch.index_select( image_b_pred, 1, matches_b_non.long().to(device)) qconf = reliability_a[:, uv_a[:, 1].long(), uv_a[:, 0].long()] + \ reliability_b[:, uv_b[:, 1].long(), uv_b[:, 0].long()] qconf /= 2 pscores = (descriptors_a * descriptors_b).sum(-1)[:, None] nscores = (descriptors_a * descriptors_b_non).sum(-1)[:, None] scores = torch.cat((pscores, nscores), dim=1) gt = torch.zeros_like(scores, dtype=torch.uint8) gt[:, :pscores.shape[1]] = 1 ap = aplosser(scores, gt) ap_loss = 1 - ap * qconf - (1 - qconf) * reli_base return ap_loss.mean() from utils.utils import filter_points from utils.utils import crop_or_pad_choice from utils.utils import normPts # ##### print configs # print("num_masked_non_matches_per_match: ", num_masked_non_matches_per_match) # print("num_matching_attempts: ", num_matching_attempts) # dist = 'cos' # print("method: ", method) Hc, Wc = descriptors.shape[1], descriptors.shape[2] img_shape = (Hc, Wc) # print("img_shape: ", img_shape) # img_shape_cpu = (Hc.to('cpu'), Wc.to('cpu')) # image_a_pred = descriptors.view(1, -1, Hc * Wc).transpose(1, 2) # torch [batch_size, H*W, D] def descriptor_reshape(descriptors): descriptors = descriptors.view(-1, Hc * Wc).transpose( 0, 1) # torch [D, H, W] --> [H*W, d] descriptors = descriptors.unsqueeze(0) # torch [1, H*W, D] return descriptors image_a_pred = descriptor_reshape(descriptors) # torch [1, H*W, D] # print("image_a_pred: ", image_a_pred.shape) image_b_pred = descriptor_reshape( descriptors_warped) # torch [batch_size, H*W, D] # matches uv_a = get_coor_cells(Hc, Wc, cell_size, uv=True, device='cpu') # print("uv_a: ", uv_a[0]) homographies_H = scale_homography_torch( homographies, img_shape, shift=(-1, -1) ) # original scale is for image size, now downscale it to descriptor tensor size # print("experiment inverse homographies") # homographies_H = torch.stack([torch.inverse(H) for H in homographies_H]) # print("homographies_H: ", homographies_H.shape) # homographies_H = torch.inverse(homographies_H) uv_b_matches = warp_coor_cells_with_homographies(uv_a, homographies_H.to('cpu'), uv=True, device='cpu') # # print("uv_b_matches before round: ", uv_b_matches[0]) uv_b_matches.round_() # print("uv_b_matches after round: ", uv_b_matches[0]) uv_b_matches = uv_b_matches.squeeze(0) # filtering out of range points # choice = crop_or_pad_choice(x_all.shape[0], self.sift_num, shuffle=True) uv_b_matches, mask = filter_points(uv_b_matches, torch.tensor([Wc, Hc]).to(device='cpu'), return_mask=True) # print ("pos mask sum: ", mask.sum()) uv_a = uv_a[mask] # uv_a, uv_b_matches are the gt # crop to the same length shuffle = True if not shuffle: print("shuffle: ", shuffle) choice = crop_or_pad_choice(uv_b_matches.shape[0], num_matching_attempts, shuffle=shuffle) choice = torch.tensor(choice) uv_a = uv_a[choice] uv_b_matches = uv_b_matches[choice] ## add reliability, if method == '2d': matches_a = normPts(uv_a, torch.tensor([Wc, Hc]).float()) # [u, v] matches_b = normPts(uv_b_matches, torch.tensor([Wc, Hc]).float()) else: matches_a = uv_to_1d(uv_a, Wc) matches_b = uv_to_1d(uv_b_matches, Wc) # print("matches_a: ", matches_a.shape) # print("matches_b: ", matches_b.shape) # print("matches_b max: ", matches_b.max()) if method == '2d': match_loss, matches_a_descriptors, matches_b_descriptors = get_match_loss( descriptors, descriptors_warped, matches_a.to(device), matches_b.to(device), dist=dist, method='2d', sos=sos) else: match_loss, matches_a_descriptors, matches_b_descriptors = get_match_loss( image_a_pred, image_b_pred, matches_a.long().to(device), matches_b.long().to(device), dist=dist, sos=sos) # reliability loss matches_a_descriptors, matches_b_descriptors = matches_a_descriptors.squeeze( ), matches_b_descriptors.squeeze() # gt = torch.eye(matches_a_descriptors.shape[0]).to(device) if method == '2d': ap_loss = reliability_loss(matches_a_descriptors, matches_b_descriptors, descriptors_warped, reliability, reliability_warp, uv_a, uv_b_matches, img_shape, aplosser, method=method, device=device, reli_base=reli_base) else: ap_loss = reliability_loss(matches_a_descriptors, matches_b_descriptors, image_b_pred, reliability, reliability_warp, uv_a, uv_b_matches, img_shape, aplosser, method=method, device=device, reli_base=reli_base) # non matches # get non matches correspondence uv_a_tuple, uv_b_non_matches_tuple = get_non_matches_corr( img_shape, uv_a, uv_b_matches, num_masked_non_matches_per_match=num_masked_non_matches_per_match) non_matches_a = tuple_to_1d(uv_a_tuple, Wc) non_matches_b = tuple_to_1d(uv_b_non_matches_tuple, Wc) # print("non_matches_a: ", non_matches_a) # print("non_matches_b: ", non_matches_b) non_match_loss = get_non_match_loss(image_a_pred, image_b_pred, non_matches_a.to(device), non_matches_b.to(device), dist=dist) # non_match_loss = non_match_loss.mean() loss = lamda_d * match_loss + lamda_r * ap_loss + non_match_loss return loss, lamda_d * match_loss, non_match_loss, lamda_r * ap_loss pass
def descriptor_loss_sparse(descriptors, descriptors_warped, homographies, mask_valid=None, cell_size=8, device='cpu', descriptor_dist=4, lamda_d=250, num_matching_attempts=1000, num_masked_non_matches_per_match=10, **config): """ consider batches of descriptors :param descriptors: Output from descriptor head tensor [descriptors, Hc, Wc] :param descriptors_warped: Output from descriptor head of warped image tensor [descriptors, Hc, Wc] """ def uv_to_tuple(uv): return (uv[:, 0], uv[:, 1]) def tuple_to_uv(uv_tuple): return torch.stack([uv_tuple[:, 0], uv_tuple[:, 1]]) def tuple_to_1d(uv_tuple, H): return uv_tuple[0] * H + uv_tuple[1] def uv_to_1d(points, H): # assert points.dim == 2 # print("points: ", points[0]) # print("H: ", H) return points[..., 0] * H + points[..., 1] ## calculate matches loss def get_match_loss(image_a_pred, image_b_pred, matches_a, matches_b): match_loss, matches_a_descriptors, matches_b_descriptors = \ PixelwiseContrastiveLoss.match_loss(image_a_pred, image_b_pred, matches_a.long(), matches_b.long()) return match_loss def get_non_matches_corr(img_b_shape, uv_a, uv_b_matches, num_masked_non_matches_per_match=10): ## sample non matches uv_b_matches = uv_b_matches.squeeze() uv_b_matches_tuple = uv_to_tuple(uv_b_matches) uv_b_non_matches_tuple = correspondence_finder.create_non_correspondences( uv_b_matches_tuple, img_b_shape, num_non_matches_per_match=num_masked_non_matches_per_match, img_b_mask=None) ## create_non_correspondences # print("img_b_shape ", img_b_shape) # print("uv_b_matches ", uv_b_matches.shape) # print("uv_a: ", uv_to_tuple(uv_a)) # print("uv_b_non_matches: ", uv_b_non_matches) # print("uv_b_non_matches: ", tensorUv2tuple(uv_b_non_matches)) uv_a_tuple, uv_b_non_matches_tuple = \ create_non_matches(uv_to_tuple(uv_a), uv_b_non_matches_tuple, num_masked_non_matches_per_match) return uv_a_tuple, uv_b_non_matches_tuple def get_non_match_loss(image_a_pred, image_b_pred, non_matches_a, non_matches_b): ## non matches loss non_match_loss, num_hard_negatives, non_matches_a_descriptors, non_matches_b_descriptors = \ PixelwiseContrastiveLoss.non_match_descriptor_loss(image_a_pred, image_b_pred, non_matches_a.long().squeeze(), non_matches_b.long().squeeze()) return non_match_loss from utils.utils import filter_points from utils.utils import crop_or_pad_choice Hc, Wc = descriptors.shape[1], descriptors.shape[2] image_a_pred = descriptors.view(1, -1, Hc * Wc).transpose( 1, 2) # torch [batch_size, H*W, D] # print("image_a_pred: ", image_a_pred.shape) image_b_pred = descriptors_warped.view(1, -1, Hc * Wc).transpose( 1, 2) # torch [batch_size, H*W, D] # matches uv_a = get_coor_cells(Hc, Wc, cell_size, uv=True) # print("uv_a: ", uv_a.shape) homographies_H = scale_homography_torch(homographies, image_shape, shift=(-1, -1)) uv_b_matches = warp_coor_cells_with_homographies(uv_a, homographies_H, uv=True) uv_b_matches = uv_b_matches.squeeze(0) # print("uv_b_matches: ", uv_b_matches.shape) # filtering out of range points # choice = crop_or_pad_choice(x_all.shape[0], self.sift_num, shuffle=True) uv_b_matches, mask = filter_points(uv_b_matches, torch.tensor([Wc, Hc]), return_mask=True) uv_a = uv_a[mask] # crop to the same length choice = crop_or_pad_choice(uv_b_matches.shape[0], num_matching_attempts, shuffle=True) choice = torch.tensor(choice) uv_a = uv_a[choice] uv_b_matches = uv_b_matches[choice] matches_a = uv_to_1d(uv_a, Hc) matches_b = uv_to_1d(uv_b_matches, Hc) # print("matches_a: ", matches_a.shape) # print("matches_b: ", matches_b.shape) # print("matches_b max: ", matches_b.max()) match_loss = get_match_loss(image_a_pred, image_b_pred, matches_a, matches_b) # non matches img_b_shape = (Hc, Wc) # get non matches correspondence uv_a_tuple, uv_b_non_matches_tuple = get_non_matches_corr( img_b_shape, uv_a, uv_b_matches, num_masked_non_matches_per_match=num_masked_non_matches_per_match) non_matches_a = tuple_to_1d(uv_a_tuple, Hc) non_matches_b = tuple_to_1d(uv_b_non_matches_tuple, Hc) # print("non_matches_a: ", non_matches_a) # print("non_matches_b: ", non_matches_b) non_match_loss = get_non_match_loss(image_a_pred, image_b_pred, non_matches_a, non_matches_a) non_match_loss = non_match_loss.mean() loss = lamda_d * match_loss + non_match_loss return loss pass
def quadruplet_descriptor_loss_sparse(descriptors, descriptors_warped, homographies, mask_valid=None, cell_size=8, device='cpu', descriptor_dist=4, lamda_d=250, num_matching_attempts=1000, num_masked_non_matches_per_match=10, dist='cos', method='1d', **config): """ consider batches of descriptors :param descriptors: Output from descriptor head tensor [descriptors, Hc, Wc] :param descriptors_warped: Output from descriptor head of warped image tensor [descriptors, Hc, Wc] """ from utils.utils import filter_points from utils.utils import crop_or_pad_choice from utils.utils import normPts # ##### print configs # print("num_masked_non_matches_per_match: ", num_masked_non_matches_per_match) # print("num_matching_attempts: ", num_matching_attempts) # dist = 'cos' # print("method: ", method) Hc, Wc, Dim = descriptors.shape[1], descriptors.shape[ 2], descriptors.shape[0] img_shape = (Hc, Wc) # print("img_shape: ", img_shape) # img_shape_cpu = (Hc.to('cpu'), Wc.to('cpu')) # image_a_pred = descriptors.view(1, -1, Hc * Wc).transpose(1, 2) # torch [batch_size, H*W, D] image_a_pred = descriptor_reshape(descriptors, Hc, Wc) # torch [1, H*W, D] # print("image_a_pred: ", image_a_pred.shape) image_b_pred = descriptor_reshape(descriptors_warped, Hc, Wc) # torch [batch_size, H*W, D] # matches uv_a = get_coor_cells(Hc, Wc, cell_size, uv=True, device='cpu') # print("uv_a: ", uv_a[0]) homographies_H = scale_homography_torch(homographies, img_shape, shift=(-1, -1)) # print("experiment inverse homographies") # homographies_H = torch.stack([torch.inverse(H) for H in homographies_H]) # print("homographies_H: ", homographies_H.shape) # homographies_H = torch.inverse(homographies_H) uv_b_matches = warp_coor_cells_with_homographies(uv_a, homographies_H.to('cpu'), uv=True, device='cpu') # # print("uv_b_matches before round: ", uv_b_matches[0]) uv_b_matches.round_() # print("uv_b_matches after round: ", uv_b_matches[0]) uv_b_matches = uv_b_matches.squeeze(0) # filtering out of range points # choice = crop_or_pad_choice(x_all.shape[0], self.sift_num, shuffle=True) uv_b_matches, mask = filter_points(uv_b_matches, torch.tensor([Wc, Hc]).to(device='cpu'), return_mask=True) # print ("pos mask sum: ", mask.sum()) uv_a = uv_a[mask] # crop to the same length shuffle = True if not shuffle: print("shuffle: ", shuffle) choice = crop_or_pad_choice(uv_b_matches.shape[0], num_matching_attempts, shuffle=shuffle) choice = torch.tensor(choice) uv_a = uv_a[choice] uv_b_matches = uv_b_matches[choice] matches_a = normPts(uv_a, torch.tensor([Wc, Hc]).float()) # [u, v] matches_b = normPts(uv_b_matches, torch.tensor([Wc, Hc]).float()) # print("matches_a: ", matches_a.shape) # print("matches_b: ", matches_b.shape) # print("matches_b max: ", matches_b.max()) # non matches # get non matches correspondence uv_a_tuple, uv_b_matches_tuple, uv_b_non_matches_tuple = get_first_term_corr( img_shape, uv_a, uv_b_matches, num_masked_non_matches_per_match=num_masked_non_matches_per_match) non_matches_a = tuple_to_1d(uv_a_tuple, Wc) long_matches_b = tuple_to_1d(uv_b_matches_tuple, Wc) non_matches_b = tuple_to_1d(uv_b_non_matches_tuple, Wc) non_matches_a_descriptors = torch.index_select( image_a_pred, 1, non_matches_a.to(device).long().squeeze()).squeeze() non_matches_b_descriptors = torch.index_select( image_b_pred, 1, non_matches_b.to(device).long().squeeze()).squeeze() long_matches_b_descriptors = torch.index_select( image_b_pred, 1, long_matches_b.to(device).long().squeeze()).squeeze() all_neg_pairs = (non_matches_a_descriptors - non_matches_b_descriptors).pow(2).sum(dim=-1) all_pos_pairs = (non_matches_a_descriptors - long_matches_b_descriptors).pow(2).sum(dim=-1) alpha = (all_neg_pairs.sum() - all_pos_pairs.sum()) / ( num_matching_attempts * num_masked_non_matches_per_match) first_term = torch.clamp(all_pos_pairs - all_neg_pairs + alpha, min=0).sum() / (num_matching_attempts * num_masked_non_matches_per_match) pos_matches_a = non_matches_a_descriptors.reshape( (num_matching_attempts, num_masked_non_matches_per_match, Dim)).repeat(1, num_masked_non_matches_per_match, 1).reshape( (-1, Dim)) pos_matches_b = long_matches_b_descriptors.reshape( (num_matching_attempts, num_masked_non_matches_per_match, Dim)).repeat(1, num_masked_non_matches_per_match, 1).reshape( (-1, Dim)) neg_matches_b = non_matches_b_descriptors.reshape( (num_matching_attempts, num_masked_non_matches_per_match, 1, Dim)).repeat(1, 1, num_masked_non_matches_per_match, 1).reshape( (-1, Dim)) neg_matches_b_permute = non_matches_b_descriptors.reshape( (num_matching_attempts, 1, num_masked_non_matches_per_match, Dim)).repeat(1, num_masked_non_matches_per_match, 1, 1).reshape( (-1, Dim)) pos_matches = (pos_matches_a - pos_matches_b).pow(2).sum(dim=-1) neg_matches = (neg_matches_b - neg_matches_b_permute).pow(2).sum(dim=-1) match_mask = (torch.ones(num_masked_non_matches_per_match, num_masked_non_matches_per_match) - torch.eye(num_masked_non_matches_per_match)).repeat( num_matching_attempts, 1, 1).flatten().to(device) second_term = (torch.clamp( (pos_matches - neg_matches) * match_mask + 0.5 * alpha, min=0) / (num_masked_non_matches_per_match * (num_masked_non_matches_per_match - 1) * num_matching_attempts)).sum() # non_match_loss = non_match_loss.mean() loss = lamda_d * first_term + second_term return loss, lamda_d * first_term, second_term