Exemplo n.º 1
0
    def build_from_wikipedia(
        dump_db: DumpDB,
        tokenizer,
        normalizer,
        out_file,
        max_candidate_size,
        min_mention_count,
        max_mention_length,
        pool_size,
        chunk_size,
    ):
        logger.info("Extracting all entity names...")

        title_dict = defaultdict(Counter)
        with tqdm(total=dump_db.page_size(), mininterval=0.5) as pbar:
            initargs = (dump_db, tokenizer, normalizer, max_mention_length)
            with closing(
                    Pool(pool_size,
                         initializer=EntityDB._initialize_worker,
                         initargs=initargs)) as pool:
                for ret in pool.imap_unordered(
                        EntityDB._extract_name_entity_pairs,
                        dump_db.titles(),
                        chunksize=chunk_size):
                    for (name, title) in ret:
                        title_dict[title][name] += 1
                    pbar.update()

        logger.info("Building DB...")

        mentions = frozenset([
            mention for mention_counter in title_dict.values()
            for mention in mention_counter.keys()
        ])
        title_trie = frozenset(title_dict.keys())
        mention_trie = marisa_trie.Trie(mentions)

        def item_generator():
            for (title, mention_counter) in title_dict.items():
                for (mention, mention_count
                     ) in mention_counter.most_common()[:max_candidate_size]:
                    if mention_count < min_mention_count:
                        continue
                    yield (title, (mention_trie[mention], mention_count))

        data_trie = marisa_trie.RecordTrie("<II", item_generator())

        joblib.dump(
            dict(
                title_trie=title_trie,
                mention_trie=mention_trie,
                data_trie=data_trie,
                tokenizer=tokenizer,
                normalizer=normalizer,
                max_mention_length=max_mention_length,
            ),
            out_file,
        )
Exemplo n.º 2
0
def setUp():
    global dump_db, dump_db_file

    dump_file = pkg_resources.resource_filename(
        __name__, 'test_data/enwiki-pages-articles-sample.xml.bz2')
    dump_reader = WikiDumpReader(dump_file)
    dump_db_file = NamedTemporaryFile()

    DumpDB.build(dump_reader, dump_db_file.name, 1, 1)
    dump_db = DumpDB(dump_db_file.name)
Exemplo n.º 3
0
    def build(
            dump_db: DumpDB,
            out_file: str,
            vocab_size: int,
            white_list: List[str],
            white_list_only: bool,
            include_all: bool,
            pool_size: int,
            chunk_size: int,
            language: str,
    ):
        counter = Counter()
        # count = 0
        with tqdm(total=dump_db.page_size(), mininterval=0.5) as pbar:
            with closing(
                    Pool(pool_size, initializer=EntityVocabCustom._initialize_worker, initargs=(dump_db,))) as pool:
                for ret in pool.imap_unordered(EntityVocabCustom._count_entities, dump_db.titles(),
                                               chunksize=chunk_size):
                    # count += 1
                    # print(count)
                    # if count == 10:
                    #     pool.terminate()
                    #     pool.join()
                    #     break
                    counter.update(ret)
                    pbar.update()

        title_dict = OrderedDict()
        title_dict[PAD_TOKEN] = 0
        title_dict[UNK_TOKEN] = 0
        title_dict[MASK_TOKEN] = 0

        for title_link in white_list:
            if counter[title_link] != 0:
                title_dict[title_link] = counter[title_link]

        if not white_list_only:
            valid_titles = frozenset(dump_db.titles())
            for title_link, count in counter.most_common():
                if title_link in valid_titles and not title_link.startswith("Category:"):
                    title_dict[title_link] = count
                    if not include_all and len(title_dict) == vocab_size:
                        break

        with open(out_file, "w") as f:
            for ent_id, (title_link, count) in enumerate(title_dict.items()):
                info_box = {}
                if '#' not in title_link and ':' not in title_link and count != 0:
                    info_box = dump_db.get_infobox(title_link)

                json.dump({"id": ent_id, "entities": [[title_link, language]],
                           "count": count, "info_box": info_box}, f, default=str)
                f.write("\n")
Exemplo n.º 4
0
    def build(
        dump_db: DumpDB,
        out_file: str,
        vocab_size: int,
        white_list: List[str],
        white_list_only: bool,
        pool_size: int,
        chunk_size: int,
        language: str,
    ):
        counter = Counter()
        with tqdm(total=dump_db.page_size(), mininterval=0.5) as pbar:
            with closing(
                    Pool(pool_size,
                         initializer=EntityVocab._initialize_worker,
                         initargs=(dump_db, ))) as pool:
                for ret in pool.imap_unordered(EntityVocab._count_entities,
                                               dump_db.titles(),
                                               chunksize=chunk_size):
                    counter.update(ret)
                    pbar.update()

        title_dict = OrderedDict()
        title_dict[PAD_TOKEN] = 0
        title_dict[UNK_TOKEN] = 0
        title_dict[MASK_TOKEN] = 0

        for title in white_list:
            if counter[title] != 0:
                title_dict[title] = counter[title]

        if not white_list_only:
            valid_titles = frozenset(dump_db.titles())
            for title, count in counter.most_common():
                if title in valid_titles and not title.startswith("Category:"):
                    title_dict[title] = count
                    if len(title_dict) == vocab_size:
                        break

        with open(out_file, "w") as f:
            for ent_id, (title, count) in enumerate(title_dict.items()):
                json.dump(
                    {
                        "id": ent_id,
                        "entities": [[title, language]],
                        "count": count
                    }, f)
                f.write("\n")
Exemplo n.º 5
0
def build_entity_vocab(dump_db_file: str, white_list: List[TextIO], **kwargs):
    dump_db = DumpDB(dump_db_file)
    white_list = [line.rstrip() for f in white_list for line in f]
    EntityVocab.build(dump_db,
                      white_list=white_list,
                      language=dump_db.language,
                      **kwargs)
Exemplo n.º 6
0
def build_from_p_e_m_file(p_e_m_file, dump_db_file, wiki_mention_db_file,
                          **kwargs):
    dump_db = DumpDB(dump_db_file)
    tokenizer = BasicTokenizer(do_lower_case=False)
    normalizer = BertLowercaseNormalizer()
    wiki_mention_db = MentionDB(wiki_mention_db_file)
    MentionDB.build_from_p_e_m_file(p_e_m_file, dump_db, wiki_mention_db,
                                    tokenizer, normalizer, **kwargs)
Exemplo n.º 7
0
Arquivo: main.py Projeto: yifding/luke
def create_candidate_list(dump_db_file, out_file, data_dir):
    dump_db = DumpDB(dump_db_file)

    titles = set()
    valid_titles = frozenset(dump_db.titles())

    reader = EntityDisambiguationDataset(data_dir)
    for documents in reader.get_all_datasets():
        for document in documents:
            for mention in document.mentions:
                candidates = mention.candidates
                for candidate in candidates:
                    title = dump_db.resolve_redirect(candidate.title)
                    if title in valid_titles:
                        titles.add(title)

    for title in titles:
        out_file.write(title + "\n")
Exemplo n.º 8
0
def build_wikipedia_pretraining_dataset(dump_db_file: str, tokenizer_name: str,
                                        entity_vocab_file: str,
                                        output_dir: str,
                                        sentence_tokenizer: str, **kwargs):
    dump_db = DumpDB(dump_db_file)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    sentence_tokenizer = SentenceTokenizer.from_name(sentence_tokenizer)

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    entity_vocab = EntityVocab(entity_vocab_file)
    WikipediaPretrainingDataset.build(dump_db, tokenizer, sentence_tokenizer,
                                      entity_vocab, output_dir, **kwargs)
Exemplo n.º 9
0
    def build(
        cls,
        dump_db: DumpDB,
        tokenizer: PreTrainedTokenizer,
        sentence_tokenizer: SentenceTokenizer,
        entity_vocab: EntityVocab,
        output_dir: str,
        max_seq_length: int,
        max_entity_length: int,
        max_mention_length: int,
        min_sentence_length: int,
        include_sentences_without_entities: bool,
        include_unk_entities: bool,
        pool_size: int,
        chunk_size: int,
        max_num_documents: int,
    ):

        target_titles = [
            title for title in dump_db.titles()
            if not (":" in title and title.lower().split(":")[0] in
                    ("image", "file", "category"))
        ]
        random.shuffle(target_titles)

        if max_num_documents is not None:
            target_titles = target_titles[:max_num_documents]

        max_num_tokens = max_seq_length - 2  # 2 for [CLS] and [SEP]

        tokenizer.save_pretrained(output_dir)

        entity_vocab.save(os.path.join(output_dir, ENTITY_VOCAB_FILE))
        number_of_items = 0
        tf_file = os.path.join(output_dir, DATASET_FILE)
        options = tf.io.TFRecordOptions(
            tf.compat.v1.io.TFRecordCompressionType.GZIP)
        with TFRecordWriter(tf_file, options=options) as writer:
            with tqdm(total=len(target_titles)) as pbar:
                initargs = (
                    dump_db,
                    tokenizer,
                    sentence_tokenizer,
                    entity_vocab,
                    max_num_tokens,
                    max_entity_length,
                    max_mention_length,
                    min_sentence_length,
                    include_sentences_without_entities,
                    include_unk_entities,
                )
                with closing(
                        Pool(pool_size,
                             initializer=WikipediaPretrainingDataset.
                             _initialize_worker,
                             initargs=initargs)) as pool:
                    for ret in pool.imap(
                            WikipediaPretrainingDataset._process_page,
                            target_titles,
                            chunksize=chunk_size):
                        for data in ret:
                            writer.write(data)
                            number_of_items += 1
                        pbar.update()

        with open(os.path.join(output_dir, METADATA_FILE),
                  "w") as metadata_file:
            json.dump(
                dict(
                    number_of_items=number_of_items,
                    max_seq_length=max_seq_length,
                    max_entity_length=max_entity_length,
                    max_mention_length=max_mention_length,
                    min_sentence_length=min_sentence_length,
                    tokenizer_class=tokenizer.__class__.__name__,
                    language=dump_db.language,
                ),
                metadata_file,
                indent=2,
            )
Exemplo n.º 10
0
Arquivo: main.py Projeto: yifding/luke
def create_title_list(dump_db_file, out_file):
    dump_db = DumpDB(dump_db_file)

    for title in dump_db.titles():
        out_file.write(f"{title}\n")
Exemplo n.º 11
0
def generate_redirect_file(dump_db_file, out_file, compress):
    data = {k: v for k, v in DumpDB(dump_db_file).redirects()}
    joblib.dump(data, out_file, compress=compress)
Exemplo n.º 12
0
def build_wiki_link_db(common_args, dump_db_file, mention_db_file, **kwargs):
    dump_db = DumpDB(dump_db_file)
    mention_db = MentionDB(mention_db_file)
    WikiLinkDB.build(dump_db, mention_db, **kwargs)
Exemplo n.º 13
0
def build_dump_db(dump_file, out_file, **kwargs):
    dump_reader = WikiDumpReader(dump_file)
    DumpDB.build(dump_reader,
                 out_file,
                 preprocess_func=normalize_text,
                 **kwargs)
Exemplo n.º 14
0
def build_entity_linker(dump_db_file, **kwargs):
    dump_db = DumpDB(dump_db_file)
    tokenizer = RegexpTokenizer()
    EntityLinker.build(dump_db, tokenizer, **kwargs)
Exemplo n.º 15
0
def build_from_wikipedia(dump_db_file, **kwargs):
    dump_db = DumpDB(dump_db_file)
    tokenizer = BasicTokenizer(do_lower_case=False)
    normalizer = BertLowercaseNormalizer()
    MentionDB.build_from_wikipedia(dump_db, tokenizer, normalizer, **kwargs)
Exemplo n.º 16
0
# -*- coding: utf-8 -*-

import sys
import Levenshtein
from collections import Counter
from wikipedia2vec.dump_db import DumpDB

dump_db = DumpDB(sys.argv[1])
pair_counter = Counter()

for (title1, title2) in dump_db.redirects():
    ops = Levenshtein.editops(title1.lower(), title2.lower())
    if len(ops) == 1:
        (op, p1, p2) = ops[0]
        if op == 'replace':
            pair_counter[frozenset((title1[p1], title2[p2]))] += 1

for (pair, count) in pair_counter.most_common():
    print('%s\t%s\t%d' % (*list(pair), count))
Exemplo n.º 17
0
def build_dump_db(dump_file: str, out_file: str, **kwargs):
    dump_reader = WikiDumpReader(dump_file)
    DumpDB.build(dump_reader, out_file, **kwargs)
Exemplo n.º 18
0
Arquivo: main.py Projeto: yifding/luke
def create_redirect_tsv(dump_db_file, out_file):
    dump_db = DumpDB(dump_db_file)

    for src, dest in dump_db.redirects():
        out_file.write(f"{src}\t{dest}\n")
Exemplo n.º 19
0
import copy
import joblib
import argparse
import numpy as np

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--entity_file", type=str, required=True)
    parser.add_argument("--src", type=str, required=True)
    parser.add_argument("--tgt", type=str, required=True)
    parser.add_argument("--dumpdb", type=str, required=True)

    args = parser.parse_args()

    dictionary = Dictionary.load(args.src)
    dumpdb = DumpDB(args.dumpdb)

    with open(args.entity_file) as handle:
        all_needed_entities_raw = set(handle.readlines())

    title2dest_title = dict(dumpdb.redirects())
    all_needed_entities = set([
        title2dest_title.get(title, title) for title in all_needed_entities_raw
    ])

    src_file = joblib.load(args.src)

    old_word_dict = Trie()
    old_word_dict.frombytes(src_file['word_dict'])

    old_word_stats = src_file['word_stats']
Exemplo n.º 20
0
class DatasetBuilder:

    tokenizer_language = "da"

    # Files saved by the build method
    metadata_file = "metadata.json"
    entity_vocab_file = "entity-vocab.json"
    data_file = "data.jsonl"
    token_map_file = "token-map.npy"

    def __init__(
        self,
        dump_db_file: str,  # Location of file build by build-dump-db
        tokenizer_name:
        str,  # Tokenizer to use, e.g. Maltehb/danish-bert-botxo for Danish BERT
        entity_vocab_file: str,  # Build by build-entity-vocab
        out_dir:
        str,  # Where to put finished dataset. All contents will be removed before saving dataset
        validation_prob:
        float,  # Chance of each finished document to be marked as part of validation set
        max_entities:
        int,  # Only up to this many entities are included in each sequence
        max_entity_span:
        int,  # Maximum number tokens an entity can span before sequence is discarded
        min_sentence_length:
        int,  # Minimum number of tokens a sentence must span to be included
        max_articles: int | None,
        max_vocab_size: int,
    ):
        if not wikipedia2vec_available:
            raise ModuleNotFoundError(
                "Pretrain data generation requires installation of the optional requirement `wikipedia2vec`"
            )
        log("Reading dump database at %s" % dump_db_file)
        self.dump_db = DumpDB(dump_db_file)
        log("Building tokeninizer: %s" % tokenizer_name)
        self.tokenizer_name = tokenizer_name
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        log("Building sentence tokenizer: %s" % self.tokenizer_language)
        self.sentence_tokenizer = ICUSentenceTokenizer(self.tokenizer_language)
        log("Loading entity vocab at %s" % entity_vocab_file)
        self.entity_vocab = load_entity_vocab(entity_vocab_file)
        # Make sure IDs on non-ignored entities are contiguous
        num = 0
        for entity_info in self.entity_vocab.values():
            entity_info["id"] = num
            num += 1
        log("Entity vocab has size %i" % num)

        self.out_dir = out_dir
        self.data_file = os.path.join(self.out_dir, self.data_file)
        self.token_map_file = os.path.join(self.out_dir, self.token_map_file)
        self.max_seq_length = self.tokenizer.model_max_length
        self.validation_prob = validation_prob
        self.max_entities = max_entities
        self.max_entity_span = max_entity_span
        self.min_sentence_length = min_sentence_length
        # Get maximum number of tokens in a sequence excluding start and end tokens
        self.max_num_tokens = self.max_seq_length - 2
        self.max_articles = max_articles
        self.vocab_size = self.tokenizer.vocab_size if max_vocab_size == -1 else min(
            max_vocab_size, max_vocab_size)

        # Filter titles so only real articles are included
        self.target_titles = list(self.dump_db.titles())

        # Remove old datafile if it exists
        if os.path.isfile(self.data_file):
            log.debug("Removing old datafile '%s'" % self.data_file)
            os.remove(self.data_file)

        self.examples = list()

    def _tokenize(self, text: str, paragraph_text: str, idx: int) -> list[str]:
        if not text:
            return list()
        try:
            if isinstance(self.tokenizer, RobertaTokenizer):
                tokens = self.tokenizer.tokenize(
                    text,
                    add_prefix_space=idx == 0 or text.startswith(" ")
                    or paragraph_text[idx - 1] == " ",
                )
            else:
                tokens = self.tokenizer.tokenize(text)
        except KeyboardInterrupt:
            # Make sure program can be keyboard interrupted despite needing to catch BaseException
            raise
        except BaseException as e:
            # Catch an exception caused by rust panicking in the tokenizer
            log.warning(
                "Failed to tokenize text with exception '%s'\nText: '%s'" %
                (e, text))
            return list()

        return tokens

    def build(self):
        log("Saving tokenizer config and word token config to '%s'" %
            self.out_dir)
        with open(path := os.path.join(self.out_dir, self.entity_vocab_file),
                  "w",
                  encoding="utf-8") as ev:
            log("Saving entity vocab to '%s'" % path)
            ujson.dump(self.entity_vocab, ev, indent=2)

        log.section("Processing %i pages" %
                    len(self.target_titles[:self.max_articles]))
        n_seqs, n_ents, n_word_toks, n_words = 0, 0, 0, 0
        for title in log.tqdm(tqdm(self.target_titles[:self.max_articles])):
            log("Processing %s" % title)
            with TT.profile("Process page"):
                s, e, nt, nw = self._process_page(title)
                n_seqs += s
                n_ents += e
                n_word_toks += nt
                n_words += nw

        log("Shuffling data")
        random.shuffle(self.examples)
        n_vals = int(self.validation_prob * len(self.examples))
        for i in range(n_vals):
            self.examples[i]["is_validation"] = True

        # Save metadata
        metadata = {
            "number-of-items": n_seqs,
            "number-of-word-tokens": n_word_toks,
            "number-of-words": n_words,
            "number-of-entities": n_ents,
            "number-of-val-items": n_vals,
            "max-seq-length": self.max_seq_length,
            "max-entities": self.max_entities,
            "max-entity-span": self.max_entity_span,
            "min-sentence-length": self.min_sentence_length,
            "base-model": self.tokenizer_name,
            "tokenizer-class": self.tokenizer.__class__.__name__,
            "language": self.dump_db.language,
            "reduced-vocab": self.vocab_size < self.tokenizer.vocab_size,
            "vocab-size": self.vocab_size,
        }

        if self.vocab_size < self.tokenizer.vocab_size:
            log.section("Reducing token number")
            with TT.profile("Reduce token vocab"):
                token_map, metadata["vocab-size"] = self._reduce_tokens()
            with TT.profile("Rewrite dataset with new tokens"):
                self._update_tokens(token_map)

        with open(path := os.path.join(self.out_dir, self.metadata_file),
                  "w") as f:
            log.section("Saving metadata to %s" % path)
            ujson.dump(metadata, f, indent=4)
Exemplo n.º 21
0
    def __init__(
        self,
        dump_db_file: str,  # Location of file build by build-dump-db
        tokenizer_name:
        str,  # Tokenizer to use, e.g. Maltehb/danish-bert-botxo for Danish BERT
        entity_vocab_file: str,  # Build by build-entity-vocab
        out_dir:
        str,  # Where to put finished dataset. All contents will be removed before saving dataset
        validation_prob:
        float,  # Chance of each finished document to be marked as part of validation set
        max_entities:
        int,  # Only up to this many entities are included in each sequence
        max_entity_span:
        int,  # Maximum number tokens an entity can span before sequence is discarded
        min_sentence_length:
        int,  # Minimum number of tokens a sentence must span to be included
        max_articles: int | None,
        max_vocab_size: int,
    ):
        if not wikipedia2vec_available:
            raise ModuleNotFoundError(
                "Pretrain data generation requires installation of the optional requirement `wikipedia2vec`"
            )
        log("Reading dump database at %s" % dump_db_file)
        self.dump_db = DumpDB(dump_db_file)
        log("Building tokeninizer: %s" % tokenizer_name)
        self.tokenizer_name = tokenizer_name
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        log("Building sentence tokenizer: %s" % self.tokenizer_language)
        self.sentence_tokenizer = ICUSentenceTokenizer(self.tokenizer_language)
        log("Loading entity vocab at %s" % entity_vocab_file)
        self.entity_vocab = load_entity_vocab(entity_vocab_file)
        # Make sure IDs on non-ignored entities are contiguous
        num = 0
        for entity_info in self.entity_vocab.values():
            entity_info["id"] = num
            num += 1
        log("Entity vocab has size %i" % num)

        self.out_dir = out_dir
        self.data_file = os.path.join(self.out_dir, self.data_file)
        self.token_map_file = os.path.join(self.out_dir, self.token_map_file)
        self.max_seq_length = self.tokenizer.model_max_length
        self.validation_prob = validation_prob
        self.max_entities = max_entities
        self.max_entity_span = max_entity_span
        self.min_sentence_length = min_sentence_length
        # Get maximum number of tokens in a sequence excluding start and end tokens
        self.max_num_tokens = self.max_seq_length - 2
        self.max_articles = max_articles
        self.vocab_size = self.tokenizer.vocab_size if max_vocab_size == -1 else min(
            max_vocab_size, max_vocab_size)

        # Filter titles so only real articles are included
        self.target_titles = list(self.dump_db.titles())

        # Remove old datafile if it exists
        if os.path.isfile(self.data_file):
            log.debug("Removing old datafile '%s'" % self.data_file)
            os.remove(self.data_file)

        self.examples = list()