def kb_el_eval(ranker, valid_dataloader, params, device, all_cand_enc): ranker.model.eval() y_true = [] y_pred = [] for batch in valid_dataloader: batch = tuple(t.to(device) for t in batch) assert params['is_biencoder'] token_ids, tags, cand_enc, cand_enc_mask, label_ids, label_mask, attn_mask, global_attn_mask = batch with torch.no_grad(): raw_ctxt_encoding = ranker.model.get_raw_ctxt_encoding( token_ids, attn_mask) ctxt_embeds = ranker.model.get_ctxt_embeds(raw_ctxt_encoding, tags) scores = ctxt_embeds.mm(all_cand_enc.t()) true_labels = label_ids[label_mask].cpu().tolist() y_true.extend(true_labels) pred_labels = torch.argmax(scores, dim=1).cpu().tolist() y_pred.extend(pred_labels) assert len(y_true) == len(y_pred) acc, f1_macro, f1_micro = utils.get_metrics_result(y_true, y_pred) print( f'Accuracy: {acc:.4f}, F1 macro: {f1_macro:.4f}, F1 micro: {f1_micro:.4f}' )
def ner_eval(ranker, valid_dataloader, params, device, pos_tag=1): ranker.model.eval() y_true = [] y_pred = [] for batch in valid_dataloader: batch = tuple(t.to(device) for t in batch) cand_enc = cand_enc_mask = label_ids = label_mask = None if params['is_biencoder']: token_ids, tags, cand_enc, cand_enc_mask, label_ids, label_mask, attn_mask, global_attn_mask = batch else: token_ids, tags, attn_mask, global_attn_mask = batch # evaluate: not leak information about tags global_attn_mask = None with torch.no_grad(): _, tags_pred, _ = ranker(token_ids, attn_mask, global_attn_mask, tags, golden_cand_enc=cand_enc, golden_cand_mask=cand_enc_mask, label_ids=label_ids, label_mask=label_mask) y_true.extend(tags[attn_mask].cpu().tolist()) y_pred.extend(tags_pred[attn_mask].cpu().tolist()) assert len(y_true) == len(y_pred) acc, precision, recall, f1, f1_macro, f1_micro = utils.get_metrics_result( y_true, y_pred, b_tag=pos_tag) # print result print( f'Accuracy: {acc:.4f}, F1 macro: {f1_macro:.4f}, F1 micro: {f1_micro:.4f}' ) print( f'Tag to investigate is {pos_tag}, metrics: precision {precision:.4f}, recall {recall:.4f}, F1 {f1:.4f}' )
def in_batch_el_eval(ranker, valid_dataloader, params, device): ranker.model.eval() y_true = [] y_pred = [] for batch in valid_dataloader: batch = tuple(t.to(device) for t in batch) assert params['is_biencoder'] total = corr = 0 token_ids, tags, cand_enc, cand_enc_mask, label_ids, label_mask, attn_mask, global_attn_mask = batch # evaluate: not leak information about tags global_attn_mask = None with torch.no_grad(): raw_ctxt_encoding = ranker.model.get_raw_ctxt_encoding( token_ids, attn_mask, global_attn_mask) ctxt_embeds = ranker.model.get_ctxt_embeds(raw_ctxt_encoding, tags) # get similarity scores cand_enc = cand_enc[cand_enc_mask] scores = ctxt_embeds.mm(cand_enc.t()) true_labels = label_ids[label_mask].cpu().tolist() y_true.extend(true_labels) id2label = {i: lab for i, lab in enumerate(true_labels)} pred_inds = torch.argmax(scores, dim=1).cpu().tolist() pred_labels = [id2label[i] for i in pred_inds] y_pred.extend(pred_labels) assert len(y_true) == len(y_pred) acc, f1_macro, f1_micro = utils.get_metrics_result(y_true, y_pred) print( f'Accuracy: {acc:.4f}, F1 macro: {f1_macro:.4f}, F1 micro: {f1_micro:.4f}' )