def test_confusion_matrix():
    target = (torch.arange(120) % 3).view(-1, 1)
    pred = target.clone()
    cm = confusion_matrix(pred, target, normalize=True)

    assert torch.allclose(
        cm, torch.tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]))

    pred = torch.zeros_like(pred)
    cm = confusion_matrix(pred, target, normalize=True)
    assert torch.allclose(
        cm, torch.tensor([[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]]))
 def validation_epoch_end(self, validation_step_outputs):
     if self.logger and self.current_epoch > 0 and self.current_epoch % 5 == 0:
         epoch_preds = torch.cat([x["preds"] for x in validation_step_outputs])
         epoch_targets = torch.cat([x["labels"] for x in validation_step_outputs])
         cm = confusion_matrix(epoch_preds, epoch_targets, num_classes=self.hparams.num_classes).cpu().numpy()
         class_names = getattr(self.train_dataloader().dataset, "classes", None)
         fig = generate_confusion_matrix(cm, class_names=class_names)
         self.logger.experiment.log({"confusion_matrix": fig})
 def on_test_epoch_end(self):
     # confusion matrix
     cm = confusion_matrix(self.test_pred[:,0], self.test_pred[:,1],\
                           normalize=True, num_classes=10)
     plot_confusion_matrix(cm.cpu().numpy(), \
                           target_names=[str(i) for i in range(10)], \
                           title='Confusion Matrix',normalize=False, \
                           cmap=plt.get_cmap('bwr'), \
                           args=self.args)
     return None
示例#4
0
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Actual metric computation

        Args:
            pred: predicted labels
            target: ground truth labels

        Return:
            A Tensor with the confusion matrix.
        """
        return confusion_matrix(pred=pred, target=target,
                                normalize=self.normalize)
示例#5
0
    def validation_epoch_end(self, outputs):

        avg_loss = torch.stack([b['loss'] for b in outputs]).mean()

        conf_mtx = confusion_matrix(torch.cat([b['preds'] for b in outputs]),
                                    torch.cat([b['labels'] for b in outputs]),
                                    normalize=False,
                                    num_classes=5)
        avg_acc = torch.diag(conf_mtx).sum() / conf_mtx.sum()

        self.logger.experiment.add_scalar('val_avg_loss', avg_loss,
                                          self.current_epoch)
        self.logger.experiment.add_scalar('val_avg_acc', avg_acc,
                                          self.current_epoch)
示例#6
0
    def training_epoch_end(self, outputs):

        #####  Print Progress  ###############
        if self.trainer.current_epoch < self.hparams.burn_in_epochs:
            print('[PROGRESS] Burning in classifier epoch %i' %
                  self.current_epoch)

        #####  Calculate Metircs  ############
        avg_loss = torch.stack([b['loss'] for b in outputs]).mean()

        conf_mtx = confusion_matrix(torch.cat([b['preds'] for b in outputs]),
                                    torch.cat([b['labels'] for b in outputs]),
                                    normalize=False,
                                    num_classes=5)
        avg_acc = torch.diag(conf_mtx).sum() / conf_mtx.sum()

        self.logger.experiment.add_scalar('train_avg_loss', avg_loss,
                                          self.current_epoch)
        self.logger.experiment.add_scalar('train_avg_acc', avg_acc,
                                          self.current_epoch)
示例#7
0
    def test_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict:
        """ 
        Runs one training step. This usually consists in the forward function followed
            by the loss function.
        
        :param batch: The output of your dataloader. 
        :param batch_nb: Integer displaying which batch this is

        Returns:
            - dictionary containing the loss and the metrics to be added to the lightning logger.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val = loss_val.unsqueeze(0)

        self.log('test_loss', loss_val)

        y_hat = model_out['logits']
        labels_hat = torch.argmax(y_hat, dim=1)
        y = targets['labels']

        f1 = metrics.f1_score(labels_hat, y, class_reduction='weighted')
        prec = metrics.precision(labels_hat, y, class_reduction='weighted')
        recall = metrics.recall(labels_hat, y, class_reduction='weighted')
        acc = metrics.accuracy(labels_hat, y, class_reduction='weighted')

        self.log('test_batch_prec', prec)
        self.log('test_batch_f1', f1)
        self.log('test_batch_recall', recall)
        self.log('test_batch_weighted_acc', acc)

        cm = metrics.confusion_matrix(pred=labels_hat,
                                      target=y,
                                      normalize=False)
        self.test_conf_matrices.append(cm)
def test(args, dm, net):

    from pytorch_lightning.metrics.functional.classification import confusion_matrix

    model = net.load_from_checkpoint(checkpoint_path=args.ckpt_path,
                                     num_channel=dm.n_channels,
                                     num_class=dm.n_classes)
    model.eval()

    dm.config["batch_size"] = 1
    dm.prepare_data()
    dm.setup()

    trainer = pl.Trainer()
    trainer.test(model, datamodule=dm)

    cmat = confusion_matrix(torch.stack(model.predictions),
                            torch.stack(model.targets),
                            num_classes=dm.n_classes)

    utils.plot_confusion_matrix(matrix=cmat.numpy(),
                                classes=dm.raw_Y.columns,
                                figure_name="./figures/cmat.jpg")
示例#9
0
    def test_epoch_end(self, outputs):

        if self.incorrect_type != 'boundary':
            #####  Confusion Matrix  #####
            conf_mtx = confusion_matrix(
                torch.cat([b['preds'] for b in outputs]),
                torch.cat([b['labels'] for b in outputs]),
                normalize=False,
                num_classes=5)

            #####  Normalized Confusion Matrix  #####
            conf_mtx_normalized = confusion_matrix(
                torch.cat([b['preds'] for b in outputs]),
                torch.cat([b['labels'] for b in outputs]),
                normalize=True,
                num_classes=5)

            #####  Weighted Confusion Matrix  #####
            conf_mtx_weighted = conf_mtx.clone()
            for c, w in enumerate(self.weights):
                conf_mtx_weighted[c, :] *= w

            #####  ACCURACY  #####
            accuracy = torch.diag(conf_mtx).sum() / conf_mtx.sum()
            accuracy_weighted = torch.diag(
                conf_mtx_weighted).sum() / conf_mtx_weighted.sum()

            #####  AUC_SCORE  #####
            roc_results = multiclass_roc(
                torch.cat([b['logits'] for b in outputs]),
                torch.cat([b['labels'] for b in outputs]),
                num_classes=5)
            AUROC_str = ''
            AUROC_list = {}
            for cls, roc_cls in enumerate(roc_results):
                fpr, tpr, threshold = roc_cls
                self.logger.experiment.add_scalar(f'val_AUC[{cls}]',
                                                  auc(fpr, tpr),
                                                  self.current_epoch)
                AUROC_str += '\tAUC_SCORE[CLS %d]: \t%.4f\n' % (cls,
                                                                auc(fpr, tpr))
                AUROC_list['AUC_SCORE[CLS %d]' % cls] = auc(fpr, tpr)

            #####  F1  #####
            f1_score = f1(torch.cat([b['preds'] for b in outputs]),
                          torch.cat([b['labels'] for b in outputs]),
                          num_classes=5)

            #####  Average Precision  #####
            # TO DO

            #####  PRINT RESULTS  #####
            print('=' * 100)
            print(
                f'[MODEL NAME]: {self.model_name} \t [INCORRECT TYPE]: {self.incorrect_type}'
            )
            print('RESULTS:')
            print('\tAccuracy: \t\t%.4f' % accuracy)
            print('\tWeighted Accuracy: \t%.4f' % accuracy_weighted)
            print('\tF1 Score: \t\t%.4f' % f1_score)
            print(AUROC_str)

            self.metrics_result[self.incorrect_type][self.model_name] = {
                'Accuracy': round(float(accuracy), 4),
                'Weighted Accuracy': round(float(accuracy_weighted), 4),
                'F1_score': round(float(f1_score), 4)
            }
            for key, val in AUROC_list.items():
                self.metrics_result[self.incorrect_type][
                    self.model_name].update({key: round(float(val), 4)})
            print('Confusion Matrix')
            fig, ax = plt.subplots(figsize=(4, 4))
            sn.heatmap(conf_mtx.cpu(),
                       annot=True,
                       cbar=False,
                       annot_kws={"size": 15},
                       fmt='g',
                       cmap='mako')
            plt.show()
            fig, ax = plt.subplots(figsize=(4, 4))
            sn.heatmap(conf_mtx_normalized.cpu(),
                       annot=True,
                       cbar=False,
                       annot_kws={"size": 12},
                       fmt='.2f',
                       cmap='mako')
            plt.show()
            print('=' * 100)

        else:
            tol_correct = 0
            tol_samples = 0
            tol_drop = 0
            for batch in outputs:
                preds = batch['preds']
                labels = batch['labels']
                slope_id = batch['doc_ids']
                ##### Change lizhong's code ####
                for idx, slop_idx in enumerate(slope_id):
                    agree_by_user = bool(
                        slope_df[slope_df['slope_id'] == slop_idx.item()]
                        ['sentiment_correct'].values[0])
                    possible_classes = slope_df[
                        slope_df['slope_id'] ==
                        slop_idx.item()]['label_from_score'].values[0]

                    pred_class = preds[idx]
                    # difference between pred and true label
                    diff = torch.abs(pred_class - possible_classes)

                    # if correct label
                    if agree_by_user:  # True
                        if diff == 0:
                            # correct prediction
                            tol_correct += 1
                            tol_samples += 1
                        elif diff == 1:
                            # discard
                            tol_drop += 1
                        else:
                            # wrong prediction
                            tol_samples += 1
                    # if incorrect label
                    else:  # False
                        if diff == 0:
                            # wrong
                            tol_samples += 1
                        elif diff == 1:
                            # discard
                            tol_drop += 1
                        else:
                            # Correct
                            tol_correct += 1
                            tol_samples += 1

            boundary_accuracy = round(tol_correct / tol_samples, 4)
            self.metrics_result[self.incorrect_type][self.model_name] = {}
            self.metrics_result[self.incorrect_type][
                self.model_name]['boundary_acc'] = boundary_accuracy
            self.metrics_result[self.incorrect_type][
                self.model_name]['total_drop_sample'] = tol_drop
            print('=' * 100)
            print(
                f'[MODEL NAME]: {self.model_name} \t [INCORRECT TYPE]: {self.incorrect_type}'
            )
            print('\tBoundary Accuracy: \t\t%.4f' % boundary_accuracy)
            print('\tDrop Total Sample: \t\t%.4f' % tol_drop)
示例#10
0
def run_epoch(model,
              dataloader,
              criterion,
              optimizer=None,
              epoch=0,
              scheduler=None,
              device='cpu'):
    import pytorch_lightning.metrics.functional.classification as clmetrics
    from pytorch_lightning.metrics import Precision, Accuracy, Recall
    from sklearn.metrics import roc_auc_score, average_precision_score

    metrics = Accumulator()
    cnt = 0
    total_steps = len(dataloader)
    steps = 0
    running_corrects = 0

    accuracy = Accuracy()
    precision = Precision(num_classes=2)
    recall = Recall(num_classes=2)

    preds_epoch = []
    labels_epoch = []
    for inputs, labels in dataloader:
        steps += 1
        inputs = inputs.to(device)  # torch.Size([2, 1, 224, 224])
        labels = labels.to(device).unsqueeze(1).float()  ## torch.Size([2, 1])

        outputs = model(inputs)  # [batch_size, nb_classes]

        loss = criterion(outputs, labels)

        if optimizer:
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        preds_epoch.extend(torch.sigmoid(outputs).tolist())
        labels_epoch.extend(labels.tolist())
        threshold = 0.5
        prob = (torch.sigmoid(outputs) > threshold).long()

        conf = torch.flatten(
            clmetrics.confusion_matrix(prob, labels, num_classes=2))
        tn, fp, fn, tp = conf

        metrics.add_dict({
            'data_count': len(inputs),
            'loss': loss.item() * len(inputs),
            'tp': tp.item(),
            'tn': tn.item(),
            'fp': fp.item(),
            'fn': fn.item(),
        })
        cnt += len(inputs)

        if scheduler:
            scheduler.step()
        del outputs, loss, inputs, labels, prob
    logger.info(f'cnt = {cnt}')

    metrics['loss'] /= cnt

    def safe_div(x, y):
        if y == 0:
            return 0
        return x / y

    _TP, _TN, _FP, _FN = metrics['tp'], metrics['tn'], metrics['fp'], metrics[
        'fn']
    acc = (_TP + _TN) / cnt
    sen = safe_div(_TP, (_TP + _FN))
    spe = safe_div(_TN, (_FP + _TN))
    prec = safe_div(_TP, (_TP + _FP))
    metrics.add('accuracy', acc)
    metrics.add('sensitivity', sen)
    metrics.add('specificity', spe)
    metrics.add('precision', prec)

    auc = roc_auc_score(labels_epoch, preds_epoch)
    aupr = average_precision_score(labels_epoch, preds_epoch)
    metrics.add('auroc', auc)
    metrics.add('aupr', aupr)

    logger.info(metrics)

    return metrics, preds_epoch, labels_epoch