def forward(self, line_seg1, line_seg2, desc1, desc2): """ Find the best matches between two sets of line segments and their corresponding descriptors. """ img_size1 = (desc1.shape[2] * self.grid_size, desc1.shape[3] * self.grid_size) img_size2 = (desc2.shape[2] * self.grid_size, desc2.shape[3] * self.grid_size) device = desc1.device # Default case when an image has no lines if len(line_seg1) == 0: return np.empty((0), dtype=int) if len(line_seg2) == 0: return -np.ones(len(line_seg1), dtype=int) # Sample points regularly along each line if self.sampling_mode == "regular": line_points1, valid_points1 = self.sample_line_points(line_seg1) line_points2, valid_points2 = self.sample_line_points(line_seg2) else: line_points1, valid_points1 = self.sample_salient_points( line_seg1, desc1, img_size1, self.sampling_mode) line_points2, valid_points2 = self.sample_salient_points( line_seg2, desc2, img_size2, self.sampling_mode) line_points1 = torch.tensor(line_points1.reshape(-1, 2), dtype=torch.float, device=device) line_points2 = torch.tensor(line_points2.reshape(-1, 2), dtype=torch.float, device=device) # Extract the descriptors for each point grid1 = keypoints_to_grid(line_points1, img_size1) grid2 = keypoints_to_grid(line_points2, img_size2) desc1 = F.normalize(F.grid_sample(desc1, grid1)[0, :, :, 0], dim=0) desc2 = F.normalize(F.grid_sample(desc2, grid2)[0, :, :, 0], dim=0) # Precompute the distance between line points for every pair of lines # Assign a score of -1 for unvalid points scores = (desc1.t() @ desc2).cpu().numpy() scores[~valid_points1.flatten()] = -1 scores[:, ~valid_points2.flatten()] = -1 scores = scores.reshape(len(line_seg1), self.num_samples, len(line_seg2), self.num_samples) scores = scores.transpose(0, 2, 1, 3) # scores.shape = (n_lines1, n_lines2, num_samples, num_samples) # Pre-filter the line candidates and find the best match for each line matches = self.filter_and_match_lines(scores) # [Optionally] filter matches with mutual nearest neighbor filtering if self.cross_check: matches2 = self.filter_and_match_lines( scores.transpose(1, 0, 3, 2)) mutual = matches2[matches] == np.arange(len(line_seg1)) matches[~mutual] = -1 return matches
def get_pairwise_distance(self, line_seg1, line_seg2, desc1, desc2): """ Compute the OPPOSITE of the NW score for pairs of line segments and their corresponding descriptors. """ num_lines = len(line_seg1) assert num_lines == len(line_seg2), "The same number of lines is required in pairwise score." img_size1 = (desc1.shape[2] * self.grid_size, desc1.shape[3] * self.grid_size) img_size2 = (desc2.shape[2] * self.grid_size, desc2.shape[3] * self.grid_size) device = desc1.device # Sample points regularly along each line line_points1, valid_points1 = self.sample_line_points(line_seg1) line_points2, valid_points2 = self.sample_line_points(line_seg2) line_points1 = torch.tensor(line_points1.reshape(-1, 2), dtype=torch.float, device=device) line_points2 = torch.tensor(line_points2.reshape(-1, 2), dtype=torch.float, device=device) # Extract the descriptors for each point grid1 = keypoints_to_grid(line_points1, img_size1) grid2 = keypoints_to_grid(line_points2, img_size2) desc1 = F.normalize(F.grid_sample(desc1, grid1)[0, :, :, 0], dim=0) desc1 = desc1.reshape(-1, num_lines, self.num_samples) desc2 = F.normalize(F.grid_sample(desc2, grid2)[0, :, :, 0], dim=0) desc2 = desc2.reshape(-1, num_lines, self.num_samples) # Compute the distance between line points for every pair of lines # Assign a score of -1 for unvalid points scores = torch.einsum('dns,dnt->nst', desc1, desc2).cpu().numpy() scores = scores.reshape(num_lines * self.num_samples, self.num_samples) scores[~valid_points1.flatten()] = -1 scores = scores.reshape(num_lines, self.num_samples, self.num_samples) scores = scores.transpose(1, 0, 2).reshape(self.num_samples, -1) scores[:, ~valid_points2.flatten()] = -1 scores = scores.reshape(self.num_samples, num_lines, self.num_samples) scores = scores.transpose(1, 0, 2) # scores.shape = (num_lines, num_samples, num_samples) # Compute the NW score for each pair of lines pairwise_scores = np.array([self.needleman_wunsch(s) for s in scores]) return -pairwise_scores
def __call__(self, points1, points2, desc_pred1, desc_pred2, line_indices): b_size, _, Hc, Wc = desc_pred1.size() img_size = (Hc * self.grid_size, Wc * self.grid_size) device = desc_pred1.device # Extract valid keypoints n_points = line_indices.size()[1] valid_points = line_indices.bool().flatten() n_correct_points = torch.sum(valid_points).item() if n_correct_points == 0: return torch.tensor(0., dtype=torch.float, device=device) # Convert the keypoints to a grid suitable for interpolation grid1 = keypoints_to_grid(points1, img_size) grid2 = keypoints_to_grid(points2, img_size) # Extract the descriptors desc1 = F.grid_sample(desc_pred1, grid1).permute(0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points] desc1 = F.normalize(desc1, dim=1) desc2 = F.grid_sample(desc_pred2, grid2).permute(0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points] desc2 = F.normalize(desc2, dim=1) desc_dists = 2 - 2 * (desc1 @ desc2.t()) # Compute percentage of correct matches matches0 = torch.min(desc_dists, dim=1)[1] matches1 = torch.min(desc_dists, dim=0)[1] matching_score = (matches1[matches0] == torch.arange( len(matches0)).to(device)) matching_score = matching_score.float().mean() return matching_score
def triplet_loss(desc_pred1, desc_pred2, points1, points2, line_indices, epoch, grid_size=8, dist_threshold=8, init_dist_threshold=64, margin=1): """ Regular triplet loss for descriptor learning. """ b_size, _, Hc, Wc = desc_pred1.size() img_size = (Hc * grid_size, Wc * grid_size) device = desc_pred1.device # Extract valid keypoints n_points = line_indices.size()[1] valid_points = line_indices.bool().flatten() n_correct_points = torch.sum(valid_points).item() if n_correct_points == 0: return torch.tensor(0., dtype=torch.float, device=device) # Check which keypoints are too close to be matched # dist_threshold is decreased at each epoch for easier training dist_threshold = max(dist_threshold, 2 * init_dist_threshold // (epoch + 1)) dist_mask = get_dist_mask(points1, points2, valid_points, dist_threshold) # Additionally ban negative mining along the same line common_line_mask = get_common_line_mask(line_indices, valid_points) dist_mask = dist_mask | common_line_mask # Convert the keypoints to a grid suitable for interpolation grid1 = keypoints_to_grid(points1, img_size) grid2 = keypoints_to_grid(points2, img_size) # Extract the descriptors desc1 = F.grid_sample(desc_pred1, grid1).permute(0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points] desc1 = F.normalize(desc1, dim=1) desc2 = F.grid_sample(desc_pred2, grid2).permute(0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points] desc2 = F.normalize(desc2, dim=1) desc_dists = 2 - 2 * (desc1 @ desc2.t()) # Positive distance loss pos_dist = torch.diag(desc_dists) # Negative distance loss max_dist = torch.tensor(4., dtype=torch.float, device=device) desc_dists[torch.arange(n_correct_points, dtype=torch.long), torch.arange(n_correct_points, dtype=torch.long)] = max_dist desc_dists[dist_mask] = max_dist neg_dist = torch.min( torch.min(desc_dists, dim=1)[0], torch.min(desc_dists, dim=0)[0]) triplet_loss = F.relu(margin + pos_dist - neg_dist) return triplet_loss, grid1, grid2, valid_points
def sample_salient_points(self, line_seg, desc, img_size, saliency_type='d2_net'): """ Sample the most salient points along each line segments, with a minimal distance between each point. Pad the remaining points. Inputs: line_seg: an Nx2x2 torch.Tensor. desc: a NxDxHxW torch.Tensor. image_size: the original image size. saliency_type: 'd2_net' or 'asl_feat'. Outputs: line_points: an Nxnum_samplesx2 np.array. valid_points: a boolean Nxnum_samples np.array. """ device = desc.device if not self.line_score: # Compute the score map if saliency_type == "d2_net": score = self.d2_net_saliency_score(desc) else: score = self.asl_feat_saliency_score(desc) num_lines = len(line_seg) line_lengths = np.linalg.norm(line_seg[:, 0] - line_seg[:, 1], axis=1) # The number of samples depends on the length of the line num_samples_lst = np.clip(line_lengths // self.min_dist_pts, 2, self.num_samples) line_points = np.empty((num_lines, self.num_samples, 2), dtype=float) valid_points = np.empty((num_lines, self.num_samples), dtype=bool) # Sample the score on a fixed number of points of each line n_samples_per_region = 4 for n in np.arange(2, self.num_samples + 1): sample_rate = n * n_samples_per_region # Consider all lines where we can fit up to n points cur_mask = num_samples_lst == n cur_line_seg = line_seg[cur_mask] cur_num_lines = len(cur_line_seg) if cur_num_lines == 0: continue line_points_x = np.linspace(cur_line_seg[:, 0, 0], cur_line_seg[:, 1, 0], sample_rate, axis=-1) line_points_y = np.linspace(cur_line_seg[:, 0, 1], cur_line_seg[:, 1, 1], sample_rate, axis=-1) cur_line_points = np.stack([line_points_x, line_points_y], axis=-1).reshape(-1, 2) # cur_line_points is of shape (n_cur_lines * sample_rate, 2) cur_line_points = torch.tensor(cur_line_points, dtype=torch.float, device=device) grid_points = keypoints_to_grid(cur_line_points, img_size) if self.line_score: # The saliency score is high when the activation are locally # maximal along the line (and not in a square neigborhood) line_desc = F.grid_sample(desc, grid_points).squeeze() line_desc = line_desc.reshape(-1, cur_num_lines, sample_rate) line_desc = line_desc.permute(1, 0, 2) if saliency_type == "d2_net": scores = self.d2_net_saliency_score(line_desc) else: scores = self.asl_feat_saliency_score(line_desc) else: scores = F.grid_sample(score.unsqueeze(1), grid_points).squeeze() # Take the most salient point in n distinct regions scores = scores.reshape(-1, n, n_samples_per_region) best = torch.max(scores, dim=2, keepdim=True)[1].cpu().numpy() cur_line_points = cur_line_points.reshape(-1, n, n_samples_per_region, 2) cur_line_points = np.take_along_axis( cur_line_points, best[..., None], axis=2)[:, :, 0] # Pad cur_valid_points = np.ones((cur_num_lines, self.num_samples), dtype=bool) cur_valid_points[:, n:] = False cur_line_points = np.concatenate([ cur_line_points, np.zeros((cur_num_lines, self.num_samples - n, 2), dtype=float)], axis=1) line_points[cur_mask] = cur_line_points valid_points[cur_mask] = cur_valid_points return line_points, valid_points