Beispiel #1
0
    def iter_test(self, test_iter, step, sum_sent_count=3):
        """
           select sentences in each iteration
           given selected sentences, predict the next one
        """
        self.model.eval()
        stats = Statistics()
        #dir_name = os.path.dirname(self.args.result_path)
        base_name = os.path.basename(self.args.result_path)
        #base_dir = os.path.join(dir_name, 'iter_eval')
        base_dir = os.path.dirname(self.args.result_path)

        if (not os.path.exists(base_dir)):
            os.makedirs(base_dir)

        can_path = '%s/%s_step%d_itereval.candidate'%(base_dir, base_name, step)
        gold_path = '%s/%s_step%d_itereval.gold' % (base_dir, base_name, step)

        all_pred_ids, all_gold_ids, all_doc_ids = [], [], []
        all_gold_texts, all_pred_texts = [], []

        with torch.no_grad():
            for batch in test_iter:
                doc_ids = batch.doc_id
                oracle_ids = [set([j for j in seq if j > -1]) for seq in batch.label_seq.tolist()]

                sel_sent_idxs, sel_sent_masks = self.model.infer_sentences(batch, sum_sent_count, stats=stats)
                sel_sent_idxs = sel_sent_idxs.tolist()
                all_pred_ids.extend(sel_sent_idxs)

                for i in range(batch.batch_size):
                    _pred = '<q>'.join([batch.src_str[i][idx].strip() for j, idx in enumerate(sel_sent_idxs[i]) if sel_sent_masks[i][j]])
                    all_pred_texts.append(_pred)
                    all_gold_texts.append(batch.tgt_str[i])
                    all_gold_ids.append(oracle_ids[i])
                    all_doc_ids.append(doc_ids[i])
        macro_precision, micro_precision = self._output_predicted_summaries(
                all_doc_ids, all_pred_ids, all_gold_ids,
                all_pred_texts, all_gold_texts, can_path, gold_path)
        rouge1_arr, rouge2_arr = du.cal_rouge_score(all_pred_texts, all_gold_texts)
        rouge_1, rouge_2 = du.aggregate_rouge(rouge1_arr, rouge2_arr)
        logger.info('[PERF]At step %d: rouge1:%.2f rouge2:%.2f' % (
            step, rouge_1 * 100, rouge_2 * 100))
        if(step!=-1 and self.args.report_precision):
            macro_arr = ["P@%s:%.2f%%" % (i+1, macro_precision[i] * 100) for i in range(3)]
            micro_arr = ["P@%s:%.2f%%" % (i+1, micro_precision[i] * 100) for i in range(3)]
            logger.info('[PERF]MacroPrecision at step %d: %s' % (step, '\t'.join(macro_arr)))
            logger.info('[PERF]MicroPrecision at step %d: %s' % (step, '\t'.join(micro_arr)))
        if(step!=-1 and self.args.report_rouge):
            rouge_str, detail_rouge = test_rouge(self.args.temp_dir, can_path, gold_path, all_doc_ids, show_all=True)
            logger.info('[PERF]Rouges at step %d: %s \n' % (step, rouge_str))
            result_path = '%s_step%d_itereval.rouge' % (self.args.result_path, step)
            if detail_rouge is not None:
                du.output_rouge_file(result_path, rouge1_arr, rouge2_arr, detail_rouge, all_doc_ids)
        self._report_step(0, step, valid_stats=stats)

        return stats
Beispiel #2
0
    def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s))>0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        base_dir = os.path.dirname(self.args.result_path)
        if (not os.path.exists(base_dir)):
            os.makedirs(base_dir)

        can_path = '%s_step%d_initial.candidate'%(self.args.result_path,step)
        gold_path = '%s_step%d_initial.gold' % (self.args.result_path, step)

        all_pred_ids, all_gold_ids, all_doc_ids = [], [], []
        all_gold_texts, all_pred_texts = [], []

        with torch.no_grad():
            for batch in test_iter:
                src = batch.src
                labels = batch.labels
                segs = batch.segs
                clss = batch.clss
                mask = batch.mask
                mask_cls = batch.mask_cls
                doc_ids = batch.doc_id
                group_idxs = batch.groups

                oracle_ids = [set([j for j in seq if j > -1]) for seq in batch.label_seq.tolist()]

                if (cal_lead):
                    selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size
                elif (cal_oracle):
                    selected_ids = [[j for j in range(batch.clss.size(1)) if labels[i][j] == 1] for i in
                                    range(batch.batch_size)]
                else:
                    sent_scores, mask = self.model(src, mask, segs, clss, mask_cls, group_idxs, candi_sent_masks=mask_cls, is_test=True)
                    #selected sentences in candi_masks can be set to 0
                    loss = -self.logsoftmax(sent_scores) * labels.float() #batch_size, max_sent_count
                    loss = (loss*mask.float()).sum()

                    batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels))
                    stats.update(batch_stats)

                    sent_scores[mask==False] = float('-inf')
                    # give a cap 1 to sentscores, so no need to add 1000
                    sent_scores = sent_scores.cpu().data.numpy()
                    selected_ids = np.argsort(-sent_scores, 1)
                for i, idx in enumerate(selected_ids):
                    _pred = []
                    _pred_ids = []
                    if(len(batch.src_str[i])==0):
                        continue
                    for j in selected_ids[i][:len(batch.src_str[i])]:
                        if(j>=len( batch.src_str[i])):
                            continue
                        candidate = batch.src_str[i][j].strip()
                        if(self.args.block_trigram):
                            if(not _block_tri(candidate,_pred)):
                                _pred.append(candidate)
                                _pred_ids.append(j)
                        else:
                            _pred.append(candidate)
                            _pred_ids.append(j)

                        if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3):
                            break

                    _pred = '<q>'.join(_pred)
                    if(self.args.recall_eval):
                        _pred = ' '.join(_pred.split()[:len(batch.tgt_str[i].split())])

                    all_pred_texts.append(_pred)
                    all_pred_ids.append(_pred_ids)
                    all_gold_texts.append(batch.tgt_str[i])
                    all_gold_ids.append(oracle_ids[i])
                    all_doc_ids.append(doc_ids[i])
        macro_precision, micro_precision = self._output_predicted_summaries(
                all_doc_ids, all_pred_ids, all_gold_ids,
                all_pred_texts, all_gold_texts, can_path, gold_path)
        rouge1_arr, rouge2_arr = du.cal_rouge_score(all_pred_texts, all_gold_texts)
        rouge_1, rouge_2 = du.aggregate_rouge(rouge1_arr, rouge2_arr)
        logger.info('[PERF]At step %d: rouge1:%.2f rouge2:%.2f' % (
            step, rouge_1 * 100, rouge_2 * 100))

        if(step!=-1 and self.args.report_precision):
            macro_arr = ["P@%s:%.2f%%" % (i+1, macro_precision[i] * 100) for i in range(3)]
            micro_arr = ["P@%s:%.2f%%" % (i+1, micro_precision[i] * 100) for i in range(3)]
            logger.info('[PERF]MacroPrecision at step %d: %s' % (step, '\t'.join(macro_arr)))
            logger.info('[PERF]MicroPrecision at step %d: %s' % (step, '\t'.join(micro_arr)))

        if(step!=-1 and self.args.report_rouge):
            rouge_str, detail_rouge = test_rouge(self.args.temp_dir, can_path, gold_path, all_doc_ids, show_all=True)
            logger.info('[PERF]Rouges at step %d: %s \n' % (step, rouge_str))
            result_path = '%s_step%d_initial.rouge' % (self.args.result_path, step)
            if detail_rouge is not None:
                du.output_rouge_file(result_path, rouge1_arr, rouge2_arr, detail_rouge, all_doc_ids)
        self._report_step(0, step, valid_stats=stats)

        return stats
Beispiel #3
0
            print(model_dir)
            cp_files = sorted(glob.glob(os.path.join(args.model_path, 'model_step_*.pt')))
            for cp in cp_files:
                step = int(cp.split('.')[-2].split('_')[-1])
                test(args, device_id, cp, step)
        else:
            try:
                step = int(cp.split('.')[-2].split('_')[-1])
            except:
                step = 0
            test(args, device_id, cp, step)
    elif (args.mode == 'getrouge'):
        if args.model_name == 'base':
            pattern = '*step*initial.candidate'
        else:
            pattern = '*step*.candidate' #evaluate all
        candi_files = sorted(glob.glob("%s_%s" % (args.result_path, pattern)))
        #print(args.result_path)
        #print(candi_files)
        for can_path in candi_files:
            gold_path = can_path.replace('candidate', 'gold')
            rouge1_arr, rouge2_arr = du.compute_metrics(can_path, gold_path)
            step = os.path.basename(gold_path)
            precs_path = can_path.replace('candidate', 'precs')
            all_doc_ids = du.read_prec_file(precs_path)
            rouge_str, detail_rouge = test_rouge(args.temp_dir, can_path, gold_path, all_doc_ids, show_all=True)
            logger.info('Rouges at step %s \n%s' % (step, rouge_str))
            result_path = can_path.replace('candidate', 'rouge')
            if detail_rouge is not None:
                du.output_rouge_file(result_path, rouge1_arr, rouge2_arr, detail_rouge, all_doc_ids)