示例#1
0
def infer(model,dev_load,opt,device,ent2id):
    if opt.cv_num==-1:
        f = open(file='./result.txt', mode='w', encoding='utf-8')
    else:
        f = open(file='./cv_tmp/temp_result_{}'.format(opt.cv_num), mode='w', encoding='utf-8')
    with open(file='./tcdata/final_test.txt',mode='r',encoding='utf-8') as files:
        fu_raw_texts = []
        for line in files:
            fu_raw_texts.append(line.strip())

    with open(file='./data/raw_data/final_test.txt',mode='r',encoding='utf-8') as files:
        chu_raw_texts = []
        for line in files:
            chu_raw_texts.append(line.strip())

    if opt.cv_num==-1:
        raw_texts = fu_raw_texts
    else:
        if opt.cv_infer:
            raw_texts = fu_raw_texts
        else:
            raw_texts = fu_raw_texts + chu_raw_texts

    id2ent = {v:k for k,v in ent2id.items()}
    model.eval()
    decode_output = []
    with torch.no_grad():
        for batch,batch_data in enumerate(dev_load):
            raw_text = batch_data['raw_text']
            del batch_data['raw_text']
            labels = batch_data['labels']
            del batch_data['labels']

            for key in batch_data.keys():
                batch_data[key] = batch_data[key].to(device)

            if opt.task_type == 'crf':
                tmp_decode = model(**batch_data)[0]
                tmp_decode = [sample[1:-1] for sample in tmp_decode]
                decode_output+=tmp_decode

            if opt.task_type == 'span':
                tmp_decode = model(**batch_data)
                start_logits = tmp_decode[0].cpu().numpy()
                end_logits = tmp_decode[1].cpu().numpy()
                for tmp_start_logits, tmp_end_logits,text in zip(start_logits,end_logits,raw_text):
                    tmp_start_logits = tmp_start_logits[1:1 + len(text)]
                    tmp_end_logits = tmp_end_logits[1:1 + len(text)]
                    predict = span_decode(tmp_start_logits,tmp_end_logits,text,id2ent)
                    decode_output.append(predict)


    for text, decode in zip(raw_texts, decode_output):
        tmp_decode_output = " ".join([id2ent[x] if opt.task_type=='crf' else x  for x in decode])
        f.write('{}\n'.format('\u0001'.join([text, tmp_decode_output])))
    f.close()
示例#2
0
    def vote_entities(self, model_inputs, sent, id2ent, threshold):
        entities_ls = []

        for idx, model in enumerate(self.models):

            start_logits, end_logits = model(**model_inputs)
            start_logits = start_logits[0].cpu().numpy()[1:1 + len(sent)]
            end_logits = end_logits[0].cpu().numpy()[1:1 + len(sent)]

            decode_entities = span_decode(start_logits, end_logits, sent, id2ent)

            entities_ls.append(decode_entities)

        return vote(entities_ls, threshold)
示例#3
0
def base_predict(model, device, info_dict, ensemble=False, mixed=""):
    labels = defaultdict(list)

    tokenizer = info_dict["tokenizer"]
    id2ent = info_dict["id2ent"]

    with torch.no_grad():
        for _ex in info_dict["examples"]:
            ex_idx = _ex["id"]
            raw_text = _ex["text"]

            if not len(raw_text):
                labels[ex_idx] = []
                print("{}为空".format(ex_idx))
                continue

            sentences = cut_sent(raw_text, MAX_SEQ_LEN)

            start_index = 0

            for sent in sentences:

                sent_tokens = fine_grade_tokenize(sent, tokenizer)

                encode_dict = tokenizer.encode_plus(
                    text=sent_tokens,
                    max_length=MAX_SEQ_LEN,
                    is_pretokenized=True,
                    pad_to_max_length=False,
                    return_tensors="pt",
                    return_token_type_ids=True,
                    return_attention_mask=True,
                )

                model_inputs = {
                    "token_ids": encode_dict["input_ids"],
                    "attention_masks": encode_dict["attention_mask"],
                    "token_type_ids": encode_dict["token_type_ids"],
                }

                for key in model_inputs:
                    model_inputs[key] = model_inputs[key].to(device)

                if ensemble:
                    if TASK_TYPE == "crf":
                        if VOTE:
                            decode_entities = model.vote_entities(
                                model_inputs, sent, id2ent, THRESHOLD
                            )
                        else:
                            pred_tokens = model.predict(model_inputs)[0]
                            decode_entities = crf_decode(pred_tokens, sent, id2ent)
                    else:
                        if VOTE:
                            decode_entities = model.vote_entities(
                                model_inputs, sent, id2ent, THRESHOLD
                            )
                        else:
                            start_logits, end_logits = model.predict(model_inputs)
                            start_logits = (
                                start_logits[0].cpu().numpy()[1 : 1 + len(sent)]
                            )
                            end_logits = end_logits[0].cpu().numpy()[1 : 1 + len(sent)]

                            decode_entities = span_decode(
                                start_logits, end_logits, sent, id2ent
                            )

                else:

                    if mixed:
                        if mixed == "crf":
                            pred_tokens = model(**model_inputs)[0][0]
                            decode_entities = crf_decode(pred_tokens, sent, id2ent)
                        else:
                            start_logits, end_logits = model(**model_inputs)

                            start_logits = (
                                start_logits[0].cpu().numpy()[1 : 1 + len(sent)]
                            )
                            end_logits = end_logits[0].cpu().numpy()[1 : 1 + len(sent)]

                            decode_entities = span_decode(
                                start_logits, end_logits, sent, id2ent
                            )

                    else:
                        if TASK_TYPE == "crf":
                            pred_tokens = model(**model_inputs)[0][0]
                            decode_entities = crf_decode(pred_tokens, sent, id2ent)
                        else:
                            start_logits, end_logits = model(**model_inputs)

                            start_logits = (
                                start_logits[0].cpu().numpy()[1 : 1 + len(sent)]
                            )
                            end_logits = end_logits[0].cpu().numpy()[1 : 1 + len(sent)]

                            decode_entities = span_decode(
                                start_logits, end_logits, sent, id2ent
                            )

                for _ent_type in decode_entities:
                    for _ent in decode_entities[_ent_type]:
                        tmp_start = _ent[1] + start_index
                        tmp_end = tmp_start + len(_ent[0])

                        assert raw_text[tmp_start:tmp_end] == _ent[0]

                        labels[ex_idx].append((_ent_type, tmp_start, tmp_end, _ent[0]))

                start_index += len(sent)

                if not len(labels[ex_idx]):
                    labels[ex_idx] = []

    return labels
示例#4
0
def base_predict(model, device, info_dict, ensemble=False, mixed=""):
    labels = defaultdict(list)

    tokenizer = info_dict["tokenizer"]
    id2ent = info_dict["id2ent"]

    with torch.no_grad():
        for _ex in info_dict["examples"]:
            ex_idx = _ex["id"]
            raw_text = _ex["text"]

            if not len(raw_text):
                labels[ex_idx] = []
                print("{}为空".format(ex_idx))
                continue

            sentences = raw_text
            # sentences = raw_text
            start_index = 0

            for sent in sentences:

                sent_tokens = fine_grade_tokenize(sent, tokenizer)
                # sent_tokens = list(sent)

                encode_dict = tokenizer.encode_plus(
                    text=sent_tokens,
                    max_length=args.max_seq_len,
                    # is_pretokenized=True,
                    # pad_to_max_length=True,
                    padding="max_length",  # todo 这啥为啥是false
                    is_split_into_words=True,
                    return_tensors="pt",
                    return_token_type_ids=True,
                    return_attention_mask=True,
                )

                model_inputs = {
                    "token_ids": encode_dict["input_ids"],
                    "attention_masks": encode_dict["attention_mask"],
                    "token_type_ids": encode_dict["token_type_ids"],
                }

                for key in model_inputs:
                    model_inputs[key] = model_inputs[key].to(device)

                if ensemble:
                    if args.task_type == "crf":
                        if VOTE:
                            decode_entities = model.vote_entities(
                                model_inputs, sent, id2ent, THRESHOLD)
                        else:
                            pred_tokens = model.predict(model_inputs)[0]
                            decode_entities = crf_decode(
                                pred_tokens, sent, id2ent)
                    else:
                        if VOTE:
                            decode_entities = model.vote_entities(
                                model_inputs, sent, id2ent, THRESHOLD)
                        else:
                            start_logits, end_logits = model.predict(
                                model_inputs)
                            start_logits = (
                                start_logits[0].cpu().numpy()[1:1 + len(sent)])
                            end_logits = end_logits[0].cpu().numpy()[1:1 +
                                                                     len(sent)]

                            decode_entities = span_decode(
                                start_logits, end_logits, sent, id2ent)

                else:

                    if mixed:
                        if mixed == "crf":
                            pred_tokens = model(**model_inputs)[0][0]
                            decode_entities = crf_decode(
                                pred_tokens, sent, id2ent)
                        else:
                            start_logits, end_logits = model(**model_inputs)

                            start_logits = (
                                start_logits[0].cpu().numpy()[1:1 + len(sent)])
                            end_logits = end_logits[0].cpu().numpy()[1:1 +
                                                                     len(sent)]

                            decode_entities = span_decode(
                                start_logits, end_logits, sent, id2ent)

                    else:
                        if args.task_type == "crf":
                            pred_tokens = model(**model_inputs)[0][0]
                            decode_entities = crf_decode(
                                pred_tokens, sent, id2ent)
                        else:
                            start_logits, end_logits = model(**model_inputs)

                            start_logits = (
                                start_logits[0].cpu().numpy()[1:1 + len(sent)])
                            end_logits = end_logits[0].cpu().numpy()[1:1 +
                                                                     len(sent)]

                            decode_entities = span_decode(
                                start_logits, end_logits, sent, id2ent)

                for _ent_type in decode_entities:
                    for _ent in decode_entities[_ent_type]:
                        tmp_start = _ent[1] + start_index
                        tmp_end = tmp_start + len(_ent[0])

                        # try:
                        #     assert sent[tmp_start:tmp_end] == _ent[0]
                        # except:

                        #     print("-----")
                        #     print("sent[tmp_start:tmp_end]:", sent[tmp_start:tmp_end])
                        #     print("_ent[0]", _ent[0])
                        #     print(sent)
                        #     exit(1)

                        labels[ex_idx].append(
                            (_ent_type, tmp_start, tmp_end, _ent[0]))

                start_index += len(sent)

                if not len(labels[ex_idx]):
                    labels[ex_idx] = []

    return labels