def reset(self, stage="train"): super().reset(stage=stage) self._confusion_matrix = ConfusionMatrix(self._num_classes) self._ap_meter = APMeter()
class SegmentationTracker(BaseTracker): def __init__(self, dataset, stage="train", wandb_log=False, use_tensorboard: bool = False, ignore_label=IGNORE_LABEL): """ This is a generic tracker for segmentation tasks. It uses a confusion matrix in the back-end to track results. Use the tracker to track an epoch. You can use the reset function before you start a new epoch Arguments: dataset -- dataset to track (used for the number of classes) Keyword Arguments: stage {str} -- current stage. (train, validation, test, etc...) (default: {"train"}) wandb_log {str} -- Log using weight and biases """ super(SegmentationTracker, self).__init__(stage, wandb_log, use_tensorboard) self._num_classes = dataset.num_classes self._ignore_label = ignore_label self._dataset = dataset self.reset(stage) def reset(self, stage="train"): super().reset(stage=stage) self._confusion_matrix = ConfusionMatrix(self._num_classes) self._ap_meter = APMeter() @staticmethod def detach_tensor(tensor): if torch.torch.is_tensor(tensor): tensor = tensor.detach() return tensor @property def confusion_matrix(self): return self._confusion_matrix.confusion_matrix def track(self, model: BaseModel, **kwargs): """ Add current model predictions (usually the result of a batch) to the tracking """ super().track(model) outputs = model.get_output() targets = model.get_labels() # Mask ignored label mask = targets != self._ignore_label outputs = outputs[mask] targets = targets[mask] outputs = SegmentationTracker.detach_tensor(outputs) targets = SegmentationTracker.detach_tensor(targets) if not torch.is_tensor(targets): targets = torch.from_numpy(targets) self._ap_meter.add(outputs, F.one_hot(targets, self._num_classes).bool()) outputs = self._convert(outputs) targets = self._convert(targets) if len(targets) == 0: return assert outputs.shape[0] == len(targets) self._confusion_matrix.count_predicted_batch(targets, np.argmax(outputs, 1)) self._acc = 100 * self._confusion_matrix.get_overall_accuracy() self._macc = 100 * self._confusion_matrix.get_mean_class_accuracy() self._miou = 100 * self._confusion_matrix.get_average_intersection_union( ) self._map = 100 * self._ap_meter.value().mean().item() def get_metrics(self, verbose=False) -> Dict[str, float]: """ Returns a dictionnary of all metrics and losses being tracked """ metrics = super().get_metrics(verbose) metrics["{}_acc".format(self._stage)] = self._acc metrics["{}_macc".format(self._stage)] = self._macc metrics["{}_miou".format(self._stage)] = self._miou metrics["{}_map".format(self._stage)] = self._map return metrics