def main(_): ## # Pick the model and datasets # TODO(lit-dev): add remaining GLUE tasks? These three cover all the major # features (single segment, two segment, classification, regression). if FLAGS.task == "sst2": train_data = glue.SST2Data("train") val_data = glue.SST2Data("validation") model = glue_models.SST2Model(FLAGS.encoder_name) elif FLAGS.task == "mnli": train_data = glue.MNLIData("train") val_data = glue.MNLIData("validation_matched") model = glue_models.MNLIModel(FLAGS.encoder_name) elif FLAGS.task == "stsb": train_data = glue.STSBData("train") val_data = glue.STSBData("validation") model = glue_models.STSBModel(FLAGS.encoder_name) elif FLAGS.task == "toxicity": train_data = classification.ToxicityData("train") val_data = classification.ToxicityData("test") model = glue_models.ToxicityModel(FLAGS.encoder_name) else: raise ValueError(f"Unrecognized task name: '{FLAGS.task:s}'") ## # Run training and save model. train_and_save(model, train_data, val_data, FLAGS.train_path, save_intermediates=FLAGS.save_intermediates, num_epochs=FLAGS.num_epochs)
def run_finetuning(train_path): """Fine-tune a transformer model.""" train_data = glue.SST2Data("train") val_data = glue.SST2Data("validation") model = glue_models.SST2Model(FLAGS.encoder_name, for_training=True) model.train(train_data.examples, validation_inputs=val_data.examples) model.save(train_path)
def main(_): models = {} datasets = {} if "sst2" in FLAGS.tasks: models["sst2"] = glue_models.SST2Model( os.path.join(FLAGS.models_path, "sst2")) datasets["sst_dev"] = glue.SST2Data("validation") logging.info("Loaded models and data for SST-2 task.") if "stsb" in FLAGS.tasks: models["stsb"] = glue_models.STSBModel( os.path.join(FLAGS.models_path, "stsb")) datasets["stsb_dev"] = glue.STSBData("validation") logging.info("Loaded models and data for STS-B task.") if "mnli" in FLAGS.tasks: models["mnli"] = glue_models.MNLIModel( os.path.join(FLAGS.models_path, "mnli")) datasets["mnli_dev"] = glue.MNLIData("validation_matched") datasets["mnli_dev_mm"] = glue.MNLIData("validation_mismatched") logging.info("Loaded models and data for MultiNLI task.") # Truncate datasets if --max_examples is set. for name in datasets: logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) datasets[name] = datasets[name].slice[:FLAGS.max_examples] logging.info(" truncated to %d examples", len(datasets[name])) # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) lit_demo.serve()
def main(_): # Load the model we defined above. models = {"sst": SimpleSentimentModel(FLAGS.model_path)} # Load SST-2 validation set from TFDS. datasets = {"sst_dev": glue.SST2Data("validation")} # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) lit_demo.serve()
def main(_): model_path = FLAGS.model_path or tempfile.mkdtemp() logging.info("Working directory: %s", model_path) run_finetuning(model_path) # Load our trained model. models = {"sst": glue_models.SST2Model(model_path)} datasets = {"sst_dev": glue.SST2Data("validation")} # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) lit_demo.serve()
def main(_): # Quick-start mode. if FLAGS.quickstart: FLAGS.models = QUICK_START_MODELS # smaller, faster models if FLAGS.max_examples is None or FLAGS.max_examples > 1000: FLAGS.max_examples = 1000 # truncate larger eval sets logging.info( "Quick-start mode; overriding --models and --max_examples.") models = {} datasets = {} tasks_to_load = set() for model_string in FLAGS.models: # Only split on the first two ':', because path may be a URL # containing 'https://' name, task, path = model_string.split(":", 2) logging.info("Loading model '%s' for task '%s' from '%s'", name, task, path) # Normally path is a directory; if it's an archive file, download and # extract to the transformers cache. if path.endswith(".tar.gz"): path = transformers.file_utils.cached_path( path, extract_compressed_file=True) # Load the model from disk. models[name] = MODELS_BY_TASK[task](path) tasks_to_load.add(task) ## # Load datasets for each task that we have a model for if "sst2" in tasks_to_load: logging.info("Loading data for SST-2 task.") datasets["sst_dev"] = glue.SST2Data("validation") if "stsb" in tasks_to_load: logging.info("Loading data for STS-B task.") datasets["stsb_dev"] = glue.STSBData("validation") if "mnli" in tasks_to_load: logging.info("Loading data for MultiNLI task.") datasets["mnli_dev"] = glue.MNLIData("validation_matched") datasets["mnli_dev_mm"] = glue.MNLIData("validation_mismatched") # Truncate datasets if --max_examples is set. for name in datasets: logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) datasets[name] = datasets[name].slice[:FLAGS.max_examples] logging.info(" truncated to %d examples", len(datasets[name])) # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) return lit_demo.serve()
def main(_): ## # Load models, according to the --models flag. models = {} for model_name_or_path in FLAGS.models: # Ignore path prefix, if using /path/to/<model_name> to load from a # specific directory rather than the default shortcut. model_name = os.path.basename(model_name_or_path) if model_name.startswith("bert-"): models[model_name] = pretrained_lms.BertMLM( model_name_or_path, top_k=FLAGS.top_k) elif model_name.startswith("gpt2") or model_name in ["distilgpt2"]: models[model_name] = pretrained_lms.GPT2LanguageModel( model_name_or_path, top_k=FLAGS.top_k) else: raise ValueError( f"Unsupported model name '{model_name}' from path '{model_name_or_path}'" ) datasets = { # Single sentences from movie reviews (SST dev set). "sst_dev": glue.SST2Data("validation").remap({"sentence": "text"}), # Longer passages from movie reviews (IMDB dataset, test split). "imdb_train": classification.IMDBData("test"), # Empty dataset, if you just want to type sentences into the UI. "blank": lm.PlaintextSents(""), } # Guard this with a flag, because TFDS will download and process 1.67 GB # of data if you haven't loaded `lm1b` before. if FLAGS.load_bwb: # A few sentences from the Billion Word Benchmark (Chelba et al. 2013). datasets["bwb"] = lm.BillionWordBenchmark( "train", max_examples=FLAGS.max_examples) for name in datasets: datasets[name] = datasets[name].slice[:FLAGS.max_examples] logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) generators = {"word_replacer": word_replacer.WordReplacer()} lit_demo = dev_server.Server( models, datasets, generators=generators, layouts=CUSTOM_LAYOUTS, **server_flags.get_flags()) return lit_demo.serve()
def main(_): # Normally path is a directory; if it's an archive file, download and # extract to the transformers cache. model_path = FLAGS.model_path if model_path.endswith(".tar.gz"): model_path = transformers.file_utils.cached_path( model_path, extract_compressed_file=True) # Load the model we defined above. models = {"sst": SimpleSentimentModel(model_path)} # Load SST-2 validation set from TFDS. datasets = {"sst_dev": glue.SST2Data("validation")} # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) lit_demo.serve()