예제 #1
0
def do_fscore(sentdb, model, device, args):
    """
    micro-avgd segment-level f1-score
    """
    model.eval()
    total_pred, total_gold, total_crct = 0.0, 0.0, 0.0
    for i in range(len(sentdb.val_minibatches)):
        x, neighbs, Cx, Cn, tag2mask, gold = sentdb.predword_batch(
            i, args.eval_ne_per_sent)
        x, neighbs = x.to(device), neighbs.to(device)
        Cx, Cn, = Cx.to(device), Cn.to(device)
        tag2mask = [(tag, mask.to(device)) for (tag, mask) in tag2mask]
        # Tne, nesz = neighbs.size()
        batch_reps = model.get_word_reps(x, Cx)  # bsz x T x dim
        ne_reps = model.get_word_reps(  # nesz x max_len x dim
            neighbs,
            Cn,
            shard_bsz=args.pred_shard_size)
        preds = get_batch_preds(batch_reps, ne_reps.view(-1, ne_reps.size(2)),
                                tag2mask, args)[0]  # bsz x T
        if args.acc_eval:
            bpred, bcrct = eval_util.batch_acc_eval(preds, gold)
            bgold = bpred
        else:
            bpred, bgold, bcrct = eval_util.batch_span_eval(preds, gold)
        total_pred += bpred
        total_gold += bgold
        total_crct += bcrct
    microp = total_crct / total_pred if total_pred > 0 else 0
    micror = total_crct / total_gold if total_gold > 0 else 0
    microf1 = 2 * microp * micror / (microp + micror)
    return microp, micror, microf1
예제 #2
0
def do_fscore(sentdb, model, device, idx2tag, args, scripteval=False, labelmap=None):
    """
    micro-avgd segment-level f1-score
    """
    model.eval()
    total_pred, total_gold, total_crct = 0.0, 0.0, 0.0
    goldss, predss = [], []
    zero_shot = labelmap is not None
    for i in range(len(sentdb.val_minibatches)):
        x, Cx, tgts = sentdb.pp_word_batch(i, val=True, gold_as_str=zero_shot)
        x, Cx = x.to(device), Cx.to(device)
        if not isinstance(tgts, list):
            tgts = tgts.to(device)
        #bsz, T = x.size()
        mask = (x != 0).long()
        logits = model(x, attention_mask=mask) # bsz x num_wrd_pcs x num_labels
        wlogits = torch.bmm(Cx, logits) # bsz x T x num_labels
        _, preds = wlogits.max(2) # bsz x T
        if not zero_shot:
            gold = [[idx2tag[idx.item()] for idx in row] for row in tgts]
        else:
            gold = tgts
        preds = [[idx2tag[idx.item()] for idx in row] for row in preds]
        if labelmap is not None:
            preds = [[labelmap[labe] for labe in pred] for pred in preds]
        if scripteval:
            goldss.append(gold)
            predss.append(preds)
        else:
            if args.acc_eval:
                bpred, bcrct = eval_util.batch_acc_eval(preds, gold)
                bgold = bpred
            else:
                bpred, bgold, bcrct = eval_util.batch_span_eval(preds, gold)
            total_pred += bpred
            total_gold += bgold
            total_crct += bcrct
    if scripteval:
        acc, prec, rec, f1 = eval_util.run_conll(goldss, predss)
        if args.acc_eval:
            microp, micror, microf1 = acc, acc, acc
        else:
            microp, micror, microf1 = prec, rec, f1
    else:
        microp = total_crct/total_pred if total_pred > 0 else 0
        micror = total_crct/total_gold if total_gold > 0 else 0
        microf1 = 2*microp*micror/(microp + micror)
    return microp, micror, microf1
예제 #3
0
def do_single_fscore(sentdb, model, device, args):
    """
    micro-avgd segment-level f1-score
    """
    model.eval()
    total_pred, total_gold, total_crct = 0.0, 0.0, 0.0
    total_copies, total_words = 0, 0
    print("predicting on", len(sentdb.vsent_words), "sentences")
    for i in range(len(sentdb.vsent_words)):
        #if i > 20:
        #    break
        if i % 200 == 0:
            print("sent", i)
        x, neighbs, Cx, Cn, tag2mask, gold, ne_tags = sentdb.pred_single_batch(
            i, args.eval_ne_per_sent)
        x, neighbs = x.to(device), neighbs.to(device)
        Cx, Cn, = Cx.to(device), Cn.to(device)
        tag2mask = [(tag, mask.to(device)) for (tag, mask) in tag2mask]
        batch_reps = model.get_word_reps(x, Cx)  # 1 x T x dim
        ne_reps = model.get_word_reps(  # nesz x max_len x dim
            neighbs,
            Cn,
            shard_bsz=args.pred_shard_size)
        preds, ncopies = get_batch_preds(batch_reps,
                                         ne_reps.view(-1, ne_reps.size(2)),
                                         tag2mask,
                                         args,
                                         ne_tag_seqs=ne_tags)  # 1 x T
        if ncopies is not None:
            total_copies += ncopies
        if args.acc_eval:
            bpred, bcrct = eval_util.batch_acc_eval(preds, gold)
            bgold = bpred
        else:
            bpred, bgold, bcrct = eval_util.batch_span_eval(preds, gold)
        total_pred += bpred
        total_gold += bgold
        total_crct += bcrct
        total_words += Cx.size(1)
    if args.dp_pred:
        print("avg moves/sent", total_copies / len(sentdb.vsent_words))
        print("avg words/move", total_words / total_copies)
    microp = total_crct / total_pred if total_pred > 0 else 0
    micror = total_crct / total_gold if total_gold > 0 else 0
    microf1 = 2 * microp * micror / (microp + micror)
    return microp, micror, microf1