示例#1
0
def checkpoint(
    data: dbag.Bag,
    name: str,
    checkpoint_dir: Path,
    respect_partial_checkpoints: bool = True,
    overwrite: bool = False,
    **compute_kwargs,
) -> dbag.Bag:
    """
  This function checkpoints a dask bag. The bag is broken into partitions, each
  partition is given a checkpoint task, and then the bag is recombined. Any
  partition that has been previously checkpointed will be restored.
  """

    if overwrite:
        respect_partial_checkpoints = False

    # Assure ourselves that we have a unique name
    assert name not in _CHECKPOINT_NAMES
    _CHECKPOINT_NAMES.add(name)

    # Setup directory
    assert checkpoint_dir.is_dir()
    part_dir = checkpoint_dir.joinpath(name)
    part_dir.mkdir(parents=True, exist_ok=True)
    assert part_dir.is_dir()

    if overwrite:
        done_path = part_dir.joinpath(file_util.DONE_FILE)
        if done_path.is_file():
            print("\t- Clearing existing ckpt")
            done_path.unlink()

    if file_util.is_result_saved(part_dir):
        print("\t- Cached")
    else:
        print("\t- Computing")
        file_util.save(
            data, part_dir,
            keep_partial_result=respect_partial_checkpoints).compute(
                **compute_kwargs)
    return file_util.load(part_dir)
示例#2
0
def extract_predicates(config: cpb.AbstractGeneratorConfig):
    paths = get_paths(config)
    dask_client = connect_to_dask_cluster(config)

    preloader = dpg.WorkerPreloader()
    preloader.register(
        *predicate_util.get_scispacy_initalizer(config.predicate_spacy_model))
    preloader.register(*predicate_util.get_stopwordlist_initializer(
        config.predicate_stopword_list))
    dpg.add_global_preloader(client=dask_client, preloader=preloader)

    abstracts = file_util.load(
        paths["checkpoint_dir"].joinpath("medline_documents"))

    predicates = abstracts.map_partitions(
        predicate_util.abstracts_to_predicates)
    predicates = dask_checkpoint.checkpoint(
        predicates,
        name="predicates",
        checkpoint_dir=paths["model_ckpt_dir"],
        overwrite=False,
    )
    predicates.compute()
示例#3
0
def prep(config: cpb.AbstractGeneratorConfig):
    # all important paths
    paths = get_paths(config)
    connect_to_dask_cluster(config)

    def ckpt(val, name, overwrite=False):
        print("Checkpoint", name)
        return dask_checkpoint.checkpoint(
            val,
            name=name,
            checkpoint_dir=paths["model_ckpt_dir"],
            overwrite=overwrite,
        )

    # Get the full set of abstracts
    parsed_abstracts = (file_util.load(
        paths["checkpoint_dir"].joinpath("sentences_with_lemmas")).
                        map_partitions(group_and_filter_parsed_sentences))
    parsed_abstracts = ckpt(parsed_abstracts, "parsed_abstracts")

    is_test_data = (parsed_abstracts.map(
        lambda rec: (random.random() <= config.sys.test_ratio, rec)))
    is_test_data = ckpt(is_test_data, "is_test_data")

    testing_data = (
        is_test_data.filter(lambda b_r: b_r[0]).map(lambda b_r: b_r[1]))
    testing_data = ckpt(testing_data, "testing_data")

    training_data = (
        is_test_data.filter(lambda b_r: not b_r[0]).map(lambda b_r: b_r[1]))
    training_data = ckpt(training_data, "training_data")

    # write each partition of the training dataset to its own sqlitedict db
    # This allows for fast random access during distributed training
    print("Loading training database")
    to_training_database(training_data, paths["training_db_dir"])

    # print("Collecting all mesh headings")
    all_mesh_headings = (training_data.map(
        lambda rec: rec["mesh_headings"]).flatten().frequencies().filter(
            lambda mesh_freq: mesh_freq[1] >= config.min_mesh_term_support).
                         map(lambda mesh_freq: mesh_freq[0]).compute())
    print(f"Indexing all {len(all_mesh_headings)} mesh headings")
    mesh_index = items_to_ordered_index(all_mesh_headings)

    ###

    print("Getting oldest year")
    oldest_year = (
        training_data.map(lambda rec: rec["year"]).filter(
            lambda year: year > 1000)  # some invalid years are crazy
        .min().compute())
    print("\t-", oldest_year)

    ###

    print("Collecting training data for tokenizer")
    training_data_files = (
        training_data
        # Only collect 30% of abstracts
        .random_sample(0.3)
        .map(lambda rec: [s["text"] for s in rec["sentences"]])
        .flatten()
        # Only take 10% of sentences, ultimately,'re subsetting again
        .random_sample(0.1)
        .map(lambda text: text.lower() if config.lowercase else text)
        # Reduce the total number of files
        .repartition(20)
        # Store results in textfiles
        .to_textfiles(f"{paths['tokenizer_training_data_dir']}/*.txt")
    )
    print("Training tokenizer")
    # need to place files in tokenizer_model_path
    spm.SentencePieceTrainer.train(
        f"--input={','.join(training_data_files)} "
        f"--model_prefix={paths['tokenizer_model_path'].parent}/tokenizer "
        f"--vocab_size={config.vocab_size} "
        f"--character_coverage=1.0 "
        f"--model_type=unigram "
        f"--input_sentence_size={config.max_tokenizer_sentences} "
        f"--shuffle_input_sentence=true ")
    assert paths["tokenizer_model_path"].is_file()
    assert paths["tokenizer_vocab_path"].is_file()

    condition_index = {
        "mesh_index": mesh_index,
        "oldest_year": oldest_year,
    }
    with open(paths["model_extra_data_path"], 'wb') as f:
        pickle.dump(condition_index, f)
    print("\t- Written:", paths["model_extra_data_path"])

    if not paths["ngram_freqs_path"].is_file():
        # We're going to need the frequency distribution of ngrams for analysis
        print("Collecting a sample of ngram frequencies.")
        ngram_frequencies = dict(
            file_util.load(paths["model_ckpt_dir"].joinpath(
                "sentences_with_bow")).random_sample(0.1).map(
                    lambda r: r["bow"]).flatten().frequencies().compute())
        with open(paths["ngram_freqs_path"], 'wb') as f:
            pickle.dump(ngram_frequencies, f)
示例#4
0
def checkpoint(name: str,
               bag: Optional[dbag.Bag] = None,
               verbose: Optional[bool] = None,
               allow_partial: Optional[bool] = None,
               halt_after: Optional[str] = None,
               textfile: bool = False,
               **compute_kw) -> Optional[dbag.Bag]:
    """Stores the contents of the bag as a series of files.

  This function takes each partition of the input bag and writes them to files
  within a directory associated with the input name. The location of each
  checkpoint directory is dependent on the `ckpt_root` option.

  For each optional argument, (other than `bag`) of this function, there is an
  associated module-level parameter that can be set globally.

  The module-level parameter checkpoint_root, set with `set_root` must be set
  before calling checkpoint.

  Usage:
    checkpoint(name) - returns load opt for checkpoint "name"
    checkpoint(name, bag) - if ckpt
    writes bag to ckpt "name" and returns load op
    if disable() was called, returns the input bag

  Args:
    name: The name of the checkpoint directory to lookup or save to
    bag: If set, save this bag. Otherwise, we will require that this checkpoint
      has already been saved.
    verbose: Print helper info. If unspecified, defaults to module-level parameter.
    allow_partial: If true, partial files present in an unfinished checkpoint
      directory will not be overwritten. If false, unfinished checkpoints will
      be recomputed in full. Defaults to module-level parameter if unset.
    halt_after: If set to the name of the current checkpoint, the agatha process
      will stop after computing its contents. This is important for partial
      pipeline runs, for instance, for computing training data for an ml model.
    textfile: If set, checkpoint will be stored in plaintext format, used to
      save strings. This results in this function returning `None`.

  Returns:
    A dask bag that, if computed, _LOADS_ the specified checkpoint. This means
    that future operations can depend on the loading of intermediate data,
    rather than the intermediate computations themselves.
  """
    if verbose is None:
        verbose = get_verbose()
    if allow_partial is None:
        allow_partial = get_allow_partial()
    if halt_after is None:
        halt_after = (_PARAM["halt_after_ckpt"] is not None
                      and _PARAM["halt_after_ckpt"] == name)

    def vprint(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)

    def check_halt():
        if halt_after:
            vprint("\t- Halting")
            exit(0)

    # If checkpoint is done, load no matter what
    vprint("Checkpoint:", name, "\t", datetime.now())
    if is_ckpt_done(name):
        vprint("\t- Ready")
        check_halt()
        if textfile:
            return None
        else:
            return file_util.load(get_or_make_ckpt_dir(name))

    # If check pointing is enabled, we need to save the bag and return the load fn
    if _PARAM["enabled"]:
        assert bag is not None, f"Checkpoint needs bag argument to load {name}"
        vprint("\t- Saving")
        file_util.save(
            bag=bag,
            path=get_or_make_ckpt_dir(name),
            keep_partial_result=allow_partial,
            textfile=textfile,
        ).compute(**compute_kw)
        vprint("\t- Done!")
        check_halt()
        if textfile:
            return None
        else:
            return file_util.load(get_or_make_ckpt_dir(name))

    # If check pointing is disabled, we just return the in-progress bag.
    else:  #disabled
        assert bag is not None, \
            f"Checkpointing is disabled, and no bag specified for {name}"
        vprint("\t- Checkpoint Disabled")
        check_halt()
        return bag