Esempio n. 1
0
def main(_):

    ##
    # Pick the model and datasets
    # TODO(lit-dev): add remaining GLUE tasks? These three cover all the major
    # features (single segment, two segment, classification, regression).
    if FLAGS.task == "sst2":
        train_data = glue.SST2Data("train")
        val_data = glue.SST2Data("validation")
        model = glue_models.SST2Model(FLAGS.encoder_name)
    elif FLAGS.task == "mnli":
        train_data = glue.MNLIData("train")
        val_data = glue.MNLIData("validation_matched")
        model = glue_models.MNLIModel(FLAGS.encoder_name)
    elif FLAGS.task == "stsb":
        train_data = glue.STSBData("train")
        val_data = glue.STSBData("validation")
        model = glue_models.STSBModel(FLAGS.encoder_name)
    elif FLAGS.task == "toxicity":
        train_data = classification.ToxicityData("train")
        val_data = classification.ToxicityData("test")
        model = glue_models.ToxicityModel(FLAGS.encoder_name)
    else:
        raise ValueError(f"Unrecognized task name: '{FLAGS.task:s}'")

    ##
    # Run training and save model.
    train_and_save(model,
                   train_data,
                   val_data,
                   FLAGS.train_path,
                   save_intermediates=FLAGS.save_intermediates,
                   num_epochs=FLAGS.num_epochs)
Esempio n. 2
0
def run_finetuning(train_path):
    """Fine-tune a transformer model."""
    train_data = glue.SST2Data("train")
    val_data = glue.SST2Data("validation")
    model = glue_models.SST2Model(FLAGS.encoder_name, for_training=True)
    model.train(train_data.examples, validation_inputs=val_data.examples)
    model.save(train_path)
Esempio n. 3
0
def main(_):

  models = {}
  datasets = {}

  if "sst2" in FLAGS.tasks:
    models["sst2"] = glue_models.SST2Model(
        os.path.join(FLAGS.models_path, "sst2"))
    datasets["sst_dev"] = glue.SST2Data("validation")
    logging.info("Loaded models and data for SST-2 task.")

  if "stsb" in FLAGS.tasks:
    models["stsb"] = glue_models.STSBModel(
        os.path.join(FLAGS.models_path, "stsb"))
    datasets["stsb_dev"] = glue.STSBData("validation")
    logging.info("Loaded models and data for STS-B task.")

  if "mnli" in FLAGS.tasks:
    models["mnli"] = glue_models.MNLIModel(
        os.path.join(FLAGS.models_path, "mnli"))
    datasets["mnli_dev"] = glue.MNLIData("validation_matched")
    datasets["mnli_dev_mm"] = glue.MNLIData("validation_mismatched")
    logging.info("Loaded models and data for MultiNLI task.")

  # Truncate datasets if --max_examples is set.
  for name in datasets:
    logging.info("Dataset: '%s' with %d examples", name, len(datasets[name]))
    datasets[name] = datasets[name].slice[:FLAGS.max_examples]
    logging.info("  truncated to %d examples", len(datasets[name]))

  # Start the LIT server. See server_flags.py for server options.
  lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
  lit_demo.serve()
Esempio n. 4
0
def main(_):
    # Load the model we defined above.
    models = {"sst": SimpleSentimentModel(FLAGS.model_path)}
    # Load SST-2 validation set from TFDS.
    datasets = {"sst_dev": glue.SST2Data("validation")}

    # Start the LIT server. See server_flags.py for server options.
    lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
    lit_demo.serve()
Esempio n. 5
0
def main(_):
    model_path = FLAGS.model_path or tempfile.mkdtemp()
    logging.info("Working directory: %s", model_path)
    run_finetuning(model_path)

    # Load our trained model.
    models = {"sst": glue_models.SST2Model(model_path)}
    datasets = {"sst_dev": glue.SST2Data("validation")}

    # Start the LIT server. See server_flags.py for server options.
    lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
    lit_demo.serve()
Esempio n. 6
0
def main(_):
    # Quick-start mode.
    if FLAGS.quickstart:
        FLAGS.models = QUICK_START_MODELS  # smaller, faster models
        if FLAGS.max_examples is None or FLAGS.max_examples > 1000:
            FLAGS.max_examples = 1000  # truncate larger eval sets
        logging.info(
            "Quick-start mode; overriding --models and --max_examples.")

    models = {}
    datasets = {}

    tasks_to_load = set()
    for model_string in FLAGS.models:
        # Only split on the first two ':', because path may be a URL
        # containing 'https://'
        name, task, path = model_string.split(":", 2)
        logging.info("Loading model '%s' for task '%s' from '%s'", name, task,
                     path)
        # Normally path is a directory; if it's an archive file, download and
        # extract to the transformers cache.
        if path.endswith(".tar.gz"):
            path = transformers.file_utils.cached_path(
                path, extract_compressed_file=True)
        # Load the model from disk.
        models[name] = MODELS_BY_TASK[task](path)
        tasks_to_load.add(task)

    ##
    # Load datasets for each task that we have a model for
    if "sst2" in tasks_to_load:
        logging.info("Loading data for SST-2 task.")
        datasets["sst_dev"] = glue.SST2Data("validation")

    if "stsb" in tasks_to_load:
        logging.info("Loading data for STS-B task.")
        datasets["stsb_dev"] = glue.STSBData("validation")

    if "mnli" in tasks_to_load:
        logging.info("Loading data for MultiNLI task.")
        datasets["mnli_dev"] = glue.MNLIData("validation_matched")
        datasets["mnli_dev_mm"] = glue.MNLIData("validation_mismatched")

    # Truncate datasets if --max_examples is set.
    for name in datasets:
        logging.info("Dataset: '%s' with %d examples", name,
                     len(datasets[name]))
        datasets[name] = datasets[name].slice[:FLAGS.max_examples]
        logging.info("  truncated to %d examples", len(datasets[name]))

    # Start the LIT server. See server_flags.py for server options.
    lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
    return lit_demo.serve()
Esempio n. 7
0
def main(_):

  ##
  # Load models, according to the --models flag.
  models = {}
  for model_name_or_path in FLAGS.models:
    # Ignore path prefix, if using /path/to/<model_name> to load from a
    # specific directory rather than the default shortcut.
    model_name = os.path.basename(model_name_or_path)
    if model_name.startswith("bert-"):
      models[model_name] = pretrained_lms.BertMLM(
          model_name_or_path, top_k=FLAGS.top_k)
    elif model_name.startswith("gpt2") or model_name in ["distilgpt2"]:
      models[model_name] = pretrained_lms.GPT2LanguageModel(
          model_name_or_path, top_k=FLAGS.top_k)
    else:
      raise ValueError(
          f"Unsupported model name '{model_name}' from path '{model_name_or_path}'"
      )

  datasets = {
      # Single sentences from movie reviews (SST dev set).
      "sst_dev": glue.SST2Data("validation").remap({"sentence": "text"}),
      # Longer passages from movie reviews (IMDB dataset, test split).
      "imdb_train": classification.IMDBData("test"),
      # Empty dataset, if you just want to type sentences into the UI.
      "blank": lm.PlaintextSents(""),
  }
  # Guard this with a flag, because TFDS will download and process 1.67 GB
  # of data if you haven't loaded `lm1b` before.
  if FLAGS.load_bwb:
    # A few sentences from the Billion Word Benchmark (Chelba et al. 2013).
    datasets["bwb"] = lm.BillionWordBenchmark(
        "train", max_examples=FLAGS.max_examples)

  for name in datasets:
    datasets[name] = datasets[name].slice[:FLAGS.max_examples]
    logging.info("Dataset: '%s' with %d examples", name, len(datasets[name]))

  generators = {"word_replacer": word_replacer.WordReplacer()}

  lit_demo = dev_server.Server(
      models,
      datasets,
      generators=generators,
      layouts=CUSTOM_LAYOUTS,
      **server_flags.get_flags())
  return lit_demo.serve()
Esempio n. 8
0
def main(_):
    # Normally path is a directory; if it's an archive file, download and
    # extract to the transformers cache.
    model_path = FLAGS.model_path
    if model_path.endswith(".tar.gz"):
        model_path = transformers.file_utils.cached_path(
            model_path, extract_compressed_file=True)

    # Load the model we defined above.
    models = {"sst": SimpleSentimentModel(model_path)}
    # Load SST-2 validation set from TFDS.
    datasets = {"sst_dev": glue.SST2Data("validation")}

    # Start the LIT server. See server_flags.py for server options.
    lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
    lit_demo.serve()