예제 #1
0
def handle_normal_dataset(dataset, ignore_subword_match=False):
    """
    if ignore_subword_match is true, find entities with whitespace around, e.g. "entity" -> " entity "
    """
    # 加载preprocessor
    if config["encoder"] == "BERT":
        tokenizer = BertTokenizerFast.from_pretrained(config["bert_path"],
                                                      add_special_tokens=False,
                                                      do_lower_case=False)
        tokenize = tokenizer.tokenize
        get_tok2char_span_map = lambda text: tokenizer.encode_plus(
            text, return_offsets_mapping=True, add_special_tokens=False)[
                "offset_mapping"]
    elif config["encoder"] == "BiLSTM":
        tokenize = lambda text: text.split(" ")

        def get_tok2char_span_map(text):
            tokens = tokenize(text)
            tok2char_span = []
            char_num = 0
            for tok in tokens:
                tok2char_span.append((char_num, char_num + len(tok)))
                char_num += len(tok) + 1  # +1: whitespace
            return tok2char_span

    preprocessor = Preprocessor(
        tokenize_func=tokenize,
        get_tok2char_span_map_func=get_tok2char_span_map)
    # add char span
    dataset, miss_sample_list = preprocessor.add_char_span(
        dataset, ignore_subword_match=False)

    if len(miss_sample_list) > 0:
        print("=========存在不匹配实体,请检查===========")
        print(miss_sample_list)
        print("========================================")

    # add token span
    dataset = preprocessor.add_tok_span(dataset)

    return dataset
예제 #2
0
    def __init__(self, config, model_path, rel2id_path):
        self.config = config
        self.hyper_parameters = config["hyper_parameters"]

        os.environ["TOKENIZERS_PARALLELISM"] = "true"
        os.environ["CUDA_VISIBLE_DEVICES"] = str(config["device_num"])
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")

        torch.backends.cudnn.deterministic = True

        rel2id = json.load(open(rel2id_path, "r", encoding="utf-8"))
        self.handshaking_tagger = HandshakingTaggingScheme(
            rel2id=rel2id,
            max_seq_len=self.hyper_parameters["max_test_seq_len"])

        tokenizer = BertTokenizerFast.from_pretrained(config["bert_path"],
                                                      add_special_tokens=False,
                                                      do_lower_case=False)
        self.data_maker = DataMaker4Bert(tokenizer, self.handshaking_tagger)
        get_tok2char_span_map = lambda text: tokenizer.encode_plus(
            text, return_offsets_mapping=True, add_special_tokens=False)[
                "offset_mapping"]

        tokenizer = BertTokenizerFast.from_pretrained(config["bert_path"],
                                                      add_special_tokens=False,
                                                      do_lower_case=False)
        tokenize = tokenizer.tokenize
        get_tok2char_span_map = lambda text: tokenizer.encode_plus(
            text, return_offsets_mapping=True, add_special_tokens=False)[
                "offset_mapping"]
        self.preprocessor = Preprocessor(
            tokenize_func=tokenize,
            get_tok2char_span_map_func=get_tok2char_span_map)

        roberta = AutoModel.from_pretrained(config["bert_path"])
        self.model = TPLinkerBert(
            roberta,
            len(rel2id),
            self.hyper_parameters["shaking_type"],
            self.hyper_parameters["inner_enc_type"],
            self.hyper_parameters["dist_emb_size"],
            self.hyper_parameters["ent_add_dist"],
            self.hyper_parameters["rel_add_dist"],
        ).to(self.device)
        self.model.load_state_dict(torch.load(model_path))
        self.model.eval()
예제 #3
0
}:
    tokenize = lambda text: text.split(" ")

    def get_tok2char_span_map(text):
        tokens = text.split(" ")
        tok2char_span = []
        char_num = 0
        for tok in tokens:
            tok2char_span.append((char_num, char_num + len(tok)))
            char_num += len(tok) + 1  # +1: whitespace
        return tok2char_span


# In[ ]:

preprocessor = Preprocessor(tokenize_func=tokenize,
                            get_tok2char_span_map_func=get_tok2char_span_map)

# In[ ]:

# train and valid max token num
max_tok_num = 0
all_data = train_data + valid_data

for sample in all_data:
    tokens = tokenize(sample["text"])
    max_tok_num = max(max_tok_num, len(tokens))
max_tok_num

# In[ ]:

if max_tok_num > hyper_parameters["max_seq_len"]:
}:
    tokenize = lambda text: text.split(" ")

    def get_tok2char_span_map(text):
        tokens = text.split(" ")
        tok2char_span = []
        char_num = 0
        for tok in tokens:
            tok2char_span.append((char_num, char_num + len(tok)))
            char_num += len(tok) + 1  # +1: whitespace
        return tok2char_span


# In[ ]:

preprocessor = Preprocessor(tokenize_func=tokenize,
                            get_tok2char_span_map_func=get_tok2char_span_map)

# In[ ]:

all_data = []
for data in list(test_data_dict.values()):
    all_data.extend(data)

max_tok_num = 0
for sample in tqdm(all_data, desc="Calculate the max token number"):
    tokens = tokenize(sample["text"])
    max_tok_num = max(len(tokens), max_tok_num)

# In[ ]:

split_test_data = False
def tplinker_predict(config, test_data_path, model_state_path):

    config = config.eval_config
    hyper_parameters = config["hyper_parameters"]

    os.environ["TOKENIZERS_PARALLELISM"] = "true"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(config["device_num"])
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_home = config["data_home"]
    experiment_name = config["exp_name"]
    # test_data_path = os.path.join(data_home, experiment_name, config["test_data"])
    batch_size = hyper_parameters["batch_size"]
    rel2id_path = os.path.join(data_home, experiment_name, config["rel2id"])
    save_res_dir = os.path.join(config["save_res_dir"], experiment_name)
    max_test_seq_len = hyper_parameters["max_test_seq_len"]
    sliding_len = hyper_parameters["sliding_len"]
    force_split = hyper_parameters["force_split"]
    # for reproductivity
    torch.backends.cudnn.deterministic = True
    if force_split:
        split_test_data = True
        print("force to split the test dataset!")
        # read test data
    test_data = json.load(open(test_data_path, "r", encoding="utf-8"))
    # get tokenizer
    tokenize, get_tok2char_span_map = get_tokenizer(config["encoder"],
                                                    config["bert_path"])
    # get data token num
    max_tok_num = get_token_num(test_data, tokenize)
    max_seq_len = min(max_tok_num, max_test_seq_len)

    # data prpcessor
    preprocessor = Preprocessor(
        tokenize_func=tokenize,
        get_tok2char_span_map_func=get_tok2char_span_map)
    data = preprocessor.split_into_short_samples(test_data,
                                                 max_seq_len,
                                                 sliding_len=sliding_len,
                                                 encoder=config["encoder"],
                                                 data_type="test")

    rel2id = json.load(open(rel2id_path, "r", encoding="utf-8"))
    handshaking_tagger = HandshakingTaggingScheme(rel2id=rel2id,
                                                  max_seq_len=max_seq_len)
    metrics = MetricsCalculator(handshaking_tagger)

    # get data maker and model
    if config["encoder"] == "BERT":
        data_maker = get_data_bert_data_maker(config["bert_path"],
                                              handshaking_tagger)
        rel_extractor = get_tplinker_bert_model(config["bert_path"], rel2id,
                                                hyper_parameters)
    elif config["encoder"] == "BiLSTM":
        token2idx_path = os.path.join(data_home, experiment_name,
                                      config["token2idx"])
        data_maker, token2idx = get_data_bilstm_data_maker(
            token2idx_path, handshaking_tagger)
        rel_extractor = get_tplinker_lstm_model(token2idx, hyper_parameters,
                                                rel2id)

    # load model
    rel_extractor.load_state_dict(
        torch.load(model_state_path, map_location=torch.device('cpu')))
    rel_extractor.eval()

    result = predict(config, data, data_maker, max_seq_len, batch_size, device,
                     rel_extractor, True, handshaking_tagger)

    with open("./results/nyt_demo/predict_result.json", "w",
              encoding="utf-8") as f:
        f.write(json.dumps(result, ensure_ascii=False, indent=2))
elif config["encoder"] == "BiLSTM":
    tokenize = lambda text: text.split(" ")

    def get_tok2char_span_map(text):
        tokens = tokenize(text)
        tok2char_span = []
        char_num = 0
        for tok in tokens:
            tok2char_span.append((char_num, char_num + len(tok)))
            char_num += len(tok) + 1  # +1: whitespace
        return tok2char_span


# In[7]:

preprocessor = Preprocessor(tokenize_func=tokenize,
                            get_tok2char_span_map_func=get_tok2char_span_map)

# ## Transform

# In[8]:

ori_format = config["ori_data_format"]
if ori_format != "tplinker":  # if tplinker, skip transforming
    for file_name, data in file_name2data.items():
        if "train" in file_name:
            data_type = "train"
        if "valid" in file_name:
            data_type = "valid"
        if "test" in file_name:
            data_type = "test"
        # 将原数据中实体和关系构建为tplinker的关系抽取的数据集格式,"text","relation_list"