def test_caching_model_wrapper_mixed_list(self): model = testing_utils.TestIdentityRegressionModel() wrapper = caching.CachingModelWrapper(model, "test") examples = [{"data": {"val": 1}, "id": "my_id"}] results = wrapper.predict_with_metadata(examples, "dataset") self.assertEqual(1, model.count) self.assertEqual({"score": 1}, results[0]) examples = [ { "data": { "val": 0 }, "id": "first_id" }, { "data": { "val": 1 }, "id": "my_id" }, { "data": { "val": 2 }, "id": "last_id" }, ] results = wrapper.predict_with_metadata(examples, "dataset") self.assertEqual(3, model.count) self.assertEqual({"score": 0}, results[0]) self.assertEqual({"score": 1}, results[1]) self.assertEqual({"score": 2}, results[2])
def __init__(self, model: lit_model.Model, indexed_inputs: List[JsonDict], model_outputs: Optional[List[JsonDict]], projector: ProjectorModel, field_name: Text, name: Text): self._projector = caching.CachingModelWrapper(projector, name=name) self._field_name = field_name # Train on the given examples self._run(model, indexed_inputs, model_outputs, do_fit=True)
def test_caching_model_wrapper_use_cache(self): model = testing_utils.TestIdentityRegressionModel() wrapper = caching.CachingModelWrapper(model, "test") examples = [{"data": {"val": 1}, "id": "id_to_cache"}] results = wrapper.predict_with_metadata(examples, "dataset") self.assertEqual(1, model.count) self.assertEqual({"score": 1}, results[0]) results = wrapper.predict_with_metadata(examples, "dataset") self.assertEqual(1, model.count) self.assertEqual({"score": 1}, results[0])
def setUp(self): super(ThresholderTest, self).setUp() self.thresholder = thresholder.Thresholder() self.model = caching.CachingModelWrapper( glue_models.SST2Model(BERT_TINY_PATH), 'test') examples = [{ 'sentence': 'a', 'label': '1' }, { 'sentence': 'b', 'label': '1' }, { 'sentence': 'c', 'label': '1' }, { 'sentence': 'd', 'label': '1' }, { 'sentence': 'e', 'label': '1' }, { 'sentence': 'f', 'label': '0' }, { 'sentence': 'g', 'label': '0' }, { 'sentence': 'h', 'label': '0' }, { 'sentence': 'i', 'label': '0' }] self.indexed_inputs = [{ 'id': caching.input_hash(ex), 'data': ex } for ex in examples] self.dataset = lit_dataset.IndexedDataset( id_fn=caching.input_hash, spec={ 'sentence': lit_types.TextSegment(), 'label': lit_types.CategoryLabel(vocab=['0', '1']) }, indexed_examples=self.indexed_inputs) self.model_outputs = list( self.model.predict_with_metadata(self.indexed_inputs, dataset_name='test'))
def _create_model(self, unused_data, model_name: Optional[Text] = None, model_path: Optional[Text] = None, **unused_kw): """Create model from a path, updating and returning the metadata.""" assert model_name is not None, 'No model specified.' assert model_path is not None, 'No model path specified.' # Load using the underlying model class, then wrap explicitly in a cache. new_model = self._models[model_name].wrapped.load(model_path) if new_model is not None: new_model_name = model_name + ':' + os.path.basename(model_path) self._models[new_model_name] = caching.CachingModelWrapper( new_model, new_model_name, cache_dir=self._data_dir) self._info = self._build_metadata() return (self._info, new_model_name) else: return None
def __init__( self, models: Mapping[Text, lit_model.Model], datasets: MutableMapping[Text, lit_dataset.Dataset], generators: Optional[Mapping[Text, lit_components.Generator]] = None, interpreters: Optional[Mapping[Text, lit_components.Interpreter]] = None, # General server config; see server_flags.py. data_dir: Optional[Text] = None, warm_start: float = 0.0, warm_projections: bool = False, client_root: Optional[Text] = None, demo_mode: bool = False, default_layout: str = None, canonical_url: str = None, ): if client_root is None: raise ValueError('client_root must be set on application') self._demo_mode = demo_mode self._default_layout = default_layout self._canonical_url = canonical_url if data_dir and not os.path.isdir(data_dir): os.mkdir(data_dir) self._models = { name: caching.CachingModelWrapper(model, name, cache_dir=data_dir) for name, model in models.items() } self._datasets = datasets self._datasets['_union_empty'] = NoneDataset(self._models) if generators is not None: self._generators = generators else: self._generators = { 'scrambler': scrambler.Scrambler(), 'word_replacer': word_replacer.WordReplacer(), } if interpreters is not None: self._interpreters = interpreters else: metrics_group = lit_components.ComponentGroup({ 'regression': metrics.RegressionMetrics(), 'multiclass': metrics.MulticlassMetrics(), 'paired': metrics.MulticlassPairedMetrics(), 'bleu': metrics.CorpusBLEU(), }) self._interpreters = { 'grad_norm': gradient_maps.GradientNorm(), 'lime': lime_explainer.LIME(), 'grad_dot_input': gradient_maps.GradientDotInput(), 'integrated gradients': gradient_maps.IntegratedGradients(), 'counterfactual explainer': lemon_explainer.LEMON(), 'metrics': metrics_group, # Embedding projectors expose a standard interface, but get special # handling so we can precompute the projections if requested. 'pca': projection.ProjectionManager(pca.PCAModel), 'umap': projection.ProjectionManager(umap.UmapModel), } # Information on models and datasets. self._build_metadata() # Optionally, run models to pre-populate cache. if warm_projections: logging.info( 'Projection (dimensionality reduction) warm-start requested; ' 'will do full warm-start for all models since predictions are needed.' ) warm_start = 1.0 if warm_start > 0: self._warm_start(rate=warm_start) self.save_cache() # If you add a new embedding projector that should be warm-started, # also add it to the list here. # TODO(lit-dev): add some registry mechanism / automation if this grows to # more than 2-3 projection types. if warm_projections: self._warm_projections(['pca', 'umap']) handlers = { # Metadata endpoints. '/get_info': self._get_info, # Dataset-related endpoints. '/get_dataset': self._get_dataset, '/get_generated': self._get_generated, '/save_datapoints': self._save_datapoints, '/load_datapoints': self._load_datapoints, '/get_datapoint_ids': self._get_datapoint_ids, # Model prediction endpoints. '/get_preds': self._get_preds, '/get_interpretations': self._get_interpretations, } self._wsgi_app = wsgi_app.App( # Wrap endpoint fns to take (handler, request) handlers={k: make_handler(v) for k, v in handlers.items()}, project_root=client_root, index_file='static/index.html', )
def setUp(self): super(ModelBasedTCAVTest, self).setUp() self.tcav = tcav.TCAV() self.model = caching.CachingModelWrapper( glue_models.SST2Model(BERT_TINY_PATH), 'test')
def __init__( self, models: Mapping[Text, lit_model.Model], datasets: Mapping[Text, lit_dataset.Dataset], generators: Optional[Mapping[Text, lit_components.Generator]] = None, interpreters: Optional[Mapping[Text, lit_components.Interpreter]] = None, annotators: Optional[List[lit_components.Annotator]] = None, layouts: Optional[dtypes.LitComponentLayouts] = None, # General server config; see server_flags.py. data_dir: Optional[Text] = None, warm_start: float = 0.0, warm_projections: bool = False, client_root: Optional[Text] = None, demo_mode: bool = False, default_layout: Optional[str] = None, canonical_url: Optional[str] = None, page_title: Optional[str] = None, development_demo: bool = False, ): if client_root is None: raise ValueError('client_root must be set on application') self._demo_mode = demo_mode self._development_demo = development_demo self._default_layout = default_layout self._canonical_url = canonical_url self._page_title = page_title self._data_dir = data_dir self._layouts = layouts or {} if data_dir and not os.path.isdir(data_dir): os.mkdir(data_dir) # Wrap models in caching wrapper self._models = { name: caching.CachingModelWrapper(model, name, cache_dir=data_dir) for name, model in models.items() } self._datasets = dict(datasets) self._datasets['_union_empty'] = lit_dataset.NoneDataset(self._models) self._annotators = annotators or [] # Run annotation on each dataset, creating an annotated dataset and # replace the datasets with the annotated versions. for ds_key, ds in self._datasets.items(): self._datasets[ds_key] = self._run_annotators(ds) # Index all datasets self._datasets = lit_dataset.IndexedDataset.index_all( self._datasets, caching.input_hash) if generators is not None: self._generators = generators else: self._generators = { 'Ablation Flip': ablation_flip.AblationFlip(), 'Hotflip': hotflip.HotFlip(), 'Scrambler': scrambler.Scrambler(), 'Word Replacer': word_replacer.WordReplacer(), } if interpreters is not None: self._interpreters = interpreters else: metrics_group = lit_components.ComponentGroup({ 'regression': metrics.RegressionMetrics(), 'multiclass': metrics.MulticlassMetrics(), 'paired': metrics.MulticlassPairedMetrics(), 'bleu': metrics.CorpusBLEU(), }) self._interpreters = { 'Grad L2 Norm': gradient_maps.GradientNorm(), 'Grad ⋅ Input': gradient_maps.GradientDotInput(), 'Integrated Gradients': gradient_maps.IntegratedGradients(), 'LIME': lime_explainer.LIME(), 'Model-provided salience': model_salience.ModelSalience(self._models), 'counterfactual explainer': lemon_explainer.LEMON(), 'tcav': tcav.TCAV(), 'thresholder': thresholder.Thresholder(), 'nearest neighbors': nearest_neighbors.NearestNeighbors(), 'metrics': metrics_group, 'pdp': pdp.PdpInterpreter(), # Embedding projectors expose a standard interface, but get special # handling so we can precompute the projections if requested. 'pca': projection.ProjectionManager(pca.PCAModel), 'umap': projection.ProjectionManager(umap.UmapModel), } # Information on models, datasets, and other components. self._info = self._build_metadata() # Optionally, run models to pre-populate cache. if warm_projections: logging.info( 'Projection (dimensionality reduction) warm-start requested; ' 'will do full warm-start for all models since predictions are needed.' ) warm_start = 1.0 if warm_start > 0: self._warm_start(rate=warm_start) self.save_cache() # If you add a new embedding projector that should be warm-started, # also add it to the list here. # TODO(lit-dev): add some registry mechanism / automation if this grows to # more than 2-3 projection types. if warm_projections: self._warm_projections(['pca', 'umap']) handlers = { # Metadata endpoints. '/get_info': self._get_info, # Dataset-related endpoints. '/get_dataset': self._get_dataset, '/create_dataset': self._create_dataset, '/create_model': self._create_model, '/get_generated': self._get_generated, '/save_datapoints': self._save_datapoints, '/load_datapoints': self._load_datapoints, '/annotate_new_data': self._annotate_new_data, # Model prediction endpoints. '/get_preds': self._get_preds, '/get_interpretations': self._get_interpretations, } self._wsgi_app = wsgi_app.App( # Wrap endpoint fns to take (handler, request, environ) handlers={k: self.make_handler(v) for k, v in handlers.items()}, project_root=client_root, index_file='static/index.html', )