Пример #1
0
    def forward(self, line_seg1, line_seg2, desc1, desc2):
        """
            Find the best matches between two sets of line segments
            and their corresponding descriptors.
        """
        img_size1 = (desc1.shape[2] * self.grid_size,
                     desc1.shape[3] * self.grid_size)
        img_size2 = (desc2.shape[2] * self.grid_size,
                     desc2.shape[3] * self.grid_size)
        device = desc1.device
        
        # Default case when an image has no lines
        if len(line_seg1) == 0:
            return np.empty((0), dtype=int)
        if len(line_seg2) == 0:
            return -np.ones(len(line_seg1), dtype=int)

        # Sample points regularly along each line
        if self.sampling_mode == "regular":
            line_points1, valid_points1 = self.sample_line_points(line_seg1)
            line_points2, valid_points2 = self.sample_line_points(line_seg2)
        else:
            line_points1, valid_points1 = self.sample_salient_points(
                line_seg1, desc1, img_size1, self.sampling_mode)
            line_points2, valid_points2 = self.sample_salient_points(
                line_seg2, desc2, img_size2, self.sampling_mode)
        line_points1 = torch.tensor(line_points1.reshape(-1, 2),
                                    dtype=torch.float, device=device)
        line_points2 = torch.tensor(line_points2.reshape(-1, 2),
                                    dtype=torch.float, device=device)

        # Extract the descriptors for each point
        grid1 = keypoints_to_grid(line_points1, img_size1)
        grid2 = keypoints_to_grid(line_points2, img_size2)
        desc1 = F.normalize(F.grid_sample(desc1, grid1)[0, :, :, 0], dim=0)
        desc2 = F.normalize(F.grid_sample(desc2, grid2)[0, :, :, 0], dim=0)

        # Precompute the distance between line points for every pair of lines
        # Assign a score of -1 for unvalid points
        scores = (desc1.t() @ desc2).cpu().numpy()
        scores[~valid_points1.flatten()] = -1
        scores[:, ~valid_points2.flatten()] = -1
        scores = scores.reshape(len(line_seg1), self.num_samples,
                                len(line_seg2), self.num_samples)
        scores = scores.transpose(0, 2, 1, 3)
        # scores.shape = (n_lines1, n_lines2, num_samples, num_samples)

        # Pre-filter the line candidates and find the best match for each line
        matches = self.filter_and_match_lines(scores)

        # [Optionally] filter matches with mutual nearest neighbor filtering
        if self.cross_check:
            matches2 = self.filter_and_match_lines(
                scores.transpose(1, 0, 3, 2))
            mutual = matches2[matches] == np.arange(len(line_seg1))
            matches[~mutual] = -1

        return matches
Пример #2
0
    def get_pairwise_distance(self, line_seg1, line_seg2, desc1, desc2):
        """
            Compute the OPPOSITE of the NW score for pairs of line segments
            and their corresponding descriptors.
        """
        num_lines = len(line_seg1)
        assert num_lines == len(line_seg2), "The same number of lines is required in pairwise score."
        img_size1 = (desc1.shape[2] * self.grid_size,
                     desc1.shape[3] * self.grid_size)
        img_size2 = (desc2.shape[2] * self.grid_size,
                     desc2.shape[3] * self.grid_size)
        device = desc1.device

        # Sample points regularly along each line
        line_points1, valid_points1 = self.sample_line_points(line_seg1)
        line_points2, valid_points2 = self.sample_line_points(line_seg2)
        line_points1 = torch.tensor(line_points1.reshape(-1, 2),
                                    dtype=torch.float, device=device)
        line_points2 = torch.tensor(line_points2.reshape(-1, 2),
                                    dtype=torch.float, device=device)

        # Extract the descriptors for each point
        grid1 = keypoints_to_grid(line_points1, img_size1)
        grid2 = keypoints_to_grid(line_points2, img_size2)
        desc1 = F.normalize(F.grid_sample(desc1, grid1)[0, :, :, 0], dim=0)
        desc1 = desc1.reshape(-1, num_lines, self.num_samples)
        desc2 = F.normalize(F.grid_sample(desc2, grid2)[0, :, :, 0], dim=0)
        desc2 = desc2.reshape(-1, num_lines, self.num_samples)

        # Compute the distance between line points for every pair of lines
        # Assign a score of -1 for unvalid points
        scores = torch.einsum('dns,dnt->nst', desc1, desc2).cpu().numpy()
        scores = scores.reshape(num_lines * self.num_samples,
                                self.num_samples)
        scores[~valid_points1.flatten()] = -1
        scores = scores.reshape(num_lines, self.num_samples, self.num_samples)
        scores = scores.transpose(1, 0, 2).reshape(self.num_samples, -1)
        scores[:, ~valid_points2.flatten()] = -1
        scores = scores.reshape(self.num_samples, num_lines, self.num_samples)
        scores = scores.transpose(1, 0, 2)
        # scores.shape = (num_lines, num_samples, num_samples)

        # Compute the NW score for each pair of lines
        pairwise_scores = np.array([self.needleman_wunsch(s) for s in scores])
        return -pairwise_scores
Пример #3
0
    def __call__(self, points1, points2, desc_pred1, desc_pred2, line_indices):
        b_size, _, Hc, Wc = desc_pred1.size()
        img_size = (Hc * self.grid_size, Wc * self.grid_size)
        device = desc_pred1.device

        # Extract valid keypoints
        n_points = line_indices.size()[1]
        valid_points = line_indices.bool().flatten()
        n_correct_points = torch.sum(valid_points).item()
        if n_correct_points == 0:
            return torch.tensor(0., dtype=torch.float, device=device)

        # Convert the keypoints to a grid suitable for interpolation
        grid1 = keypoints_to_grid(points1, img_size)
        grid2 = keypoints_to_grid(points2, img_size)

        # Extract the descriptors
        desc1 = F.grid_sample(desc_pred1,
                              grid1).permute(0, 2, 3,
                                             1).reshape(b_size * n_points,
                                                        -1)[valid_points]
        desc1 = F.normalize(desc1, dim=1)
        desc2 = F.grid_sample(desc_pred2,
                              grid2).permute(0, 2, 3,
                                             1).reshape(b_size * n_points,
                                                        -1)[valid_points]
        desc2 = F.normalize(desc2, dim=1)
        desc_dists = 2 - 2 * (desc1 @ desc2.t())

        # Compute percentage of correct matches
        matches0 = torch.min(desc_dists, dim=1)[1]
        matches1 = torch.min(desc_dists, dim=0)[1]
        matching_score = (matches1[matches0] == torch.arange(
            len(matches0)).to(device))
        matching_score = matching_score.float().mean()
        return matching_score
Пример #4
0
def triplet_loss(desc_pred1,
                 desc_pred2,
                 points1,
                 points2,
                 line_indices,
                 epoch,
                 grid_size=8,
                 dist_threshold=8,
                 init_dist_threshold=64,
                 margin=1):
    """ Regular triplet loss for descriptor learning. """
    b_size, _, Hc, Wc = desc_pred1.size()
    img_size = (Hc * grid_size, Wc * grid_size)
    device = desc_pred1.device

    # Extract valid keypoints
    n_points = line_indices.size()[1]
    valid_points = line_indices.bool().flatten()
    n_correct_points = torch.sum(valid_points).item()
    if n_correct_points == 0:
        return torch.tensor(0., dtype=torch.float, device=device)

    # Check which keypoints are too close to be matched
    # dist_threshold is decreased at each epoch for easier training
    dist_threshold = max(dist_threshold,
                         2 * init_dist_threshold // (epoch + 1))
    dist_mask = get_dist_mask(points1, points2, valid_points, dist_threshold)

    # Additionally ban negative mining along the same line
    common_line_mask = get_common_line_mask(line_indices, valid_points)
    dist_mask = dist_mask | common_line_mask

    # Convert the keypoints to a grid suitable for interpolation
    grid1 = keypoints_to_grid(points1, img_size)
    grid2 = keypoints_to_grid(points2, img_size)

    # Extract the descriptors
    desc1 = F.grid_sample(desc_pred1,
                          grid1).permute(0, 2, 3,
                                         1).reshape(b_size * n_points,
                                                    -1)[valid_points]
    desc1 = F.normalize(desc1, dim=1)
    desc2 = F.grid_sample(desc_pred2,
                          grid2).permute(0, 2, 3,
                                         1).reshape(b_size * n_points,
                                                    -1)[valid_points]
    desc2 = F.normalize(desc2, dim=1)
    desc_dists = 2 - 2 * (desc1 @ desc2.t())

    # Positive distance loss
    pos_dist = torch.diag(desc_dists)

    # Negative distance loss
    max_dist = torch.tensor(4., dtype=torch.float, device=device)
    desc_dists[torch.arange(n_correct_points, dtype=torch.long),
               torch.arange(n_correct_points, dtype=torch.long)] = max_dist
    desc_dists[dist_mask] = max_dist
    neg_dist = torch.min(
        torch.min(desc_dists, dim=1)[0],
        torch.min(desc_dists, dim=0)[0])

    triplet_loss = F.relu(margin + pos_dist - neg_dist)
    return triplet_loss, grid1, grid2, valid_points
Пример #5
0
    def sample_salient_points(self, line_seg, desc, img_size,
                              saliency_type='d2_net'):
        """
        Sample the most salient points along each line segments, with a
        minimal distance between each point. Pad the remaining points.
        Inputs:
            line_seg: an Nx2x2 torch.Tensor.
            desc: a NxDxHxW torch.Tensor.
            image_size: the original image size.
            saliency_type: 'd2_net' or 'asl_feat'.
        Outputs:
            line_points: an Nxnum_samplesx2 np.array.
            valid_points: a boolean Nxnum_samples np.array.
        """
        device = desc.device
        if not self.line_score:
            # Compute the score map
            if saliency_type == "d2_net":
                score = self.d2_net_saliency_score(desc)
            else:
                score = self.asl_feat_saliency_score(desc)

        num_lines = len(line_seg)
        line_lengths = np.linalg.norm(line_seg[:, 0] - line_seg[:, 1], axis=1)

        # The number of samples depends on the length of the line
        num_samples_lst = np.clip(line_lengths // self.min_dist_pts,
                                  2, self.num_samples)
        line_points = np.empty((num_lines, self.num_samples, 2), dtype=float)
        valid_points = np.empty((num_lines, self.num_samples), dtype=bool)
        
        # Sample the score on a fixed number of points of each line
        n_samples_per_region = 4
        for n in np.arange(2, self.num_samples + 1):
            sample_rate = n * n_samples_per_region
            # Consider all lines where we can fit up to n points
            cur_mask = num_samples_lst == n
            cur_line_seg = line_seg[cur_mask]
            cur_num_lines = len(cur_line_seg)
            if cur_num_lines == 0:
                continue
            line_points_x = np.linspace(cur_line_seg[:, 0, 0],
                                        cur_line_seg[:, 1, 0],
                                        sample_rate, axis=-1)
            line_points_y = np.linspace(cur_line_seg[:, 0, 1],
                                        cur_line_seg[:, 1, 1],
                                        sample_rate, axis=-1)
            cur_line_points = np.stack([line_points_x, line_points_y],
                                       axis=-1).reshape(-1, 2)
            # cur_line_points is of shape (n_cur_lines * sample_rate, 2)
            cur_line_points = torch.tensor(cur_line_points, dtype=torch.float,
                                           device=device)
            grid_points = keypoints_to_grid(cur_line_points, img_size)

            if self.line_score:
                # The saliency score is high when the activation are locally
                # maximal along the line (and not in a square neigborhood)
                line_desc = F.grid_sample(desc, grid_points).squeeze()
                line_desc = line_desc.reshape(-1, cur_num_lines, sample_rate)
                line_desc = line_desc.permute(1, 0, 2)
                if saliency_type == "d2_net":
                    scores = self.d2_net_saliency_score(line_desc)
                else:
                    scores = self.asl_feat_saliency_score(line_desc)
            else:
                scores = F.grid_sample(score.unsqueeze(1),
                                       grid_points).squeeze()

            # Take the most salient point in n distinct regions
            scores = scores.reshape(-1, n, n_samples_per_region)
            best = torch.max(scores, dim=2, keepdim=True)[1].cpu().numpy()
            cur_line_points = cur_line_points.reshape(-1, n,
                                                      n_samples_per_region, 2)
            cur_line_points = np.take_along_axis(
                cur_line_points, best[..., None], axis=2)[:, :, 0]

            # Pad
            cur_valid_points = np.ones((cur_num_lines, self.num_samples),
                                       dtype=bool)
            cur_valid_points[:, n:] = False
            cur_line_points = np.concatenate([
                cur_line_points,
                np.zeros((cur_num_lines, self.num_samples - n, 2), dtype=float)],
                axis=1)
            
            line_points[cur_mask] = cur_line_points
            valid_points[cur_mask] = cur_valid_points

        return line_points, valid_points