def find_corr(self, xyz0, xyz1, F0, F1, subsample_size=-1):
    subsample = len(F0) > subsample_size
    if subsample_size > 0 and subsample:
      N0 = min(len(F0), subsample_size)
      N1 = min(len(F1), subsample_size)
      inds0 = np.random.choice(len(F0), N0, replace=False)
      inds1 = np.random.choice(len(F1), N1, replace=False)
      F0, F1 = F0[inds0], F1[inds1]

    # Compute the nn
    nn_inds = find_nn_gpu(F0, F1, nn_max_n=self.config.nn_max_n)
    if subsample_size > 0 and subsample:
      return xyz0[inds0], xyz1[inds1[nn_inds]]
    else:
      return xyz0, xyz1[nn_inds]
示例#2
0
    def _valid_epoch(self, data_loader_iter):
        # Change the network to evaluation mode
        self.model.eval()
        num_data = 0
        hit_ratio_meter, reciprocity_ratio_meter = AverageMeter(
        ), AverageMeter()
        reciprocity_hit_ratio_meter = AverageMeter()
        data_timer, feat_timer = Timer(), Timer()
        tot_num_data = len(self.val_data_loader.dataset)
        if self.val_max_iter > 0:
            tot_num_data = min(self.val_max_iter, tot_num_data)

        for curr_iter in range(tot_num_data):
            data_timer.tic()
            input_dict = self.get_data(data_loader_iter)
            data_timer.toc()

            # pairs consist of (xyz1 index, xyz0 index)
            feat_timer.tic()
            with torch.no_grad():
                F0 = self.model(input_dict['img0'].to(self.device))
                F1 = self.model(input_dict['img1'].to(self.device))
            feat_timer.toc()

            # Test self.num_pos_per_batch * self.batch_size features only.
            _, _, H0, W0 = F0.shape
            _, _, H1, W1 = F1.shape
            for batch_idx, pair in enumerate(input_dict['pairs']):
                N = len(pair)
                sel = np.random.choice(N,
                                       min(N, self.config.num_pos_per_batch),
                                       replace=False)
                curr_pair = pair[sel]
                w0, h0, w1, h1 = torch.floor(curr_pair.t() /
                                             self.out_tensor_stride).long()
                feats0 = F0[batch_idx, :, h0, w0]
                nn_inds1 = find_nn_gpu(feats0,
                                       F1[batch_idx, :].view(F1.shape[1], -1),
                                       nn_max_n=self.config.nn_max_n,
                                       transposed=True)

                # Convert the index to coordinate: BxCxHxW
                xs1 = nn_inds1 % W1
                ys1 = nn_inds1 // W1

                # Test reciprocity
                nn_inds0 = find_nn_gpu(F1[batch_idx, :, ys1, xs1],
                                       F0[batch_idx, :].view(F0.shape[1], -1),
                                       nn_max_n=self.config.nn_max_n,
                                       transposed=True)

                # Convert the index to coordinate: BxCxHxW
                xs0 = nn_inds0 % W0
                ys0 = nn_inds0 // W0

                dist_sq = (w1 - xs1)**2 + (h1 - ys1)**2
                is_correct = dist_sq < (self.config.ucn_inlier_threshold_pixel
                                        / self.out_tensor_stride)**2
                hit_ratio_meter.update(is_correct.sum().item() /
                                       len(is_correct))

                # Recipocity test result
                dist_sq_nn = (w0 - xs0)**2 + (h0 - ys0)**2
                mask = dist_sq_nn < (self.config.ucn_inlier_threshold_pixel /
                                     self.out_tensor_stride)**2
                reciprocity_ratio_meter.update(mask.sum().item() /
                                               float(len(mask)))
                reciprocity_hit_ratio_meter.update(
                    is_correct[mask].sum().item() / (mask.sum().item() + eps))

                torch.cuda.empty_cache()
                # visualize_image_correspondence(input_dict['img0'][batch_idx, 0].numpy() + 0.5,
                #                                input_dict['img1'][batch_idx, 0].numpy() + 0.5,
                #                                F0[batch_idx], F1[batch_idx], curr_iter,
                #                                self.config)

            num_data += 1

            if num_data % 100 == 0:
                logging.info(', '.join([
                    f"Validation iter {num_data} / {tot_num_data} : Data Loading Time: {data_timer.avg:.3f}",
                    f"Feature Extraction Time: {feat_timer.avg:.3f}, Hit Ratio: {hit_ratio_meter.avg}",
                    f"Reciprocity Ratio: {reciprocity_ratio_meter.avg}, Reci Filtered Hit Ratio: {reciprocity_hit_ratio_meter.avg}"
                ]))
                data_timer.reset()

        logging.info(', '.join([
            f"Validation : Data Loading Time: {data_timer.avg:.3f}",
            f"Feature Extraction Time: {feat_timer.avg:.3f}, Hit Ratio: {hit_ratio_meter.avg}",
            f"Reciprocity Ratio: {reciprocity_ratio_meter.avg}, Reci Filtered Hit Ratio: {reciprocity_hit_ratio_meter.avg}"
        ]))

        return {
            'hit_ratio': hit_ratio_meter.avg,
            'reciprocity_ratio': reciprocity_ratio_meter.avg,
            'reciprocity_hit_ratio': reciprocity_hit_ratio_meter.avg,
        }
示例#3
0
    def contrastive_loss(self,
                         img0,
                         img1,
                         F0,
                         F1,
                         pairs,
                         num_pos=5192,
                         num_hn_samples=2048):
        """
    F0: B x C x H0 x W0
    F0: B x C x H1 x W1
    Generate negative pairs
    """
        B, C, H0, W0 = F0.shape
        B1, C1, H1, W1 = F1.shape
        assert B == B1
        assert C == C1
        pos_loss_sum, neg_loss_sum = 0, 0
        sq_thresh = (self.config.ucn_inlier_threshold_pixel /
                     self.out_tensor_stride)**2
        for curr_F0, curr_F1, curr_pairs in zip(F0, F1, pairs):
            flat_F0 = curr_F0.view(C, -1)
            flat_F1 = curr_F1.view(C, -1)

            # Sample self.config.num_pos_per_batch,
            # Sample num_hn_samples as well for hardest negative mining
            N = len(curr_pairs)
            num_pos = min(num_pos, N)
            num_hn_samples = min(num_hn_samples, min(H0, H1) * min(W0, W1))

            sel_pos = np.random.choice(N, num_pos, replace=False)
            sel_pairs = curr_pairs[sel_pos]
            sel_neg0 = torch.from_numpy(
                np.random.choice(H0 * W0, num_hn_samples, replace=False))
            sel_neg1 = torch.from_numpy(
                np.random.choice(H1 * W1, num_hn_samples, replace=False))

            w0, h0, w1, h1 = torch.floor(sel_pairs.t() /
                                         self.out_tensor_stride).long()

            sel_pos0 = h0 * W0 + w0
            sel_pos1 = h1 * W1 + w1

            # Find negatives for all F1[positive_pairs[:, 1]]
            subF0, subF1 = flat_F0[:, sel_neg0], flat_F1[:, sel_neg1]
            posF0, posF1 = flat_F0[:, sel_pos0], flat_F1[:, sel_pos1]

            with torch.no_grad():
                nn_inds1 = find_nn_gpu(posF0,
                                       subF1,
                                       nn_max_n=self.config.nn_max_n,
                                       transposed=True)
                nn_inds0 = find_nn_gpu(posF1,
                                       subF0,
                                       nn_max_n=self.config.nn_max_n,
                                       transposed=True)

            D1ind = sel_neg1[nn_inds1]
            D0ind = sel_neg0[nn_inds0]

            neg_w1 = D1ind % W1
            neg_h1 = D1ind // W1

            neg_w0 = D0ind % W0
            neg_h0 = D0ind // W0

            # Check if they are outside the pixel thresh
            mask0 = ((h0 - neg_h0)**2 + (w0 - neg_w0)**2) > sq_thresh
            mask1 = ((h1 - neg_h1)**2 + (w1 - neg_w1)**2) > sq_thresh

            D01min = (posF0[:, mask0] -
                      subF1[:, nn_inds1[mask0]]).pow(2).sum(0)
            D10min = (posF1[:, mask1] -
                      subF0[:, nn_inds0[mask1]]).pow(2).sum(0)

            pw0, ph0, pw1, ph1 = torch.floor(curr_pairs.t() /
                                             self.out_tensor_stride).long()

            pos_loss = F.relu((curr_F0[:, ph0, pw0] -
                               curr_F1[:, ph1, pw1]).pow(2).sum(0) -
                              self.pos_thresh)
            neg_loss0 = F.relu(self.neg_thresh - D01min).pow(2)
            neg_loss1 = F.relu(self.neg_thresh - D10min).pow(2)

            pos_loss_sum += pos_loss.mean()
            neg_loss_sum += (neg_loss0.mean() + neg_loss1.mean()) / 2

        return pos_loss_sum / B, neg_loss_sum / B
示例#4
0
def visualize_image_correspondence(img0,
                                   img1,
                                   F0,
                                   F1,
                                   filename,
                                   mode='gpu-all',
                                   config=None,
                                   visualize=True):
    use_stability_test = True
    use_cyclic_test = False
    keypoint = 'sift'
    if keypoint == 'sift':
        sift = cv2.xfeatures2d.SIFT_create(
            0,
            9,
            0.01,  # Smaller more keypoints, default 0.04
            100  # larger more keypoints, default 10
        )
        kp0 = sift.detect(img0, None)
        kp1 = sift.detect(img1, None)
        xy_kp0 = np.floor(np.array([k.pt for k in kp0]).T)
        xy_kp1 = np.floor(np.array([k.pt for k in kp1]).T)
        x0, y0 = xy_kp0[0], xy_kp0[1]
        x1, y1 = xy_kp1[0], xy_kp1[1]
    elif keypoint == 'all':
        x0, y0 = None, None
        x1, y1 = None, None

    H0, W0 = img0.shape
    H1, W1 = img1.shape

    if mode == 'cpu-keypoints':
        matches1 = util_2d.feature_match(F0[:, y0, x0].t().cpu().numpy(),
                                         F1[:, y1, x1].t().cpu().numpy(),
                                         ratio_test=True,
                                         ratio=0.95)

        # Convert the index to coordinate: BxCxHxW
        x0 = x0[matches1[:, 0]]
        y0 = y0[matches1[:, 0]]
        xs1 = x1[matches1[:, 1]]
        ys1 = y1[matches1[:, 1]]

        # Test reciprocity
        nn_inds0 = find_nn_gpu(F1[:, ys1, xs1],
                               F0[:, y0, x0],
                               nn_max_n=config.nn_max_n,
                               transposed=True)

        # Convert the index to coordinate: BxCxHxW
        xs0 = x0[nn_inds0.numpy()]
        ys0 = y0[nn_inds0.numpy()]

        dist_sq_nn = (x0 - xs0)**2 + (y0 - ys0)**2
        mask = dist_sq_nn < (config.ucn_inlier_threshold_pixel**2)

    elif mode == 'gpu-keypoints':
        nn_inds1 = find_nn_gpu(F0[:, y0, x0],
                               F1[:, y1, x1],
                               nn_max_n=config.nn_max_n,
                               transposed=True).numpy()

        # Convert the index to coordinate: BxCxHxW
        xs1 = x1[nn_inds1]
        ys1 = y1[nn_inds1]

        if use_stability_test:
            # Stability test: check stable under perturbation
            noisex = 2 * (np.random.rand(len(xs1)) < 0.5) - 1
            noisey = 2 * (np.random.rand(len(ys1)) < 0.5) - 1
            xs1n = np.clip(xs1 + noisex, 0, W1 - 1)
            ys1n = np.clip(ys1 + noisey, 0, H1 - 1)
        else:
            xs1n = xs1
            ys1n = ys1

        # Test reciprocity
        nn_inds0 = find_nn_gpu(F1[:, ys1n, xs1n],
                               F0[:, y0, x0],
                               nn_max_n=config.nn_max_n,
                               transposed=True).numpy()

        # Convert the index to coordinate: BxCxHxW
        xs0 = x0[nn_inds0]
        ys0 = y0[nn_inds0]

        dist_sq_nn = (x0 - xs0)**2 + (y0 - ys0)**2
        mask = dist_sq_nn < (config.ucn_inlier_threshold_pixel**2)

    elif mode == 'gpu-all':
        nn_inds1 = find_nn_faiss(
            F0[:, y0, x0],
            F1.view(F1.shape[0], -1),
        )

        # Convert the index to coordinate: BxCxHxW
        xs1 = nn_inds1 % W1
        ys1 = nn_inds1 // W1

        if use_stability_test:
            # Stability test: check stable under perturbation
            noisex = 2 * (np.random.rand(len(xs1)) < 0.5) - 1
            noisey = 2 * (np.random.rand(len(ys1)) < 0.5) - 1
            xs1n = np.clip(xs1 + noisex, 0, W1 - 1)
            ys1n = np.clip(ys1 + noisey, 0, H1 - 1)
        else:
            xs1n = xs1
            ys1n = ys1

        if use_cyclic_test:
            # Test reciprocity
            nn_inds0 = find_nn_faiss(
                F1[:, ys1n, xs1n],
                F0.view(F0.shape[0], -1),
            )

            # Convert the index to coordinate: BxCxHxW
            xs0 = (nn_inds0 % W0)
            ys0 = (nn_inds0 // W0)

            # Test cyclic consistency
            dist_sq_nn = (x0 - xs0)**2 + (y0 - ys0)**2
            mask = dist_sq_nn < (config.ucn_inlier_threshold_pixel**2)

        else:
            xs0 = x0
            ys0 = y0
            mask = np.ones(len(x0)).astype(bool)

    elif mode == 'gpu-all-all':
        nn_inds1 = find_nn_faiss(
            F0.view(F0.shape[0], -1),
            F1.view(F1.shape[0], -1),
        )

        inds0 = np.arange(len(nn_inds1))
        x0 = inds0 % W0
        y0 = inds0 // W0

        xs1 = nn_inds1 % W1
        ys1 = nn_inds1 // W1

        if use_stability_test:
            # Stability test: check stable under perturbation
            noisex = 2 * (np.random.rand(len(xs1)) < 0.5) - 1
            noisey = 2 * (np.random.rand(len(ys1)) < 0.5) - 1
            xs1n = np.clip(xs1 + noisex, 0, W1 - 1)
            ys1n = np.clip(ys1 + noisey, 0, H1 - 1)
        else:
            xs1n = xs1
            ys1n = ys1

        # Test reciprocity
        nn_inds0 = find_nn_faiss(
            F1[:, ys1n, xs1n],
            F0.view(F0.shape[0], -1),
        )

        # Convert the index to coordinate: BxCxHxW
        xs0 = nn_inds0 % W0
        ys0 = nn_inds0 // W0

        # Filter out the points that fail the cycle consistency
        dist_sq_nn = (x0 - xs0)**2 + (y0 - ys0)**2
        mask = dist_sq_nn < (config.ucn_inlier_threshold_pixel**2)

    if visualize:
        color = x0[mask] + y0[mask] * W0
        plt.clf()
        fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2)
        fig = plt.gcf()
        fig.set_size_inches(9, 6)

        ax0.imshow(img0 * 0.5, vmin=0, vmax=255, cmap='gray')
        ax0.scatter(x=x0[mask], y=y0[mask], c=color, s=2, cmap="jet")
        ax0.axis('off')

        ax1.imshow(img1 * 0.5, vmin=0, vmax=255, cmap='gray')
        ax1.scatter(x=xs1[mask], y=ys1[mask], c=color, s=2, cmap="jet")
        ax1.axis('off')

        fig.tight_layout()
        ensure_dir('./ucn_outputs')
        plt.savefig(f"./ucn_outputs/{filename:03d}.png", dpi=300)
    else:
        return x0[mask], y0[mask], xs1[mask], ys1[mask]