示例#1
0
    def vote_entities(self, model_inputs, sent, id2ent, threshold):
        entities_ls = []
        for idx, model in enumerate(self.models):
            tmp_tokens = model(**model_inputs)[0][0]
            tmp_entities = crf_decode(tmp_tokens, sent, id2ent)
            entities_ls.append(tmp_entities)

        return vote(entities_ls, threshold)
示例#2
0
def base_predict(model, device, info_dict, ensemble=False, mixed=""):
    labels = defaultdict(list)

    tokenizer = info_dict["tokenizer"]
    id2ent = info_dict["id2ent"]

    with torch.no_grad():
        for _ex in info_dict["examples"]:
            ex_idx = _ex["id"]
            raw_text = _ex["text"]

            if not len(raw_text):
                labels[ex_idx] = []
                print("{}为空".format(ex_idx))
                continue

            sentences = cut_sent(raw_text, MAX_SEQ_LEN)

            start_index = 0

            for sent in sentences:

                sent_tokens = fine_grade_tokenize(sent, tokenizer)

                encode_dict = tokenizer.encode_plus(
                    text=sent_tokens,
                    max_length=MAX_SEQ_LEN,
                    is_pretokenized=True,
                    pad_to_max_length=False,
                    return_tensors="pt",
                    return_token_type_ids=True,
                    return_attention_mask=True,
                )

                model_inputs = {
                    "token_ids": encode_dict["input_ids"],
                    "attention_masks": encode_dict["attention_mask"],
                    "token_type_ids": encode_dict["token_type_ids"],
                }

                for key in model_inputs:
                    model_inputs[key] = model_inputs[key].to(device)

                if ensemble:
                    if TASK_TYPE == "crf":
                        if VOTE:
                            decode_entities = model.vote_entities(
                                model_inputs, sent, id2ent, THRESHOLD
                            )
                        else:
                            pred_tokens = model.predict(model_inputs)[0]
                            decode_entities = crf_decode(pred_tokens, sent, id2ent)
                    else:
                        if VOTE:
                            decode_entities = model.vote_entities(
                                model_inputs, sent, id2ent, THRESHOLD
                            )
                        else:
                            start_logits, end_logits = model.predict(model_inputs)
                            start_logits = (
                                start_logits[0].cpu().numpy()[1 : 1 + len(sent)]
                            )
                            end_logits = end_logits[0].cpu().numpy()[1 : 1 + len(sent)]

                            decode_entities = span_decode(
                                start_logits, end_logits, sent, id2ent
                            )

                else:

                    if mixed:
                        if mixed == "crf":
                            pred_tokens = model(**model_inputs)[0][0]
                            decode_entities = crf_decode(pred_tokens, sent, id2ent)
                        else:
                            start_logits, end_logits = model(**model_inputs)

                            start_logits = (
                                start_logits[0].cpu().numpy()[1 : 1 + len(sent)]
                            )
                            end_logits = end_logits[0].cpu().numpy()[1 : 1 + len(sent)]

                            decode_entities = span_decode(
                                start_logits, end_logits, sent, id2ent
                            )

                    else:
                        if TASK_TYPE == "crf":
                            pred_tokens = model(**model_inputs)[0][0]
                            decode_entities = crf_decode(pred_tokens, sent, id2ent)
                        else:
                            start_logits, end_logits = model(**model_inputs)

                            start_logits = (
                                start_logits[0].cpu().numpy()[1 : 1 + len(sent)]
                            )
                            end_logits = end_logits[0].cpu().numpy()[1 : 1 + len(sent)]

                            decode_entities = span_decode(
                                start_logits, end_logits, sent, id2ent
                            )

                for _ent_type in decode_entities:
                    for _ent in decode_entities[_ent_type]:
                        tmp_start = _ent[1] + start_index
                        tmp_end = tmp_start + len(_ent[0])

                        assert raw_text[tmp_start:tmp_end] == _ent[0]

                        labels[ex_idx].append((_ent_type, tmp_start, tmp_end, _ent[0]))

                start_index += len(sent)

                if not len(labels[ex_idx]):
                    labels[ex_idx] = []

    return labels
示例#3
0
def base_predict(model, device, info_dict, ensemble=False, mixed=""):
    labels = defaultdict(list)

    tokenizer = info_dict["tokenizer"]
    id2ent = info_dict["id2ent"]

    with torch.no_grad():
        for _ex in info_dict["examples"]:
            ex_idx = _ex["id"]
            raw_text = _ex["text"]

            if not len(raw_text):
                labels[ex_idx] = []
                print("{}为空".format(ex_idx))
                continue

            sentences = raw_text
            # sentences = raw_text
            start_index = 0

            for sent in sentences:

                sent_tokens = fine_grade_tokenize(sent, tokenizer)
                # sent_tokens = list(sent)

                encode_dict = tokenizer.encode_plus(
                    text=sent_tokens,
                    max_length=args.max_seq_len,
                    # is_pretokenized=True,
                    # pad_to_max_length=True,
                    padding="max_length",  # todo 这啥为啥是false
                    is_split_into_words=True,
                    return_tensors="pt",
                    return_token_type_ids=True,
                    return_attention_mask=True,
                )

                model_inputs = {
                    "token_ids": encode_dict["input_ids"],
                    "attention_masks": encode_dict["attention_mask"],
                    "token_type_ids": encode_dict["token_type_ids"],
                }

                for key in model_inputs:
                    model_inputs[key] = model_inputs[key].to(device)

                if ensemble:
                    if args.task_type == "crf":
                        if VOTE:
                            decode_entities = model.vote_entities(
                                model_inputs, sent, id2ent, THRESHOLD)
                        else:
                            pred_tokens = model.predict(model_inputs)[0]
                            decode_entities = crf_decode(
                                pred_tokens, sent, id2ent)
                    else:
                        if VOTE:
                            decode_entities = model.vote_entities(
                                model_inputs, sent, id2ent, THRESHOLD)
                        else:
                            start_logits, end_logits = model.predict(
                                model_inputs)
                            start_logits = (
                                start_logits[0].cpu().numpy()[1:1 + len(sent)])
                            end_logits = end_logits[0].cpu().numpy()[1:1 +
                                                                     len(sent)]

                            decode_entities = span_decode(
                                start_logits, end_logits, sent, id2ent)

                else:

                    if mixed:
                        if mixed == "crf":
                            pred_tokens = model(**model_inputs)[0][0]
                            decode_entities = crf_decode(
                                pred_tokens, sent, id2ent)
                        else:
                            start_logits, end_logits = model(**model_inputs)

                            start_logits = (
                                start_logits[0].cpu().numpy()[1:1 + len(sent)])
                            end_logits = end_logits[0].cpu().numpy()[1:1 +
                                                                     len(sent)]

                            decode_entities = span_decode(
                                start_logits, end_logits, sent, id2ent)

                    else:
                        if args.task_type == "crf":
                            pred_tokens = model(**model_inputs)[0][0]
                            decode_entities = crf_decode(
                                pred_tokens, sent, id2ent)
                        else:
                            start_logits, end_logits = model(**model_inputs)

                            start_logits = (
                                start_logits[0].cpu().numpy()[1:1 + len(sent)])
                            end_logits = end_logits[0].cpu().numpy()[1:1 +
                                                                     len(sent)]

                            decode_entities = span_decode(
                                start_logits, end_logits, sent, id2ent)

                for _ent_type in decode_entities:
                    for _ent in decode_entities[_ent_type]:
                        tmp_start = _ent[1] + start_index
                        tmp_end = tmp_start + len(_ent[0])

                        # try:
                        #     assert sent[tmp_start:tmp_end] == _ent[0]
                        # except:

                        #     print("-----")
                        #     print("sent[tmp_start:tmp_end]:", sent[tmp_start:tmp_end])
                        #     print("_ent[0]", _ent[0])
                        #     print(sent)
                        #     exit(1)

                        labels[ex_idx].append(
                            (_ent_type, tmp_start, tmp_end, _ent[0]))

                start_index += len(sent)

                if not len(labels[ex_idx]):
                    labels[ex_idx] = []

    return labels