def evaluate(self, eval_batches, data_type, result_dir=None, result_prefix=None, save_full_info=False): pred_answers, ref_answers = [], [] total_loss, total_num = 0, 0 for b_itx, batch in enumerate(eval_batches): feed_dict = { self.p: batch['passage_token_ids'], self.q: batch['question_token_ids'], self.qh: batch['question_char_ids'], self.ph: batch["passage_char_ids"], self.start_label: batch['start_id'], self.end_label: batch['end_id'], } try: start_probs, end_probs, loss = self.sess.run( [self.logits1, self.logits2, self.loss], feed_dict) total_loss += loss * len(batch['raw_data']) total_num += len(batch['raw_data']) padded_p_len = len(batch['passage_token_ids'][0]) for sample, start_prob, end_prob in zip( batch['raw_data'], start_probs, end_probs): best_answer = self.find_best_answer( sample, start_prob, end_prob, padded_p_len) if save_full_info: sample['pred_answers'] = [best_answer] pred_answers.append(sample) else: pred_answers.append({ 'question_id': sample['question_id'], 'question_type': sample['question_type'], 'answers': [best_answer], 'entity_answers': [[]], 'yesno_answers': [] }) if 'answers' in sample: ref_answers.append({ 'question_id': sample['question_id'], 'question_type': sample['question_type'], 'answers': sample['answers'], 'entity_answers': [[]], 'yesno_answers': [] }) except: print('evaluate 异常') continue if result_dir is not None and result_prefix is not None: result_file = os.path.join(result_dir, result_prefix + '.json') with open(result_file, 'w') as fout: for pred_answer in pred_answers: fout.write( json.dumps(pred_answer, ensure_ascii=False) + '\n') self.logger.info('Saving {} results to {}'.format( result_prefix, result_file)) # 这个平均损失在测试集上是无效的,因为我们没有真正的start_id和end_id ave_loss = 1.0 * total_loss / total_num # 如果提供了参考答案,则计算bleu和rouge分数 if len(ref_answers) > 0: pred_dict, ref_dict = {}, {} for pred, ref in zip(pred_answers, ref_answers): question_id = ref['question_id'] if len(ref['answers']) > 0: pred_dict[question_id] = normalize(pred['answers']) ref_dict[question_id] = normalize(ref['answers']) bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict) else: bleu_rouge = None # 存储 ave_loss_sum = tf.Summary(value=[ tf.Summary.Value(tag="{}/loss".format(data_type), simple_value=ave_loss), ]) bleu_4_sum = tf.Summary(value=[ tf.Summary.Value(tag="{}/bleu_4".format(data_type), simple_value=bleu_rouge['Bleu-4']), ]) rougeL_sum = tf.Summary(value=[ tf.Summary.Value(tag="{}/rouge-L".format(data_type), simple_value=bleu_rouge['Rouge-L']), ]) return ave_loss, bleu_rouge, [ave_loss_sum, bleu_4_sum, rougeL_sum]
def evaluate(self, eval_batches, result_dir=None, result_prefix=None, save_full_info=False): pred_answers, ref_answers = [], [] total_loss, total_num = 0, 0 for b_itx, batch in enumerate(eval_batches): feed_dict = { self.c: batch['passage_token_ids'], self.q: batch['question_token_ids'], self.start_label: batch['start_id'], self.end_label: batch['end_id'], self.dropout: 0.0 } start_probs, end_probs, loss = self.sess.run( [self.yp1, self.yp2, self.loss], feed_dict) total_loss += loss * len(batch['raw_data']) total_num += len(batch['raw_data']) for sample, start_prob, end_prob in zip(batch['raw_data'], start_probs.tolist(), end_probs.tolist()): best_answer = "".join( sample['passage_tokens'][start_prob:end_prob + 1]) if save_full_info: sample['pred_answers'] = [best_answer] pred_answers.append(sample) else: pred_answers.append({ 'question_id': sample['question_id'], 'question_type': sample['question_type'], 'answers': [best_answer], 'entity_answers': [[]], 'yesno_answers': [] }) if 'answers' in sample: ref_answers.append({ 'question_id': sample['question_id'], 'question_type': sample['question_type'], 'answers': sample['answers'], 'entity_answers': [[]], 'yesno_answers': [] }) if result_dir is not None and result_prefix is not None: result_file = os.path.join(result_dir, result_prefix + '.json') with open(result_file, 'w', encoding='utf-8') as fout: for pred_answer in pred_answers: fout.write( json.dumps(pred_answer, ensure_ascii=False) + '\n') self.logger.info('Saving {} results to {}'.format( result_prefix, result_file)) # this average loss is invalid on test set, since we don't have true start_id and end_id ave_loss = 1.0 * total_loss / total_num # compute the bleu and rouge scores if reference answers is provided if len(ref_answers) > 0: pred_dict, ref_dict = {}, {} for pred, ref in zip(pred_answers, ref_answers): question_id = ref['question_id'] if len(ref['answers']) > 0: pred_dict[question_id] = normalize(pred['answers']) ref_dict[question_id] = normalize(ref['answers']) bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict) else: bleu_rouge = None return ave_loss, bleu_rouge
def evaluate(self, eval_batches, result_dir=None, result_prefix=None, save_full_info=False): start_eval_time = time.time() pred_answers, ref_answers = [], [] total_loss, total_num = 0, 0 for b_itx, batch in enumerate(eval_batches): start_batches_time = time.time() feed_dict = { self.c: batch['passage_token_ids'], self.q: batch['question_token_ids'], self.qh: batch['question_char_ids'], self.ch: batch["passage_char_ids"], self.start_label: batch['start_id'], self.end_label: batch['end_id'], self.dropout: 0.0 } try: start_sess_time = time.time() start_probs, end_probs, loss = self.sess.run( [self.logits1, self.logits2, self.loss], feed_dict) print("Sess time: ", time.time() - start_sess_time) total_loss += loss * len(batch['raw_data']) total_num += len(batch['raw_data']) padded_p_len = len(batch['passage_token_ids'][0]) for sample, start_prob, end_prob in zip( batch['raw_data'], start_probs, end_probs): best_answer = self.find_best_answer( sample, start_prob, end_prob, padded_p_len) if save_full_info: sample['pred_answers'] = [best_answer] pred_answers.append(sample) else: pred_answers.append({ 'question_id': sample['question_id'], 'question': sample['question'], 'question_type': sample['question_type'], 'answers': [best_answer], 'entity_answers': [[]], 'yesno_answers': [] }) if 'answers' in sample: ref_answers.append({ 'question_id': sample['question_id'], 'question_type': sample['question_type'], 'answers': sample['answers'], 'entity_answers': [[]], 'yesno_answers': [] }) except Exception as e: # print(str(e)) traceback.print_exc() continue print("Batch time: ", time.time() - start_batches_time) print("Predict answers: ", pred_answers) if result_dir is not None and result_prefix is not None: result_file = os.path.join(result_dir, result_prefix + '.json') with open(result_file, 'w', encoding='utf-8') as fout: for pred_answer in pred_answers: fout.write( json.dumps(pred_answer, ensure_ascii=False) + '\n') self.logger.info('Saving {} results to {}'.format( result_prefix, result_file)) if result_prefix == 'test.predicted': print(pred_answers) # this average loss is invalid on test set, since we don't have true start_id and end_id ave_loss = 1.0 * total_loss / total_num # compute the bleu and rouge scores if reference answers is provided if len(ref_answers) > 0: pred_dict, ref_dict = {}, {} for pred, ref in zip(pred_answers, ref_answers): question_id = ref['question_id'] if len(ref['answers']) > 0: pred_dict[question_id] = normalize(pred['answers']) ref_dict[question_id] = normalize(ref['answers']) bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict) else: bleu_rouge = None print("Eval time: ", time.time() - start_eval_time) return ave_loss, bleu_rouge
def evaluate(self, eval_batches, result_dir=None, result_prefix=None, save_full_info=False): """ Evaluates the model performance on eval_batches and results are saved if specified Args: eval_batches: iterable batch data result_dir: directory to save predicted answers, answers will not be saved if None result_prefix: prefix of the file for saving predicted answers, answers will not be saved if None save_full_info: if True, the pred_answers will be added to raw sample and saved """ pred_answers, ref_answers = [], [] total_num, num_of_batch, correct_p_num, select_total_num, select_true_num = 0, 0, 0, 0, 0 self.model.eval() for b_itx, batch in enumerate(eval_batches): print("aaaaaaaaa") num_of_batch += 1 # print("now is batch: ", b_itx) # batch_size * max_passage_num x padded_p_len p = Variable(torch.LongTensor(batch['passage_token_ids']), volatile=True).cuda() # batch_size * max_passage_num x padded_q_len q = Variable(torch.LongTensor(batch['question_token_ids']), volatile=True).cuda() # batch_size start_label = Variable(torch.LongTensor(batch['start_id']), volatile=True).cuda() # batch_size # end_label = Variable(torch.LongTensor(batch['end_id']), volatile=True).cuda() # batch_size * max_passage_num x padded_p_len x 2 answer_prob = self.model(p, q) # batch_size * max_passage_num x padded_p_len answer_begin_prob = answer_prob[:, :, 0].contiguous() # batch_size * max_passage_num x padded_p_len answer_end_prob = answer_prob[:, :, 1].contiguous() total_num += len(batch['raw_data']) # padded_p_len = len(batch['passage_token_ids'][0]) max_passage_num = p.size(0) // start_label.size(0) for idx, sample in enumerate(batch['raw_data']): select_total_num += 1 # max_passage_num x padded_p_len start_prob = answer_begin_prob[idx * max_passage_num:(idx + 1) * max_passage_num, :] end_prob = answer_end_prob[idx * max_passage_num:(idx + 1) * max_passage_num, :] best_answer, best_p_idx = self.find_best_answer( sample, start_prob, end_prob) if best_p_idx in sample['answer_passages']: correct_p_num += 1 if sample['passages'][best_p_idx]['is_selected']: select_true_num += 1 if save_full_info: sample['pred_answers'] = [best_answer] pred_answers.append(sample) else: pred_answers.append({ 'question_id': sample['question_id'], 'question_type': sample['question_type'], 'answers': [best_answer], 'entity_answers': [[]], 'yesno_answers': [] }) if 'answers' in sample: ref_answers.append({ 'question_id': sample['question_id'], 'question_type': sample['question_type'], 'answers': sample['answers'], 'entity_answers': [[]], 'yesno_answers': [] }) if result_dir is not None and result_prefix is not None: result_file = os.path.join(result_dir, result_prefix + '.json') with open(result_file, 'w') as fout: for pred_answer in pred_answers: fout.write( json.dumps(pred_answer, ensure_ascii=False) + '\n') print('Saving {} results to {}'.format(result_prefix, result_file)) # this average loss is invalid on test set, since we don't have true start_id and end_id # ave_loss = 1.0 * total_loss / num_of_batch # compute the bleu and rouge scores if reference answers is provided if len(ref_answers) > 0: pred_dict, ref_dict = {}, {} for pred, ref in zip(pred_answers, ref_answers): question_id = ref['question_id'] if len(ref['answers']) > 0: pred_dict[question_id] = normalize(pred['answers']) ref_dict[question_id] = normalize(ref['answers']) bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict) else: bleu_rouge = None print('correct selected passage num is {} in {}'.format( select_true_num, select_total_num)) print('correct passage num is {} in {}'.format(correct_p_num, total_num)) return bleu_rouge
def evaluate(self, eval_batches, result_dir=None, result_prefix=None, save_full_info=False): """ Evaluates the model performance on eval_batches and results are saved if specified Args: eval_batches: iterable batch data result_dir: directory to save predicted answers, answers will not be saved if None result_prefix: prefix of the file for saving predicted answers, answers will not be saved if None save_full_info: if True, the pred_answers will be added to raw sample and saved """ # print("\033[0;30;46m WHY Info: result_dir is: {}. \033[0m ".format(result_dir)) pred_answers, ref_answers = [], [] total_num, num_of_batch, correct_p_num, select_total_num, select_true_num = 0, 0, 0, 0, 0 # MODE_YESNO = False # if MODE_YESNO: # para = [] # qua = [] self.model.eval() for b_itx, batch in enumerate(eval_batches): # print("[debug] in func eval") num_of_batch += 1 # print("now is batch: ", b_itx) # batch_size * max_passage_num x padded_p_len # ------------------------------------------- # [why edit] update code version to torch 0.4 # p = Variable(torch.LongTensor(batch['passage_token_ids']), volatile=True).cuda() p = Variable(torch.LongTensor(batch['passage_token_ids']), requires_grad=False).cuda() # batch_size * max_passage_num x padded_q_len # q = Variable(torch.LongTensor(batch['question_token_ids']), volatile=True).cuda() # batch_size # [why edit] update code version to torch 0.4 q = Variable(torch.LongTensor(batch['question_token_ids']), requires_grad=False).cuda() # start_label = Variable(torch.LongTensor(batch['start_id']), volatile=True).cuda() # if MODE_YESNO: # para.append(p) # qua.append(q) # else: start_label = Variable(torch.LongTensor(batch['start_id']), requires_grad=False).cuda() # --------------------------------------------- # batch_size # end_label = Variable(torch.LongTensor(batch['end_id']), volatile=True).cuda() # batch_size * max_passage_num x padded_p_len x 2 answer_prob = self.model(p, q) # batch_size * max_passage_num x padded_p_len answer_begin_prob = answer_prob[:, :, 0].contiguous() # batch_size * max_passage_num x padded_p_len answer_end_prob = answer_prob[:, :, 1].contiguous() total_num += len(batch['raw_data']) # padded_p_len = len(batch['passage_token_ids'][0]) max_passage_num = p.size(0) // start_label.size(0) for idx, sample in enumerate(batch['raw_data']): select_total_num += 1 # max_passage_num x padded_p_len start_prob = answer_begin_prob[idx * max_passage_num:(idx + 1) * max_passage_num, :] end_prob = answer_end_prob[idx * max_passage_num:(idx + 1) * max_passage_num, :] best_answer, best_p_idx = self.find_best_answer( sample, start_prob, end_prob) # [why] added by WHY, 2018.8.22, to solve KeyError in prediction. MODE = 'predict' if MODE == 'predict': pass else: if 'answer_passages' in sample.keys(): if best_p_idx in sample['answer_passages']: correct_p_num += 1 try: if sample['passages'][best_p_idx]['is_selected']: select_true_num += 1 except KeyError: print( "\033[0;30;46m WHY Info: sample passages is:\n {} \033[0m " .format(sample['passages'])) pass #----------------------------------------------------------------- if save_full_info: sample['pred_answers'] = [best_answer] pred_answers.append(sample) else: if len(best_answer ) < 2 and sample['question_type'] == 'entity': if not re.match(r'[0-9]*', best_answer): best_answer = '日' if sample['question_type'] == 'entity': if re.match(r'.*月[0-9]*$', best_answer): best_answer = best_answer + '日' if len(best_answer ) < 4 and sample['question_type'] == 'DESCRIPTION': best_answer = sample['documents'][0]['paragraphs'][0] pred_answers.append({ 'question_id': sample['question_id'], 'question_type': sample['question_type'], 'answers': [best_answer], 'entity_answers': [[]], 'yesno_answers': [] }) if 'answers' in sample: ref_answers.append({ 'question_id': sample['question_id'], 'question_type': sample['question_type'], 'answers': sample['answers'], 'entity_answers': [[]], 'yesno_answers': [] }) # if MODE_YESNO: # para = para.numpy() # qua = qua.numpy() # np.savetxt('../data/para.txt', para) # np.savetxt('../data/qua.txt', qua) if result_dir is not None and result_prefix is not None: result_file = os.path.join(result_dir, result_prefix + '.json') with open(result_file, 'w') as fout: for pred_answer in pred_answers: fout.write( json.dumps(pred_answer, ensure_ascii=False) + '\n') print('Saving {} results to {}'.format(result_prefix, result_file)) # this average loss is invalid on test set, since we don't have true start_id and end_id # ave_loss = 1.0 * total_loss / num_of_batch # compute the bleu and rouge scores if reference answers is provided if len(ref_answers) > 0: pred_dict, ref_dict = {}, {} for pred, ref in zip(pred_answers, ref_answers): question_id = ref['question_id'] if len(ref['answers']) > 0: pred_dict[question_id] = normalize(pred['answers']) ref_dict[question_id] = normalize(ref['answers']) bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict) else: bleu_rouge = None print('correct selected passage num is {} in {}'.format( select_true_num, select_total_num)) print('correct passage num is {} in {}'.format(correct_p_num, total_num)) return bleu_rouge
def evaluate(self, eval_batches, result_dir=None, result_prefix=None, save_full_info=False): """ Evaluates the model performance on eval_batches and results are saved if specified Args: eval_batches: iterable batch data result_dir: directory to save predicted answers, answers will not be saved if None result_prefix: prefix of the file for saving predicted answers, answers will not be saved if None save_full_info: if True, the pred_answers will be added to raw sample and saved """ # print('eval_batches######') # # print(eval_batches.get_shape()) pred_answers, ref_answers = [], [] total_loss, total_num = 0, 0 def mul_dict(data): a = [] i = 1 a_one = [] for line in data: a_one.append(line) if i % self.max_p_num == 0: a.append(a_one) a_one = [] i += 1 return a for b_itx, batch in enumerate(eval_batches): feed_dict = { self.c: batch['passage_token_ids'], self.q: batch['question_token_ids'], self.qh: batch['question_char_ids'], self.ch: batch["passage_char_ids"], self.start_label: batch['start_id'], self.end_label: batch['end_id'], self.dropout: 0.0 } # mul_dict={self.c:mul_dict( batch['passage_token_ids']), # self.q: mul_dict(batch['question_token_ids']), # self.qh: mul_dict(batch['question_char_ids']), # self.ch: mul_dict(batch["passage_char_ids"]), # self.start_label: mul_dict(batch['start_id']), # self.end_label: mul_dict(batch['end_id']), # self.dropout:0.0} # mul_a,max=self.sess.run([self.mul_a,self.max],feed_dict) # print('mul_a#####') # print(mul_a,max) try: f, mp, sq, ssc, sc, c_emb, q_emb, match, fuse, sa, sl, el, s, e, start_probs, end_probs, loss = self.sess.run( [ self.f, self.mp, self.sep_q_encodes, self.sc, self.sep_c_encodes, self.c_emb, self.q_emb, self.match_p_encodes, self.fuse_p_encodes, self.anttion_p, self.sl, self.el, self.s, self.e, self.logits1, self.logits2, self.loss ], feed_dict) total_loss += loss * len(batch['raw_data']) total_num += len(batch['raw_data']) padded_p_len = len(batch['passage_token_ids'][0]) for sample, start_prob, end_prob in zip( batch['raw_data'], start_probs, end_probs): best_answer = self.find_best_answer( sample, start_prob, end_prob, padded_p_len) if save_full_info: sample['pred_answers'] = [best_answer] pred_answers.append(sample) else: pred_answers.append({ 'question_id': sample['question_id'], 'question_type': sample['question_type'], 'answers': [best_answer], 'entity_answers': [[]], 'yesno_answers': [] }) if 'answers' in sample: ref_answers.append({ 'question_id': sample['question_id'], 'question_type': sample['question_type'], 'answers': sample['answers'], 'entity_answers': [[]], 'yesno_answers': [] }) except: continue if result_dir is not None and result_prefix is not None: result_file = os.path.join(result_dir, result_prefix + '.json') with open(result_file, 'w') as fout: for pred_answer in pred_answers: fout.write( json.dumps(pred_answer, ensure_ascii=False) + '\n') self.logger.info('Saving {} results to {}'.format( result_prefix, result_file)) # this average loss is invalid on test set, since we don't have true start_id and end_id ave_loss = 1.0 * total_loss / total_num # compute the bleu and rouge scores if reference answers is provided if len(ref_answers) > 0: pred_dict, ref_dict = {}, {} for pred, ref in zip(pred_answers, ref_answers): question_id = ref['question_id'] if len(ref['answers']) > 0: pred_dict[question_id] = normalize(pred['answers']) ref_dict[question_id] = normalize(ref['answers']) bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict) else: bleu_rouge = None return ave_loss, bleu_rouge
def evaluate(self, eval_batches, result_dir=None, result_prefix=None, save_full_info=False): """ Evaluates the model performance on eval_batches and results are saved if specified Args: eval_batches: iterable batch data result_dir: directory to save predicted answers, answers will not be saved if None result_prefix: prefix of the file for saving predicted answers, answers will not be saved if None save_full_info: if True, the pred_answers will be added to raw sample and saved """ pred_answers, ref_answers = [], [] total_loss, total_num = 0, 0 for b_itx, batch in enumerate(eval_batches): feed_dict = { self.p: batch['passage_token_ids'], self.q: batch['question_token_ids'], self.p_length: batch['passage_length'], self.q_length: batch['question_length'], self.start_label: batch['start_id'], self.end_label: batch['end_id'], self.dropout_keep_prob: 1.0 } #print('111111') start_probs, end_probs, loss = self.sess.run( [self.start_probs, self.end_probs, self.loss], feed_dict) #print('####') # for i in batch['raw_data']: # print(i) #print(batch['raw_data']) total_loss += loss * len(batch['raw_data']) total_num += len(batch['raw_data']) padded_p_len = len(batch['passage_token_ids'][0]) for sample, start_prob, end_prob in zip(batch['raw_data'], start_probs, end_probs): best_answer = self.find_best_answer(sample, start_prob, end_prob, padded_p_len) if save_full_info: sample['pred_answers'] = [best_answer] pred_answers.append(sample) else: pred_answers.append({ 'question_id': sample['question_id'], 'question_type': sample['question_type'], 'answers': [best_answer], 'entity_answers': [[]], 'yesno_answers': [] }) if 'answers' in sample: ref_answers.append({ 'question_id': sample['question_id'], 'question_type': sample['question_type'], 'answers': sample['answers'], 'entity_answers': [[]], 'yesno_answers': [] }) if result_dir is not None and result_prefix is not None: result_file = os.path.join(result_dir, result_prefix + '.json') with open(result_file, 'w') as fout: for pred_answer in pred_answers: fout.write( json.dumps(pred_answer, ensure_ascii=False) + '\n') self.logger.info('Saving {} results to {}'.format( result_prefix, result_file)) # this average loss is invalid on test set, since we don't have true start_id and end_id #print(total_num) ave_loss = 1.0 * total_loss / total_num # compute the bleu and rouge scores if reference answers is provided if len(ref_answers) > 0: pred_dict, ref_dict = {}, {} for pred, ref in zip(pred_answers, ref_answers): question_id = ref['question_id'] if len(ref['answers']) > 0: pred_dict[question_id] = normalize(pred['answers']) ref_dict[question_id] = normalize(ref['answers']) bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict) else: bleu_rouge = None return ave_loss, bleu_rouge
def evaluate(args, eval_batches, vocab, sess, result_dir=None, result_prefix=None): logger = logging.getLogger("brc") total_loss, total_num = 0, 0 result_prob = [] pred_answers, ref_answers = [], [] with tf.device('/gpu:%s' % args.gpus[0]): with tf.variable_scope('model', reuse=True): model = Model(vocab, args) pp_scores = (0.43, 0.23, 0.16, 0.10, 0.09) for _, batch_data in enumerate(eval_batches, 1): feed_dict = { model.p: batch_data['passage_token_ids'], model.q: batch_data['question_token_ids'], model.p_length: batch_data['passage_length'], model.q_length: batch_data['question_length'], model.start_label: batch_data['start_ids'], model.end_label: batch_data['end_ids'], model.match_score: batch_data['match_scores'] } start_probs, end_probs, loss = sess.run( [model.start_probs, model.end_probs, model.loss], feed_dict) total_loss += loss * len(batch_data['raw_data']) total_num += len(batch_data['raw_data']) padded_p_len = len(batch_data['passage_token_ids'][0]) for sample, start_prob, end_prob in zip(batch_data['raw_data'], start_probs, end_probs): start_prob_list = [str(element) for element in list(start_prob)] end_prob_list = [str(element) for element in list(end_prob)] result_prob.append({ "question_id": sample['question_id'], "start_prob": start_prob_list, "end_prob": end_prob_list, 'padd_len': padded_p_len }) best_answer, segmented_pred = find_best_answer( sample, start_prob, end_prob, padded_p_len, args, para_prior_scores=pp_scores) pred_answers.append({ 'question_id': sample['question_id'], 'question_type': sample['question_type'], 'answers': [best_answer], 'entity_answers': [[]], 'yesno_answers': [], 'segmented_question': sample['segmented_question'], 'segmented_answers': segmented_pred }) if 'segmented_answers' in sample: ref_answers.append({ 'question_id': sample['question_id'], 'question_type': sample['question_type'], 'answers': [ ''.join(seg_ans) for seg_ans in sample['segmented_answers'] ], 'entity_answers': [[]], 'yesno_answers': [] }) if result_dir is not None and result_prefix is not None: result_file = os.path.join(result_dir, result_prefix + '.json') with open(result_file, 'w') as fout: for pred_answer in tqdm(pred_answers): fout.write(json.dumps(pred_answer, ensure_ascii=False) + '\n') prob_file = os.path.join(result_dir, result_prefix + 'probs.json') with open(prob_file, 'w') as f: for prob in tqdm(result_prob): f.write(json.dumps(prob, ensure_ascii=False) + "\n") logger.info('Saving {} results to {}'.format(result_prefix, result_file)) ave_loss = 1.0 * total_loss / total_num if len(ref_answers) > 0: pred_dict, ref_dict = {}, {} for pred, ref in zip(pred_answers, ref_answers): question_id = ref['question_id'] if len(ref['answers']) > 0: pred_dict[question_id] = normalize(pred['answers']) ref_dict[question_id] = normalize(ref['answers']) bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict) else: bleu_rouge = None return ave_loss, bleu_rouge