def run(model: BaseModel, dataset: BaseDataset, device, output_path, cfg):
    # Set dataloaders
    num_fragment = dataset.num_fragment
    if cfg.data.is_patch:
        for i in range(num_fragment):
            dataset.set_patches(i)
            dataset.create_dataloaders(
                model,
                cfg.batch_size,
                False,
                cfg.num_workers,
                False,
            )
            loader = dataset.test_dataloaders()[0]
            features = []
            scene_name, pc_name = dataset.get_name(i)

            with Ctq(loader) as tq_test_loader:
                for data in tq_test_loader:
                    # pcd = open3d.geometry.PointCloud()
                    # pcd.points = open3d.utility.Vector3dVector(data.pos[0].numpy())
                    # open3d.visualization.draw_geometries([pcd])
                    with torch.no_grad():
                        model.set_input(data, device)
                        model.forward()
                        features.append(model.get_output().cpu())
            features = torch.cat(features, 0).numpy()
            log.info("save {} from {} in  {}".format(pc_name, scene_name,
                                                     output_path))
            save(output_path, scene_name, pc_name,
                 dataset.base_dataset[i].to("cpu"), features)
    else:
        dataset.create_dataloaders(
            model,
            1,
            False,
            cfg.num_workers,
            False,
        )
        loader = dataset.test_dataloaders()[0]
        with Ctq(loader) as tq_test_loader:
            for i, data in enumerate(tq_test_loader):
                with torch.no_grad():
                    model.set_input(data, device)
                    model.forward()
                    features = model.get_output()[0]  # batch of 1
                    save(output_path, scene_name, pc_name, data.to("cpu"),
                         features)
Exemplo n.º 2
0
def test_epoch(
    model: BaseModel,
    dataset,
    device,
    tracker: BaseTracker,
    checkpoint: ModelCheckpoint,
    voting_runs=1,
    tracker_options={},
):

    loaders = dataset.test_dataloaders

    for loader in loaders:
        stage_name = loader.dataset.name
        tracker.reset(stage_name)
        for i in range(voting_runs):
            with Ctq(loader) as tq_test_loader:
                for data in tq_test_loader:
                    with torch.no_grad():
                        model.set_input(data, device)
                        model.forward()

                    tracker.track(model, **tracker_options)
                    tq_test_loader.set_postfix(**tracker.get_metrics(),
                                               color=COLORS.TEST_COLOR)

        tracker.finalise(**tracker_options)
        tracker.print_summary()
Exemplo n.º 3
0
def run_epoch(model: BaseModel, loader, device: str, num_batches: int):
    model.eval()
    with Ctq(loader) as tq_loader:
        for batch_idx, data in enumerate(tq_loader):
            if batch_idx < num_batches:
                process(model, data, device)
            else:
                break
Exemplo n.º 4
0
def run(model: BaseModel, dataset: BaseDataset, device, cfg):
    dataset.create_dataloaders(
        model,
        1,
        False,
        cfg.training.num_workers,
        False,
    )
    loader = dataset.test_dataloaders[0]
    list_res = []
    with Ctq(loader) as tq_test_loader:
        for i, data in enumerate(tq_test_loader):
            with torch.no_grad():
                model.set_input(data, device)
                model.forward()

                name_scene, name_pair_source, name_pair_target = dataset.test_dataset[
                    0].get_name(i)
                input, input_target = model.get_input()
                xyz, xyz_target = input.pos, input_target.pos
                ind, ind_target = input.ind, input_target.ind
                matches_gt = torch.stack([ind, ind_target]).transpose(0, 1)
                feat, feat_target = model.get_output()
                rand = torch.randperm(len(feat))[:cfg.data.num_points]
                rand_target = torch.randperm(
                    len(feat_target))[:cfg.data.num_points]
                res = dict(name_scene=name_scene,
                           name_pair_source=name_pair_source,
                           name_pair_target=name_pair_target)
                T_gt = estimate_transfo(xyz[matches_gt[:, 0]],
                                        xyz_target[matches_gt[:, 1]])
                metric = compute_metrics(
                    xyz[rand],
                    xyz_target[rand_target],
                    feat[rand],
                    feat_target[rand_target],
                    T_gt,
                    sym=cfg.data.sym,
                    tau_1=cfg.data.tau_1,
                    tau_2=cfg.data.tau_2,
                    rot_thresh=cfg.data.rot_thresh,
                    trans_thresh=cfg.data.trans_thresh,
                    use_ransac=cfg.data.use_ransac,
                    ransac_thresh=cfg.data.first_subsampling,
                    use_teaser=cfg.data.use_teaser,
                    noise_bound_teaser=cfg.data.noise_bound_teaser,
                )
                res = dict(**res, **metric)
                list_res.append(res)

    df = pd.DataFrame(list_res)
    output_path = os.path.join(cfg.training.checkpoint_dir, cfg.data.name,
                               "matches")
    if not os.path.exists(output_path):
        os.makedirs(output_path, exist_ok=True)
    df.to_csv(osp.join(output_path, "final_res.csv"))
    print(df.groupby("name_scene").mean())
Exemplo n.º 5
0
def train_epoch(
    epoch: int,
    model: BaseModel,
    dataset,
    device: str,
    tracker: BaseTracker,
    checkpoint: ModelCheckpoint,
    visualizer: Visualizer,
    debugging,
):

    early_break = getattr(debugging, "early_break", False)
    profiling = getattr(debugging, "profiling", False)

    model.train()
    tracker.reset("train")
    visualizer.reset(epoch, "train")
    train_loader = dataset.train_dataloader

    iter_data_time = time.time()
    with Ctq(train_loader) as tq_train_loader:
        for i, data in enumerate(tq_train_loader):
            model.set_input(data, device)
            t_data = time.time() - iter_data_time

            iter_start_time = time.time()
            model.optimize_parameters(epoch, dataset.batch_size)
            if i % 10 == 0:
                tracker.track(model)

            tq_train_loader.set_postfix(**tracker.get_metrics(),
                                        data_loading=float(t_data),
                                        iteration=float(time.time() -
                                                        iter_start_time),
                                        color=COLORS.TRAIN_COLOR)

            if visualizer.is_active:
                visualizer.save_visuals(model.get_current_visuals())

            iter_data_time = time.time()

            if early_break:
                break

            if profiling:
                if i > getattr(debugging, "num_batches", 50):
                    return 0

    metrics = tracker.publish(epoch)
    checkpoint.save_best_models_under_current_metrics(model, metrics,
                                                      tracker.metric_func)
    log.info("Learning rate = %f" % model.learning_rate)
Exemplo n.º 6
0
    def _test_epoch(self, epoch, stage_name: str):
        voting_runs = self._cfg.get("voting_runs", 1)
        if stage_name == "test":
            loaders = self._dataset.test_dataloaders
        else:
            loaders = [self._dataset.val_dataloader]

        self._model.eval()
        if self.enable_dropout:
            self._model.enable_dropout_in_eval()

        for loader in loaders:
            stage_name = loader.dataset.name
            self._tracker.reset(stage_name)
            if self.has_visualization:
                self._visualizer.reset(epoch, stage_name)
            if not self._dataset.has_labels(
                    stage_name) and not self.tracker_options.get(
                        "make_submission",
                        False):  # No label, no submission -> do nothing
                log.warning("No forward will be run on dataset %s." %
                            stage_name)
                continue

            for i in range(voting_runs):
                with Ctq(loader) as tq_loader:
                    for data in tq_loader:
                        with torch.no_grad():
                            self._model.set_input(data, self._device)
                            with torch.cuda.amp.autocast(
                                    enabled=self._model.is_mixed_precision()):
                                self._model.forward(epoch=epoch)
                            self._tracker.track(self._model,
                                                data=data,
                                                **self.tracker_options)
                        tq_loader.set_postfix(**self._tracker.get_metrics(),
                                              color=COLORS.TEST_COLOR)

                        if self.has_visualization and self._visualizer.is_active:
                            self._visualizer.save_visuals(
                                self._model.get_current_visuals())

                        if self.early_break:
                            break

                        if self.profiling:
                            if i > self.num_batches:
                                return 0

            self._finalize_epoch(epoch)
            self._tracker.print_summary()
    def _train_epoch(self, epoch: int):

        self._model.train()
        self._tracker.reset("train")
        self._visualizer.reset(epoch, "train")
        train_loader = self._dataset.train_dataloader

        with self.profiler_profile(epoch) as prof:
            iter_data_time = time.time()
            with Ctq(train_loader) as tq_train_loader:
                for i, data in enumerate(tq_train_loader):
                    t_data = time.time() - iter_data_time
                    iter_start_time = time.time()

                    with self.profiler_record_function('train_step'):
                        self._model.set_input(data, self._device)
                        self._model.optimize_parameters(
                            epoch, self._dataset.batch_size)

                    with self.profiler_record_function('track/log/visualize'):
                        if i % 10 == 0:
                            with torch.no_grad():
                                self._tracker.track(self._model,
                                                    data=data,
                                                    **self.tracker_options)

                        tq_train_loader.set_postfix(
                            **self._tracker.get_metrics(),
                            data_loading=float(t_data),
                            iteration=float(time.time() - iter_start_time),
                            color=COLORS.TRAIN_COLOR)

                        if self._visualizer.is_active:
                            self._visualizer.save_visuals(
                                self._model.get_current_visuals())

                    iter_data_time = time.time()

                    if self.pytorch_profiler_log:
                        prof.step()

                    if self.early_break:
                        break

                    if self.profiling:
                        if i > self.num_batches:
                            return 0

        self._finalize_epoch(epoch)
Exemplo n.º 8
0
def run(model: BaseModel, dataset: BaseDataset, device, output_path):
    loaders = dataset.test_dataloaders
    predicted = {}
    for loader in loaders:
        loader.dataset.name
        with Ctq(loader) as tq_test_loader:
            for data in tq_test_loader:
                with torch.no_grad():
                    model.set_input(data, device)
                    model.forward()
                predicted = {
                    **predicted,
                    **dataset.predict_original_samples(data, model.conv_type,
                                                       model.get_output())
                }

    save(output_path, predicted)
Exemplo n.º 9
0
 def val_one_epoch(self, epoch):
     self.model.eval()
     self.metrics.reset()
     with Ctq(self.dataset.val_loader) as tq_loader:
         for i, data in enumerate(tq_loader):
             tq_loader.set_description('Val epoch[{}]'.format(epoch))
             data = data.to(self.device)
             y_target = data.y.reshape(-1) - 1
             with torch.no_grad():
                 y_pred = self.model(data)
             loss = F.cross_entropy(y_pred,
                                    y_target,
                                    weight=self.cfg.class_weights.to(
                                        self.device),
                                    ignore_index=self.cfg.ignore_index)
             tq_loader.set_postfix(loss=loss.item())
             self.metrics.update(y_target.cpu().numpy(),
                                 y_pred.max(dim=1)[1].cpu().numpy())
Exemplo n.º 10
0
def test_epoch(device):
    model.to(device)
    model.eval()
    tracker.reset("test")
    test_loader = dataset.test_dataloaders[0]
    iter_data_time = time.time()
    with Ctq(test_loader) as tq_test_loader:
        for i, data in enumerate(tq_test_loader):
            t_data = time.time() - iter_data_time
            iter_start_time = time.time()
            data.to(device)
            model.forward(data)
            tracker.track(model)

            tq_test_loader.set_postfix(
                **tracker.get_metrics(),
                data_loading=float(t_data),
                iteration=float(time.time() - iter_start_time),
            )
            iter_data_time = time.time()
Exemplo n.º 11
0
 def train_one_epoch(self, epoch):
     self.model.train()
     self.metrics.reset()
     with Ctq(self.dataset.train_loader) as tq_loader:
         for i, data in enumerate(tq_loader):
             tq_loader.set_description('Train epoch[{}]'.format(epoch))
             data = data.to(self.device)
             self.optimizer.zero_grad()
             y_pred = self.model(data)
             y_target = data.y.reshape(-1) - 1
             loss = F.cross_entropy(y_pred,
                                    y_target,
                                    weight=self.cfg.class_weights.to(
                                        self.device),
                                    ignore_index=self.cfg.ignore_index)
             loss.backward()
             self.optimizer.step()
             tq_loader.set_postfix(loss=loss.item())
             self.metrics.update(y_target.cpu().numpy(),
                                 y_pred.max(dim=1)[1].cpu().numpy())
Exemplo n.º 12
0
def test_epoch(
    epoch: int,
    model: BaseModel,
    dataset,
    device,
    tracker: BaseTracker,
    checkpoint: ModelCheckpoint,
    visualizer: Visualizer,
    debugging,
):
    early_break = getattr(debugging, "early_break", False)
    model.eval()

    loaders = dataset.test_dataloaders

    for loader in loaders:
        stage_name = loader.dataset.name
        tracker.reset(stage_name)
        visualizer.reset(epoch, stage_name)
        with Ctq(loader) as tq_test_loader:
            for data in tq_test_loader:
                with torch.no_grad():
                    model.set_input(data, device)
                    model.forward()

                tracker.track(model)
                tq_test_loader.set_postfix(**tracker.get_metrics(),
                                           color=COLORS.TEST_COLOR)

                if visualizer.is_active:
                    visualizer.save_visuals(model.get_current_visuals())

                if early_break:
                    break

        tracker.finalise()
        metrics = tracker.publish(epoch)
        tracker.print_summary()
        checkpoint.save_best_models_under_current_metrics(
            model, metrics, tracker.metric_func)
Exemplo n.º 13
0
def eval_epoch(
    epoch: int,
    model: BaseModel,
    dataset,
    device,
    tracker: BaseTracker,
    checkpoint: ModelCheckpoint,
    visualizer: Visualizer,
    debugging,
):

    early_break = getattr(debugging, "early_break", False)

    model.eval()
    tracker.reset("val")
    visualizer.reset(epoch, "val")
    loader = dataset.val_dataloader
    with Ctq(loader) as tq_val_loader:
        for data in tq_val_loader:
            with torch.no_grad():
                model.set_input(data, device)
                model.forward()

            tracker.track(model)
            tq_val_loader.set_postfix(**tracker.get_metrics(),
                                      color=COLORS.VAL_COLOR)

            if visualizer.is_active:
                visualizer.save_visuals(model.get_current_visuals())

            if early_break:
                break

    metrics = tracker.publish(epoch)
    tracker.print_summary()
    checkpoint.save_best_models_under_current_metrics(model, metrics,
                                                      tracker.metric_func)
Exemplo n.º 14
0
def train_epoch(device):
    model.to(device)
    model.train()
    tracker.reset("train")
    train_loader = dataset.train_dataloader
    iter_data_time = time.time()
    with Ctq(train_loader) as tq_train_loader:
        for i, data in enumerate(tq_train_loader):
            t_data = time.time() - iter_data_time
            iter_start_time = time.time()
            optimizer.zero_grad()
            data.to(device)
            model.forward(data)
            model.backward()
            optimizer.step()
            if i % 10 == 0:
                tracker.track(model)

            tq_train_loader.set_postfix(
                **tracker.get_metrics(),
                data_loading=float(t_data),
                iteration=float(time.time() - iter_start_time),
            )
            iter_data_time = time.time()
Exemplo n.º 15
0
def eval_epoch(
    model: BaseModel,
    dataset,
    device,
    tracker: BaseTracker,
    checkpoint: ModelCheckpoint,
    voting_runs=1,
    tracker_options={},
):
    tracker.reset("val")
    loader = dataset.val_dataloader
    for i in range(voting_runs):
        with Ctq(loader) as tq_val_loader:
            for data in tq_val_loader:
                with torch.no_grad():
                    model.set_input(data, device)
                    model.forward()

                tracker.track(model, **tracker_options)
                tq_val_loader.set_postfix(**tracker.get_metrics(),
                                          color=COLORS.VAL_COLOR)

    tracker.finalise(**tracker_options)
    tracker.print_summary()
Exemplo n.º 16
0
    def test(self, num_votes=100):
        logging.info('Test {} on {} ...'.format(self.cfg.model_name,
                                                self.cfg.dataset))
        test_smooth = 0.98
        saving_path = 'results/Semantic3D/predictions'
        os.makedirs(saving_path) if not os.path.exists(saving_path) else None

        # load model checkpoints
        self.model.load(
            'checkpoints/PointConvBig_on_Semantic3D_bs_8_epochs_100_big_crf.ckpt'
        )
        self.model.to(self.device)
        self.model.eval()

        epoch = 0
        last_min = -0.5
        while last_min < num_votes:
            # test one epoch
            with Ctq(self.dataset.val_loader) as tq_loader:
                for i, data in enumerate(tq_loader):
                    tq_loader.set_description('Evaluation')
                    # model inference
                    data = data.to(self.device)
                    with torch.no_grad():
                        probs = F.softmax(self.model(data),
                                          dim=-1)  # get pred probs

                    # running means for each epoch on Test set
                    point_idx = data.point_idx.cpu().numpy()  # the point idx
                    cloud_idx = data.cloud_idx.cpu().numpy()  # the cloud idx
                    probs = probs.reshape(
                        self.cfg.batch_size, -1,
                        self.cfg.num_classes).cpu().numpy()  # [B, N, C]
                    for b in range(
                            self.cfg.batch_size):  # for each sample in batch
                        prob = probs[b, :, :]  # [N, C]
                        p_idx = point_idx[b, :]  # [N]
                        c_idx = cloud_idx[b][0]  # int
                        self.test_probs[c_idx][p_idx] = test_smooth * self.test_probs[c_idx][p_idx] \
                                                        + (1 - test_smooth) * prob  # running means

            # after each epoch
            new_min = np.min(self.dataset.val_set.min_possibility)
            print('Epoch {:3d} end, current min possibility = {:.2f}'.format(
                epoch, new_min))
            if last_min + 4 < new_min:
                print('Test procedure done, saving predicted clouds ...')
                last_min = new_min
                # projection prediction to original point cloud
                t1 = time.time()
                for i, file in enumerate(self.dataset.val_set.val_files):
                    proj_idx = self.dataset.val_set.test_proj[
                        i]  # already get the shape
                    probs = self.test_probs[i][
                        proj_idx, :]  # same shape with proj_idx
                    # [0 ~ 7] + 1 -> [1 ~ 8], because 0 for unlabeled
                    preds = np.argmax(probs, axis=1).astype(np.uint8) + 1
                    # saving prediction results
                    cloud_name = file.split('/')[-1]
                    # ascii_name = os.path.join(saving_path, self.dataset.test_set.ascii_files[cloud_name])
                    # np.savetxt(ascii_name, preds, fmt='%d')
                    # print('Save {:s} succeed !'.format(ascii_name))
                    filename = os.path.join(saving_path, cloud_name)
                    write_ply(filename, [preds], ['pred'])
                    print('Save {:s} succeed !'.format(filename))
                t2 = time.time()
                print('Done in {:.2f} s.'.format(t2 - t1))
                return
            epoch += 1
        return
Exemplo n.º 17
0
    def test_s3dis(self, num_votes=100):
        logging.info('Evaluating {} on {} ...'.format(self.cfg.model_name,
                                                      self.cfg.dataset))
        test_smooth = 0.95
        # statistic label proportions in test set
        class_proportions = np.zeros(self.cfg.num_classes, dtype=np.float32)
        for i, label in enumerate(self.dataset.test_set.label_values):
            class_proportions[i] = np.sum([
                np.sum(labels == label)
                for labels in self.dataset.test_set.val_labels
            ])

        # load model checkpoints
        self.model.load(
            'checkpoints/RandLANet_on_S3DIS_bs_8_epochs_100_big.ckpt')
        self.model.to(self.device)
        self.model.eval()

        epoch = 0
        last_min = -0.5
        while last_min < num_votes:

            # test one epoch
            with Ctq(self.dataset.test_loader) as tq_loader:
                for i, data in enumerate(tq_loader):
                    tq_loader.set_description('Evaluation')

                    # model inference
                    data = data.to(self.device)
                    with torch.no_grad():
                        logits = self.model(data)  # get pred
                        y_pred = F.softmax(logits, dim=-1)

                    y_pred = y_pred.cpu().numpy()
                    y_target = data.y.cpu().numpy().reshape(-1)  # get target
                    point_idx = data.point_idx.cpu().numpy()  # the point idx
                    cloud_idx = data.cloud_idx.cpu().numpy()  # the cloud idx

                    # compute batch accuracy
                    correct = np.sum(np.argmax(y_pred, axis=1) == y_target)
                    acc = correct / float(np.prod(
                        np.shape(y_target)))  # accurate for each test batch
                    tq_loader.set_postfix(ACC=acc)

                    y_pred = y_pred.reshape(self.cfg.batch_size, -1,
                                            self.cfg.num_classes)  # [B, N, C]
                    for b in range(
                            self.cfg.batch_size):  # for each sample in batch
                        probs = y_pred[b, :, :]  # [N, C]
                        p_idx = point_idx[b, :]  # [N]
                        c_idx = cloud_idx[b][0]  # int
                        self.test_probs[c_idx][p_idx] = test_smooth * self.test_probs[c_idx][p_idx] \
                                                       + (1 - test_smooth) * probs   # running means

            new_min = np.min(self.dataset.test_set.min_possibility)
            print('Epoch {:3d} end, current min possibility = {:.2f}'.format(
                epoch, new_min))

            if last_min + 1 < new_min:
                # update last_min
                last_min += 1
                # show vote results
                print('Confusion on sub clouds.')
                confusion_list = []
                num_clouds = len(
                    self.dataset.test_set.input_labels)  # test cloud number
                for i in range(num_clouds):
                    probs = self.test_probs[i]
                    preds = self.dataset.test_set.label_values[np.argmax(
                        probs, axis=1)].astype(np.int32)
                    labels = self.dataset.test_set.input_labels[i]
                    confusion_list += [
                        confusion_matrix(labels, preds,
                                         self.dataset.test_set.label_values)
                    ]

                # re-group confusions
                C = np.sum(np.stack(confusion_list), axis=0).astype(np.float32)
                # re-scale with the right number of point per class
                C *= np.expand_dims(
                    class_proportions / (np.sum(C, axis=1) + 1e-6), 1)

                # compute IoU
                IoUs = self._iou_from_confusions(C)
                m_IoU = np.mean(IoUs)
                s = '{:5.2f} | '.format(100 * m_IoU)
                for IoU in IoUs:
                    s += '{:5.2f} '.format(100 * IoU)
                print(s)

                if int(np.ceil(new_min) % 1) == 0:  # ???
                    print('re-project vote #{:d}'.format(int(
                        np.floor(new_min))))
                    proj_prob_list = []

                    for i in range(num_clouds):
                        proj_idx = self.dataset.test_set.val_proj[i]
                        probs = self.test_probs[i][proj_idx, :]
                        proj_prob_list += [probs]

                    # show vote results
                    print('confusion on full cloud')
                    confusion_list = []
                    for i in range(num_clouds):
                        preds = self.dataset.test_set.label_values[np.argmax(
                            proj_prob_list[i], axis=1)].astype(np.uint8)
                        labels = self.dataset.test_set.val_labels[i]
                        acc = np.sum(preds == labels) / len(labels)
                        print(self.dataset.test_set.input_names[i] + 'ACC:' +
                              str(acc))
                        confusion_list += [
                            confusion_matrix(
                                labels, preds,
                                self.dataset.test_set.label_values)
                        ]
                        # name = self.dataset.test_set.label_values + '.ply'
                        # write_ply(join(path, 'val_preds', name), [preds, labels], ['pred', 'label'])

                    # re-group confusions
                    C = np.sum(np.stack(confusion_list), axis=0)
                    IoUs = self._iou_from_confusions(C)
                    m_IoU = np.mean(IoUs)
                    s = '{:5.2f} | '.format(100 * m_IoU)
                    for IoU in IoUs:
                        s += '{:5.2f} '.format(100 * IoU)

                    print(s)
                    print('finished.')
                    return
            epoch += 1
            continue
        return
Exemplo n.º 18
0
def run(model: BaseModel, dataset: BaseDataset, device, cfg):

    reg_thresh = cfg.data.registration_recall_thresh
    if reg_thresh is None:
        reg_thresh = 0.2
    print(time.strftime("%Y%m%d-%H%M%S"))
    dataset.create_dataloaders(
        model, 1, False, cfg.training.num_workers, False,
    )
    loader = dataset.test_dataloaders[0]
    list_res = []
    with Ctq(loader) as tq_test_loader:
        for i, data in enumerate(tq_test_loader):
            with torch.no_grad():
                t0 = time.time()
                model.set_input(data, device)
                model.forward()
                t1 = time.time()
                name_scene, name_pair_source, name_pair_target = dataset.test_dataset[0].get_name(i)
                input, input_target = model.get_input()
                xyz, xyz_target = input.pos, input_target.pos
                ind, ind_target = input.ind, input_target.ind
                matches_gt = torch.stack([ind, ind_target]).transpose(0, 1)
                feat, feat_target = model.get_output()
                # rand = voxel_selection(xyz, grid_size=0.06, min_points=cfg.data.min_points)
                # rand_target = voxel_selection(xyz_target, grid_size=0.06, min_points=cfg.data.min_points)

                rand = torch.randperm(len(feat))[: cfg.data.num_points]
                rand_target = torch.randperm(len(feat_target))[: cfg.data.num_points]
                res = dict(name_scene=name_scene, name_pair_source=name_pair_source, name_pair_target=name_pair_target)
                T_gt = estimate_transfo(xyz[matches_gt[:, 0]], xyz_target[matches_gt[:, 1]])
                t2 = time.time()
                metric = compute_metrics(
                    xyz[rand],
                    xyz_target[rand_target],
                    feat[rand],
                    feat_target[rand_target],
                    T_gt,
                    sym=cfg.data.sym,
                    tau_1=cfg.data.tau_1,
                    tau_2=cfg.data.tau_2,
                    rot_thresh=cfg.data.rot_thresh,
                    trans_thresh=cfg.data.trans_thresh,
                    use_ransac=cfg.data.use_ransac,
                    ransac_thresh=cfg.data.first_subsampling,
                    use_teaser=cfg.data.use_teaser,
                    noise_bound_teaser=cfg.data.noise_bound_teaser,
                    xyz_gt=xyz[matches_gt[:, 0]],
                    xyz_target_gt=xyz_target[matches_gt[:, 1]],
                    registration_recall_thresh=reg_thresh,
                )
                res = dict(**res, **metric)
                res["time_feature"] = t1 - t0
                res["time_feature_per_point"] = (t1 - t0) / (len(input.pos) + len(input_target.pos))
                res["time_prep"] = t2 - t1

                list_res.append(res)

    df = pd.DataFrame(list_res)
    output_path = os.path.join(cfg.training.checkpoint_dir, cfg.data.name, "matches")
    if not os.path.exists(output_path):
        os.makedirs(output_path, exist_ok=True)
    df.to_csv(osp.join(output_path, "final_res_{}.csv".format(time.strftime("%Y%m%d-%H%M%S"))))
    print(df.groupby("name_scene").mean())