Exemple #1
0
    def test_votenet_backbones(self):
        from torch_points3d.applications.votenet import VoteNet

        cfg = OmegaConf.load(
            os.path.join(DIR_PATH,
                         "data/scannet-fixed/config_object_detection.yaml"))
        config_data = cfg.data
        config_data.is_test = True
        dataset = ScannetDataset(config_data)
        model = VoteNet(
            original=False,
            backbone="kpconv",
            input_nc=dataset.feature_dimension,
            num_classes=dataset.num_classes,
            mean_size_arr=dataset.mean_size_arr,
            compute_loss=True,
            in_feat=4,
        )

        dataset.create_dataloaders(model,
                                   batch_size=2,
                                   shuffle=True,
                                   num_workers=0,
                                   precompute_multi_scale=False)

        train_loader = dataset.train_dataloader
        data = next(iter(train_loader))
        data = GridSampling3D(0.1)(data)
        # for key in data.keys:
        #    print(key, data[key].shape, data[key].dtype)
        model.verify_data(data)
        model.forward(data)

        self.assertEqual(hasattr(model, "loss"), True)

        attrs_test = {
            "center": [2, 256, 3],
            "heading_residuals": [2, 256, 1],
            "heading_residuals_normalized": [2, 256, 1],
            "heading_scores": [2, 256, 1],
            "object_assignment": [2, 256],
            "objectness_label": [2, 256],
            "objectness_mask": [2, 256],
            "objectness_scores": [2, 256, 2],
            "sampled_votes": [2, 256, 3],
            "seed_inds": [2048],
            "seed_pos": [2, 1024, 3],
            "seed_votes": [2, 1024, 3],
            "sem_cls_scores": [2, 256, 20],
            "size_residuals_normalized": [2, 256, 18, 3],
            "size_scores": [2, 256, 18],
        }

        output = model.output
        for k, v in attrs_test.items():
            self.assertEqual(hasattr(output, k), True)
            self.assertEqual(getattr(output, k).shape, torch.Size(v))
Exemple #2
0
    def test_votenet_paper(self):
        from torch_points3d.applications.votenet import VoteNet

        current_dir = os.path.dirname(os.path.realpath(__file__))
        cfg = OmegaConf.load(
            os.path.join(current_dir,
                         "data/scannet-fixed/config_object_detection.yaml"))
        config_data = cfg.data
        config_data.is_test = True
        dataset = ScannetDataset(config_data)
        model = VoteNet(original=True,
                        input_nc=dataset.feature_dimension,
                        num_classes=dataset.num_classes,
                        compute_loss=True)

        dataset.create_dataloaders(model,
                                   batch_size=2,
                                   shuffle=True,
                                   num_workers=0,
                                   precompute_multi_scale=False)

        train_loader = dataset.train_dataloader
        data = next(iter(train_loader))
        model.verify_data(data)
        model.forward(data)

        self.assertEqual(hasattr(model, "loss"), True)

        attrs_test = {
            "center": [2, 256, 3],
            "heading_residuals": [2, 256, 1],
            "heading_residuals_normalized": [2, 256, 1],
            "heading_scores": [2, 256, 1],
            "object_assignment": [2, 256],
            "objectness_label": [2, 256],
            "objectness_mask": [2, 256],
            "objectness_scores": [2, 256, 2],
            "sampled_votes": [2, 256, 3],
            "seed_inds": [2, 1024],
            "seed_pos": [2, 1024, 3],
            "seed_votes": [2, 1024, 3],
            "sem_cls_scores": [2, 256, 20],
            "size_residuals_normalized": [2, 256, 0, 3],
            "size_scores": [2, 256, 0],
        }

        output = model.output
        for k, v in attrs_test.items():
            self.assertEqual(hasattr(output, k), True)
            self.assertEqual(getattr(output, k).shape, torch.Size(v))