def report_rouge(logger, args, gold_path, can_path): logger.info("Calculating Rouge") rouges = test_rouge(args.temp_dir, can_path, gold_path) logger.info('Rouges:\n%s' % rouge_results_to_str(rouges)) avg_f1 = avg_rouge_f1(rouges) logger.info('Average Rouge F1: %s' % avg_f1) return avg_f1
def validate(self, data_iter, step, attn_debug=False): self.model.eval() gold_path = self.args.result_path + 'step.%d.gold_temp' % step pred_path = self.args.result_path + 'step.%d.pred_temp' % step gold_out_file = codecs.open(gold_path, 'w', 'utf-8') pred_out_file = codecs.open(pred_path, 'w', 'utf-8') ct = 0 ext_acc_num = 0 ext_pred_num = 0 ext_gold_num = 0 with torch.no_grad(): for batch in data_iter: output_data, tgt_data, ext_pred, ext_gold = self.translate_batch( batch) translations = self.from_batch_dev(output_data, tgt_data) for idx in range(len(translations)): if ct % 100 == 0: print("Processing %d" % ct) pred_summ, gold_data = translations[idx] # ext f1 calculate acc_num = len(ext_pred[idx] + ext_gold[idx]) - len( set(ext_pred[idx] + ext_gold[idx])) pred_num = len(ext_pred[idx]) gold_num = len(ext_gold[idx]) ext_acc_num += acc_num ext_pred_num += pred_num ext_gold_num += gold_num pred_out_file.write(pred_summ + '\n') gold_out_file.write(gold_data + '\n') ct += 1 pred_out_file.flush() gold_out_file.flush() pred_out_file.close() gold_out_file.close() if (step != -1): pred_bleu = test_bleu(pred_path, gold_path) file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path) pred_rouges = file_rouge.get_scores(avg=True) f1, p, r = test_f1(ext_acc_num, ext_pred_num, ext_gold_num) self.logger.info( 'Ext Sent Score at step %d: \n>> P/R/F1: %.2f/%.2f/%.2f' % (step, p * 100, r * 100, f1 * 100)) self.logger.info( 'Gold Length at step %d: %.2f' % (step, test_length(gold_path, gold_path, ratio=False))) self.logger.info('Prediction Length ratio at step %d: %.2f' % (step, test_length(pred_path, gold_path))) self.logger.info('Prediction Bleu at step %d: %.2f' % (step, pred_bleu * 100)) self.logger.info('Prediction Rouges at step %d: \n%s\n' % (step, rouge_results_to_str(pred_rouges))) rouge_results = (pred_rouges["rouge-1"]['f'], pred_rouges["rouge-l"]['f']) return rouge_results
def translate(self, data_iter, step, attn_debug=False): self.model.eval() gold_path = self.args.result_path + '.%d.gold' % step can_path = self.args.result_path + '.%d.candidate' % step self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8') self.can_out_file = codecs.open(can_path, 'w', 'utf-8') raw_src_path = self.args.result_path + '.%d.raw_src' % step self.src_out_file = codecs.open(raw_src_path, 'w', 'utf-8') example_id_path = self.args.result_path + '.%d.example_id' % step self.example_id_file = codecs.open(example_id_path, 'w', 'utf-8') ct = 0 with torch.no_grad(): for batch in data_iter: batch_data = self.translate_batch(batch) translations = self.from_batch(batch_data, batch) for trans in translations: pred, gold, src, example_id = trans pred_str = pred.replace('[unused0]', '')\ .replace('[unused1]', '')\ .replace('[PAD]', '')\ .replace('[SEP]', '')\ .replace('[UNK]', '')\ .replace(r' +', ' ').strip() gold_str = gold.strip() self.can_out_file.write(pred_str + '\n') self.gold_out_file.write(gold_str + '\n') self.src_out_file.write(src.strip() + '\n') self.example_id_file.write(str(example_id) + '\n') ct += 1 self.can_out_file.flush() self.gold_out_file.flush() self.src_out_file.flush() self.example_id_file.flush() self.can_out_file.close() self.gold_out_file.close() self.src_out_file.close() self.example_id_file.close() if (step != -1): rouges = self._report_rouge(gold_path, can_path) self.logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges))) if self.tensorboard_writer is not None: self.tensorboard_writer.add_scalar('test/rouge1-F', rouges['rouge_1_f_score'], step) self.tensorboard_writer.add_scalar('test/rouge2-F', rouges['rouge_2_f_score'], step) self.tensorboard_writer.add_scalar('test/rougeL-F', rouges['rouge_l_f_score'], step)
def validate(self, data_iter, step, attn_debug=False): self.model.eval() gold_path = self.args.result_path + '.step.%d.gold_temp' % step pred_path = self.args.result_path + '.step.%d.pred_temp' % step gold_out_file = codecs.open(gold_path, 'w', 'utf-8') pred_out_file = codecs.open(pred_path, 'w', 'utf-8') # pred_results, gold_results = [], [] ct = 0 with torch.no_grad(): for batch in data_iter: doc_data, summ_data = self.translate_batch(batch) translations = self.from_batch_dev(batch, doc_data) for idx in range(len(translations)): if ct % 100 == 0: print("Processing %d" % ct) doc_short_context = translations[idx][1] gold_data = summ_data[idx] pred_out_file.write(doc_short_context + '\n') gold_out_file.write(gold_data + '\n') ct += 1 pred_out_file.flush() gold_out_file.flush() pred_out_file.close() gold_out_file.close() if (step != -1): pred_bleu = test_bleu(pred_path, gold_path) file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path) pred_rouges = file_rouge.get_scores(avg=True) self.logger.info( 'Gold Length at step %d: %.2f' % (step, test_length(gold_path, gold_path, ratio=False))) self.logger.info('Prediction Length ratio at step %d: %.2f' % (step, test_length(pred_path, gold_path))) self.logger.info('Prediction Bleu at step %d: %.2f' % (step, pred_bleu * 100)) self.logger.info('Prediction Rouges at step %d: \n%s\n' % (step, rouge_results_to_str(pred_rouges))) rouge_results = (pred_rouges["rouge-1"]['f'], pred_rouges["rouge-l"]['f']) return rouge_results
def test(self, test_iter, step, cal_lead=False, cal_oracle=False): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. def _get_ngrams(n, text): ngram_set = set() text_length = len(text) max_index_ngram_start = text_length - n for i in range(max_index_ngram_start + 1): ngram_set.add(tuple(text[i:i + n])) return ngram_set def _block_tri(c, p): tri_c = _get_ngrams(3, c.split()) for s in p: tri_s = _get_ngrams(3, s.split()) if len(tri_c.intersection(tri_s)) > 0: return True return False if (not cal_lead and not cal_oracle): self.model.eval() stats = Statistics() can_path = '%s_step%d.candidate' % (self.args.result_path, step) gold_path = '%s_step%d.gold' % (self.args.result_path, step) with open(can_path, 'w') as save_pred: with open(gold_path, 'w') as save_gold: with torch.no_grad(): for batch in test_iter: gold = [] pred = [] if (cal_lead): selected_ids = [list(range(batch.clss.size(1))) ] * batch.batch_size for i, idx in enumerate(selected_ids): _pred = [] if (len(batch.src_str[i]) == 0): continue for j in selected_ids[i][:len(batch.src_str[i])]: if (j >= len(batch.src_str[i])): continue candidate = batch.src_str[i][j].strip() _pred.append(candidate) if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3): break _pred = '<q>'.join(_pred) if (self.args.recall_eval): _pred = ' '.join( _pred.split() [:len(batch.tgt_str[i].split())]) pred.append(_pred) gold.append(batch.tgt_str[i]) for i in range(len(gold)): save_gold.write(gold[i].strip() + '\n') for i in range(len(pred)): save_pred.write(pred[i].strip() + '\n') if (step != -1 and self.args.report_rouge): rouges = test_rouge(self.args.temp_dir, can_path, gold_path) logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges))) self._report_step(0, step, valid_stats=stats) return stats
def translate(self, data_iter, step): """ Main control flow for decoding """ # Set model to eval mode for decoding self.model.eval() # Output file path gold_path = os.path.join(self.args.result_path, 'test.%d.gold' % step) can_path = os.path.join(self.args.result_path, 'test.%d.candidate' % step) raw_src_path = os.path.join(self.args.result_path, 'test.%d.raw_src' % step) self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8') self.can_out_file = codecs.open(can_path, 'w', 'utf-8') self.src_out_file = codecs.open(raw_src_path, 'w', 'utf-8') ct = 0 with torch.no_grad(): for batch in data_iter: # batch (:obj:data_loader.Batch) # data_iter (:ojb:data_loader.Dataloader) # Constraint prediction length close to gold length if self.args.recall_eval: gold_tgt_len = batch.tgt.size(1) self.min_length = gold_tgt_len + 20 self.max_length = gold_tgt_len + 60 # batch_data: type=dict # keys -> ['predictions', 'scores', 'gold_score', 'batch'] # translations: type=list # content -> (predict_sent, gold_sent, raw_src) batch_data = self.translate_batch(batch) translations = self.from_batch(batch_data) for trans in translations: pred, gold, src = trans # type=string src_str = src.strip() # type=string # [unused0] -> BOS # [unused1] -> EOS # [unused2] -> EOQ pred_str = pred.replace('[unused0]', '').replace( '[unused3]', '').replace('[PAD]', '').replace( '[unused1]', '').replace(r' +', ' ').replace( ' [unused2] ', '<q>').replace('[unused2]', '').strip() # type=string gold_str = gold.strip() # Constraint prediction length close to gold length if (self.args.recall_eval): _pred_str = '' for sent in pred_str.split('<q>'): # Accumulate pred_str sentence by sentnce can_pred_str = _pred_str + '<q>' + sent.strip() # Cut if length difference above 10 tokens if (len(can_pred_str.split()) >= len(gold_str.split()) + 10): pred_str = _pred_str break else: _pred_str = can_pred_str self.src_out_file.write(src_str + '\n') self.can_out_file.write(pred_str + '\n') self.gold_out_file.write(gold_str + '\n') ct += 1 # Flush the buffer self.can_out_file.flush() self.gold_out_file.flush() self.src_out_file.flush() # Close files self.can_out_file.close() self.gold_out_file.close() self.src_out_file.close() # Report results in console and log if (step != -1): rouges = self._report_rouge(gold_path, can_path) self.logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges))) if self.tensorboard_writer is not None: self.tensorboard_writer.add_scalar('test/rouge1-F', rouges['rouge_1_f_score'], step) self.tensorboard_writer.add_scalar('test/rouge2-F', rouges['rouge_2_f_score'], step) self.tensorboard_writer.add_scalar('test/rougeL-F', rouges['rouge_l_f_score'], step) mlflow.log_metric('Test_ROUGE1_F', rouges['rouge_1_f_score'], step) mlflow.log_metric('Test_ROUGE2_F', rouges['rouge_2_f_score'], step) mlflow.log_metric('Test_ROUGEL_F', rouges['rouge_l_f_score'], step)
def translate(self, data_iter, step, attn_debug=False): self.model.eval() output_path = self.args.result_path + '.%d.output' % step output_file = codecs.open(output_path, 'w', 'utf-8') gold_path = self.args.result_path + '.%d.gold_test' % step pred_path = self.args.result_path + '.%d.pred_test' % step gold_out_file = codecs.open(gold_path, 'w', 'utf-8') pred_out_file = codecs.open(pred_path, 'w', 'utf-8') # pred_results, gold_results = [], [] ct = 0 ext_acc_num = 0 ext_pred_num = 0 ext_gold_num = 0 with torch.no_grad(): rouge = Rouge() for batch in data_iter: output_data, tgt_data, ext_pred, ext_gold = self.translate_batch( batch) translations = self.from_batch_test(batch, output_data, tgt_data) for idx in range(len(translations)): origin_sent, pred_summ, gold_data = translations[idx] if ct % 100 == 0: print("Processing %d" % ct) output_file.write("ID : %d\n" % ct) output_file.write("ORIGIN : \n " + origin_sent.replace('<S>', '\n ') + "\n") output_file.write("GOLD : " + gold_data.strip() + "\n") output_file.write("DOC_GEN : " + pred_summ.strip() + "\n") rouge_score = rouge.get_scores(pred_summ, gold_data) bleu_score = sentence_bleu( [gold_data.split()], pred_summ.split(), smoothing_function=SmoothingFunction().method1) output_file.write( "DOC_GEN bleu & rouge-f 1/2/l: %.4f & %.4f/%.4f/%.4f\n" % (bleu_score, rouge_score[0]["rouge-1"]["f"], rouge_score[0]["rouge-2"]["f"], rouge_score[0]["rouge-l"]["f"])) # ext f1 calculate acc_num = len(ext_pred[idx] + ext_gold[idx]) - len( set(ext_pred[idx] + ext_gold[idx])) pred_num = len(ext_pred[idx]) gold_num = len(ext_gold[idx]) ext_acc_num += acc_num ext_pred_num += pred_num ext_gold_num += gold_num f1, p, r = test_f1(acc_num, pred_num, gold_num) output_file.write( "EXT_GOLD: [" + ','.join([str(i) for i in sorted(ext_gold[idx])]) + "]\n") output_file.write( "EXT_PRED: [" + ','.join([str(i) for i in sorted(ext_pred[idx])]) + "]\n") output_file.write( "EXT_SCORE P/R/F1: %.4f/%.4f/%.4f\n\n" % (p, r, f1)) pred_out_file.write(pred_summ.strip() + '\n') gold_out_file.write(gold_data.strip() + '\n') ct += 1 pred_out_file.flush() gold_out_file.flush() output_file.flush() pred_out_file.close() gold_out_file.close() output_file.close() if (step != -1): pred_bleu = test_bleu(pred_path, gold_path) file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path) pred_rouges = file_rouge.get_scores(avg=True) f1, p, r = test_f1(ext_acc_num, ext_pred_num, ext_gold_num) self.logger.info( 'Ext Sent Score at step %d: \n>> P/R/F1: %.2f/%.2f/%.2f' % (step, p * 100, r * 100, f1 * 100)) self.logger.info( 'Gold Length at step %d: %.2f' % (step, test_length(gold_path, gold_path, ratio=False))) self.logger.info('Prediction Length ratio at step %d: %.2f' % (step, test_length(pred_path, gold_path))) self.logger.info('Prediction Bleu at step %d: %.2f' % (step, pred_bleu * 100)) self.logger.info('Prediction Rouges at step %d: \n%s' % (step, rouge_results_to_str(pred_rouges)))
def translate(self, data_iter, step, attn_debug=False): self.model.eval() output_path = self.args.result_path + '.%d.output' % step output_file = codecs.open(output_path, 'w', 'utf-8') gold_path = self.args.result_path + '.%d.gold_test' % step pred_path = self.args.result_path + '.%d.pred_test' % step ex_single_path = self.args.result_path + '.%d.ex_test' % step + ".short" ex_context_path = self.args.result_path + '.%d.ex_test' % step + ".long" gold_out_file = codecs.open(gold_path, 'w', 'utf-8') pred_out_file = codecs.open(pred_path, 'w', 'utf-8') short_ex_out_file = codecs.open(ex_single_path, 'w', 'utf-8') long_ex_out_file = codecs.open(ex_context_path, 'w', 'utf-8') # pred_results, gold_results = [], [] ct = 0 with torch.no_grad(): rouge = Rouge() for batch in data_iter: doc_data, summ_data = self.translate_batch(batch) translations = self.from_batch_test(batch, doc_data) for idx in range(len(translations)): origin_sent, doc_extract, context_doc_extract, \ doc_pred, lead = translations[idx] if ct % 100 == 0: print("Processing %d" % ct) output_file.write("ID : %d\n" % ct) output_file.write( "ORIGIN : " + origin_sent.replace('<S>', '\n ') + "\n") gold_data = summ_data[idx] output_file.write("GOLD : " + gold_data + "\n") output_file.write("LEAD : " + lead + "\n") output_file.write("DOC_EX : " + doc_extract.strip() + "\n") output_file.write("DOC_CONT: " + context_doc_extract.strip() + "\n") output_file.write("DOC_GEN : " + doc_pred.strip() + "\n") gold_list = gold_data.strip().split() lead_list = lead.strip().replace("[unused2]", "").replace( "[unused3]", "").split() rouge_score = rouge.get_scores(lead, gold_data) bleu_score = sentence_bleu( [gold_list], lead_list, smoothing_function=SmoothingFunction().method1) output_file.write( "LEAD bleu & rouge-f 1/2/l: %.4f & %.4f/%.4f/%.4f\n" % (bleu_score, rouge_score[0]["rouge-1"]["f"], rouge_score[0]["rouge-2"]["f"], rouge_score[0]["rouge-l"]["f"])) doc_extract_list = doc_extract.strip().replace( "[unused2]", "").replace("[unused3]", "").split() rouge_score = rouge.get_scores(doc_extract, gold_data) bleu_score = sentence_bleu( [gold_list], doc_extract_list, smoothing_function=SmoothingFunction().method1) output_file.write( "DOC_EX bleu & rouge-f 1/2/l: %.4f & %.4f/%.4f/%.4f\n" % (bleu_score, rouge_score[0]["rouge-1"]["f"], rouge_score[0]["rouge-2"]["f"], rouge_score[0]["rouge-l"]["f"])) doc_context_list = context_doc_extract.strip().replace( "[unused2]", "").replace("[unused3]", "").split() rouge_score = rouge.get_scores(context_doc_extract, gold_data) bleu_score = sentence_bleu( [gold_list], doc_context_list, smoothing_function=SmoothingFunction().method1) output_file.write( "DOC_CONT bleu & rouge-f 1/2/l: %.4f & %.4f/%.4f/%.4f\n" % (bleu_score, rouge_score[0]["rouge-1"]["f"], rouge_score[0]["rouge-2"]["f"], rouge_score[0]["rouge-l"]["f"])) doc_long_list = doc_pred.strip().replace( "[unused2]", "").replace("[unused3]", "").split() rouge_score = rouge.get_scores(doc_pred, gold_data) bleu_score = sentence_bleu( [gold_list], doc_long_list, smoothing_function=SmoothingFunction().method1) output_file.write( "DOC_GEN bleu & rouge-f 1/2/l: %.4f & %.4f/%.4f/%.4f\n\n" % (bleu_score, rouge_score[0]["rouge-1"]["f"], rouge_score[0]["rouge-2"]["f"], rouge_score[0]["rouge-l"]["f"])) short_ex_out_file.write(doc_extract.strip().replace( "[unused2]", "").replace("[unused3]", "") + '\n') long_ex_out_file.write(context_doc_extract.strip().replace( "[unused2]", "").replace("[unused3]", "") + '\n') pred_out_file.write(doc_pred.strip().replace( "[unused2]", "").replace("[unused3]", "") + '\n') gold_out_file.write(gold_data.strip() + '\n') ct += 1 pred_out_file.flush() short_ex_out_file.flush() long_ex_out_file.flush() gold_out_file.flush() output_file.flush() pred_out_file.close() short_ex_out_file.close() long_ex_out_file.close() gold_out_file.close() output_file.close() if (step != -1): ex_short_bleu = test_bleu(gold_path, ex_single_path) ex_long_bleu = test_bleu(gold_path, ex_context_path) pred_bleu = test_bleu(gold_path, pred_path) file_rouge = FilesRouge(hyp_path=ex_single_path, ref_path=gold_path) ex_short_rouges = file_rouge.get_scores(avg=True) file_rouge = FilesRouge(hyp_path=ex_context_path, ref_path=gold_path) ex_long_rouges = file_rouge.get_scores(avg=True) file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path) pred_rouges = file_rouge.get_scores(avg=True) self.logger.info( 'Gold Length at step %d: %.2f\n' % (step, test_length(gold_path, gold_path, ratio=False))) self.logger.info('Short Extraction Length ratio at step %d: %.2f' % (step, test_length(ex_single_path, gold_path))) self.logger.info('Short Extraction Bleu at step %d: %.2f' % (step, ex_short_bleu * 100)) self.logger.info('Short Extraction Rouges at step %d \n%s' % (step, rouge_results_to_str(ex_short_rouges))) self.logger.info('Long Extraction Length ratio at step %d: %.2f' % (step, test_length(ex_context_path, gold_path))) self.logger.info('Long Extraction Bleu at step %d: %.2f' % (step, ex_long_bleu * 100)) self.logger.info('Long Extraction Rouges at step %d \n%s' % (step, rouge_results_to_str(ex_long_rouges))) self.logger.info('Prediction Length ratio at step %d: %.2f' % (step, test_length(pred_path, gold_path))) self.logger.info('Prediction Bleu at step %d: %.2f' % (step, pred_bleu * 100)) self.logger.info('Prediction Rouges at step %d \n%s' % (step, rouge_results_to_str(pred_rouges)))
def translate(self, data_iter, step, attn_debug=False): self.model.eval() output_path = self.args.result_path + '.%d.output' % step output_file = codecs.open(output_path, 'w', 'utf-8') gold_path = self.args.result_path + '.%d.gold_test' % step pred_path = self.args.result_path + '.%d.pred_test' % step gold_out_file = codecs.open(gold_path, 'w', 'utf-8') pred_out_file = codecs.open(pred_path, 'w', 'utf-8') # pred_results, gold_results = [], [] ct = 0 with torch.no_grad(): rouge = Rouge() for batch in data_iter: output_data, tgt_data = self.translate_batch(batch) translations = self.from_batch_test(batch, output_data, tgt_data) for idx in range(len(translations)): origin_sent, pred_summ, gold_data = translations[idx] if ct % 100 == 0: print("Processing %d" % ct) output_file.write("ID : %d\n" % ct) output_file.write("ORIGIN : \n " + origin_sent.replace('<S>', '\n ') + "\n") output_file.write("GOLD : " + gold_data.strip() + "\n") output_file.write("DOC_GEN : " + pred_summ.strip() + "\n") rouge_score = rouge.get_scores(pred_summ, gold_data) bleu_score = sentence_bleu( [gold_data.split()], pred_summ.split(), smoothing_function=SmoothingFunction().method1) output_file.write( "DOC_GEN bleu & rouge-f 1/l: %.4f & %.4f/%.4f\n\n" % (bleu_score, rouge_score[0]["rouge-1"]["f"], rouge_score[0]["rouge-l"]["f"])) pred_out_file.write(pred_summ.strip() + '\n') gold_out_file.write(gold_data.strip() + '\n') ct += 1 pred_out_file.flush() gold_out_file.flush() output_file.flush() pred_out_file.close() gold_out_file.close() output_file.close() if (step != -1): pred_bleu = test_bleu(pred_path, gold_path) file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path) pred_rouges = file_rouge.get_scores(avg=True) self.logger.info( 'Gold Length at step %d: %.2f\n' % (step, test_length(gold_path, gold_path, ratio=False))) self.logger.info('Prediction Length ratio at step %d: %.2f' % (step, test_length(pred_path, gold_path))) self.logger.info('Prediction Bleu at step %d: %.2f' % (step, pred_bleu * 100)) self.logger.info('Prediction Rouges at step %d: \n%s' % (step, rouge_results_to_str(pred_rouges)))
def baseline(args, cal_lead=False, cal_oracle=False): test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), args.test_batch_size, args.test_batch_ex_size, 'cpu', shuffle=False, is_test=True) if cal_lead: mode = "lead" else: mode = "oracle" rouge = Rouge() pred_path = '%s.%s.pred' % (args.result_path, mode) gold_path = '%s.%s.gold' % (args.result_path, mode) save_pred = open(pred_path, 'w', encoding='utf-8') save_gold = open(gold_path, 'w', encoding='utf-8') with torch.no_grad(): for batch in test_iter: summaries = batch.summ_txt origin_sents = batch.original_str ex_segs = batch.ex_segs ex_segs = [sum(ex_segs[:i]) for i in range(len(ex_segs)+1)] for idx in range(len(summaries)): summary = summaries[idx] txt = origin_sents[ex_segs[idx]:ex_segs[idx+1]] if cal_oracle: selected = [] max_rouge = 0. while len(selected) < args.ranking_max_k: cur_max_rouge = max_rouge cur_id = -1 for i in range(len(txt)): if (i in selected): continue c = selected + [i] temp_txt = " ".join([txt[j] for j in c]) rouge_score = rouge.get_scores(temp_txt, summary) rouge_1 = rouge_score[0]["rouge-1"]["f"] rouge_l = rouge_score[0]["rouge-l"]["f"] rouge_score = rouge_1 + rouge_l if rouge_score > cur_max_rouge: cur_max_rouge = rouge_score cur_id = i if (cur_id == -1): break selected.append(cur_id) max_rouge = cur_max_rouge pred_txt = " ".join([txt[j] for j in selected]) else: k = min(max(len(txt) // (2*args.win_size+1), 1), args.ranking_max_k) pred_txt = " ".join(txt[:k]) save_gold.write(summary + "\n") save_pred.write(pred_txt + "\n") save_gold.flush() save_pred.flush() save_gold.close() save_pred.close() length = test_length(pred_path, gold_path) bleu = test_bleu(pred_path, gold_path) file_rouge = FilesRouge(hyp_path=pred_path, ref_path=gold_path) pred_rouges = file_rouge.get_scores(avg=True) logger.info('Length ratio:\n%s' % str(length)) logger.info('Bleu:\n%.2f' % (bleu*100)) logger.info('Rouges:\n%s' % rouge_results_to_str(pred_rouges))
def translate(self, data_iter, step, attn_debug=False): self.model.eval() gold_path = self.args.result_path + '.%d.gold' % step can_path = self.args.result_path + '.%d.candidate' % step self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8') self.can_out_file = codecs.open(can_path, 'w', 'utf-8') # raw_gold_path = self.args.result_path + '.%d.raw_gold' % step # raw_can_path = self.args.result_path + '.%d.raw_candidate' % step self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8') self.can_out_file = codecs.open(can_path, 'w', 'utf-8') raw_src_path = self.args.result_path + '.%d.raw_src' % step self.src_out_file = codecs.open(raw_src_path, 'w', 'utf-8') # # ent_path = '%s_step%d.ent' % (self.args.result_path, step) # self.ent_file = codecs.open(ent_path, 'w', 'utf-8') # tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) # self.gold = codecs.open('/data/bqw/GraphEnt_3_21/abs_gate.112000.gold', 'w', 'utf-8') # self.cand = codecs.open('/data/bqw/GraphEnt_3_21/abs_gate.112000.candidate', 'w', 'utf-8') # sample_num = 1 # # with torch.no_grad(): # for batch in data_iter: # for i, sample in enumerate(batch.ent_text): # spo_num_meta = batch.spo_num[i] # ent_num_meta = len(sample) # raw_ent = ['[root]'] # src = linecache.getline("/data/bqw/GraphEnt_3_21/logs/cnndm/abs_gate.112000.raw_src", sample_num).strip() # gold = linecache.getline("/data/bqw/GraphEnt_3_21/logs/cnndm/abs_gate.112000.gold", sample_num).strip() # cand1 = linecache.getline("/data/bqw/GraphEnt_3_21/logs/cnndm/abs_gate.112000.candidate", sample_num).strip() # cand2 = linecache.getline("/data/bqw/PreSumm/logs/cnndm/test_ent_1.118000.candidate", sample_num).strip() # # sample_num += 1 # for j, meta_ent in enumerate(sample): # if j == 0: # continue # meta_text = ' '.join(tokenizer.convert_ids_to_tokens(meta_ent)).replace(' ##', '') # raw_ent.append(meta_text) # raw_ent = '[CLS]'.join(raw_ent).strip() # # print(raw_ent) # # self.ent_file.write(gold+'[ANA]'+cand+'[ANA]'+raw_ent + '[ANA]{}[ANA]{}'.format(ent_num_meta, spo_num_meta) + '\n') # self.ent_file.write(src+'[ANA]'+gold+'[ANA]'+cand1+'[ANA]'+cand2+'[ANA]' # +raw_ent + '[ANA]{}[ANA]{}'.format(ent_num_meta, spo_num_meta) + '\n') # self.ent_file.flush() # self.ent_file.close() #################################################################################################### pred_results, gold_results = [], [] ct = 0 with torch.no_grad(): for batch in data_iter: if (self.args.recall_eval): gold_tgt_len = batch.tgt.size(1) self.min_length = gold_tgt_len + 20 self.max_length = gold_tgt_len + 60 batch_data = self.translate_batch(batch) translations = self.from_batch(batch_data) for trans in translations: pred, gold, src = trans pred_str = pred.replace('[unused0]', '').replace( '[unused3]', '').replace('[PAD]', '').replace( '[unused1]', '').replace(r' +', ' ').replace( ' [unused2] ', '<q>').replace('[unused2]', '').strip() gold_str = gold.strip() if (self.args.recall_eval): print( 'this is recal-----------------------------------------------' ) _pred_str = '' gap = 1e3 for sent in pred_str.split('<q>'): can_pred_str = _pred_str + '<q>' + sent.strip() can_gap = math.fabs( len(_pred_str.split()) - len(gold_str.split())) # if(can_gap>=gap): if (len(can_pred_str.split()) >= len(gold_str.split()) + 10): pred_str = _pred_str break else: gap = can_gap _pred_str = can_pred_str # pred_str = ' '.join(pred_str.split()[:len(gold_str.split())]) # self.raw_can_out_file.write(' '.join(pred).strip() + '\n') # self.raw_gold_out_file.write(' '.join(gold).strip() + '\n') self.can_out_file.write(pred_str + '\n') self.gold_out_file.write(gold_str + '\n') self.src_out_file.write(src.strip() + '\n') ## # with open('./scru_data.txt', 'a+') as f: # f.write('doc \n' + src.replace('[unused0]', '').replace('[unused3]', '').replace('[PAD]', '').replace('[unused1]', '').replace(r' +', ' ').replace(' [unused2] ', '<q>').replace('[unused2]', '').strip() + '\n') # f.write('gold \n' + gold_str + '\n') # f.write('pred \n' + pred_str + '\n') ct += 1 self.can_out_file.flush() self.gold_out_file.flush() self.src_out_file.flush() self.can_out_file.close() self.gold_out_file.close() self.src_out_file.close() if (step != -1): rouges = self._report_rouge(gold_path, can_path) self.logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges))) if self.tensorboard_writer is not None: self.tensorboard_writer.add_scalar('test/rouge1-F', rouges['rouge_1_f_score'], step) self.tensorboard_writer.add_scalar('test/rouge2-F', rouges['rouge_2_f_score'], step) self.tensorboard_writer.add_scalar('test/rougeL-F', rouges['rouge_l_f_score'], step)
def translate(self, data_iter, step, attn_debug=False): self.model.eval() gold_path = self.args.result_path + '.%d.gold' % step can_path = self.args.result_path + '.%d.candidate' % step self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8') self.can_out_file = codecs.open(can_path, 'w', 'utf-8') # raw_gold_path = self.args.result_path + '.%d.raw_gold' % step # raw_can_path = self.args.result_path + '.%d.raw_candidate' % step self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8') self.can_out_file = codecs.open(can_path, 'w', 'utf-8') raw_src_path = self.args.result_path + '.%d.raw_src' % step self.src_out_file = codecs.open(raw_src_path, 'w', 'utf-8') # pred_results, gold_results = [], [] ct = 0 with torch.no_grad(): for batch in data_iter: if (self.args.recall_eval): gold_tgt_len = batch.tgt.size(1) self.min_length = gold_tgt_len + 20 self.max_length = gold_tgt_len + 60 batch_data = self.translate_batch(batch) translations = self.from_batch(batch_data) for trans in translations: pred, gold, src = trans pred = pred.split('. [MASK]')[0] + '.' pred_str = pred.replace(' ホ ', '<q>').replace( '. [MASK]', '').replace('! [MASK]', '').replace( '? [MASK]', '').replace(' [MASK]', '').replace( ' [UNK]', '').replace(' [PAD]', '').replace( ' [CLS]', '').replace(' [SEP]', '').strip() gold_str = gold.lower().strip() # Filter result _pred_str = [] cur_word = 0 for sent in pred_str.split('<q>'): _pred_str.append(sent) if len(_pred_str) >= 2 and cur_word > 15: break pred_str = '<q>'.join(_pred_str) self.can_out_file.write(pred_str + '\n') self.gold_out_file.write(gold_str + '\n') self.src_out_file.write(src.strip() + '\n') ct += 1 self.can_out_file.flush() self.gold_out_file.flush() self.src_out_file.flush() self.can_out_file.close() self.gold_out_file.close() self.src_out_file.close() if (step != -1): rouges = self._report_rouge(gold_path, can_path) self.logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges))) if self.tensorboard_writer is not None: self.tensorboard_writer.add_scalar('test/rouge1-F', rouges['rouge_1_f_score'], step) self.tensorboard_writer.add_scalar('test/rouge2-F', rouges['rouge_2_f_score'], step) self.tensorboard_writer.add_scalar('test/rougeL-F', rouges['rouge_l_f_score'], step)
def translate(self, data_iter, epoch, cal_rouge=False, save=True, save_by_id=False, have_gold=False, info=""): self.model.eval() if save: base_path = pjoin(self.args.savepath, epoch) os.makedirs(base_path, exist_ok=True) out_path = pjoin(base_path, 'out.txt') self.out_file = codecs.open(out_path, 'w', 'utf-8') can_path = pjoin(base_path, 'can.txt') self.can_out_file = codecs.open(can_path, 'w', 'utf-8') if have_gold: gold_path = pjoin(base_path, 'gold.txt') self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8') #ididx_path = self.args.savepath + '/cond.txt' #self.ididx_out_file = codecs.open(ididx_path, 'w', 'utf-8') #raw_src_path = self.args.savepath + '/result/{}.raw_src'.format(epoch) #self.src_out_file = codecs.open(raw_src_path, 'w', 'utf-8') # pred_results, gold_results = [], [] #ct = 0 def _translate(batch): ''' if(self.args.recall_eval): gold_tgt_len = batch.tgt.size(1) self.min_length = gold_tgt_len + 20 self.max_length = gold_tgt_len + 60 ''' #id_idx = batch.id_idx batch_data = self.translate_batch(batch) translations = self.from_batch(batch_data, have_gold) for i, trans in enumerate(translations): pred, gold, src, id_, label = trans pred_str = pred.replace('[unused0]', '').replace( '[unused3]', '').replace('[PAD]', '').replace('[unused1]', '').replace( r' +', ' ').replace(' [unused2] ', '<q>').replace('[unused2]', '').strip() if have_gold: gold_str = gold.strip() #id_ = id_idx[i][0].item() #parent_idx = id_idx[i][1].item() #self_idx = id_idx[i][2].item() ''' if(self.args.recall_eval): _pred_str = '' gap = 1e3 for sent in pred_str.split('<q>'): can_pred_str = _pred_str+ '<q>'+sent.strip() can_gap = math.fabs(len(_pred_str.split())-len(gold_str.split())) # if(can_gap>=gap): if(len(can_pred_str.split())>=len(gold_str.split())+10): pred_str = _pred_str break else: gap = can_gap _pred_str = can_pred_str #pred_str = ' '.join(pred_str.split()[:len(gold_str.split())]) ''' #self.raw_can_out_file.write(' '.join(pred).strip() + '\n') #self.raw_gold_out_file.write(' '.join(gold).strip() + '\n') if save_by_id: with open( pjoin(self.args.savepath, 'temp/{}.txt'.format(id_)), 'a') as f: f.write('{}\t{}\t{}\n'.format(parent_idx, self_idx, pred_str)) #with open(pjoin(self.args.savepath,'temp_gold/{}.txt'.format(id_)), 'a') as f: # f.write('{}\t{}\t{}\n'.format(parent_idx, self_idx, gold_str)) if save: if have_gold: self.out_file.write( 'Predict:\t{}\nGold:\t{}\nSource:\t{}\n\n'.format( pred_str, gold_str, src.strip())) self.gold_out_file.write(gold_str + '\n') else: self.can_out_file.write(pred_str + '\n') self.out_file.write( 'Predict:\t{}\nSource:\t{}\n'.format( pred_str, src.strip())) ori_df = pd.read_csv(pjoin( self.args.savepath, "gen_result/{}.txt".format(id_)), delimiter="\t") ori_df.loc[ori_df["exp"] == info, "generated"] = pred_str ori_df.to_csv(pjoin(self.args.savepath, "gen_result/{}.txt".format(id_)), index=False, sep="\t") #with open(os.path.join(self.args.savepath, "gen_result/{}.txt".format(id_)), 'a') as f: # f.write("[Generated Response]:\t{}\n".format(pred_str)) # if have_gold: # f.write("[Reference Response]:\t{}\n".format(gold_str)) #self.src_out_file.write(src.strip() + '\n') #conds = torch.load(self.args.cache_path+'/{}.pt'.format(id_)) #cond = conds[int(self_idx),0].item() #self.ididx_out_file.write('{}\t{}\t{}\t{}\n'.format(id_,parent_idx,self_idx, cond)) #ct += 1 if save: self.out_file.flush() self.can_out_file.flush() if have_gold: self.gold_out_file.flush() #self.src_out_file.flush() #self.ididx_out_file.flush() Parallel(n_jobs=1, backend='threading')(delayed(_translate)(batch) for batch in tqdm(data_iter)) if save: self.out_file.close() if have_gold: self.can_out_file.close() self.gold_out_file.close() #self.src_out_file.close() #self.ididx_out_file.close() if save and cal_rouge and have_gold: rouges = self._report_rouge(gold_path, can_path) if self.logger: self.logger.info('Rouges at epoch {} \n{}'.format( epoch, rouge_results_to_str(rouges))) if self.tensorboard_writer is not None and epoch != 'best': self.tensorboard_writer.add_scalar('test/rouge1-F', rouges['rouge_1_f_score'], epoch) self.tensorboard_writer.add_scalar('test/rouge2-F', rouges['rouge_2_f_score'], epoch) self.tensorboard_writer.add_scalar('test/rougeL-F', rouges['rouge_l_f_score'], epoch) with open(os.path.join(self.args.savepath, 'result/rouge.txt'), 'a') as f: f.write('Epoch {}\n'.format(epoch)) f.write('--------------------------------------\n') f.write('test/rouge1-R: {}\n'.format(rouges['rouge_1_recall'])) f.write('test/rouge1-P: {}\n'.format( rouges['rouge_1_precision'])) f.write('test/rouge1-F: {}\n'.format( rouges['rouge_1_f_score'])) f.write('--------------------------------------\n') f.write('test/rouge2-R: {}\n'.format(rouges['rouge_2_recall'])) f.write('test/rouge2-P: {}\n'.format( rouges['rouge_2_precision'])) f.write('test/rouge2-F: {}\n'.format( rouges['rouge_2_f_score'])) f.write('--------------------------------------\n') f.write('test/rougeL-R: {}\n'.format(rouges['rouge_l_recall'])) f.write('test/rougeL-P: {}\n'.format( rouges['rouge_l_precision'])) f.write('test/rougeL-F: {}\n'.format( rouges['rouge_l_f_score'])) f.write('======================================\n')
def translate(self, data_iter, step, attn_debug=False): self.model.eval() gold_path = self.args.result_path + '.%d.gold' % step can_path = self.args.result_path + '.%d.candidate' % step gold_str_path = self.args.result_path + '.%d.goldstr' % step can_str_path = self.args.result_path + '.%d.canstr' % step eng_gold_path = self.args.result_path + '.%d.gold_eng' % step self.gold_eng_out_file = codecs.open(eng_gold_path, 'w', 'utf-8') # self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8') # self.can_out_file = codecs.open(can_path, 'w', 'utf-8') # raw_gold_path = self.args.result_path + '.%d.raw_gold' % step # raw_can_path = self.args.result_path + '.%d.raw_candidate' % step self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8') self.can_out_file = codecs.open(can_path, 'w', 'utf-8') self.gold_str_out_file = codecs.open(gold_str_path, 'w', 'utf-8') self.can_str_out_file = codecs.open(can_str_path, 'w', 'utf-8') raw_src_path = self.args.result_path + '.%d.raw_src' % step self.src_out_file = codecs.open(raw_src_path, 'w', 'utf-8') # pred_results, gold_results = [], [] ct = 0 with torch.no_grad(): for batch in data_iter: if (self.args.recall_eval): gold_tgt_len = batch.tgt.size(1) self.min_length = gold_tgt_len + 20 self.max_length = gold_tgt_len + 60 batch_data = self.translate_batch(batch) translations = self.from_batch(batch_data) for trans in translations: print(time.asctime(time.localtime(time.time())), "----- now is the test sample: ", ct, end='\r') # pred, gold, src, pred_strrr, gold_strrr, = trans pred, gold, src, pred_strrr, gold_strrr, gold_eng = trans if self.args.bart: pred_str = pred.replace('madeupword0000', '').replace( 'madeupword0001', '').replace('<pad>', '').replace( '<unk>', '').replace(r' +', ' ').replace( ' madeupword0002 ', '<q>').replace('madeupword0002', '').strip() # 这里由于target也是从id转过来的了,所以做了如下变换。 gold_str = gold.replace('[unused1]', '').replace( '[unused4]', '').replace('[PAD]', '').replace( '[unused2]', '').replace(r' +', ' ').replace( ' [unused3] ', '<q>').replace('[unused3]', '').strip() else: pred_str = pred.replace('[unused1]', '').replace( '[unused4]', '').replace('[PAD]', '').replace( '[unused2]', '').replace(r' +', ' ').replace( ' [unused3] ', '<q>').replace('[unused3]', '').strip() # 这里由于target也是从id转过来的了,所以做了如下变换。 gold_str = gold.replace('[unused1]', '').replace( '[unused4]', '').replace('[PAD]', '').replace( '[unused2]', '').replace(r' +', ' ').replace( ' [unused3] ', '<q>').replace('[unused3]', '').strip() # gold_str = gold.strip() if (self.args.recall_eval): _pred_str = '' gap = 1e3 for sent in pred_str.split('<q>'): can_pred_str = _pred_str + '<q>' + sent.strip() can_gap = math.fabs( len(_pred_str.split()) - len(gold_str.split())) # if(can_gap>=gap): if (len(can_pred_str.split()) >= len(gold_str.split()) + 10): pred_str = _pred_str break else: gap = can_gap _pred_str = can_pred_str # pred_str = ' '.join(pred_str.split()[:len(gold_str.split())]) # self.raw_can_out_file.write(' '.join(pred).strip() + '\n') # self.raw_gold_out_file.write(' '.join(gold).strip() + '\n') self.gold_eng_out_file.write(gold_eng + '\n') self.can_out_file.write(pred_str + '\n') self.gold_out_file.write(gold_str + '\n') self.src_out_file.write(src.strip() + '\n') # 下边是加的 self.can_str_out_file.write(pred_strrr + '\n') self.gold_str_out_file.write(gold_strrr + '\n') # print("pred_strrr = ", pred_strrr) # print("gold_strrr = ", gold_strrr) ct += 1 self.can_out_file.flush() self.gold_out_file.flush() self.src_out_file.flush() self.can_str_out_file.flush() self.gold_str_out_file.flush() self.can_out_file.close() self.gold_out_file.close() self.src_out_file.close() self.can_str_out_file.close() self.gold_str_out_file.close() if (step != -1): rouges = self._report_rouge(gold_path, can_path) self.logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges))) if self.tensorboard_writer is not None: self.tensorboard_writer.add_scalar('test/rouge1-F', rouges['rouge_1_f_score'], step) self.tensorboard_writer.add_scalar('test/rouge2-F', rouges['rouge_2_f_score'], step) self.tensorboard_writer.add_scalar('test/rougeL-F', rouges['rouge_l_f_score'], step)
def translate(self, data_iter, step, attn_debug=False): self.model.eval() gold_path = self.args.result_path + '.%d.gold' % step can_path = self.args.result_path + '.%d.candidate' % step self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8') self.can_out_file = codecs.open(can_path, 'w', 'utf-8') # raw_gold_path = self.args.result_path + '.%d.raw_gold' % step # raw_can_path = self.args.result_path + '.%d.raw_candidate' % step self.gold_out_file = codecs.open(gold_path, 'w', 'utf-8') self.can_out_file = codecs.open(can_path, 'w', 'utf-8') raw_src_path = self.args.result_path + '.%d.raw_src' % step self.src_out_file = codecs.open(raw_src_path, 'w', 'utf-8') # pred_results, gold_results = [], [] ct = 0 with torch.no_grad(): for batch in data_iter: if (self.args.recall_eval): gold_tgt_len = batch.tgt.size(1) self.min_length = gold_tgt_len + 20 self.max_length = gold_tgt_len + 60 batch_data = self.translate_batch(batch) translations = self.from_batch(batch_data) for trans in translations: pred, gold, src = trans pred_str = pred.replace('[unused6]', '').replace( '[unused3]', '').replace('[PAD]', '').replace( '[unused1]', '').replace(r' +', ' ').replace( ' [unused2] ', '<q>').replace('[unused2]', '').strip() gold_str = gold.strip() if (self.args.recall_eval): _pred_str = '' gap = 1e3 for sent in pred_str.split('<q>'): can_pred_str = _pred_str + '<q>' + sent.strip() can_gap = math.fabs( len(_pred_str.split()) - len(gold_str.split())) # if(can_gap>=gap): if (len(can_pred_str.split()) >= len(gold_str.split()) + 10): pred_str = _pred_str break else: gap = can_gap _pred_str = can_pred_str # pred_str = ' '.join(pred_str.split()[:len(gold_str.split())]) # self.raw_can_out_file.write(' '.join(pred).strip() + '\n') # self.raw_gold_out_file.write(' '.join(gold).strip() + '\n') self.can_out_file.write(pred_str + '\n') self.gold_out_file.write(gold_str + '\n') self.src_out_file.write(src.strip() + '\n') ct += 1 self.can_out_file.flush() self.gold_out_file.flush() self.src_out_file.flush() self.can_out_file.close() self.gold_out_file.close() self.src_out_file.close() if (step != -1): rouges = self._report_rouge(gold_path, can_path) self.logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges))) if self.tensorboard_writer is not None: self.tensorboard_writer.add_scalar('test/rouge1-F', rouges['rouge_1_f_score'], step) self.tensorboard_writer.add_scalar('test/rouge2-F', rouges['rouge_2_f_score'], step) self.tensorboard_writer.add_scalar('test/rougeL-F', rouges['rouge_l_f_score'], step)
import logger from others.utils import test_rouge, rouge_results_to_str rouges = test_rouge(self.args.temp_dir, can_path, gold_path) logger.info('Rouges at step %d \n%s' % (30000, rouge_results_to_str(rouges)))
def test(self, test_iter, step, cal_lead=False, cal_oracle=False): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. def _get_ngrams(n, text): ngram_set = set() text_length = len(text) max_index_ngram_start = text_length - n for i in range(max_index_ngram_start + 1): ngram_set.add(tuple(text[i:i + n])) return ngram_set def _block_tri(c, p): tri_c = _get_ngrams(3, c.split()) for s in p: tri_s = _get_ngrams(3, s.split()) if len(tri_c.intersection(tri_s))>0: return True return False if (not cal_lead and not cal_oracle): self.model.eval() stats = Statistics() can_path = '%s_step%d.candidate'%(self.args.result_path,step) gold_path = '%s_step%d.gold' % (self.args.result_path, step) with open(can_path, 'w') as save_pred: with open(gold_path, 'w') as save_gold: with torch.no_grad(): for batch in test_iter: src = batch.src labels = batch.labels segs = batch.segs clss = batch.clss mask = batch.mask mask_cls = batch.mask_cls gold = [] pred = [] if (cal_lead): selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size elif (cal_oracle): selected_ids = [[j for j in range(batch.clss.size(1)) if labels[i][j] == 1] for i in range(batch.batch_size)] else: sent_scores, mask = self.model(src, segs, clss, mask, mask_cls) loss = self.loss(sent_scores, labels.float()) loss = (loss * mask.float()).sum() batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels)) stats.update(batch_stats) sent_scores = sent_scores + mask.float() sent_scores = sent_scores.cpu().data.numpy() selected_ids = np.argsort(-sent_scores, 1) # selected_ids = np.sort(selected_ids,1) for i, idx in enumerate(selected_ids): _pred = [] if(len(batch.src_str[i])==0): continue for j in selected_ids[i][:len(batch.src_str[i])]: if(j>=len( batch.src_str[i])): continue candidate = batch.src_str[i][j].strip() if(self.args.block_trigram): if(not _block_tri(candidate,_pred)): _pred.append(candidate) else: _pred.append(candidate) if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3): break _pred = '<q>'.join(_pred) _pred=_pred+" original txt: "+" ".join(batch.src_str[i]) if(self.args.recall_eval): _pred = ' '.join(_pred.split()[:len(batch.tgt_str[i].split())]) pred.append(_pred) gold.append(batch.tgt_str[i]) for i in range(len(gold)): save_gold.write(gold[i].strip()+'\n') for i in range(len(pred)): save_pred.write(pred[i].strip()+'\n') if(step!=-1 and self.args.report_rouge): rouges = test_rouge(self.args.temp_dir, can_path, gold_path) logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges))) self._report_step(0, step, valid_stats=stats) return stats
if (args.task == 'abs'): if (args.mode == 'train'): train_abs(args, device_id) elif (args.mode == 'validate'): validate_abs(args, device_id) elif (args.mode == 'score'): tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir) symbols = {'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'], 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]']} predictor = build_predictor(args, tokenizer, symbols, None, logger) # step = 30000 gold_path = args.result_path + '.gold' can_path = args.result_path + '.candidate' rouges = predictor._report_rouge(gold_path, can_path) logger.info('Rouges at step %d \n%s' % (args.result_path, rouge_results_to_str(rouges))) # if self.tensorboard_writer is not None: # self.tensorboard_writer.add_scalar('test/rouge1-F', rouges['rouge_1_f_score'], step) # self.tensorboard_writer.add_scalar('test/rouge2-F', rouges['rouge_2_f_score'], step) # self.tensorboard_writer.add_scalar('test/rougeL-F', rouges['rouge_l_f_score'], step) elif (args.mode == 'lead'): baseline(args, cal_lead=True) elif (args.mode == 'oracle'): baseline(args, cal_oracle=True) if (args.mode == 'test'): cp = args.test_from try: step = int(cp.split('.')[-2].split('_')[-1]) except: step = 0 test_abs(args, device_id, cp, step)