def obtain_sorted_indices(src, tgt_seqs, sort_by):
    """
    :param src: used for verbatim and alphabetical
    :param tgt_seqs:
    :param sort_by:
    :param absent_pos: must be one of [prepend, append and ignore], ignore means simply drop absent kps
    :return:
    """
    num_tgt = len(tgt_seqs)
    src = src[0]
    tgt_seqs = [tgt[0] for tgt in tgt_seqs]

    if sort_by == 'no_sort':
        sorted_id = list(range(len(tgt_seqs)))
    elif sort_by == 'random':
        sorted_id = np.random.permutation(num_tgt)
    elif sort_by.startswith('verbatim'):
        # obtain present flags as well their positions, lowercase should be done beforehand
        present_tgt_flags, present_indices, _ = if_present_duplicate_phrases(
            src, tgt_seqs, stemming=False, lowercase=False)
        # separate present/absent phrases
        present_tgt_idx = np.arange(num_tgt)[present_tgt_flags]
        absent_tgt_idx = [
            t_id for t_id, present in zip(range(num_tgt), present_tgt_flags)
            if ~present
        ]
        absent_tgt_idx = np.random.permutation(absent_tgt_idx)
        # sort present phrases by their positions
        present_indices = present_indices[present_tgt_flags]
        present_tgt_idx = sorted(zip(present_tgt_idx, present_indices),
                                 key=lambda x: x[1])
        present_tgt_idx = [t[0] for t in present_tgt_idx]

        if sort_by.endswith('append'):
            sorted_id = np.concatenate((present_tgt_idx, absent_tgt_idx),
                                       axis=None)
        elif sort_by.endswith('prepend'):
            sorted_id = np.concatenate((absent_tgt_idx, present_tgt_idx),
                                       axis=None)
        else:
            sorted_id = present_tgt_idx

    elif sort_by == 'alphabetical':
        sorted_tgts = sorted(enumerate(tgt_seqs), key=lambda x: '_'.join(x[1]))
        sorted_id = [t[0] for t in sorted_tgts]
    elif sort_by == 'length':
        sorted_tgts = sorted(enumerate(tgt_seqs), key=lambda x: len(x[1]))
        sorted_id = [t[0] for t in sorted_tgts]

    return np.asarray(sorted_id, dtype=int)
def evaluate(src_list,
             tgt_list,
             pred_list,
             unk_token,
             logger=None,
             verbose=False,
             report_path=None,
             eval_topbeam=False):
    # progbar = Progbar(logger=logger, title='', target=len(pred_list), total_examples=len(pred_list))

    if report_path:
        report_file = open(report_path, 'w+')
    else:
        report_file = None
    # 'k' means the number of phrases in ground-truth
    topk_range = [5, 10, 'k', 'M']
    absent_topk_range = [10, 30, 50]
    # 'precision_hard' and 'f_score_hard' mean that precision is calculated with denominator strictly as K (say 5 or 10), won't be lessened even number of preds is smaller
    metric_names = [
        'correct', 'precision', 'recall', 'f_score', 'precision_hard',
        'f_score_hard'
    ]

    score_dict = {
    }  # {'precision@5':[],'recall@5':[],'f1score@5':[], 'precision@10':[],'recall@10':[],'f1score@10':[]}
    '''
    process each example in current batch
    '''
    for i, (src_dict, tgt_dict, pred_dict) in tqdm.tqdm(
            enumerate(zip(src_list, tgt_list, pred_list))):
        src_seq = src_dict["src"].split()
        tgt_seqs = [t.split() for t in tgt_dict["tgt"]]

        if eval_topbeam:
            pred_sents = pred_dict["topseq_pred_sents"]
            pred_idxs = pred_dict[
                "topseq_preds"] if "topseq_preds" in pred_dict else None
            pred_scores = pred_dict[
                "topseq_pred_scores"] if "topseq_pred_scores" in pred_dict else None
            copied_flags = None
        else:
            pred_sents = pred_dict["pred_sents"]
            pred_idxs = pred_dict["preds"] if "preds" in pred_dict else None
            pred_scores = pred_dict[
                "pred_scores"] if "pred_scores" in pred_dict else None
            copied_flags = pred_dict[
                "copied_flags"] if "copied_flags" in pred_dict else None

        # src, src_str, tgt, tgt_str_seqs, tgt_copy, pred_seq, oov
        print_out = '======================  %d =========================' % (
            i)
        print_out += '\n[Title]: %s \n' % (src_dict["title"])
        print_out += '[Abstract]: %s \n' % (src_dict["abstract"])
        # print_out += '[Source tokenized][%d]: %s \n' % (len(src_seq), ' '.join(src_seq))
        # print_out += 'Real Target [%d] \n\t\t%s \n' % (len(tgt_seqs), str(tgt_seqs))

        # check which phrases are present in source text
        present_tgt_flags, _, _ = if_present_duplicate_phrases(
            src_seq, tgt_seqs)

        print_out += '[GROUND-TRUTH] #(present)/#(all targets)=%d/%d\n' % (
            sum(present_tgt_flags), len(present_tgt_flags))
        print_out += '\n'.join([
            '\t\t[%s]' % ' '.join(phrase) if is_present else '\t\t%s' %
            ' '.join(phrase)
            for phrase, is_present in zip(tgt_seqs, present_tgt_flags)
        ])

        # 1st filtering, ignore phrases having <unk> and puncs
        valid_pred_flags = process_predseqs(pred_sents, unk_token)
        # 2nd filtering: if filter out phrases that don't appear in text, and keep unique ones after stemming
        present_pred_flags, _, duplicate_flags = if_present_duplicate_phrases(
            src_seq, pred_sents)
        # treat duplicates as invalid
        valid_pred_flags = valid_pred_flags * ~duplicate_flags if len(
            valid_pred_flags) > 0 else []
        valid_and_present = valid_pred_flags * present_pred_flags if len(
            valid_pred_flags) > 0 else []
        print_out += '\n[PREDICTION] #(valid)=%d, #(present)=%d, #(retained&present)=%d, #(all)=%d\n' % (
            sum(valid_pred_flags), sum(present_pred_flags),
            sum(valid_and_present), len(pred_sents))
        print_out += ''

        # compute match scores (exact, partial and mixed), for exact it's a list otherwise matrix
        match_scores_exact = get_match_result(true_seqs=tgt_seqs,
                                              pred_seqs=pred_sents,
                                              type='exact')
        match_scores_partial = get_match_result(true_seqs=tgt_seqs,
                                                pred_seqs=pred_sents,
                                                type='ngram')
        # simply add full-text to n-grams might not be good as its contribution is not clear
        match_scores_mixed = get_match_result(true_seqs=tgt_seqs,
                                              pred_seqs=pred_sents,
                                              type='mixed')

        # sanity check of pred (does not work for eval_topbeam, discard)
        # num_pred = len(pred_dict["pred_sents"])
        # for d_name, d in zip(['pred_idxs', 'pred_sents', 'pred_scores',
        #               'match_scores_exact', 'valid_pred_flags',
        #               'present_pred_flags', 'copied_flags'],
        #               [pred_idxs, pred_sents, pred_scores,
        #               match_scores_exact, valid_pred_flags,
        #               present_pred_flags, copied_flags]):
        #     if d is not None:
        #         if len(d) != num_pred:
        #             logger.error('%s number does not match' % d_name)
        #         assert len(d) == num_pred
        '''
        Print and export predictions
        '''
        preds_out = ''
        for p_id, (word, match, match_soft, is_valid, is_present) in enumerate(
                zip(pred_sents, match_scores_exact, match_scores_partial,
                    valid_pred_flags, present_pred_flags)):
            # if p_id > 5:
            #     break
            score = pred_scores[p_id] if pred_scores else "Score N/A"
            pred_idx = pred_idxs[p_id] if pred_idxs else "Index N/A"
            copied_flag = copied_flags[p_id] if copied_flags else "CopyFlag N/A"

            preds_out += '%s\n' % (' '.join(word))
            if is_present:
                print_phrase = '[%s]' % ' '.join(word)
            else:
                print_phrase = ' '.join(word)

            if match == 1.0:
                correct_str = '[correct!]'
            else:
                correct_str = ''

            if any(copied_flag):
                copy_str = '[copied!]'
            else:
                copy_str = ''

            pred_str = '\t\t%s\t%s \t %s %s%s\n' % (
                '[%.4f]' % (-score) if pred_scores else "Score N/A",
                print_phrase, str(pred_idx), correct_str, copy_str)
            if not is_valid:
                pred_str = '\t%s' % pred_str

            print_out += pred_str

        # split tgts by present/absent
        present_tgts = [
            tgt for tgt, present in zip(tgt_seqs, present_tgt_flags) if present
        ]
        absent_tgts = [
            tgt for tgt, present in zip(tgt_seqs, present_tgt_flags)
            if ~present
        ]

        # filter out results of invalid preds
        valid_pred_sents = [
            seq for seq, valid in zip(pred_sents, valid_pred_flags) if valid
        ]
        present_pred_flags = present_pred_flags[valid_pred_flags]

        match_scores_exact = match_scores_exact[valid_pred_flags]
        match_scores_partial = match_scores_partial[valid_pred_flags]
        match_scores_mixed = match_scores_mixed[valid_pred_flags]

        # split preds by present/absent and exact/partial/mixed
        present_preds = [
            pred for pred, present in zip(valid_pred_sents, present_pred_flags)
            if present
        ]
        absent_preds = [
            pred for pred, present in zip(valid_pred_sents, present_pred_flags)
            if ~present
        ]
        if len(present_pred_flags) > 0:
            present_exact_match_scores = match_scores_exact[present_pred_flags]
            present_partial_match_scores = match_scores_partial[
                present_pred_flags][:, present_tgt_flags]
            present_mixed_match_scores = match_scores_mixed[
                present_pred_flags][:, present_tgt_flags]
            absent_exact_match_scores = match_scores_exact[~present_pred_flags]
            absent_partial_match_scores = match_scores_partial[
                ~present_pred_flags][:, ~present_tgt_flags]
            absent_mixed_match_scores = match_scores_mixed[
                ~present_pred_flags][:, ~present_tgt_flags]
        else:
            present_exact_match_scores = []
            present_partial_match_scores = []
            present_mixed_match_scores = []
            absent_exact_match_scores = []
            absent_partial_match_scores = []
            absent_mixed_match_scores = []

        # assert len(valid_pred_sents) == len(match_scores_exact) == len(present_pred_flags)
        # assert len(present_preds) == len(present_exact_match_scores) == len(present_partial_match_scores) == len(present_mixed_match_scores)
        # assert present_partial_match_scores.shape == present_mixed_match_scores.shape
        # assert len(absent_preds) == len(absent_exact_match_scores) == len(absent_partial_match_scores) == len(absent_mixed_match_scores)
        # assert absent_partial_match_scores.shape == absent_mixed_match_scores.shape

        # Compute metrics
        print_out += "\n ======================================================="
        # get the scores on different scores (for absent results, only recall matters)
        present_exact_results = run_metrics(present_exact_match_scores,
                                            present_preds, present_tgts,
                                            metric_names, topk_range)
        absent_exact_results = run_metrics(absent_exact_match_scores,
                                           absent_preds, absent_tgts,
                                           metric_names, absent_topk_range)
        present_partial_results = run_metrics(present_partial_match_scores,
                                              present_preds,
                                              present_tgts,
                                              metric_names,
                                              topk_range,
                                              type='partial')
        absent_partial_results = run_metrics(absent_partial_match_scores,
                                             absent_preds,
                                             absent_tgts,
                                             metric_names,
                                             absent_topk_range,
                                             type='partial')

        # present_mixed_results = run_metrics(present_mixed_match_scores, present_preds, present_tgts, metric_names, topk_range, type='partial')
        # absent_mixed_results = run_metrics(absent_mixed_match_scores, absent_preds, absent_tgts, metric_names, absent_topk_range, type='partial')

        def _gather_scores(gathered_scores, results_names, results_dicts):
            for result_name, result_dict in zip(results_names, results_dicts):
                for metric_name, score in result_dict.items():
                    if metric_name.endswith('_num'):
                        # if it's 'present_tgt_num' or 'absent_tgt_num', leave as is
                        field_name = result_name
                    else:
                        # if it's other score like 'precision@5' is renamed to like 'present_exact_precision@'
                        field_name = result_name + '_' + metric_name

                    if field_name not in gathered_scores:
                        gathered_scores[field_name] = []

                    gathered_scores[field_name].append(score)

            return gathered_scores

        results_names = [
            'present_exact',
            'absent_exact',
            'present_partial',
            'absent_partial',
            # 'present_mixed', 'absent_mixed'
        ]
        results_list = [
            present_exact_results,
            absent_exact_results,
            present_partial_results,
            absent_partial_results,
            # present_mixed_results, absent_mixed_results
        ]
        # update score_dict, appending new scores (results_list) to it
        score_dict = _gather_scores(score_dict, results_names, results_list)

        for name, resutls in zip(results_names, results_list):
            if name.startswith('present'):
                topk = 5
            else:
                topk = 50

            print_out += "\n --- batch {} P/R/F1/Corr @{}: \t".format(name, topk) \
                         + " {:.4f} , {:.4f} , {:.4f} , {:2f}".format(resutls['precision@{}'.format(topk)],
                                                           resutls['recall@{}'.format(topk)],
                                                           resutls['f_score@{}'.format(topk)],
                                                           resutls['correct@{}'.format(topk)])
            print_out += "\n --- total {} P/R/F1/Corr @{}: \t".format(name, topk) \
                         + " {:.4f} , {:.4f} , {:.4f} , {:2f}".format(np.average(score_dict['{}_precision@{}'.format(name, topk)]),
                                                           np.average(score_dict['{}_recall@{}'.format(name, topk)]),
                                                           np.average(score_dict['{}_f_score@{}'.format(name, topk)]),
                                                           np.sum(score_dict['{}_correct@{}'.format(name, topk)]))

            if name.startswith('present'):
                topk = 10
                print_out += "\n --- batch {} P/R/F1/Corr @{}: \t".format(name, topk) \
                             + " {:.4f} , {:.4f} , {:.4f} , {:2f}".format(resutls['precision@{}'.format(topk)],
                                                               resutls['recall@{}'.format(topk)],
                                                               resutls['f_score@{}'.format(topk)],
                                                               resutls['correct@{}'.format(topk)])
                print_out += "\n --- total {} P/R/F1/Corr @{}: \t".format(name, topk) \
                             + " {:.4f} , {:.4f} , {:.4f} , {:2f}".format(np.average(score_dict['{}_precision@{}'.format(name, topk)]),
                                                               np.average(score_dict['{}_recall@{}'.format(name, topk)]),
                                                               np.average(score_dict['{}_f_score@{}'.format(name, topk)]),
                                                               np.sum(score_dict['{}_correct@{}'.format(name, topk)]))

        print_out += "\n ======================================================="

        if verbose:
            if logger:
                logger.info(print_out)
            else:
                print(print_out)

        if report_file:
            report_file.write(print_out)

        # add tgt/pred count for computing average performance on non-empty items
        results_names = [
            'present_tgt_num', 'absent_tgt_num', 'present_pred_num',
            'absent_pred_num', 'unique_pred_num', 'dup_pred_num', 'beam_num',
            'beamstep_num'
        ]
        results_list = [
            {
                'present_tgt_num': len(present_tgts)
            },
            {
                'absent_tgt_num': len(absent_tgts)
            },
            {
                'present_pred_num': len(present_preds)
            },
            {
                'absent_pred_num': len(absent_preds)
            },
            {
                'unique_pred_num':
                pred_dict['unique_pred_num']
                if 'unique_pred_num' in pred_dict else 0
            },
            {
                'dup_pred_num':
                pred_dict['dup_pred_num'] if 'dup_pred_num' in pred_dict else 0
            },
            {
                'beam_num':
                pred_dict['beam_num'] if 'beam_num' in pred_dict else 0
            },
            {
                'beamstep_num':
                pred_dict['beamstep_num'] if 'beamstep_num' in pred_dict else 0
            },
        ]
        score_dict = _gather_scores(score_dict, results_names, results_list)

    # for k, v in score_dict.items():
    #     print('%s, num=%d, mean=%f' % (k, len(v), np.average(v)))

    if report_file:
        report_file.close()

    return score_dict
Example #3
0
def eval_and_print(src_text,
                   tgt_kps,
                   pred_kps,
                   pred_scores,
                   unk_token='<unk>'):
    src_seq = [
        t.text.lower() for t in spacy_nlp(src_text, disable=["textcat"])
    ]
    tgt_seqs = [[t.text.lower() for t in spacy_nlp(p, disable=["textcat"])]
                for p in tgt_kps]
    pred_seqs = [[t.text.lower() for t in spacy_nlp(p, disable=["textcat"])]
                 for p in pred_kps]

    topk_range = ['k', 10]
    absent_topk_range = ['M']
    metric_names = ['f_score']

    # 1st filtering, ignore phrases having <unk> and puncs
    valid_pred_flags = validate_phrases(pred_seqs, unk_token)
    # 2nd filtering: filter out phrases that don't appear in text, and keep unique ones after stemming
    present_pred_flags, _, duplicate_flags = if_present_duplicate_phrases(
        src_seq, pred_seqs)
    # treat duplicates as invalid
    valid_pred_flags = valid_pred_flags * ~duplicate_flags if len(
        valid_pred_flags) > 0 else []
    valid_and_present_flags = valid_pred_flags * present_pred_flags if len(
        valid_pred_flags) > 0 else []
    valid_and_absent_flags = valid_pred_flags * ~present_pred_flags if len(
        valid_pred_flags) > 0 else []

    # compute match scores (exact, partial and mixed), for exact it's a list otherwise matrix
    match_scores_exact = compute_match_scores(tgt_seqs=tgt_seqs,
                                              pred_seqs=pred_seqs,
                                              do_lower=True,
                                              do_stem=True,
                                              type='exact')
    # split tgts by present/absent
    present_tgt_flags, _, _ = if_present_duplicate_phrases(src_seq, tgt_seqs)
    present_tgts = [
        tgt for tgt, present in zip(tgt_seqs, present_tgt_flags) if present
    ]
    absent_tgts = [
        tgt for tgt, present in zip(tgt_seqs, present_tgt_flags) if ~present
    ]

    # filter out results of invalid preds
    valid_preds = [
        seq for seq, valid in zip(pred_seqs, valid_pred_flags) if valid
    ]
    valid_present_pred_flags = present_pred_flags[valid_pred_flags]

    valid_match_scores_exact = match_scores_exact[valid_pred_flags]

    # split preds by present/absent and exact/partial/mixed
    valid_present_preds = [
        pred for pred, present in zip(valid_preds, valid_present_pred_flags)
        if present
    ]
    valid_absent_preds = [
        pred for pred, present in zip(valid_preds, valid_present_pred_flags)
        if ~present
    ]
    present_exact_match_scores = valid_match_scores_exact[
        valid_present_pred_flags]
    absent_exact_match_scores = valid_match_scores_exact[
        ~valid_present_pred_flags]

    all_exact_results = run_classic_metrics(valid_match_scores_exact,
                                            valid_preds, tgt_seqs,
                                            metric_names, topk_range)
    present_exact_results = run_classic_metrics(present_exact_match_scores,
                                                valid_present_preds,
                                                present_tgts, metric_names,
                                                topk_range)
    absent_exact_results = run_classic_metrics(absent_exact_match_scores,
                                               valid_absent_preds, absent_tgts,
                                               metric_names, absent_topk_range)

    eval_results_names = ['all_exact', 'present_exact', 'absent_exact']
    eval_results_list = [
        all_exact_results, present_exact_results, absent_exact_results
    ]

    print_out = print_predeval_result(
        src_text, tgt_seqs, present_tgt_flags, pred_seqs, pred_scores,
        present_pred_flags, valid_pred_flags, valid_and_present_flags,
        valid_and_absent_flags, match_scores_exact, eval_results_names,
        eval_results_list)
    return print_out
def evaluate(src_list,
             tgt_list,
             pred_list,
             unk_token,
             logger=None,
             verbose=False,
             report_path=None,
             tokenizer=None):
    if report_path:
        report_file = open(report_path, 'w+')
    else:
        report_file = None
    # 'k' means the number of phrases in ground-truth, add 1,3 for openkp
    topk_range = [5, 10, 'k', 'M', 1, 3]
    absent_topk_range = [10, 50, 'k', 'M']
    # 'precision_hard' and 'f_score_hard' mean that precision is calculated with denominator strictly as K (say 5 or 10), won't be lessened even number of preds is smaller
    metric_names = [
        'correct', 'precision', 'recall', 'f_score', 'precision_hard',
        'f_score_hard'
    ]

    individual_score_dicts = [
    ]  # {'precision@5':[],'recall@5':[],'f1score@5':[], 'precision@10':[],'recall@10':[],'f1score@10':[]}
    gathered_score_dict = {
    }  # {'precision@5':[],'recall@5':[],'f1score@5':[], 'precision@10':[],'recall@10':[],'f1score@10':[]}

    # for i, (src_dict, tgt_dict, pred_dict) in tqdm.tqdm(enumerate(zip(src_list, tgt_list, pred_list))):
    for i, (src_dict, tgt_dict, pred_dict) in tqdm.tqdm(
            enumerate(zip(src_list, tgt_list, pred_list))):
        """
        1. Process each data example and predictions
        """
        pred_seqs = pred_dict["pred_sents"]
        pred_idxs = pred_dict["preds"] if "preds" in pred_dict else None
        pred_scores = pred_dict[
            "pred_scores"] if "pred_scores" in pred_dict else None
        copied_flags = pred_dict[
            "copied_flags"] if "copied_flags" in pred_dict else None

        # @memray 20200316 change to spacy tokenization rather than simple splitting or Meng's tokenization
        if tokenizer == 'spacy':
            src_seq = [
                t.text for t in spacy_nlp(src_dict["src"], disable=["textcat"])
            ]
            tgt_seqs = [[t.text for t in spacy_nlp(p, disable=["textcat"])]
                        for p in tgt_dict["tgt"]]
            pred_seqs = [[
                t.text for t in spacy_nlp(' '.join(p), disable=["textcat"])
            ] for p in pred_seqs]
            unk_token = 'unk'
        elif tokenizer == 'split':
            src_seq = src_dict["src"].split()
            tgt_seqs = [t.split() for t in tgt_dict["tgt"]]
            pred_seqs = pred_dict["pred_sents"]
        elif tokenizer == 'split_nopunc':
            src_seq = [
                t for t in re.split(r'\W', src_dict["src"]) if len(t) > 0
            ]
            tgt_seqs = [[t for t in re.split(r'\W', p) if len(t) > 0]
                        for p in tgt_dict["tgt"]]
            pred_seqs = [[
                t for t in re.split(r'\W', ' '.join(p)) if len(t) > 0
            ] for p in pred_dict["pred_sents"]]
            unk_token = 'unk'
        else:
            raise Exception(
                'Unset or unsupported tokenizer for evaluation: %s' %
                str(tokenizer))

        # 1st filtering, ignore phrases having <unk> and puncs
        valid_pred_flags = validate_phrases(pred_seqs, unk_token)
        # 2nd filtering: filter out phrases that don't appear in text, and keep unique ones after stemming
        present_pred_flags, _, duplicate_flags = if_present_duplicate_phrases(
            src_seq, pred_seqs, stemming=True, lowercase=True)
        # treat duplicates as invalid
        valid_pred_flags = valid_pred_flags * ~duplicate_flags if len(
            valid_pred_flags) > 0 else []
        valid_and_present_flags = valid_pred_flags * present_pred_flags if len(
            valid_pred_flags) > 0 else []
        valid_and_absent_flags = valid_pred_flags * ~present_pred_flags if len(
            valid_pred_flags) > 0 else []

        # compute match scores (exact, partial and mixed), for exact it's a list otherwise matrix
        match_scores_exact = compute_match_scores(tgt_seqs=tgt_seqs,
                                                  pred_seqs=pred_seqs,
                                                  do_lower=True,
                                                  do_stem=True,
                                                  type='exact')
        match_scores_partial = compute_match_scores(tgt_seqs=tgt_seqs,
                                                    pred_seqs=pred_seqs,
                                                    do_lower=True,
                                                    do_stem=True,
                                                    type='ngram')
        # simply add full-text to n-grams might not be good as its contribution is not clear
        # match_scores_mixed = compute_match_scores(tgt_seqs=tgt_seqs, pred_seqs=pred_seqs, type='mixed')

        # split tgts by present/absent
        present_tgt_flags, _, _ = if_present_duplicate_phrases(src_seq,
                                                               tgt_seqs,
                                                               stemming=True,
                                                               lowercase=True)
        present_tgts = [
            tgt for tgt, present in zip(tgt_seqs, present_tgt_flags) if present
        ]
        absent_tgts = [
            tgt for tgt, present in zip(tgt_seqs, present_tgt_flags)
            if ~present
        ]

        # filter out results of invalid preds
        valid_preds = [
            seq for seq, valid in zip(pred_seqs, valid_pred_flags) if valid
        ]
        valid_present_pred_flags = present_pred_flags[valid_pred_flags]

        valid_match_scores_exact = match_scores_exact[valid_pred_flags]
        valid_match_scores_partial = match_scores_partial[valid_pred_flags]
        # match_scores_mixed = match_scores_mixed[valid_pred_flags]

        # split preds by present/absent and exact/partial/mixed
        valid_present_preds = [
            pred
            for pred, present in zip(valid_preds, valid_present_pred_flags)
            if present
        ]
        valid_absent_preds = [
            pred
            for pred, present in zip(valid_preds, valid_present_pred_flags)
            if ~present
        ]
        if len(valid_present_pred_flags) > 0:
            present_exact_match_scores = valid_match_scores_exact[
                valid_present_pred_flags]
            present_partial_match_scores = valid_match_scores_partial[
                valid_present_pred_flags][:, present_tgt_flags]
            # present_mixed_match_scores = match_scores_mixed[present_pred_flags][:, present_tgt_flags]
            absent_exact_match_scores = valid_match_scores_exact[
                ~valid_present_pred_flags]
            absent_partial_match_scores = valid_match_scores_partial[
                ~valid_present_pred_flags][:, ~present_tgt_flags]
            # absent_mixed_match_scores = match_scores_mixed[~present_pred_flags][:, ~present_tgt_flags]
        else:
            present_exact_match_scores = []
            present_partial_match_scores = []
            # present_mixed_match_scores = []
            absent_exact_match_scores = []
            absent_partial_match_scores = []
            # absent_mixed_match_scores = []

        # assert len(valid_pred_seqs) == len(match_scores_exact) == len(present_pred_flags)
        # assert len(present_preds) == len(present_exact_match_scores) == len(present_partial_match_scores) == len(present_mixed_match_scores)
        # assert present_partial_match_scores.shape == present_mixed_match_scores.shape
        # assert len(absent_preds) == len(absent_exact_match_scores) == len(absent_partial_match_scores) == len(absent_mixed_match_scores)
        # assert absent_partial_match_scores.shape == absent_mixed_match_scores.shape
        """
        2. Compute metrics
        """
        # get the scores on different scores (for absent results, only recall matters)
        all_exact_results = run_classic_metrics(valid_match_scores_exact,
                                                valid_preds, tgt_seqs,
                                                metric_names, topk_range)
        present_exact_results = run_classic_metrics(present_exact_match_scores,
                                                    valid_present_preds,
                                                    present_tgts, metric_names,
                                                    topk_range)
        absent_exact_results = run_classic_metrics(absent_exact_match_scores,
                                                   valid_absent_preds,
                                                   absent_tgts, metric_names,
                                                   absent_topk_range)

        all_partial_results = run_classic_metrics(valid_match_scores_partial,
                                                  valid_preds,
                                                  tgt_seqs,
                                                  metric_names,
                                                  topk_range,
                                                  type='partial')
        present_partial_results = run_classic_metrics(
            present_partial_match_scores,
            valid_present_preds,
            present_tgts,
            metric_names,
            topk_range,
            type='partial')
        absent_partial_results = run_classic_metrics(
            absent_partial_match_scores,
            valid_absent_preds,
            absent_tgts,
            metric_names,
            absent_topk_range,
            type='partial')
        # present_mixed_results = run_metrics(present_mixed_match_scores, present_preds, present_tgts, metric_names, topk_range, type='partial')
        # absent_mixed_results = run_metrics(absent_mixed_match_scores, absent_preds, absent_tgts, metric_names, absent_topk_range, type='partial')

        all_exact_advanced_results = run_advanced_metrics(
            valid_match_scores_exact, valid_preds, tgt_seqs)
        present_exact_advanced_results = run_advanced_metrics(
            present_exact_match_scores, valid_present_preds, present_tgts)
        absent_exact_advanced_results = run_advanced_metrics(
            absent_exact_match_scores, valid_absent_preds, absent_tgts)
        # print(advanced_present_exact_results)
        # print(advanced_absent_exact_results)
        """
        3. Gather scores
        """
        eval_results_names = [
            'all_exact',
            'all_partial',
            'present_exact',
            'absent_exact',
            'present_partial',
            'absent_partial',
            # 'present_mixed', 'absent_mixed'
            'all_exact_advanced',
            'present_exact_advanced',
            'absent_exact_advanced',
        ]
        eval_results_list = [
            all_exact_results,
            all_partial_results,
            present_exact_results,
            absent_exact_results,
            present_partial_results,
            absent_partial_results,
            # present_mixed_results, absent_mixed_results
            all_exact_advanced_results,
            present_exact_advanced_results,
            absent_exact_advanced_results
        ]
        # update score_dict, appending new scores (results_list) to it
        individual_score_dict = {
            result_name: results
            for result_name, results in zip(eval_results_names,
                                            eval_results_list)
        }
        gathered_score_dict = gather_scores(gathered_score_dict,
                                            eval_results_names,
                                            eval_results_list)

        # add tgt/pred count for computing average performance on non-empty items
        stats_results_names = [
            'present_tgt_num', 'absent_tgt_num', 'present_pred_num',
            'absent_pred_num', 'unique_pred_num', 'dup_pred_num', 'beam_num',
            'beamstep_num'
        ]
        stats_results_list = [
            {
                'present_tgt_num': len(present_tgts)
            },
            {
                'absent_tgt_num': len(absent_tgts)
            },
            {
                'present_pred_num': len(valid_present_preds)
            },
            {
                'absent_pred_num': len(valid_absent_preds)
            },
            # TODO some stat should be calculated here since exhaustive/self-terminating makes difference
            {
                'unique_pred_num':
                pred_dict['unique_pred_num']
                if 'unique_pred_num' in pred_dict else 0
            },
            {
                'dup_pred_num':
                pred_dict['dup_pred_num'] if 'dup_pred_num' in pred_dict else 0
            },
            {
                'beam_num':
                pred_dict['beam_num'] if 'beam_num' in pred_dict else 0
            },
            {
                'beamstep_num':
                pred_dict['beamstep_num'] if 'beamstep_num' in pred_dict else 0
            },
        ]
        for result_name, result_dict in zip(stats_results_names,
                                            stats_results_list):
            individual_score_dict[result_name] = result_dict[result_name]
        gathered_score_dict = gather_scores(gathered_score_dict,
                                            stats_results_names,
                                            stats_results_list)
        # individual_score_dicts.append(individual_score_dict)
        """
        4. Print results if necessary
        """
        if verbose or report_file:
            print_out = print_predeval_result(
                i, ' '.join(src_seq), tgt_seqs, present_tgt_flags, pred_seqs,
                pred_scores, pred_idxs, copied_flags, present_pred_flags,
                valid_pred_flags, valid_and_present_flags,
                valid_and_absent_flags, match_scores_exact,
                match_scores_partial, eval_results_names, eval_results_list,
                gathered_score_dict)

            if verbose:
                if logger:
                    logger.info(print_out)
                else:
                    print(print_out)

            if report_file:
                report_file.write(print_out)

    # for k, v in score_dict.items():
    #     print('%s, num=%d, mean=%f' % (k, len(v), np.average(v)))

    if report_file:
        report_file.close()

    return gathered_score_dict