def _valid_epoch(self):
    # Change the network to evaluation mode
    self.model.eval()
    self.val_data_loader.dataset.reset_seed(0)
    num_data = 0
    hit_ratio_meter, feat_match_ratio, loss_meter, rte_meter, rre_meter = AverageMeter(
    ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
    data_timer, feat_timer, matching_timer = Timer(), Timer(), Timer()
    tot_num_data = len(self.val_data_loader.dataset)
    if self.val_max_iter > 0:
      tot_num_data = min(self.val_max_iter, tot_num_data)
    data_loader_iter = self.val_data_loader.__iter__()

    for batch_idx in range(tot_num_data):
      data_timer.tic()
      input_dict = data_loader_iter.next()
      data_timer.toc()

      # pairs consist of (xyz1 index, xyz0 index)
      feat_timer.tic()
      sinput0 = ME.SparseTensor(
          input_dict['sinput0_F'], coords=input_dict['sinput0_C']).to(self.device)
      F0 = self.model(sinput0).F

      sinput1 = ME.SparseTensor(
          input_dict['sinput1_F'], coords=input_dict['sinput1_C']).to(self.device)
      F1 = self.model(sinput1).F
      feat_timer.toc()

      matching_timer.tic()
      xyz0, xyz1, T_gt = input_dict['pcd0'], input_dict['pcd1'], input_dict['T_gt']
      xyz0_corr, xyz1_corr = self.find_corr(xyz0, xyz1, F0, F1, subsample_size=5000)
      T_est = te.est_quad_linear_robust(xyz0_corr, xyz1_corr)

      loss = corr_dist(T_est, T_gt, xyz0, xyz1, weight=None)
      loss_meter.update(loss)

      rte = np.linalg.norm(T_est[:3, 3] - T_gt[:3, 3])
      rte_meter.update(rte)
      rre = np.arccos((np.trace(T_est[:3, :3].t() @ T_gt[:3, :3]) - 1) / 2)
      if not np.isnan(rre):
        rre_meter.update(rre)

      hit_ratio = self.evaluate_hit_ratio(
          xyz0_corr, xyz1_corr, T_gt, thresh=self.config.hit_ratio_thresh)
      hit_ratio_meter.update(hit_ratio)
      feat_match_ratio.update(hit_ratio > 0.05)
      matching_timer.toc()

      num_data += 1
      torch.cuda.empty_cache()

      if batch_idx % 100 == 0 and batch_idx > 0:
        logging.info(' '.join([
            f"Validation iter {num_data} / {tot_num_data} : Data Loading Time: {data_timer.avg:.3f},",
            f"Feature Extraction Time: {feat_timer.avg:.3f}, Matching Time: {matching_timer.avg:.3f},",
            f"Loss: {loss_meter.avg:.3f}, RTE: {rte_meter.avg:.3f}, RRE: {rre_meter.avg:.3f},",
            f"Hit Ratio: {hit_ratio_meter.avg:.3f}, Feat Match Ratio: {feat_match_ratio.avg:.3f}"
        ]))
        data_timer.reset()

    logging.info(' '.join([
        f"Final Loss: {loss_meter.avg:.3f}, RTE: {rte_meter.avg:.3f}, RRE: {rre_meter.avg:.3f},",
        f"Hit Ratio: {hit_ratio_meter.avg:.3f}, Feat Match Ratio: {feat_match_ratio.avg:.3f}"
    ]))
    return {
        "loss": loss_meter.avg,
        "rre": rre_meter.avg,
        "rte": rte_meter.avg,
        'feat_match_ratio': feat_match_ratio.avg,
        'hit_ratio': hit_ratio_meter.avg
    }
예제 #2
0
def main(config):
    test_loader = make_data_loader(
        config, config.test_phase, 1, num_threads=config.test_num_workers, shuffle=True)

    num_feats = 1

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    Model = load_model(config.model)
    model = Model(
        num_feats,
        config.model_n_out,
        bn_momentum=config.bn_momentum,
        conv1_kernel_size=config.conv1_kernel_size,
        normalize_feature=config.normalize_feature)
    checkpoint = torch.load(config.save_dir + '/checkpoint.pth')
    model.load_state_dict(checkpoint['state_dict'])
    model = model.to(device)
    model.eval()

    success_meter, rte_meter, rre_meter = AverageMeter(), AverageMeter(), AverageMeter()
    data_timer, feat_timer, reg_timer = Timer(), Timer(), Timer()

    test_iter = test_loader.__iter__()
    N = len(test_iter)
    n_gpu_failures = 0

    # downsample_voxel_size = 2 * config.voxel_size

    for i in range(len(test_iter)):
        data_timer.tic()
        try:
            data_dict = test_iter.next()
        except ValueError:
            n_gpu_failures += 1
            logging.info(f"# Erroneous GPU Pair {n_gpu_failures}")
            continue
        data_timer.toc()
        xyz0, xyz1 = data_dict['pcd0'], data_dict['pcd1']
        T_gth = data_dict['T_gt']
        xyz0np, xyz1np = xyz0.numpy(), xyz1.numpy()

        pcd0 = make_open3d_point_cloud(xyz0np)
        pcd1 = make_open3d_point_cloud(xyz1np)

        with torch.no_grad():
            feat_timer.tic()
            sinput0 = ME.SparseTensor(
                data_dict['sinput0_F'].to(device), coordinates=data_dict['sinput0_C'].to(device))
            F0 = model(sinput0).F.detach()
            sinput1 = ME.SparseTensor(
                data_dict['sinput1_F'].to(device), coordinates=data_dict['sinput1_C'].to(device))
            F1 = model(sinput1).F.detach()
            feat_timer.toc()

        feat0 = make_open3d_feature(F0, 32, F0.shape[0])
        feat1 = make_open3d_feature(F1, 32, F1.shape[0])

        reg_timer.tic()
        distance_threshold = config.voxel_size * 1.0
        ransac_result = o3d.registration.registration_ransac_based_on_feature_matching(
            pcd0, pcd1, feat0, feat1, distance_threshold,
            o3d.registration.TransformationEstimationPointToPoint(False), 4, [
                o3d.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9),
                o3d.registration.CorrespondenceCheckerBasedOnDistance(
                    distance_threshold)
            ], o3d.registration.RANSACConvergenceCriteria(4000000, 10000))
        T_ransac = torch.from_numpy(
            ransac_result.transformation.astype(np.float32))
        reg_timer.toc()

        # Translation error
        rte = np.linalg.norm(T_ransac[:3, 3] - T_gth[:3, 3])
        rre = np.arccos((np.trace(T_ransac[:3, :3].t() @ T_gth[:3, :3]) - 1) / 2)

        # Check if the ransac was successful. successful if rte < 2m and rre < 5◦
        # http://openaccess.thecvf.com/content_ECCV_2018/papers/Zi_Jian_Yew_3DFeat-Net_Weakly_Supervised_ECCV_2018_paper.pdf
        if rte < 2:
            rte_meter.update(rte)

        if not np.isnan(rre) and rre < np.pi / 180 * 5:
            rre_meter.update(rre)

        if rte < 2 and not np.isnan(rre) and rre < np.pi / 180 * 5:
            success_meter.update(1)
        else:
            success_meter.update(0)
            logging.info(f"Failed with RTE: {rte}, RRE: {rre}")

        if i % 10 == 0:
            logging.info(
                f"{i} / {N}: Data time: {data_timer.avg}, Feat time: {feat_timer.avg}," +
                f" Reg time: {reg_timer.avg}, RTE: {rte_meter.avg}," +
                f" RRE: {rre_meter.avg}, Success: {success_meter.sum} / {success_meter.count}"
                + f" ({success_meter.avg * 100} %)")
            data_timer.reset()
            feat_timer.reset()
            reg_timer.reset()

    logging.info(
        f"RTE: {rte_meter.avg}, var: {rte_meter.var}," +
        f" RRE: {rre_meter.avg}, var: {rre_meter.var}, Success: {success_meter.sum} " +
        f"/ {success_meter.count} ({success_meter.avg * 100} %)")
예제 #3
0
    def _valid_epoch(self, data_loader_iter):
        # Change the network to evaluation mode
        self.model.eval()
        num_data = 0
        hit_ratio_meter, reciprocity_ratio_meter = AverageMeter(
        ), AverageMeter()
        reciprocity_hit_ratio_meter = AverageMeter()
        data_timer, feat_timer = Timer(), Timer()
        tot_num_data = len(self.val_data_loader.dataset)
        if self.val_max_iter > 0:
            tot_num_data = min(self.val_max_iter, tot_num_data)

        for curr_iter in range(tot_num_data):
            data_timer.tic()
            input_dict = self.get_data(data_loader_iter)
            data_timer.toc()

            # pairs consist of (xyz1 index, xyz0 index)
            feat_timer.tic()
            with torch.no_grad():
                F0 = self.model(input_dict['img0'].to(self.device))
                F1 = self.model(input_dict['img1'].to(self.device))
            feat_timer.toc()

            # Test self.num_pos_per_batch * self.batch_size features only.
            _, _, H0, W0 = F0.shape
            _, _, H1, W1 = F1.shape
            for batch_idx, pair in enumerate(input_dict['pairs']):
                N = len(pair)
                sel = np.random.choice(N,
                                       min(N, self.config.num_pos_per_batch),
                                       replace=False)
                curr_pair = pair[sel]
                w0, h0, w1, h1 = torch.floor(curr_pair.t() /
                                             self.out_tensor_stride).long()
                feats0 = F0[batch_idx, :, h0, w0]
                nn_inds1 = find_nn_gpu(feats0,
                                       F1[batch_idx, :].view(F1.shape[1], -1),
                                       nn_max_n=self.config.nn_max_n,
                                       transposed=True)

                # Convert the index to coordinate: BxCxHxW
                xs1 = nn_inds1 % W1
                ys1 = nn_inds1 // W1

                # Test reciprocity
                nn_inds0 = find_nn_gpu(F1[batch_idx, :, ys1, xs1],
                                       F0[batch_idx, :].view(F0.shape[1], -1),
                                       nn_max_n=self.config.nn_max_n,
                                       transposed=True)

                # Convert the index to coordinate: BxCxHxW
                xs0 = nn_inds0 % W0
                ys0 = nn_inds0 // W0

                dist_sq = (w1 - xs1)**2 + (h1 - ys1)**2
                is_correct = dist_sq < (self.config.ucn_inlier_threshold_pixel
                                        / self.out_tensor_stride)**2
                hit_ratio_meter.update(is_correct.sum().item() /
                                       len(is_correct))

                # Recipocity test result
                dist_sq_nn = (w0 - xs0)**2 + (h0 - ys0)**2
                mask = dist_sq_nn < (self.config.ucn_inlier_threshold_pixel /
                                     self.out_tensor_stride)**2
                reciprocity_ratio_meter.update(mask.sum().item() /
                                               float(len(mask)))
                reciprocity_hit_ratio_meter.update(
                    is_correct[mask].sum().item() / (mask.sum().item() + eps))

                torch.cuda.empty_cache()
                # visualize_image_correspondence(input_dict['img0'][batch_idx, 0].numpy() + 0.5,
                #                                input_dict['img1'][batch_idx, 0].numpy() + 0.5,
                #                                F0[batch_idx], F1[batch_idx], curr_iter,
                #                                self.config)

            num_data += 1

            if num_data % 100 == 0:
                logging.info(', '.join([
                    f"Validation iter {num_data} / {tot_num_data} : Data Loading Time: {data_timer.avg:.3f}",
                    f"Feature Extraction Time: {feat_timer.avg:.3f}, Hit Ratio: {hit_ratio_meter.avg}",
                    f"Reciprocity Ratio: {reciprocity_ratio_meter.avg}, Reci Filtered Hit Ratio: {reciprocity_hit_ratio_meter.avg}"
                ]))
                data_timer.reset()

        logging.info(', '.join([
            f"Validation : Data Loading Time: {data_timer.avg:.3f}",
            f"Feature Extraction Time: {feat_timer.avg:.3f}, Hit Ratio: {hit_ratio_meter.avg}",
            f"Reciprocity Ratio: {reciprocity_ratio_meter.avg}, Reci Filtered Hit Ratio: {reciprocity_hit_ratio_meter.avg}"
        ]))

        return {
            'hit_ratio': hit_ratio_meter.avg,
            'reciprocity_ratio': reciprocity_ratio_meter.avg,
            'reciprocity_hit_ratio': reciprocity_hit_ratio_meter.avg,
        }
예제 #4
0
    def _valid_epoch(self):
        # Change the network to evaluation mode
        self.model.eval()
        self.val_data_loader.dataset.reset_seed(0)
        num_data = 0
        hit_ratio_meter, feat_match_ratio, loss_meter, rte_meter, rre_meter = AverageMeter(
        ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
        data_timer, feat_timer, matching_timer = Timer(), Timer(), Timer()
        tot_num_data = len(self.val_data_loader.dataset)
        if self.val_max_iter > 0:
            tot_num_data = min(self.val_max_iter, tot_num_data)
        data_loader_iter = self.val_data_loader.__iter__()

        for batch_idx in range(tot_num_data):
            data_timer.tic()
            input_dict = data_loader_iter.next()
            data_timer.toc()

            # pairs consist of (xyz1 index, xyz0 index)
            feat_timer.tic()

            coords = input_dict['sinput0_C'].to(self.device)
            sinput0 = ME.SparseTensor(
                input_dict['sinput0_F'].to(self.device),
                coordinates=input_dict['sinput0_C'].to(self.device).type(torch.float))

            F0 = self.model(sinput0).F

            sinput1 = ME.SparseTensor(
                input_dict['sinput1_F'].to(self.device),
                coordinates=input_dict['sinput1_C'].to(self.device).type(torch.float))
            F1 = self.model(sinput1).F
            feat_timer.toc()

            matching_timer.tic()
            xyz0, xyz1, T_gt = input_dict['pcd0'], input_dict['pcd1'], input_dict['T_gt']
            xyz0_corr, xyz1_corr = self.find_corr(xyz0, xyz1, F0, F1, subsample_size=5000)

            if False:

                from sklearn.decomposition import PCA
                import open3d as o3d

                pc0 = o3d.geometry.PointCloud()
                pc0.points = o3d.utility.Vector3dVector(xyz0.numpy())
                pca = PCA(n_components=3)

                colors = pca.fit_transform(torch.cat((F0, F1), axis=0).cpu().numpy())
                colors -= colors.min()
                colors /= colors.max()
                pc0.colors = o3d.utility.Vector3dVector(colors[0:F0.shape[0]])

                o3d.io.write_point_cloud("pc0.ply", pc0)
                pc0.transform(T_gt.numpy())
                o3d.io.write_point_cloud("pc0_trans.ply", pc0)

                pc1 = o3d.geometry.PointCloud()
                pc1.points = o3d.utility.Vector3dVector(xyz1.numpy())
                pc1.colors = o3d.utility.Vector3dVector(colors[F0.shape[0]:])
                o3d.io.write_point_cloud("pc1.ply", pc1)

                ind_0 = input_dict['correspondences'][:, 0].type(torch.long)
                ind_1 = input_dict['correspondences'][:, 1].type(torch.long)

                pc1.points = o3d.utility.Vector3dVector(xyz1[ind_1].numpy())
                pc1.colors = o3d.utility.Vector3dVector(
                    colors[F0.shape[0]:][ind_1])
                o3d.io.write_point_cloud("pc1_corr.ply", pc1)

                pc0.points = o3d.utility.Vector3dVector(xyz0[ind_0].numpy())
                pc0.colors = o3d.utility.Vector3dVector(colors[:F0.shape[0]][ind_0])
                pc0.transform(T_gt.numpy())
                o3d.io.write_point_cloud("pc0_trans_corr.ply", pc0)
                import pdb
                pdb.set_trace()

            #pc0.points = o3d.utility.Vector3dVector(xyz0_corr.numpy())
            # pc0.transform(T_gt.numpy())
            #o3d.io.write_point_cloud("xyz0_corr_trans.ply" , pc0)
#
            #pc0.points = o3d.utility.Vector3dVector(xyz1_corr.numpy())
            #o3d.io.write_point_cloud("xyz1_corr_trans.ply" , pc0)

            T_est = te.est_quad_linear_robust(xyz0_corr, xyz1_corr)

            loss = corr_dist(T_est, T_gt, xyz0, xyz1, weight=None)
            loss_meter.update(loss)

            rte = np.linalg.norm(T_est[:3, 3] - T_gt[:3, 3])
            rte_meter.update(rte)
            rre = np.arccos((np.trace(T_est[:3, :3].t() @ T_gt[:3, :3]) - 1) / 2)
            if not np.isnan(rre):
                rre_meter.update(rre)

            hit_ratio = self.evaluate_hit_ratio(xyz0_corr, xyz1_corr, T_gt, thresh=self.config.hit_ratio_thresh)
            hit_ratio_meter.update(hit_ratio)
            feat_match_ratio.update(hit_ratio > 0.05)
            matching_timer.toc()

            num_data += 1
            torch.cuda.empty_cache()

            if batch_idx % 100 == 0 and batch_idx > 0:
                logging.info(' '.join([
                    f"Validation iter {num_data} / {tot_num_data} : Data Loading Time: {data_timer.avg:.3f},",
                    f"Feature Extraction Time: {feat_timer.avg:.3f}, Matching Time: {matching_timer.avg:.3f},",
                    f"Loss: {loss_meter.avg:.3f}, RTE: {rte_meter.avg:.3f}, RRE: {rre_meter.avg:.3f},",
                    f"Hit Ratio: {hit_ratio_meter.avg:.3f}, Feat Match Ratio: {feat_match_ratio.avg:.3f}"
                ]))
                data_timer.reset()

        logging.info(' '.join([
            f"Final Loss: {loss_meter.avg:.3f}, RTE: {rte_meter.avg:.3f}, RRE: {rre_meter.avg:.3f},",
            f"Hit Ratio: {hit_ratio_meter.avg:.3f}, Feat Match Ratio: {feat_match_ratio.avg:.3f}"
        ]))
        return {
            "loss": loss_meter.avg,
            "rre": rre_meter.avg,
            "rte": rte_meter.avg,
            'feat_match_ratio': feat_match_ratio.avg,
            'hit_ratio': hit_ratio_meter.avg
        }