def __init__( self, models: Mapping[Text, lit_model.Model], datasets: MutableMapping[Text, lit_dataset.Dataset], generators: Optional[Mapping[Text, lit_components.Generator]] = None, interpreters: Optional[Mapping[Text, lit_components.Interpreter]] = None, # General server config; see server_flags.py. data_dir: Optional[Text] = None, warm_start: float = 0.0, warm_projections: bool = False, client_root: Optional[Text] = None, demo_mode: bool = False, default_layout: str = None, canonical_url: str = None, ): if client_root is None: raise ValueError('client_root must be set on application') self._demo_mode = demo_mode self._default_layout = default_layout self._canonical_url = canonical_url if data_dir and not os.path.isdir(data_dir): os.mkdir(data_dir) self._models = { name: caching.CachingModelWrapper(model, name, cache_dir=data_dir) for name, model in models.items() } self._datasets = datasets self._datasets['_union_empty'] = NoneDataset(self._models) if generators is not None: self._generators = generators else: self._generators = { 'scrambler': scrambler.Scrambler(), 'word_replacer': word_replacer.WordReplacer(), } if interpreters is not None: self._interpreters = interpreters else: metrics_group = lit_components.ComponentGroup({ 'regression': metrics.RegressionMetrics(), 'multiclass': metrics.MulticlassMetrics(), 'paired': metrics.MulticlassPairedMetrics(), 'bleu': metrics.CorpusBLEU(), }) self._interpreters = { 'grad_norm': gradient_maps.GradientNorm(), 'lime': lime_explainer.LIME(), 'grad_dot_input': gradient_maps.GradientDotInput(), 'integrated gradients': gradient_maps.IntegratedGradients(), 'counterfactual explainer': lemon_explainer.LEMON(), 'metrics': metrics_group, # Embedding projectors expose a standard interface, but get special # handling so we can precompute the projections if requested. 'pca': projection.ProjectionManager(pca.PCAModel), 'umap': projection.ProjectionManager(umap.UmapModel), } # Information on models and datasets. self._build_metadata() # Optionally, run models to pre-populate cache. if warm_projections: logging.info( 'Projection (dimensionality reduction) warm-start requested; ' 'will do full warm-start for all models since predictions are needed.' ) warm_start = 1.0 if warm_start > 0: self._warm_start(rate=warm_start) self.save_cache() # If you add a new embedding projector that should be warm-started, # also add it to the list here. # TODO(lit-dev): add some registry mechanism / automation if this grows to # more than 2-3 projection types. if warm_projections: self._warm_projections(['pca', 'umap']) handlers = { # Metadata endpoints. '/get_info': self._get_info, # Dataset-related endpoints. '/get_dataset': self._get_dataset, '/get_generated': self._get_generated, '/save_datapoints': self._save_datapoints, '/load_datapoints': self._load_datapoints, '/get_datapoint_ids': self._get_datapoint_ids, # Model prediction endpoints. '/get_preds': self._get_preds, '/get_interpretations': self._get_interpretations, } self._wsgi_app = wsgi_app.App( # Wrap endpoint fns to take (handler, request) handlers={k: make_handler(v) for k, v in handlers.items()}, project_root=client_root, index_file='static/index.html', )
def create_lit_args(exp_name: str, checkpoints: Set[str] = {"last"}) -> Tuple[tuple, dict]: config = load_config(exp_name) datasets: Dict[str, lit_dataset.Dataset] = { "test": create_test_dataset(config) } src_spp, trg_spp = config.create_sp_processors() model = create_model(config) models: Dict[str, NMTModel] = {} for checkpoint in checkpoints: if checkpoint == "avg": checkpoint_path, _ = get_last_checkpoint(config.model_dir / "avg") models["avg"] = NMTModel(config, model, src_spp, trg_spp, -1, checkpoint_path, type="avg") elif checkpoint == "last": last_checkpoint_path, last_step = get_last_checkpoint( config.model_dir) step_str = str(last_step) if step_str in models: models[step_str].types.append("last") else: models[str(last_step)] = NMTModel(config, model, src_spp, trg_spp, last_step, last_checkpoint_path, type="last") elif checkpoint == "best": best_model_path, best_step = get_best_model_dir(config.model_dir) step_str = str(best_step) if step_str in models: models[step_str].types.append("best") else: models[step_str] = NMTModel(config, model, src_spp, trg_spp, best_step, best_model_path / "ckpt", type="best") else: checkpoint_path = config.model_dir / f"ckpt-{checkpoint}" step = int(checkpoint) models[checkpoint] = NMTModel(config, model, src_spp, trg_spp, step, checkpoint_path) index_datasets: Dict[str, lit_dataset.Dataset] = { "test": create_train_dataset(config) } indexer = IndexerEx( models, lit_dataset.IndexedDataset.index_all(index_datasets, caching.input_hash), data_dir=str(config.exp_dir / "lit-index"), initialize_new_indices=True, ) generators: Dict[str, lit_components.Generator] = { "word_replacer": word_replacer.WordReplacer(), "similarity_searcher": similarity_searcher.SimilaritySearcher(indexer), } interpreters: Dict[str, lit_components.Interpreter] = { "metrics": lit_components.ComponentGroup({"bleu": BLEUMetrics(config.data)}), "pca": projection.ProjectionManager(pca.PCAModel), "umap": projection.ProjectionManager(umap.UmapModel), } return ((models, datasets), { "generators": generators, "interpreters": interpreters })
def __init__( self, models: Mapping[Text, lit_model.Model], datasets: Mapping[Text, lit_dataset.Dataset], generators: Optional[Mapping[Text, lit_components.Generator]] = None, interpreters: Optional[Mapping[Text, lit_components.Interpreter]] = None, annotators: Optional[List[lit_components.Annotator]] = None, layouts: Optional[dtypes.LitComponentLayouts] = None, # General server config; see server_flags.py. data_dir: Optional[Text] = None, warm_start: float = 0.0, warm_projections: bool = False, client_root: Optional[Text] = None, demo_mode: bool = False, default_layout: Optional[str] = None, canonical_url: Optional[str] = None, page_title: Optional[str] = None, development_demo: bool = False, ): if client_root is None: raise ValueError('client_root must be set on application') self._demo_mode = demo_mode self._development_demo = development_demo self._default_layout = default_layout self._canonical_url = canonical_url self._page_title = page_title self._data_dir = data_dir self._layouts = layouts or {} if data_dir and not os.path.isdir(data_dir): os.mkdir(data_dir) # Wrap models in caching wrapper self._models = { name: caching.CachingModelWrapper(model, name, cache_dir=data_dir) for name, model in models.items() } self._datasets = dict(datasets) self._datasets['_union_empty'] = lit_dataset.NoneDataset(self._models) self._annotators = annotators or [] # Run annotation on each dataset, creating an annotated dataset and # replace the datasets with the annotated versions. for ds_key, ds in self._datasets.items(): self._datasets[ds_key] = self._run_annotators(ds) # Index all datasets self._datasets = lit_dataset.IndexedDataset.index_all( self._datasets, caching.input_hash) if generators is not None: self._generators = generators else: self._generators = { 'Ablation Flip': ablation_flip.AblationFlip(), 'Hotflip': hotflip.HotFlip(), 'Scrambler': scrambler.Scrambler(), 'Word Replacer': word_replacer.WordReplacer(), } if interpreters is not None: self._interpreters = interpreters else: metrics_group = lit_components.ComponentGroup({ 'regression': metrics.RegressionMetrics(), 'multiclass': metrics.MulticlassMetrics(), 'paired': metrics.MulticlassPairedMetrics(), 'bleu': metrics.CorpusBLEU(), }) self._interpreters = { 'Grad L2 Norm': gradient_maps.GradientNorm(), 'Grad ⋅ Input': gradient_maps.GradientDotInput(), 'Integrated Gradients': gradient_maps.IntegratedGradients(), 'LIME': lime_explainer.LIME(), 'Model-provided salience': model_salience.ModelSalience(self._models), 'counterfactual explainer': lemon_explainer.LEMON(), 'tcav': tcav.TCAV(), 'thresholder': thresholder.Thresholder(), 'nearest neighbors': nearest_neighbors.NearestNeighbors(), 'metrics': metrics_group, 'pdp': pdp.PdpInterpreter(), # Embedding projectors expose a standard interface, but get special # handling so we can precompute the projections if requested. 'pca': projection.ProjectionManager(pca.PCAModel), 'umap': projection.ProjectionManager(umap.UmapModel), } # Information on models, datasets, and other components. self._info = self._build_metadata() # Optionally, run models to pre-populate cache. if warm_projections: logging.info( 'Projection (dimensionality reduction) warm-start requested; ' 'will do full warm-start for all models since predictions are needed.' ) warm_start = 1.0 if warm_start > 0: self._warm_start(rate=warm_start) self.save_cache() # If you add a new embedding projector that should be warm-started, # also add it to the list here. # TODO(lit-dev): add some registry mechanism / automation if this grows to # more than 2-3 projection types. if warm_projections: self._warm_projections(['pca', 'umap']) handlers = { # Metadata endpoints. '/get_info': self._get_info, # Dataset-related endpoints. '/get_dataset': self._get_dataset, '/create_dataset': self._create_dataset, '/create_model': self._create_model, '/get_generated': self._get_generated, '/save_datapoints': self._save_datapoints, '/load_datapoints': self._load_datapoints, '/annotate_new_data': self._annotate_new_data, # Model prediction endpoints. '/get_preds': self._get_preds, '/get_interpretations': self._get_interpretations, } self._wsgi_app = wsgi_app.App( # Wrap endpoint fns to take (handler, request, environ) handlers={k: self.make_handler(v) for k, v in handlers.items()}, project_root=client_root, index_file='static/index.html', )