示例#1
0
 def _report_rouge(self, gold_path, can_path, logger=None):
     self.logger.info("Calculating Rouge")
     results_dict = test_rouge(self.args.temp_dir,
                               can_path,
                               gold_path,
                               logger=logger)
     return results_dict
示例#2
0
def report_rouge(logger, args, gold_path, can_path):
    logger.info("Calculating Rouge")
    rouges = test_rouge(args.temp_dir, can_path, gold_path)
    logger.info('Rouges:\n%s' % rouge_results_to_str(rouges))
    avg_f1 = avg_rouge_f1(rouges)
    logger.info('Average Rouge F1: %s' % avg_f1)
    return avg_f1
示例#3
0
    def iter_test(self, test_iter, step, sum_sent_count=3):
        """
           select sentences in each iteration
           given selected sentences, predict the next one
        """
        self.model.eval()
        stats = Statistics()
        #dir_name = os.path.dirname(self.args.result_path)
        base_name = os.path.basename(self.args.result_path)
        #base_dir = os.path.join(dir_name, 'iter_eval')
        base_dir = os.path.dirname(self.args.result_path)

        if (not os.path.exists(base_dir)):
            os.makedirs(base_dir)

        can_path = '%s/%s_step%d_itereval.candidate'%(base_dir, base_name, step)
        gold_path = '%s/%s_step%d_itereval.gold' % (base_dir, base_name, step)

        all_pred_ids, all_gold_ids, all_doc_ids = [], [], []
        all_gold_texts, all_pred_texts = [], []

        with torch.no_grad():
            for batch in test_iter:
                doc_ids = batch.doc_id
                oracle_ids = [set([j for j in seq if j > -1]) for seq in batch.label_seq.tolist()]

                sel_sent_idxs, sel_sent_masks = self.model.infer_sentences(batch, sum_sent_count, stats=stats)
                sel_sent_idxs = sel_sent_idxs.tolist()
                all_pred_ids.extend(sel_sent_idxs)

                for i in range(batch.batch_size):
                    _pred = '<q>'.join([batch.src_str[i][idx].strip() for j, idx in enumerate(sel_sent_idxs[i]) if sel_sent_masks[i][j]])
                    all_pred_texts.append(_pred)
                    all_gold_texts.append(batch.tgt_str[i])
                    all_gold_ids.append(oracle_ids[i])
                    all_doc_ids.append(doc_ids[i])
        macro_precision, micro_precision = self._output_predicted_summaries(
                all_doc_ids, all_pred_ids, all_gold_ids,
                all_pred_texts, all_gold_texts, can_path, gold_path)
        rouge1_arr, rouge2_arr = du.cal_rouge_score(all_pred_texts, all_gold_texts)
        rouge_1, rouge_2 = du.aggregate_rouge(rouge1_arr, rouge2_arr)
        logger.info('[PERF]At step %d: rouge1:%.2f rouge2:%.2f' % (
            step, rouge_1 * 100, rouge_2 * 100))
        if(step!=-1 and self.args.report_precision):
            macro_arr = ["P@%s:%.2f%%" % (i+1, macro_precision[i] * 100) for i in range(3)]
            micro_arr = ["P@%s:%.2f%%" % (i+1, micro_precision[i] * 100) for i in range(3)]
            logger.info('[PERF]MacroPrecision at step %d: %s' % (step, '\t'.join(macro_arr)))
            logger.info('[PERF]MicroPrecision at step %d: %s' % (step, '\t'.join(micro_arr)))
        if(step!=-1 and self.args.report_rouge):
            rouge_str, detail_rouge = test_rouge(self.args.temp_dir, can_path, gold_path, all_doc_ids, show_all=True)
            logger.info('[PERF]Rouges at step %d: %s \n' % (step, rouge_str))
            result_path = '%s_step%d_itereval.rouge' % (self.args.result_path, step)
            if detail_rouge is not None:
                du.output_rouge_file(result_path, rouge1_arr, rouge2_arr, detail_rouge, all_doc_ids)
        self._report_step(0, step, valid_stats=stats)

        return stats
示例#4
0
    def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s))>0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        can_path = '%s_step%d.candidate'%(self.args.result_path,step)
        gold_path = '%s_step%d.gold' % (self.args.result_path, step)
        with open(can_path, 'w') as save_pred:
            with open(gold_path, 'w') as save_gold:
                with torch.no_grad():
                    for batch in test_iter:
                        src = batch.src
                        labels = batch.labels
                        segs = batch.segs
                        clss = batch.clss
                        mask = batch.mask
                        mask_cls = batch.mask_cls


                        gold = []
                        pred = []

                        if (cal_lead):
                            selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size
                        elif (cal_oracle):
                            selected_ids = [[j for j in range(batch.clss.size(1)) if labels[i][j] == 1] for i in
                                            range(batch.batch_size)]
                        else:
                            sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)

                            loss = self.loss(sent_scores, labels.float())
                            loss = (loss * mask.float()).sum()
                            batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels))
                            stats.update(batch_stats)

                            sent_scores = sent_scores + mask.float()
                            sent_scores = sent_scores.cpu().data.numpy()
                            selected_ids = np.argsort(-sent_scores, 1)
                        # selected_ids = np.sort(selected_ids,1)
                        for i, idx in enumerate(selected_ids):
                            _pred = []

                            if(len(batch.src_str[i])==0):
                                continue

                            for j in selected_ids[i][:len(batch.src_str[i])]:
                                if(j>=len( batch.src_str[i])):
                                    continue
                                candidate = batch.src_str[i][j].strip()
                                if(self.args.block_trigram):
                                    if(not _block_tri(candidate,_pred)):
                                        _pred.append(candidate)
                                else:
                                    _pred.append(candidate)

                                if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3):
                                    break

                            _pred = '<q>'.join(_pred)
                            _pred=_pred+"   original txt:   "+" ".join(batch.src_str[i])
                            if(self.args.recall_eval):
                                _pred = ' '.join(_pred.split()[:len(batch.tgt_str[i].split())])

                            pred.append(_pred)
                            gold.append(batch.tgt_str[i])

                        for i in range(len(gold)):
                            save_gold.write(gold[i].strip()+'\n')
                        for i in range(len(pred)):
                            save_pred.write(pred[i].strip()+'\n')
        if(step!=-1 and self.args.report_rouge):
            rouges = test_rouge(self.args.temp_dir, can_path, gold_path)
            logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges)))
        self._report_step(0, step, valid_stats=stats)

        return stats
示例#5
0
    def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """

        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s)) > 0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        can_path = '%s_step%d.candidate' % (self.args.result_path, step)
        gold_path = '%s_step%d.gold' % (self.args.result_path, step)
        with open(can_path, 'w') as save_pred:
            with open(gold_path, 'w') as save_gold:
                with torch.no_grad():
                    for batch in test_iter:
                        gold = []
                        pred = []
                        if (cal_lead):
                            selected_ids = [list(range(batch.clss.size(1)))
                                            ] * batch.batch_size
                        for i, idx in enumerate(selected_ids):
                            _pred = []
                            if (len(batch.src_str[i]) == 0):
                                continue
                            for j in selected_ids[i][:len(batch.src_str[i])]:
                                if (j >= len(batch.src_str[i])):
                                    continue
                                candidate = batch.src_str[i][j].strip()
                                _pred.append(candidate)

                                if ((not cal_oracle)
                                        and (not self.args.recall_eval)
                                        and len(_pred) == 3):
                                    break

                            _pred = '<q>'.join(_pred)
                            if (self.args.recall_eval):
                                _pred = ' '.join(
                                    _pred.split()
                                    [:len(batch.tgt_str[i].split())])

                            pred.append(_pred)
                            gold.append(batch.tgt_str[i])

                        for i in range(len(gold)):
                            save_gold.write(gold[i].strip() + '\n')
                        for i in range(len(pred)):
                            save_pred.write(pred[i].strip() + '\n')
        if (step != -1 and self.args.report_rouge):
            rouges = test_rouge(self.args.temp_dir, can_path, gold_path)
            logger.info('Rouges at step %d \n%s' %
                        (step, rouge_results_to_str(rouges)))
        self._report_step(0, step, valid_stats=stats)

        return stats
示例#6
0
 def _report_rouge(self, gold_path, can_path):
     """  Calculate ROUGE scores
     """
     self.logger.info("Calculating Rouge")
     results_dict = test_rouge(self.args.temp_dir, can_path, gold_path)
     return results_dict
示例#7
0
from others.utils import test_rouge, rouge_results_to_str
import os

rouges = {}
sum = {}
n = 0
for i in ["", "1", "2", "3"]:
    can_path = f'results/candidate{i}'
    gold_path = f'results/gold{i}'
    if os.path.exists(can_path) and os.path.exists(gold_path):
        print(f'{"*" * 40}\n{can_path} <--> {gold_path}')
        rouges[i] = test_rouge("results", can_path, gold_path)
        if i is not "":
            for key in rouges[i]:
                if key in sum:
                    sum[key] += rouges[i][key]
                else:
                    sum[key] = rouges[i][key]
            n += 1
        print(f'Rouges of results/candidate{i}: \n{rouge_results_to_str(rouges[i])}')

print('*' * 10 + ' Summary ' + '*' * 10)

for key in rouges:
    print(f'Rouges of results/candidate{key}: \n{rouge_results_to_str(rouges[key])}')

if n > 0:
    print(f'Sum of rouges: \n{rouge_results_to_str(sum)}')
    for key in sum:
        sum[key] /= n
    print(f'Average of {n} rouges: \n{rouge_results_to_str(sum)}')
示例#8
0
import logger
from others.utils import test_rouge, rouge_results_to_str

rouges = test_rouge(self.args.temp_dir, can_path, gold_path)
logger.info('Rouges at step %d \n%s' % (30000, rouge_results_to_str(rouges)))
示例#9
0
    def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s))>0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        base_dir = os.path.dirname(self.args.result_path)
        if (not os.path.exists(base_dir)):
            os.makedirs(base_dir)

        can_path = '%s_step%d_initial.candidate'%(self.args.result_path,step)
        gold_path = '%s_step%d_initial.gold' % (self.args.result_path, step)

        all_pred_ids, all_gold_ids, all_doc_ids = [], [], []
        all_gold_texts, all_pred_texts = [], []

        with torch.no_grad():
            for batch in test_iter:
                src = batch.src
                labels = batch.labels
                segs = batch.segs
                clss = batch.clss
                mask = batch.mask
                mask_cls = batch.mask_cls
                doc_ids = batch.doc_id
                group_idxs = batch.groups

                oracle_ids = [set([j for j in seq if j > -1]) for seq in batch.label_seq.tolist()]

                if (cal_lead):
                    selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size
                elif (cal_oracle):
                    selected_ids = [[j for j in range(batch.clss.size(1)) if labels[i][j] == 1] for i in
                                    range(batch.batch_size)]
                else:
                    sent_scores, mask = self.model(src, mask, segs, clss, mask_cls, group_idxs, candi_sent_masks=mask_cls, is_test=True)
                    #selected sentences in candi_masks can be set to 0
                    loss = -self.logsoftmax(sent_scores) * labels.float() #batch_size, max_sent_count
                    loss = (loss*mask.float()).sum()

                    batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels))
                    stats.update(batch_stats)

                    sent_scores[mask==False] = float('-inf')
                    # give a cap 1 to sentscores, so no need to add 1000
                    sent_scores = sent_scores.cpu().data.numpy()
                    selected_ids = np.argsort(-sent_scores, 1)
                for i, idx in enumerate(selected_ids):
                    _pred = []
                    _pred_ids = []
                    if(len(batch.src_str[i])==0):
                        continue
                    for j in selected_ids[i][:len(batch.src_str[i])]:
                        if(j>=len( batch.src_str[i])):
                            continue
                        candidate = batch.src_str[i][j].strip()
                        if(self.args.block_trigram):
                            if(not _block_tri(candidate,_pred)):
                                _pred.append(candidate)
                                _pred_ids.append(j)
                        else:
                            _pred.append(candidate)
                            _pred_ids.append(j)

                        if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3):
                            break

                    _pred = '<q>'.join(_pred)
                    if(self.args.recall_eval):
                        _pred = ' '.join(_pred.split()[:len(batch.tgt_str[i].split())])

                    all_pred_texts.append(_pred)
                    all_pred_ids.append(_pred_ids)
                    all_gold_texts.append(batch.tgt_str[i])
                    all_gold_ids.append(oracle_ids[i])
                    all_doc_ids.append(doc_ids[i])
        macro_precision, micro_precision = self._output_predicted_summaries(
                all_doc_ids, all_pred_ids, all_gold_ids,
                all_pred_texts, all_gold_texts, can_path, gold_path)
        rouge1_arr, rouge2_arr = du.cal_rouge_score(all_pred_texts, all_gold_texts)
        rouge_1, rouge_2 = du.aggregate_rouge(rouge1_arr, rouge2_arr)
        logger.info('[PERF]At step %d: rouge1:%.2f rouge2:%.2f' % (
            step, rouge_1 * 100, rouge_2 * 100))

        if(step!=-1 and self.args.report_precision):
            macro_arr = ["P@%s:%.2f%%" % (i+1, macro_precision[i] * 100) for i in range(3)]
            micro_arr = ["P@%s:%.2f%%" % (i+1, micro_precision[i] * 100) for i in range(3)]
            logger.info('[PERF]MacroPrecision at step %d: %s' % (step, '\t'.join(macro_arr)))
            logger.info('[PERF]MicroPrecision at step %d: %s' % (step, '\t'.join(micro_arr)))

        if(step!=-1 and self.args.report_rouge):
            rouge_str, detail_rouge = test_rouge(self.args.temp_dir, can_path, gold_path, all_doc_ids, show_all=True)
            logger.info('[PERF]Rouges at step %d: %s \n' % (step, rouge_str))
            result_path = '%s_step%d_initial.rouge' % (self.args.result_path, step)
            if detail_rouge is not None:
                du.output_rouge_file(result_path, rouge1_arr, rouge2_arr, detail_rouge, all_doc_ids)
        self._report_step(0, step, valid_stats=stats)

        return stats
示例#10
0
            print(model_dir)
            cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt')))
            for cp in cp_files:
                step = int(cp.split('.')[-2].split('_')[-1])
                test(args, device_id, cp, step)
        else:
            try:
                step = int(cp.split('.')[-2].split('_')[-1])
            except:
                step = 0
            test(args, device_id, cp, step)
    elif (args.mode == 'getrouge'):
        if args.model_name == 'base':
            pattern = '*step*initial.candidate'
        else:
            pattern = '*step*.candidate' #evaluate all
        candi_files = sorted(glob.glob("%s_%s" % (args.result_path, pattern)))
        #print(args.result_path)
        #print(candi_files)
        for can_path in candi_files:
            gold_path = can_path.replace('candidate', 'gold')
            rouge1_arr, rouge2_arr = du.compute_metrics(can_path, gold_path)
            step = os.path.basename(gold_path)
            precs_path = can_path.replace('candidate', 'precs')
            all_doc_ids = du.read_prec_file(precs_path)
            rouge_str, detail_rouge = test_rouge(args.temp_dir, can_path, gold_path, all_doc_ids, show_all=True)
            logger.info('Rouges at step %s \n%s' % (step, rouge_str))
            result_path = can_path.replace('candidate', 'rouge')
            if detail_rouge is not None:
                du.output_rouge_file(result_path, rouge1_arr, rouge2_arr, detail_rouge, all_doc_ids)
示例#11
0
from glob import glob
from os.path import basename
from others.utils import test_rouge

dirs = list(
    glob(
        "/Users/denisporplenko/Documents/UCU_master_degree/diploma/raw/PreSumm/result_final/results/origin_*"
    ))
for dir in dirs:
    print("Dir: ", basename(dir))
    for file in glob("%s/r.*.gold" % dir):
        num = basename(file).split(".")[1]
        print(num)
        cand = "%s/r.%s.candidate" % (dir, num)
        ref = "%s/r.%s.gold" % (dir, num)
        try:
            print(test_rouge("", cand, ref))
        except:
            print("Error")