def test_parse_sub_string(self): generator = word_replacer.WordReplacer() query_string = 'foo -> bar, spam -> eggs' expected = {'foo': ['bar'], 'spam': ['eggs']} self.assertDictEqual(generator.parse_subs_string(query_string), expected) # Should ignore the malformed rule query_string = 'foo -> bar, spam eggs' expected = {'foo': ['bar']} self.assertDictEqual(generator.parse_subs_string(query_string), expected) # Multiple target tokens. query_string = 'foo -> bar, spam -> eggs|donuts | cream' expected = {'foo': ['bar'], 'spam': ['eggs', 'donuts', 'cream']} self.assertDictEqual(generator.parse_subs_string(query_string), expected) query_string = '' expected = {} self.assertDictEqual(generator.parse_subs_string(query_string), expected) query_string = '♞ -> ♟' expected = {'♞': ['♟']} self.assertDictEqual(generator.parse_subs_string(query_string), expected)
def main(_): ## # Load models, according to the --models flag. models = {} for model_name_or_path in FLAGS.models: # Ignore path prefix, if using /path/to/<model_name> to load from a # specific directory rather than the default shortcut. model_name = os.path.basename(model_name_or_path) if model_name.startswith("bert-"): models[model_name] = pretrained_lms.BertMLM( model_name_or_path, top_k=FLAGS.top_k) elif model_name.startswith("gpt2") or model_name in ["distilgpt2"]: models[model_name] = pretrained_lms.GPT2LanguageModel( model_name_or_path, top_k=FLAGS.top_k) else: raise ValueError( f"Unsupported model name '{model_name}' from path '{model_name_or_path}'" ) datasets = { # Single sentences from movie reviews (SST dev set). "sst_dev": glue.SST2Data("validation").remap({"sentence": "text"}), # Longer passages from movie reviews (IMDB dataset, test split). "imdb_train": classification.IMDBData("test"), # Empty dataset, if you just want to type sentences into the UI. "blank": lm.PlaintextSents(""), } # Guard this with a flag, because TFDS will download and process 1.67 GB # of data if you haven't loaded `lm1b` before. if FLAGS.load_bwb: # A few sentences from the Billion Word Benchmark (Chelba et al. 2013). datasets["bwb"] = lm.BillionWordBenchmark( "train", max_examples=FLAGS.max_examples) for name in datasets: datasets[name] = datasets[name].slice[:FLAGS.max_examples] logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) generators = {"word_replacer": word_replacer.WordReplacer()} lit_demo = dev_server.Server( models, datasets, generators=generators, layouts=CUSTOM_LAYOUTS, **server_flags.get_flags()) return lit_demo.serve()
def test_parse_sub_string(self): generator = word_replacer.WordReplacer() query_string = 'foo -> bar, spam -> eggs' expected = {'foo': 'bar', 'spam': 'eggs'} self.assertDictEqual(generator.parse_subs_string(query_string), expected) # Should ignore the malformed rule query_string = 'foo -> bar, spam eggs' expected = {'foo': 'bar'} self.assertDictEqual(generator.parse_subs_string(query_string), expected) query_string = '' expected = {} self.assertDictEqual(generator.parse_subs_string(query_string), expected) query_string = '♞ -> ♟' expected = {'♞': '♟'} self.assertDictEqual(generator.parse_subs_string(query_string), expected)
def main(_): ## # Load models. You can specify several here, if you want to compare different # models side-by-side, and can also include models of different types that use # different datasets. base_models = {} for model_name_or_path in FLAGS.models: # Ignore path prefix, if using /path/to/<model_name> to load from a # specific directory rather than the default shortcut. model_name = os.path.basename(model_name_or_path) if model_name_or_path.startswith("SavedModel"): saved_model_path = model_name_or_path.split(":", 1)[1] base_models[model_name] = t5.T5SavedModel(saved_model_path) else: # TODO(lit-dev): attention is temporarily disabled, because O(n^2) between # tokens in a long document can get very, very large. Re-enable once we # can send this to the frontend more efficiently. base_models[model_name] = t5.T5HFModel( model_name=model_name_or_path, num_to_generate=FLAGS.num_to_generate, token_top_k=FLAGS.token_top_k, output_attention=False) ## # Load eval sets and model wrappers for each task. # Model wrappers share the same in-memory T5 model, but add task-specific pre- # and post-processing code. models = {} datasets = {} if "summarization" in FLAGS.tasks: for k, m in base_models.items(): models[k + "_summarization"] = t5.SummarizationWrapper(m) datasets["CNNDM"] = summarization.CNNDMData( split="validation", max_examples=FLAGS.max_examples) if "mt" in FLAGS.tasks: for k, m in base_models.items(): models[k + "_translation"] = t5.TranslationWrapper(m) datasets["wmt14_enfr"] = mt.WMT14Data(version="fr-en", reverse=True) datasets["wmt14_ende"] = mt.WMT14Data(version="de-en", reverse=True) # Truncate datasets if --max_examples is set. for name in datasets: logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) datasets[name] = datasets[name].slice[:FLAGS.max_examples] logging.info(" truncated to %d examples", len(datasets[name])) ## # We can also add custom components. Generators are used to create new # examples by perturbing or modifying existing ones. generators = { # Word-substitution, like "great" -> "terrible" "word_replacer": word_replacer.WordReplacer(), } if FLAGS.use_indexer: indexer = build_indexer(models) # Wrap the indexer into a Generator component that we can query. generators[ "similarity_searcher"] = similarity_searcher.SimilaritySearcher( indexer=indexer) ## # Actually start the LIT server, using the models, datasets, and other # components constructed above. lit_demo = dev_server.Server(models, datasets, generators=generators, **server_flags.get_flags()) return lit_demo.serve()
def create_lit_args(exp_name: str, checkpoints: Set[str] = {"last"}) -> Tuple[tuple, dict]: config = load_config(exp_name) datasets: Dict[str, lit_dataset.Dataset] = { "test": create_test_dataset(config) } src_spp, trg_spp = config.create_sp_processors() model = create_model(config) models: Dict[str, NMTModel] = {} for checkpoint in checkpoints: if checkpoint == "avg": checkpoint_path, _ = get_last_checkpoint(config.model_dir / "avg") models["avg"] = NMTModel(config, model, src_spp, trg_spp, -1, checkpoint_path, type="avg") elif checkpoint == "last": last_checkpoint_path, last_step = get_last_checkpoint( config.model_dir) step_str = str(last_step) if step_str in models: models[step_str].types.append("last") else: models[str(last_step)] = NMTModel(config, model, src_spp, trg_spp, last_step, last_checkpoint_path, type="last") elif checkpoint == "best": best_model_path, best_step = get_best_model_dir(config.model_dir) step_str = str(best_step) if step_str in models: models[step_str].types.append("best") else: models[step_str] = NMTModel(config, model, src_spp, trg_spp, best_step, best_model_path / "ckpt", type="best") else: checkpoint_path = config.model_dir / f"ckpt-{checkpoint}" step = int(checkpoint) models[checkpoint] = NMTModel(config, model, src_spp, trg_spp, step, checkpoint_path) index_datasets: Dict[str, lit_dataset.Dataset] = { "test": create_train_dataset(config) } indexer = IndexerEx( models, lit_dataset.IndexedDataset.index_all(index_datasets, caching.input_hash), data_dir=str(config.exp_dir / "lit-index"), initialize_new_indices=True, ) generators: Dict[str, lit_components.Generator] = { "word_replacer": word_replacer.WordReplacer(), "similarity_searcher": similarity_searcher.SimilaritySearcher(indexer), } interpreters: Dict[str, lit_components.Interpreter] = { "metrics": lit_components.ComponentGroup({"bleu": BLEUMetrics(config.data)}), "pca": projection.ProjectionManager(pca.PCAModel), "umap": projection.ProjectionManager(umap.UmapModel), } return ((models, datasets), { "generators": generators, "interpreters": interpreters })
def main(_): ## # Load models. You can specify several here, if you want to compare different # models side-by-side, and can also include models of different types that use # different datasets. models = {} for model_name_or_path in FLAGS.models: # Ignore path prefix, if using /path/to/<model_name> to load from a # specific directory rather than the default shortcut. model_name = os.path.basename(model_name_or_path) # TODO(lit-dev): attention is temporarily disabled, because O(n^2) between # tokens in a long document can get very, very large. Re-enable once we can # send this to the frontend more efficiently. models[model_name] = t5.T5GenerationModel( model_name=model_name_or_path, input_prefix="summarize: ", top_k=FLAGS.top_k, output_attention=False) ## # Load datasets. Typically you"ll have the constructor actually read the # examples and do any pre-processing, so that they"re in memory and ready to # send to the frontend when you open the web UI. datasets = { "CNNDM": summarization.CNNDMData( split="validation", max_examples=FLAGS.max_examples), } for name, ds in datasets.items(): logging.info("Dataset: '%s' with %d examples", name, len(ds)) ## # We can also add custom components. Generators are used to create new # examples by perturbing or modifying existing ones. generators = { # Word-substitution, like "great" -> "terrible" "word_replacer": word_replacer.WordReplacer(), } if FLAGS.use_indexer: assert FLAGS.data_dir, "--data_dir must be set to use the indexer." # Datasets for indexer - this one loads the training corpus instead of val. index_datasets = { "CNNDM": summarization.CNNDMData( split="train", max_examples=FLAGS.max_index_examples), } # Set up the Indexer, building index if necessary (this may be slow). indexer = index.Indexer( datasets=index_datasets, models=models, data_dir=FLAGS.data_dir, initialize_new_indices=FLAGS.initialize_index) # Wrap the indexer into a Generator component that we can query. generators["similarity_searcher"] = similarity_searcher.SimilaritySearcher( indexer=indexer) ## # Actually start the LIT server, using the models, datasets, and other # components constructed above. lit_demo = dev_server.Server( models, datasets, generators=generators, **server_flags.get_flags()) lit_demo.serve()
def test_all_replacements(self): input_spec = {'text': lit_types.TextSegment()} model = testing_utils.TestRegressionModel(input_spec) # Dataset is only used for spec in word_replacer so define once dataset = lit_dataset.Dataset(input_spec, [{'text': 'blank'}]) ## Test replacements generator = word_replacer.WordReplacer() # Unicode to Unicode input_dict = {'text': '♞ is a black chess knight.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: '♞ -> ♟', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': '♟ is a black chess knight.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Unicode to ASCII input_dict = {'text': 'Is répertoire a unicode word?'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'répertoire -> repertoire', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'Is repertoire a unicode word?'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Ignore capitalization input_dict = {'text': 'Capitalization is ignored.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'Capitalization -> blank', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'blank is ignored.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) input_dict = {'text': 'Capitalization is ignored.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'capitalization -> blank', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'blank is ignored.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Do not Ignore capitalization input_dict = {'text': 'Capitalization is important.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'Capitalization -> blank', word_replacer.IGNORE_CASING_KEY: False, word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'blank is important.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) input_dict = {'text': 'Capitalization is important.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'capitalization -> blank', word_replacer.IGNORE_CASING_KEY: False, word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Repetition input_dict = {'text': 'maybe repetition repetition maybe'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'repetition -> blank', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'maybe blank repetition maybe'}, {'text': 'maybe repetition blank maybe'}] self.assertCountEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # No partial match input_dict = {'text': 'A catastrophic storm'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'cat -> blank', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) ## Special characters # Punctuation input_dict = {'text': 'A catastrophic storm .'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: '. -> -', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'A catastrophic storm -'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) input_dict = {'text': 'A.catastrophic. storm'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: '. -> -', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'A-catastrophic. storm'}, {'text': 'A.catastrophic- storm'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) input_dict = {'text': 'A...catastrophic.... storm'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: '.. -> --', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'A--.catastrophic.... storm'}, {'text': 'A...catastrophic--.. storm'}, {'text': 'A...catastrophic..-- storm'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Underscore input_dict = {'text': 'A catastrophic_storm is raging.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'catastrophic_storm -> nice_storm', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'A nice_storm is raging.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Deletion input_dict = {'text': 'A storm is raging.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'storm -> ', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'A is raging.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Word next to punctuation and words with punctuation. input_dict = {'text': 'It`s raining cats and dogs.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'dogs -> blank', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'It`s raining cats and blank.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Multiple target tokens. input_dict = {'text': 'It`s raining cats and dogs.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'dogs -> horses|donkeys', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'It`s raining cats and horses.'}, {'text': 'It`s raining cats and donkeys.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) ## Test default_replacements applied at init. replacements = {'tree': ['car']} generator = word_replacer.WordReplacer(replacements=replacements) input_dict = {'text': 'black truck hit the tree'} expected = [{'text': 'black truck hit the car'}] config_dict = { word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) ## Test not passing replacements not breaking. generator = word_replacer.WordReplacer() input_dict = {'text': 'xyz yzy zzz.'} expected = [] self.assertEqual( generator.generate(input_dict, model, dataset), expected) # Multi word match. input_dict = {'text': 'A red cat is coming.'} config_dict = { word_replacer.SUBSTITUTIONS_KEY: 'red cat -> black dog', word_replacer.FIELDS_TO_REPLACE_KEY: ['text'], } expected = [{'text': 'A black dog is coming.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected)
def __init__( self, models: Mapping[Text, lit_model.Model], datasets: MutableMapping[Text, lit_dataset.Dataset], generators: Optional[Mapping[Text, lit_components.Generator]] = None, interpreters: Optional[Mapping[Text, lit_components.Interpreter]] = None, # General server config; see server_flags.py. data_dir: Optional[Text] = None, warm_start: float = 0.0, warm_projections: bool = False, client_root: Optional[Text] = None, demo_mode: bool = False, default_layout: str = None, canonical_url: str = None, ): if client_root is None: raise ValueError('client_root must be set on application') self._demo_mode = demo_mode self._default_layout = default_layout self._canonical_url = canonical_url if data_dir and not os.path.isdir(data_dir): os.mkdir(data_dir) self._models = { name: caching.CachingModelWrapper(model, name, cache_dir=data_dir) for name, model in models.items() } self._datasets = datasets self._datasets['_union_empty'] = NoneDataset(self._models) if generators is not None: self._generators = generators else: self._generators = { 'scrambler': scrambler.Scrambler(), 'word_replacer': word_replacer.WordReplacer(), } if interpreters is not None: self._interpreters = interpreters else: metrics_group = lit_components.ComponentGroup({ 'regression': metrics.RegressionMetrics(), 'multiclass': metrics.MulticlassMetrics(), 'paired': metrics.MulticlassPairedMetrics(), 'bleu': metrics.CorpusBLEU(), }) self._interpreters = { 'grad_norm': gradient_maps.GradientNorm(), 'lime': lime_explainer.LIME(), 'grad_dot_input': gradient_maps.GradientDotInput(), 'integrated gradients': gradient_maps.IntegratedGradients(), 'counterfactual explainer': lemon_explainer.LEMON(), 'metrics': metrics_group, # Embedding projectors expose a standard interface, but get special # handling so we can precompute the projections if requested. 'pca': projection.ProjectionManager(pca.PCAModel), 'umap': projection.ProjectionManager(umap.UmapModel), } # Information on models and datasets. self._build_metadata() # Optionally, run models to pre-populate cache. if warm_projections: logging.info( 'Projection (dimensionality reduction) warm-start requested; ' 'will do full warm-start for all models since predictions are needed.' ) warm_start = 1.0 if warm_start > 0: self._warm_start(rate=warm_start) self.save_cache() # If you add a new embedding projector that should be warm-started, # also add it to the list here. # TODO(lit-dev): add some registry mechanism / automation if this grows to # more than 2-3 projection types. if warm_projections: self._warm_projections(['pca', 'umap']) handlers = { # Metadata endpoints. '/get_info': self._get_info, # Dataset-related endpoints. '/get_dataset': self._get_dataset, '/get_generated': self._get_generated, '/save_datapoints': self._save_datapoints, '/load_datapoints': self._load_datapoints, '/get_datapoint_ids': self._get_datapoint_ids, # Model prediction endpoints. '/get_preds': self._get_preds, '/get_interpretations': self._get_interpretations, } self._wsgi_app = wsgi_app.App( # Wrap endpoint fns to take (handler, request) handlers={k: make_handler(v) for k, v in handlers.items()}, project_root=client_root, index_file='static/index.html', )
def test_all_replacements(self): input_spec = {'text': lit_types.TextSegment()} model = testing_utils.TestRegressionModel(input_spec) # Dataset is only used for spec in word_replacer so define once dataset = lit_dataset.Dataset(input_spec, {'text': 'blank'}) ## Test replacements generator = word_replacer.WordReplacer() # Unicode to Unicode input_dict = {'text': '♞ is a black chess knight.'} config_dict = {'subs': '♞ -> ♟'} expected = [{'text': '♟ is a black chess knight.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Unicode to ASCII input_dict = {'text': 'Is répertoire a unicode word?'} config_dict = {'subs': 'répertoire -> repertoire'} expected = [{'text': 'Is repertoire a unicode word?'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Capitalization input_dict = {'text': 'Capitalization is important.'} config_dict = {'subs': 'Capitalization -> blank'} expected = [{'text': 'blank is important.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) input_dict = {'text': 'Capitalization is important.'} config_dict = {'subs': 'capitalization -> blank'} expected = [] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Repetition input_dict = {'text': 'maybe repetition repetition maybe'} config_dict = {'subs': 'repetition -> blank'} expected = [{'text': 'maybe blank repetition maybe'}, {'text': 'maybe repetition blank maybe'}] self.assertCountEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # No partial match input_dict = {'text': 'A catastrophic storm'} config_dict = {'subs': 'cat -> blank'} expected = [] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) ## Special characters # Punctuation input_dict = {'text': 'A catastrophic storm .'} config_dict = {'subs': '. -> -'} expected = [{'text': 'A catastrophic storm -'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Underscore input_dict = {'text': 'A catastrophic_storm is raging.'} config_dict = {'subs': 'catastrophic_storm -> nice_storm'} expected = [{'text': 'A nice_storm is raging.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) # Word next to punctuation and words with punctuation. input_dict = {'text': 'It`s raining cats and dogs.'} config_dict = {'subs': 'dogs -> blank'} expected = [{'text': 'It`s raining cats and blank.'}] self.assertEqual( generator.generate(input_dict, model, dataset, config=config_dict), expected) ## Test default_replacements applied at init. replacements = {'tree': 'car'} generator = word_replacer.WordReplacer(replacements=replacements) input_dict = {'text': 'black truck hit the tree'} expected = [{'text': 'black truck hit the car'}] self.assertEqual( generator.generate(input_dict, model, dataset), expected) ## Test not passing replacements not breaking. generator = word_replacer.WordReplacer() input_dict = {'text': 'xyz yzy zzz.'} expected = [] self.assertEqual( generator.generate(input_dict, model, dataset), expected)
def __init__( self, models: Mapping[Text, lit_model.Model], datasets: Mapping[Text, lit_dataset.Dataset], generators: Optional[Mapping[Text, lit_components.Generator]] = None, interpreters: Optional[Mapping[Text, lit_components.Interpreter]] = None, annotators: Optional[List[lit_components.Annotator]] = None, layouts: Optional[dtypes.LitComponentLayouts] = None, # General server config; see server_flags.py. data_dir: Optional[Text] = None, warm_start: float = 0.0, warm_projections: bool = False, client_root: Optional[Text] = None, demo_mode: bool = False, default_layout: Optional[str] = None, canonical_url: Optional[str] = None, page_title: Optional[str] = None, development_demo: bool = False, ): if client_root is None: raise ValueError('client_root must be set on application') self._demo_mode = demo_mode self._development_demo = development_demo self._default_layout = default_layout self._canonical_url = canonical_url self._page_title = page_title self._data_dir = data_dir self._layouts = layouts or {} if data_dir and not os.path.isdir(data_dir): os.mkdir(data_dir) # Wrap models in caching wrapper self._models = { name: caching.CachingModelWrapper(model, name, cache_dir=data_dir) for name, model in models.items() } self._datasets = dict(datasets) self._datasets['_union_empty'] = lit_dataset.NoneDataset(self._models) self._annotators = annotators or [] # Run annotation on each dataset, creating an annotated dataset and # replace the datasets with the annotated versions. for ds_key, ds in self._datasets.items(): self._datasets[ds_key] = self._run_annotators(ds) # Index all datasets self._datasets = lit_dataset.IndexedDataset.index_all( self._datasets, caching.input_hash) if generators is not None: self._generators = generators else: self._generators = { 'Ablation Flip': ablation_flip.AblationFlip(), 'Hotflip': hotflip.HotFlip(), 'Scrambler': scrambler.Scrambler(), 'Word Replacer': word_replacer.WordReplacer(), } if interpreters is not None: self._interpreters = interpreters else: metrics_group = lit_components.ComponentGroup({ 'regression': metrics.RegressionMetrics(), 'multiclass': metrics.MulticlassMetrics(), 'paired': metrics.MulticlassPairedMetrics(), 'bleu': metrics.CorpusBLEU(), }) self._interpreters = { 'Grad L2 Norm': gradient_maps.GradientNorm(), 'Grad ⋅ Input': gradient_maps.GradientDotInput(), 'Integrated Gradients': gradient_maps.IntegratedGradients(), 'LIME': lime_explainer.LIME(), 'Model-provided salience': model_salience.ModelSalience(self._models), 'counterfactual explainer': lemon_explainer.LEMON(), 'tcav': tcav.TCAV(), 'thresholder': thresholder.Thresholder(), 'nearest neighbors': nearest_neighbors.NearestNeighbors(), 'metrics': metrics_group, 'pdp': pdp.PdpInterpreter(), # Embedding projectors expose a standard interface, but get special # handling so we can precompute the projections if requested. 'pca': projection.ProjectionManager(pca.PCAModel), 'umap': projection.ProjectionManager(umap.UmapModel), } # Information on models, datasets, and other components. self._info = self._build_metadata() # Optionally, run models to pre-populate cache. if warm_projections: logging.info( 'Projection (dimensionality reduction) warm-start requested; ' 'will do full warm-start for all models since predictions are needed.' ) warm_start = 1.0 if warm_start > 0: self._warm_start(rate=warm_start) self.save_cache() # If you add a new embedding projector that should be warm-started, # also add it to the list here. # TODO(lit-dev): add some registry mechanism / automation if this grows to # more than 2-3 projection types. if warm_projections: self._warm_projections(['pca', 'umap']) handlers = { # Metadata endpoints. '/get_info': self._get_info, # Dataset-related endpoints. '/get_dataset': self._get_dataset, '/create_dataset': self._create_dataset, '/create_model': self._create_model, '/get_generated': self._get_generated, '/save_datapoints': self._save_datapoints, '/load_datapoints': self._load_datapoints, '/annotate_new_data': self._annotate_new_data, # Model prediction endpoints. '/get_preds': self._get_preds, '/get_interpretations': self._get_interpretations, } self._wsgi_app = wsgi_app.App( # Wrap endpoint fns to take (handler, request, environ) handlers={k: self.make_handler(v) for k, v in handlers.items()}, project_root=client_root, index_file='static/index.html', )