Ejemplo n.º 1
0
def test_warning_on_nan(tmpdir):
    preds = torch.randint(3, size=(20, ))
    target = torch.randint(3, size=(20, ))

    with pytest.warns(
        UserWarning,
        match='.* nan values found in confusion matrix have been replaced with zeros.',
    ):
        confusion_matrix(preds, target, num_classes=5, normalize='true')
Ejemplo n.º 2
0
def run_epoch(model, dataloader, criterion, optimizer=None, epoch=0, scheduler=None, device='cpu'):
    import torchmetrics.functional as clmetrics
    from torchmetrics import Precision, Accuracy, Recall
    #import pytorch_lightning.metrics.functional.classification as clmetrics
    #from pytorch_lightning.metrics import Precision, Accuracy, Recall
    from sklearn.metrics import roc_auc_score, average_precision_score

    metrics = Accumulator()
    cnt = 0
    total_steps = len(dataloader)
    steps = 0
    running_corrects = 0
    

    accuracy = Accuracy()
    precision = Precision(num_classes=2)
    recall = Recall(num_classes=2)

    preds_epoch = []
    labels_epoch = []
    for inputs, labels in dataloader:
        steps += 1
        inputs = inputs.to(device) # torch.Size([2, 1, 224, 224])
        labels = labels.to(device).unsqueeze(1).float() ## torch.Size([2, 1])

        outputs = model(inputs) # [batch_size, nb_classes]

        loss = criterion(outputs, labels)

        if optimizer:
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        preds_epoch.extend(torch.sigmoid(outputs).tolist())
        labels_epoch.extend(labels.tolist())
        threshold = 0.5
        prob = (torch.sigmoid(outputs)>threshold).long()
        
        conf = torch.flatten(clmetrics.confusion_matrix(prob, labels.to(prob.device, dtype=torch.int), num_classes=2))
        tn, fp, fn, tp = conf

        metrics.add_dict({
            'data_count': len(inputs),
            'loss': loss.item() * len(inputs),
            'tp': tp.item(),
            'tn': tn.item(),
            'fp': fp.item(),
            'fn': fn.item(),
        })
        cnt += len(inputs)

        if scheduler:
            scheduler.step()
        del outputs, loss, inputs, labels, prob
    logger.info(f'cnt = {cnt}')

    metrics['loss'] /= cnt

    def safe_div(x,y):
        if y == 0:
            return 0
        return x / y
    _TP,_TN, _FP, _FN = metrics['tp'], metrics['tn'], metrics['fp'], metrics['fn']
    acc = (_TP+_TN)/cnt
    sen = safe_div(_TP , (_TP + _FN))
    spe = safe_div(_TN , (_FP + _TN))
    prec = safe_div(_TP , (_TP + _FP))
    metrics.add('accuracy', acc)
    metrics.add('sensitivity', sen)
    metrics.add('specificity', spe)
    metrics.add('precision', prec)
    
    try:
        auc = roc_auc_score(labels_epoch, preds_epoch)
    except ValueError:
      auc = 0.
      print('ValueError. set auc = 0')
      pass
    try:
        aupr = average_precision_score(labels_epoch, preds_epoch)
    except ValueError:
      aupr = 0.
      print('ValueError. set aupr = 0')
      pass
    metrics.add('auroc', auc)
    metrics.add('aupr', aupr)

    logger.info(metrics)

    return metrics, preds_epoch, labels_epoch