def validate(self, data_iter, step, attn_debug=False):

        self.model.eval()
        gold_path = self.args.result_path + 'step.%d.gold_temp' % step
        pred_path = self.args.result_path + 'step.%d.pred_temp' % step
        gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        pred_out_file = codecs.open(pred_path, 'w', 'utf-8')

        ct = 0
        ext_acc_num = 0
        ext_pred_num = 0
        ext_gold_num = 0

        with torch.no_grad():
            for batch in data_iter:
                output_data, tgt_data, ext_pred, ext_gold = self.translate_batch(
                    batch)
                translations = self.from_batch_dev(output_data, tgt_data)

                for idx in range(len(translations)):
                    if ct % 100 == 0:
                        print("Processing %d" % ct)
                    pred_summ, gold_data = translations[idx]
                    # ext f1 calculate
                    acc_num = len(ext_pred[idx] + ext_gold[idx]) - len(
                        set(ext_pred[idx] + ext_gold[idx]))
                    pred_num = len(ext_pred[idx])
                    gold_num = len(ext_gold[idx])
                    ext_acc_num += acc_num
                    ext_pred_num += pred_num
                    ext_gold_num += gold_num
                    pred_out_file.write(pred_summ + '\n')
                    gold_out_file.write(gold_data + '\n')
                    ct += 1
                pred_out_file.flush()
                gold_out_file.flush()

        pred_out_file.close()
        gold_out_file.close()

        if (step != -1):
            pred_bleu = test_bleu(pred_path, gold_path)
            file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path)
            pred_rouges = file_rouge.get_scores(avg=True)
            f1, p, r = test_f1(ext_acc_num, ext_pred_num, ext_gold_num)
            self.logger.info(
                'Ext Sent Score at step %d: \n>> P/R/F1: %.2f/%.2f/%.2f' %
                (step, p * 100, r * 100, f1 * 100))
            self.logger.info(
                'Gold Length at step %d: %.2f' %
                (step, test_length(gold_path, gold_path, ratio=False)))
            self.logger.info('Prediction Length ratio at step %d: %.2f' %
                             (step, test_length(pred_path, gold_path)))
            self.logger.info('Prediction Bleu at step %d: %.2f' %
                             (step, pred_bleu * 100))
            self.logger.info('Prediction Rouges at step %d: \n%s\n' %
                             (step, rouge_results_to_str(pred_rouges)))
            rouge_results = (pred_rouges["rouge-1"]['f'],
                             pred_rouges["rouge-l"]['f'])
        return rouge_results
Beispiel #2
0
    def validate(self, data_iter, step, attn_debug=False):

        self.model.eval()
        gold_path = self.args.result_path + '.step.%d.gold_temp' % step
        pred_path = self.args.result_path + '.step.%d.pred_temp' % step
        gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        pred_out_file = codecs.open(pred_path, 'w', 'utf-8')

        # pred_results, gold_results = [], []
        ct = 0
        with torch.no_grad():
            for batch in data_iter:
                doc_data, summ_data = self.translate_batch(batch)
                translations = self.from_batch_dev(batch, doc_data)

                for idx in range(len(translations)):
                    if ct % 100 == 0:
                        print("Processing %d" % ct)
                    doc_short_context = translations[idx][1]
                    gold_data = summ_data[idx]
                    pred_out_file.write(doc_short_context + '\n')
                    gold_out_file.write(gold_data + '\n')
                    ct += 1
                pred_out_file.flush()
                gold_out_file.flush()

        pred_out_file.close()
        gold_out_file.close()

        if (step != -1):
            pred_bleu = test_bleu(pred_path, gold_path)
            file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path)
            pred_rouges = file_rouge.get_scores(avg=True)
            self.logger.info(
                'Gold Length at step %d: %.2f' %
                (step, test_length(gold_path, gold_path, ratio=False)))
            self.logger.info('Prediction Length ratio at step %d: %.2f' %
                             (step, test_length(pred_path, gold_path)))
            self.logger.info('Prediction Bleu at step %d: %.2f' %
                             (step, pred_bleu * 100))
            self.logger.info('Prediction Rouges at step %d: \n%s\n' %
                             (step, rouge_results_to_str(pred_rouges)))
            rouge_results = (pred_rouges["rouge-1"]['f'],
                             pred_rouges["rouge-l"]['f'])
        return rouge_results
    def translate(self, data_iter, step, attn_debug=False):

        self.model.eval()
        output_path = self.args.result_path + '.%d.output' % step
        output_file = codecs.open(output_path, 'w', 'utf-8')
        gold_path = self.args.result_path + '.%d.gold_test' % step
        pred_path = self.args.result_path + '.%d.pred_test' % step
        gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        pred_out_file = codecs.open(pred_path, 'w', 'utf-8')
        # pred_results, gold_results = [], []

        ct = 0
        ext_acc_num = 0
        ext_pred_num = 0
        ext_gold_num = 0

        with torch.no_grad():
            rouge = Rouge()
            for batch in data_iter:
                output_data, tgt_data, ext_pred, ext_gold = self.translate_batch(
                    batch)
                translations = self.from_batch_test(batch, output_data,
                                                    tgt_data)

                for idx in range(len(translations)):
                    origin_sent, pred_summ, gold_data = translations[idx]
                    if ct % 100 == 0:
                        print("Processing %d" % ct)
                    output_file.write("ID      : %d\n" % ct)
                    output_file.write("ORIGIN  : \n    " +
                                      origin_sent.replace('<S>', '\n    ') +
                                      "\n")
                    output_file.write("GOLD    : " + gold_data.strip() + "\n")
                    output_file.write("DOC_GEN : " + pred_summ.strip() + "\n")
                    rouge_score = rouge.get_scores(pred_summ, gold_data)
                    bleu_score = sentence_bleu(
                        [gold_data.split()],
                        pred_summ.split(),
                        smoothing_function=SmoothingFunction().method1)
                    output_file.write(
                        "DOC_GEN  bleu & rouge-f 1/2/l:    %.4f & %.4f/%.4f/%.4f\n"
                        % (bleu_score, rouge_score[0]["rouge-1"]["f"],
                           rouge_score[0]["rouge-2"]["f"],
                           rouge_score[0]["rouge-l"]["f"]))
                    # ext f1 calculate
                    acc_num = len(ext_pred[idx] + ext_gold[idx]) - len(
                        set(ext_pred[idx] + ext_gold[idx]))
                    pred_num = len(ext_pred[idx])
                    gold_num = len(ext_gold[idx])
                    ext_acc_num += acc_num
                    ext_pred_num += pred_num
                    ext_gold_num += gold_num
                    f1, p, r = test_f1(acc_num, pred_num, gold_num)
                    output_file.write(
                        "EXT_GOLD: [" +
                        ','.join([str(i)
                                  for i in sorted(ext_gold[idx])]) + "]\n")
                    output_file.write(
                        "EXT_PRED: [" +
                        ','.join([str(i)
                                  for i in sorted(ext_pred[idx])]) + "]\n")
                    output_file.write(
                        "EXT_SCORE  P/R/F1:    %.4f/%.4f/%.4f\n\n" %
                        (p, r, f1))
                    pred_out_file.write(pred_summ.strip() + '\n')
                    gold_out_file.write(gold_data.strip() + '\n')
                    ct += 1
                pred_out_file.flush()
                gold_out_file.flush()
                output_file.flush()

        pred_out_file.close()
        gold_out_file.close()
        output_file.close()

        if (step != -1):
            pred_bleu = test_bleu(pred_path, gold_path)
            file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path)
            pred_rouges = file_rouge.get_scores(avg=True)
            f1, p, r = test_f1(ext_acc_num, ext_pred_num, ext_gold_num)
            self.logger.info(
                'Ext Sent Score at step %d: \n>> P/R/F1: %.2f/%.2f/%.2f' %
                (step, p * 100, r * 100, f1 * 100))
            self.logger.info(
                'Gold Length at step %d: %.2f' %
                (step, test_length(gold_path, gold_path, ratio=False)))
            self.logger.info('Prediction Length ratio at step %d: %.2f' %
                             (step, test_length(pred_path, gold_path)))
            self.logger.info('Prediction Bleu at step %d: %.2f' %
                             (step, pred_bleu * 100))
            self.logger.info('Prediction Rouges at step %d: \n%s' %
                             (step, rouge_results_to_str(pred_rouges)))
Beispiel #4
0
    def translate(self, data_iter, step, attn_debug=False):

        self.model.eval()
        output_path = self.args.result_path + '.%d.output' % step
        output_file = codecs.open(output_path, 'w', 'utf-8')
        gold_path = self.args.result_path + '.%d.gold_test' % step
        pred_path = self.args.result_path + '.%d.pred_test' % step
        ex_single_path = self.args.result_path + '.%d.ex_test' % step + ".short"
        ex_context_path = self.args.result_path + '.%d.ex_test' % step + ".long"
        gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        pred_out_file = codecs.open(pred_path, 'w', 'utf-8')
        short_ex_out_file = codecs.open(ex_single_path, 'w', 'utf-8')
        long_ex_out_file = codecs.open(ex_context_path, 'w', 'utf-8')
        # pred_results, gold_results = [], []

        ct = 0
        with torch.no_grad():
            rouge = Rouge()
            for batch in data_iter:
                doc_data, summ_data = self.translate_batch(batch)
                translations = self.from_batch_test(batch, doc_data)

                for idx in range(len(translations)):
                    origin_sent, doc_extract, context_doc_extract, \
                        doc_pred, lead = translations[idx]
                    if ct % 100 == 0:
                        print("Processing %d" % ct)
                    output_file.write("ID      : %d\n" % ct)
                    output_file.write(
                        "ORIGIN  : " +
                        origin_sent.replace('<S>', '\n          ') + "\n")
                    gold_data = summ_data[idx]
                    output_file.write("GOLD    : " + gold_data + "\n")
                    output_file.write("LEAD    : " + lead + "\n")
                    output_file.write("DOC_EX  : " + doc_extract.strip() +
                                      "\n")
                    output_file.write("DOC_CONT: " +
                                      context_doc_extract.strip() + "\n")
                    output_file.write("DOC_GEN : " + doc_pred.strip() + "\n")

                    gold_list = gold_data.strip().split()
                    lead_list = lead.strip().replace("[unused2]", "").replace(
                        "[unused3]", "").split()
                    rouge_score = rouge.get_scores(lead, gold_data)
                    bleu_score = sentence_bleu(
                        [gold_list],
                        lead_list,
                        smoothing_function=SmoothingFunction().method1)
                    output_file.write(
                        "LEAD     bleu & rouge-f 1/2/l:    %.4f & %.4f/%.4f/%.4f\n"
                        % (bleu_score, rouge_score[0]["rouge-1"]["f"],
                           rouge_score[0]["rouge-2"]["f"],
                           rouge_score[0]["rouge-l"]["f"]))

                    doc_extract_list = doc_extract.strip().replace(
                        "[unused2]", "").replace("[unused3]", "").split()
                    rouge_score = rouge.get_scores(doc_extract, gold_data)
                    bleu_score = sentence_bleu(
                        [gold_list],
                        doc_extract_list,
                        smoothing_function=SmoothingFunction().method1)
                    output_file.write(
                        "DOC_EX   bleu & rouge-f 1/2/l:    %.4f & %.4f/%.4f/%.4f\n"
                        % (bleu_score, rouge_score[0]["rouge-1"]["f"],
                           rouge_score[0]["rouge-2"]["f"],
                           rouge_score[0]["rouge-l"]["f"]))

                    doc_context_list = context_doc_extract.strip().replace(
                        "[unused2]", "").replace("[unused3]", "").split()
                    rouge_score = rouge.get_scores(context_doc_extract,
                                                   gold_data)
                    bleu_score = sentence_bleu(
                        [gold_list],
                        doc_context_list,
                        smoothing_function=SmoothingFunction().method1)
                    output_file.write(
                        "DOC_CONT bleu & rouge-f 1/2/l:    %.4f & %.4f/%.4f/%.4f\n"
                        % (bleu_score, rouge_score[0]["rouge-1"]["f"],
                           rouge_score[0]["rouge-2"]["f"],
                           rouge_score[0]["rouge-l"]["f"]))

                    doc_long_list = doc_pred.strip().replace(
                        "[unused2]", "").replace("[unused3]", "").split()
                    rouge_score = rouge.get_scores(doc_pred, gold_data)
                    bleu_score = sentence_bleu(
                        [gold_list],
                        doc_long_list,
                        smoothing_function=SmoothingFunction().method1)
                    output_file.write(
                        "DOC_GEN  bleu & rouge-f 1/2/l:    %.4f & %.4f/%.4f/%.4f\n\n"
                        % (bleu_score, rouge_score[0]["rouge-1"]["f"],
                           rouge_score[0]["rouge-2"]["f"],
                           rouge_score[0]["rouge-l"]["f"]))

                    short_ex_out_file.write(doc_extract.strip().replace(
                        "[unused2]", "").replace("[unused3]", "") + '\n')
                    long_ex_out_file.write(context_doc_extract.strip().replace(
                        "[unused2]", "").replace("[unused3]", "") + '\n')
                    pred_out_file.write(doc_pred.strip().replace(
                        "[unused2]", "").replace("[unused3]", "") + '\n')
                    gold_out_file.write(gold_data.strip() + '\n')
                    ct += 1
                pred_out_file.flush()
                short_ex_out_file.flush()
                long_ex_out_file.flush()
                gold_out_file.flush()
                output_file.flush()

        pred_out_file.close()
        short_ex_out_file.close()
        long_ex_out_file.close()
        gold_out_file.close()
        output_file.close()

        if (step != -1):
            ex_short_bleu = test_bleu(gold_path, ex_single_path)
            ex_long_bleu = test_bleu(gold_path, ex_context_path)
            pred_bleu = test_bleu(gold_path, pred_path)

            file_rouge = FilesRouge(hyp_path=ex_single_path,
                                    ref_path=gold_path)
            ex_short_rouges = file_rouge.get_scores(avg=True)

            file_rouge = FilesRouge(hyp_path=ex_context_path,
                                    ref_path=gold_path)
            ex_long_rouges = file_rouge.get_scores(avg=True)

            file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path)
            pred_rouges = file_rouge.get_scores(avg=True)

            self.logger.info(
                'Gold Length at step %d: %.2f\n' %
                (step, test_length(gold_path, gold_path, ratio=False)))
            self.logger.info('Short Extraction Length ratio at step %d: %.2f' %
                             (step, test_length(ex_single_path, gold_path)))
            self.logger.info('Short Extraction Bleu at step %d: %.2f' %
                             (step, ex_short_bleu * 100))
            self.logger.info('Short Extraction Rouges at step %d \n%s' %
                             (step, rouge_results_to_str(ex_short_rouges)))
            self.logger.info('Long Extraction Length ratio at step %d: %.2f' %
                             (step, test_length(ex_context_path, gold_path)))
            self.logger.info('Long Extraction Bleu at step %d: %.2f' %
                             (step, ex_long_bleu * 100))
            self.logger.info('Long Extraction Rouges at step %d \n%s' %
                             (step, rouge_results_to_str(ex_long_rouges)))
            self.logger.info('Prediction Length ratio at step %d: %.2f' %
                             (step, test_length(pred_path, gold_path)))
            self.logger.info('Prediction Bleu at step %d: %.2f' %
                             (step, pred_bleu * 100))
            self.logger.info('Prediction Rouges at step %d \n%s' %
                             (step, rouge_results_to_str(pred_rouges)))
def baseline(args, cal_lead=False, cal_oracle=False):
    test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False),
                                       args.test_batch_size, args.test_batch_ex_size, 'cpu',
                                       shuffle=False, is_test=True)

    if cal_lead:
        mode = "lead"
    else:
        mode = "oracle"

    rouge = Rouge()
    pred_path = '%s.%s.pred' % (args.result_path, mode)
    gold_path = '%s.%s.gold' % (args.result_path, mode)
    save_pred = open(pred_path, 'w', encoding='utf-8')
    save_gold = open(gold_path, 'w', encoding='utf-8')

    with torch.no_grad():
        for batch in test_iter:
            summaries = batch.summ_txt
            origin_sents = batch.original_str
            ex_segs = batch.ex_segs
            ex_segs = [sum(ex_segs[:i]) for i in range(len(ex_segs)+1)]

            for idx in range(len(summaries)):
                summary = summaries[idx]
                txt = origin_sents[ex_segs[idx]:ex_segs[idx+1]]
                if cal_oracle:
                    selected = []
                    max_rouge = 0.
                    while len(selected) < args.ranking_max_k:
                        cur_max_rouge = max_rouge
                        cur_id = -1
                        for i in range(len(txt)):
                            if (i in selected):
                                continue
                            c = selected + [i]
                            temp_txt = " ".join([txt[j] for j in c])
                            rouge_score = rouge.get_scores(temp_txt, summary)
                            rouge_1 = rouge_score[0]["rouge-1"]["f"]
                            rouge_l = rouge_score[0]["rouge-l"]["f"]
                            rouge_score = rouge_1 + rouge_l
                            if rouge_score > cur_max_rouge:
                                cur_max_rouge = rouge_score
                                cur_id = i
                        if (cur_id == -1):
                            break
                        selected.append(cur_id)
                        max_rouge = cur_max_rouge
                    pred_txt = " ".join([txt[j] for j in selected])
                else:
                    k = min(max(len(txt) // (2*args.win_size+1), 1), args.ranking_max_k)
                    pred_txt = " ".join(txt[:k])
                save_gold.write(summary + "\n")
                save_pred.write(pred_txt + "\n")
    save_gold.flush()
    save_pred.flush()
    save_gold.close()
    save_pred.close()

    length = test_length(pred_path, gold_path)
    bleu = test_bleu(pred_path, gold_path)
    file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path)
    pred_rouges = file_rouge.get_scores(avg=True)
    logger.info('Length ratio:\n%s' % str(length))
    logger.info('Bleu:\n%.2f' % (bleu*100))
    logger.info('Rouges:\n%s' % rouge_results_to_str(pred_rouges))
    def translate(self, data_iter, step, attn_debug=False):

        self.model.eval()
        output_path = self.args.result_path + '.%d.output' % step
        output_file = codecs.open(output_path, 'w', 'utf-8')
        gold_path = self.args.result_path + '.%d.gold_test' % step
        pred_path = self.args.result_path + '.%d.pred_test' % step
        gold_out_file = codecs.open(gold_path, 'w', 'utf-8')
        pred_out_file = codecs.open(pred_path, 'w', 'utf-8')
        # pred_results, gold_results = [], []
        ct = 0

        with torch.no_grad():
            rouge = Rouge()
            for batch in data_iter:
                output_data, tgt_data = self.translate_batch(batch)
                translations = self.from_batch_test(batch, output_data,
                                                    tgt_data)

                for idx in range(len(translations)):
                    origin_sent, pred_summ, gold_data = translations[idx]
                    if ct % 100 == 0:
                        print("Processing %d" % ct)
                    output_file.write("ID      : %d\n" % ct)
                    output_file.write("ORIGIN  : \n    " +
                                      origin_sent.replace('<S>', '\n    ') +
                                      "\n")
                    output_file.write("GOLD    : " + gold_data.strip() + "\n")
                    output_file.write("DOC_GEN : " + pred_summ.strip() + "\n")
                    rouge_score = rouge.get_scores(pred_summ, gold_data)
                    bleu_score = sentence_bleu(
                        [gold_data.split()],
                        pred_summ.split(),
                        smoothing_function=SmoothingFunction().method1)
                    output_file.write(
                        "DOC_GEN  bleu & rouge-f 1/l:    %.4f & %.4f/%.4f\n\n"
                        % (bleu_score, rouge_score[0]["rouge-1"]["f"],
                           rouge_score[0]["rouge-l"]["f"]))
                    pred_out_file.write(pred_summ.strip() + '\n')
                    gold_out_file.write(gold_data.strip() + '\n')
                    ct += 1
                pred_out_file.flush()
                gold_out_file.flush()
                output_file.flush()

        pred_out_file.close()
        gold_out_file.close()
        output_file.close()

        if (step != -1):
            pred_bleu = test_bleu(pred_path, gold_path)
            file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path)
            pred_rouges = file_rouge.get_scores(avg=True)
            self.logger.info(
                'Gold Length at step %d: %.2f\n' %
                (step, test_length(gold_path, gold_path, ratio=False)))
            self.logger.info('Prediction Length ratio at step %d: %.2f' %
                             (step, test_length(pred_path, gold_path)))
            self.logger.info('Prediction Bleu at step %d: %.2f' %
                             (step, pred_bleu * 100))
            self.logger.info('Prediction Rouges at step %d: \n%s' %
                             (step, rouge_results_to_str(pred_rouges)))