示例#1
0
    def predict(self, model, dataloader, exam, feats, prediction_file, need_sp_logit_file=False):

        model.eval()
        answer_dict = {}
        sp_dict = {}
        # dataloader.refresh()
        total_test_loss = [0] * 5
        # context_idxs, context_mask, segment_idxs,
        # query_mapping, all_mapping,
        # ids, y1, y2, q_type,
        # start_mapping,
        # is_support
        tqdm_obj = tqdm(dataloader, ncols=80)
        for step, batch in enumerate(tqdm_obj):
            batch = tuple(t.to(self.device) for t in batch)
            # batch['context_mask'] = batch['context_mask'].float()
            start_logits, end_logits, type_logits, sp_logits, start_position, end_position = model(*batch)
            loss1 = self.criterion(start_logits, batch[6]) + self.criterion(end_logits, batch[7])#y1,y2
            loss2 = self.config.type_lambda * self.criterion(type_logits, batch[8])#q_type
            loss3 = self.config.sp_lambda * self.sp_loss_fct(sp_logits.view(-1), batch[10].float().view(-1)).sum() / batch[9].sum()  # start_mapping

            loss = loss1 + loss2 + loss3
            loss_list = [loss, loss1, loss2, loss3]

            for i, l in enumerate(loss_list):
                if not isinstance(l, int):
                    total_test_loss[i] += l.item()

            batchsize = batch[0].size(0)
            #ids
            answer_dict_ = convert_to_tokens(exam, feats, batch[5], start_position.data.cpu().numpy().tolist(),
                                             end_position.data.cpu().numpy().tolist(),
                                             np.argmax(type_logits.data.cpu().numpy(), 1))
            answer_dict.update(answer_dict_)

            predict_support_np = torch.sigmoid(sp_logits).data.cpu().numpy()
            for i in range(predict_support_np.shape[0]):
                cur_sp_pred = []
                cur_id = batch[5][i].item()

                cur_sp_logit_pred = []  # for sp logit output
                for j in range(predict_support_np.shape[1]):
                    if j >= len(exam[cur_id].sent_names):
                        break
                    if need_sp_logit_file:
                        temp_title, temp_id = exam[cur_id].sent_names[j]
                        cur_sp_logit_pred.append((temp_title, temp_id, predict_support_np[i, j]))
                    if predict_support_np[i, j] > self.config.sp_threshold:
                        cur_sp_pred.append(exam[cur_id].sent_names[j])
                sp_dict.update({cur_id: cur_sp_pred})

        new_answer_dict = {}
        for key, value in answer_dict.items():
            new_answer_dict[key] = value.replace(" ", "")
        prediction = {'answer': new_answer_dict, 'sp': sp_dict}
        with open(prediction_file, 'w', encoding='utf8') as f:
            json.dump(prediction, f, indent=4, ensure_ascii=False)

        for i, l in enumerate(total_test_loss):
            print("Test Loss{}: {}".format(i, l / len(dataloader)))
示例#2
0
def predict(model,
            dataloader,
            example_dict,
            feature_dict,
            prediction_file,
            need_sp_logit_file=False):

    model.eval()
    answer_dict = {}
    sp_dict = {}
    dataloader.refresh()
    total_test_loss = [0] * 5

    for batch in tqdm(dataloader):

        batch['context_mask'] = batch['context_mask'].float()
        start_logits, end_logits, type_logits, sp_logits, start_position, end_position = model(
            batch)

        loss_list = compute_loss(batch, start_logits, end_logits, type_logits,
                                 sp_logits, start_position, end_position)

        for i, l in enumerate(loss_list):
            if not isinstance(l, int):
                total_test_loss[i] += l.item()

        answer_dict_ = convert_to_tokens(
            example_dict, feature_dict, batch['ids'],
            start_position.data.cpu().numpy().tolist(),
            end_position.data.cpu().numpy().tolist(),
            np.argmax(type_logits.data.cpu().numpy(), 1))
        answer_dict.update(answer_dict_)

        predict_support_np = torch.sigmoid(sp_logits).data.cpu().numpy()
        for i in range(predict_support_np.shape[0]):
            cur_sp_pred = []
            cur_id = batch['ids'][i]

            cur_sp_logit_pred = []  # for sp logit output
            for j in range(predict_support_np.shape[1]):
                if j >= len(example_dict[cur_id].sent_names):
                    break
                if need_sp_logit_file:
                    temp_title, temp_id = example_dict[cur_id].sent_names[j]
                    cur_sp_logit_pred.append(
                        (temp_title, temp_id, predict_support_np[i, j]))
                if predict_support_np[i, j] > args.sp_threshold:
                    cur_sp_pred.append(example_dict[cur_id].sent_names[j])
            sp_dict.update({cur_id: cur_sp_pred})

    new_answer_dict = {}
    for key, value in answer_dict.items():
        new_answer_dict[key] = value.replace(" ", "")
    prediction = {'answer': new_answer_dict, 'sp': sp_dict}
    with open(prediction_file, 'w', encoding='utf8') as f:
        json.dump(prediction, f, indent=4, ensure_ascii=False)

    for i, l in enumerate(total_test_loss):
        print("Test Loss{}: {}".format(i, l / len(dataloader)))
    test_loss_record.append(sum(total_test_loss[:3]) / len(dataloader))
示例#3
0
    def bert_classification(self, content, question):
        logger.info('1:{}'.format( content))
        conv_dic = {}
        conv_dic['_id'] = 0
        conv_dic['context'] = self.process_context(content)
        conv_dic['question'] = question
        conv_dic["answer"] = ""
        conv_dic['supporting_facts'] = []
        rows = [conv_dic]
        filename = "data/{}.json".format(time.time())
        with open(filename, 'w', encoding='utf8') as fw:
            json.dump(rows, fw, ensure_ascii=False, indent=4)

        exam, feats, dataset = self.data.load_file(filename, False)

        data_loader = DataLoader(dataset, batch_size=self.config.batch_size)

        self.model.eval()
        answer_dict = {}
        sp_dict = {}
        tqdm_obj = tqdm(data_loader, ncols=80)
        for step, batch in enumerate(tqdm_obj):
            batch = tuple(t.to(self.device) for t in batch)
            start_logits, end_logits, type_logits, sp_logits, start_position, end_position = self.model(*batch)

            batchsize = batch[0].size(0)
            # ids
            answer_dict_ = convert_to_tokens(exam, feats, batch[5], start_position.data.cpu().numpy().tolist(),
                                             end_position.data.cpu().numpy().tolist(),
                                             np.argmax(type_logits.data.cpu().numpy(), 1))
            answer_dict.update(answer_dict_)

            predict_support_np = torch.sigmoid(sp_logits).data.cpu().numpy()
            for i in range(predict_support_np.shape[0]):
                cur_sp_pred = []
                cur_id = batch[5][i].item()

                cur_sp_logit_pred = []  # for sp logit output
                for j in range(predict_support_np.shape[1]):
                    if j >= len(exam[cur_id].sent_names):
                        break

                    if predict_support_np[i, j] > self.config.sp_threshold:
                        cur_sp_pred.append(exam[cur_id].sent_names[j])
                sp_dict.update({cur_id: cur_sp_pred})

        new_answer_dict = {}
        for key, value in answer_dict.items():
            new_answer_dict[key] = value.replace(" ", "")
        prediction = {'answer': new_answer_dict, 'sp': sp_dict}

        return {"data": prediction}
示例#4
0
        # loss2 = self.config.type_lambda * self.criterion(type_logits, batch[8])  # q_type
        # # sp_value = self.sp_loss_fct(sp_logits.view(-1), batch[10].float().view(-1)).sum()
        # # loss3 = self.config.sp_lambda * sp_value / batch[9].sum()
        #
        # loss = loss1 + loss2
        # loss_list = [loss, loss1, loss2]
        #
        # for i, l in enumerate(loss_list):
        #     if not isinstance(l, int):
        #         total_test_loss[i] += l.item()

        batchsize = batch[0].size(0)
        # ids
        answer_dict_ = convert_to_tokens(
            exam, feats, batch[5],
            start_position.data.cpu().numpy().tolist(),
            end_position.data.cpu().numpy().tolist(),
            np.argmax(type_logits.data.cpu().numpy(), 1))
        answer_dict.update(answer_dict_)

        predict_support_np = torch.sigmoid(sp_logits).data.cpu().numpy()
        for i in range(predict_support_np.shape[0]):
            cur_sp_pred = []
            cur_id = batch[5][i].item()

            cur_sp_logit_pred = []  # for sp logit output
            for j in range(predict_support_np.shape[1]):
                if j >= len(exam[cur_id].sent_names):
                    break

                if predict_support_np[i, j] > config.sp_threshold: