示例#1
0
文件: __init__.py 项目: qidanrui/DQE
def build_table(filename):
    try:
        workbook = xlrd.open_workbook(filename)
        table = workbook.sheets()[0]
    except:
        return None
    attrs = table.row_values(0)
    attrs = map(lambda x: strQ2B(x), attrs)
    attrs = map(lambda x: x.encode('utf-8')
                if isinstance(x, unicode) else x, attrs)
    schema = dict()
    instance = list()
    for i in xrange(1, table.nrows):
        row_data = table.row_values(i)
        map(lambda x: strQ2B(x), row_data)
        row_dict = dict()
        for (k, v) in zip(attrs, row_data):
            if v == '':
                v = None
            if v is not None and k not in schema:
                t = type(v)
                if t == int:
                    t = float
                schema[k] = t
            row_dict[k] = v
        instance.append(row_dict)
    return Table(instance, schema, attrs)
示例#2
0
def read_test(files):
    total_sentences = []
    for file in files:
        sen = []
        with open(file, encoding='utf8') as f:
            temp_token = Token(-1, NULL, NULL, None, None)
            for line in f:
                line = line.strip()
                if line:
                    line = strQ2B(line)
                    i, w, _, p, _, _, h, d, _, _ = line.split()
                    new_token = Token(int(i), w, p, [d], [int(h)])
                    if new_token.token_id == temp_token.token_id:
                        sen[-1].head_id.append(int(h))
                        sen[-1].dep.append(d)
                    else:
                        sen.append(new_token)
                    temp_token = new_token
                else:
                    total_sentences.append(Sentence(sen))
                    sen = []
                    temp_token = Token(-1, NULL, NULL, None, None)

            if len(sen) > 0:
                total_sentences.append(Sentence(sen))

    return total_sentences
示例#3
0
 def read_sentences(self):
   file = open(self.input_file, 'r', encoding='utf-8')
   content = file.read()
   sentences = re.sub('[ ]+', self.SPLIT_CHAR, strQ2B(content)).splitlines()  # 将词分隔符统一为双空格
   sentences = list(filter(None, sentences))  # 去除空行
   file.close()
   return sentences
示例#4
0
def build_and_read_train(file):
    vocab_file = os.path.join(Config.data.processed_path, Config.data.vocab_file)
    pos_file = os.path.join(Config.data.processed_path, Config.data.pos_file)
    dep_file = os.path.join(Config.data.processed_path, Config.data.dep_file)
    vocab, pos, dep = set(), set(), set()
    sen = []
    total_sentences = []

    with open(file, encoding='utf8') as f:
        for line in f:
            line = line.strip()
            if line:
                line = strQ2B(line)
                i, w, _, p, _, _, h, d, _, _ = line.split()
                vocab.add(w)
                pos.add(p)
                dep.add(d)
                sen.append(Token(int(i), w, p, d, int(h)))
            else:
                if 5 <= len(sen) <= 30:
                    total_sentences.append(Sentence(sen))
                sen = []

        if len(sen) > 0:
            total_sentences.append(Sentence(sen))

    with open(vocab_file, 'w', encoding='utf8') as f:
        f.write('\n'.join([UNK, ROOT] + sorted(vocab)))
    with open(pos_file, 'w', encoding='utf8') as f:
        f.write('\n'.join([ROOT] + sorted(pos)))
    with open(dep_file, 'w', encoding='utf8') as f:
        f.write('\n'.join(sorted(dep)))

    return total_sentences
示例#5
0
def build_and_read_train(file):
    vocab_file = os.path.join(Config.data.processed_path,
                              Config.data.vocab_file)
    rel_file = os.path.join(Config.data.processed_path, Config.data.rel_file)
    entity_file = os.path.join(Config.data.processed_path,
                               Config.data.entity_file)
    vocab, rel, entity_pairs = set(), set(), set()

    all_samples = []

    with open(file, encoding='utf8') as f:
        for line in f:
            line = line.strip()
            if line:
                line = strQ2B(line)
                en1, en2, r, sen = line.split('\t')
                entity_pairs.add(en1 + ' ' + en2)
                vocab.update(sen)
                rel.add(r)
                all_samples.append((en1, en2, r, sen))

    with open(vocab_file, 'w', encoding='utf8') as f:
        f.write('\n'.join([PAD, UNK] + sorted(vocab)))
    with open(rel_file, 'w', encoding='utf8') as f:
        f.write('\n'.join(sorted(rel)))
    with open(entity_file, 'w', encoding='utf8') as f:
        f.write('\n'.join(sorted(entity_pairs)))
    return all_samples
示例#6
0
def build_and_read_train(file):
    vocab_file = os.path.join(Config.data.processed_path, Config.data.vocab_file)
    tag_file = os.path.join(Config.data.processed_path, Config.data.tag_file)
    vocab, tag = set(), set()
    sen = []
    total_sen = []

    with open(file, encoding='utf8') as f:
        for line in f:
            line = line.strip()
            if line:
                line = strQ2B(line)
                w, t, p, l = line.split()
                vocab.add(w)
                tag.add(t)
                sen.append([w, t, int(p), l])
            else:
                total_sen.append(sen)
                sen = []

        if sen:
            total_sen.append(sen)

    with open(vocab_file, 'w', encoding='utf8') as f:
        f.write('\n'.join([PAD, UNK] + sorted(vocab)))
    with open(tag_file, 'w', encoding='utf8') as f:
        f.write('\n'.join([PAD] + sorted(tag)))

    return total_sen
示例#7
0
def read_test(file):
    all_samples = []
    sample = []

    count = 0
    with open(file, encoding='utf8') as f:
        for line in f:
            line = line.strip()
            if line:
                line = strQ2B(line)
                content = line.split()
                if count == 0:
                    sample.append(int(content[0]))
                elif count in [2, 3, 5, 6]:
                    sample.append(pad_to_fixed_len(content))

                count += 1
            else:
                count = 0
                all_samples.append(sample)
                sample = []

        if sample:
            all_samples.append(sample)

    return all_samples
示例#8
0
def get_testing_data(test_path: str):
    data = []
    with open(test_path) as f:
        articles = strQ2B(f.read()).split('\n\n--------------------\n\n')[:-1]
        for article in articles:
            lines = article.split('\n')

            article_id = lines[0][12:]
            tokens = list(lines[1])

            data.append(tokens)

    return data
示例#9
0
    def read_corpus_from_file(self, file_path):
        with open(file_path, encoding='UTF-8') as f:
            for line in f:
                sent = []
                for word in utils.strQ2B(line).strip().split():
                    # TODO 是否需要将数字替换为 #NUM, 将英文替换为 #ENG
                    sent.append(word)
                    if self.is_puns(word):
                        yield sent
                        sent = []

                if len(sent) > 0:
                    yield sent
示例#10
0
 def read_content(self):
     words = []
     labels = []
     with open(self.corpus_path, 'r', encoding='utf8') as corpus_file:
         sentences = corpus_file.read().splitlines()
         for sentence in sentences:
             word = []
             label = []
             sections = sentence.strip().split(' ')
             for section in sections:
                 pair = section.split('/')
                 word.append(strQ2B(pair[0]))
                 label.append(pair[1])
             words.append(word)
             labels.append(label)
     return words, labels
示例#11
0
def read_test(file):
    sen = []
    total_sen = []
    with open(file, encoding='utf8') as f:
        for line in f:
            line = line.strip()
            if line:
                line = strQ2B(line)
                w, t, p, l = line.split()
                sen.append([w, t, int(p), l])
            else:
                total_sen.append(sen)
                sen = []

        if sen:
            total_sen.append(sen)

    return total_sen
示例#12
0
def read_test(file):
    sen = []
    total_sentences = []
    with open(file, encoding='utf8') as f:
        for line in f:
            line = line.strip()
            if line:
                line = strQ2B(line)
                i, w, _, p, _, _, h, d, _, _ = line.split()
                sen.append(Token(int(i), w, p, d, int(h)))
            else:
                if 5 <= len(sen) <= 30:
                    total_sentences.append(Sentence(sen))
                sen = []
        if len(sen) > 0:
            total_sentences.append(Sentence(sen))

    return total_sentences
 def read_words(self):
     """
     读取语料库中每个句子的字符
     """
     with open(self.input_file, 'r', encoding='utf-8') as file:
         # 将所有中文标点符号转换为英文标点符号
         data = strQ2B(file.read())
         self.sentences = data.splitlines()
         # 根据标点符号对长句子进行切分
         self.sentences = re.split(u'[。,?;!]', ''.join(self.sentences))
         self.sentences = [
             sentence.strip() for sentence in self.sentences
             if len(sentence) >= 3
         ]
         self.sample_nums = len(self.sentences)
         words = data.replace('\n', "").split(self.SPLIT_CHAR)
         words = [word.strip() for word in words]
         self.words = [char for word in words for char in word]
示例#14
0
def infer():
    with tf.Graph().as_default():
        hparams = create_hparams(FLAGS)
        hparams.is_user_input = True
        cws_model = BaseModel(hparams)
        with tf.Session() as session:
            # in_sent = "中华人民共和国中央人民政府今天成立了。"
            while True:
                in_sent = input("\n====== in_sent: ")
                chars = list(utils.strQ2B(in_sent).strip())
                chars_length = len(chars)
                char_ids = [
                    vocab.chr2id.get(chr, vocab.OOV_ID) for chr in chars
                ]

                tags, scores = cws_model.infer(session, [char_ids],
                                               [chars_length])
                print("------ score: %s ------" % scores[0])
                for item in zip(chars, tags[0].tolist()):
                    print(item[0], vocab.id2tag[item[1]])
示例#15
0
def build_and_read_train(files):
    vocab_file = os.path.join(Config.data.processed_path, Config.data.vocab_file)
    pos_file = os.path.join(Config.data.processed_path, Config.data.pos_file)
    dep_file = os.path.join(Config.data.processed_path, Config.data.dep_file)
    vocab, pos, dep = set(), set(), set()
    total_sentences = []
    for file in files:
        sen = []
        with open(file, encoding='utf8') as f:
            temp_token = Token(-1, NULL, NULL, None, None)
            for line in f:
                line = line.strip()
                if line:
                    line = strQ2B(line)
                    i, w, _, p, _, _, h, d, _, _ = line.split()
                    vocab.add(w)
                    pos.add(p)
                    dep.add(d)
                    new_token = Token(int(i), w, p, [d], [int(h)])
                    if new_token.token_id == temp_token.token_id:
                        sen[-1].head_id.append(int(h))
                        sen[-1].dep.append(d)
                    else:
                        sen.append(new_token)
                    temp_token = new_token
                else:
                    total_sentences.append(Sentence(sen))
                    sen = []
                    temp_token = Token(-1, NULL, NULL, None, None)

            if len(sen) > 0:
                total_sentences.append(Sentence(sen))

    with open(vocab_file, 'w', encoding='utf8') as f:
        f.write('\n'.join([NULL, UNK, ROOT] + sorted(vocab)))
    with open(pos_file, 'w', encoding='utf8') as f:
        f.write('\n'.join([NULL, ROOT] + sorted(pos)))
    with open(dep_file, 'w', encoding='utf8') as f:
        f.write('\n'.join(sorted(dep)))

    return total_sentences
示例#16
0
def read_test(file):
    total_sen, total_pos, total_arc, total_dep = [], [], [], []
    sen, pos, arc, dep = ['<ROOT>'], ['<ROOT>'], [], []
    with open(file, encoding='utf8') as f:
        for line in f:
            line = line.strip()
            if line:
                line = strQ2B(line)
                _, w, _, p, _, _, a, d, _, _ = line.split()
                sen.append(w)
                pos.append(p)
                arc.append(int(a))
                dep.append(d)
            else:
                total_sen.append(sen)
                total_pos.append(pos)
                total_arc.append(arc)
                total_dep.append(dep)
                sen, pos, arc, dep = ['<ROOT>'], ['<ROOT>'], [], []

    return total_sen, total_pos, total_arc, total_dep
示例#17
0
def build_and_read_train(file):
    vocab_file = os.path.join(Config.data.processed_path,
                              Config.data.vocab_file)
    pos_file = os.path.join(Config.data.processed_path, Config.data.pos_file)
    vocab, pos = set(), set()

    all_samples = []
    sample = []

    count = 0
    with open(file, encoding='utf8') as f:
        for line in f:
            line = line.strip()
            if line:
                line = strQ2B(line)
                content = line.split()
                if count == 0:
                    sample.append(int(content[0]))
                elif count in [2, 5]:
                    sample.append(pad_to_fixed_len(content))
                    vocab.update(content)
                elif count in [3, 6]:
                    sample.append(pad_to_fixed_len(content))
                    pos.update(content)

                count += 1
            else:
                count = 0
                all_samples.append(sample)
                sample = []

        if sample:
            all_samples.append(sample)

    with open(vocab_file, 'w', encoding='utf8') as f:
        f.write('\n'.join([PAD, UNK] + sorted(vocab)))
    with open(pos_file, 'w', encoding='utf8') as f:
        f.write('\n'.join([PAD] + sorted(pos)))

    return all_samples
示例#18
0
def build_and_read_train(file):
    vocab_file = os.path.join(Config.data.processed_path,
                              Config.data.vocab_file)
    pos_file = os.path.join(Config.data.processed_path, Config.data.pos_file)
    dep_file = os.path.join(Config.data.processed_path, Config.data.dep_file)
    vocab, pos_tag, dep_tag = set(), set(), set()
    total_sen, total_pos, total_arc, total_dep = [], [], [], []
    sen, pos, arc, dep = ['<ROOT>'], ['<ROOT>'], [], []

    with open(file, encoding='utf8') as f:
        for line in f:
            line = line.strip()
            if line:
                line = strQ2B(line)
                _, w, _, p, _, _, a, d, _, _ = line.split()
                vocab.add(w)
                pos_tag.add(p)
                dep_tag.add(d)

                sen.append(w)
                pos.append(p)
                arc.append(int(a))
                dep.append(d)
            else:
                total_sen.append(sen)
                total_pos.append(pos)
                total_arc.append(arc)
                total_dep.append(dep)
                sen, pos, arc, dep = ['<ROOT>'], ['<ROOT>'], [], []

    with open(vocab_file, 'w', encoding='utf8') as f:
        f.write('\n'.join(['<PAD>', '<UNK>', '<ROOT>'] + sorted(vocab)))
    with open(pos_file, 'w', encoding='utf8') as f:
        f.write('\n'.join(['<PAD>', '<ROOT>'] + sorted(pos_tag)))
    with open(dep_file, 'w', encoding='utf8') as f:
        f.write('\n'.join(sorted(dep_tag)))

    return total_sen, total_pos, total_arc, total_dep
示例#19
0
        predict_input_fn = tf.estimator.inputs.numpy_input_fn(x={
            'word_id': total_sen,
            'en_indicator': total_indicator,
            'length': length
        },
                                                              batch_size=512,
                                                              num_epochs=1,
                                                              shuffle=False)
        results = list(self.estimator.predict(input_fn=predict_input_fn))
        results = data_loader.id2rel(results, self.rel_dict)
        return results


if __name__ == '__main__':

    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    tf.logging.set_verbosity(tf.logging.ERROR)

    Config('config/bigru-adv-soft_label.yml')
    Config.train.model_dir = os.path.expanduser(Config.train.model_dir)
    Config.data.processed_path = os.path.expanduser(Config.data.processed_path)

    p = Predictor()
    while True:
        text = input('input text -> ')
        text = strQ2B(text)
        entity = input('input entity (separated by space) -> ')
        results = p.predict([[text, entity]])
        print('result ->', results[0])
示例#20
0
        results = []
        for i in range(len(pred)):
            result = []
            arc = mst(pred[i]['arc_logits'][:length[i], :length[i]])[1:]
            label = np.argmax(pred[i]['label_logits'][range(1, length[i]), arc, :], -1)
            label = data_loader.id2dep(label, self.dep_dict)
            [result.append((w, p, str(a), l)) for w, p, a, l in zip(sen[i], pos[i], arc, label)]
            results.append(result)
        return results


if __name__ == '__main__':
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    tf.logging.set_verbosity(tf.logging.ERROR)

    Config('config/biaffine.yml')
    Config.train.model_dir = os.path.expanduser(Config.train.model_dir)
    Config.data.processed_path = os.path.expanduser(Config.data.processed_path)

    p = Predictor()
    while True:
        text = input('input words (separated by space) -> ')
        text = strQ2B(text)
        words = text.split(' ')
        pos = input('input tags (separated by space) -> ')
        pos = strQ2B(pos)
        pos = pos.split(' ')
        results = p.predict(words, pos)
        print('result ->')
        print('\n'.join(['\t'.join(n) for n in results[0]]) + '\n')
示例#21
0
        results = list(self.estimator.predict(input_fn=predict_input_fn))
        return results


if __name__ == '__main__':
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    tf.logging.set_verbosity(tf.logging.ERROR)

    Config('config/diin.yml')
    Config.train.model_dir = os.path.expanduser(Config.train.model_dir)
    Config.data.processed_path = os.path.expanduser(Config.data.processed_path)

    p = Predictor()
    word_map = {1: '相同意图', 0: '不同意图'}
    while True:
        origin_premise = input('input premise words (separated by space) -> ')
        premise_tags = input('input premise tags (separated by space) -> ')
        premise_words = strQ2B(origin_premise).split(' ')
        premise_tags = premise_tags.split(' ')

        origin_hypothesis = input('input hypothesis words (separated by space) -> ')
        hypothesis_tags = input('input hypothesis tags (separated by space) -> ')
        hypothesis_words = strQ2B(origin_hypothesis).split(' ')
        hypothesis_tags = hypothesis_tags.split(' ')

        results = p.predict([[premise_words, premise_tags, hypothesis_words, hypothesis_tags]])
        print('result ->')
        print('句子一:', ''.join(premise_words))
        print('句子二:', ''.join(hypothesis_words))
        print('结果:', word_map[results[0]])
示例#22
0
    def preprocess_train(self, dataset):
        TEXT_NORM = self.config['text_norm']
        MAX_TOKENS_LENGTH = self.config['max_tokens_length']
        MAX_SENT_NUM = self.config['max_sent_num']
        sentence_sum , sentence_length_sum, truncate_span, span_sum = 0, 0, 0, 0
        pbar = tqdm(total=len(dataset))
        lengths = []
        for _, ins in enumerate(dataset):
            if self.config['use_bert'] or self.config['use_transformer'] or self.config['use_rnn_basic_encoder']:
                UNK_ID = self.tokenizer.vocab['[UNK]']
                PAD_ID = self.tokenizer.vocab['[PAD]']
            elif self.config['use_xlnet']:
                UNK_ID = self.tokenizer.convert_tokens_to_ids('<unk>')
                PAD_ID = self.tokenizer.convert_tokens_to_ids('<pad>')

            ids_list = []
            ids_length = []
            attention_mask = []
            labels_list = []
            cut_word_labels_list = []
            pos_tag_labels_list = []
            parser_labels_list = []

            sentences = ins['sentences'][:MAX_SENT_NUM]
            ins['merge_sentences'] = sentences
            for sentence in sentences:
                ids = []
                mask = []
                labels = []
                cut_word_labels = []
                pos_tag_labels = []
                parser_labels = []

                if TEXT_NORM:
                    sentence_norm = Traditional2Simplified(strQ2B(sentence)).lower()
                    assert len(sentence_norm) == len(sentence)
                    sentence = sentence_norm
                for char in sentence:
                    ids.append(self.tokenizer.convert_tokens_to_ids(char))
                ids_length.append(len(ids))
                mask = [1 for _ in range(ids_length[-1])]
                labels = [0 for _ in range(ids_length[-1])]
                #pos_tag_labels = [0 for _ in range(ids_length[-1])]
                #cut_word_labels = [0 for _ in range(ids_length[-1])]

                pad_num = MAX_TOKENS_LENGTH - ids_length[-1]
                ids.extend([PAD_ID for _ in range(pad_num)])
                mask.extend([0 for _ in range(pad_num)])
                labels.extend([-1 for _ in range(pad_num)])
                #pos_tag_labels.extend([-1 for _ in range(pad_num)])
                #cut_word_labels.extend([-1 for _ in range(pad_num)])

                words = None
                if self.config['cut_word_task']:
                    words = list(self.segmentor.segment(sentence))
                    #words = list(self.segmentor.segment(re.sub('\s', '#', sentence)))
                    for word in words:
                        cut_word_labels.append(1)
                        for _ in word[1:]:
                            cut_word_labels.append(0)
                    assert len(cut_word_labels) == ids_length[-1]
                    cut_word_labels.extend([-1 for _ in range(pad_num)])
                
                postags = None
                if self.config['pos_tag_task']:
                    if words is None:
                        words = list(self.segmentor.segment(sentence))
                    postags = list(self.postagger.postag(words))
                    assert len(postags) == len(words)
                    for idx, word in enumerate(words):
                        if postags[idx].startswith('n'):
                            postags[idx] = 'n'

                        postag_id = POS_TAG2ID.get(postags[idx])
                        if postag_id is None:
                            pos_tag_labels.append(5)
                        else:
                            pos_tag_labels.append(postag_id)

                        for _ in word[1:]:
                            pos_tag_labels.append(0)

                    assert len(pos_tag_labels) == ids_length[-1]
                    pos_tag_labels.extend([-1 for _ in range(pad_num)])
                
                if self.config['parser_task']:
                    if words is None:
                        words = list(self.segmentor.segment(sentence))
                    if postags is None:
                        postags = list(self.postagger.postag(words))
                    arcs = list(self.parser.parse(words, postags))
                    for idx, word in enumerate(words):
                        arc_head = len(''.join(words[:arcs[idx].head - 1]))
                        if arcs[idx].head == 0 or arc_head >= MAX_TOKENS_LENGTH:
                            parser_labels.extend([-1 for _ in word])
                            continue
                        parser_labels.append(arc_head)
                        for _ in word[1:]:
                            parser_labels.append(-1)
                    assert len(parser_labels) == ids_length[-1]
                    parser_labels.extend([-1 for _ in range(pad_num)])

                ids_list.append(ids)
                attention_mask.append(mask)
                labels_list.append(labels)
                cut_word_labels_list.append(cut_word_labels)
                pos_tag_labels_list.append(pos_tag_labels)
                parser_labels_list.append(parser_labels)
                assert len(ids) == len(mask) == len(labels_list[-1])

            for idx, span in enumerate(ins['ann_valid_mspans']):
                dranges = ins['ann_mspan2dranges'].get(span)
                label = ins['ann_mspan2guess_field'].get(span)
                assert label is not None and dranges is not None
                if label == 'OtherType':
                    continue
                for sent_idx, beg, end in dranges:
                    if sent_idx >= MAX_SENT_NUM:
                        continue
                    labels_list[sent_idx][beg] = NER_LABEL2ID['B-' + label]
                    for k in range(beg + 1, end):
                        labels_list[sent_idx][k] = NER_LABEL2ID['I-' + label]
            events = []
            event_cls = [0 for _ in EVENT_TYPES]
            for _, event_type, event in ins['recguid_eventname_eventdict_list']:
                event['event_type'] = event_type
                events.append(event)
                event_cls[EVENT_TYPE2ID.get(event_type)] = 1
            ins['events'] = events

            assert len(ids_list) == len(labels_list) == len(attention_mask)
            for idx in range(len(ids_list)):
                ids_list[idx] = ids_list[idx][:MAX_TOKENS_LENGTH]
                labels_list[idx] = labels_list[idx][:MAX_TOKENS_LENGTH]
                attention_mask[idx] = attention_mask[idx][:MAX_TOKENS_LENGTH]
                ids_length[idx] = MAX_TOKENS_LENGTH if ids_length[idx] > MAX_TOKENS_LENGTH else ids_length[idx]
                #assert len(ids_list[idx]) == len(mask) == len(labels_list[-1])

                cut_word_labels_list[idx] = cut_word_labels_list[idx][:MAX_TOKENS_LENGTH]
                pos_tag_labels_list[idx] = pos_tag_labels_list[idx][:MAX_TOKENS_LENGTH]
                parser_labels_list[idx] = parser_labels_list[idx][:MAX_TOKENS_LENGTH]

            ins['ids_list'] = ids_list
            ins['labels_list'] = labels_list
            ins['attention_mask'] = attention_mask
            ins['event_cls'] = event_cls
            ins['ids_length'] = ids_length
            lengths.extend(ids_length)

            ins['cw_labels_list'] = cut_word_labels_list
            ins['pos_tag_labels_list'] = pos_tag_labels_list
            ins['parser_labels_list'] = parser_labels_list
            pbar.update()
        pbar.close()
        random.shuffle(dataset)
示例#23
0
文件: task.py 项目: tujie-jiangye/dee
    def preprocess_train(self):
        TEXT_NORM = self.config['text_norm']
        MAX_TOKENS_LENGTH = self.config['max_tokens_length']
        MAX_SENT_NUM = self.config['max_sent_num']
        sentence_sum, sentence_length_sum, truncate_span, span_sum = 0, 0, 0, 0
        pbar = tqdm(total=len(self.train))
        for i, ins in enumerate(self.train):
            events = ins['events']
            content = ins['content']
            #content = re.sub('\s', '/', ins['content'])

            if TEXT_NORM:
                content_norm = Traditional2Simplified(strQ2B(content)).lower()
                assert len(content) == len(content_norm)
                ins['content_norm'] = content_norm

            if self.config['use_bert']:
                UNK_ID = self.tokenizer.vocab['[UNK]']
                PAD_ID = self.tokenizer.vocab['[PAD]']
            elif self.config['use_xlnet']:
                UNK_ID = self.tokenizer.convert_tokens_to_ids('<unk>')
                PAD_ID = self.tokenizer.convert_tokens_to_ids('<pad>')

            ids = []
            for i, char in enumerate(content):
                if TEXT_NORM:
                    char = content_norm[i]
                ids.append(self.tokenizer.convert_tokens_to_ids(char))

                # if char in self.tokenizer.vocab:
                #     ids.append(self.tokenizer.vocab[char])
                # else:
                #     ids.append(UNK_ID)

            labels = [0 for _ in ids]
            event_cls = [0 for _ in EVENT_TYPES]
            for event in events:
                event_cls[EVENT_TYPES.index(event['event_type'])] = 1
                label_drange_dict = {}
                for event_role, span in event.items():
                    if event_role == 'event_type' or event_role == 'event_id' or not span:
                        continue
                    label_drange_dict[event_role] = []
                    span_sum += 1
                    find_idx = -0.5
                    while find_idx != -1:
                        find_idx = content.find(span, int(find_idx + 1))
                        if find_idx != -1:
                            assert content[find_idx:find_idx +
                                           len(span)] == span
                            label_drange_dict[event_role].append(
                                (find_idx, find_idx + len(span)))
                            # labels[find_idx] = NER_LABEL2ID['B-' + event_role]
                            # for k in range(1, len(span)):
                            #     labels[find_idx + k] = NER_LABEL2ID['I-' + event_role]

                if self.config['ner_label_count_limit'] is None:
                    for event_role, dranges in label_drange_dict.items():
                        for drange in dranges:
                            labels[drange[0]] = NER_LABEL2ID['B-' + event_role]
                            for k in range(drange[0] + 1, drange[1]):
                                labels[k] = NER_LABEL2ID['I-' + event_role]
                else:
                    # filte ner label by event position which is calculated by role position
                    count_limit = self.config['ner_label_count_limit']
                    event_pos = []
                    for event_role, dranges in label_drange_dict.items():
                        role_pos = list(map(lambda drange: drange[0], dranges))
                        role_pos = sum(role_pos) / len(role_pos)
                        event_pos.append(role_pos)
                    event_pos = sum(event_pos) / len(event_pos)

                    for event_role, dranges in label_drange_dict.items():
                        rela_pos = list(
                            map(
                                lambda drange:
                                (abs(drange[0] - event_pos), drange), dranges))
                        rela_pos = sorted(rela_pos, key=lambda x: x[0])
                        for drange_count, (pos_diff,
                                           drange) in enumerate(rela_pos):
                            if drange_count < count_limit or pos_diff < self.config[
                                    'ner_label_sentence_length']:
                                assert content[drange[0]:drange[1]] == event[
                                    event_role]
                                labels[drange[0]] = NER_LABEL2ID['B-' +
                                                                 event_role]
                                for k in range(drange[0] + 1, drange[1]):
                                    labels[k] = NER_LABEL2ID['I-' + event_role]
            assert len(ids) == len(content) == len(labels)

            if self.config['cut_word_task']:
                cw_labels = []
                cw_list = list(self.segmentor.segment(content_norm))
                # temp = ''.join(cw_list)
                # for idx, char in enumerate(temp):
                #     if content_norm[idx] != char:
                #         assert 1 == 2
                for word in cw_list:
                    cw_labels.append(1)
                    for _ in word[1:]:
                        cw_labels.append(0)
                assert len(content) == len(cw_labels)

            sentences = []
            raw_sentences = list(
                filter(lambda x: bool(x), re.split('([^。;]+[。;])', content)))
            curr_pos = 0
            sentence_sum += len(raw_sentences)
            for sentence in raw_sentences:
                sentence_length_sum += len(sentence)
                # print(len(sentence))
                if len(sentence) < MAX_TOKENS_LENGTH:
                    sentences.append(sentence)
                    curr_pos += len(sentence)
                else:
                    while len(sentence) > 0:
                        sentences.append(sentence[:MAX_TOKENS_LENGTH])
                        curr_pos += len(sentences[-1])
                        if curr_pos < len(labels) and labels[
                                curr_pos] != 0 and labels[curr_pos - 1] != 0:
                            truncate_span += 1
                            #print(truncate_span / span_sum)
                        sentence = sentence[MAX_TOKENS_LENGTH:]

            merge_sentences = []
            curr_sentence = ''
            for sentence in sentences:
                if len(sentence) + len(curr_sentence) <= MAX_TOKENS_LENGTH:
                    curr_sentence += sentence
                else:
                    merge_sentences.append(curr_sentence)
                    curr_sentence = sentence
            if curr_sentence:
                merge_sentences.append(curr_sentence)

            curr_pos = 0
            ids_list = []
            labels_list = []
            attention_mask = []
            ids_length = []
            cw_labels_list = []
            if len(merge_sentences) > 3:
                filted_merge_sentences = []
                for sentence in merge_sentences:
                    if sum(labels[curr_pos:curr_pos + len(sentence)]) == 0:
                        curr_pos += len(sentence)
                    else:
                        filted_merge_sentences.append(sentence)
                        ids_list.append(ids[curr_pos:curr_pos + len(sentence)])
                        labels_list.append(labels[curr_pos:curr_pos +
                                                  len(sentence)])
                        attention_mask.append(
                            [1 for _ in range(len(sentence))])
                        ids_length.append(len(ids_list[-1]))
                        if self.config['cut_word_task']:
                            cw_labels_list.append(cw_labels[curr_pos:curr_pos +
                                                            len(sentence)])

                        if len(ids_list[-1]) < MAX_TOKENS_LENGTH:
                            pad_num = MAX_TOKENS_LENGTH - len(ids_list[-1])
                            ids_list[-1].extend(
                                [PAD_ID for _ in range(pad_num)])
                            labels_list[-1].extend(
                                [-1 for _ in range(pad_num)])
                            attention_mask[-1].extend(
                                [0 for _ in range(pad_num)])
                            if self.config['cut_word_task']:
                                cw_labels_list[-1].extend(
                                    [-1 for _ in range(pad_num)])

                        curr_pos += len(sentence)
                        assert ids_length[-1] == sum(attention_mask[-1])

                assert len(filted_merge_sentences) == len(ids_list) == len(
                    labels_list) == len(attention_mask)
                if len(filted_merge_sentences
                       ) > MAX_SENT_NUM:  # truncate sentence
                    filted_merge_sentences = filted_merge_sentences[:
                                                                    MAX_SENT_NUM]
                    ids_list = ids_list[:MAX_SENT_NUM]
                    labels_list = labels_list[:MAX_SENT_NUM]
                    attention_mask = attention_mask[:MAX_SENT_NUM]
                    ids_length = ids_length[:MAX_SENT_NUM]
                    if self.config['cut_word_task']:
                        cw_labels_list = cw_labels_list[:MAX_SENT_NUM]
                ins['merge_sentences'] = filted_merge_sentences
                assert len(filted_merge_sentences) == len(ids_list) == len(
                    labels_list) == len(attention_mask)
            else:
                for sentence in merge_sentences:
                    ids_list.append(ids[curr_pos:curr_pos + len(sentence)])
                    labels_list.append(labels[curr_pos:curr_pos +
                                              len(sentence)])
                    attention_mask.append([1 for _ in range(len(sentence))])
                    ids_length.append(len(ids_list[-1]))
                    if self.config['cut_word_task']:
                        cw_labels_list.append(cw_labels[curr_pos:curr_pos +
                                                        len(sentence)])

                    if len(ids_list[-1]) < MAX_TOKENS_LENGTH:
                        pad_num = MAX_TOKENS_LENGTH - len(ids_list[-1])
                        ids_list[-1].extend([PAD_ID for _ in range(pad_num)])
                        labels_list[-1].extend([-1 for _ in range(pad_num)])
                        attention_mask[-1].extend([0 for _ in range(pad_num)])
                        if self.config['cut_word_task']:
                            cw_labels_list[-1].extend(
                                [-1 for _ in range(pad_num)])

                    curr_pos += len(sentence)
                    assert ids_length[-1] == sum(attention_mask[-1])
                ins['merge_sentences'] = merge_sentences
                assert len(merge_sentences) == len(ids_list) == len(
                    labels_list) == len(attention_mask)
            ins['ids_list'] = ids_list
            ins['labels_list'] = labels_list
            ins['attention_mask'] = attention_mask
            ins['ids'] = ids
            ins['labels'] = labels
            ins['event_cls'] = event_cls
            ins['ids_length'] = ids_length
            ins['cw_labels_list'] = cw_labels_list
            assert ''.join(merge_sentences) == content
            pbar.update()
        #doc_id: 2047486 for test
        #chr(1627)
        random.shuffle(self.train)
示例#24
0
文件: task.py 项目: tujie-jiangye/dee
    def preprocess_test(self):
        TEXT_NORM = self.config['text_norm']
        MAX_TOKENS_LENGTH = self.config['max_tokens_length']
        MAX_SENT_NUM = self.config['max_sent_num']
        pbar = tqdm(total=len(self.test))
        for i, ins in enumerate(self.test):
            content = ins['content']

            if self.config['use_bert']:
                UNK_ID = self.tokenizer.vocab['[UNK]']
                PAD_ID = self.tokenizer.vocab['[PAD]']
            elif self.config['use_xlnet']:
                UNK_ID = self.tokenizer.convert_tokens_to_ids('<unk>')
                PAD_ID = self.tokenizer.convert_tokens_to_ids('<pad>')

            if TEXT_NORM:
                content_norm = Traditional2Simplified(strQ2B(content)).lower()
                assert len(content) == len(content_norm)
                ins['content_norm'] = content_norm

            ids = []
            for i, char in enumerate(content):
                if TEXT_NORM:
                    char = content_norm[i]
                ids.append(self.tokenizer.convert_tokens_to_ids(char))
                # if char in self.tokenizer.vocab:
                #     ids.append(self.tokenizer.vocab[char])
                # else:
                #     ids.append(UNK_ID)

            sentences_ids = []
            sentences = []
            raw_sentences = list(
                filter(lambda x: bool(x), re.split('([^。;]+[。;])', content)))
            curr_pos = 0
            for sentence in raw_sentences:
                if len(sentence) < MAX_TOKENS_LENGTH:
                    sentences.append(sentence)
                    curr_pos += len(sentence)
                else:
                    while len(sentence) > 0:
                        sentences.append(sentence[:MAX_TOKENS_LENGTH])
                        curr_pos += len(sentences[-1])
                        sentence = sentence[MAX_TOKENS_LENGTH:]

            merge_sentences = []
            curr_sentence = ''
            for sentence in sentences:
                if len(sentence) + len(curr_sentence) <= MAX_TOKENS_LENGTH:
                    curr_sentence += sentence
                else:
                    merge_sentences.append(curr_sentence)
                    curr_sentence = sentence
            if curr_sentence:
                merge_sentences.append(curr_sentence)

            curr_pos = 0
            ids_list = []
            attention_mask = []
            ids_length = []
            filted_merge_sentences = []
            for sentence in merge_sentences:
                filted_merge_sentences.append(sentence)
                ids_list.append(ids[curr_pos:curr_pos + len(sentence)])
                attention_mask.append([1 for _ in range(len(sentence))])
                ids_length.append(len(ids_list[-1]))
                if len(ids_list[-1]) < MAX_TOKENS_LENGTH:
                    pad_num = MAX_TOKENS_LENGTH - len(ids_list[-1])
                    ids_list[-1].extend([PAD_ID for _ in range(pad_num)])
                    attention_mask[-1].extend([0 for _ in range(pad_num)])
                curr_pos += len(sentence)
                assert ids_length[-1] == sum(attention_mask[-1])
            assert len(filted_merge_sentences) == len(ids_list) == len(
                attention_mask)

            if len(filted_merge_sentences) > MAX_SENT_NUM:  # truncate sentence
                for truncate_beg in [3, 2, 1, 0]:
                    if truncate_beg + MAX_SENT_NUM <= len(
                            filted_merge_sentences):
                        break
                filted_merge_sentences = filted_merge_sentences[
                    truncate_beg:truncate_beg + MAX_SENT_NUM]
                ids_list = ids_list[truncate_beg:truncate_beg + MAX_SENT_NUM]
                attention_mask = attention_mask[truncate_beg:truncate_beg +
                                                MAX_SENT_NUM]
                ids_length = ids_length[truncate_beg:truncate_beg +
                                        MAX_SENT_NUM]
            ins['merge_sentences'] = filted_merge_sentences
            assert len(filted_merge_sentences) == len(ids_list) == len(
                attention_mask)

            ins['ids_list'] = ids_list
            ins['attention_mask'] = attention_mask
            ins['ids'] = ids
            ins['ids_length'] = ids_length
            assert ''.join(merge_sentences) == content
            pbar.update()
        pickle.dump(self.test, open(self.config['test_doc_file'], mode='wb'))
示例#25
0
                                                              shuffle=False)
        labels = list(self.estimator.predict(input_fn=predict_input_fn))
        return [
            data_loader.id2label(labels[i][:length[i]], self.label_dict)
            for i in range(len(inputs))
        ]


if __name__ == '__main__':

    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    tf.logging.set_verbosity(tf.logging.ERROR)

    Config('config/bilstm-highway.yml')
    Config.train.model_dir = os.path.expanduser(Config.train.model_dir)
    Config.data.processed_path = os.path.expanduser(Config.data.processed_path)

    p = Predictor()
    while True:
        text = input('input words (separated by space) -> ')
        text = strQ2B(text)
        words = text.split(' ')
        tags = input('input tags (separated by space) -> ')
        tags = strQ2B(tags)
        tags = tags.split(' ')
        predicate = input('input predicate -> ')
        results = p.predict([[words, tags, predicate]])
        print('result ->')
        for i in range(len(words)):
            print("{:\u3000<5} {:<5}".format(words[i], results[0][i]))