def obtain_sorted_indices(src, tgt_seqs, sort_by): """ :param src: used for verbatim and alphabetical :param tgt_seqs: :param sort_by: :param absent_pos: must be one of [prepend, append and ignore], ignore means simply drop absent kps :return: """ num_tgt = len(tgt_seqs) src = src[0] tgt_seqs = [tgt[0] for tgt in tgt_seqs] if sort_by == 'no_sort': sorted_id = list(range(len(tgt_seqs))) elif sort_by == 'random': sorted_id = np.random.permutation(num_tgt) elif sort_by.startswith('verbatim'): # obtain present flags as well their positions, lowercase should be done beforehand present_tgt_flags, present_indices, _ = if_present_duplicate_phrases( src, tgt_seqs, stemming=False, lowercase=False) # separate present/absent phrases present_tgt_idx = np.arange(num_tgt)[present_tgt_flags] absent_tgt_idx = [ t_id for t_id, present in zip(range(num_tgt), present_tgt_flags) if ~present ] absent_tgt_idx = np.random.permutation(absent_tgt_idx) # sort present phrases by their positions present_indices = present_indices[present_tgt_flags] present_tgt_idx = sorted(zip(present_tgt_idx, present_indices), key=lambda x: x[1]) present_tgt_idx = [t[0] for t in present_tgt_idx] if sort_by.endswith('append'): sorted_id = np.concatenate((present_tgt_idx, absent_tgt_idx), axis=None) elif sort_by.endswith('prepend'): sorted_id = np.concatenate((absent_tgt_idx, present_tgt_idx), axis=None) else: sorted_id = present_tgt_idx elif sort_by == 'alphabetical': sorted_tgts = sorted(enumerate(tgt_seqs), key=lambda x: '_'.join(x[1])) sorted_id = [t[0] for t in sorted_tgts] elif sort_by == 'length': sorted_tgts = sorted(enumerate(tgt_seqs), key=lambda x: len(x[1])) sorted_id = [t[0] for t in sorted_tgts] return np.asarray(sorted_id, dtype=int)
def evaluate(src_list, tgt_list, pred_list, unk_token, logger=None, verbose=False, report_path=None, eval_topbeam=False): # progbar = Progbar(logger=logger, title='', target=len(pred_list), total_examples=len(pred_list)) if report_path: report_file = open(report_path, 'w+') else: report_file = None # 'k' means the number of phrases in ground-truth topk_range = [5, 10, 'k', 'M'] absent_topk_range = [10, 30, 50] # 'precision_hard' and 'f_score_hard' mean that precision is calculated with denominator strictly as K (say 5 or 10), won't be lessened even number of preds is smaller metric_names = [ 'correct', 'precision', 'recall', 'f_score', 'precision_hard', 'f_score_hard' ] score_dict = { } # {'precision@5':[],'recall@5':[],'f1score@5':[], 'precision@10':[],'recall@10':[],'f1score@10':[]} ''' process each example in current batch ''' for i, (src_dict, tgt_dict, pred_dict) in tqdm.tqdm( enumerate(zip(src_list, tgt_list, pred_list))): src_seq = src_dict["src"].split() tgt_seqs = [t.split() for t in tgt_dict["tgt"]] if eval_topbeam: pred_sents = pred_dict["topseq_pred_sents"] pred_idxs = pred_dict[ "topseq_preds"] if "topseq_preds" in pred_dict else None pred_scores = pred_dict[ "topseq_pred_scores"] if "topseq_pred_scores" in pred_dict else None copied_flags = None else: pred_sents = pred_dict["pred_sents"] pred_idxs = pred_dict["preds"] if "preds" in pred_dict else None pred_scores = pred_dict[ "pred_scores"] if "pred_scores" in pred_dict else None copied_flags = pred_dict[ "copied_flags"] if "copied_flags" in pred_dict else None # src, src_str, tgt, tgt_str_seqs, tgt_copy, pred_seq, oov print_out = '====================== %d =========================' % ( i) print_out += '\n[Title]: %s \n' % (src_dict["title"]) print_out += '[Abstract]: %s \n' % (src_dict["abstract"]) # print_out += '[Source tokenized][%d]: %s \n' % (len(src_seq), ' '.join(src_seq)) # print_out += 'Real Target [%d] \n\t\t%s \n' % (len(tgt_seqs), str(tgt_seqs)) # check which phrases are present in source text present_tgt_flags, _, _ = if_present_duplicate_phrases( src_seq, tgt_seqs) print_out += '[GROUND-TRUTH] #(present)/#(all targets)=%d/%d\n' % ( sum(present_tgt_flags), len(present_tgt_flags)) print_out += '\n'.join([ '\t\t[%s]' % ' '.join(phrase) if is_present else '\t\t%s' % ' '.join(phrase) for phrase, is_present in zip(tgt_seqs, present_tgt_flags) ]) # 1st filtering, ignore phrases having <unk> and puncs valid_pred_flags = process_predseqs(pred_sents, unk_token) # 2nd filtering: if filter out phrases that don't appear in text, and keep unique ones after stemming present_pred_flags, _, duplicate_flags = if_present_duplicate_phrases( src_seq, pred_sents) # treat duplicates as invalid valid_pred_flags = valid_pred_flags * ~duplicate_flags if len( valid_pred_flags) > 0 else [] valid_and_present = valid_pred_flags * present_pred_flags if len( valid_pred_flags) > 0 else [] print_out += '\n[PREDICTION] #(valid)=%d, #(present)=%d, #(retained&present)=%d, #(all)=%d\n' % ( sum(valid_pred_flags), sum(present_pred_flags), sum(valid_and_present), len(pred_sents)) print_out += '' # compute match scores (exact, partial and mixed), for exact it's a list otherwise matrix match_scores_exact = get_match_result(true_seqs=tgt_seqs, pred_seqs=pred_sents, type='exact') match_scores_partial = get_match_result(true_seqs=tgt_seqs, pred_seqs=pred_sents, type='ngram') # simply add full-text to n-grams might not be good as its contribution is not clear match_scores_mixed = get_match_result(true_seqs=tgt_seqs, pred_seqs=pred_sents, type='mixed') # sanity check of pred (does not work for eval_topbeam, discard) # num_pred = len(pred_dict["pred_sents"]) # for d_name, d in zip(['pred_idxs', 'pred_sents', 'pred_scores', # 'match_scores_exact', 'valid_pred_flags', # 'present_pred_flags', 'copied_flags'], # [pred_idxs, pred_sents, pred_scores, # match_scores_exact, valid_pred_flags, # present_pred_flags, copied_flags]): # if d is not None: # if len(d) != num_pred: # logger.error('%s number does not match' % d_name) # assert len(d) == num_pred ''' Print and export predictions ''' preds_out = '' for p_id, (word, match, match_soft, is_valid, is_present) in enumerate( zip(pred_sents, match_scores_exact, match_scores_partial, valid_pred_flags, present_pred_flags)): # if p_id > 5: # break score = pred_scores[p_id] if pred_scores else "Score N/A" pred_idx = pred_idxs[p_id] if pred_idxs else "Index N/A" copied_flag = copied_flags[p_id] if copied_flags else "CopyFlag N/A" preds_out += '%s\n' % (' '.join(word)) if is_present: print_phrase = '[%s]' % ' '.join(word) else: print_phrase = ' '.join(word) if match == 1.0: correct_str = '[correct!]' else: correct_str = '' if any(copied_flag): copy_str = '[copied!]' else: copy_str = '' pred_str = '\t\t%s\t%s \t %s %s%s\n' % ( '[%.4f]' % (-score) if pred_scores else "Score N/A", print_phrase, str(pred_idx), correct_str, copy_str) if not is_valid: pred_str = '\t%s' % pred_str print_out += pred_str # split tgts by present/absent present_tgts = [ tgt for tgt, present in zip(tgt_seqs, present_tgt_flags) if present ] absent_tgts = [ tgt for tgt, present in zip(tgt_seqs, present_tgt_flags) if ~present ] # filter out results of invalid preds valid_pred_sents = [ seq for seq, valid in zip(pred_sents, valid_pred_flags) if valid ] present_pred_flags = present_pred_flags[valid_pred_flags] match_scores_exact = match_scores_exact[valid_pred_flags] match_scores_partial = match_scores_partial[valid_pred_flags] match_scores_mixed = match_scores_mixed[valid_pred_flags] # split preds by present/absent and exact/partial/mixed present_preds = [ pred for pred, present in zip(valid_pred_sents, present_pred_flags) if present ] absent_preds = [ pred for pred, present in zip(valid_pred_sents, present_pred_flags) if ~present ] if len(present_pred_flags) > 0: present_exact_match_scores = match_scores_exact[present_pred_flags] present_partial_match_scores = match_scores_partial[ present_pred_flags][:, present_tgt_flags] present_mixed_match_scores = match_scores_mixed[ present_pred_flags][:, present_tgt_flags] absent_exact_match_scores = match_scores_exact[~present_pred_flags] absent_partial_match_scores = match_scores_partial[ ~present_pred_flags][:, ~present_tgt_flags] absent_mixed_match_scores = match_scores_mixed[ ~present_pred_flags][:, ~present_tgt_flags] else: present_exact_match_scores = [] present_partial_match_scores = [] present_mixed_match_scores = [] absent_exact_match_scores = [] absent_partial_match_scores = [] absent_mixed_match_scores = [] # assert len(valid_pred_sents) == len(match_scores_exact) == len(present_pred_flags) # assert len(present_preds) == len(present_exact_match_scores) == len(present_partial_match_scores) == len(present_mixed_match_scores) # assert present_partial_match_scores.shape == present_mixed_match_scores.shape # assert len(absent_preds) == len(absent_exact_match_scores) == len(absent_partial_match_scores) == len(absent_mixed_match_scores) # assert absent_partial_match_scores.shape == absent_mixed_match_scores.shape # Compute metrics print_out += "\n =======================================================" # get the scores on different scores (for absent results, only recall matters) present_exact_results = run_metrics(present_exact_match_scores, present_preds, present_tgts, metric_names, topk_range) absent_exact_results = run_metrics(absent_exact_match_scores, absent_preds, absent_tgts, metric_names, absent_topk_range) present_partial_results = run_metrics(present_partial_match_scores, present_preds, present_tgts, metric_names, topk_range, type='partial') absent_partial_results = run_metrics(absent_partial_match_scores, absent_preds, absent_tgts, metric_names, absent_topk_range, type='partial') # present_mixed_results = run_metrics(present_mixed_match_scores, present_preds, present_tgts, metric_names, topk_range, type='partial') # absent_mixed_results = run_metrics(absent_mixed_match_scores, absent_preds, absent_tgts, metric_names, absent_topk_range, type='partial') def _gather_scores(gathered_scores, results_names, results_dicts): for result_name, result_dict in zip(results_names, results_dicts): for metric_name, score in result_dict.items(): if metric_name.endswith('_num'): # if it's 'present_tgt_num' or 'absent_tgt_num', leave as is field_name = result_name else: # if it's other score like 'precision@5' is renamed to like 'present_exact_precision@' field_name = result_name + '_' + metric_name if field_name not in gathered_scores: gathered_scores[field_name] = [] gathered_scores[field_name].append(score) return gathered_scores results_names = [ 'present_exact', 'absent_exact', 'present_partial', 'absent_partial', # 'present_mixed', 'absent_mixed' ] results_list = [ present_exact_results, absent_exact_results, present_partial_results, absent_partial_results, # present_mixed_results, absent_mixed_results ] # update score_dict, appending new scores (results_list) to it score_dict = _gather_scores(score_dict, results_names, results_list) for name, resutls in zip(results_names, results_list): if name.startswith('present'): topk = 5 else: topk = 50 print_out += "\n --- batch {} P/R/F1/Corr @{}: \t".format(name, topk) \ + " {:.4f} , {:.4f} , {:.4f} , {:2f}".format(resutls['precision@{}'.format(topk)], resutls['recall@{}'.format(topk)], resutls['f_score@{}'.format(topk)], resutls['correct@{}'.format(topk)]) print_out += "\n --- total {} P/R/F1/Corr @{}: \t".format(name, topk) \ + " {:.4f} , {:.4f} , {:.4f} , {:2f}".format(np.average(score_dict['{}_precision@{}'.format(name, topk)]), np.average(score_dict['{}_recall@{}'.format(name, topk)]), np.average(score_dict['{}_f_score@{}'.format(name, topk)]), np.sum(score_dict['{}_correct@{}'.format(name, topk)])) if name.startswith('present'): topk = 10 print_out += "\n --- batch {} P/R/F1/Corr @{}: \t".format(name, topk) \ + " {:.4f} , {:.4f} , {:.4f} , {:2f}".format(resutls['precision@{}'.format(topk)], resutls['recall@{}'.format(topk)], resutls['f_score@{}'.format(topk)], resutls['correct@{}'.format(topk)]) print_out += "\n --- total {} P/R/F1/Corr @{}: \t".format(name, topk) \ + " {:.4f} , {:.4f} , {:.4f} , {:2f}".format(np.average(score_dict['{}_precision@{}'.format(name, topk)]), np.average(score_dict['{}_recall@{}'.format(name, topk)]), np.average(score_dict['{}_f_score@{}'.format(name, topk)]), np.sum(score_dict['{}_correct@{}'.format(name, topk)])) print_out += "\n =======================================================" if verbose: if logger: logger.info(print_out) else: print(print_out) if report_file: report_file.write(print_out) # add tgt/pred count for computing average performance on non-empty items results_names = [ 'present_tgt_num', 'absent_tgt_num', 'present_pred_num', 'absent_pred_num', 'unique_pred_num', 'dup_pred_num', 'beam_num', 'beamstep_num' ] results_list = [ { 'present_tgt_num': len(present_tgts) }, { 'absent_tgt_num': len(absent_tgts) }, { 'present_pred_num': len(present_preds) }, { 'absent_pred_num': len(absent_preds) }, { 'unique_pred_num': pred_dict['unique_pred_num'] if 'unique_pred_num' in pred_dict else 0 }, { 'dup_pred_num': pred_dict['dup_pred_num'] if 'dup_pred_num' in pred_dict else 0 }, { 'beam_num': pred_dict['beam_num'] if 'beam_num' in pred_dict else 0 }, { 'beamstep_num': pred_dict['beamstep_num'] if 'beamstep_num' in pred_dict else 0 }, ] score_dict = _gather_scores(score_dict, results_names, results_list) # for k, v in score_dict.items(): # print('%s, num=%d, mean=%f' % (k, len(v), np.average(v))) if report_file: report_file.close() return score_dict
def eval_and_print(src_text, tgt_kps, pred_kps, pred_scores, unk_token='<unk>'): src_seq = [ t.text.lower() for t in spacy_nlp(src_text, disable=["textcat"]) ] tgt_seqs = [[t.text.lower() for t in spacy_nlp(p, disable=["textcat"])] for p in tgt_kps] pred_seqs = [[t.text.lower() for t in spacy_nlp(p, disable=["textcat"])] for p in pred_kps] topk_range = ['k', 10] absent_topk_range = ['M'] metric_names = ['f_score'] # 1st filtering, ignore phrases having <unk> and puncs valid_pred_flags = validate_phrases(pred_seqs, unk_token) # 2nd filtering: filter out phrases that don't appear in text, and keep unique ones after stemming present_pred_flags, _, duplicate_flags = if_present_duplicate_phrases( src_seq, pred_seqs) # treat duplicates as invalid valid_pred_flags = valid_pred_flags * ~duplicate_flags if len( valid_pred_flags) > 0 else [] valid_and_present_flags = valid_pred_flags * present_pred_flags if len( valid_pred_flags) > 0 else [] valid_and_absent_flags = valid_pred_flags * ~present_pred_flags if len( valid_pred_flags) > 0 else [] # compute match scores (exact, partial and mixed), for exact it's a list otherwise matrix match_scores_exact = compute_match_scores(tgt_seqs=tgt_seqs, pred_seqs=pred_seqs, do_lower=True, do_stem=True, type='exact') # split tgts by present/absent present_tgt_flags, _, _ = if_present_duplicate_phrases(src_seq, tgt_seqs) present_tgts = [ tgt for tgt, present in zip(tgt_seqs, present_tgt_flags) if present ] absent_tgts = [ tgt for tgt, present in zip(tgt_seqs, present_tgt_flags) if ~present ] # filter out results of invalid preds valid_preds = [ seq for seq, valid in zip(pred_seqs, valid_pred_flags) if valid ] valid_present_pred_flags = present_pred_flags[valid_pred_flags] valid_match_scores_exact = match_scores_exact[valid_pred_flags] # split preds by present/absent and exact/partial/mixed valid_present_preds = [ pred for pred, present in zip(valid_preds, valid_present_pred_flags) if present ] valid_absent_preds = [ pred for pred, present in zip(valid_preds, valid_present_pred_flags) if ~present ] present_exact_match_scores = valid_match_scores_exact[ valid_present_pred_flags] absent_exact_match_scores = valid_match_scores_exact[ ~valid_present_pred_flags] all_exact_results = run_classic_metrics(valid_match_scores_exact, valid_preds, tgt_seqs, metric_names, topk_range) present_exact_results = run_classic_metrics(present_exact_match_scores, valid_present_preds, present_tgts, metric_names, topk_range) absent_exact_results = run_classic_metrics(absent_exact_match_scores, valid_absent_preds, absent_tgts, metric_names, absent_topk_range) eval_results_names = ['all_exact', 'present_exact', 'absent_exact'] eval_results_list = [ all_exact_results, present_exact_results, absent_exact_results ] print_out = print_predeval_result( src_text, tgt_seqs, present_tgt_flags, pred_seqs, pred_scores, present_pred_flags, valid_pred_flags, valid_and_present_flags, valid_and_absent_flags, match_scores_exact, eval_results_names, eval_results_list) return print_out
def evaluate(src_list, tgt_list, pred_list, unk_token, logger=None, verbose=False, report_path=None, tokenizer=None): if report_path: report_file = open(report_path, 'w+') else: report_file = None # 'k' means the number of phrases in ground-truth, add 1,3 for openkp topk_range = [5, 10, 'k', 'M', 1, 3] absent_topk_range = [10, 50, 'k', 'M'] # 'precision_hard' and 'f_score_hard' mean that precision is calculated with denominator strictly as K (say 5 or 10), won't be lessened even number of preds is smaller metric_names = [ 'correct', 'precision', 'recall', 'f_score', 'precision_hard', 'f_score_hard' ] individual_score_dicts = [ ] # {'precision@5':[],'recall@5':[],'f1score@5':[], 'precision@10':[],'recall@10':[],'f1score@10':[]} gathered_score_dict = { } # {'precision@5':[],'recall@5':[],'f1score@5':[], 'precision@10':[],'recall@10':[],'f1score@10':[]} # for i, (src_dict, tgt_dict, pred_dict) in tqdm.tqdm(enumerate(zip(src_list, tgt_list, pred_list))): for i, (src_dict, tgt_dict, pred_dict) in tqdm.tqdm( enumerate(zip(src_list, tgt_list, pred_list))): """ 1. Process each data example and predictions """ pred_seqs = pred_dict["pred_sents"] pred_idxs = pred_dict["preds"] if "preds" in pred_dict else None pred_scores = pred_dict[ "pred_scores"] if "pred_scores" in pred_dict else None copied_flags = pred_dict[ "copied_flags"] if "copied_flags" in pred_dict else None # @memray 20200316 change to spacy tokenization rather than simple splitting or Meng's tokenization if tokenizer == 'spacy': src_seq = [ t.text for t in spacy_nlp(src_dict["src"], disable=["textcat"]) ] tgt_seqs = [[t.text for t in spacy_nlp(p, disable=["textcat"])] for p in tgt_dict["tgt"]] pred_seqs = [[ t.text for t in spacy_nlp(' '.join(p), disable=["textcat"]) ] for p in pred_seqs] unk_token = 'unk' elif tokenizer == 'split': src_seq = src_dict["src"].split() tgt_seqs = [t.split() for t in tgt_dict["tgt"]] pred_seqs = pred_dict["pred_sents"] elif tokenizer == 'split_nopunc': src_seq = [ t for t in re.split(r'\W', src_dict["src"]) if len(t) > 0 ] tgt_seqs = [[t for t in re.split(r'\W', p) if len(t) > 0] for p in tgt_dict["tgt"]] pred_seqs = [[ t for t in re.split(r'\W', ' '.join(p)) if len(t) > 0 ] for p in pred_dict["pred_sents"]] unk_token = 'unk' else: raise Exception( 'Unset or unsupported tokenizer for evaluation: %s' % str(tokenizer)) # 1st filtering, ignore phrases having <unk> and puncs valid_pred_flags = validate_phrases(pred_seqs, unk_token) # 2nd filtering: filter out phrases that don't appear in text, and keep unique ones after stemming present_pred_flags, _, duplicate_flags = if_present_duplicate_phrases( src_seq, pred_seqs, stemming=True, lowercase=True) # treat duplicates as invalid valid_pred_flags = valid_pred_flags * ~duplicate_flags if len( valid_pred_flags) > 0 else [] valid_and_present_flags = valid_pred_flags * present_pred_flags if len( valid_pred_flags) > 0 else [] valid_and_absent_flags = valid_pred_flags * ~present_pred_flags if len( valid_pred_flags) > 0 else [] # compute match scores (exact, partial and mixed), for exact it's a list otherwise matrix match_scores_exact = compute_match_scores(tgt_seqs=tgt_seqs, pred_seqs=pred_seqs, do_lower=True, do_stem=True, type='exact') match_scores_partial = compute_match_scores(tgt_seqs=tgt_seqs, pred_seqs=pred_seqs, do_lower=True, do_stem=True, type='ngram') # simply add full-text to n-grams might not be good as its contribution is not clear # match_scores_mixed = compute_match_scores(tgt_seqs=tgt_seqs, pred_seqs=pred_seqs, type='mixed') # split tgts by present/absent present_tgt_flags, _, _ = if_present_duplicate_phrases(src_seq, tgt_seqs, stemming=True, lowercase=True) present_tgts = [ tgt for tgt, present in zip(tgt_seqs, present_tgt_flags) if present ] absent_tgts = [ tgt for tgt, present in zip(tgt_seqs, present_tgt_flags) if ~present ] # filter out results of invalid preds valid_preds = [ seq for seq, valid in zip(pred_seqs, valid_pred_flags) if valid ] valid_present_pred_flags = present_pred_flags[valid_pred_flags] valid_match_scores_exact = match_scores_exact[valid_pred_flags] valid_match_scores_partial = match_scores_partial[valid_pred_flags] # match_scores_mixed = match_scores_mixed[valid_pred_flags] # split preds by present/absent and exact/partial/mixed valid_present_preds = [ pred for pred, present in zip(valid_preds, valid_present_pred_flags) if present ] valid_absent_preds = [ pred for pred, present in zip(valid_preds, valid_present_pred_flags) if ~present ] if len(valid_present_pred_flags) > 0: present_exact_match_scores = valid_match_scores_exact[ valid_present_pred_flags] present_partial_match_scores = valid_match_scores_partial[ valid_present_pred_flags][:, present_tgt_flags] # present_mixed_match_scores = match_scores_mixed[present_pred_flags][:, present_tgt_flags] absent_exact_match_scores = valid_match_scores_exact[ ~valid_present_pred_flags] absent_partial_match_scores = valid_match_scores_partial[ ~valid_present_pred_flags][:, ~present_tgt_flags] # absent_mixed_match_scores = match_scores_mixed[~present_pred_flags][:, ~present_tgt_flags] else: present_exact_match_scores = [] present_partial_match_scores = [] # present_mixed_match_scores = [] absent_exact_match_scores = [] absent_partial_match_scores = [] # absent_mixed_match_scores = [] # assert len(valid_pred_seqs) == len(match_scores_exact) == len(present_pred_flags) # assert len(present_preds) == len(present_exact_match_scores) == len(present_partial_match_scores) == len(present_mixed_match_scores) # assert present_partial_match_scores.shape == present_mixed_match_scores.shape # assert len(absent_preds) == len(absent_exact_match_scores) == len(absent_partial_match_scores) == len(absent_mixed_match_scores) # assert absent_partial_match_scores.shape == absent_mixed_match_scores.shape """ 2. Compute metrics """ # get the scores on different scores (for absent results, only recall matters) all_exact_results = run_classic_metrics(valid_match_scores_exact, valid_preds, tgt_seqs, metric_names, topk_range) present_exact_results = run_classic_metrics(present_exact_match_scores, valid_present_preds, present_tgts, metric_names, topk_range) absent_exact_results = run_classic_metrics(absent_exact_match_scores, valid_absent_preds, absent_tgts, metric_names, absent_topk_range) all_partial_results = run_classic_metrics(valid_match_scores_partial, valid_preds, tgt_seqs, metric_names, topk_range, type='partial') present_partial_results = run_classic_metrics( present_partial_match_scores, valid_present_preds, present_tgts, metric_names, topk_range, type='partial') absent_partial_results = run_classic_metrics( absent_partial_match_scores, valid_absent_preds, absent_tgts, metric_names, absent_topk_range, type='partial') # present_mixed_results = run_metrics(present_mixed_match_scores, present_preds, present_tgts, metric_names, topk_range, type='partial') # absent_mixed_results = run_metrics(absent_mixed_match_scores, absent_preds, absent_tgts, metric_names, absent_topk_range, type='partial') all_exact_advanced_results = run_advanced_metrics( valid_match_scores_exact, valid_preds, tgt_seqs) present_exact_advanced_results = run_advanced_metrics( present_exact_match_scores, valid_present_preds, present_tgts) absent_exact_advanced_results = run_advanced_metrics( absent_exact_match_scores, valid_absent_preds, absent_tgts) # print(advanced_present_exact_results) # print(advanced_absent_exact_results) """ 3. Gather scores """ eval_results_names = [ 'all_exact', 'all_partial', 'present_exact', 'absent_exact', 'present_partial', 'absent_partial', # 'present_mixed', 'absent_mixed' 'all_exact_advanced', 'present_exact_advanced', 'absent_exact_advanced', ] eval_results_list = [ all_exact_results, all_partial_results, present_exact_results, absent_exact_results, present_partial_results, absent_partial_results, # present_mixed_results, absent_mixed_results all_exact_advanced_results, present_exact_advanced_results, absent_exact_advanced_results ] # update score_dict, appending new scores (results_list) to it individual_score_dict = { result_name: results for result_name, results in zip(eval_results_names, eval_results_list) } gathered_score_dict = gather_scores(gathered_score_dict, eval_results_names, eval_results_list) # add tgt/pred count for computing average performance on non-empty items stats_results_names = [ 'present_tgt_num', 'absent_tgt_num', 'present_pred_num', 'absent_pred_num', 'unique_pred_num', 'dup_pred_num', 'beam_num', 'beamstep_num' ] stats_results_list = [ { 'present_tgt_num': len(present_tgts) }, { 'absent_tgt_num': len(absent_tgts) }, { 'present_pred_num': len(valid_present_preds) }, { 'absent_pred_num': len(valid_absent_preds) }, # TODO some stat should be calculated here since exhaustive/self-terminating makes difference { 'unique_pred_num': pred_dict['unique_pred_num'] if 'unique_pred_num' in pred_dict else 0 }, { 'dup_pred_num': pred_dict['dup_pred_num'] if 'dup_pred_num' in pred_dict else 0 }, { 'beam_num': pred_dict['beam_num'] if 'beam_num' in pred_dict else 0 }, { 'beamstep_num': pred_dict['beamstep_num'] if 'beamstep_num' in pred_dict else 0 }, ] for result_name, result_dict in zip(stats_results_names, stats_results_list): individual_score_dict[result_name] = result_dict[result_name] gathered_score_dict = gather_scores(gathered_score_dict, stats_results_names, stats_results_list) # individual_score_dicts.append(individual_score_dict) """ 4. Print results if necessary """ if verbose or report_file: print_out = print_predeval_result( i, ' '.join(src_seq), tgt_seqs, present_tgt_flags, pred_seqs, pred_scores, pred_idxs, copied_flags, present_pred_flags, valid_pred_flags, valid_and_present_flags, valid_and_absent_flags, match_scores_exact, match_scores_partial, eval_results_names, eval_results_list, gathered_score_dict) if verbose: if logger: logger.info(print_out) else: print(print_out) if report_file: report_file.write(print_out) # for k, v in score_dict.items(): # print('%s, num=%d, mean=%f' % (k, len(v), np.average(v))) if report_file: report_file.close() return gathered_score_dict