Beispiel #1
0
def kb_el_eval(ranker, valid_dataloader, params, device, all_cand_enc):
    ranker.model.eval()
    y_true = []
    y_pred = []

    for batch in valid_dataloader:
        batch = tuple(t.to(device) for t in batch)
        assert params['is_biencoder']

        token_ids, tags, cand_enc, cand_enc_mask, label_ids, label_mask, attn_mask, global_attn_mask = batch

        with torch.no_grad():
            raw_ctxt_encoding = ranker.model.get_raw_ctxt_encoding(
                token_ids, attn_mask)
            ctxt_embeds = ranker.model.get_ctxt_embeds(raw_ctxt_encoding, tags)

        scores = ctxt_embeds.mm(all_cand_enc.t())

        true_labels = label_ids[label_mask].cpu().tolist()
        y_true.extend(true_labels)
        pred_labels = torch.argmax(scores, dim=1).cpu().tolist()
        y_pred.extend(pred_labels)
        assert len(y_true) == len(y_pred)

    acc, f1_macro, f1_micro = utils.get_metrics_result(y_true, y_pred)
    print(
        f'Accuracy: {acc:.4f}, F1 macro: {f1_macro:.4f}, F1 micro: {f1_micro:.4f}'
    )
Beispiel #2
0
def ner_eval(ranker, valid_dataloader, params, device, pos_tag=1):
    ranker.model.eval()
    y_true = []
    y_pred = []

    for batch in valid_dataloader:
        batch = tuple(t.to(device) for t in batch)
        cand_enc = cand_enc_mask = label_ids = label_mask = None
        if params['is_biencoder']:
            token_ids, tags, cand_enc, cand_enc_mask, label_ids, label_mask, attn_mask, global_attn_mask = batch
        else:
            token_ids, tags, attn_mask, global_attn_mask = batch

        # evaluate: not leak information about tags
        global_attn_mask = None
        with torch.no_grad():
            _, tags_pred, _ = ranker(token_ids,
                                     attn_mask,
                                     global_attn_mask,
                                     tags,
                                     golden_cand_enc=cand_enc,
                                     golden_cand_mask=cand_enc_mask,
                                     label_ids=label_ids,
                                     label_mask=label_mask)

        y_true.extend(tags[attn_mask].cpu().tolist())
        y_pred.extend(tags_pred[attn_mask].cpu().tolist())
        assert len(y_true) == len(y_pred)

    acc, precision, recall, f1, f1_macro, f1_micro = utils.get_metrics_result(
        y_true, y_pred, b_tag=pos_tag)

    # print result
    print(
        f'Accuracy: {acc:.4f}, F1 macro: {f1_macro:.4f}, F1 micro: {f1_micro:.4f}'
    )
    print(
        f'Tag to investigate is {pos_tag}, metrics: precision {precision:.4f}, recall {recall:.4f}, F1 {f1:.4f}'
    )
Beispiel #3
0
def in_batch_el_eval(ranker, valid_dataloader, params, device):
    ranker.model.eval()
    y_true = []
    y_pred = []

    for batch in valid_dataloader:
        batch = tuple(t.to(device) for t in batch)
        assert params['is_biencoder']

        total = corr = 0

        token_ids, tags, cand_enc, cand_enc_mask, label_ids, label_mask, attn_mask, global_attn_mask = batch

        # evaluate: not leak information about tags
        global_attn_mask = None
        with torch.no_grad():
            raw_ctxt_encoding = ranker.model.get_raw_ctxt_encoding(
                token_ids, attn_mask, global_attn_mask)
            ctxt_embeds = ranker.model.get_ctxt_embeds(raw_ctxt_encoding, tags)

        # get similarity scores
        cand_enc = cand_enc[cand_enc_mask]
        scores = ctxt_embeds.mm(cand_enc.t())

        true_labels = label_ids[label_mask].cpu().tolist()
        y_true.extend(true_labels)

        id2label = {i: lab for i, lab in enumerate(true_labels)}
        pred_inds = torch.argmax(scores, dim=1).cpu().tolist()
        pred_labels = [id2label[i] for i in pred_inds]
        y_pred.extend(pred_labels)
        assert len(y_true) == len(y_pred)

    acc, f1_macro, f1_micro = utils.get_metrics_result(y_true, y_pred)
    print(
        f'Accuracy: {acc:.4f}, F1 macro: {f1_macro:.4f}, F1 micro: {f1_micro:.4f}'
    )