Пример #1
0
def report_rouge(logger, args, gold_path, can_path):
    logger.info("Calculating Rouge")
    rouges = test_rouge(args.temp_dir, can_path, gold_path)
    logger.info('Rouges:\n%s' % rouge_results_to_str(rouges))
    avg_f1 = avg_rouge_f1(rouges)
    logger.info('Average Rouge F1: %s' % avg_f1)
    return avg_f1
Пример #2
0
    def validate(self, data_iter, step, attn_debug=False):

        self.model.eval()
        gold_path = self.args.result_path + 'step.%d.gold_temp' % step
        pred_path = self.args.result_path + 'step.%d.pred_temp' % step
        gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        pred_out_file = codecs.open(pred_path, 'w', 'utf-8')

        ct = 0
        ext_acc_num = 0
        ext_pred_num = 0
        ext_gold_num = 0

        with torch.no_grad():
            for batch in data_iter:
                output_data, tgt_data, ext_pred, ext_gold = self.translate_batch(
                    batch)
                translations = self.from_batch_dev(output_data, tgt_data)

                for idx in range(len(translations)):
                    if ct % 100 == 0:
                        print("Processing %d" % ct)
                    pred_summ, gold_data = translations[idx]
                    # ext f1 calculate
                    acc_num = len(ext_pred[idx] + ext_gold[idx]) - len(
                        set(ext_pred[idx] + ext_gold[idx]))
                    pred_num = len(ext_pred[idx])
                    gold_num = len(ext_gold[idx])
                    ext_acc_num += acc_num
                    ext_pred_num += pred_num
                    ext_gold_num += gold_num
                    pred_out_file.write(pred_summ + '\n')
                    gold_out_file.write(gold_data + '\n')
                    ct += 1
                pred_out_file.flush()
                gold_out_file.flush()

        pred_out_file.close()
        gold_out_file.close()

        if (step != -1):
            pred_bleu = test_bleu(pred_path, gold_path)
            file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path)
            pred_rouges = file_rouge.get_scores(avg=True)
            f1, p, r = test_f1(ext_acc_num, ext_pred_num, ext_gold_num)
            self.logger.info(
                'Ext Sent Score at step %d: \n>> P/R/F1: %.2f/%.2f/%.2f' %
                (step, p * 100, r * 100, f1 * 100))
            self.logger.info(
                'Gold Length at step %d: %.2f' %
                (step, test_length(gold_path, gold_path, ratio=False)))
            self.logger.info('Prediction Length ratio at step %d: %.2f' %
                             (step, test_length(pred_path, gold_path)))
            self.logger.info('Prediction Bleu at step %d: %.2f' %
                             (step, pred_bleu * 100))
            self.logger.info('Prediction Rouges at step %d: \n%s\n' %
                             (step, rouge_results_to_str(pred_rouges)))
            rouge_results = (pred_rouges["rouge-1"]['f'],
                             pred_rouges["rouge-l"]['f'])
        return rouge_results
Пример #3
0
    def translate(self, data_iter, step, attn_debug=False):

        self.model.eval()
        gold_path = self.args.result_path + '.%d.gold' % step
        can_path = self.args.result_path + '.%d.candidate' % step
        self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        self.can_out_file = codecs.open(can_path, 'w', 'utf-8')
        raw_src_path = self.args.result_path + '.%d.raw_src' % step
        self.src_out_file = codecs.open(raw_src_path, 'w', 'utf-8')
        example_id_path = self.args.result_path + '.%d.example_id' % step
        self.example_id_file = codecs.open(example_id_path, 'w', 'utf-8')

        ct = 0
        with torch.no_grad():
            for batch in data_iter:
                batch_data = self.translate_batch(batch)
                translations = self.from_batch(batch_data, batch)
                for trans in translations:
                    pred, gold, src, example_id = trans
                    pred_str = pred.replace('[unused0]', '')\
                                .replace('[unused1]', '')\
                                .replace('[PAD]', '')\
                                .replace('[SEP]', '')\
                                .replace('[UNK]', '')\
                                .replace(r' +', ' ').strip()
                    gold_str = gold.strip()
                    self.can_out_file.write(pred_str + '\n')
                    self.gold_out_file.write(gold_str + '\n')
                    self.src_out_file.write(src.strip() + '\n')
                    self.example_id_file.write(str(example_id) + '\n')
                    ct += 1
                self.can_out_file.flush()
                self.gold_out_file.flush()
                self.src_out_file.flush()
                self.example_id_file.flush()
        self.can_out_file.close()
        self.gold_out_file.close()
        self.src_out_file.close()
        self.example_id_file.close()

        if (step != -1):
            rouges = self._report_rouge(gold_path, can_path)
            self.logger.info('Rouges at step %d \n%s' %
                             (step, rouge_results_to_str(rouges)))
            if self.tensorboard_writer is not None:
                self.tensorboard_writer.add_scalar('test/rouge1-F',
                                                   rouges['rouge_1_f_score'],
                                                   step)
                self.tensorboard_writer.add_scalar('test/rouge2-F',
                                                   rouges['rouge_2_f_score'],
                                                   step)
                self.tensorboard_writer.add_scalar('test/rougeL-F',
                                                   rouges['rouge_l_f_score'],
                                                   step)
Пример #4
0
    def validate(self, data_iter, step, attn_debug=False):

        self.model.eval()
        gold_path = self.args.result_path + '.step.%d.gold_temp' % step
        pred_path = self.args.result_path + '.step.%d.pred_temp' % step
        gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        pred_out_file = codecs.open(pred_path, 'w', 'utf-8')

        # pred_results, gold_results = [], []
        ct = 0
        with torch.no_grad():
            for batch in data_iter:
                doc_data, summ_data = self.translate_batch(batch)
                translations = self.from_batch_dev(batch, doc_data)

                for idx in range(len(translations)):
                    if ct % 100 == 0:
                        print("Processing %d" % ct)
                    doc_short_context = translations[idx][1]
                    gold_data = summ_data[idx]
                    pred_out_file.write(doc_short_context + '\n')
                    gold_out_file.write(gold_data + '\n')
                    ct += 1
                pred_out_file.flush()
                gold_out_file.flush()

        pred_out_file.close()
        gold_out_file.close()

        if (step != -1):
            pred_bleu = test_bleu(pred_path, gold_path)
            file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path)
            pred_rouges = file_rouge.get_scores(avg=True)
            self.logger.info(
                'Gold Length at step %d: %.2f' %
                (step, test_length(gold_path, gold_path, ratio=False)))
            self.logger.info('Prediction Length ratio at step %d: %.2f' %
                             (step, test_length(pred_path, gold_path)))
            self.logger.info('Prediction Bleu at step %d: %.2f' %
                             (step, pred_bleu * 100))
            self.logger.info('Prediction Rouges at step %d: \n%s\n' %
                             (step, rouge_results_to_str(pred_rouges)))
            rouge_results = (pred_rouges["rouge-1"]['f'],
                             pred_rouges["rouge-l"]['f'])
        return rouge_results
Пример #5
0
    def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """

        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s)) > 0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        can_path = '%s_step%d.candidate' % (self.args.result_path, step)
        gold_path = '%s_step%d.gold' % (self.args.result_path, step)
        with open(can_path, 'w') as save_pred:
            with open(gold_path, 'w') as save_gold:
                with torch.no_grad():
                    for batch in test_iter:
                        gold = []
                        pred = []
                        if (cal_lead):
                            selected_ids = [list(range(batch.clss.size(1)))
                                            ] * batch.batch_size
                        for i, idx in enumerate(selected_ids):
                            _pred = []
                            if (len(batch.src_str[i]) == 0):
                                continue
                            for j in selected_ids[i][:len(batch.src_str[i])]:
                                if (j >= len(batch.src_str[i])):
                                    continue
                                candidate = batch.src_str[i][j].strip()
                                _pred.append(candidate)

                                if ((not cal_oracle)
                                        and (not self.args.recall_eval)
                                        and len(_pred) == 3):
                                    break

                            _pred = '<q>'.join(_pred)
                            if (self.args.recall_eval):
                                _pred = ' '.join(
                                    _pred.split()
                                    [:len(batch.tgt_str[i].split())])

                            pred.append(_pred)
                            gold.append(batch.tgt_str[i])

                        for i in range(len(gold)):
                            save_gold.write(gold[i].strip() + '\n')
                        for i in range(len(pred)):
                            save_pred.write(pred[i].strip() + '\n')
        if (step != -1 and self.args.report_rouge):
            rouges = test_rouge(self.args.temp_dir, can_path, gold_path)
            logger.info('Rouges at step %d \n%s' %
                        (step, rouge_results_to_str(rouges)))
        self._report_step(0, step, valid_stats=stats)

        return stats
Пример #6
0
    def translate(self, data_iter, step):
        """ Main control flow for decoding
        """

        # Set model to eval mode for decoding
        self.model.eval()

        # Output file path
        gold_path = os.path.join(self.args.result_path, 'test.%d.gold' % step)
        can_path = os.path.join(self.args.result_path,
                                'test.%d.candidate' % step)
        raw_src_path = os.path.join(self.args.result_path,
                                    'test.%d.raw_src' % step)
        self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        self.can_out_file = codecs.open(can_path, 'w', 'utf-8')
        self.src_out_file = codecs.open(raw_src_path, 'w', 'utf-8')

        ct = 0
        with torch.no_grad():
            for batch in data_iter:
                # batch (:obj:data_loader.Batch)
                # data_iter (:ojb:data_loader.Dataloader)

                # Constraint prediction length close to gold length
                if self.args.recall_eval:
                    gold_tgt_len = batch.tgt.size(1)
                    self.min_length = gold_tgt_len + 20
                    self.max_length = gold_tgt_len + 60

                # batch_data: type=dict
                #   keys -> ['predictions', 'scores', 'gold_score', 'batch']
                # translations: type=list
                #   content -> (predict_sent, gold_sent, raw_src)
                batch_data = self.translate_batch(batch)
                translations = self.from_batch(batch_data)

                for trans in translations:
                    pred, gold, src = trans

                    # type=string
                    src_str = src.strip()

                    # type=string
                    # [unused0] -> BOS
                    # [unused1] -> EOS
                    # [unused2] -> EOQ
                    pred_str = pred.replace('[unused0]', '').replace(
                        '[unused3]', '').replace('[PAD]', '').replace(
                            '[unused1]', '').replace(r' +', ' ').replace(
                                ' [unused2] ', '<q>').replace('[unused2]',
                                                              '').strip()
                    # type=string
                    gold_str = gold.strip()

                    # Constraint prediction length close to gold length
                    if (self.args.recall_eval):
                        _pred_str = ''
                        for sent in pred_str.split('<q>'):
                            # Accumulate pred_str sentence by sentnce
                            can_pred_str = _pred_str + '<q>' + sent.strip()

                            # Cut if length difference above 10 tokens
                            if (len(can_pred_str.split()) >=
                                    len(gold_str.split()) + 10):
                                pred_str = _pred_str
                                break
                            else:
                                _pred_str = can_pred_str

                    self.src_out_file.write(src_str + '\n')
                    self.can_out_file.write(pred_str + '\n')
                    self.gold_out_file.write(gold_str + '\n')
                    ct += 1

                # Flush the buffer
                self.can_out_file.flush()
                self.gold_out_file.flush()
                self.src_out_file.flush()

        # Close files
        self.can_out_file.close()
        self.gold_out_file.close()
        self.src_out_file.close()

        # Report results in console and log
        if (step != -1):
            rouges = self._report_rouge(gold_path, can_path)
            self.logger.info('Rouges at step %d \n%s' %
                             (step, rouge_results_to_str(rouges)))
            if self.tensorboard_writer is not None:
                self.tensorboard_writer.add_scalar('test/rouge1-F',
                                                   rouges['rouge_1_f_score'],
                                                   step)
                self.tensorboard_writer.add_scalar('test/rouge2-F',
                                                   rouges['rouge_2_f_score'],
                                                   step)
                self.tensorboard_writer.add_scalar('test/rougeL-F',
                                                   rouges['rouge_l_f_score'],
                                                   step)
                mlflow.log_metric('Test_ROUGE1_F', rouges['rouge_1_f_score'],
                                  step)
                mlflow.log_metric('Test_ROUGE2_F', rouges['rouge_2_f_score'],
                                  step)
                mlflow.log_metric('Test_ROUGEL_F', rouges['rouge_l_f_score'],
                                  step)
Пример #7
0
    def translate(self, data_iter, step, attn_debug=False):

        self.model.eval()
        output_path = self.args.result_path + '.%d.output' % step
        output_file = codecs.open(output_path, 'w', 'utf-8')
        gold_path = self.args.result_path + '.%d.gold_test' % step
        pred_path = self.args.result_path + '.%d.pred_test' % step
        gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        pred_out_file = codecs.open(pred_path, 'w', 'utf-8')
        # pred_results, gold_results = [], []

        ct = 0
        ext_acc_num = 0
        ext_pred_num = 0
        ext_gold_num = 0

        with torch.no_grad():
            rouge = Rouge()
            for batch in data_iter:
                output_data, tgt_data, ext_pred, ext_gold = self.translate_batch(
                    batch)
                translations = self.from_batch_test(batch, output_data,
                                                    tgt_data)

                for idx in range(len(translations)):
                    origin_sent, pred_summ, gold_data = translations[idx]
                    if ct % 100 == 0:
                        print("Processing %d" % ct)
                    output_file.write("ID      : %d\n" % ct)
                    output_file.write("ORIGIN  : \n    " +
                                      origin_sent.replace('<S>', '\n    ') +
                                      "\n")
                    output_file.write("GOLD    : " + gold_data.strip() + "\n")
                    output_file.write("DOC_GEN : " + pred_summ.strip() + "\n")
                    rouge_score = rouge.get_scores(pred_summ, gold_data)
                    bleu_score = sentence_bleu(
                        [gold_data.split()],
                        pred_summ.split(),
                        smoothing_function=SmoothingFunction().method1)
                    output_file.write(
                        "DOC_GEN  bleu & rouge-f 1/2/l:    %.4f & %.4f/%.4f/%.4f\n"
                        % (bleu_score, rouge_score[0]["rouge-1"]["f"],
                           rouge_score[0]["rouge-2"]["f"],
                           rouge_score[0]["rouge-l"]["f"]))
                    # ext f1 calculate
                    acc_num = len(ext_pred[idx] + ext_gold[idx]) - len(
                        set(ext_pred[idx] + ext_gold[idx]))
                    pred_num = len(ext_pred[idx])
                    gold_num = len(ext_gold[idx])
                    ext_acc_num += acc_num
                    ext_pred_num += pred_num
                    ext_gold_num += gold_num
                    f1, p, r = test_f1(acc_num, pred_num, gold_num)
                    output_file.write(
                        "EXT_GOLD: [" +
                        ','.join([str(i)
                                  for i in sorted(ext_gold[idx])]) + "]\n")
                    output_file.write(
                        "EXT_PRED: [" +
                        ','.join([str(i)
                                  for i in sorted(ext_pred[idx])]) + "]\n")
                    output_file.write(
                        "EXT_SCORE  P/R/F1:    %.4f/%.4f/%.4f\n\n" %
                        (p, r, f1))
                    pred_out_file.write(pred_summ.strip() + '\n')
                    gold_out_file.write(gold_data.strip() + '\n')
                    ct += 1
                pred_out_file.flush()
                gold_out_file.flush()
                output_file.flush()

        pred_out_file.close()
        gold_out_file.close()
        output_file.close()

        if (step != -1):
            pred_bleu = test_bleu(pred_path, gold_path)
            file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path)
            pred_rouges = file_rouge.get_scores(avg=True)
            f1, p, r = test_f1(ext_acc_num, ext_pred_num, ext_gold_num)
            self.logger.info(
                'Ext Sent Score at step %d: \n>> P/R/F1: %.2f/%.2f/%.2f' %
                (step, p * 100, r * 100, f1 * 100))
            self.logger.info(
                'Gold Length at step %d: %.2f' %
                (step, test_length(gold_path, gold_path, ratio=False)))
            self.logger.info('Prediction Length ratio at step %d: %.2f' %
                             (step, test_length(pred_path, gold_path)))
            self.logger.info('Prediction Bleu at step %d: %.2f' %
                             (step, pred_bleu * 100))
            self.logger.info('Prediction Rouges at step %d: \n%s' %
                             (step, rouge_results_to_str(pred_rouges)))
Пример #8
0
    def translate(self, data_iter, step, attn_debug=False):

        self.model.eval()
        output_path = self.args.result_path + '.%d.output' % step
        output_file = codecs.open(output_path, 'w', 'utf-8')
        gold_path = self.args.result_path + '.%d.gold_test' % step
        pred_path = self.args.result_path + '.%d.pred_test' % step
        ex_single_path = self.args.result_path + '.%d.ex_test' % step + ".short"
        ex_context_path = self.args.result_path + '.%d.ex_test' % step + ".long"
        gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        pred_out_file = codecs.open(pred_path, 'w', 'utf-8')
        short_ex_out_file = codecs.open(ex_single_path, 'w', 'utf-8')
        long_ex_out_file = codecs.open(ex_context_path, 'w', 'utf-8')
        # pred_results, gold_results = [], []

        ct = 0
        with torch.no_grad():
            rouge = Rouge()
            for batch in data_iter:
                doc_data, summ_data = self.translate_batch(batch)
                translations = self.from_batch_test(batch, doc_data)

                for idx in range(len(translations)):
                    origin_sent, doc_extract, context_doc_extract, \
                        doc_pred, lead = translations[idx]
                    if ct % 100 == 0:
                        print("Processing %d" % ct)
                    output_file.write("ID      : %d\n" % ct)
                    output_file.write(
                        "ORIGIN  : " +
                        origin_sent.replace('<S>', '\n          ') + "\n")
                    gold_data = summ_data[idx]
                    output_file.write("GOLD    : " + gold_data + "\n")
                    output_file.write("LEAD    : " + lead + "\n")
                    output_file.write("DOC_EX  : " + doc_extract.strip() +
                                      "\n")
                    output_file.write("DOC_CONT: " +
                                      context_doc_extract.strip() + "\n")
                    output_file.write("DOC_GEN : " + doc_pred.strip() + "\n")

                    gold_list = gold_data.strip().split()
                    lead_list = lead.strip().replace("[unused2]", "").replace(
                        "[unused3]", "").split()
                    rouge_score = rouge.get_scores(lead, gold_data)
                    bleu_score = sentence_bleu(
                        [gold_list],
                        lead_list,
                        smoothing_function=SmoothingFunction().method1)
                    output_file.write(
                        "LEAD     bleu & rouge-f 1/2/l:    %.4f & %.4f/%.4f/%.4f\n"
                        % (bleu_score, rouge_score[0]["rouge-1"]["f"],
                           rouge_score[0]["rouge-2"]["f"],
                           rouge_score[0]["rouge-l"]["f"]))

                    doc_extract_list = doc_extract.strip().replace(
                        "[unused2]", "").replace("[unused3]", "").split()
                    rouge_score = rouge.get_scores(doc_extract, gold_data)
                    bleu_score = sentence_bleu(
                        [gold_list],
                        doc_extract_list,
                        smoothing_function=SmoothingFunction().method1)
                    output_file.write(
                        "DOC_EX   bleu & rouge-f 1/2/l:    %.4f & %.4f/%.4f/%.4f\n"
                        % (bleu_score, rouge_score[0]["rouge-1"]["f"],
                           rouge_score[0]["rouge-2"]["f"],
                           rouge_score[0]["rouge-l"]["f"]))

                    doc_context_list = context_doc_extract.strip().replace(
                        "[unused2]", "").replace("[unused3]", "").split()
                    rouge_score = rouge.get_scores(context_doc_extract,
                                                   gold_data)
                    bleu_score = sentence_bleu(
                        [gold_list],
                        doc_context_list,
                        smoothing_function=SmoothingFunction().method1)
                    output_file.write(
                        "DOC_CONT bleu & rouge-f 1/2/l:    %.4f & %.4f/%.4f/%.4f\n"
                        % (bleu_score, rouge_score[0]["rouge-1"]["f"],
                           rouge_score[0]["rouge-2"]["f"],
                           rouge_score[0]["rouge-l"]["f"]))

                    doc_long_list = doc_pred.strip().replace(
                        "[unused2]", "").replace("[unused3]", "").split()
                    rouge_score = rouge.get_scores(doc_pred, gold_data)
                    bleu_score = sentence_bleu(
                        [gold_list],
                        doc_long_list,
                        smoothing_function=SmoothingFunction().method1)
                    output_file.write(
                        "DOC_GEN  bleu & rouge-f 1/2/l:    %.4f & %.4f/%.4f/%.4f\n\n"
                        % (bleu_score, rouge_score[0]["rouge-1"]["f"],
                           rouge_score[0]["rouge-2"]["f"],
                           rouge_score[0]["rouge-l"]["f"]))

                    short_ex_out_file.write(doc_extract.strip().replace(
                        "[unused2]", "").replace("[unused3]", "") + '\n')
                    long_ex_out_file.write(context_doc_extract.strip().replace(
                        "[unused2]", "").replace("[unused3]", "") + '\n')
                    pred_out_file.write(doc_pred.strip().replace(
                        "[unused2]", "").replace("[unused3]", "") + '\n')
                    gold_out_file.write(gold_data.strip() + '\n')
                    ct += 1
                pred_out_file.flush()
                short_ex_out_file.flush()
                long_ex_out_file.flush()
                gold_out_file.flush()
                output_file.flush()

        pred_out_file.close()
        short_ex_out_file.close()
        long_ex_out_file.close()
        gold_out_file.close()
        output_file.close()

        if (step != -1):
            ex_short_bleu = test_bleu(gold_path, ex_single_path)
            ex_long_bleu = test_bleu(gold_path, ex_context_path)
            pred_bleu = test_bleu(gold_path, pred_path)

            file_rouge = FilesRouge(hyp_path=ex_single_path,
                                    ref_path=gold_path)
            ex_short_rouges = file_rouge.get_scores(avg=True)

            file_rouge = FilesRouge(hyp_path=ex_context_path,
                                    ref_path=gold_path)
            ex_long_rouges = file_rouge.get_scores(avg=True)

            file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path)
            pred_rouges = file_rouge.get_scores(avg=True)

            self.logger.info(
                'Gold Length at step %d: %.2f\n' %
                (step, test_length(gold_path, gold_path, ratio=False)))
            self.logger.info('Short Extraction Length ratio at step %d: %.2f' %
                             (step, test_length(ex_single_path, gold_path)))
            self.logger.info('Short Extraction Bleu at step %d: %.2f' %
                             (step, ex_short_bleu * 100))
            self.logger.info('Short Extraction Rouges at step %d \n%s' %
                             (step, rouge_results_to_str(ex_short_rouges)))
            self.logger.info('Long Extraction Length ratio at step %d: %.2f' %
                             (step, test_length(ex_context_path, gold_path)))
            self.logger.info('Long Extraction Bleu at step %d: %.2f' %
                             (step, ex_long_bleu * 100))
            self.logger.info('Long Extraction Rouges at step %d \n%s' %
                             (step, rouge_results_to_str(ex_long_rouges)))
            self.logger.info('Prediction Length ratio at step %d: %.2f' %
                             (step, test_length(pred_path, gold_path)))
            self.logger.info('Prediction Bleu at step %d: %.2f' %
                             (step, pred_bleu * 100))
            self.logger.info('Prediction Rouges at step %d \n%s' %
                             (step, rouge_results_to_str(pred_rouges)))
    def translate(self, data_iter, step, attn_debug=False):

        self.model.eval()
        output_path = self.args.result_path + '.%d.output' % step
        output_file = codecs.open(output_path, 'w', 'utf-8')
        gold_path = self.args.result_path + '.%d.gold_test' % step
        pred_path = self.args.result_path + '.%d.pred_test' % step
        gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        pred_out_file = codecs.open(pred_path, 'w', 'utf-8')
        # pred_results, gold_results = [], []
        ct = 0

        with torch.no_grad():
            rouge = Rouge()
            for batch in data_iter:
                output_data, tgt_data = self.translate_batch(batch)
                translations = self.from_batch_test(batch, output_data,
                                                    tgt_data)

                for idx in range(len(translations)):
                    origin_sent, pred_summ, gold_data = translations[idx]
                    if ct % 100 == 0:
                        print("Processing %d" % ct)
                    output_file.write("ID      : %d\n" % ct)
                    output_file.write("ORIGIN  : \n    " +
                                      origin_sent.replace('<S>', '\n    ') +
                                      "\n")
                    output_file.write("GOLD    : " + gold_data.strip() + "\n")
                    output_file.write("DOC_GEN : " + pred_summ.strip() + "\n")
                    rouge_score = rouge.get_scores(pred_summ, gold_data)
                    bleu_score = sentence_bleu(
                        [gold_data.split()],
                        pred_summ.split(),
                        smoothing_function=SmoothingFunction().method1)
                    output_file.write(
                        "DOC_GEN  bleu & rouge-f 1/l:    %.4f & %.4f/%.4f\n\n"
                        % (bleu_score, rouge_score[0]["rouge-1"]["f"],
                           rouge_score[0]["rouge-l"]["f"]))
                    pred_out_file.write(pred_summ.strip() + '\n')
                    gold_out_file.write(gold_data.strip() + '\n')
                    ct += 1
                pred_out_file.flush()
                gold_out_file.flush()
                output_file.flush()

        pred_out_file.close()
        gold_out_file.close()
        output_file.close()

        if (step != -1):
            pred_bleu = test_bleu(pred_path, gold_path)
            file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path)
            pred_rouges = file_rouge.get_scores(avg=True)
            self.logger.info(
                'Gold Length at step %d: %.2f\n' %
                (step, test_length(gold_path, gold_path, ratio=False)))
            self.logger.info('Prediction Length ratio at step %d: %.2f' %
                             (step, test_length(pred_path, gold_path)))
            self.logger.info('Prediction Bleu at step %d: %.2f' %
                             (step, pred_bleu * 100))
            self.logger.info('Prediction Rouges at step %d: \n%s' %
                             (step, rouge_results_to_str(pred_rouges)))
Пример #10
0
def baseline(args, cal_lead=False, cal_oracle=False):
    test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False),
                                       args.test_batch_size, args.test_batch_ex_size, 'cpu',
                                       shuffle=False, is_test=True)

    if cal_lead:
        mode = "lead"
    else:
        mode = "oracle"

    rouge = Rouge()
    pred_path = '%s.%s.pred' % (args.result_path, mode)
    gold_path = '%s.%s.gold' % (args.result_path, mode)
    save_pred = open(pred_path, 'w', encoding='utf-8')
    save_gold = open(gold_path, 'w', encoding='utf-8')

    with torch.no_grad():
        for batch in test_iter:
            summaries = batch.summ_txt
            origin_sents = batch.original_str
            ex_segs = batch.ex_segs
            ex_segs = [sum(ex_segs[:i]) for i in range(len(ex_segs)+1)]

            for idx in range(len(summaries)):
                summary = summaries[idx]
                txt = origin_sents[ex_segs[idx]:ex_segs[idx+1]]
                if cal_oracle:
                    selected = []
                    max_rouge = 0.
                    while len(selected) < args.ranking_max_k:
                        cur_max_rouge = max_rouge
                        cur_id = -1
                        for i in range(len(txt)):
                            if (i in selected):
                                continue
                            c = selected + [i]
                            temp_txt = " ".join([txt[j] for j in c])
                            rouge_score = rouge.get_scores(temp_txt, summary)
                            rouge_1 = rouge_score[0]["rouge-1"]["f"]
                            rouge_l = rouge_score[0]["rouge-l"]["f"]
                            rouge_score = rouge_1 + rouge_l
                            if rouge_score > cur_max_rouge:
                                cur_max_rouge = rouge_score
                                cur_id = i
                        if (cur_id == -1):
                            break
                        selected.append(cur_id)
                        max_rouge = cur_max_rouge
                    pred_txt = " ".join([txt[j] for j in selected])
                else:
                    k = min(max(len(txt) // (2*args.win_size+1), 1), args.ranking_max_k)
                    pred_txt = " ".join(txt[:k])
                save_gold.write(summary + "\n")
                save_pred.write(pred_txt + "\n")
    save_gold.flush()
    save_pred.flush()
    save_gold.close()
    save_pred.close()

    length = test_length(pred_path, gold_path)
    bleu = test_bleu(pred_path, gold_path)
    file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path)
    pred_rouges = file_rouge.get_scores(avg=True)
    logger.info('Length ratio:\n%s' % str(length))
    logger.info('Bleu:\n%.2f' % (bleu*100))
    logger.info('Rouges:\n%s' % rouge_results_to_str(pred_rouges))
Пример #11
0
    def translate(self, data_iter, step, attn_debug=False):

        self.model.eval()
        gold_path = self.args.result_path + '.%d.gold' % step
        can_path = self.args.result_path + '.%d.candidate' % step
        self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        self.can_out_file = codecs.open(can_path, 'w', 'utf-8')

        # raw_gold_path = self.args.result_path + '.%d.raw_gold' % step
        # raw_can_path = self.args.result_path + '.%d.raw_candidate' % step
        self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        self.can_out_file = codecs.open(can_path, 'w', 'utf-8')

        raw_src_path = self.args.result_path + '.%d.raw_src' % step
        self.src_out_file = codecs.open(raw_src_path, 'w', 'utf-8')

        #
        # ent_path = '%s_step%d.ent' % (self.args.result_path, step)
        # self.ent_file = codecs.open(ent_path, 'w', 'utf-8')
        # tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
        # self.gold = codecs.open('/data/bqw/GraphEnt_3_21/abs_gate.112000.gold', 'w', 'utf-8')
        # self.cand = codecs.open('/data/bqw/GraphEnt_3_21/abs_gate.112000.candidate', 'w', 'utf-8')
        # sample_num = 1
        #
        # with torch.no_grad():
        #     for batch in data_iter:
        #         for i, sample in enumerate(batch.ent_text):
        #             spo_num_meta = batch.spo_num[i]
        #             ent_num_meta = len(sample)
        #             raw_ent = ['[root]']
        #             src = linecache.getline("/data/bqw/GraphEnt_3_21/logs/cnndm/abs_gate.112000.raw_src", sample_num).strip()
        #             gold = linecache.getline("/data/bqw/GraphEnt_3_21/logs/cnndm/abs_gate.112000.gold", sample_num).strip()
        #             cand1 = linecache.getline("/data/bqw/GraphEnt_3_21/logs/cnndm/abs_gate.112000.candidate", sample_num).strip()
        #             cand2 = linecache.getline("/data/bqw/PreSumm/logs/cnndm/test_ent_1.118000.candidate", sample_num).strip()
        #
        #             sample_num += 1
        #             for j, meta_ent in enumerate(sample):
        #                 if j == 0:
        #                     continue
        #                 meta_text = ' '.join(tokenizer.convert_ids_to_tokens(meta_ent)).replace(' ##', '')
        #                 raw_ent.append(meta_text)
        #             raw_ent = '[CLS]'.join(raw_ent).strip()
        #             # print(raw_ent)
        #             # self.ent_file.write(gold+'[ANA]'+cand+'[ANA]'+raw_ent + '[ANA]{}[ANA]{}'.format(ent_num_meta, spo_num_meta) + '\n')
        #             self.ent_file.write(src+'[ANA]'+gold+'[ANA]'+cand1+'[ANA]'+cand2+'[ANA]'
        #                                 +raw_ent + '[ANA]{}[ANA]{}'.format(ent_num_meta, spo_num_meta) + '\n')
        #     self.ent_file.flush()
        # self.ent_file.close()
        ####################################################################################################
        pred_results, gold_results = [], []
        ct = 0
        with torch.no_grad():
            for batch in data_iter:
                if (self.args.recall_eval):
                    gold_tgt_len = batch.tgt.size(1)
                    self.min_length = gold_tgt_len + 20
                    self.max_length = gold_tgt_len + 60
                batch_data = self.translate_batch(batch)
                translations = self.from_batch(batch_data)

                for trans in translations:
                    pred, gold, src = trans
                    pred_str = pred.replace('[unused0]', '').replace(
                        '[unused3]', '').replace('[PAD]', '').replace(
                            '[unused1]', '').replace(r' +', ' ').replace(
                                ' [unused2] ', '<q>').replace('[unused2]',
                                                              '').strip()
                    gold_str = gold.strip()
                    if (self.args.recall_eval):
                        print(
                            'this is recal-----------------------------------------------'
                        )
                        _pred_str = ''
                        gap = 1e3
                        for sent in pred_str.split('<q>'):
                            can_pred_str = _pred_str + '<q>' + sent.strip()
                            can_gap = math.fabs(
                                len(_pred_str.split()) - len(gold_str.split()))
                            # if(can_gap>=gap):
                            if (len(can_pred_str.split()) >=
                                    len(gold_str.split()) + 10):
                                pred_str = _pred_str
                                break
                            else:
                                gap = can_gap
                                _pred_str = can_pred_str

                        # pred_str = ' '.join(pred_str.split()[:len(gold_str.split())])
                    # self.raw_can_out_file.write(' '.join(pred).strip() + '\n')
                    # self.raw_gold_out_file.write(' '.join(gold).strip() + '\n')
                    self.can_out_file.write(pred_str + '\n')
                    self.gold_out_file.write(gold_str + '\n')
                    self.src_out_file.write(src.strip() + '\n')
                    ##
                    # with open('./scru_data.txt', 'a+') as f:
                    #     f.write('doc \n' + src.replace('[unused0]', '').replace('[unused3]', '').replace('[PAD]', '').replace('[unused1]', '').replace(r' +', ' ').replace(' [unused2] ', '<q>').replace('[unused2]', '').strip() + '\n')
                    #     f.write('gold \n' + gold_str + '\n')
                    #     f.write('pred \n' + pred_str + '\n')
                    ct += 1
                self.can_out_file.flush()
                self.gold_out_file.flush()
                self.src_out_file.flush()

        self.can_out_file.close()
        self.gold_out_file.close()
        self.src_out_file.close()

        if (step != -1):
            rouges = self._report_rouge(gold_path, can_path)
            self.logger.info('Rouges at step %d \n%s' %
                             (step, rouge_results_to_str(rouges)))
            if self.tensorboard_writer is not None:
                self.tensorboard_writer.add_scalar('test/rouge1-F',
                                                   rouges['rouge_1_f_score'],
                                                   step)
                self.tensorboard_writer.add_scalar('test/rouge2-F',
                                                   rouges['rouge_2_f_score'],
                                                   step)
                self.tensorboard_writer.add_scalar('test/rougeL-F',
                                                   rouges['rouge_l_f_score'],
                                                   step)
Пример #12
0
    def translate(self, data_iter, step, attn_debug=False):

        self.model.eval()
        gold_path = self.args.result_path + '.%d.gold' % step
        can_path = self.args.result_path + '.%d.candidate' % step
        self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        self.can_out_file = codecs.open(can_path, 'w', 'utf-8')

        # raw_gold_path = self.args.result_path + '.%d.raw_gold' % step
        # raw_can_path = self.args.result_path + '.%d.raw_candidate' % step
        self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        self.can_out_file = codecs.open(can_path, 'w', 'utf-8')

        raw_src_path = self.args.result_path + '.%d.raw_src' % step
        self.src_out_file = codecs.open(raw_src_path, 'w', 'utf-8')

        # pred_results, gold_results = [], []
        ct = 0
        with torch.no_grad():
            for batch in data_iter:
                if (self.args.recall_eval):
                    gold_tgt_len = batch.tgt.size(1)
                    self.min_length = gold_tgt_len + 20
                    self.max_length = gold_tgt_len + 60
                batch_data = self.translate_batch(batch)
                translations = self.from_batch(batch_data)

                for trans in translations:
                    pred, gold, src = trans
                    pred = pred.split('. [MASK]')[0] + '.'
                    pred_str = pred.replace(' ホ ', '<q>').replace(
                        '. [MASK]', '').replace('! [MASK]', '').replace(
                            '? [MASK]', '').replace(' [MASK]', '').replace(
                                ' [UNK]', '').replace(' [PAD]', '').replace(
                                    ' [CLS]', '').replace(' [SEP]',
                                                          '').strip()
                    gold_str = gold.lower().strip()

                    # Filter result
                    _pred_str = []
                    cur_word = 0
                    for sent in pred_str.split('<q>'):
                        _pred_str.append(sent)
                        if len(_pred_str) >= 2 and cur_word > 15:
                            break
                    pred_str = '<q>'.join(_pred_str)

                    self.can_out_file.write(pred_str + '\n')
                    self.gold_out_file.write(gold_str + '\n')
                    self.src_out_file.write(src.strip() + '\n')
                    ct += 1
                self.can_out_file.flush()
                self.gold_out_file.flush()
                self.src_out_file.flush()

        self.can_out_file.close()
        self.gold_out_file.close()
        self.src_out_file.close()

        if (step != -1):
            rouges = self._report_rouge(gold_path, can_path)
            self.logger.info('Rouges at step %d \n%s' %
                             (step, rouge_results_to_str(rouges)))
            if self.tensorboard_writer is not None:
                self.tensorboard_writer.add_scalar('test/rouge1-F',
                                                   rouges['rouge_1_f_score'],
                                                   step)
                self.tensorboard_writer.add_scalar('test/rouge2-F',
                                                   rouges['rouge_2_f_score'],
                                                   step)
                self.tensorboard_writer.add_scalar('test/rougeL-F',
                                                   rouges['rouge_l_f_score'],
                                                   step)
Пример #13
0
    def translate(self,
                  data_iter,
                  epoch,
                  cal_rouge=False,
                  save=True,
                  save_by_id=False,
                  have_gold=False,
                  info=""):

        self.model.eval()
        if save:
            base_path = pjoin(self.args.savepath, epoch)
            os.makedirs(base_path, exist_ok=True)
            out_path = pjoin(base_path, 'out.txt')
            self.out_file = codecs.open(out_path, 'w', 'utf-8')

            can_path = pjoin(base_path, 'can.txt')
            self.can_out_file = codecs.open(can_path, 'w', 'utf-8')

            if have_gold:
                gold_path = pjoin(base_path, 'gold.txt')
                self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8')

            #ididx_path = self.args.savepath + '/cond.txt'
            #self.ididx_out_file = codecs.open(ididx_path, 'w', 'utf-8')

            #raw_src_path = self.args.savepath + '/result/{}.raw_src'.format(epoch)
            #self.src_out_file = codecs.open(raw_src_path, 'w', 'utf-8')

        # pred_results, gold_results = [], []
        #ct = 0

        def _translate(batch):
            '''
            if(self.args.recall_eval):
                gold_tgt_len = batch.tgt.size(1)
                self.min_length = gold_tgt_len + 20
                self.max_length = gold_tgt_len + 60
            '''
            #id_idx = batch.id_idx
            batch_data = self.translate_batch(batch)
            translations = self.from_batch(batch_data, have_gold)

            for i, trans in enumerate(translations):
                pred, gold, src, id_, label = trans
                pred_str = pred.replace('[unused0]', '').replace(
                    '[unused3]',
                    '').replace('[PAD]', '').replace('[unused1]', '').replace(
                        r' +', ' ').replace(' [unused2] ',
                                            '<q>').replace('[unused2]',
                                                           '').strip()
                if have_gold:
                    gold_str = gold.strip()
                #id_ = id_idx[i][0].item()
                #parent_idx = id_idx[i][1].item()
                #self_idx = id_idx[i][2].item()
                '''
                if(self.args.recall_eval):
                    _pred_str = ''
                    gap = 1e3
                    for sent in pred_str.split('<q>'):
                        can_pred_str = _pred_str+ '<q>'+sent.strip()
                        can_gap = math.fabs(len(_pred_str.split())-len(gold_str.split()))
                        # if(can_gap>=gap):
                        if(len(can_pred_str.split())>=len(gold_str.split())+10):
                            pred_str = _pred_str
                            break
                        else:
                            gap = can_gap
                            _pred_str = can_pred_str

                    #pred_str = ' '.join(pred_str.split()[:len(gold_str.split())])
                '''
                #self.raw_can_out_file.write(' '.join(pred).strip() + '\n')
                #self.raw_gold_out_file.write(' '.join(gold).strip() + '\n')
                if save_by_id:
                    with open(
                            pjoin(self.args.savepath,
                                  'temp/{}.txt'.format(id_)), 'a') as f:
                        f.write('{}\t{}\t{}\n'.format(parent_idx, self_idx,
                                                      pred_str))
                    #with open(pjoin(self.args.savepath,'temp_gold/{}.txt'.format(id_)), 'a') as f:
                    #    f.write('{}\t{}\t{}\n'.format(parent_idx, self_idx, gold_str))

                if save:
                    if have_gold:
                        self.out_file.write(
                            'Predict:\t{}\nGold:\t{}\nSource:\t{}\n\n'.format(
                                pred_str, gold_str, src.strip()))
                        self.gold_out_file.write(gold_str + '\n')
                    else:
                        self.can_out_file.write(pred_str + '\n')
                        self.out_file.write(
                            'Predict:\t{}\nSource:\t{}\n'.format(
                                pred_str, src.strip()))

                        ori_df = pd.read_csv(pjoin(
                            self.args.savepath,
                            "gen_result/{}.txt".format(id_)),
                                             delimiter="\t")
                        ori_df.loc[ori_df["exp"] == info,
                                   "generated"] = pred_str
                        ori_df.to_csv(pjoin(self.args.savepath,
                                            "gen_result/{}.txt".format(id_)),
                                      index=False,
                                      sep="\t")

                        #with open(os.path.join(self.args.savepath, "gen_result/{}.txt".format(id_)), 'a') as f:
                        #    f.write("[Generated Response]:\t{}\n".format(pred_str))
                        #    if have_gold:
                        #        f.write("[Reference Response]:\t{}\n".format(gold_str))

                    #self.src_out_file.write(src.strip() + '\n')
                    #conds = torch.load(self.args.cache_path+'/{}.pt'.format(id_))
                    #cond = conds[int(self_idx),0].item()

                    #self.ididx_out_file.write('{}\t{}\t{}\t{}\n'.format(id_,parent_idx,self_idx, cond))
                    #ct += 1
            if save:
                self.out_file.flush()
                self.can_out_file.flush()
                if have_gold:
                    self.gold_out_file.flush()
                #self.src_out_file.flush()
                #self.ididx_out_file.flush()

        Parallel(n_jobs=1, backend='threading')(delayed(_translate)(batch)
                                                for batch in tqdm(data_iter))

        if save:
            self.out_file.close()
            if have_gold:
                self.can_out_file.close()
                self.gold_out_file.close()
            #self.src_out_file.close()
            #self.ididx_out_file.close()

        if save and cal_rouge and have_gold:
            rouges = self._report_rouge(gold_path, can_path)
            if self.logger:
                self.logger.info('Rouges at epoch {} \n{}'.format(
                    epoch, rouge_results_to_str(rouges)))
            if self.tensorboard_writer is not None and epoch != 'best':
                self.tensorboard_writer.add_scalar('test/rouge1-F',
                                                   rouges['rouge_1_f_score'],
                                                   epoch)
                self.tensorboard_writer.add_scalar('test/rouge2-F',
                                                   rouges['rouge_2_f_score'],
                                                   epoch)
                self.tensorboard_writer.add_scalar('test/rougeL-F',
                                                   rouges['rouge_l_f_score'],
                                                   epoch)
            with open(os.path.join(self.args.savepath, 'result/rouge.txt'),
                      'a') as f:
                f.write('Epoch {}\n'.format(epoch))
                f.write('--------------------------------------\n')
                f.write('test/rouge1-R: {}\n'.format(rouges['rouge_1_recall']))
                f.write('test/rouge1-P: {}\n'.format(
                    rouges['rouge_1_precision']))
                f.write('test/rouge1-F: {}\n'.format(
                    rouges['rouge_1_f_score']))
                f.write('--------------------------------------\n')
                f.write('test/rouge2-R: {}\n'.format(rouges['rouge_2_recall']))
                f.write('test/rouge2-P: {}\n'.format(
                    rouges['rouge_2_precision']))
                f.write('test/rouge2-F: {}\n'.format(
                    rouges['rouge_2_f_score']))
                f.write('--------------------------------------\n')
                f.write('test/rougeL-R: {}\n'.format(rouges['rouge_l_recall']))
                f.write('test/rougeL-P: {}\n'.format(
                    rouges['rouge_l_precision']))
                f.write('test/rougeL-F: {}\n'.format(
                    rouges['rouge_l_f_score']))
                f.write('======================================\n')
Пример #14
0
    def translate(self, data_iter, step, attn_debug=False):

        self.model.eval()
        gold_path = self.args.result_path + '.%d.gold' % step
        can_path = self.args.result_path + '.%d.candidate' % step
        gold_str_path = self.args.result_path + '.%d.goldstr' % step
        can_str_path = self.args.result_path + '.%d.canstr' % step

        eng_gold_path = self.args.result_path + '.%d.gold_eng' % step
        self.gold_eng_out_file = codecs.open(eng_gold_path, 'w', 'utf-8')

        # self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        # self.can_out_file = codecs.open(can_path, 'w', 'utf-8')

        # raw_gold_path = self.args.result_path + '.%d.raw_gold' % step
        # raw_can_path = self.args.result_path + '.%d.raw_candidate' % step
        self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        self.can_out_file = codecs.open(can_path, 'w', 'utf-8')

        self.gold_str_out_file = codecs.open(gold_str_path, 'w', 'utf-8')
        self.can_str_out_file = codecs.open(can_str_path, 'w', 'utf-8')

        raw_src_path = self.args.result_path + '.%d.raw_src' % step
        self.src_out_file = codecs.open(raw_src_path, 'w', 'utf-8')

        # pred_results, gold_results = [], []
        ct = 0

        with torch.no_grad():
            for batch in data_iter:
                if (self.args.recall_eval):
                    gold_tgt_len = batch.tgt.size(1)
                    self.min_length = gold_tgt_len + 20
                    self.max_length = gold_tgt_len + 60
                batch_data = self.translate_batch(batch)
                translations = self.from_batch(batch_data)

                for trans in translations:
                    print(time.asctime(time.localtime(time.time())),
                          "----- now is the test sample: ",
                          ct,
                          end='\r')

                    # pred, gold, src, pred_strrr, gold_strrr,  = trans
                    pred, gold, src, pred_strrr, gold_strrr, gold_eng = trans
                    if self.args.bart:
                        pred_str = pred.replace('madeupword0000', '').replace(
                            'madeupword0001', '').replace('<pad>', '').replace(
                                '<unk>', '').replace(r' +', ' ').replace(
                                    ' madeupword0002 ',
                                    '<q>').replace('madeupword0002',
                                                   '').strip()

                        # 这里由于target也是从id转过来的了,所以做了如下变换。
                        gold_str = gold.replace('[unused1]', '').replace(
                            '[unused4]', '').replace('[PAD]', '').replace(
                                '[unused2]', '').replace(r' +', ' ').replace(
                                    ' [unused3] ',
                                    '<q>').replace('[unused3]', '').strip()
                    else:
                        pred_str = pred.replace('[unused1]', '').replace(
                            '[unused4]', '').replace('[PAD]', '').replace(
                                '[unused2]', '').replace(r' +', ' ').replace(
                                    ' [unused3] ',
                                    '<q>').replace('[unused3]', '').strip()

                        # 这里由于target也是从id转过来的了,所以做了如下变换。
                        gold_str = gold.replace('[unused1]', '').replace(
                            '[unused4]', '').replace('[PAD]', '').replace(
                                '[unused2]', '').replace(r' +', ' ').replace(
                                    ' [unused3] ',
                                    '<q>').replace('[unused3]', '').strip()
                    # gold_str = gold.strip()
                    if (self.args.recall_eval):
                        _pred_str = ''
                        gap = 1e3
                        for sent in pred_str.split('<q>'):
                            can_pred_str = _pred_str + '<q>' + sent.strip()
                            can_gap = math.fabs(
                                len(_pred_str.split()) - len(gold_str.split()))
                            # if(can_gap>=gap):
                            if (len(can_pred_str.split()) >=
                                    len(gold_str.split()) + 10):
                                pred_str = _pred_str
                                break
                            else:
                                gap = can_gap
                                _pred_str = can_pred_str

                        # pred_str = ' '.join(pred_str.split()[:len(gold_str.split())])
                    # self.raw_can_out_file.write(' '.join(pred).strip() + '\n')
                    # self.raw_gold_out_file.write(' '.join(gold).strip() + '\n')
                    self.gold_eng_out_file.write(gold_eng + '\n')
                    self.can_out_file.write(pred_str + '\n')
                    self.gold_out_file.write(gold_str + '\n')
                    self.src_out_file.write(src.strip() + '\n')

                    # 下边是加的
                    self.can_str_out_file.write(pred_strrr + '\n')
                    self.gold_str_out_file.write(gold_strrr + '\n')
                    # print("pred_strrr = ", pred_strrr)
                    # print("gold_strrr = ", gold_strrr)
                    ct += 1
                self.can_out_file.flush()
                self.gold_out_file.flush()
                self.src_out_file.flush()
                self.can_str_out_file.flush()
                self.gold_str_out_file.flush()

        self.can_out_file.close()
        self.gold_out_file.close()
        self.src_out_file.close()
        self.can_str_out_file.close()
        self.gold_str_out_file.close()

        if (step != -1):
            rouges = self._report_rouge(gold_path, can_path)
            self.logger.info('Rouges at step %d \n%s' %
                             (step, rouge_results_to_str(rouges)))
            if self.tensorboard_writer is not None:
                self.tensorboard_writer.add_scalar('test/rouge1-F',
                                                   rouges['rouge_1_f_score'],
                                                   step)
                self.tensorboard_writer.add_scalar('test/rouge2-F',
                                                   rouges['rouge_2_f_score'],
                                                   step)
                self.tensorboard_writer.add_scalar('test/rougeL-F',
                                                   rouges['rouge_l_f_score'],
                                                   step)
Пример #15
0
    def translate(self, data_iter, step, attn_debug=False):

        self.model.eval()
        gold_path = self.args.result_path + '.%d.gold' % step
        can_path = self.args.result_path + '.%d.candidate' % step
        self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        self.can_out_file = codecs.open(can_path, 'w', 'utf-8')

        # raw_gold_path = self.args.result_path + '.%d.raw_gold' % step
        # raw_can_path = self.args.result_path + '.%d.raw_candidate' % step
        self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        self.can_out_file = codecs.open(can_path, 'w', 'utf-8')

        raw_src_path = self.args.result_path + '.%d.raw_src' % step
        self.src_out_file = codecs.open(raw_src_path, 'w', 'utf-8')

        # pred_results, gold_results = [], []
        ct = 0
        with torch.no_grad():
            for batch in data_iter:
                if (self.args.recall_eval):
                    gold_tgt_len = batch.tgt.size(1)
                    self.min_length = gold_tgt_len + 20
                    self.max_length = gold_tgt_len + 60
                batch_data = self.translate_batch(batch)
                translations = self.from_batch(batch_data)

                for trans in translations:
                    pred, gold, src = trans
                    pred_str = pred.replace('[unused6]', '').replace(
                        '[unused3]', '').replace('[PAD]', '').replace(
                            '[unused1]', '').replace(r' +', ' ').replace(
                                ' [unused2] ', '<q>').replace('[unused2]',
                                                              '').strip()
                    gold_str = gold.strip()
                    if (self.args.recall_eval):
                        _pred_str = ''
                        gap = 1e3
                        for sent in pred_str.split('<q>'):
                            can_pred_str = _pred_str + '<q>' + sent.strip()
                            can_gap = math.fabs(
                                len(_pred_str.split()) - len(gold_str.split()))
                            # if(can_gap>=gap):
                            if (len(can_pred_str.split()) >=
                                    len(gold_str.split()) + 10):
                                pred_str = _pred_str
                                break
                            else:
                                gap = can_gap
                                _pred_str = can_pred_str

                        # pred_str = ' '.join(pred_str.split()[:len(gold_str.split())])
                    # self.raw_can_out_file.write(' '.join(pred).strip() + '\n')
                    # self.raw_gold_out_file.write(' '.join(gold).strip() + '\n')
                    self.can_out_file.write(pred_str + '\n')
                    self.gold_out_file.write(gold_str + '\n')
                    self.src_out_file.write(src.strip() + '\n')
                    ct += 1
                self.can_out_file.flush()
                self.gold_out_file.flush()
                self.src_out_file.flush()

        self.can_out_file.close()
        self.gold_out_file.close()
        self.src_out_file.close()

        if (step != -1):
            rouges = self._report_rouge(gold_path, can_path)
            self.logger.info('Rouges at step %d \n%s' %
                             (step, rouge_results_to_str(rouges)))
            if self.tensorboard_writer is not None:
                self.tensorboard_writer.add_scalar('test/rouge1-F',
                                                   rouges['rouge_1_f_score'],
                                                   step)
                self.tensorboard_writer.add_scalar('test/rouge2-F',
                                                   rouges['rouge_2_f_score'],
                                                   step)
                self.tensorboard_writer.add_scalar('test/rougeL-F',
                                                   rouges['rouge_l_f_score'],
                                                   step)
Пример #16
0
import logger
from others.utils import test_rouge, rouge_results_to_str

rouges = test_rouge(self.args.temp_dir, can_path, gold_path)
logger.info('Rouges at step %d \n%s' % (30000, rouge_results_to_str(rouges)))
Пример #17
0
    def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s))>0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        can_path = '%s_step%d.candidate'%(self.args.result_path,step)
        gold_path = '%s_step%d.gold' % (self.args.result_path, step)
        with open(can_path, 'w') as save_pred:
            with open(gold_path, 'w') as save_gold:
                with torch.no_grad():
                    for batch in test_iter:
                        src = batch.src
                        labels = batch.labels
                        segs = batch.segs
                        clss = batch.clss
                        mask = batch.mask
                        mask_cls = batch.mask_cls


                        gold = []
                        pred = []

                        if (cal_lead):
                            selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size
                        elif (cal_oracle):
                            selected_ids = [[j for j in range(batch.clss.size(1)) if labels[i][j] == 1] for i in
                                            range(batch.batch_size)]
                        else:
                            sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)

                            loss = self.loss(sent_scores, labels.float())
                            loss = (loss * mask.float()).sum()
                            batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels))
                            stats.update(batch_stats)

                            sent_scores = sent_scores + mask.float()
                            sent_scores = sent_scores.cpu().data.numpy()
                            selected_ids = np.argsort(-sent_scores, 1)
                        # selected_ids = np.sort(selected_ids,1)
                        for i, idx in enumerate(selected_ids):
                            _pred = []

                            if(len(batch.src_str[i])==0):
                                continue

                            for j in selected_ids[i][:len(batch.src_str[i])]:
                                if(j>=len( batch.src_str[i])):
                                    continue
                                candidate = batch.src_str[i][j].strip()
                                if(self.args.block_trigram):
                                    if(not _block_tri(candidate,_pred)):
                                        _pred.append(candidate)
                                else:
                                    _pred.append(candidate)

                                if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3):
                                    break

                            _pred = '<q>'.join(_pred)
                            _pred=_pred+"   original txt:   "+" ".join(batch.src_str[i])
                            if(self.args.recall_eval):
                                _pred = ' '.join(_pred.split()[:len(batch.tgt_str[i].split())])

                            pred.append(_pred)
                            gold.append(batch.tgt_str[i])

                        for i in range(len(gold)):
                            save_gold.write(gold[i].strip()+'\n')
                        for i in range(len(pred)):
                            save_pred.write(pred[i].strip()+'\n')
        if(step!=-1 and self.args.report_rouge):
            rouges = test_rouge(self.args.temp_dir, can_path, gold_path)
            logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges)))
        self._report_step(0, step, valid_stats=stats)

        return stats
Пример #18
0
    if (args.task == 'abs'):
        if (args.mode == 'train'):
            train_abs(args, device_id)
        elif (args.mode == 'validate'):
            validate_abs(args, device_id)
        elif (args.mode == 'score'):
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir)
            symbols = {'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'],
                       'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]']}
            predictor = build_predictor(args, tokenizer, symbols, None, logger)

            # step = 30000
            gold_path = args.result_path + '.gold'
            can_path = args.result_path + '.candidate'
            rouges = predictor._report_rouge(gold_path, can_path)
            logger.info('Rouges at step %d \n%s' % (args.result_path, rouge_results_to_str(rouges)))
            # if self.tensorboard_writer is not None:
            #     self.tensorboard_writer.add_scalar('test/rouge1-F', rouges['rouge_1_f_score'], step)
            #     self.tensorboard_writer.add_scalar('test/rouge2-F', rouges['rouge_2_f_score'], step)
            #     self.tensorboard_writer.add_scalar('test/rougeL-F', rouges['rouge_l_f_score'], step)
        elif (args.mode == 'lead'):
            baseline(args, cal_lead=True)
        elif (args.mode == 'oracle'):
            baseline(args, cal_oracle=True)
        if (args.mode == 'test'):
            cp = args.test_from
            try:
                step = int(cp.split('.')[-2].split('_')[-1])
            except:
                step = 0
            test_abs(args, device_id, cp, step)