コード例 #1
0
    def evaluate(self,
                 eval_batches,
                 data_type,
                 result_dir=None,
                 result_prefix=None,
                 save_full_info=False):
        pred_answers, ref_answers = [], []
        total_loss, total_num = 0, 0
        for b_itx, batch in enumerate(eval_batches):

            feed_dict = {
                self.p: batch['passage_token_ids'],
                self.q: batch['question_token_ids'],
                self.qh: batch['question_char_ids'],
                self.ph: batch["passage_char_ids"],
                self.start_label: batch['start_id'],
                self.end_label: batch['end_id'],
            }
            try:
                start_probs, end_probs, loss = self.sess.run(
                    [self.logits1, self.logits2, self.loss], feed_dict)
                total_loss += loss * len(batch['raw_data'])
                total_num += len(batch['raw_data'])

                padded_p_len = len(batch['passage_token_ids'][0])
                for sample, start_prob, end_prob in zip(
                        batch['raw_data'], start_probs, end_probs):

                    best_answer = self.find_best_answer(
                        sample, start_prob, end_prob, padded_p_len)
                    if save_full_info:
                        sample['pred_answers'] = [best_answer]
                        pred_answers.append(sample)
                    else:
                        pred_answers.append({
                            'question_id':
                            sample['question_id'],
                            'question_type':
                            sample['question_type'],
                            'answers': [best_answer],
                            'entity_answers': [[]],
                            'yesno_answers': []
                        })
                    if 'answers' in sample:
                        ref_answers.append({
                            'question_id':
                            sample['question_id'],
                            'question_type':
                            sample['question_type'],
                            'answers':
                            sample['answers'],
                            'entity_answers': [[]],
                            'yesno_answers': []
                        })
            except:
                print('evaluate 异常')
                continue

        if result_dir is not None and result_prefix is not None:
            result_file = os.path.join(result_dir, result_prefix + '.json')
            with open(result_file, 'w') as fout:
                for pred_answer in pred_answers:
                    fout.write(
                        json.dumps(pred_answer, ensure_ascii=False) + '\n')

            self.logger.info('Saving {} results to {}'.format(
                result_prefix, result_file))

        # 这个平均损失在测试集上是无效的,因为我们没有真正的start_id和end_id
        ave_loss = 1.0 * total_loss / total_num
        # 如果提供了参考答案,则计算bleu和rouge分数
        if len(ref_answers) > 0:
            pred_dict, ref_dict = {}, {}
            for pred, ref in zip(pred_answers, ref_answers):
                question_id = ref['question_id']
                if len(ref['answers']) > 0:
                    pred_dict[question_id] = normalize(pred['answers'])
                    ref_dict[question_id] = normalize(ref['answers'])
            bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict)
        else:
            bleu_rouge = None
        # 存储
        ave_loss_sum = tf.Summary(value=[
            tf.Summary.Value(tag="{}/loss".format(data_type),
                             simple_value=ave_loss),
        ])
        bleu_4_sum = tf.Summary(value=[
            tf.Summary.Value(tag="{}/bleu_4".format(data_type),
                             simple_value=bleu_rouge['Bleu-4']),
        ])
        rougeL_sum = tf.Summary(value=[
            tf.Summary.Value(tag="{}/rouge-L".format(data_type),
                             simple_value=bleu_rouge['Rouge-L']),
        ])
        return ave_loss, bleu_rouge, [ave_loss_sum, bleu_4_sum, rougeL_sum]
コード例 #2
0
ファイル: model2.py プロジェクト: Yaozeng/MRC
    def evaluate(self,
                 eval_batches,
                 result_dir=None,
                 result_prefix=None,
                 save_full_info=False):
        pred_answers, ref_answers = [], []
        total_loss, total_num = 0, 0
        for b_itx, batch in enumerate(eval_batches):

            feed_dict = {
                self.c: batch['passage_token_ids'],
                self.q: batch['question_token_ids'],
                self.start_label: batch['start_id'],
                self.end_label: batch['end_id'],
                self.dropout: 0.0
            }
            start_probs, end_probs, loss = self.sess.run(
                [self.yp1, self.yp2, self.loss], feed_dict)
            total_loss += loss * len(batch['raw_data'])
            total_num += len(batch['raw_data'])

            for sample, start_prob, end_prob in zip(batch['raw_data'],
                                                    start_probs.tolist(),
                                                    end_probs.tolist()):

                best_answer = "".join(
                    sample['passage_tokens'][start_prob:end_prob + 1])
                if save_full_info:
                    sample['pred_answers'] = [best_answer]
                    pred_answers.append(sample)
                else:
                    pred_answers.append({
                        'question_id':
                        sample['question_id'],
                        'question_type':
                        sample['question_type'],
                        'answers': [best_answer],
                        'entity_answers': [[]],
                        'yesno_answers': []
                    })
                if 'answers' in sample:
                    ref_answers.append({
                        'question_id': sample['question_id'],
                        'question_type': sample['question_type'],
                        'answers': sample['answers'],
                        'entity_answers': [[]],
                        'yesno_answers': []
                    })

        if result_dir is not None and result_prefix is not None:
            result_file = os.path.join(result_dir, result_prefix + '.json')
            with open(result_file, 'w', encoding='utf-8') as fout:
                for pred_answer in pred_answers:
                    fout.write(
                        json.dumps(pred_answer, ensure_ascii=False) + '\n')

            self.logger.info('Saving {} results to {}'.format(
                result_prefix, result_file))

        # this average loss is invalid on test set, since we don't have true start_id and end_id
        ave_loss = 1.0 * total_loss / total_num
        # compute the bleu and rouge scores if reference answers is provided
        if len(ref_answers) > 0:
            pred_dict, ref_dict = {}, {}
            for pred, ref in zip(pred_answers, ref_answers):
                question_id = ref['question_id']
                if len(ref['answers']) > 0:
                    pred_dict[question_id] = normalize(pred['answers'])
                    ref_dict[question_id] = normalize(ref['answers'])
            bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict)
        else:
            bleu_rouge = None
        return ave_loss, bleu_rouge
コード例 #3
0
    def evaluate(self,
                 eval_batches,
                 result_dir=None,
                 result_prefix=None,
                 save_full_info=False):
        start_eval_time = time.time()
        pred_answers, ref_answers = [], []
        total_loss, total_num = 0, 0
        for b_itx, batch in enumerate(eval_batches):
            start_batches_time = time.time()
            feed_dict = {
                self.c: batch['passage_token_ids'],
                self.q: batch['question_token_ids'],
                self.qh: batch['question_char_ids'],
                self.ch: batch["passage_char_ids"],
                self.start_label: batch['start_id'],
                self.end_label: batch['end_id'],
                self.dropout: 0.0
            }

            try:
                start_sess_time = time.time()
                start_probs, end_probs, loss = self.sess.run(
                    [self.logits1, self.logits2, self.loss], feed_dict)
                print("Sess time: ", time.time() - start_sess_time)
                total_loss += loss * len(batch['raw_data'])
                total_num += len(batch['raw_data'])

                padded_p_len = len(batch['passage_token_ids'][0])
                for sample, start_prob, end_prob in zip(
                        batch['raw_data'], start_probs, end_probs):

                    best_answer = self.find_best_answer(
                        sample, start_prob, end_prob, padded_p_len)
                    if save_full_info:
                        sample['pred_answers'] = [best_answer]
                        pred_answers.append(sample)
                    else:
                        pred_answers.append({
                            'question_id':
                            sample['question_id'],
                            'question':
                            sample['question'],
                            'question_type':
                            sample['question_type'],
                            'answers': [best_answer],
                            'entity_answers': [[]],
                            'yesno_answers': []
                        })
                    if 'answers' in sample:
                        ref_answers.append({
                            'question_id':
                            sample['question_id'],
                            'question_type':
                            sample['question_type'],
                            'answers':
                            sample['answers'],
                            'entity_answers': [[]],
                            'yesno_answers': []
                        })

            except Exception as e:
                # print(str(e))
                traceback.print_exc()
                continue
            print("Batch time: ", time.time() - start_batches_time)

        print("Predict answers: ", pred_answers)
        if result_dir is not None and result_prefix is not None:
            result_file = os.path.join(result_dir, result_prefix + '.json')
            with open(result_file, 'w', encoding='utf-8') as fout:
                for pred_answer in pred_answers:
                    fout.write(
                        json.dumps(pred_answer, ensure_ascii=False) + '\n')

            self.logger.info('Saving {} results to {}'.format(
                result_prefix, result_file))
            if result_prefix == 'test.predicted':
                print(pred_answers)

        # this average loss is invalid on test set, since we don't have true start_id and end_id
        ave_loss = 1.0 * total_loss / total_num
        # compute the bleu and rouge scores if reference answers is provided
        if len(ref_answers) > 0:
            pred_dict, ref_dict = {}, {}
            for pred, ref in zip(pred_answers, ref_answers):
                question_id = ref['question_id']
                if len(ref['answers']) > 0:
                    pred_dict[question_id] = normalize(pred['answers'])
                    ref_dict[question_id] = normalize(ref['answers'])
            bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict)
        else:
            bleu_rouge = None

        print("Eval time: ", time.time() - start_eval_time)
        return ave_loss, bleu_rouge
コード例 #4
0
    def evaluate(self,
                 eval_batches,
                 result_dir=None,
                 result_prefix=None,
                 save_full_info=False):
        """
        Evaluates the model performance on eval_batches and results are saved if specified
        Args:
            eval_batches: iterable batch data
            result_dir: directory to save predicted answers, answers will not be saved if None
            result_prefix: prefix of the file for saving predicted answers,
                           answers will not be saved if None
            save_full_info: if True, the pred_answers will be added to raw sample and saved
        """
        pred_answers, ref_answers = [], []
        total_num, num_of_batch, correct_p_num, select_total_num, select_true_num = 0, 0, 0, 0, 0
        self.model.eval()
        for b_itx, batch in enumerate(eval_batches):
            print("aaaaaaaaa")
            num_of_batch += 1
            # print("now is batch: ", b_itx)
            # batch_size * max_passage_num x padded_p_len
            p = Variable(torch.LongTensor(batch['passage_token_ids']),
                         volatile=True).cuda()
            # batch_size * max_passage_num x padded_q_len
            q = Variable(torch.LongTensor(batch['question_token_ids']),
                         volatile=True).cuda()
            # batch_size
            start_label = Variable(torch.LongTensor(batch['start_id']),
                                   volatile=True).cuda()
            # batch_size
            # end_label = Variable(torch.LongTensor(batch['end_id']), volatile=True).cuda()
            # batch_size * max_passage_num x padded_p_len x 2
            answer_prob = self.model(p, q)
            # batch_size * max_passage_num x padded_p_len
            answer_begin_prob = answer_prob[:, :, 0].contiguous()
            # batch_size * max_passage_num x padded_p_len
            answer_end_prob = answer_prob[:, :, 1].contiguous()
            total_num += len(batch['raw_data'])
            # padded_p_len = len(batch['passage_token_ids'][0])
            max_passage_num = p.size(0) // start_label.size(0)
            for idx, sample in enumerate(batch['raw_data']):
                select_total_num += 1
                # max_passage_num x padded_p_len
                start_prob = answer_begin_prob[idx *
                                               max_passage_num:(idx + 1) *
                                               max_passage_num, :]
                end_prob = answer_end_prob[idx * max_passage_num:(idx + 1) *
                                           max_passage_num, :]

                best_answer, best_p_idx = self.find_best_answer(
                    sample, start_prob, end_prob)
                if best_p_idx in sample['answer_passages']:
                    correct_p_num += 1
                if sample['passages'][best_p_idx]['is_selected']:
                    select_true_num += 1

                if save_full_info:
                    sample['pred_answers'] = [best_answer]
                    pred_answers.append(sample)
                else:
                    pred_answers.append({
                        'question_id':
                        sample['question_id'],
                        'question_type':
                        sample['question_type'],
                        'answers': [best_answer],
                        'entity_answers': [[]],
                        'yesno_answers': []
                    })
                if 'answers' in sample:
                    ref_answers.append({
                        'question_id': sample['question_id'],
                        'question_type': sample['question_type'],
                        'answers': sample['answers'],
                        'entity_answers': [[]],
                        'yesno_answers': []
                    })

        if result_dir is not None and result_prefix is not None:
            result_file = os.path.join(result_dir, result_prefix + '.json')
            with open(result_file, 'w') as fout:
                for pred_answer in pred_answers:
                    fout.write(
                        json.dumps(pred_answer, ensure_ascii=False) + '\n')

            print('Saving {} results to {}'.format(result_prefix, result_file))

        # this average loss is invalid on test set, since we don't have true start_id and end_id
        # ave_loss = 1.0 * total_loss / num_of_batch
        # compute the bleu and rouge scores if reference answers is provided
        if len(ref_answers) > 0:
            pred_dict, ref_dict = {}, {}
            for pred, ref in zip(pred_answers, ref_answers):
                question_id = ref['question_id']
                if len(ref['answers']) > 0:
                    pred_dict[question_id] = normalize(pred['answers'])
                    ref_dict[question_id] = normalize(ref['answers'])
            bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict)
        else:
            bleu_rouge = None
        print('correct selected passage num is {} in {}'.format(
            select_true_num, select_total_num))
        print('correct passage num is {} in {}'.format(correct_p_num,
                                                       total_num))
        return bleu_rouge
コード例 #5
0
    def evaluate(self,
                 eval_batches,
                 result_dir=None,
                 result_prefix=None,
                 save_full_info=False):
        """
        Evaluates the model performance on eval_batches and results are saved if specified
        Args:
            eval_batches: iterable batch data
            result_dir: directory to save predicted answers, answers will not be saved if None
            result_prefix: prefix of the file for saving predicted answers,
                           answers will not be saved if None
            save_full_info: if True, the pred_answers will be added to raw sample and saved
        """
        # print("\033[0;30;46m WHY Info: result_dir is: {}. \033[0m ".format(result_dir))
        pred_answers, ref_answers = [], []
        total_num, num_of_batch, correct_p_num, select_total_num, select_true_num = 0, 0, 0, 0, 0
        # MODE_YESNO = False
        # if MODE_YESNO:
        #     para = []
        #     qua = []
        self.model.eval()
        for b_itx, batch in enumerate(eval_batches):
            # print("[debug] in func eval")
            num_of_batch += 1
            # print("now is batch: ", b_itx)
            # batch_size * max_passage_num x padded_p_len

            # -------------------------------------------
            # [why edit] update code version to torch 0.4
            # p = Variable(torch.LongTensor(batch['passage_token_ids']), volatile=True).cuda()
            p = Variable(torch.LongTensor(batch['passage_token_ids']),
                         requires_grad=False).cuda()

            # batch_size * max_passage_num x padded_q_len
            # q = Variable(torch.LongTensor(batch['question_token_ids']), volatile=True).cuda()
            # batch_size
            # [why edit] update code version to torch 0.4
            q = Variable(torch.LongTensor(batch['question_token_ids']),
                         requires_grad=False).cuda()
            # start_label = Variable(torch.LongTensor(batch['start_id']), volatile=True).cuda()
            # if MODE_YESNO:
            #     para.append(p)
            #     qua.append(q)
            # else:
            start_label = Variable(torch.LongTensor(batch['start_id']),
                                   requires_grad=False).cuda()
            # ---------------------------------------------

            # batch_size
            # end_label = Variable(torch.LongTensor(batch['end_id']), volatile=True).cuda()
            # batch_size * max_passage_num x padded_p_len x 2
            answer_prob = self.model(p, q)
            # batch_size * max_passage_num x padded_p_len
            answer_begin_prob = answer_prob[:, :, 0].contiguous()
            # batch_size * max_passage_num x padded_p_len
            answer_end_prob = answer_prob[:, :, 1].contiguous()
            total_num += len(batch['raw_data'])
            # padded_p_len = len(batch['passage_token_ids'][0])
            max_passage_num = p.size(0) // start_label.size(0)
            for idx, sample in enumerate(batch['raw_data']):
                select_total_num += 1
                # max_passage_num x padded_p_len
                start_prob = answer_begin_prob[idx *
                                               max_passage_num:(idx + 1) *
                                               max_passage_num, :]
                end_prob = answer_end_prob[idx * max_passage_num:(idx + 1) *
                                           max_passage_num, :]

                best_answer, best_p_idx = self.find_best_answer(
                    sample, start_prob, end_prob)

                # [why] added by WHY, 2018.8.22, to solve KeyError in prediction.
                MODE = 'predict'
                if MODE == 'predict':
                    pass
                else:
                    if 'answer_passages' in sample.keys():
                        if best_p_idx in sample['answer_passages']:
                            correct_p_num += 1

                    try:
                        if sample['passages'][best_p_idx]['is_selected']:
                            select_true_num += 1
                    except KeyError:
                        print(
                            "\033[0;30;46m WHY Info: sample passages is:\n {} \033[0m "
                            .format(sample['passages']))
                        pass
                    #-----------------------------------------------------------------
                if save_full_info:
                    sample['pred_answers'] = [best_answer]
                    pred_answers.append(sample)
                else:
                    if len(best_answer
                           ) < 2 and sample['question_type'] == 'entity':
                        if not re.match(r'[0-9]*', best_answer):
                            best_answer = '日'
                    if sample['question_type'] == 'entity':
                        if re.match(r'.*月[0-9]*$', best_answer):
                            best_answer = best_answer + '日'
                    if len(best_answer
                           ) < 4 and sample['question_type'] == 'DESCRIPTION':
                        best_answer = sample['documents'][0]['paragraphs'][0]
                    pred_answers.append({
                        'question_id':
                        sample['question_id'],
                        'question_type':
                        sample['question_type'],
                        'answers': [best_answer],
                        'entity_answers': [[]],
                        'yesno_answers': []
                    })
                if 'answers' in sample:
                    ref_answers.append({
                        'question_id': sample['question_id'],
                        'question_type': sample['question_type'],
                        'answers': sample['answers'],
                        'entity_answers': [[]],
                        'yesno_answers': []
                    })
        # if MODE_YESNO:
        #     para = para.numpy()
        #     qua = qua.numpy()
        #     np.savetxt('../data/para.txt', para)
        #     np.savetxt('../data/qua.txt', qua)

        if result_dir is not None and result_prefix is not None:
            result_file = os.path.join(result_dir, result_prefix + '.json')
            with open(result_file, 'w') as fout:
                for pred_answer in pred_answers:
                    fout.write(
                        json.dumps(pred_answer, ensure_ascii=False) + '\n')

            print('Saving {} results to {}'.format(result_prefix, result_file))

        # this average loss is invalid on test set, since we don't have true start_id and end_id
        # ave_loss = 1.0 * total_loss / num_of_batch
        # compute the bleu and rouge scores if reference answers is provided
        if len(ref_answers) > 0:
            pred_dict, ref_dict = {}, {}
            for pred, ref in zip(pred_answers, ref_answers):
                question_id = ref['question_id']
                if len(ref['answers']) > 0:
                    pred_dict[question_id] = normalize(pred['answers'])
                    ref_dict[question_id] = normalize(ref['answers'])
            bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict)
        else:
            bleu_rouge = None
        print('correct selected passage num is {} in {}'.format(
            select_true_num, select_total_num))
        print('correct passage num is {} in {}'.format(correct_p_num,
                                                       total_num))
        return bleu_rouge
コード例 #6
0
ファイル: model.py プロジェクト: MiHuangLan/reader
    def evaluate(self,
                 eval_batches,
                 result_dir=None,
                 result_prefix=None,
                 save_full_info=False):
        """
        Evaluates the model performance on eval_batches and results are saved if specified
        Args:
            eval_batches: iterable batch data
            result_dir: directory to save predicted answers, answers will not be saved if None
            result_prefix: prefix of the file for saving predicted answers,
                           answers will not be saved if None
            save_full_info: if True, the pred_answers will be added to raw sample and saved
        """
        # print('eval_batches######')
        #
        # print(eval_batches.get_shape())
        pred_answers, ref_answers = [], []
        total_loss, total_num = 0, 0

        def mul_dict(data):
            a = []
            i = 1
            a_one = []
            for line in data:
                a_one.append(line)
                if i % self.max_p_num == 0:
                    a.append(a_one)
                    a_one = []
                i += 1
            return a

        for b_itx, batch in enumerate(eval_batches):
            feed_dict = {
                self.c: batch['passage_token_ids'],
                self.q: batch['question_token_ids'],
                self.qh: batch['question_char_ids'],
                self.ch: batch["passage_char_ids"],
                self.start_label: batch['start_id'],
                self.end_label: batch['end_id'],
                self.dropout: 0.0
            }
            # mul_dict={self.c:mul_dict( batch['passage_token_ids']),
            #              self.q: mul_dict(batch['question_token_ids']),
            #              self.qh: mul_dict(batch['question_char_ids']),
            #              self.ch: mul_dict(batch["passage_char_ids"]),
            #              self.start_label: mul_dict(batch['start_id']),
            #              self.end_label: mul_dict(batch['end_id']),
            #              self.dropout:0.0}
            # mul_a,max=self.sess.run([self.mul_a,self.max],feed_dict)
            # print('mul_a#####')
            # print(mul_a,max)

            try:
                f, mp, sq, ssc, sc, c_emb, q_emb, match, fuse, sa, sl, el, s, e, start_probs, end_probs, loss = self.sess.run(
                    [
                        self.f, self.mp, self.sep_q_encodes, self.sc,
                        self.sep_c_encodes, self.c_emb, self.q_emb,
                        self.match_p_encodes, self.fuse_p_encodes,
                        self.anttion_p, self.sl, self.el, self.s, self.e,
                        self.logits1, self.logits2, self.loss
                    ], feed_dict)

                total_loss += loss * len(batch['raw_data'])
                total_num += len(batch['raw_data'])

                padded_p_len = len(batch['passage_token_ids'][0])
                for sample, start_prob, end_prob in zip(
                        batch['raw_data'], start_probs, end_probs):

                    best_answer = self.find_best_answer(
                        sample, start_prob, end_prob, padded_p_len)
                    if save_full_info:
                        sample['pred_answers'] = [best_answer]
                        pred_answers.append(sample)
                    else:
                        pred_answers.append({
                            'question_id':
                            sample['question_id'],
                            'question_type':
                            sample['question_type'],
                            'answers': [best_answer],
                            'entity_answers': [[]],
                            'yesno_answers': []
                        })
                    if 'answers' in sample:
                        ref_answers.append({
                            'question_id':
                            sample['question_id'],
                            'question_type':
                            sample['question_type'],
                            'answers':
                            sample['answers'],
                            'entity_answers': [[]],
                            'yesno_answers': []
                        })
            except:
                continue

        if result_dir is not None and result_prefix is not None:
            result_file = os.path.join(result_dir, result_prefix + '.json')
            with open(result_file, 'w') as fout:
                for pred_answer in pred_answers:
                    fout.write(
                        json.dumps(pred_answer, ensure_ascii=False) + '\n')

            self.logger.info('Saving {} results to {}'.format(
                result_prefix, result_file))

            # this average loss is invalid on test set, since we don't have true start_id and end_id
        ave_loss = 1.0 * total_loss / total_num
        # compute the bleu and rouge scores if reference answers is provided
        if len(ref_answers) > 0:
            pred_dict, ref_dict = {}, {}
            for pred, ref in zip(pred_answers, ref_answers):
                question_id = ref['question_id']
                if len(ref['answers']) > 0:
                    pred_dict[question_id] = normalize(pred['answers'])
                    ref_dict[question_id] = normalize(ref['answers'])
            bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict)
        else:
            bleu_rouge = None

        return ave_loss, bleu_rouge
コード例 #7
0
ファイル: mymodel.py プロジェクト: MiHuangLan/reader
    def evaluate(self,
                 eval_batches,
                 result_dir=None,
                 result_prefix=None,
                 save_full_info=False):
        """
        Evaluates the model performance on eval_batches and results are saved if specified
        Args:
            eval_batches: iterable batch data
            result_dir: directory to save predicted answers, answers will not be saved if None
            result_prefix: prefix of the file for saving predicted answers,
                           answers will not be saved if None
            save_full_info: if True, the pred_answers will be added to raw sample and saved
        """

        pred_answers, ref_answers = [], []
        total_loss, total_num = 0, 0

        for b_itx, batch in enumerate(eval_batches):
            feed_dict = {
                self.p: batch['passage_token_ids'],
                self.q: batch['question_token_ids'],
                self.p_length: batch['passage_length'],
                self.q_length: batch['question_length'],
                self.start_label: batch['start_id'],
                self.end_label: batch['end_id'],
                self.dropout_keep_prob: 1.0
            }
            #print('111111')

            start_probs, end_probs, loss = self.sess.run(
                [self.start_probs, self.end_probs, self.loss], feed_dict)
            #print('####')
            # for i in batch['raw_data']:
            #     print(i)
            #print(batch['raw_data'])
            total_loss += loss * len(batch['raw_data'])
            total_num += len(batch['raw_data'])

            padded_p_len = len(batch['passage_token_ids'][0])
            for sample, start_prob, end_prob in zip(batch['raw_data'],
                                                    start_probs, end_probs):

                best_answer = self.find_best_answer(sample, start_prob,
                                                    end_prob, padded_p_len)
                if save_full_info:
                    sample['pred_answers'] = [best_answer]
                    pred_answers.append(sample)
                else:
                    pred_answers.append({
                        'question_id':
                        sample['question_id'],
                        'question_type':
                        sample['question_type'],
                        'answers': [best_answer],
                        'entity_answers': [[]],
                        'yesno_answers': []
                    })
                if 'answers' in sample:
                    ref_answers.append({
                        'question_id': sample['question_id'],
                        'question_type': sample['question_type'],
                        'answers': sample['answers'],
                        'entity_answers': [[]],
                        'yesno_answers': []
                    })

        if result_dir is not None and result_prefix is not None:
            result_file = os.path.join(result_dir, result_prefix + '.json')
            with open(result_file, 'w') as fout:
                for pred_answer in pred_answers:
                    fout.write(
                        json.dumps(pred_answer, ensure_ascii=False) + '\n')

            self.logger.info('Saving {} results to {}'.format(
                result_prefix, result_file))

        # this average loss is invalid on test set, since we don't have true start_id and end_id
        #print(total_num)
        ave_loss = 1.0 * total_loss / total_num
        # compute the bleu and rouge scores if reference answers is provided
        if len(ref_answers) > 0:
            pred_dict, ref_dict = {}, {}
            for pred, ref in zip(pred_answers, ref_answers):
                question_id = ref['question_id']
                if len(ref['answers']) > 0:
                    pred_dict[question_id] = normalize(pred['answers'])
                    ref_dict[question_id] = normalize(ref['answers'])
            bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict)
        else:
            bleu_rouge = None
        return ave_loss, bleu_rouge
コード例 #8
0
def evaluate(args,
             eval_batches,
             vocab,
             sess,
             result_dir=None,
             result_prefix=None):
    logger = logging.getLogger("brc")
    total_loss, total_num = 0, 0
    result_prob = []
    pred_answers, ref_answers = [], []
    with tf.device('/gpu:%s' % args.gpus[0]):
        with tf.variable_scope('model', reuse=True):
            model = Model(vocab, args)
    pp_scores = (0.43, 0.23, 0.16, 0.10, 0.09)
    for _, batch_data in enumerate(eval_batches, 1):
        feed_dict = {
            model.p: batch_data['passage_token_ids'],
            model.q: batch_data['question_token_ids'],
            model.p_length: batch_data['passage_length'],
            model.q_length: batch_data['question_length'],
            model.start_label: batch_data['start_ids'],
            model.end_label: batch_data['end_ids'],
            model.match_score: batch_data['match_scores']
        }
        start_probs, end_probs, loss = sess.run(
            [model.start_probs, model.end_probs, model.loss], feed_dict)
        total_loss += loss * len(batch_data['raw_data'])
        total_num += len(batch_data['raw_data'])
        padded_p_len = len(batch_data['passage_token_ids'][0])
        for sample, start_prob, end_prob in zip(batch_data['raw_data'],
                                                start_probs, end_probs):
            start_prob_list = [str(element) for element in list(start_prob)]
            end_prob_list = [str(element) for element in list(end_prob)]
            result_prob.append({
                "question_id": sample['question_id'],
                "start_prob": start_prob_list,
                "end_prob": end_prob_list,
                'padd_len': padded_p_len
            })
            best_answer, segmented_pred = find_best_answer(
                sample,
                start_prob,
                end_prob,
                padded_p_len,
                args,
                para_prior_scores=pp_scores)
            pred_answers.append({
                'question_id':
                sample['question_id'],
                'question_type':
                sample['question_type'],
                'answers': [best_answer],
                'entity_answers': [[]],
                'yesno_answers': [],
                'segmented_question':
                sample['segmented_question'],
                'segmented_answers':
                segmented_pred
            })
            if 'segmented_answers' in sample:
                ref_answers.append({
                    'question_id':
                    sample['question_id'],
                    'question_type':
                    sample['question_type'],
                    'answers': [
                        ''.join(seg_ans)
                        for seg_ans in sample['segmented_answers']
                    ],
                    'entity_answers': [[]],
                    'yesno_answers': []
                })
    if result_dir is not None and result_prefix is not None:
        result_file = os.path.join(result_dir, result_prefix + '.json')
        with open(result_file, 'w') as fout:
            for pred_answer in tqdm(pred_answers):
                fout.write(json.dumps(pred_answer, ensure_ascii=False) + '\n')
        prob_file = os.path.join(result_dir, result_prefix + 'probs.json')
        with open(prob_file, 'w') as f:
            for prob in tqdm(result_prob):
                f.write(json.dumps(prob, ensure_ascii=False) + "\n")
        logger.info('Saving {} results to {}'.format(result_prefix,
                                                     result_file))
    ave_loss = 1.0 * total_loss / total_num
    if len(ref_answers) > 0:
        pred_dict, ref_dict = {}, {}
        for pred, ref in zip(pred_answers, ref_answers):
            question_id = ref['question_id']
            if len(ref['answers']) > 0:
                pred_dict[question_id] = normalize(pred['answers'])
                ref_dict[question_id] = normalize(ref['answers'])
        bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict)
    else:
        bleu_rouge = None
    return ave_loss, bleu_rouge