def context_entities(setname, domain_set):
    for domain in domain_set:
        with open(
                JSON_OUTPUT_DIR +
                '/mentions/{}/{}.json'.format(setname, domain), 'r') as f:
            # add entities to this trie
            trie = Trie()
            for ent_id in domain2entities[domain]:
                # 不考虑multi category的情况
                trie.setdefault(entities2name[ent_id].lower(), 0)
                # 考虑multi category的情况
                # trie.setdefault(entities2alias[ent_id], 0)
            total_alias_count = 0
            total_doc_count = 0
            matched_text_list = []
            for line in tqdm(f, desc='processing {}...'.format(domain)):
                datum = json.loads(line.strip())
                text = decode(datum['mention_context_tokens'])
                total_doc_count += 1
                i = 0
                matched_text = ''
                while i < len(text):
                    item = trie.longest_prefix_item(text[i:], default=None)
                    if item is not None:
                        prefix, key_id = item
                        if (i == 0 or text[i-1] == ' ') \
                                and (i+len(prefix) == len(text) or text[i+len(prefix)] == ' '):
                            total_alias_count += 1
                            i += len(prefix)
                            matched_text += '##' + prefix + '##'
                        else:
                            matched_text += prefix
                            i += len(prefix)
                    else:
                        matched_text += text[i]
                        i += 1
                matched_text_list.append(matched_text)
            print('Avg alias count {:.2f} in {} documents.'.format(
                total_alias_count / total_doc_count, domain))
        with open(
                JSON_OUTPUT_DIR +
                '/mentions/{}/{}_matched.json'.format(setname, domain),
                'w') as f:
            for text in matched_text_list:
                f.write(text + '\n\n')
Exemple #2
0
def create_training_instances(input_file, max_seq_length, tokenizer, rng,
                              alias2entities):
    def is_whitespace(c):
        if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
            return True
        return False

    """Create `TrainingInstance`s from raw text."""
    all_documents = []
    all_alias_token_spans = []
    from pytrie import SortedStringTrie as Trie
    trie = Trie()
    # add entities to this trie
    for alias, ents in alias2entities.items():
        trie.setdefault(alias, 0)

    with open(input_file, "r") as reader:
        for line in tqdm(reader, desc='converting tokens'):
            line = tokenization.convert_to_unicode(line.strip())
            line = json.loads(line)['text']

            tokens = []
            if do_lower_case:
                line = line.lower()
            char_to_word_offset = []
            prev_is_whitespace = True
            for c in line:
                if is_whitespace(c):
                    prev_is_whitespace = True
                else:
                    if prev_is_whitespace:
                        tokens.append(c)
                    else:
                        tokens[-1] += c
                    prev_is_whitespace = False
                char_to_word_offset.append(len(tokens) - 1)

            # 匹配文档中的alias
            alias_spans = match_alias(line, trie, alias2entities)
            # 此时的span对应粗粒度的token,span_end对应alias的最后一个token
            alias_token_spans = [(char_to_word_offset[span[0]],
                                  char_to_word_offset[span[1] - 1])
                                 for span in alias_spans]

            for span, token_span in zip(alias_spans, alias_token_spans):
                alias_tokens = ' '.join(tokens[token_span[0]:token_span[1] +
                                               1])
                alias_texts = line[span[0]:span[1]]
                assert alias_tokens in alias2entities, print(
                    alias_tokens, token_span, alias_texts, span)
            # assert all(' '.join(tokens[span[0]: span[1] + 1]) in alias2entities for span in alias_token_spans), \
            #     print([' '.join(tokens[span[0]: span[1] + 1]) for span in alias_token_spans])

            tok_to_orig_index = []  # 细粒度-粗粒度
            orig_to_tok_index = []  # 粗粒度-细粒度
            real_tokens = []
            for (i, token) in enumerate(tokens):
                orig_to_tok_index.append(len(real_tokens))
                sub_tokens = tokenizer.tokenize(token)
                for sub_token in sub_tokens:
                    tok_to_orig_index.append(i)
                    real_tokens.append(sub_token)
            # 判断当前span对应的粗粒度token是否为最后一个token,
            # 如果是的话,则取最后一个细粒度token为结尾,如果不是的话,取下一个粗粒度token对应的第一个细粒度token的前一个token为结尾。
            real_alias_token_spans = []
            for span in alias_token_spans:
                if span[1] == len(tokens) - 1:
                    real_end = orig_to_tok_index[-1]
                else:
                    real_end = orig_to_tok_index[span[1] + 1] - 1
                real_start = orig_to_tok_index[span[0]]
                real_alias_token_spans.append((real_start, real_end))

            # alias_token_spans = [(orig_to_tok_index[span[0]], orig_to_tok_index[span[1]])
            #                      for span in alias_token_spans]

            all_documents.append(real_tokens)
            all_alias_token_spans.append(real_alias_token_spans)

    vocab_words = list(tokenizer.vocab.keys())
    instances = []
    for document_index in tqdm(range(len(all_documents)),
                               total=len(all_documents),
                               desc='creating instances'):
        instances.extend(
            create_instances_from_document(all_documents, document_index,
                                           all_alias_token_spans,
                                           max_seq_length, vocab_words, rng))

    rng.shuffle(instances)
    return instances