Exemplo n.º 1
0
 def reset(self, stage="train"):
     super().reset(stage=stage)
     self._confusion_matrix = ConfusionMatrix(self._num_classes)
     self._ap_meter = APMeter()
class SegmentationTracker(BaseTracker):
    def __init__(self,
                 dataset,
                 stage="train",
                 wandb_log=False,
                 use_tensorboard: bool = False,
                 ignore_label=IGNORE_LABEL):
        """ This is a generic tracker for segmentation tasks.
        It uses a confusion matrix in the back-end to track results.
        Use the tracker to track an epoch.
        You can use the reset function before you start a new epoch

        Arguments:
            dataset  -- dataset to track (used for the number of classes)

        Keyword Arguments:
            stage {str} -- current stage. (train, validation, test, etc...) (default: {"train"})
            wandb_log {str} --  Log using weight and biases
        """
        super(SegmentationTracker, self).__init__(stage, wandb_log,
                                                  use_tensorboard)
        self._num_classes = dataset.num_classes
        self._ignore_label = ignore_label
        self._dataset = dataset
        self.reset(stage)

    def reset(self, stage="train"):
        super().reset(stage=stage)
        self._confusion_matrix = ConfusionMatrix(self._num_classes)
        self._ap_meter = APMeter()

    @staticmethod
    def detach_tensor(tensor):
        if torch.torch.is_tensor(tensor):
            tensor = tensor.detach()
        return tensor

    @property
    def confusion_matrix(self):
        return self._confusion_matrix.confusion_matrix

    def track(self, model: BaseModel, **kwargs):
        """ Add current model predictions (usually the result of a batch) to the tracking
        """
        super().track(model)

        outputs = model.get_output()
        targets = model.get_labels()

        # Mask ignored label
        mask = targets != self._ignore_label
        outputs = outputs[mask]
        targets = targets[mask]

        outputs = SegmentationTracker.detach_tensor(outputs)
        targets = SegmentationTracker.detach_tensor(targets)
        if not torch.is_tensor(targets):
            targets = torch.from_numpy(targets)
        self._ap_meter.add(outputs,
                           F.one_hot(targets, self._num_classes).bool())

        outputs = self._convert(outputs)
        targets = self._convert(targets)

        if len(targets) == 0:
            return

        assert outputs.shape[0] == len(targets)
        self._confusion_matrix.count_predicted_batch(targets,
                                                     np.argmax(outputs, 1))

        self._acc = 100 * self._confusion_matrix.get_overall_accuracy()
        self._macc = 100 * self._confusion_matrix.get_mean_class_accuracy()
        self._miou = 100 * self._confusion_matrix.get_average_intersection_union(
        )
        self._map = 100 * self._ap_meter.value().mean().item()

    def get_metrics(self, verbose=False) -> Dict[str, float]:
        """ Returns a dictionnary of all metrics and losses being tracked
        """
        metrics = super().get_metrics(verbose)

        metrics["{}_acc".format(self._stage)] = self._acc
        metrics["{}_macc".format(self._stage)] = self._macc
        metrics["{}_miou".format(self._stage)] = self._miou
        metrics["{}_map".format(self._stage)] = self._map
        return metrics