Beispiel #1
0
class NerInference(object):
    def __init__(self):
        self.tokenizer = Tokenizer(VocabPath)
        with open(Class2NumFile, 'rb') as f:
            self.class_to_num = pickle.load(f)
        self.num_to_class = {}
        for k, v in self.class_to_num.items():
            self.num_to_class[v] = k
        try:
            self.model = torch.load(NerFinetunePath).to(device).eval()
        except:
            self.model = torch.load(NerFinetunePath, map_location='cpu').eval()
        print('加载模型完成!')

    def parse_inference_text(self, ori_line):
        ori_line = ori_line.strip().replace(' ', '')
        if len(list(ori_line)) > SentenceLength:
            print('文本过长!')
            return None, None

        input_tokens_id = []
        segment_ids = []
        for token in list(ori_line):
            id = self.tokenizer.token_to_id(token)
            input_tokens_id.append(id)

        for i in range(SentenceLength - len(input_tokens_id)):
            input_tokens_id.append(0)

        for x in input_tokens_id:
            if x:
                segment_ids.append(1)
            else:
                segment_ids.append(0)

        return input_tokens_id, segment_ids

    def inference_single(self, text):
        input_tokens_id, segment_ids = self.parse_inference_text(text)
        input_tokens_id = torch.tensor(input_tokens_id)
        segment_ids = torch.tensor(segment_ids)

        input_token = input_tokens_id.unsqueeze(0).to(device)
        segment_ids = torch.tensor(segment_ids).unsqueeze(0).to(device)
        input_token_list = input_token.tolist()
        input_len = len([x for x in input_token_list[0] if x])
        mlm_output = self.model(input_token, segment_ids)[:, :input_len, :]
        output_tensor = torch.nn.Softmax(dim=-1)(mlm_output)
        output_topk = torch.topk(output_tensor, 1).indices.squeeze(0).tolist()
        output2class = []
        for i, output in enumerate(output_topk):
            output = output[0]
            # output2class.append((text[i], self.num_to_class[output]))
            output2class.append(self.num_to_class[output])
        return output2class
Beispiel #2
0
class RobertaTestSet(Dataset):
    def __init__(self, test_path):
        self.tokenizer = Tokenizer(VocabPath)
        self.test_path = test_path
        self.test_lines = []
        self.label_lines = []
        # 读取数据
        with open(self.test_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line:
                    line = line.strip()
                    line_list = line.split('-***-')
                    self.test_lines.append(line_list[1])
                    self.label_lines.append(line_list[0])

    def __len__(self):
        return len(self.label_lines)

    def __getitem__(self, item):
        output = {}
        test_text = self.test_lines[item]
        label_text = self.label_lines[item]
        test_token = self.__gen_token(test_text)
        label_token = self.__gen_token(label_text)
        segment_ids = [1 if x else 0 for x in label_token]
        output['input_token_ids'] = test_token
        output['token_ids_labels'] = label_token
        output['segment_ids'] = segment_ids
        instance = {
            k: torch.tensor(v, dtype=torch.long)
            for k, v in output.items()
        }
        return instance

    def __gen_token(self, tokens):
        tar_token_ids = [101]
        tokens = list(tokens)
        tokens = tokens[:(SentenceLength - 2)]
        for token in tokens:
            token_id = self.tokenizer.token_to_id(token)
            tar_token_ids.append(token_id)
        tar_token_ids.append(102)
        if len(tar_token_ids) < SentenceLength:
            for i in range(SentenceLength - len(tar_token_ids)):
                tar_token_ids.append(0)
        return tar_token_ids
Beispiel #3
0
def parse_ori_line(ori_line, class_to_num):
    """
    :param ori_line: 六味地黄{3,ypcf}丸{1,yplb}
    :return:
    [101, 123, 233, 334, 221, 299, ..., 102, ...]
    [ptzf, b-ypcf, i-ypcf, i-ypcf, e-ypcf, e-yplb, ..., pytzf, ...]
    """
    ori_line = ori_line.strip().replace(' ', '')
    input_tokens = ''
    input_tokens_id = []
    input_tokens_class = []
    input_tokens_class_id = []
    tokenizer = Tokenizer(VocabPath)
    i = 0
    l = 0
    ori_line_list = list(ori_line)
    while i < len(ori_line_list):
        if ori_line_list[i] != '{' and ori_line_list[i] != '}':
            input_tokens += ori_line_list[i]
            input_tokens_class.append(NormalChar)
            i += 1
            l += 1
        if ori_line_list[i] == '{':
            current_type = ''
            current_len = ''
            j = i
            while True:
                j += 1
                if ori_line_list[j].isdigit():
                    current_len += ori_line_list[j]
                if ori_line_list[j] == ',':
                    break
            while True:
                j += 1
                if ori_line_list[j] == '}':
                    break
                current_type += ori_line_list[j]

            current_len = int(current_len)
            if current_len == 1:
                input_tokens_class[l - 1] = 'e' + current_type
            elif current_len == 2:
                input_tokens_class[l - 2] = 'b' + current_type
                input_tokens_class[l - 1] = 'e' + current_type
            else:
                input_tokens_class[l - current_len] = 'b' + current_type
                input_tokens_class[l - 1] = 'e' + current_type
                for k in range(current_len - 2):
                    input_tokens_class[l - 2 - k] = 'i' + current_type
            i = j
            i += 1

    for token in input_tokens:
        id = tokenizer.token_to_id(token)
        if not id:
            print('警告!本地vocab缺少以下字符:%s!' % token)
            continue
        input_tokens_id.append(id)

    # 补全类别
    if len(input_tokens_id) > MedicineLength - 2:
        return None, None, None, None
    else:
        input_tokens_id.append(102)
        input_tokens_class.append(NormalChar)
        for i in range(MedicineLength - len(input_tokens_id) - 1):
            input_tokens_id.append(0)
            input_tokens_class.append('pad')

    # 数值化文字分类
    input_tokens_id = [101] + input_tokens_id
    input_tokens_class = [NormalChar] + input_tokens_class
    for token_class in input_tokens_class:
        if token_class in class_to_num:
            input_tokens_class_id.append(class_to_num[token_class])
        else:
            class_to_num[token_class] = len(class_to_num)
            input_tokens_class_id.append(class_to_num[token_class])

    return input_tokens, input_tokens_id, input_tokens_class, input_tokens_class_id, class_to_num
Beispiel #4
0
class NerInference(object):
    def __init__(self):
        self.NerClassDict = NerClassDict
        self.tokenizer = Tokenizer(VocabPath)
        with open(Class2NumFile, 'rb') as f:
            self.class_to_num = pickle.load(f)
        self.num_to_class = {}
        for k, v in self.class_to_num.items():
            self.num_to_class[v] = k
        self.model = torch.load(NerFinetunePath).to(device).eval()
        print('加载模型完成!')

    def parse_inference_text(self, ori_line):
        ori_line = ori_line.strip().replace(' ', '')
        if len(list(ori_line)) > MedicineLength - 2:
            print('文本过长!')
            return None, None

        input_tokens_id = [101]
        segment_ids = []
        for token in list(ori_line):
            id = self.tokenizer.token_to_id(token)
            input_tokens_id.append(id)
        input_tokens_id.append(102)

        for i in range(MedicineLength - len(input_tokens_id)):
            input_tokens_id.append(0)

        for x in input_tokens_id:
            if x:
                segment_ids.append(1)
            else:
                segment_ids.append(0)

        return input_tokens_id, segment_ids

    def inference_single(self, text):
        input_tokens_id, segment_ids = self.parse_inference_text(text)
        input_tokens_id = torch.tensor(input_tokens_id)
        segment_ids = torch.tensor(segment_ids)

        input_token = input_tokens_id.unsqueeze(0).to(device)
        segment_ids = torch.tensor(segment_ids).unsqueeze(0).to(device)
        input_token_list = input_token.tolist()
        input_len = len([x for x in input_token_list[0] if x]) - 2
        mlm_output = self.model(input_token, segment_ids)[:, 1:input_len + 1, :]
        output_tensor = torch.nn.Softmax(dim=-1)(mlm_output)
        output_topk = torch.topk(output_tensor, 1).indices.squeeze(0).tolist()

        output2class = []
        result = []
        for i, output in enumerate(output_topk):
            output = output[0]
            output2class.append(self.num_to_class[output])
        entities = extract_output_entities(output2class)
        for key, val in entities.items():
            entity_len = len(val)
            current_text = ''
            current_entity = self.NerClassDict[val[0][1:]]
            for i in range(entity_len):
                current_text += text[key + i]
            result.append((current_text, current_entity))
        print('输入数据为:', text)
        print('实体识别结果为:', result)
        return result
Beispiel #5
0
def parse_new_data():
    """
    :return:
    [123, 233, 334, 221, 299, ..., ...]
    [ptzf, b-ypcf, i-ypcf, i-ypcf, e-ypcf, e-yplb, ..., pytzf, ...]
    """
    with open(Class2NumFile, 'rb') as f:
        class2num = pickle.load(f)
    # class2num = {'pad': 0, 'ptzf': 1}
    new_train_data = {}
    new_eval_data = {}
    tokenizer = Tokenizer(VocabPath)
    input_path = 'data/train_new'
    eval_path = 'data/eval_new'
    f_train = open(NerCorpusPath, 'a+', encoding='utf-8')
    f_eval = open(NerEvalPath, 'w', encoding='utf-8')
    category_list = []

    for data_file in os.listdir(input_path):
        if '.txt' not in data_file:
            continue
        file_num = data_file.split('.')[0]
        f1 = open(os.path.join(input_path, data_file), 'r', encoding='utf-8')
        lines = f1.readlines()
        lines = [x.strip().replace(',', ',') for x in lines if x][:-1]
        new_train_data[file_num] = {}
        new_train_data[file_num]['sentence'] = ''
        new_train_data[file_num]['tokens_id'] = []
        new_train_data[file_num]['tokens_class'] = []
        new_train_data[file_num]['tokens_class_num'] = []
        for i, line in enumerate(lines):
            try:
                ch, label = tuple(line.lower().split(' '))
            except:
                print(file_num)
                print(i)
                print(line)
                print('\n')
                ch = ','
                label = 'o'
            new_train_data[file_num]['sentence'] += ch
            new_train_data[file_num]['tokens_id'].append(
                tokenizer.token_to_id(ch))
            if label == 'o' or label == '0':
                token_class = 'ptzf'
                token_class_num = 1
            else:
                token_class = label.lower().replace('-', '')
                if token_class[1:] in ['qq', 'vx', 'mobile', 'email']:
                    token_class = 'ptzf'
                if token_class != 'ptzf':
                    category_list.append(token_class[1:])
                if token_class in class2num:
                    token_class_num = class2num[token_class]
                else:
                    token_class_num = len(class2num)
                    class2num[token_class] = token_class_num
            new_train_data[file_num]['tokens_class'].append(token_class)
            new_train_data[file_num]['tokens_class_num'].append(
                token_class_num)

    for data_file in os.listdir(eval_path):
        if '.txt' not in data_file:
            continue
        file_num = data_file.split('.')[0]
        f1 = open(os.path.join(eval_path, data_file), 'r', encoding='utf-8')
        lines = f1.readlines()
        lines = [x.strip().replace(',', ',') for x in lines if x][:-1]
        new_eval_data[file_num] = {}
        new_eval_data[file_num]['sentence'] = ''
        new_eval_data[file_num]['tokens_id'] = []
        new_eval_data[file_num]['tokens_class'] = []
        new_eval_data[file_num]['tokens_class_num'] = []
        for i, line in enumerate(lines):
            try:
                ch, label = tuple(line.lower().split(' '))
            except:
                print(file_num)
                print(i)
                print(line)
                print('\n')
                ch = ','
                label = 'o'
            new_eval_data[file_num]['sentence'] += ch
            new_eval_data[file_num]['tokens_id'].append(
                tokenizer.token_to_id(ch))
            if label == 'o':
                token_class = 'ptzf'
                token_class_num = 1
            else:
                token_class = label.lower().replace('-', '')
                if token_class[1:] in ['qq', 'vx', 'mobile', 'email']:
                    token_class = 'ptzf'
                if token_class != 'ptzf':
                    category_list.append(token_class[1:])
                token_class_num = class2num[token_class]
            new_eval_data[file_num]['tokens_class'].append(token_class)
            new_eval_data[file_num]['tokens_class_num'].append(token_class_num)

    print(set(category_list))

    # 补全所有的句子
    for num in new_train_data:
        difference = SentenceLength - len(new_train_data[num]['sentence'])
        new_train_data[num]['tokens_id'].extend([0] * difference)
        new_train_data[num]['tokens_class'].extend(['pad'] * difference)
        new_train_data[num]['tokens_class_num'].extend([class2num['pad']] *
                                                       difference)
        new_train_data[num]['tokens_id'] = [
            str(x) for x in new_train_data[num]['tokens_id']
        ]
        new_train_data[num]['tokens_class_num'] = [
            str(x) for x in new_train_data[num]['tokens_class_num']
        ]
    for num in new_eval_data:
        difference = SentenceLength - len(new_eval_data[num]['sentence'])
        new_eval_data[num]['tokens_id'].extend([0] * difference)
        new_eval_data[num]['tokens_class'].extend(['pad'] * difference)
        new_eval_data[num]['tokens_class_num'].extend([class2num['pad']] *
                                                      difference)
        new_eval_data[num]['tokens_id'] = [
            str(x) for x in new_eval_data[num]['tokens_id']
        ]
        new_eval_data[num]['tokens_class_num'] = [
            str(x) for x in new_eval_data[num]['tokens_class_num']
        ]

    # 将类型及编号进行存储
    # with open(Class2NumFile, 'wb') as f:
    #     pickle.dump(class2num, f)

    for num in new_train_data:
        if new_train_data[num]['sentence']:
            if new_train_data[num]['sentence']:
                f_train.write(
                    new_train_data[num]['sentence'] + ',' +
                    ' '.join(new_train_data[num]['tokens_id']) + ',' +
                    ' '.join(new_train_data[num]['tokens_class']) + ',' +
                    ' '.join(new_train_data[num]['tokens_class_num']) + '\n')
    for num in new_eval_data:
        if new_eval_data[num]['sentence']:
            if new_eval_data[num]['sentence']:
                f_eval.write(new_eval_data[num]['sentence'] + ',' +
                             ' '.join(new_eval_data[num]['tokens_id']) + ',' +
                             ' '.join(new_eval_data[num]['tokens_class']) +
                             ',' +
                             ' '.join(new_eval_data[num]['tokens_class_num']) +
                             '\n')
Beispiel #6
0
def parse_source_data():
    """
    :return:
    [123, 233, 334, 221, 299, ..., ...]
    [ptzf, b-ypcf, i-ypcf, i-ypcf, e-ypcf, e-yplb, ..., pytzf, ...]
    """
    MaxLen = 0
    class2num = {'pad': 0, 'ptzf': 1}
    total_data = {}
    tokenizer = Tokenizer(VocabPath)
    input_path = os.path.join(NerSourcePath, 'data')
    label_path = os.path.join(NerSourcePath, 'label')
    f_train = open(NerCorpusPath, 'w', encoding='utf-8')
    # f_eval = open(NerEvalPath, 'w', encoding='utf-8')
    category_list = []

    relabel_list = []
    for data_file in os.listdir(input_path):
        label_word_pool = {}
        if '.txt' not in data_file:
            continue
        file_num = data_file.split('.')[0]
        f1 = open(os.path.join(input_path, data_file), 'r', encoding='utf-8')
        f2 = open(os.path.join(label_path, file_num + '.csv'),
                  'r',
                  encoding='utf-8')
        sentence = f1.read().strip().replace(',', ',')

        # 初始化数据结构
        total_data[int(file_num)] = {}
        total_data[int(file_num)]['sentence'] = sentence
        total_data[int(file_num)]['tokens_id'] = [0] * len(sentence)
        total_data[int(file_num)]['tokens_class'] = ['ptzf'] * len(sentence)
        total_data[int(file_num)]['tokens_class_num'] = [1] * len(sentence)

        # 存储原句tokenid, 101表示cls
        for i, token in enumerate(sentence):
            id = tokenizer.token_to_id(token)
            if not id:
                print('警告!本地vocab缺少以下字符:%s!' % token)
                print(sentence)
                # 100表示UNK
                total_data[int(file_num)]['tokens_id'][i] = 100
            else:
                total_data[int(file_num)]['tokens_id'][i] = id
        label_lines = f2.readlines()[1:]
        for label_line in label_lines:
            label_line = label_line.split(',', 4)
            assert len(label_line) == 5
            category = label_line[1]
            begin = int(label_line[2])
            end = int(label_line[3])
            label_words = label_line[4].strip()
            category_list.append(category)

            # if '启示录》' in label_words:
            #     x = 1
            # if category == 'organization':
            #     print(file_num, label_words)

            # 校验标记正确性
            ori_words = sentence[begin:end + 1]
            if ori_words != label_words:
                print('标记位置错误:%s,%s!' % (file_num, label_words))

            # 校验重复标记
            for j in range(begin, end + 1):
                if j in label_word_pool:
                    relabel_list.append(file_num)
                else:
                    label_word_pool[j] = 'ok'

            if category in ['QQ', 'vx', 'mobile', 'email']:
                continue
            if begin == end:
                if 'b' + category not in class2num:
                    class2num['b' + category] = len(class2num)
                total_data[int(file_num)]['tokens_class'][end] = 'b' + category
                total_data[int(file_num)]['tokens_class_num'][end] = class2num[
                    'b' + category]
            if end - begin > 0:
                if 'b' + category not in class2num:
                    class2num['b' + category] = len(class2num)
                if 'i' + category not in class2num:
                    class2num['i' + category] = len(class2num)
                total_data[int(
                    file_num)]['tokens_class'][begin] = 'b' + category
                total_data[int(file_num)]['tokens_class'][begin + 1:end] = [
                    'i' + category
                ] * (end - begin)
                total_data[int(
                    file_num)]['tokens_class_num'][begin] = class2num['b' +
                                                                      category]
                total_data[int(file_num)]['tokens_class_num'][
                    begin +
                    1:end] = [class2num['i' + category]] * (end - begin)

    # 将长句进行分割
    new_total_data = {}
    tmp_docker = ['', [], [], []]
    for num in total_data:
        if len(total_data[num]['sentence']) <= SentenceLength:
            tl = len(new_total_data)
            new_total_data[tl] = {}
            new_total_data[tl]['sentence'] = total_data[num]['sentence']
            new_total_data[tl]['tokens_id'] = total_data[num]['tokens_id']
            new_total_data[tl]['tokens_class'] = total_data[num][
                'tokens_class']
            new_total_data[tl]['tokens_class_num'] = total_data[num][
                'tokens_class_num']
            tmp_docker = ['', [], [], []]
        else:
            ts = list(total_data[num]['sentence'])
            ti = total_data[num]['tokens_id']
            tc = total_data[num]['tokens_class']
            tn = total_data[num]['tokens_class_num']
            for i, word in enumerate(ts):
                if word in [',', ',', '。', '?', '?', '!', '!', '~', ':', ':']:
                    if len(tmp_docker[0]) > MaxLen:
                        MaxLen = len(tmp_docker[0])
                    if len(tmp_docker[0]) > 200:
                        x = 1
                    if tc[i][0] == 'i' or 0 < len(tmp_docker[0]) < 10:
                        tmp_docker[0] += word
                        tmp_docker[1].append(ti[i])
                        tmp_docker[2].append(tc[i])
                        tmp_docker[3].append(tn[i])
                    else:
                        tl = len(new_total_data)
                        new_total_data[tl] = {}
                        new_total_data[tl]['sentence'] = tmp_docker[0]
                        new_total_data[tl]['tokens_id'] = tmp_docker[1]
                        new_total_data[tl]['tokens_class'] = tmp_docker[2]
                        new_total_data[tl]['tokens_class_num'] = tmp_docker[3]
                        tmp_docker = ['', [], [], []]
                        continue
                else:
                    tmp_docker[0] += word
                    tmp_docker[1].append(ti[i])
                    tmp_docker[2].append(tc[i])
                    tmp_docker[3].append(tn[i])

    # print(list(set(relabel_list)))
    print('最长句子为:', MaxLen)
    print(set(category_list))

    # 补全所有的句子
    total_data = new_total_data
    for num in total_data:
        difference = SentenceLength - len(total_data[num]['sentence'])
        total_data[num]['tokens_id'].extend([0] * difference)
        total_data[num]['tokens_class'].extend(['pad'] * difference)
        total_data[num]['tokens_class_num'].extend([class2num['pad']] *
                                                   difference)
        total_data[num]['tokens_id'] = [
            str(x) for x in total_data[num]['tokens_id']
        ]
        total_data[num]['tokens_class_num'] = [
            str(x) for x in total_data[num]['tokens_class_num']
        ]

    # 将类型及编号进行存储
    with open(Class2NumFile, 'wb') as f:
        pickle.dump(class2num, f)

    for num in total_data:
        # rad = random.random()
        # if num > 3000 and rad < 0.02:
        #     if total_data[num]['sentence']:
        #         f_eval.write(total_data[num]['sentence'] + ',' +
        #                      ' '.join(total_data[num]['tokens_id']) + ',' +
        #                      ' '.join(total_data[num]['tokens_class']) + ',' +
        #                      ' '.join(total_data[num]['tokens_class_num']) + '\n'
        #                      )
        # else:
        if total_data[num]['sentence']:
            f_train.write(total_data[num]['sentence'] + ',' +
                          ' '.join(total_data[num]['tokens_id']) + ',' +
                          ' '.join(total_data[num]['tokens_class']) + ',' +
                          ' '.join(total_data[num]['tokens_class_num']) + '\n')