示例#1
0
 def _reset_stats(self):
     if self._version == "tnt":
         self.confusion_matrix = ConfusionMeter(self.num_classes)
     elif self._version == "sklearn":
         self.outputs = []
         self.targets = []
示例#2
0
class ConfusionMatrixCallback(Callback):
    def __init__(self,
                 input_key: str = "targets",
                 output_key: str = "logits",
                 prefix: str = "confusion_matrix",
                 version: str = "tnt",
                 class_names: List[str] = None,
                 num_classes: int = None,
                 plot_params: Dict = None):
        self.prefix = prefix
        self.output_key = output_key
        self.input_key = input_key

        assert version in ["tnt", "sklearn"]
        self._version = version
        self._plot_params = plot_params or {}

        self.class_names = class_names
        self.num_classes = num_classes \
            if class_names is None \
            else len(class_names)

        assert self.num_classes is not None
        self._reset_stats()

    def _reset_stats(self):
        if self._version == "tnt":
            self.confusion_matrix = ConfusionMeter(self.num_classes)
        elif self._version == "sklearn":
            self.outputs = []
            self.targets = []

    def _add_to_stats(self, outputs, targets):
        if self._version == "tnt":
            self.confusion_matrix.add(predicted=outputs, target=targets)
        elif self._version == "sklearn":
            outputs = outputs.cpu().numpy()
            targets = targets.cpu().numpy()

            outputs = np.argmax(outputs, axis=1)

            self.outputs.extend(outputs)
            self.targets.extend(targets)

    def _compute_confusion_matrix(self):
        if self._version == "tnt":
            confusion_matrix = self.confusion_matrix.value()
        elif self._version == "sklearn":
            confusion_matrix = confusion_matrix_fn(y_true=self.targets,
                                                   y_pred=self.outputs)
        else:
            raise NotImplementedError()
        return confusion_matrix

    def _plot_confusion_matrix(self,
                               logger,
                               epoch,
                               confusion_matrix,
                               class_names=None):
        fig = utils.plot_confusion_matrix(confusion_matrix,
                                          class_names=class_names,
                                          normalize=True,
                                          show=False,
                                          **self._plot_params)
        fig = utils.render_figure_to_tensor(fig)
        logger.add_image(f"{self.prefix}/epoch", fig, global_step=epoch)

    def on_loader_start(self, state: RunnerState):
        self._reset_stats()

    def on_batch_end(self, state: RunnerState):
        self._add_to_stats(state.output[self.output_key].detach(),
                           state.input[self.input_key].detach())

    def on_loader_end(self, state: RunnerState):
        class_names = \
            self.class_names or \
            [str(i) for i in range(self.num_classes)]
        confusion_matrix = self._compute_confusion_matrix()
        self._plot_confusion_matrix(
            logger=state.loggers["tensorboard"].loggers[state.loader_name],
            epoch=state.epoch,
            confusion_matrix=confusion_matrix,
            class_names=class_names)