def main(_): ## # Pick the model and datasets # TODO(lit-dev): add remaining GLUE tasks? These three cover all the major # features (single segment, two segment, classification, regression). if FLAGS.task == "sst2": train_data = glue.SST2Data("train") val_data = glue.SST2Data("validation") model = glue_models.SST2Model(FLAGS.encoder_name) elif FLAGS.task == "mnli": train_data = glue.MNLIData("train") val_data = glue.MNLIData("validation_matched") model = glue_models.MNLIModel(FLAGS.encoder_name) elif FLAGS.task == "stsb": train_data = glue.STSBData("train") val_data = glue.STSBData("validation") model = glue_models.STSBModel(FLAGS.encoder_name) elif FLAGS.task == "toxicity": train_data = classification.ToxicityData("train") val_data = classification.ToxicityData("test") model = glue_models.ToxicityModel(FLAGS.encoder_name) else: raise ValueError(f"Unrecognized task name: '{FLAGS.task:s}'") ## # Run training and save model. train_and_save(model, train_data, val_data, FLAGS.train_path, save_intermediates=FLAGS.save_intermediates, num_epochs=FLAGS.num_epochs)
def main(_): models = {} datasets = {} if "sst2" in FLAGS.tasks: models["sst2"] = glue_models.SST2Model( os.path.join(FLAGS.models_path, "sst2")) datasets["sst_dev"] = glue.SST2Data("validation") logging.info("Loaded models and data for SST-2 task.") if "stsb" in FLAGS.tasks: models["stsb"] = glue_models.STSBModel( os.path.join(FLAGS.models_path, "stsb")) datasets["stsb_dev"] = glue.STSBData("validation") logging.info("Loaded models and data for STS-B task.") if "mnli" in FLAGS.tasks: models["mnli"] = glue_models.MNLIModel( os.path.join(FLAGS.models_path, "mnli")) datasets["mnli_dev"] = glue.MNLIData("validation_matched") datasets["mnli_dev_mm"] = glue.MNLIData("validation_mismatched") logging.info("Loaded models and data for MultiNLI task.") # Truncate datasets if --max_examples is set. for name in datasets: logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) datasets[name] = datasets[name].slice[:FLAGS.max_examples] logging.info(" truncated to %d examples", len(datasets[name])) # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) lit_demo.serve()
def main(_): # Quick-start mode. if FLAGS.quickstart: FLAGS.models = QUICK_START_MODELS # smaller, faster models if FLAGS.max_examples is None or FLAGS.max_examples > 1000: FLAGS.max_examples = 1000 # truncate larger eval sets logging.info( "Quick-start mode; overriding --models and --max_examples.") models = {} datasets = {} tasks_to_load = set() for model_string in FLAGS.models: # Only split on the first two ':', because path may be a URL # containing 'https://' name, task, path = model_string.split(":", 2) logging.info("Loading model '%s' for task '%s' from '%s'", name, task, path) # Normally path is a directory; if it's an archive file, download and # extract to the transformers cache. if path.endswith(".tar.gz"): path = transformers.file_utils.cached_path( path, extract_compressed_file=True) # Load the model from disk. models[name] = MODELS_BY_TASK[task](path) tasks_to_load.add(task) ## # Load datasets for each task that we have a model for if "sst2" in tasks_to_load: logging.info("Loading data for SST-2 task.") datasets["sst_dev"] = glue.SST2Data("validation") if "stsb" in tasks_to_load: logging.info("Loading data for STS-B task.") datasets["stsb_dev"] = glue.STSBData("validation") if "mnli" in tasks_to_load: logging.info("Loading data for MultiNLI task.") datasets["mnli_dev"] = glue.MNLIData("validation_matched") datasets["mnli_dev_mm"] = glue.MNLIData("validation_mismatched") # Truncate datasets if --max_examples is set. for name in datasets: logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) datasets[name] = datasets[name].slice[:FLAGS.max_examples] logging.info(" truncated to %d examples", len(datasets[name])) # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) return lit_demo.serve()