示例#1
0
    def test_caching_model_wrapper_mixed_list(self):
        model = testing_utils.TestIdentityRegressionModel()
        wrapper = caching.CachingModelWrapper(model, "test")
        examples = [{"data": {"val": 1}, "id": "my_id"}]
        results = wrapper.predict_with_metadata(examples, "dataset")
        self.assertEqual(1, model.count)
        self.assertEqual({"score": 1}, results[0])

        examples = [
            {
                "data": {
                    "val": 0
                },
                "id": "first_id"
            },
            {
                "data": {
                    "val": 1
                },
                "id": "my_id"
            },
            {
                "data": {
                    "val": 2
                },
                "id": "last_id"
            },
        ]
        results = wrapper.predict_with_metadata(examples, "dataset")
        self.assertEqual(3, model.count)
        self.assertEqual({"score": 0}, results[0])
        self.assertEqual({"score": 1}, results[1])
        self.assertEqual({"score": 2}, results[2])
示例#2
0
  def __init__(self, model: lit_model.Model, indexed_inputs: List[JsonDict],
               model_outputs: Optional[List[JsonDict]],
               projector: ProjectorModel, field_name: Text, name: Text):
    self._projector = caching.CachingModelWrapper(projector, name=name)
    self._field_name = field_name

    # Train on the given examples
    self._run(model, indexed_inputs, model_outputs, do_fit=True)
示例#3
0
 def test_caching_model_wrapper_use_cache(self):
     model = testing_utils.TestIdentityRegressionModel()
     wrapper = caching.CachingModelWrapper(model, "test")
     examples = [{"data": {"val": 1}, "id": "id_to_cache"}]
     results = wrapper.predict_with_metadata(examples, "dataset")
     self.assertEqual(1, model.count)
     self.assertEqual({"score": 1}, results[0])
     results = wrapper.predict_with_metadata(examples, "dataset")
     self.assertEqual(1, model.count)
     self.assertEqual({"score": 1}, results[0])
示例#4
0
    def setUp(self):
        super(ThresholderTest, self).setUp()
        self.thresholder = thresholder.Thresholder()
        self.model = caching.CachingModelWrapper(
            glue_models.SST2Model(BERT_TINY_PATH), 'test')
        examples = [{
            'sentence': 'a',
            'label': '1'
        }, {
            'sentence': 'b',
            'label': '1'
        }, {
            'sentence': 'c',
            'label': '1'
        }, {
            'sentence': 'd',
            'label': '1'
        }, {
            'sentence': 'e',
            'label': '1'
        }, {
            'sentence': 'f',
            'label': '0'
        }, {
            'sentence': 'g',
            'label': '0'
        }, {
            'sentence': 'h',
            'label': '0'
        }, {
            'sentence': 'i',
            'label': '0'
        }]

        self.indexed_inputs = [{
            'id': caching.input_hash(ex),
            'data': ex
        } for ex in examples]
        self.dataset = lit_dataset.IndexedDataset(
            id_fn=caching.input_hash,
            spec={
                'sentence': lit_types.TextSegment(),
                'label': lit_types.CategoryLabel(vocab=['0', '1'])
            },
            indexed_examples=self.indexed_inputs)
        self.model_outputs = list(
            self.model.predict_with_metadata(self.indexed_inputs,
                                             dataset_name='test'))
示例#5
0
文件: app.py 项目: PAIR-code/lit
    def _create_model(self,
                      unused_data,
                      model_name: Optional[Text] = None,
                      model_path: Optional[Text] = None,
                      **unused_kw):
        """Create model from a path, updating and returning the metadata."""

        assert model_name is not None, 'No model specified.'
        assert model_path is not None, 'No model path specified.'
        # Load using the underlying model class, then wrap explicitly in a cache.
        new_model = self._models[model_name].wrapped.load(model_path)
        if new_model is not None:
            new_model_name = model_name + ':' + os.path.basename(model_path)
            self._models[new_model_name] = caching.CachingModelWrapper(
                new_model, new_model_name, cache_dir=self._data_dir)
            self._info = self._build_metadata()
            return (self._info, new_model_name)
        else:
            return None
示例#6
0
文件: app.py 项目: zhiyiZeng/lit
  def __init__(
      self,
      models: Mapping[Text, lit_model.Model],
      datasets: MutableMapping[Text, lit_dataset.Dataset],
      generators: Optional[Mapping[Text, lit_components.Generator]] = None,
      interpreters: Optional[Mapping[Text, lit_components.Interpreter]] = None,
      # General server config; see server_flags.py.
      data_dir: Optional[Text] = None,
      warm_start: float = 0.0,
      warm_projections: bool = False,
      client_root: Optional[Text] = None,
      demo_mode: bool = False,
      default_layout: str = None,
      canonical_url: str = None,
  ):
    if client_root is None:
      raise ValueError('client_root must be set on application')
    self._demo_mode = demo_mode
    self._default_layout = default_layout
    self._canonical_url = canonical_url
    if data_dir and not os.path.isdir(data_dir):
      os.mkdir(data_dir)
    self._models = {
        name: caching.CachingModelWrapper(model, name, cache_dir=data_dir)
        for name, model in models.items()
    }
    self._datasets = datasets
    self._datasets['_union_empty'] = NoneDataset(self._models)
    if generators is not None:
      self._generators = generators
    else:
      self._generators = {
          'scrambler': scrambler.Scrambler(),
          'word_replacer': word_replacer.WordReplacer(),
      }

    if interpreters is not None:
      self._interpreters = interpreters
    else:
      metrics_group = lit_components.ComponentGroup({
          'regression': metrics.RegressionMetrics(),
          'multiclass': metrics.MulticlassMetrics(),
          'paired': metrics.MulticlassPairedMetrics(),
          'bleu': metrics.CorpusBLEU(),
      })
      self._interpreters = {
          'grad_norm': gradient_maps.GradientNorm(),
          'lime': lime_explainer.LIME(),
          'grad_dot_input': gradient_maps.GradientDotInput(),
          'integrated gradients': gradient_maps.IntegratedGradients(),
          'counterfactual explainer': lemon_explainer.LEMON(),
          'metrics': metrics_group,
          # Embedding projectors expose a standard interface, but get special
          # handling so we can precompute the projections if requested.
          'pca': projection.ProjectionManager(pca.PCAModel),
          'umap': projection.ProjectionManager(umap.UmapModel),
      }

    # Information on models and datasets.
    self._build_metadata()

    # Optionally, run models to pre-populate cache.
    if warm_projections:
      logging.info(
          'Projection (dimensionality reduction) warm-start requested; '
          'will do full warm-start for all models since predictions are needed.'
      )
      warm_start = 1.0

    if warm_start > 0:
      self._warm_start(rate=warm_start)
      self.save_cache()

    # If you add a new embedding projector that should be warm-started,
    # also add it to the list here.
    # TODO(lit-dev): add some registry mechanism / automation if this grows to
    # more than 2-3 projection types.
    if warm_projections:
      self._warm_projections(['pca', 'umap'])

    handlers = {
        # Metadata endpoints.
        '/get_info': self._get_info,
        # Dataset-related endpoints.
        '/get_dataset': self._get_dataset,
        '/get_generated': self._get_generated,
        '/save_datapoints': self._save_datapoints,
        '/load_datapoints': self._load_datapoints,
        '/get_datapoint_ids': self._get_datapoint_ids,
        # Model prediction endpoints.
        '/get_preds': self._get_preds,
        '/get_interpretations': self._get_interpretations,
    }

    self._wsgi_app = wsgi_app.App(
        # Wrap endpoint fns to take (handler, request)
        handlers={k: make_handler(v) for k, v in handlers.items()},
        project_root=client_root,
        index_file='static/index.html',
    )
示例#7
0
 def setUp(self):
   super(ModelBasedTCAVTest, self).setUp()
   self.tcav = tcav.TCAV()
   self.model = caching.CachingModelWrapper(
       glue_models.SST2Model(BERT_TINY_PATH), 'test')
示例#8
0
文件: app.py 项目: PAIR-code/lit
    def __init__(
        self,
        models: Mapping[Text, lit_model.Model],
        datasets: Mapping[Text, lit_dataset.Dataset],
        generators: Optional[Mapping[Text, lit_components.Generator]] = None,
        interpreters: Optional[Mapping[Text,
                                       lit_components.Interpreter]] = None,
        annotators: Optional[List[lit_components.Annotator]] = None,
        layouts: Optional[dtypes.LitComponentLayouts] = None,
        # General server config; see server_flags.py.
        data_dir: Optional[Text] = None,
        warm_start: float = 0.0,
        warm_projections: bool = False,
        client_root: Optional[Text] = None,
        demo_mode: bool = False,
        default_layout: Optional[str] = None,
        canonical_url: Optional[str] = None,
        page_title: Optional[str] = None,
        development_demo: bool = False,
    ):
        if client_root is None:
            raise ValueError('client_root must be set on application')
        self._demo_mode = demo_mode
        self._development_demo = development_demo
        self._default_layout = default_layout
        self._canonical_url = canonical_url
        self._page_title = page_title
        self._data_dir = data_dir
        self._layouts = layouts or {}
        if data_dir and not os.path.isdir(data_dir):
            os.mkdir(data_dir)

        # Wrap models in caching wrapper
        self._models = {
            name: caching.CachingModelWrapper(model, name, cache_dir=data_dir)
            for name, model in models.items()
        }

        self._datasets = dict(datasets)
        self._datasets['_union_empty'] = lit_dataset.NoneDataset(self._models)

        self._annotators = annotators or []

        # Run annotation on each dataset, creating an annotated dataset and
        # replace the datasets with the annotated versions.
        for ds_key, ds in self._datasets.items():
            self._datasets[ds_key] = self._run_annotators(ds)

        # Index all datasets
        self._datasets = lit_dataset.IndexedDataset.index_all(
            self._datasets, caching.input_hash)

        if generators is not None:
            self._generators = generators
        else:
            self._generators = {
                'Ablation Flip': ablation_flip.AblationFlip(),
                'Hotflip': hotflip.HotFlip(),
                'Scrambler': scrambler.Scrambler(),
                'Word Replacer': word_replacer.WordReplacer(),
            }

        if interpreters is not None:
            self._interpreters = interpreters
        else:
            metrics_group = lit_components.ComponentGroup({
                'regression':
                metrics.RegressionMetrics(),
                'multiclass':
                metrics.MulticlassMetrics(),
                'paired':
                metrics.MulticlassPairedMetrics(),
                'bleu':
                metrics.CorpusBLEU(),
            })
            self._interpreters = {
                'Grad L2 Norm':
                gradient_maps.GradientNorm(),
                'Grad ⋅ Input':
                gradient_maps.GradientDotInput(),
                'Integrated Gradients':
                gradient_maps.IntegratedGradients(),
                'LIME':
                lime_explainer.LIME(),
                'Model-provided salience':
                model_salience.ModelSalience(self._models),
                'counterfactual explainer':
                lemon_explainer.LEMON(),
                'tcav':
                tcav.TCAV(),
                'thresholder':
                thresholder.Thresholder(),
                'nearest neighbors':
                nearest_neighbors.NearestNeighbors(),
                'metrics':
                metrics_group,
                'pdp':
                pdp.PdpInterpreter(),
                # Embedding projectors expose a standard interface, but get special
                # handling so we can precompute the projections if requested.
                'pca':
                projection.ProjectionManager(pca.PCAModel),
                'umap':
                projection.ProjectionManager(umap.UmapModel),
            }

        # Information on models, datasets, and other components.
        self._info = self._build_metadata()

        # Optionally, run models to pre-populate cache.
        if warm_projections:
            logging.info(
                'Projection (dimensionality reduction) warm-start requested; '
                'will do full warm-start for all models since predictions are needed.'
            )
            warm_start = 1.0

        if warm_start > 0:
            self._warm_start(rate=warm_start)
            self.save_cache()

        # If you add a new embedding projector that should be warm-started,
        # also add it to the list here.
        # TODO(lit-dev): add some registry mechanism / automation if this grows to
        # more than 2-3 projection types.
        if warm_projections:
            self._warm_projections(['pca', 'umap'])

        handlers = {
            # Metadata endpoints.
            '/get_info': self._get_info,
            # Dataset-related endpoints.
            '/get_dataset': self._get_dataset,
            '/create_dataset': self._create_dataset,
            '/create_model': self._create_model,
            '/get_generated': self._get_generated,
            '/save_datapoints': self._save_datapoints,
            '/load_datapoints': self._load_datapoints,
            '/annotate_new_data': self._annotate_new_data,
            # Model prediction endpoints.
            '/get_preds': self._get_preds,
            '/get_interpretations': self._get_interpretations,
        }

        self._wsgi_app = wsgi_app.App(
            # Wrap endpoint fns to take (handler, request, environ)
            handlers={k: self.make_handler(v)
                      for k, v in handlers.items()},
            project_root=client_root,
            index_file='static/index.html',
        )