def create_lit_args(exp_name: str, checkpoints: Set[str] = {"last"}) -> Tuple[tuple, dict]: config = load_config(exp_name) datasets: Dict[str, lit_dataset.Dataset] = { "test": create_test_dataset(config) } src_spp, trg_spp = config.create_sp_processors() model = create_model(config) models: Dict[str, NMTModel] = {} for checkpoint in checkpoints: if checkpoint == "avg": checkpoint_path, _ = get_last_checkpoint(config.model_dir / "avg") models["avg"] = NMTModel(config, model, src_spp, trg_spp, -1, checkpoint_path, type="avg") elif checkpoint == "last": last_checkpoint_path, last_step = get_last_checkpoint( config.model_dir) step_str = str(last_step) if step_str in models: models[step_str].types.append("last") else: models[str(last_step)] = NMTModel(config, model, src_spp, trg_spp, last_step, last_checkpoint_path, type="last") elif checkpoint == "best": best_model_path, best_step = get_best_model_dir(config.model_dir) step_str = str(best_step) if step_str in models: models[step_str].types.append("best") else: models[step_str] = NMTModel(config, model, src_spp, trg_spp, best_step, best_model_path / "ckpt", type="best") else: checkpoint_path = config.model_dir / f"ckpt-{checkpoint}" step = int(checkpoint) models[checkpoint] = NMTModel(config, model, src_spp, trg_spp, step, checkpoint_path) index_datasets: Dict[str, lit_dataset.Dataset] = { "test": create_train_dataset(config) } indexer = IndexerEx( models, lit_dataset.IndexedDataset.index_all(index_datasets, caching.input_hash), data_dir=str(config.exp_dir / "lit-index"), initialize_new_indices=True, ) generators: Dict[str, lit_components.Generator] = { "word_replacer": word_replacer.WordReplacer(), "similarity_searcher": similarity_searcher.SimilaritySearcher(indexer), } interpreters: Dict[str, lit_components.Interpreter] = { "metrics": lit_components.ComponentGroup({"bleu": BLEUMetrics(config.data)}), "pca": projection.ProjectionManager(pca.PCAModel), "umap": projection.ProjectionManager(umap.UmapModel), } return ((models, datasets), { "generators": generators, "interpreters": interpreters })
def main(_): ## # Load models. You can specify several here, if you want to compare different # models side-by-side, and can also include models of different types that use # different datasets. base_models = {} for model_name_or_path in FLAGS.models: # Ignore path prefix, if using /path/to/<model_name> to load from a # specific directory rather than the default shortcut. model_name = os.path.basename(model_name_or_path) if model_name_or_path.startswith("SavedModel"): saved_model_path = model_name_or_path.split(":", 1)[1] base_models[model_name] = t5.T5SavedModel(saved_model_path) else: # TODO(lit-dev): attention is temporarily disabled, because O(n^2) between # tokens in a long document can get very, very large. Re-enable once we # can send this to the frontend more efficiently. base_models[model_name] = t5.T5HFModel( model_name=model_name_or_path, num_to_generate=FLAGS.num_to_generate, token_top_k=FLAGS.token_top_k, output_attention=False) ## # Load eval sets and model wrappers for each task. # Model wrappers share the same in-memory T5 model, but add task-specific pre- # and post-processing code. models = {} datasets = {} if "summarization" in FLAGS.tasks: for k, m in base_models.items(): models[k + "_summarization"] = t5.SummarizationWrapper(m) datasets["CNNDM"] = summarization.CNNDMData( split="validation", max_examples=FLAGS.max_examples) if "mt" in FLAGS.tasks: for k, m in base_models.items(): models[k + "_translation"] = t5.TranslationWrapper(m) datasets["wmt14_enfr"] = mt.WMT14Data(version="fr-en", reverse=True) datasets["wmt14_ende"] = mt.WMT14Data(version="de-en", reverse=True) # Truncate datasets if --max_examples is set. for name in datasets: logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) datasets[name] = datasets[name].slice[:FLAGS.max_examples] logging.info(" truncated to %d examples", len(datasets[name])) ## # We can also add custom components. Generators are used to create new # examples by perturbing or modifying existing ones. generators = { # Word-substitution, like "great" -> "terrible" "word_replacer": word_replacer.WordReplacer(), } if FLAGS.use_indexer: indexer = build_indexer(models) # Wrap the indexer into a Generator component that we can query. generators[ "similarity_searcher"] = similarity_searcher.SimilaritySearcher( indexer=indexer) ## # Actually start the LIT server, using the models, datasets, and other # components constructed above. lit_demo = dev_server.Server(models, datasets, generators=generators, **server_flags.get_flags()) return lit_demo.serve()
def main(_): ## # Load models. You can specify several here, if you want to compare different # models side-by-side, and can also include models of different types that use # different datasets. models = {} for model_name_or_path in FLAGS.models: # Ignore path prefix, if using /path/to/<model_name> to load from a # specific directory rather than the default shortcut. model_name = os.path.basename(model_name_or_path) # TODO(lit-dev): attention is temporarily disabled, because O(n^2) between # tokens in a long document can get very, very large. Re-enable once we can # send this to the frontend more efficiently. models[model_name] = t5.T5GenerationModel( model_name=model_name_or_path, input_prefix="summarize: ", top_k=FLAGS.top_k, output_attention=False) ## # Load datasets. Typically you"ll have the constructor actually read the # examples and do any pre-processing, so that they"re in memory and ready to # send to the frontend when you open the web UI. datasets = { "CNNDM": summarization.CNNDMData( split="validation", max_examples=FLAGS.max_examples), } for name, ds in datasets.items(): logging.info("Dataset: '%s' with %d examples", name, len(ds)) ## # We can also add custom components. Generators are used to create new # examples by perturbing or modifying existing ones. generators = { # Word-substitution, like "great" -> "terrible" "word_replacer": word_replacer.WordReplacer(), } if FLAGS.use_indexer: assert FLAGS.data_dir, "--data_dir must be set to use the indexer." # Datasets for indexer - this one loads the training corpus instead of val. index_datasets = { "CNNDM": summarization.CNNDMData( split="train", max_examples=FLAGS.max_index_examples), } # Set up the Indexer, building index if necessary (this may be slow). indexer = index.Indexer( datasets=index_datasets, models=models, data_dir=FLAGS.data_dir, initialize_new_indices=FLAGS.initialize_index) # Wrap the indexer into a Generator component that we can query. generators["similarity_searcher"] = similarity_searcher.SimilaritySearcher( indexer=indexer) ## # Actually start the LIT server, using the models, datasets, and other # components constructed above. lit_demo = dev_server.Server( models, datasets, generators=generators, **server_flags.get_flags()) lit_demo.serve()