示例#1
0
def optimize(model: NERDaLUKE, dataset: NERDataset, args: dict[str, Any], sampler: Sampler):
    results, tried_params = list(), list()
    best = None
    i = 0
    while (sampled_params := sampler.sample()) is not None:
        log.section(f"Sampling #{i}: chose", f(sampled_params))
        result = objective_function(deepcopy(model), dataset, {**args, **sampled_params})
        score = result.statistics["micro avg"]["f1-score"]
        if best is None or score > results[best].statistics["micro avg"]["f1-score"]:
            log(f"Found new best at F1 of {score}")
            best = i
        result.save(out := os.path.join(args['location'], f"res-optim{i}"))
        log.debug(f"Saved results to {out}")
        results.append(result)
        tried_params.append(sampled_params)
        i += 1
示例#2
0
def cross_validate(model: NERDaLUKE, dataset: NERDataset, k: int,
                   train_args: dict[str, Any]) -> list[NER_Results]:
    cv_splits = random_divide(merge_data(list(dataset.data.values())), k)
    results = list()
    log(f"Split into {k} subdatasets with lengths {[len(c.texts) for c in cv_splits]}"
        )
    for i, test_data in enumerate(cv_splits):
        log.section(f"Cross-validation split {i}")
        train_data = merge_data([s for j, s in enumerate(cv_splits) if j != i])
        # Create split specific model and data
        split_model = deepcopy(model)
        split_dataset = deepcopy(dataset)
        split_dataset.data[Split.TRAIN] = train_data
        split_dataloader = split_dataset.build(Split.TRAIN,
                                               train_args["batch_size"])

        log("Training")
        split_dataset.document(split_dataloader, Split.TRAIN)
        type_distribution(split_dataset.data[Split.TRAIN].annotations)
        trainer = TrainNER(
            split_model,
            split_dataloader,
            split_dataset,
            device=next(split_model.parameters()).device,
            epochs=train_args["epochs"],
            lr=train_args["lr"],
            warmup_prop=train_args["warmup_prop"],
            weight_decay=train_args["weight_decay"],
            dev_dataloader=None,  # Don't eval
            loss_weight=train_args["loss_weight"])
        trainer.run()

        split_dataset.data[Split.TEST] = test_data
        split_test_dataloader = split_dataset.build(Split.TEST, EVAL_BATCH)

        log("Evaluation")
        split_dataset.document(split_dataloader, Split.TEST)
        type_distribution(split_dataset.data[Split.TEST].annotations)
        results.append(
            evaluate_ner(split_model,
                         split_test_dataloader,
                         split_dataset,
                         trainer.device,
                         Split.TEST,
                         also_no_misc=False))
    return results
def main(path: str, n: int):
    log.configure(os.path.join(path, "geometry-examples.log"),
                  "daLUKE examples",
                  print_level=Levels.DEBUG)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Hardcoded to train
    data = load_dataset(dict(dataset="DaNE"), DUMMY_METADATA,
                        device).data[Split.TRAIN]
    set_seeds()
    GeometryResults.subfolder = ""
    res = GeometryResults.load(path)
    for field, axis in OF_INTEREST.items():
        log.section(field)
        X = getattr(res, field)
        order = X[:, axis].argsort()

        log(f"Examples where dim. {axis} is high")
        _show_examples(res, X, order[::-1][:n], data)
        log(f"Examples where dim. {axis} is low")
        _show_examples(res, X, order[:n], data)
示例#4
0
def preprocess(
    dump_db_file: str,
    function: str,
    entity_vocab_file: str | None,
    dagw_sections: str | None,
    min_entity_length: int,
    max_entity_length: int,
    max_articles: int | None,
):
    if not entity_vocab_file:
        raise RuntimeError("entity-vocab-file must be given")

    log.configure(
        os.path.join(os.path.split(dump_db_file)[0], "preprocessing.log"),
        "Preprocessing",
        log_commit=True,
    )

    log.section("Collecting data")
    log(
        "Wikidump path: %s" % dump_db_file,
        "Function:      %s" % function,
    )

    log("Loading entity vocab")
    entity_vocab = {
        _insert_xml_special_characters(e.lower())
        for e in load_entity_vocab(entity_vocab_file)
    }

    dagw_files = list()
    if dagw_sections:
        n_words = 0
        log("Finding gigaword data files and counting words")
        dagw_files = list(_get_dagw_files(dagw_sections))
        for dagw_file in tqdm(dagw_files):
            with open(dagw_file) as f:
                n_words += len(f.read().split())
        log("Found %i dagw files containing %i words" %
            (len(dagw_files), n_words))

    # tempdir is not used, as the temporary files can take up more space than what temporary
    # directories usually allow
    tmpdir = os.path.join(os.path.split(dump_db_file)[0], "tmpdir")
    os.makedirs(tmpdir, exist_ok=True)
    log("Saving all articles to temporary directory %s" % tmpdir)
    for dagw_file in tqdm(dagw_files):
        shutil.copy2(
            dagw_file,
            os.path.join(tmpdir, fix_filename(os.path.split(dagw_file)[-1])))
    log("Saving Wikipedia files to temporary directory")
    for is_text, text, title in tqdm(_get_lineblocks(dump_db_file),
                                     unit=" blocks"):
        if is_text and not ignore_title(title):
            text_start = text.index(">") + 1
            text_end = -len("</text>\n")
            with open(
                    os.path.join(tmpdir,
                                 fix_filename(title)[:100] + ".wiki"),
                    "w") as f:
                f.write(text[text_start:text_end])

    files = [
        os.path.join(tmpdir, x) for x in os.listdir(tmpdir)[:max_articles]
    ]
    log("Saved a total of %i articles to %s" % (len(files), tmpdir))

    log.section("Beginning preprocessing on %i threads" % os.cpu_count())
    process_map(
        func,
        [(function, f, entity_vocab, min_entity_length, max_entity_length)
         for f in files],
        max_workers=os.cpu_count(),
        chunksize=1024,
    )

    dump_file = os.path.splitext(dump_db_file)[0] + ".%s.bz2" % function
    log.info("Saving preprocessed files to %s" % dump_file)
    with bz2.BZ2File(dump_file, "w") as dump:
        with bz2.BZ2File(dump_db_file) as old_dump:
            line = b""
            while not line.strip().startswith(b"<page>"):
                dump.write(line)
                line = old_dump.readline()
        for i, fname in tqdm(enumerate(files), total=len(files)):
            with open(fname) as f:
                text = f.read()
            s = """
            <page>
                <title>{title}</title>
                <id>{id}</id>
                <revision>
                    <text bytes="{bytes}" xml:space="preserve">{text}</text>
                </revision>
            </page>""".format(
                title=fname,
                id=i + 1,
                bytes=len(text),
                text=text,
            )
            if i == 0:
                s = s[1:]
            dump.write(s.encode("utf-8"))
        dump.write(b"\n</mediawiki>")

    log.info("Removing temporary files")
    shutil.rmtree(tmpdir)
    log.info("Done preprocessing data")
示例#5
0
class DatasetBuilder:

    tokenizer_language = "da"

    # Files saved by the build method
    metadata_file = "metadata.json"
    entity_vocab_file = "entity-vocab.json"
    data_file = "data.jsonl"
    token_map_file = "token-map.npy"

    def __init__(
        self,
        dump_db_file: str,  # Location of file build by build-dump-db
        tokenizer_name:
        str,  # Tokenizer to use, e.g. Maltehb/danish-bert-botxo for Danish BERT
        entity_vocab_file: str,  # Build by build-entity-vocab
        out_dir:
        str,  # Where to put finished dataset. All contents will be removed before saving dataset
        validation_prob:
        float,  # Chance of each finished document to be marked as part of validation set
        max_entities:
        int,  # Only up to this many entities are included in each sequence
        max_entity_span:
        int,  # Maximum number tokens an entity can span before sequence is discarded
        min_sentence_length:
        int,  # Minimum number of tokens a sentence must span to be included
        max_articles: int | None,
        max_vocab_size: int,
    ):
        if not wikipedia2vec_available:
            raise ModuleNotFoundError(
                "Pretrain data generation requires installation of the optional requirement `wikipedia2vec`"
            )
        log("Reading dump database at %s" % dump_db_file)
        self.dump_db = DumpDB(dump_db_file)
        log("Building tokeninizer: %s" % tokenizer_name)
        self.tokenizer_name = tokenizer_name
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        log("Building sentence tokenizer: %s" % self.tokenizer_language)
        self.sentence_tokenizer = ICUSentenceTokenizer(self.tokenizer_language)
        log("Loading entity vocab at %s" % entity_vocab_file)
        self.entity_vocab = load_entity_vocab(entity_vocab_file)
        # Make sure IDs on non-ignored entities are contiguous
        num = 0
        for entity_info in self.entity_vocab.values():
            entity_info["id"] = num
            num += 1
        log("Entity vocab has size %i" % num)

        self.out_dir = out_dir
        self.data_file = os.path.join(self.out_dir, self.data_file)
        self.token_map_file = os.path.join(self.out_dir, self.token_map_file)
        self.max_seq_length = self.tokenizer.model_max_length
        self.validation_prob = validation_prob
        self.max_entities = max_entities
        self.max_entity_span = max_entity_span
        self.min_sentence_length = min_sentence_length
        # Get maximum number of tokens in a sequence excluding start and end tokens
        self.max_num_tokens = self.max_seq_length - 2
        self.max_articles = max_articles
        self.vocab_size = self.tokenizer.vocab_size if max_vocab_size == -1 else min(
            max_vocab_size, max_vocab_size)

        # Filter titles so only real articles are included
        self.target_titles = list(self.dump_db.titles())

        # Remove old datafile if it exists
        if os.path.isfile(self.data_file):
            log.debug("Removing old datafile '%s'" % self.data_file)
            os.remove(self.data_file)

        self.examples = list()

    def _tokenize(self, text: str, paragraph_text: str, idx: int) -> list[str]:
        if not text:
            return list()
        try:
            if isinstance(self.tokenizer, RobertaTokenizer):
                tokens = self.tokenizer.tokenize(
                    text,
                    add_prefix_space=idx == 0 or text.startswith(" ")
                    or paragraph_text[idx - 1] == " ",
                )
            else:
                tokens = self.tokenizer.tokenize(text)
        except KeyboardInterrupt:
            # Make sure program can be keyboard interrupted despite needing to catch BaseException
            raise
        except BaseException as e:
            # Catch an exception caused by rust panicking in the tokenizer
            log.warning(
                "Failed to tokenize text with exception '%s'\nText: '%s'" %
                (e, text))
            return list()

        return tokens

    def build(self):
        log("Saving tokenizer config and word token config to '%s'" %
            self.out_dir)
        with open(path := os.path.join(self.out_dir, self.entity_vocab_file),
                  "w",
                  encoding="utf-8") as ev:
            log("Saving entity vocab to '%s'" % path)
            ujson.dump(self.entity_vocab, ev, indent=2)

        log.section("Processing %i pages" %
                    len(self.target_titles[:self.max_articles]))
        n_seqs, n_ents, n_word_toks, n_words = 0, 0, 0, 0
        for title in log.tqdm(tqdm(self.target_titles[:self.max_articles])):
            log("Processing %s" % title)
            with TT.profile("Process page"):
                s, e, nt, nw = self._process_page(title)
                n_seqs += s
                n_ents += e
                n_word_toks += nt
                n_words += nw

        log("Shuffling data")
        random.shuffle(self.examples)
        n_vals = int(self.validation_prob * len(self.examples))
        for i in range(n_vals):
            self.examples[i]["is_validation"] = True

        # Save metadata
        metadata = {
            "number-of-items": n_seqs,
            "number-of-word-tokens": n_word_toks,
            "number-of-words": n_words,
            "number-of-entities": n_ents,
            "number-of-val-items": n_vals,
            "max-seq-length": self.max_seq_length,
            "max-entities": self.max_entities,
            "max-entity-span": self.max_entity_span,
            "min-sentence-length": self.min_sentence_length,
            "base-model": self.tokenizer_name,
            "tokenizer-class": self.tokenizer.__class__.__name__,
            "language": self.dump_db.language,
            "reduced-vocab": self.vocab_size < self.tokenizer.vocab_size,
            "vocab-size": self.vocab_size,
        }

        if self.vocab_size < self.tokenizer.vocab_size:
            log.section("Reducing token number")
            with TT.profile("Reduce token vocab"):
                token_map, metadata["vocab-size"] = self._reduce_tokens()
            with TT.profile("Rewrite dataset with new tokens"):
                self._update_tokens(token_map)

        with open(path := os.path.join(self.out_dir, self.metadata_file),
                  "w") as f:
            log.section("Saving metadata to %s" % path)
            ujson.dump(metadata, f, indent=4)