Exemple #1
0
 def on_batch_end(self, state):
     import torchelie.utils as tu
     cm = state['cm']
     pred = tu.as_multiclass_shape(state['pred']).argmax(1)
     true = state['batch'][1]
     for p, t in zip(pred, true):
         cm[p][t] += 1
Exemple #2
0
    def analyze(self,
                batch: torch.Tensor,
                pred: torch.Tensor,
                true: torch.Tensor,
                pred_label: Optional[torch.Tensor] = None,
                paths: Optional[List[str]] = None) -> None:
        pred = as_multiclass_shape(pred, as_probs=True)
        for_label = pred.gather(1, true.unsqueeze(1))
        if pred_label is None:
            pred_label = pred.argmax(dim=1)
        best_pred = pred.gather(1, pred_label.unsqueeze(1))
        if paths is None:
            paths = [None] * len(batch)
        this_data = list(
            zip(batch, for_label, true, pred_label == true, paths, pred_label,
                best_pred))

        self.best += this_data
        self.best.sort(key=lambda x: -x[1])
        self.best = self.best[:self.topk]

        self.worst += this_data
        self.worst.sort(key=lambda x: x[1])
        self.worst = self.worst[:self.topk]

        self.confused += this_data
        self.confused.sort(key=lambda x: abs(self.center_value - x[1]))
        self.confused = self.confused[:self.topk]
Exemple #3
0
    def on_batch_end(self, state):
        pred, y = state['pred'], state['batch'][1]
        pred = tu.as_multiclass_shape(pred)
        batch_correct = pred.argmax(1).eq(y).float()
        if isinstance(self.avg, RunningAvg):
            self.avg.log(batch_correct.sum(), pred.shape[0])
        else:
            self.avg.log(batch_correct.mean())

        if self.post_each_batch:
            state['metrics']['acc'] = self.avg.get()
Exemple #4
0
    def analyze(self, batch, pred, true, pred_label=None, paths=None):
        pred = as_multiclass_shape(pred, as_probs=True)
        for_label = pred.gather(1, true.unsqueeze(1))
        if pred_label is None:
            pred_label = pred.argmax(dim=1)
        if paths is None:
            paths = [None] * len(batch)
        this_data = list(zip(batch, for_label, true, pred_label == true,
                             paths, pred_label))

        self.best += this_data
        self.best.sort(key=lambda x: -x[1])
        self.best = self.best[:self.topk]

        self.worst += this_data
        self.worst.sort(key=lambda x: x[1])
        self.worst = self.worst[:self.topk]

        self.confused += this_data
        self.confused.sort(key=lambda x: abs(self.center_value - x[1]))
        self.confused = self.confused[:self.topk]