Exemplo n.º 1
0
def create_lit_args(exp_name: str,
                    checkpoints: Set[str] = {"last"}) -> Tuple[tuple, dict]:
    config = load_config(exp_name)

    datasets: Dict[str, lit_dataset.Dataset] = {
        "test": create_test_dataset(config)
    }

    src_spp, trg_spp = config.create_sp_processors()

    model = create_model(config)

    models: Dict[str, NMTModel] = {}

    for checkpoint in checkpoints:
        if checkpoint == "avg":
            checkpoint_path, _ = get_last_checkpoint(config.model_dir / "avg")
            models["avg"] = NMTModel(config,
                                     model,
                                     src_spp,
                                     trg_spp,
                                     -1,
                                     checkpoint_path,
                                     type="avg")
        elif checkpoint == "last":
            last_checkpoint_path, last_step = get_last_checkpoint(
                config.model_dir)
            step_str = str(last_step)
            if step_str in models:
                models[step_str].types.append("last")
            else:
                models[str(last_step)] = NMTModel(config,
                                                  model,
                                                  src_spp,
                                                  trg_spp,
                                                  last_step,
                                                  last_checkpoint_path,
                                                  type="last")
        elif checkpoint == "best":
            best_model_path, best_step = get_best_model_dir(config.model_dir)
            step_str = str(best_step)
            if step_str in models:
                models[step_str].types.append("best")
            else:
                models[step_str] = NMTModel(config,
                                            model,
                                            src_spp,
                                            trg_spp,
                                            best_step,
                                            best_model_path / "ckpt",
                                            type="best")
        else:
            checkpoint_path = config.model_dir / f"ckpt-{checkpoint}"
            step = int(checkpoint)
            models[checkpoint] = NMTModel(config, model, src_spp, trg_spp,
                                          step, checkpoint_path)

    index_datasets: Dict[str, lit_dataset.Dataset] = {
        "test": create_train_dataset(config)
    }
    indexer = IndexerEx(
        models,
        lit_dataset.IndexedDataset.index_all(index_datasets,
                                             caching.input_hash),
        data_dir=str(config.exp_dir / "lit-index"),
        initialize_new_indices=True,
    )

    generators: Dict[str, lit_components.Generator] = {
        "word_replacer": word_replacer.WordReplacer(),
        "similarity_searcher": similarity_searcher.SimilaritySearcher(indexer),
    }

    interpreters: Dict[str, lit_components.Interpreter] = {
        "metrics":
        lit_components.ComponentGroup({"bleu": BLEUMetrics(config.data)}),
        "pca":
        projection.ProjectionManager(pca.PCAModel),
        "umap":
        projection.ProjectionManager(umap.UmapModel),
    }

    return ((models, datasets), {
        "generators": generators,
        "interpreters": interpreters
    })
Exemplo n.º 2
0
def main(_):
    ##
    # Load models. You can specify several here, if you want to compare different
    # models side-by-side, and can also include models of different types that use
    # different datasets.
    base_models = {}
    for model_name_or_path in FLAGS.models:
        # Ignore path prefix, if using /path/to/<model_name> to load from a
        # specific directory rather than the default shortcut.
        model_name = os.path.basename(model_name_or_path)
        if model_name_or_path.startswith("SavedModel"):
            saved_model_path = model_name_or_path.split(":", 1)[1]
            base_models[model_name] = t5.T5SavedModel(saved_model_path)
        else:
            # TODO(lit-dev): attention is temporarily disabled, because O(n^2) between
            # tokens in a long document can get very, very large. Re-enable once we
            # can send this to the frontend more efficiently.
            base_models[model_name] = t5.T5HFModel(
                model_name=model_name_or_path,
                num_to_generate=FLAGS.num_to_generate,
                token_top_k=FLAGS.token_top_k,
                output_attention=False)

    ##
    # Load eval sets and model wrappers for each task.
    # Model wrappers share the same in-memory T5 model, but add task-specific pre-
    # and post-processing code.
    models = {}
    datasets = {}

    if "summarization" in FLAGS.tasks:
        for k, m in base_models.items():
            models[k + "_summarization"] = t5.SummarizationWrapper(m)
        datasets["CNNDM"] = summarization.CNNDMData(
            split="validation", max_examples=FLAGS.max_examples)

    if "mt" in FLAGS.tasks:
        for k, m in base_models.items():
            models[k + "_translation"] = t5.TranslationWrapper(m)
        datasets["wmt14_enfr"] = mt.WMT14Data(version="fr-en", reverse=True)
        datasets["wmt14_ende"] = mt.WMT14Data(version="de-en", reverse=True)

    # Truncate datasets if --max_examples is set.
    for name in datasets:
        logging.info("Dataset: '%s' with %d examples", name,
                     len(datasets[name]))
        datasets[name] = datasets[name].slice[:FLAGS.max_examples]
        logging.info("  truncated to %d examples", len(datasets[name]))

    ##
    # We can also add custom components. Generators are used to create new
    # examples by perturbing or modifying existing ones.
    generators = {
        # Word-substitution, like "great" -> "terrible"
        "word_replacer": word_replacer.WordReplacer(),
    }

    if FLAGS.use_indexer:
        indexer = build_indexer(models)
        # Wrap the indexer into a Generator component that we can query.
        generators[
            "similarity_searcher"] = similarity_searcher.SimilaritySearcher(
                indexer=indexer)

    ##
    # Actually start the LIT server, using the models, datasets, and other
    # components constructed above.
    lit_demo = dev_server.Server(models,
                                 datasets,
                                 generators=generators,
                                 **server_flags.get_flags())
    return lit_demo.serve()
Exemplo n.º 3
0
def main(_):
  ##
  # Load models. You can specify several here, if you want to compare different
  # models side-by-side, and can also include models of different types that use
  # different datasets.
  models = {}
  for model_name_or_path in FLAGS.models:
    # Ignore path prefix, if using /path/to/<model_name> to load from a
    # specific directory rather than the default shortcut.
    model_name = os.path.basename(model_name_or_path)
    # TODO(lit-dev): attention is temporarily disabled, because O(n^2) between
    # tokens in a long document can get very, very large. Re-enable once we can
    # send this to the frontend more efficiently.
    models[model_name] = t5.T5GenerationModel(
        model_name=model_name_or_path,
        input_prefix="summarize: ",
        top_k=FLAGS.top_k,
        output_attention=False)

  ##
  # Load datasets. Typically you"ll have the constructor actually read the
  # examples and do any pre-processing, so that they"re in memory and ready to
  # send to the frontend when you open the web UI.
  datasets = {
      "CNNDM":
          summarization.CNNDMData(
              split="validation", max_examples=FLAGS.max_examples),
  }
  for name, ds in datasets.items():
    logging.info("Dataset: '%s' with %d examples", name, len(ds))

  ##
  # We can also add custom components. Generators are used to create new
  # examples by perturbing or modifying existing ones.
  generators = {
      # Word-substitution, like "great" -> "terrible"
      "word_replacer": word_replacer.WordReplacer(),
  }

  if FLAGS.use_indexer:
    assert FLAGS.data_dir, "--data_dir must be set to use the indexer."
    # Datasets for indexer - this one loads the training corpus instead of val.
    index_datasets = {
        "CNNDM":
            summarization.CNNDMData(
                split="train", max_examples=FLAGS.max_index_examples),
    }
    # Set up the Indexer, building index if necessary (this may be slow).
    indexer = index.Indexer(
        datasets=index_datasets,
        models=models,
        data_dir=FLAGS.data_dir,
        initialize_new_indices=FLAGS.initialize_index)

    # Wrap the indexer into a Generator component that we can query.
    generators["similarity_searcher"] = similarity_searcher.SimilaritySearcher(
        indexer=indexer)

  ##
  # Actually start the LIT server, using the models, datasets, and other
  # components constructed above.
  lit_demo = dev_server.Server(
      models, datasets, generators=generators, **server_flags.get_flags())
  lit_demo.serve()