def test_compute(self): multiclass_paired_metrics = metrics.MulticlassPairedMetrics() indices = ['7f7f85', '345ac4', '3a3112', '88bcda'] metas = [{'parentId': '345ac4'}, {}, {}, {'parentId': '3a3112'}] # No swaps. result = multiclass_paired_metrics.compute_with_metadata( ['1', '1', '0', '0'], [[0, 1], [0, 1], [1, 0], [1, 0]], types.CategoryLabel(), types.MulticlassPreds(vocab=['0', '1'], null_idx=0), indices, metas) self.assertAlmostEqual(result, { 'mean_jsd': 0.0, 'num_pairs': 2, 'swap_rate': 0.0 }) # One swap. result = multiclass_paired_metrics.compute_with_metadata( ['1', '1', '0', '0'], [[0, 1], [1, 0], [1, 0], [1, 0]], types.CategoryLabel(), types.MulticlassPreds(vocab=['0', '1'], null_idx=0), indices, metas) self.assertAlmostEqual(result, { 'mean_jsd': 0.3465735902799726, 'num_pairs': 2, 'swap_rate': 0.5 }) # Two swaps. result = multiclass_paired_metrics.compute_with_metadata( ['1', '1', '0', '0'], [[0, 1], [1, 0], [1, 0], [0, 1]], types.CategoryLabel(), types.MulticlassPreds(vocab=['0', '1'], null_idx=0), indices, metas) self.assertAlmostEqual(result, { 'mean_jsd': 0.6931471805599452, 'num_pairs': 2, 'swap_rate': 1.0 }) # Two swaps, no null index. result = multiclass_paired_metrics.compute_with_metadata( ['1', '1', '0', '0'], [[0, 1], [1, 0], [1, 0], [0, 1]], types.CategoryLabel(), types.MulticlassPreds(vocab=['0', '1']), indices, metas) self.assertAlmostEqual(result, { 'mean_jsd': 0.6931471805599452, 'num_pairs': 2, 'swap_rate': 1.0 }) # Empty predictions, indices, and meta. result = multiclass_paired_metrics.compute_with_metadata( [], [], types.CategoryLabel(), types.MulticlassPreds(vocab=['0', '1'], null_idx=0), [], []) self.assertAlmostEqual(result, {})
def test_is_compatible(self): multiclass_paired_metrics = metrics.MulticlassPairedMetrics() # Only compatible with MulticlassPreds spec. self.assertTrue( multiclass_paired_metrics.is_compatible( types.MulticlassPreds(vocab=['']))) self.assertFalse( multiclass_paired_metrics.is_compatible(types.RegressionScore())) self.assertFalse( multiclass_paired_metrics.is_compatible(types.GeneratedText()))
def __init__( self, models: Mapping[Text, lit_model.Model], datasets: MutableMapping[Text, lit_dataset.Dataset], generators: Optional[Mapping[Text, lit_components.Generator]] = None, interpreters: Optional[Mapping[Text, lit_components.Interpreter]] = None, # General server config; see server_flags.py. data_dir: Optional[Text] = None, warm_start: float = 0.0, warm_projections: bool = False, client_root: Optional[Text] = None, demo_mode: bool = False, default_layout: str = None, canonical_url: str = None, ): if client_root is None: raise ValueError('client_root must be set on application') self._demo_mode = demo_mode self._default_layout = default_layout self._canonical_url = canonical_url if data_dir and not os.path.isdir(data_dir): os.mkdir(data_dir) self._models = { name: caching.CachingModelWrapper(model, name, cache_dir=data_dir) for name, model in models.items() } self._datasets = datasets self._datasets['_union_empty'] = NoneDataset(self._models) if generators is not None: self._generators = generators else: self._generators = { 'scrambler': scrambler.Scrambler(), 'word_replacer': word_replacer.WordReplacer(), } if interpreters is not None: self._interpreters = interpreters else: metrics_group = lit_components.ComponentGroup({ 'regression': metrics.RegressionMetrics(), 'multiclass': metrics.MulticlassMetrics(), 'paired': metrics.MulticlassPairedMetrics(), 'bleu': metrics.CorpusBLEU(), }) self._interpreters = { 'grad_norm': gradient_maps.GradientNorm(), 'lime': lime_explainer.LIME(), 'grad_dot_input': gradient_maps.GradientDotInput(), 'integrated gradients': gradient_maps.IntegratedGradients(), 'counterfactual explainer': lemon_explainer.LEMON(), 'metrics': metrics_group, # Embedding projectors expose a standard interface, but get special # handling so we can precompute the projections if requested. 'pca': projection.ProjectionManager(pca.PCAModel), 'umap': projection.ProjectionManager(umap.UmapModel), } # Information on models and datasets. self._build_metadata() # Optionally, run models to pre-populate cache. if warm_projections: logging.info( 'Projection (dimensionality reduction) warm-start requested; ' 'will do full warm-start for all models since predictions are needed.' ) warm_start = 1.0 if warm_start > 0: self._warm_start(rate=warm_start) self.save_cache() # If you add a new embedding projector that should be warm-started, # also add it to the list here. # TODO(lit-dev): add some registry mechanism / automation if this grows to # more than 2-3 projection types. if warm_projections: self._warm_projections(['pca', 'umap']) handlers = { # Metadata endpoints. '/get_info': self._get_info, # Dataset-related endpoints. '/get_dataset': self._get_dataset, '/get_generated': self._get_generated, '/save_datapoints': self._save_datapoints, '/load_datapoints': self._load_datapoints, '/get_datapoint_ids': self._get_datapoint_ids, # Model prediction endpoints. '/get_preds': self._get_preds, '/get_interpretations': self._get_interpretations, } self._wsgi_app = wsgi_app.App( # Wrap endpoint fns to take (handler, request) handlers={k: make_handler(v) for k, v in handlers.items()}, project_root=client_root, index_file='static/index.html', )
def __init__( self, models: Mapping[Text, lit_model.Model], datasets: Mapping[Text, lit_dataset.Dataset], generators: Optional[Mapping[Text, lit_components.Generator]] = None, interpreters: Optional[Mapping[Text, lit_components.Interpreter]] = None, annotators: Optional[List[lit_components.Annotator]] = None, layouts: Optional[dtypes.LitComponentLayouts] = None, # General server config; see server_flags.py. data_dir: Optional[Text] = None, warm_start: float = 0.0, warm_projections: bool = False, client_root: Optional[Text] = None, demo_mode: bool = False, default_layout: Optional[str] = None, canonical_url: Optional[str] = None, page_title: Optional[str] = None, development_demo: bool = False, ): if client_root is None: raise ValueError('client_root must be set on application') self._demo_mode = demo_mode self._development_demo = development_demo self._default_layout = default_layout self._canonical_url = canonical_url self._page_title = page_title self._data_dir = data_dir self._layouts = layouts or {} if data_dir and not os.path.isdir(data_dir): os.mkdir(data_dir) # Wrap models in caching wrapper self._models = { name: caching.CachingModelWrapper(model, name, cache_dir=data_dir) for name, model in models.items() } self._datasets = dict(datasets) self._datasets['_union_empty'] = lit_dataset.NoneDataset(self._models) self._annotators = annotators or [] # Run annotation on each dataset, creating an annotated dataset and # replace the datasets with the annotated versions. for ds_key, ds in self._datasets.items(): self._datasets[ds_key] = self._run_annotators(ds) # Index all datasets self._datasets = lit_dataset.IndexedDataset.index_all( self._datasets, caching.input_hash) if generators is not None: self._generators = generators else: self._generators = { 'Ablation Flip': ablation_flip.AblationFlip(), 'Hotflip': hotflip.HotFlip(), 'Scrambler': scrambler.Scrambler(), 'Word Replacer': word_replacer.WordReplacer(), } if interpreters is not None: self._interpreters = interpreters else: metrics_group = lit_components.ComponentGroup({ 'regression': metrics.RegressionMetrics(), 'multiclass': metrics.MulticlassMetrics(), 'paired': metrics.MulticlassPairedMetrics(), 'bleu': metrics.CorpusBLEU(), }) self._interpreters = { 'Grad L2 Norm': gradient_maps.GradientNorm(), 'Grad ⋅ Input': gradient_maps.GradientDotInput(), 'Integrated Gradients': gradient_maps.IntegratedGradients(), 'LIME': lime_explainer.LIME(), 'Model-provided salience': model_salience.ModelSalience(self._models), 'counterfactual explainer': lemon_explainer.LEMON(), 'tcav': tcav.TCAV(), 'thresholder': thresholder.Thresholder(), 'nearest neighbors': nearest_neighbors.NearestNeighbors(), 'metrics': metrics_group, 'pdp': pdp.PdpInterpreter(), # Embedding projectors expose a standard interface, but get special # handling so we can precompute the projections if requested. 'pca': projection.ProjectionManager(pca.PCAModel), 'umap': projection.ProjectionManager(umap.UmapModel), } # Information on models, datasets, and other components. self._info = self._build_metadata() # Optionally, run models to pre-populate cache. if warm_projections: logging.info( 'Projection (dimensionality reduction) warm-start requested; ' 'will do full warm-start for all models since predictions are needed.' ) warm_start = 1.0 if warm_start > 0: self._warm_start(rate=warm_start) self.save_cache() # If you add a new embedding projector that should be warm-started, # also add it to the list here. # TODO(lit-dev): add some registry mechanism / automation if this grows to # more than 2-3 projection types. if warm_projections: self._warm_projections(['pca', 'umap']) handlers = { # Metadata endpoints. '/get_info': self._get_info, # Dataset-related endpoints. '/get_dataset': self._get_dataset, '/create_dataset': self._create_dataset, '/create_model': self._create_model, '/get_generated': self._get_generated, '/save_datapoints': self._save_datapoints, '/load_datapoints': self._load_datapoints, '/annotate_new_data': self._annotate_new_data, # Model prediction endpoints. '/get_preds': self._get_preds, '/get_interpretations': self._get_interpretations, } self._wsgi_app = wsgi_app.App( # Wrap endpoint fns to take (handler, request, environ) handlers={k: self.make_handler(v) for k, v in handlers.items()}, project_root=client_root, index_file='static/index.html', )