Beispiel #1
0
def main(_):

    ##
    # Pick the model and datasets
    # TODO(lit-dev): add remaining GLUE tasks? These three cover all the major
    # features (single segment, two segment, classification, regression).
    if FLAGS.task == "sst2":
        train_data = glue.SST2Data("train")
        val_data = glue.SST2Data("validation")
        model = glue_models.SST2Model(FLAGS.encoder_name)
    elif FLAGS.task == "mnli":
        train_data = glue.MNLIData("train")
        val_data = glue.MNLIData("validation_matched")
        model = glue_models.MNLIModel(FLAGS.encoder_name)
    elif FLAGS.task == "stsb":
        train_data = glue.STSBData("train")
        val_data = glue.STSBData("validation")
        model = glue_models.STSBModel(FLAGS.encoder_name)
    elif FLAGS.task == "toxicity":
        train_data = classification.ToxicityData("train")
        val_data = classification.ToxicityData("test")
        model = glue_models.ToxicityModel(FLAGS.encoder_name)
    else:
        raise ValueError(f"Unrecognized task name: '{FLAGS.task:s}'")

    ##
    # Run training and save model.
    train_and_save(model,
                   train_data,
                   val_data,
                   FLAGS.train_path,
                   save_intermediates=FLAGS.save_intermediates,
                   num_epochs=FLAGS.num_epochs)
Beispiel #2
0
def main(_):

  models = {}
  datasets = {}

  if "sst2" in FLAGS.tasks:
    models["sst2"] = glue_models.SST2Model(
        os.path.join(FLAGS.models_path, "sst2"))
    datasets["sst_dev"] = glue.SST2Data("validation")
    logging.info("Loaded models and data for SST-2 task.")

  if "stsb" in FLAGS.tasks:
    models["stsb"] = glue_models.STSBModel(
        os.path.join(FLAGS.models_path, "stsb"))
    datasets["stsb_dev"] = glue.STSBData("validation")
    logging.info("Loaded models and data for STS-B task.")

  if "mnli" in FLAGS.tasks:
    models["mnli"] = glue_models.MNLIModel(
        os.path.join(FLAGS.models_path, "mnli"))
    datasets["mnli_dev"] = glue.MNLIData("validation_matched")
    datasets["mnli_dev_mm"] = glue.MNLIData("validation_mismatched")
    logging.info("Loaded models and data for MultiNLI task.")

  # Truncate datasets if --max_examples is set.
  for name in datasets:
    logging.info("Dataset: '%s' with %d examples", name, len(datasets[name]))
    datasets[name] = datasets[name].slice[:FLAGS.max_examples]
    logging.info("  truncated to %d examples", len(datasets[name]))

  # Start the LIT server. See server_flags.py for server options.
  lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
  lit_demo.serve()
Beispiel #3
0
def main(_):
    # Quick-start mode.
    if FLAGS.quickstart:
        FLAGS.models = QUICK_START_MODELS  # smaller, faster models
        if FLAGS.max_examples is None or FLAGS.max_examples > 1000:
            FLAGS.max_examples = 1000  # truncate larger eval sets
        logging.info(
            "Quick-start mode; overriding --models and --max_examples.")

    models = {}
    datasets = {}

    tasks_to_load = set()
    for model_string in FLAGS.models:
        # Only split on the first two ':', because path may be a URL
        # containing 'https://'
        name, task, path = model_string.split(":", 2)
        logging.info("Loading model '%s' for task '%s' from '%s'", name, task,
                     path)
        # Normally path is a directory; if it's an archive file, download and
        # extract to the transformers cache.
        if path.endswith(".tar.gz"):
            path = transformers.file_utils.cached_path(
                path, extract_compressed_file=True)
        # Load the model from disk.
        models[name] = MODELS_BY_TASK[task](path)
        tasks_to_load.add(task)

    ##
    # Load datasets for each task that we have a model for
    if "sst2" in tasks_to_load:
        logging.info("Loading data for SST-2 task.")
        datasets["sst_dev"] = glue.SST2Data("validation")

    if "stsb" in tasks_to_load:
        logging.info("Loading data for STS-B task.")
        datasets["stsb_dev"] = glue.STSBData("validation")

    if "mnli" in tasks_to_load:
        logging.info("Loading data for MultiNLI task.")
        datasets["mnli_dev"] = glue.MNLIData("validation_matched")
        datasets["mnli_dev_mm"] = glue.MNLIData("validation_mismatched")

    # Truncate datasets if --max_examples is set.
    for name in datasets:
        logging.info("Dataset: '%s' with %d examples", name,
                     len(datasets[name]))
        datasets[name] = datasets[name].slice[:FLAGS.max_examples]
        logging.info("  truncated to %d examples", len(datasets[name]))

    # Start the LIT server. See server_flags.py for server options.
    lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
    return lit_demo.serve()