예제 #1
0
def train(config:cpb.AbstractGeneratorConfig):
  paths = get_paths(config)
  model = get_model_from_config(config)

  print("Configuring trainer")
  # DEFAULTS used by the Trainer
  checkpoint_callback = ModelCheckpoint(
    filepath=paths["model_root_dir"],
    verbose=True,
    monitor='val_loss',
    mode='min',
    prefix=''
  )

  trainer = Trainer(
      gradient_clip_val=config.gradient_clip_val,
      gpus=-1,
      nb_gpu_nodes=config.num_nodes if config.HasField("num_nodes") else 1,
      distributed_backend='ddp',
      accumulate_grad_batches=config.accumulate_batches,
      train_percent_check=config.training_fraction,
      weights_summary='full',
      default_save_path=paths["model_root_dir"],
      checkpoint_callback=checkpoint_callback,
  )
  model.init_datasets()
  print("Training!")
  trainer.fit(model)
예제 #2
0
def get_model_from_config(
    config:cpb.AbstractGeneratorConfig,
)->AbstractGenerator:
  paths = get_paths(config)
  tokenizer_model_path = paths["tokenizer_model_path"]
  extra_data_path = paths["model_extra_data_path"]

  if config.HasField("restore_from_checkpoint"):
    return AbstractGenerator.load_from_checkpoint(config.restore_from_checkpoint)
  else:
    return AbstractGenerator(Namespace(
        tokenizer_model_path=str(tokenizer_model_path),
        extra_data_path=str(extra_data_path),
        lowercase=config.lowercase,
        embedding_dim=config.embedding_dim,
        max_text_length=config.text_length,
        num_attention_heads=config.num_attention_heads,
        num_encoder_layers=config.num_encoder_layers,
        num_decoder_layers=config.num_decoder_layers,
        intermediate_dropout=0.1,
        intermediate_feedforward_dim=config.hidden_fc_size,
        training_data_dir=str(paths["training_db_dir"]),
        batch_size=config.sys.batch_size,
        warmup_steps=config.num_warmup_steps,
        learning_rate=config.sys.learning_rate,
        dataset_workers=4,
        train_num_machines=config.num_nodes,
    ))
예제 #3
0
def name_thy_self(config: cpb.AbstractGeneratorConfig) -> str:
    assert config.HasField("restore_from_checkpoint"), \
        "Must supply restore_from_checkpoint config"
    paths = get_paths(config)
    model = AbstractGenerator.load_from_checkpoint(
        config.restore_from_checkpoint)
    model.init_tokenizer()
    model.freeze()
    model.eval()

    text = """
    Medical Hypothesis Generation via. Conditional Abstract Generation. In
    this work, we present a variant of GPT-2 that incorporates medical domain
    knowledge. This system, which we have named py
  """
    text = re.sub(r"\s+", " ", text)
    text = text.strip()

    abstract = dict(pmid=0000,
                    year=2019,
                    mesh_headings=[],
                    sentences=[
                        dict(
                            type="title",
                            text=text,
                            tags=[],
                            ents=[],
                        ),
                        dict(
                            type="abstract:raw",
                            text="Discard this.",
                            tags=[],
                            ents=[],
                        )
                    ])

    encoder = datasets.EncodedAbstracts(
        abstract_ds=[abstract],
        tokenizer_kwargs=model.hparams.tokenizer_kwargs,
        max_text_length=model.hparams.max_text_length,
        max_mesh_length=model.hparams.max_text_length - 1,
        title_only=True,
        return_abstract=True,
    )

    loader = torch.utils.data.DataLoader(
        dataset=encoder,
        batch_size=1,
        collate_fn=collate_for_generation,
    )

    for model_in, abstract in loader:
        new_sentence = generate_new_text(
            model,
            model_in,
            min_size=3,
            max_size=10,
        )
        print(new_sentence)
예제 #4
0
def extract_predicates(config: cpb.AbstractGeneratorConfig):
    paths = get_paths(config)
    dask_client = connect_to_dask_cluster(config)

    preloader = dpg.WorkerPreloader()
    preloader.register(
        *predicate_util.get_scispacy_initalizer(config.predicate_spacy_model))
    preloader.register(*predicate_util.get_stopwordlist_initializer(
        config.predicate_stopword_list))
    dpg.add_global_preloader(client=dask_client, preloader=preloader)

    abstracts = file_util.load(
        paths["checkpoint_dir"].joinpath("medline_documents"))

    predicates = abstracts.map_partitions(
        predicate_util.abstracts_to_predicates)
    predicates = dask_checkpoint.checkpoint(
        predicates,
        name="predicates",
        checkpoint_dir=paths["model_ckpt_dir"],
        overwrite=False,
    )
    predicates.compute()
예제 #5
0
def prep(config: cpb.AbstractGeneratorConfig):
    # all important paths
    paths = get_paths(config)
    connect_to_dask_cluster(config)

    def ckpt(val, name, overwrite=False):
        print("Checkpoint", name)
        return dask_checkpoint.checkpoint(
            val,
            name=name,
            checkpoint_dir=paths["model_ckpt_dir"],
            overwrite=overwrite,
        )

    # Get the full set of abstracts
    parsed_abstracts = (file_util.load(
        paths["checkpoint_dir"].joinpath("sentences_with_lemmas")).
                        map_partitions(group_and_filter_parsed_sentences))
    parsed_abstracts = ckpt(parsed_abstracts, "parsed_abstracts")

    is_test_data = (parsed_abstracts.map(
        lambda rec: (random.random() <= config.sys.test_ratio, rec)))
    is_test_data = ckpt(is_test_data, "is_test_data")

    testing_data = (
        is_test_data.filter(lambda b_r: b_r[0]).map(lambda b_r: b_r[1]))
    testing_data = ckpt(testing_data, "testing_data")

    training_data = (
        is_test_data.filter(lambda b_r: not b_r[0]).map(lambda b_r: b_r[1]))
    training_data = ckpt(training_data, "training_data")

    # write each partition of the training dataset to its own sqlitedict db
    # This allows for fast random access during distributed training
    print("Loading training database")
    to_training_database(training_data, paths["training_db_dir"])

    # print("Collecting all mesh headings")
    all_mesh_headings = (training_data.map(
        lambda rec: rec["mesh_headings"]).flatten().frequencies().filter(
            lambda mesh_freq: mesh_freq[1] >= config.min_mesh_term_support).
                         map(lambda mesh_freq: mesh_freq[0]).compute())
    print(f"Indexing all {len(all_mesh_headings)} mesh headings")
    mesh_index = items_to_ordered_index(all_mesh_headings)

    ###

    print("Getting oldest year")
    oldest_year = (
        training_data.map(lambda rec: rec["year"]).filter(
            lambda year: year > 1000)  # some invalid years are crazy
        .min().compute())
    print("\t-", oldest_year)

    ###

    print("Collecting training data for tokenizer")
    training_data_files = (
        training_data
        # Only collect 30% of abstracts
        .random_sample(0.3)
        .map(lambda rec: [s["text"] for s in rec["sentences"]])
        .flatten()
        # Only take 10% of sentences, ultimately,'re subsetting again
        .random_sample(0.1)
        .map(lambda text: text.lower() if config.lowercase else text)
        # Reduce the total number of files
        .repartition(20)
        # Store results in textfiles
        .to_textfiles(f"{paths['tokenizer_training_data_dir']}/*.txt")
    )
    print("Training tokenizer")
    # need to place files in tokenizer_model_path
    spm.SentencePieceTrainer.train(
        f"--input={','.join(training_data_files)} "
        f"--model_prefix={paths['tokenizer_model_path'].parent}/tokenizer "
        f"--vocab_size={config.vocab_size} "
        f"--character_coverage=1.0 "
        f"--model_type=unigram "
        f"--input_sentence_size={config.max_tokenizer_sentences} "
        f"--shuffle_input_sentence=true ")
    assert paths["tokenizer_model_path"].is_file()
    assert paths["tokenizer_vocab_path"].is_file()

    condition_index = {
        "mesh_index": mesh_index,
        "oldest_year": oldest_year,
    }
    with open(paths["model_extra_data_path"], 'wb') as f:
        pickle.dump(condition_index, f)
    print("\t- Written:", paths["model_extra_data_path"])

    if not paths["ngram_freqs_path"].is_file():
        # We're going to need the frequency distribution of ngrams for analysis
        print("Collecting a sample of ngram frequencies.")
        ngram_frequencies = dict(
            file_util.load(paths["model_ckpt_dir"].joinpath(
                "sentences_with_bow")).random_sample(0.1).map(
                    lambda r: r["bow"]).flatten().frequencies().compute())
        with open(paths["ngram_freqs_path"], 'wb') as f:
            pickle.dump(ngram_frequencies, f)
예제 #6
0
def evaluate(
    config: cpb.AbstractGeneratorConfig,
    gen_whole_abstract: bool = True,
    skip_metrics: bool = False,
):

    multilogger = MultiLogger(config)

    assert config.HasField("restore_from_checkpoint"), \
        "Must supply restore_from_checkpoint config"
    paths = get_paths(config)

    testing_data_dir = paths["model_ckpt_dir"].joinpath("testing_data")
    assert testing_data_dir.is_dir()

    model = AbstractGenerator.load_from_checkpoint(
        config.restore_from_checkpoint)
    model.init_tokenizer()
    model.cuda()
    model.freeze()
    model.eval()

    for test_pkl in testing_data_dir.glob("*.pkl"):
        with open(test_pkl, "rb") as pkl_file:
            abstracts = pickle.load(pkl_file)
            encoder = datasets.EncodedAbstracts(
                abstract_ds=abstracts,
                tokenizer_kwargs=model.hparams.tokenizer_kwargs,
                max_text_length=model.hparams.max_text_length,
                max_mesh_length=model.hparams.max_text_length - 1,
                title_only=True,
                return_abstract=True,
            )
            loader = torch.utils.data.DataLoader(
                dataset=encoder,
                batch_size=1,
                collate_fn=collate_for_generation,
                shuffle=True,
            )

            # loader typically returns a list, but we set this to batch of 1
            for model_in, (abstract, ) in loader:
                original_abstract = " ".join([
                    sent["text"] for sent in abstract["sentences"]
                    if sent["type"] != "title"
                ]).lower()
                if len(original_abstract) == 0:
                    continue
                title = " ".join([
                    s["text"] for s in abstract["sentences"]
                    if s["type"] == "title"
                ]).lower()
                for trial_idx in range(config.trials_per_generation):
                    trial_model_in = deepcopy(model_in)
                    trial_model_in = {
                        k: v.cuda()
                        for k, v in trial_model_in.items()
                    }
                    new_abstract = generate_new_text(
                        model,
                        trial_model_in,
                        gen_whole_abstract,
                        min_size=3,
                        max_size=1000,
                    )
                    metrics = {}
                    metrics["pmid"] = abstract["pmid"]
                    metrics["title"] = title
                    metrics["generated_abstract"] = new_abstract
                    metrics["original_abstract"] = original_abstract
                    if config.trials_per_generation > 1:
                        metrics["trial"] = trial_idx
                    if not skip_metrics:
                        metrics.update({
                            k: float(v)
                            for k, v in
                            get_nlg_eval().compute_individual_metrics(
                                original_abstract,
                                new_abstract,
                            ).items()
                        })
                        metrics["CIDEr-Title"] = compute_cider_minus_title(
                            original_abstract=original_abstract,
                            original_title=title,
                            generated_abstract=new_abstract,
                            ngram_freqs_path=paths["ngram_freqs_path"],
                        )
                    multilogger.log_row(metrics)