def test_bertmlm(self): # Run prediction to ensure no failure. model_path = "bert-base-uncased" model = pretrained_lms.BertMLM(model_path) model_in = [{"text": "test text", "tokens": ["test", "[MASK]"]}] model_out = list(model.predict(model_in)) # Sanity-check entries exist in output. self.assertLen(model_out, 1) self.assertIn("pred_tokens", model_out[0]) self.assertIn("cls_emb", model_out[0])
def main(_): ## # Load models, according to the --models flag. models = {} for model_name_or_path in FLAGS.models: # Ignore path prefix, if using /path/to/<model_name> to load from a # specific directory rather than the default shortcut. model_name = os.path.basename(model_name_or_path) if model_name.startswith("bert-"): models[model_name] = pretrained_lms.BertMLM( model_name_or_path, top_k=FLAGS.top_k) elif model_name.startswith("gpt2") or model_name in ["distilgpt2"]: models[model_name] = pretrained_lms.GPT2LanguageModel( model_name_or_path, top_k=FLAGS.top_k) else: raise ValueError( f"Unsupported model name '{model_name}' from path '{model_name_or_path}'" ) datasets = { # Single sentences from movie reviews (SST dev set). "sst_dev": glue.SST2Data("validation").remap({"sentence": "text"}), # Longer passages from movie reviews (IMDB dataset, test split). "imdb_train": classification.IMDBData("test"), # Empty dataset, if you just want to type sentences into the UI. "blank": lm.PlaintextSents(""), } # Guard this with a flag, because TFDS will download and process 1.67 GB # of data if you haven't loaded `lm1b` before. if FLAGS.load_bwb: # A few sentences from the Billion Word Benchmark (Chelba et al. 2013). datasets["bwb"] = lm.BillionWordBenchmark( "train", max_examples=FLAGS.max_examples) for name in datasets: datasets[name] = datasets[name].slice[:FLAGS.max_examples] logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) generators = {"word_replacer": word_replacer.WordReplacer()} lit_demo = dev_server.Server( models, datasets, generators=generators, layouts=CUSTOM_LAYOUTS, **server_flags.get_flags()) return lit_demo.serve()