def find_corr(self, xyz0, xyz1, F0, F1, subsample_size=-1): subsample = len(F0) > subsample_size if subsample_size > 0 and subsample: N0 = min(len(F0), subsample_size) N1 = min(len(F1), subsample_size) inds0 = np.random.choice(len(F0), N0, replace=False) inds1 = np.random.choice(len(F1), N1, replace=False) F0, F1 = F0[inds0], F1[inds1] # Compute the nn nn_inds = find_nn_gpu(F0, F1, nn_max_n=self.config.nn_max_n) if subsample_size > 0 and subsample: return xyz0[inds0], xyz1[inds1[nn_inds]] else: return xyz0, xyz1[nn_inds]
def _valid_epoch(self, data_loader_iter): # Change the network to evaluation mode self.model.eval() num_data = 0 hit_ratio_meter, reciprocity_ratio_meter = AverageMeter( ), AverageMeter() reciprocity_hit_ratio_meter = AverageMeter() data_timer, feat_timer = Timer(), Timer() tot_num_data = len(self.val_data_loader.dataset) if self.val_max_iter > 0: tot_num_data = min(self.val_max_iter, tot_num_data) for curr_iter in range(tot_num_data): data_timer.tic() input_dict = self.get_data(data_loader_iter) data_timer.toc() # pairs consist of (xyz1 index, xyz0 index) feat_timer.tic() with torch.no_grad(): F0 = self.model(input_dict['img0'].to(self.device)) F1 = self.model(input_dict['img1'].to(self.device)) feat_timer.toc() # Test self.num_pos_per_batch * self.batch_size features only. _, _, H0, W0 = F0.shape _, _, H1, W1 = F1.shape for batch_idx, pair in enumerate(input_dict['pairs']): N = len(pair) sel = np.random.choice(N, min(N, self.config.num_pos_per_batch), replace=False) curr_pair = pair[sel] w0, h0, w1, h1 = torch.floor(curr_pair.t() / self.out_tensor_stride).long() feats0 = F0[batch_idx, :, h0, w0] nn_inds1 = find_nn_gpu(feats0, F1[batch_idx, :].view(F1.shape[1], -1), nn_max_n=self.config.nn_max_n, transposed=True) # Convert the index to coordinate: BxCxHxW xs1 = nn_inds1 % W1 ys1 = nn_inds1 // W1 # Test reciprocity nn_inds0 = find_nn_gpu(F1[batch_idx, :, ys1, xs1], F0[batch_idx, :].view(F0.shape[1], -1), nn_max_n=self.config.nn_max_n, transposed=True) # Convert the index to coordinate: BxCxHxW xs0 = nn_inds0 % W0 ys0 = nn_inds0 // W0 dist_sq = (w1 - xs1)**2 + (h1 - ys1)**2 is_correct = dist_sq < (self.config.ucn_inlier_threshold_pixel / self.out_tensor_stride)**2 hit_ratio_meter.update(is_correct.sum().item() / len(is_correct)) # Recipocity test result dist_sq_nn = (w0 - xs0)**2 + (h0 - ys0)**2 mask = dist_sq_nn < (self.config.ucn_inlier_threshold_pixel / self.out_tensor_stride)**2 reciprocity_ratio_meter.update(mask.sum().item() / float(len(mask))) reciprocity_hit_ratio_meter.update( is_correct[mask].sum().item() / (mask.sum().item() + eps)) torch.cuda.empty_cache() # visualize_image_correspondence(input_dict['img0'][batch_idx, 0].numpy() + 0.5, # input_dict['img1'][batch_idx, 0].numpy() + 0.5, # F0[batch_idx], F1[batch_idx], curr_iter, # self.config) num_data += 1 if num_data % 100 == 0: logging.info(', '.join([ f"Validation iter {num_data} / {tot_num_data} : Data Loading Time: {data_timer.avg:.3f}", f"Feature Extraction Time: {feat_timer.avg:.3f}, Hit Ratio: {hit_ratio_meter.avg}", f"Reciprocity Ratio: {reciprocity_ratio_meter.avg}, Reci Filtered Hit Ratio: {reciprocity_hit_ratio_meter.avg}" ])) data_timer.reset() logging.info(', '.join([ f"Validation : Data Loading Time: {data_timer.avg:.3f}", f"Feature Extraction Time: {feat_timer.avg:.3f}, Hit Ratio: {hit_ratio_meter.avg}", f"Reciprocity Ratio: {reciprocity_ratio_meter.avg}, Reci Filtered Hit Ratio: {reciprocity_hit_ratio_meter.avg}" ])) return { 'hit_ratio': hit_ratio_meter.avg, 'reciprocity_ratio': reciprocity_ratio_meter.avg, 'reciprocity_hit_ratio': reciprocity_hit_ratio_meter.avg, }
def contrastive_loss(self, img0, img1, F0, F1, pairs, num_pos=5192, num_hn_samples=2048): """ F0: B x C x H0 x W0 F0: B x C x H1 x W1 Generate negative pairs """ B, C, H0, W0 = F0.shape B1, C1, H1, W1 = F1.shape assert B == B1 assert C == C1 pos_loss_sum, neg_loss_sum = 0, 0 sq_thresh = (self.config.ucn_inlier_threshold_pixel / self.out_tensor_stride)**2 for curr_F0, curr_F1, curr_pairs in zip(F0, F1, pairs): flat_F0 = curr_F0.view(C, -1) flat_F1 = curr_F1.view(C, -1) # Sample self.config.num_pos_per_batch, # Sample num_hn_samples as well for hardest negative mining N = len(curr_pairs) num_pos = min(num_pos, N) num_hn_samples = min(num_hn_samples, min(H0, H1) * min(W0, W1)) sel_pos = np.random.choice(N, num_pos, replace=False) sel_pairs = curr_pairs[sel_pos] sel_neg0 = torch.from_numpy( np.random.choice(H0 * W0, num_hn_samples, replace=False)) sel_neg1 = torch.from_numpy( np.random.choice(H1 * W1, num_hn_samples, replace=False)) w0, h0, w1, h1 = torch.floor(sel_pairs.t() / self.out_tensor_stride).long() sel_pos0 = h0 * W0 + w0 sel_pos1 = h1 * W1 + w1 # Find negatives for all F1[positive_pairs[:, 1]] subF0, subF1 = flat_F0[:, sel_neg0], flat_F1[:, sel_neg1] posF0, posF1 = flat_F0[:, sel_pos0], flat_F1[:, sel_pos1] with torch.no_grad(): nn_inds1 = find_nn_gpu(posF0, subF1, nn_max_n=self.config.nn_max_n, transposed=True) nn_inds0 = find_nn_gpu(posF1, subF0, nn_max_n=self.config.nn_max_n, transposed=True) D1ind = sel_neg1[nn_inds1] D0ind = sel_neg0[nn_inds0] neg_w1 = D1ind % W1 neg_h1 = D1ind // W1 neg_w0 = D0ind % W0 neg_h0 = D0ind // W0 # Check if they are outside the pixel thresh mask0 = ((h0 - neg_h0)**2 + (w0 - neg_w0)**2) > sq_thresh mask1 = ((h1 - neg_h1)**2 + (w1 - neg_w1)**2) > sq_thresh D01min = (posF0[:, mask0] - subF1[:, nn_inds1[mask0]]).pow(2).sum(0) D10min = (posF1[:, mask1] - subF0[:, nn_inds0[mask1]]).pow(2).sum(0) pw0, ph0, pw1, ph1 = torch.floor(curr_pairs.t() / self.out_tensor_stride).long() pos_loss = F.relu((curr_F0[:, ph0, pw0] - curr_F1[:, ph1, pw1]).pow(2).sum(0) - self.pos_thresh) neg_loss0 = F.relu(self.neg_thresh - D01min).pow(2) neg_loss1 = F.relu(self.neg_thresh - D10min).pow(2) pos_loss_sum += pos_loss.mean() neg_loss_sum += (neg_loss0.mean() + neg_loss1.mean()) / 2 return pos_loss_sum / B, neg_loss_sum / B
def visualize_image_correspondence(img0, img1, F0, F1, filename, mode='gpu-all', config=None, visualize=True): use_stability_test = True use_cyclic_test = False keypoint = 'sift' if keypoint == 'sift': sift = cv2.xfeatures2d.SIFT_create( 0, 9, 0.01, # Smaller more keypoints, default 0.04 100 # larger more keypoints, default 10 ) kp0 = sift.detect(img0, None) kp1 = sift.detect(img1, None) xy_kp0 = np.floor(np.array([k.pt for k in kp0]).T) xy_kp1 = np.floor(np.array([k.pt for k in kp1]).T) x0, y0 = xy_kp0[0], xy_kp0[1] x1, y1 = xy_kp1[0], xy_kp1[1] elif keypoint == 'all': x0, y0 = None, None x1, y1 = None, None H0, W0 = img0.shape H1, W1 = img1.shape if mode == 'cpu-keypoints': matches1 = util_2d.feature_match(F0[:, y0, x0].t().cpu().numpy(), F1[:, y1, x1].t().cpu().numpy(), ratio_test=True, ratio=0.95) # Convert the index to coordinate: BxCxHxW x0 = x0[matches1[:, 0]] y0 = y0[matches1[:, 0]] xs1 = x1[matches1[:, 1]] ys1 = y1[matches1[:, 1]] # Test reciprocity nn_inds0 = find_nn_gpu(F1[:, ys1, xs1], F0[:, y0, x0], nn_max_n=config.nn_max_n, transposed=True) # Convert the index to coordinate: BxCxHxW xs0 = x0[nn_inds0.numpy()] ys0 = y0[nn_inds0.numpy()] dist_sq_nn = (x0 - xs0)**2 + (y0 - ys0)**2 mask = dist_sq_nn < (config.ucn_inlier_threshold_pixel**2) elif mode == 'gpu-keypoints': nn_inds1 = find_nn_gpu(F0[:, y0, x0], F1[:, y1, x1], nn_max_n=config.nn_max_n, transposed=True).numpy() # Convert the index to coordinate: BxCxHxW xs1 = x1[nn_inds1] ys1 = y1[nn_inds1] if use_stability_test: # Stability test: check stable under perturbation noisex = 2 * (np.random.rand(len(xs1)) < 0.5) - 1 noisey = 2 * (np.random.rand(len(ys1)) < 0.5) - 1 xs1n = np.clip(xs1 + noisex, 0, W1 - 1) ys1n = np.clip(ys1 + noisey, 0, H1 - 1) else: xs1n = xs1 ys1n = ys1 # Test reciprocity nn_inds0 = find_nn_gpu(F1[:, ys1n, xs1n], F0[:, y0, x0], nn_max_n=config.nn_max_n, transposed=True).numpy() # Convert the index to coordinate: BxCxHxW xs0 = x0[nn_inds0] ys0 = y0[nn_inds0] dist_sq_nn = (x0 - xs0)**2 + (y0 - ys0)**2 mask = dist_sq_nn < (config.ucn_inlier_threshold_pixel**2) elif mode == 'gpu-all': nn_inds1 = find_nn_faiss( F0[:, y0, x0], F1.view(F1.shape[0], -1), ) # Convert the index to coordinate: BxCxHxW xs1 = nn_inds1 % W1 ys1 = nn_inds1 // W1 if use_stability_test: # Stability test: check stable under perturbation noisex = 2 * (np.random.rand(len(xs1)) < 0.5) - 1 noisey = 2 * (np.random.rand(len(ys1)) < 0.5) - 1 xs1n = np.clip(xs1 + noisex, 0, W1 - 1) ys1n = np.clip(ys1 + noisey, 0, H1 - 1) else: xs1n = xs1 ys1n = ys1 if use_cyclic_test: # Test reciprocity nn_inds0 = find_nn_faiss( F1[:, ys1n, xs1n], F0.view(F0.shape[0], -1), ) # Convert the index to coordinate: BxCxHxW xs0 = (nn_inds0 % W0) ys0 = (nn_inds0 // W0) # Test cyclic consistency dist_sq_nn = (x0 - xs0)**2 + (y0 - ys0)**2 mask = dist_sq_nn < (config.ucn_inlier_threshold_pixel**2) else: xs0 = x0 ys0 = y0 mask = np.ones(len(x0)).astype(bool) elif mode == 'gpu-all-all': nn_inds1 = find_nn_faiss( F0.view(F0.shape[0], -1), F1.view(F1.shape[0], -1), ) inds0 = np.arange(len(nn_inds1)) x0 = inds0 % W0 y0 = inds0 // W0 xs1 = nn_inds1 % W1 ys1 = nn_inds1 // W1 if use_stability_test: # Stability test: check stable under perturbation noisex = 2 * (np.random.rand(len(xs1)) < 0.5) - 1 noisey = 2 * (np.random.rand(len(ys1)) < 0.5) - 1 xs1n = np.clip(xs1 + noisex, 0, W1 - 1) ys1n = np.clip(ys1 + noisey, 0, H1 - 1) else: xs1n = xs1 ys1n = ys1 # Test reciprocity nn_inds0 = find_nn_faiss( F1[:, ys1n, xs1n], F0.view(F0.shape[0], -1), ) # Convert the index to coordinate: BxCxHxW xs0 = nn_inds0 % W0 ys0 = nn_inds0 // W0 # Filter out the points that fail the cycle consistency dist_sq_nn = (x0 - xs0)**2 + (y0 - ys0)**2 mask = dist_sq_nn < (config.ucn_inlier_threshold_pixel**2) if visualize: color = x0[mask] + y0[mask] * W0 plt.clf() fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2) fig = plt.gcf() fig.set_size_inches(9, 6) ax0.imshow(img0 * 0.5, vmin=0, vmax=255, cmap='gray') ax0.scatter(x=x0[mask], y=y0[mask], c=color, s=2, cmap="jet") ax0.axis('off') ax1.imshow(img1 * 0.5, vmin=0, vmax=255, cmap='gray') ax1.scatter(x=xs1[mask], y=ys1[mask], c=color, s=2, cmap="jet") ax1.axis('off') fig.tight_layout() ensure_dir('./ucn_outputs') plt.savefig(f"./ucn_outputs/{filename:03d}.png", dpi=300) else: return x0[mask], y0[mask], xs1[mask], ys1[mask]