Example #1
0
    def finalise(self,
                 full_res=False,
                 vote_miou=True,
                 ply_output="",
                 **kwargs):
        per_class_iou = self._confusion_matrix.get_intersection_union_per_class(
        )[0]
        self._iou_per_class = {
            self._dataset.INV_OBJECT_LABEL[k]: v
            for k, v in enumerate(per_class_iou)
        }

        if vote_miou and self._test_area:
            # Complete for points that have a prediction
            self._test_area = self._test_area.to("cpu")
            c = ConfusionMatrix(self._num_classes)
            has_prediction = self._test_area.prediction_count > 0
            gt = self._test_area.y[has_prediction].numpy()
            pred = torch.argmax(self._test_area.votes[has_prediction],
                                1).numpy()
            c.count_predicted_batch(gt, pred)
            self._vote_miou = c.get_average_intersection_union() * 100

        if full_res:
            self._compute_full_miou()

        if ply_output:
            has_prediction = self._test_area.prediction_count > 0
            self._dataset.to_ply(
                self._test_area.pos[has_prediction].cpu(),
                torch.argmax(self._test_area.votes[has_prediction],
                             1).cpu().numpy(),
                ply_output,
            )
Example #2
0
    def _compute_full_miou(self):
        if self._full_vote_miou is not None:
            return

        has_prediction = self._test_area.prediction_count > 0
        log.info(
            "Computing full res mIoU, we have predictions for %.2f%% of the points."
            % (torch.sum(has_prediction) /
               (1.0 * has_prediction.shape[0]) * 100))

        self._test_area = self._test_area.to("cpu")

        # Full res interpolation
        full_pred = knn_interpolate(
            self._test_area.votes[has_prediction],
            self._test_area.pos[has_prediction],
            self._test_area.pos,
            k=1,
        )

        # Full res pred
        c = ConfusionMatrix(self._num_classes)
        c.count_predicted_batch(self._test_area.y.numpy(),
                                torch.argmax(full_pred, 1).numpy())
        self._full_vote_miou = c.get_average_intersection_union() * 100
Example #3
0
class ScannetSegmentationTracker(SegmentationTracker):
    def reset(self, stage="train"):
        super().reset(stage=stage)
        self._full_confusion_matrix = ConfusionMatrix(self._num_classes)
        self._raw_datas = {}
        self._votes = {}
        self._vote_counts = {}
        self._full_preds = {}
        self._full_acc = None

    def track(self, model: model_interface.TrackerInterface, full_res=False, **kwargs):
        """ Add current model predictions (usually the result of a batch) to the tracking
        """
        super().track(model)

        # Set conv type
        self._conv_type = model.conv_type

        # Train mode or low res, nothing special to do
        if not full_res or self._stage == "train" or kwargs.get("data") is None:
            return

        self._vote(kwargs.get("data"), model.get_output())

    def get_metrics(self, verbose=False) -> Dict[str, Any]:
        """ Returns a dictionnary of all metrics and losses being tracked
        """
        metrics = super().get_metrics(verbose)
        if self._full_acc:
            metrics["{}_full_acc".format(self._stage)] = self._full_acc
            metrics["{}_full_macc".format(self._stage)] = self._full_macc
            metrics["{}_full_miou".format(self._stage)] = self._full_miou
        return metrics

    def finalise(self, full_res=False, make_submission=False, **kwargs):
        if not full_res and not make_submission:
            return

        self._predict_full_res()

        # Compute full res metrics
        if self._dataset.has_labels(self._stage):
            for scan_id in self._full_preds:
                full_labels = self._raw_datas[scan_id].y
                # Mask ignored labels
                mask = full_labels != self._ignore_label
                full_labels = full_labels[mask]
                full_preds = self._full_preds[scan_id].cpu()[mask].numpy()
                self._full_confusion_matrix.count_predicted_batch(full_labels, full_preds)

            self._full_acc = 100 * self._full_confusion_matrix.get_overall_accuracy()
            self._full_macc = 100 * self._full_confusion_matrix.get_mean_class_accuracy()
            self._full_miou = 100 * self._full_confusion_matrix.get_average_intersection_union()

        # Save files to disk
        if make_submission and self._stage == "test":
            self._make_submission()

    def _make_submission(self):
        orginal_class_ids = np.asarray(self._dataset.train_dataset.valid_class_idx)
        path_to_submission = self._dataset.path_to_submission
        for scan_id in self._full_preds:
            full_pred = self._full_preds[scan_id].cpu().numpy().astype(np.int8)
            full_pred = orginal_class_ids[full_pred]  # remap labels to original labels between 0 and 40
            scan_name = self._raw_datas[scan_id].scan_name
            path_file = osp.join(path_to_submission, "{}.txt".format(scan_name))
            np.savetxt(path_file, full_pred, delimiter="/n", fmt="%d")

    def _vote(self, data, output):
        """ Populates scores for the points in data

        Parameters
        ----------
        data : Data
            should contain `pos` and `SaveOriginalPosId.KEY` keys
        output : torch.Tensor
            probablities out of the model, shape: [N,nb_classes]
        """
        id_scans = data.id_scan.squeeze()
        if self._conv_type == "DENSE":
            batch_size = len(id_scans)
            output = output.view(batch_size, -1, output.shape[-1])

        for idx_batch, id_scan in enumerate(id_scans):
            # First time we see this scan
            if id_scan not in self._raw_datas:
                raw_data = self._dataset.get_raw_data(self._stage, id_scan, remap_labels=True)
                self._raw_datas[id_scan] = raw_data
                self._vote_counts[id_scan] = torch.zeros(raw_data.pos.shape[0], dtype=torch.int)
                self._votes[id_scan] = torch.zeros((raw_data.pos.shape[0], self._num_classes), dtype=torch.float)
            else:
                raw_data = self._raw_datas[id_scan]

            batch_mask = idx_batch
            if self._conv_type != "DENSE":
                batch_mask = data.batch == idx_batch
            idx = data[SaveOriginalPosId.KEY][batch_mask]

            self._votes[id_scan][idx] += output[batch_mask].cpu()
            self._vote_counts[id_scan][idx] += 1

    def _predict_full_res(self):
        """ Predict full resolution results based on votes """
        for id_scan in self._votes:
            has_prediction = self._vote_counts[id_scan] > 0
            self._votes[id_scan][has_prediction] /= self._vote_counts[id_scan][has_prediction].unsqueeze(-1)

            # Upsample and predict
            full_pred = knn_interpolate(
                self._votes[id_scan][has_prediction],
                self._raw_datas[id_scan].pos[has_prediction],
                self._raw_datas[id_scan].pos,
                k=1,
            )
            self._full_preds[id_scan] = full_pred.argmax(-1)
Example #4
0
class SegmentationTracker(BaseTracker):
    def __init__(
        self, dataset, stage="train", wandb_log=False, use_tensorboard: bool = False, ignore_label: int = IGNORE_LABEL
    ):
        """ This is a generic tracker for segmentation tasks.
        It uses a confusion matrix in the back-end to track results.
        Use the tracker to track an epoch.
        You can use the reset function before you start a new epoch

        Arguments:
            dataset  -- dataset to track (used for the number of classes)

        Keyword Arguments:
            stage {str} -- current stage. (train, validation, test, etc...) (default: {"train"})
            wandb_log {str} --  Log using weight and biases
        """
        super(SegmentationTracker, self).__init__(stage, wandb_log, use_tensorboard)
        self._num_classes = dataset.num_classes
        self._ignore_label = ignore_label
        self._dataset = dataset
        self.reset(stage)

    def reset(self, stage="train"):
        super().reset(stage=stage)
        self._confusion_matrix = ConfusionMatrix(self._num_classes)
        self._acc = 0
        self._macc = 0
        self._miou = 0

    @staticmethod
    def detach_tensor(tensor):
        if torch.torch.is_tensor(tensor):
            tensor = tensor.detach()
        return tensor

    @property
    def confusion_matrix(self):
        return self._confusion_matrix.confusion_matrix

    def track(self, model: model_interface.TrackerInterface, **kwargs):
        """ Add current model predictions (usually the result of a batch) to the tracking
        """
        if not self._dataset.has_labels(self._stage):
            return

        super().track(model)

        outputs = model.get_output()
        targets = model.get_labels()

        # Mask ignored label
        mask = targets != self._ignore_label
        outputs = outputs[mask]
        targets = targets[mask]

        outputs = self._convert(outputs)
        targets = self._convert(targets)

        if len(targets) == 0:
            return

        assert outputs.shape[0] == len(targets)
        self._confusion_matrix.count_predicted_batch(targets, np.argmax(outputs, 1))

        self._acc = 100 * self._confusion_matrix.get_overall_accuracy()
        self._macc = 100 * self._confusion_matrix.get_mean_class_accuracy()
        self._miou = 100 * self._confusion_matrix.get_average_intersection_union()

    def get_metrics(self, verbose=False) -> Dict[str, Any]:
        """ Returns a dictionnary of all metrics and losses being tracked
        """
        metrics = super().get_metrics(verbose)

        metrics["{}_acc".format(self._stage)] = self._acc
        metrics["{}_macc".format(self._stage)] = self._macc
        metrics["{}_miou".format(self._stage)] = self._miou
        return metrics

    @property
    def metric_func(self):
        self._metric_func = {
            "miou": max,
            "macc": max,
            "acc": max,
            "loss": min,
            "map": max,
        }  # Those map subsentences to their optimization functions
        return self._metric_func
Example #5
0
class S3DISTracker(SegmentationTracker):
    def reset(self, *args, **kwargs):
        super().reset(*args, **kwargs)
        self._test_area = None
        self._full_vote_miou = None
        self._vote_miou = None
        self._full_confusion = None
        self._iou_per_class = {}

    def track(self, model: model_interface.TrackerInterface, full_res=False, data=None, **kwargs):
        """ Add current model predictions (usually the result of a batch) to the tracking
        """
        super().track(model)

        # Train mode or low res, nothing special to do
        if self._stage == "train" or not full_res:
            return

        # Test mode, compute votes in order to get full res predictions
        if self._test_area is None:
            self._test_area = self._dataset.test_data.clone()
            if self._test_area.y is None:
                raise ValueError("It seems that the test area data does not have labels (attribute y).")
            self._test_area.prediction_count = torch.zeros(self._test_area.y.shape[0], dtype=torch.int)
            self._test_area.votes = torch.zeros((self._test_area.y.shape[0], self._num_classes), dtype=torch.float)
            self._test_area.to(model.device)

        # Gather origin ids and check that it fits with the test set
        inputs = data if data is not None else model.get_input()
        if inputs[SaveOriginalPosId.KEY] is None:
            raise ValueError("The inputs given to the model do not have a %s attribute." % SaveOriginalPosId.KEY)

        originids = inputs[SaveOriginalPosId.KEY]
        if originids.dim() == 2:
            originids = originids.flatten()
        if originids.max() >= self._test_area.pos.shape[0]:
            raise ValueError("Origin ids are larger than the number of points in the original point cloud.")

        # Set predictions
        outputs = model.get_output()
        self._test_area.votes[originids] += outputs
        self._test_area.prediction_count[originids] += 1

    def finalise(self, full_res=False, vote_miou=True, ply_output="", **kwargs):
        per_class_iou = self._confusion_matrix.get_intersection_union_per_class()[0]
        self._iou_per_class = {self._dataset.INV_OBJECT_LABEL[k]: v for k, v in enumerate(per_class_iou)}

        if vote_miou and self._test_area:
            # Complete for points that have a prediction
            self._test_area = self._test_area.to("cpu")
            c = ConfusionMatrix(self._num_classes)
            has_prediction = self._test_area.prediction_count > 0
            gt = self._test_area.y[has_prediction].numpy()
            pred = torch.argmax(self._test_area.votes[has_prediction], 1).numpy()
            c.count_predicted_batch(gt, pred)
            self._vote_miou = c.get_average_intersection_union() * 100

        if full_res:
            self._compute_full_miou()

        if ply_output:
            has_prediction = self._test_area.prediction_count > 0
            self._dataset.to_ply(
                self._test_area.pos[has_prediction].cpu(),
                torch.argmax(self._test_area.votes[has_prediction], 1).cpu().numpy(),
                ply_output,
            )

    def _compute_full_miou(self):
        if self._full_vote_miou is not None:
            return

        has_prediction = self._test_area.prediction_count > 0
        log.info(
            "Computing full res mIoU, we have predictions for %.2f%% of the points."
            % (torch.sum(has_prediction) / (1.0 * has_prediction.shape[0]) * 100)
        )

        self._test_area = self._test_area.to("cpu")

        # Full res interpolation
        full_pred = knn_interpolate(
            self._test_area.votes[has_prediction], self._test_area.pos[has_prediction], self._test_area.pos, k=1,
        )

        # Full res pred
        self._full_confusion = ConfusionMatrix(self._num_classes)
        self._full_confusion.count_predicted_batch(self._test_area.y.numpy(), torch.argmax(full_pred, 1).numpy())
        self._full_vote_miou = self._full_confusion.get_average_intersection_union() * 100

    @property
    def full_confusion_matrix(self):
        return self._full_confusion

    def get_metrics(self, verbose=False) -> Dict[str, Any]:
        """ Returns a dictionnary of all metrics and losses being tracked
        """
        metrics = super().get_metrics(verbose)

        if verbose:
            metrics["{}_iou_per_class".format(self._stage)] = self._iou_per_class
            if self._vote_miou:
                metrics["{}_full_vote_miou".format(self._stage)] = self._full_vote_miou
                metrics["{}_vote_miou".format(self._stage)] = self._vote_miou
        return metrics