示例#1
0
def test_v1_5_metric_auc_auroc():
    AUC.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        AUC()

    ROC.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        ROC()

    AUROC.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        AUROC()

    x = torch.tensor([0, 1, 2, 3])
    y = torch.tensor([0, 1, 2, 2])
    auc._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        assert auc(x, y) == torch.tensor(4.)

    preds = torch.tensor([0, 1, 2, 3])
    target = torch.tensor([0, 1, 1, 1])
    roc._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        fpr, tpr, thrs = roc(preds, target, pos_label=1)
    assert torch.equal(fpr, torch.tensor([0., 0., 0., 0., 1.]))
    assert torch.allclose(tpr, torch.tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]), atol=1e-4)
    assert torch.equal(thrs, torch.tensor([4, 3, 2, 1, 0]))

    preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34])
    target = torch.tensor([0, 0, 1, 1, 1])
    auroc._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        assert auroc(preds, target) == torch.tensor(0.5)
示例#2
0
 def get_opt_threshold(self, categorical_feature_idcs):
     max_seq_len = len(max(self.all_targets, key=lambda x: len(x)))
     _thresholds = torch.empty(max_seq_len)
     if self.args.variable_threshold:
         targets = [target[:, categorical_feature_idcs] for target in self.all_targets]
         preds = [torch.sigmoid(pred[:, categorical_feature_idcs]) for pred in self.all_preds]
         for time_idx in range(max_seq_len):
             targets_timestep = torch.tensor([target[time_idx] for target in targets if len(target) > time_idx])
             preds_timestep = torch.tensor([pred[time_idx] for pred in preds if len(pred) > time_idx])
             fpr, tpr, thresholds = roc(preds_timestep, targets_timestep)
             optimal_idx = np.argmax(tpr - fpr)
             _thresholds[time_idx] = thresholds[optimal_idx]
     else:
          flat_targets = torch.cat([target[:, categorical_feature_idcs] for target in self.all_targets]).cpu()
          flat_preds = torch.cat([torch.sigmoid(pred[:, categorical_feature_idcs]) for pred in self.all_preds]).cpu()
          fpr, tpr, thresholds = roc(flat_preds, flat_targets)
          optimal_idx = np.argmax(tpr - fpr)
          _thresholds[:] = thresholds[optimal_idx]
     return _thresholds
 def _get_metrics_at_optimal_cutoff(
         self
 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
     """
     Computes the ROC to find the optimal cut-off i.e. the probability threshold for which the
     difference between true positive rate and false positive rate is smallest. Then, computes
     the false positive rate, false negative rate and accuracy at this threshold (i.e. when the
     predicted probability is higher than the threshold the predicted label is 1 otherwise 0).
     :returns: Tuple(optimal_threshold, false positive rate, false negative rate, accuracy)
     """
     preds, targets = self._get_preds_and_targets()
     if torch.unique(targets).numel() == 1:
         return torch.tensor(np.nan), torch.tensor(np.nan), torch.tensor(
             np.nan), torch.tensor(np.nan)
     fpr, tpr, thresholds = roc(preds, targets)
     optimal_idx = torch.argmax(tpr - fpr)
     optimal_threshold = thresholds[optimal_idx]
     acc = accuracy(preds > optimal_threshold, targets)
     false_negative_optimal = 1 - tpr[optimal_idx]
     false_positive_optimal = fpr[optimal_idx]
     return optimal_threshold, false_positive_optimal, false_negative_optimal, acc
示例#4
0
文件: eval.py 项目: WangXuhongCN/APAN
def eval_epoch(args, logger, g, dataloader, encoder, decoder, msg2mail,
               loss_fcn, device, num_samples):

    m_ap, m_auc, m_acc = [[], [], []] if 'LP' in args.tasks else [0, 0, 0]

    labels_all = torch.zeros((num_samples)).long()
    logits_all = torch.zeros((num_samples))

    attn_weight_all = torch.zeros((num_samples, args.n_mail))

    m_loss = []
    m_infer_time = []
    with torch.no_grad():
        encoder.eval()
        decoder.eval()
        loss = torch.tensor(0)
        for batch_idx, (input_nodes, pos_graph, neg_graph, blocks, frontier,
                        current_ts) in enumerate(dataloader):
            n_sample = pos_graph.num_edges()
            start_idx = batch_idx * n_sample
            end_idx = min(num_samples, start_idx + n_sample)

            pos_graph = pos_graph.to(device)
            neg_graph = neg_graph.to(device) if neg_graph is not None else None
            if not args.no_time or not args.no_pos:
                current_ts, pos_ts, num_pos_nodes = get_current_ts(
                    args, pos_graph, neg_graph)
                pos_graph.ndata['ts'] = current_ts
            else:
                current_ts, pos_ts, num_pos_nodes = None, None, None

            _ = dgl.add_reverse_edges(
                neg_graph) if neg_graph is not None else None

            start = time.time()
            emb, attn_weight = encoder(dgl.add_reverse_edges(pos_graph), _,
                                       num_pos_nodes)
            #attn_weight_all[start_idx:end_idx] = attn_weight[:n_sample]

            logits, labels = decoder(emb, pos_graph, neg_graph)
            end = time.time() - start
            m_infer_time.append(end)

            loss = loss_fcn(logits, labels)
            m_loss.append(loss.item())
            mail = msg2mail.gen_mail(args, emb, input_nodes, pos_graph,
                                     frontier, 'val')
            if not args.no_time:
                g.ndata['last_update'][pos_graph.ndata[dgl.NID]
                                       [:num_pos_nodes]] = pos_ts.to('cpu')
            g.ndata['feat'][pos_graph.ndata[dgl.NID]] = emb.to('cpu')
            g.ndata['mail'][input_nodes] = mail

            labels = labels.long()
            logits = logits.sigmoid()
            if 'LP' in args.tasks:
                pred = logits > 0.5
                m_ap.append(average_precision(logits, labels).cpu().numpy())
                m_auc.append(auroc(logits, labels).cpu().numpy())
                m_acc.append(accuracy(pred, labels).cpu().numpy())
            else:
                labels_all[start_idx:end_idx] = labels
                logits_all[start_idx:end_idx] = logits

    if 'LP' in args.tasks:
        ap, auc, acc = np.mean(m_ap), np.mean(m_auc), np.mean(m_acc)
    else:
        pred_all = logits_all > 0.5
        ap = average_precision(logits_all, labels_all).cpu().item()
        auc = auroc(logits_all, labels_all).cpu().item()
        acc = accuracy(pred_all, labels_all).cpu().item()

        fprs, tprs, thresholds = roc(logits_all, labels_all)
        fpr_l, tpr_l, thres_l = get_TPR_FPR_metrics(fprs, tprs, thresholds)
        print_tp_fp_thres(args.tasks, logger, fpr_l, tpr_l, thres_l)

    print('总推理时间', np.sum(m_infer_time))
    logger.info(attn_weight_all.mean(0))
    encoder.train()
    decoder.train()
    return ap, auc, acc, np.mean(m_loss)
示例#5
0
文件: utils.py 项目: vjoki/fsl-experi
def compute_evaluation_metrics(outputs: List[List[torch.Tensor]],
                               plot: bool = False,
                               prefix: Optional[str] = None) -> Dict[str, torch.Tensor]:
    scores = torch.cat(list((scores for step in outputs for scores in step[0])))
    # NOTE: Need sigmoid here because we skip the sigmoid in forward() due to using BCE with logits for loss.
    #scores = torch.sigmoid(scores)
    print('Score range: [{}, {}]'
          .format(torch.min(scores).item(),
                  torch.max(scores).item()))
    labels = torch.cat(list((labels for step in outputs for labels in step[1])))

    auc = auroc(scores, labels, pos_label=1)
    fpr, tpr, thresholds = roc(scores, labels, pos_label=1)
    prec, recall = precision_recall(scores, labels)

    # mypy massaging, single tensors when num_classes is not specified (= binary case).
    fpr = cast(torch.Tensor, fpr)
    tpr = cast(torch.Tensor, tpr)
    thresholds = cast(torch.Tensor, thresholds)

    fnr = 1 - tpr
    eer, eer_threshold, idx = equal_error_rate(fpr, fnr, thresholds)
    min_dcf, min_dcf_threshold = minDCF(fpr, fnr, thresholds)

    # Accuracy based on EER and minDCF thresholds.
    eer_preds = (scores >= eer_threshold).long()
    min_dcf_preds = (scores >= min_dcf_threshold).long()
    eer_acc = torch.sum(eer_preds == labels).float() / labels.numel()
    min_dcf_acc = torch.sum(min_dcf_preds == labels).float() / labels.numel()

    if plot:
        assert idx.dim() == 0 or (idx.dim() == 1 and idx.size(0) == 1)
        i = int(idx.item())
        fpr = fpr.cpu().numpy()
        tpr = tpr.cpu().numpy()
        plt.xlabel('False positive rate')
        plt.ylabel('True positive rate')
        plt.plot([0, 1], [0, 1], 'r--', label='Reference', alpha=0.6)
        plt.plot([1, 0], [0, 1], 'k--', label='EER line', alpha=0.6)
        plt.plot(fpr, tpr, label='ROC curve')
        plt.fill_between(fpr, tpr, 0, label='AUC', color='0.8')
        plt.plot(fpr[i], tpr[i], 'ko', label='EER = {:.2f}%'.format(eer * 100))  # EER point
        plt.legend()
        plt.show()

    if prefix:
        prefix = '{}_'.format(prefix)
    else:
        prefix = ''

    return {
        '{}eer'.format(prefix): eer,
        '{}eer_acc'.format(prefix): eer_acc,
        '{}eer_threshold'.format(prefix): eer_threshold,
        '{}auc'.format(prefix): auc,
        '{}min_dcf'.format(prefix): min_dcf,
        '{}min_dcf_acc'.format(prefix): min_dcf_acc,
        '{}min_dcf_threshold'.format(prefix): min_dcf_threshold,
        '{}prec'.format(prefix): prec,
        '{}recall'.format(prefix): recall
    }