class NerInference(object): def __init__(self): self.tokenizer = Tokenizer(VocabPath) with open(Class2NumFile, 'rb') as f: self.class_to_num = pickle.load(f) self.num_to_class = {} for k, v in self.class_to_num.items(): self.num_to_class[v] = k try: self.model = torch.load(NerFinetunePath).to(device).eval() except: self.model = torch.load(NerFinetunePath, map_location='cpu').eval() print('加载模型完成!') def parse_inference_text(self, ori_line): ori_line = ori_line.strip().replace(' ', '') if len(list(ori_line)) > SentenceLength: print('文本过长!') return None, None input_tokens_id = [] segment_ids = [] for token in list(ori_line): id = self.tokenizer.token_to_id(token) input_tokens_id.append(id) for i in range(SentenceLength - len(input_tokens_id)): input_tokens_id.append(0) for x in input_tokens_id: if x: segment_ids.append(1) else: segment_ids.append(0) return input_tokens_id, segment_ids def inference_single(self, text): input_tokens_id, segment_ids = self.parse_inference_text(text) input_tokens_id = torch.tensor(input_tokens_id) segment_ids = torch.tensor(segment_ids) input_token = input_tokens_id.unsqueeze(0).to(device) segment_ids = torch.tensor(segment_ids).unsqueeze(0).to(device) input_token_list = input_token.tolist() input_len = len([x for x in input_token_list[0] if x]) mlm_output = self.model(input_token, segment_ids)[:, :input_len, :] output_tensor = torch.nn.Softmax(dim=-1)(mlm_output) output_topk = torch.topk(output_tensor, 1).indices.squeeze(0).tolist() output2class = [] for i, output in enumerate(output_topk): output = output[0] # output2class.append((text[i], self.num_to_class[output])) output2class.append(self.num_to_class[output]) return output2class
class RobertaTestSet(Dataset): def __init__(self, test_path): self.tokenizer = Tokenizer(VocabPath) self.test_path = test_path self.test_lines = [] self.label_lines = [] # 读取数据 with open(self.test_path, 'r', encoding='utf-8') as f: for line in f: if line: line = line.strip() line_list = line.split('-***-') self.test_lines.append(line_list[1]) self.label_lines.append(line_list[0]) def __len__(self): return len(self.label_lines) def __getitem__(self, item): output = {} test_text = self.test_lines[item] label_text = self.label_lines[item] test_token = self.__gen_token(test_text) label_token = self.__gen_token(label_text) segment_ids = [1 if x else 0 for x in label_token] output['input_token_ids'] = test_token output['token_ids_labels'] = label_token output['segment_ids'] = segment_ids instance = { k: torch.tensor(v, dtype=torch.long) for k, v in output.items() } return instance def __gen_token(self, tokens): tar_token_ids = [101] tokens = list(tokens) tokens = tokens[:(SentenceLength - 2)] for token in tokens: token_id = self.tokenizer.token_to_id(token) tar_token_ids.append(token_id) tar_token_ids.append(102) if len(tar_token_ids) < SentenceLength: for i in range(SentenceLength - len(tar_token_ids)): tar_token_ids.append(0) return tar_token_ids
def parse_ori_line(ori_line, class_to_num): """ :param ori_line: 六味地黄{3,ypcf}丸{1,yplb} :return: [101, 123, 233, 334, 221, 299, ..., 102, ...] [ptzf, b-ypcf, i-ypcf, i-ypcf, e-ypcf, e-yplb, ..., pytzf, ...] """ ori_line = ori_line.strip().replace(' ', '') input_tokens = '' input_tokens_id = [] input_tokens_class = [] input_tokens_class_id = [] tokenizer = Tokenizer(VocabPath) i = 0 l = 0 ori_line_list = list(ori_line) while i < len(ori_line_list): if ori_line_list[i] != '{' and ori_line_list[i] != '}': input_tokens += ori_line_list[i] input_tokens_class.append(NormalChar) i += 1 l += 1 if ori_line_list[i] == '{': current_type = '' current_len = '' j = i while True: j += 1 if ori_line_list[j].isdigit(): current_len += ori_line_list[j] if ori_line_list[j] == ',': break while True: j += 1 if ori_line_list[j] == '}': break current_type += ori_line_list[j] current_len = int(current_len) if current_len == 1: input_tokens_class[l - 1] = 'e' + current_type elif current_len == 2: input_tokens_class[l - 2] = 'b' + current_type input_tokens_class[l - 1] = 'e' + current_type else: input_tokens_class[l - current_len] = 'b' + current_type input_tokens_class[l - 1] = 'e' + current_type for k in range(current_len - 2): input_tokens_class[l - 2 - k] = 'i' + current_type i = j i += 1 for token in input_tokens: id = tokenizer.token_to_id(token) if not id: print('警告!本地vocab缺少以下字符:%s!' % token) continue input_tokens_id.append(id) # 补全类别 if len(input_tokens_id) > MedicineLength - 2: return None, None, None, None else: input_tokens_id.append(102) input_tokens_class.append(NormalChar) for i in range(MedicineLength - len(input_tokens_id) - 1): input_tokens_id.append(0) input_tokens_class.append('pad') # 数值化文字分类 input_tokens_id = [101] + input_tokens_id input_tokens_class = [NormalChar] + input_tokens_class for token_class in input_tokens_class: if token_class in class_to_num: input_tokens_class_id.append(class_to_num[token_class]) else: class_to_num[token_class] = len(class_to_num) input_tokens_class_id.append(class_to_num[token_class]) return input_tokens, input_tokens_id, input_tokens_class, input_tokens_class_id, class_to_num
class NerInference(object): def __init__(self): self.NerClassDict = NerClassDict self.tokenizer = Tokenizer(VocabPath) with open(Class2NumFile, 'rb') as f: self.class_to_num = pickle.load(f) self.num_to_class = {} for k, v in self.class_to_num.items(): self.num_to_class[v] = k self.model = torch.load(NerFinetunePath).to(device).eval() print('加载模型完成!') def parse_inference_text(self, ori_line): ori_line = ori_line.strip().replace(' ', '') if len(list(ori_line)) > MedicineLength - 2: print('文本过长!') return None, None input_tokens_id = [101] segment_ids = [] for token in list(ori_line): id = self.tokenizer.token_to_id(token) input_tokens_id.append(id) input_tokens_id.append(102) for i in range(MedicineLength - len(input_tokens_id)): input_tokens_id.append(0) for x in input_tokens_id: if x: segment_ids.append(1) else: segment_ids.append(0) return input_tokens_id, segment_ids def inference_single(self, text): input_tokens_id, segment_ids = self.parse_inference_text(text) input_tokens_id = torch.tensor(input_tokens_id) segment_ids = torch.tensor(segment_ids) input_token = input_tokens_id.unsqueeze(0).to(device) segment_ids = torch.tensor(segment_ids).unsqueeze(0).to(device) input_token_list = input_token.tolist() input_len = len([x for x in input_token_list[0] if x]) - 2 mlm_output = self.model(input_token, segment_ids)[:, 1:input_len + 1, :] output_tensor = torch.nn.Softmax(dim=-1)(mlm_output) output_topk = torch.topk(output_tensor, 1).indices.squeeze(0).tolist() output2class = [] result = [] for i, output in enumerate(output_topk): output = output[0] output2class.append(self.num_to_class[output]) entities = extract_output_entities(output2class) for key, val in entities.items(): entity_len = len(val) current_text = '' current_entity = self.NerClassDict[val[0][1:]] for i in range(entity_len): current_text += text[key + i] result.append((current_text, current_entity)) print('输入数据为:', text) print('实体识别结果为:', result) return result
def parse_new_data(): """ :return: [123, 233, 334, 221, 299, ..., ...] [ptzf, b-ypcf, i-ypcf, i-ypcf, e-ypcf, e-yplb, ..., pytzf, ...] """ with open(Class2NumFile, 'rb') as f: class2num = pickle.load(f) # class2num = {'pad': 0, 'ptzf': 1} new_train_data = {} new_eval_data = {} tokenizer = Tokenizer(VocabPath) input_path = 'data/train_new' eval_path = 'data/eval_new' f_train = open(NerCorpusPath, 'a+', encoding='utf-8') f_eval = open(NerEvalPath, 'w', encoding='utf-8') category_list = [] for data_file in os.listdir(input_path): if '.txt' not in data_file: continue file_num = data_file.split('.')[0] f1 = open(os.path.join(input_path, data_file), 'r', encoding='utf-8') lines = f1.readlines() lines = [x.strip().replace(',', ',') for x in lines if x][:-1] new_train_data[file_num] = {} new_train_data[file_num]['sentence'] = '' new_train_data[file_num]['tokens_id'] = [] new_train_data[file_num]['tokens_class'] = [] new_train_data[file_num]['tokens_class_num'] = [] for i, line in enumerate(lines): try: ch, label = tuple(line.lower().split(' ')) except: print(file_num) print(i) print(line) print('\n') ch = ',' label = 'o' new_train_data[file_num]['sentence'] += ch new_train_data[file_num]['tokens_id'].append( tokenizer.token_to_id(ch)) if label == 'o' or label == '0': token_class = 'ptzf' token_class_num = 1 else: token_class = label.lower().replace('-', '') if token_class[1:] in ['qq', 'vx', 'mobile', 'email']: token_class = 'ptzf' if token_class != 'ptzf': category_list.append(token_class[1:]) if token_class in class2num: token_class_num = class2num[token_class] else: token_class_num = len(class2num) class2num[token_class] = token_class_num new_train_data[file_num]['tokens_class'].append(token_class) new_train_data[file_num]['tokens_class_num'].append( token_class_num) for data_file in os.listdir(eval_path): if '.txt' not in data_file: continue file_num = data_file.split('.')[0] f1 = open(os.path.join(eval_path, data_file), 'r', encoding='utf-8') lines = f1.readlines() lines = [x.strip().replace(',', ',') for x in lines if x][:-1] new_eval_data[file_num] = {} new_eval_data[file_num]['sentence'] = '' new_eval_data[file_num]['tokens_id'] = [] new_eval_data[file_num]['tokens_class'] = [] new_eval_data[file_num]['tokens_class_num'] = [] for i, line in enumerate(lines): try: ch, label = tuple(line.lower().split(' ')) except: print(file_num) print(i) print(line) print('\n') ch = ',' label = 'o' new_eval_data[file_num]['sentence'] += ch new_eval_data[file_num]['tokens_id'].append( tokenizer.token_to_id(ch)) if label == 'o': token_class = 'ptzf' token_class_num = 1 else: token_class = label.lower().replace('-', '') if token_class[1:] in ['qq', 'vx', 'mobile', 'email']: token_class = 'ptzf' if token_class != 'ptzf': category_list.append(token_class[1:]) token_class_num = class2num[token_class] new_eval_data[file_num]['tokens_class'].append(token_class) new_eval_data[file_num]['tokens_class_num'].append(token_class_num) print(set(category_list)) # 补全所有的句子 for num in new_train_data: difference = SentenceLength - len(new_train_data[num]['sentence']) new_train_data[num]['tokens_id'].extend([0] * difference) new_train_data[num]['tokens_class'].extend(['pad'] * difference) new_train_data[num]['tokens_class_num'].extend([class2num['pad']] * difference) new_train_data[num]['tokens_id'] = [ str(x) for x in new_train_data[num]['tokens_id'] ] new_train_data[num]['tokens_class_num'] = [ str(x) for x in new_train_data[num]['tokens_class_num'] ] for num in new_eval_data: difference = SentenceLength - len(new_eval_data[num]['sentence']) new_eval_data[num]['tokens_id'].extend([0] * difference) new_eval_data[num]['tokens_class'].extend(['pad'] * difference) new_eval_data[num]['tokens_class_num'].extend([class2num['pad']] * difference) new_eval_data[num]['tokens_id'] = [ str(x) for x in new_eval_data[num]['tokens_id'] ] new_eval_data[num]['tokens_class_num'] = [ str(x) for x in new_eval_data[num]['tokens_class_num'] ] # 将类型及编号进行存储 # with open(Class2NumFile, 'wb') as f: # pickle.dump(class2num, f) for num in new_train_data: if new_train_data[num]['sentence']: if new_train_data[num]['sentence']: f_train.write( new_train_data[num]['sentence'] + ',' + ' '.join(new_train_data[num]['tokens_id']) + ',' + ' '.join(new_train_data[num]['tokens_class']) + ',' + ' '.join(new_train_data[num]['tokens_class_num']) + '\n') for num in new_eval_data: if new_eval_data[num]['sentence']: if new_eval_data[num]['sentence']: f_eval.write(new_eval_data[num]['sentence'] + ',' + ' '.join(new_eval_data[num]['tokens_id']) + ',' + ' '.join(new_eval_data[num]['tokens_class']) + ',' + ' '.join(new_eval_data[num]['tokens_class_num']) + '\n')
def parse_source_data(): """ :return: [123, 233, 334, 221, 299, ..., ...] [ptzf, b-ypcf, i-ypcf, i-ypcf, e-ypcf, e-yplb, ..., pytzf, ...] """ MaxLen = 0 class2num = {'pad': 0, 'ptzf': 1} total_data = {} tokenizer = Tokenizer(VocabPath) input_path = os.path.join(NerSourcePath, 'data') label_path = os.path.join(NerSourcePath, 'label') f_train = open(NerCorpusPath, 'w', encoding='utf-8') # f_eval = open(NerEvalPath, 'w', encoding='utf-8') category_list = [] relabel_list = [] for data_file in os.listdir(input_path): label_word_pool = {} if '.txt' not in data_file: continue file_num = data_file.split('.')[0] f1 = open(os.path.join(input_path, data_file), 'r', encoding='utf-8') f2 = open(os.path.join(label_path, file_num + '.csv'), 'r', encoding='utf-8') sentence = f1.read().strip().replace(',', ',') # 初始化数据结构 total_data[int(file_num)] = {} total_data[int(file_num)]['sentence'] = sentence total_data[int(file_num)]['tokens_id'] = [0] * len(sentence) total_data[int(file_num)]['tokens_class'] = ['ptzf'] * len(sentence) total_data[int(file_num)]['tokens_class_num'] = [1] * len(sentence) # 存储原句tokenid, 101表示cls for i, token in enumerate(sentence): id = tokenizer.token_to_id(token) if not id: print('警告!本地vocab缺少以下字符:%s!' % token) print(sentence) # 100表示UNK total_data[int(file_num)]['tokens_id'][i] = 100 else: total_data[int(file_num)]['tokens_id'][i] = id label_lines = f2.readlines()[1:] for label_line in label_lines: label_line = label_line.split(',', 4) assert len(label_line) == 5 category = label_line[1] begin = int(label_line[2]) end = int(label_line[3]) label_words = label_line[4].strip() category_list.append(category) # if '启示录》' in label_words: # x = 1 # if category == 'organization': # print(file_num, label_words) # 校验标记正确性 ori_words = sentence[begin:end + 1] if ori_words != label_words: print('标记位置错误:%s,%s!' % (file_num, label_words)) # 校验重复标记 for j in range(begin, end + 1): if j in label_word_pool: relabel_list.append(file_num) else: label_word_pool[j] = 'ok' if category in ['QQ', 'vx', 'mobile', 'email']: continue if begin == end: if 'b' + category not in class2num: class2num['b' + category] = len(class2num) total_data[int(file_num)]['tokens_class'][end] = 'b' + category total_data[int(file_num)]['tokens_class_num'][end] = class2num[ 'b' + category] if end - begin > 0: if 'b' + category not in class2num: class2num['b' + category] = len(class2num) if 'i' + category not in class2num: class2num['i' + category] = len(class2num) total_data[int( file_num)]['tokens_class'][begin] = 'b' + category total_data[int(file_num)]['tokens_class'][begin + 1:end] = [ 'i' + category ] * (end - begin) total_data[int( file_num)]['tokens_class_num'][begin] = class2num['b' + category] total_data[int(file_num)]['tokens_class_num'][ begin + 1:end] = [class2num['i' + category]] * (end - begin) # 将长句进行分割 new_total_data = {} tmp_docker = ['', [], [], []] for num in total_data: if len(total_data[num]['sentence']) <= SentenceLength: tl = len(new_total_data) new_total_data[tl] = {} new_total_data[tl]['sentence'] = total_data[num]['sentence'] new_total_data[tl]['tokens_id'] = total_data[num]['tokens_id'] new_total_data[tl]['tokens_class'] = total_data[num][ 'tokens_class'] new_total_data[tl]['tokens_class_num'] = total_data[num][ 'tokens_class_num'] tmp_docker = ['', [], [], []] else: ts = list(total_data[num]['sentence']) ti = total_data[num]['tokens_id'] tc = total_data[num]['tokens_class'] tn = total_data[num]['tokens_class_num'] for i, word in enumerate(ts): if word in [',', ',', '。', '?', '?', '!', '!', '~', ':', ':']: if len(tmp_docker[0]) > MaxLen: MaxLen = len(tmp_docker[0]) if len(tmp_docker[0]) > 200: x = 1 if tc[i][0] == 'i' or 0 < len(tmp_docker[0]) < 10: tmp_docker[0] += word tmp_docker[1].append(ti[i]) tmp_docker[2].append(tc[i]) tmp_docker[3].append(tn[i]) else: tl = len(new_total_data) new_total_data[tl] = {} new_total_data[tl]['sentence'] = tmp_docker[0] new_total_data[tl]['tokens_id'] = tmp_docker[1] new_total_data[tl]['tokens_class'] = tmp_docker[2] new_total_data[tl]['tokens_class_num'] = tmp_docker[3] tmp_docker = ['', [], [], []] continue else: tmp_docker[0] += word tmp_docker[1].append(ti[i]) tmp_docker[2].append(tc[i]) tmp_docker[3].append(tn[i]) # print(list(set(relabel_list))) print('最长句子为:', MaxLen) print(set(category_list)) # 补全所有的句子 total_data = new_total_data for num in total_data: difference = SentenceLength - len(total_data[num]['sentence']) total_data[num]['tokens_id'].extend([0] * difference) total_data[num]['tokens_class'].extend(['pad'] * difference) total_data[num]['tokens_class_num'].extend([class2num['pad']] * difference) total_data[num]['tokens_id'] = [ str(x) for x in total_data[num]['tokens_id'] ] total_data[num]['tokens_class_num'] = [ str(x) for x in total_data[num]['tokens_class_num'] ] # 将类型及编号进行存储 with open(Class2NumFile, 'wb') as f: pickle.dump(class2num, f) for num in total_data: # rad = random.random() # if num > 3000 and rad < 0.02: # if total_data[num]['sentence']: # f_eval.write(total_data[num]['sentence'] + ',' + # ' '.join(total_data[num]['tokens_id']) + ',' + # ' '.join(total_data[num]['tokens_class']) + ',' + # ' '.join(total_data[num]['tokens_class_num']) + '\n' # ) # else: if total_data[num]['sentence']: f_train.write(total_data[num]['sentence'] + ',' + ' '.join(total_data[num]['tokens_id']) + ',' + ' '.join(total_data[num]['tokens_class']) + ',' + ' '.join(total_data[num]['tokens_class_num']) + '\n')