Exemplo n.º 1
0
 def setUp(self):
   super(TCAVTest, self).setUp()
   self.tcav = tcav.TCAV()
Exemplo n.º 2
0
 def setUp(self):
   super(TCAVTest, self).setUp()
   self.tcav = tcav.TCAV()
   self.model = glue_models.SST2Model(BERT_TINY_PATH)
Exemplo n.º 3
0
 def setUp(self):
   super(ModelBasedTCAVTest, self).setUp()
   self.tcav = tcav.TCAV()
   self.model = caching.CachingModelWrapper(
       glue_models.SST2Model(BERT_TINY_PATH), 'test')
Exemplo n.º 4
0
Arquivo: app.py Projeto: PAIR-code/lit
    def __init__(
        self,
        models: Mapping[Text, lit_model.Model],
        datasets: Mapping[Text, lit_dataset.Dataset],
        generators: Optional[Mapping[Text, lit_components.Generator]] = None,
        interpreters: Optional[Mapping[Text,
                                       lit_components.Interpreter]] = None,
        annotators: Optional[List[lit_components.Annotator]] = None,
        layouts: Optional[dtypes.LitComponentLayouts] = None,
        # General server config; see server_flags.py.
        data_dir: Optional[Text] = None,
        warm_start: float = 0.0,
        warm_projections: bool = False,
        client_root: Optional[Text] = None,
        demo_mode: bool = False,
        default_layout: Optional[str] = None,
        canonical_url: Optional[str] = None,
        page_title: Optional[str] = None,
        development_demo: bool = False,
    ):
        if client_root is None:
            raise ValueError('client_root must be set on application')
        self._demo_mode = demo_mode
        self._development_demo = development_demo
        self._default_layout = default_layout
        self._canonical_url = canonical_url
        self._page_title = page_title
        self._data_dir = data_dir
        self._layouts = layouts or {}
        if data_dir and not os.path.isdir(data_dir):
            os.mkdir(data_dir)

        # Wrap models in caching wrapper
        self._models = {
            name: caching.CachingModelWrapper(model, name, cache_dir=data_dir)
            for name, model in models.items()
        }

        self._datasets = dict(datasets)
        self._datasets['_union_empty'] = lit_dataset.NoneDataset(self._models)

        self._annotators = annotators or []

        # Run annotation on each dataset, creating an annotated dataset and
        # replace the datasets with the annotated versions.
        for ds_key, ds in self._datasets.items():
            self._datasets[ds_key] = self._run_annotators(ds)

        # Index all datasets
        self._datasets = lit_dataset.IndexedDataset.index_all(
            self._datasets, caching.input_hash)

        if generators is not None:
            self._generators = generators
        else:
            self._generators = {
                'Ablation Flip': ablation_flip.AblationFlip(),
                'Hotflip': hotflip.HotFlip(),
                'Scrambler': scrambler.Scrambler(),
                'Word Replacer': word_replacer.WordReplacer(),
            }

        if interpreters is not None:
            self._interpreters = interpreters
        else:
            metrics_group = lit_components.ComponentGroup({
                'regression':
                metrics.RegressionMetrics(),
                'multiclass':
                metrics.MulticlassMetrics(),
                'paired':
                metrics.MulticlassPairedMetrics(),
                'bleu':
                metrics.CorpusBLEU(),
            })
            self._interpreters = {
                'Grad L2 Norm':
                gradient_maps.GradientNorm(),
                'Grad ⋅ Input':
                gradient_maps.GradientDotInput(),
                'Integrated Gradients':
                gradient_maps.IntegratedGradients(),
                'LIME':
                lime_explainer.LIME(),
                'Model-provided salience':
                model_salience.ModelSalience(self._models),
                'counterfactual explainer':
                lemon_explainer.LEMON(),
                'tcav':
                tcav.TCAV(),
                'thresholder':
                thresholder.Thresholder(),
                'nearest neighbors':
                nearest_neighbors.NearestNeighbors(),
                'metrics':
                metrics_group,
                'pdp':
                pdp.PdpInterpreter(),
                # Embedding projectors expose a standard interface, but get special
                # handling so we can precompute the projections if requested.
                'pca':
                projection.ProjectionManager(pca.PCAModel),
                'umap':
                projection.ProjectionManager(umap.UmapModel),
            }

        # Information on models, datasets, and other components.
        self._info = self._build_metadata()

        # Optionally, run models to pre-populate cache.
        if warm_projections:
            logging.info(
                'Projection (dimensionality reduction) warm-start requested; '
                'will do full warm-start for all models since predictions are needed.'
            )
            warm_start = 1.0

        if warm_start > 0:
            self._warm_start(rate=warm_start)
            self.save_cache()

        # If you add a new embedding projector that should be warm-started,
        # also add it to the list here.
        # TODO(lit-dev): add some registry mechanism / automation if this grows to
        # more than 2-3 projection types.
        if warm_projections:
            self._warm_projections(['pca', 'umap'])

        handlers = {
            # Metadata endpoints.
            '/get_info': self._get_info,
            # Dataset-related endpoints.
            '/get_dataset': self._get_dataset,
            '/create_dataset': self._create_dataset,
            '/create_model': self._create_model,
            '/get_generated': self._get_generated,
            '/save_datapoints': self._save_datapoints,
            '/load_datapoints': self._load_datapoints,
            '/annotate_new_data': self._annotate_new_data,
            # Model prediction endpoints.
            '/get_preds': self._get_preds,
            '/get_interpretations': self._get_interpretations,
        }

        self._wsgi_app = wsgi_app.App(
            # Wrap endpoint fns to take (handler, request, environ)
            handlers={k: self.make_handler(v)
                      for k, v in handlers.items()},
            project_root=client_root,
            index_file='static/index.html',
        )