예제 #1
0
    def test_compute(self):
        multiclass_paired_metrics = metrics.MulticlassPairedMetrics()

        indices = ['7f7f85', '345ac4', '3a3112', '88bcda']
        metas = [{'parentId': '345ac4'}, {}, {}, {'parentId': '3a3112'}]

        # No swaps.
        result = multiclass_paired_metrics.compute_with_metadata(
            ['1', '1', '0', '0'], [[0, 1], [0, 1], [1, 0], [1, 0]],
            types.CategoryLabel(),
            types.MulticlassPreds(vocab=['0', '1'],
                                  null_idx=0), indices, metas)
        self.assertAlmostEqual(result, {
            'mean_jsd': 0.0,
            'num_pairs': 2,
            'swap_rate': 0.0
        })

        # One swap.
        result = multiclass_paired_metrics.compute_with_metadata(
            ['1', '1', '0', '0'], [[0, 1], [1, 0], [1, 0], [1, 0]],
            types.CategoryLabel(),
            types.MulticlassPreds(vocab=['0', '1'],
                                  null_idx=0), indices, metas)
        self.assertAlmostEqual(result, {
            'mean_jsd': 0.3465735902799726,
            'num_pairs': 2,
            'swap_rate': 0.5
        })

        # Two swaps.
        result = multiclass_paired_metrics.compute_with_metadata(
            ['1', '1', '0', '0'], [[0, 1], [1, 0], [1, 0], [0, 1]],
            types.CategoryLabel(),
            types.MulticlassPreds(vocab=['0', '1'],
                                  null_idx=0), indices, metas)
        self.assertAlmostEqual(result, {
            'mean_jsd': 0.6931471805599452,
            'num_pairs': 2,
            'swap_rate': 1.0
        })

        # Two swaps, no null index.
        result = multiclass_paired_metrics.compute_with_metadata(
            ['1', '1', '0', '0'], [[0, 1], [1, 0], [1, 0], [0, 1]],
            types.CategoryLabel(), types.MulticlassPreds(vocab=['0', '1']),
            indices, metas)
        self.assertAlmostEqual(result, {
            'mean_jsd': 0.6931471805599452,
            'num_pairs': 2,
            'swap_rate': 1.0
        })

        # Empty predictions, indices, and meta.
        result = multiclass_paired_metrics.compute_with_metadata(
            [], [], types.CategoryLabel(),
            types.MulticlassPreds(vocab=['0', '1'], null_idx=0), [], [])
        self.assertAlmostEqual(result, {})
예제 #2
0
    def test_is_compatible(self):
        multiclass_paired_metrics = metrics.MulticlassPairedMetrics()

        # Only compatible with MulticlassPreds spec.
        self.assertTrue(
            multiclass_paired_metrics.is_compatible(
                types.MulticlassPreds(vocab=[''])))
        self.assertFalse(
            multiclass_paired_metrics.is_compatible(types.RegressionScore()))
        self.assertFalse(
            multiclass_paired_metrics.is_compatible(types.GeneratedText()))
예제 #3
0
파일: app.py 프로젝트: zhiyiZeng/lit
  def __init__(
      self,
      models: Mapping[Text, lit_model.Model],
      datasets: MutableMapping[Text, lit_dataset.Dataset],
      generators: Optional[Mapping[Text, lit_components.Generator]] = None,
      interpreters: Optional[Mapping[Text, lit_components.Interpreter]] = None,
      # General server config; see server_flags.py.
      data_dir: Optional[Text] = None,
      warm_start: float = 0.0,
      warm_projections: bool = False,
      client_root: Optional[Text] = None,
      demo_mode: bool = False,
      default_layout: str = None,
      canonical_url: str = None,
  ):
    if client_root is None:
      raise ValueError('client_root must be set on application')
    self._demo_mode = demo_mode
    self._default_layout = default_layout
    self._canonical_url = canonical_url
    if data_dir and not os.path.isdir(data_dir):
      os.mkdir(data_dir)
    self._models = {
        name: caching.CachingModelWrapper(model, name, cache_dir=data_dir)
        for name, model in models.items()
    }
    self._datasets = datasets
    self._datasets['_union_empty'] = NoneDataset(self._models)
    if generators is not None:
      self._generators = generators
    else:
      self._generators = {
          'scrambler': scrambler.Scrambler(),
          'word_replacer': word_replacer.WordReplacer(),
      }

    if interpreters is not None:
      self._interpreters = interpreters
    else:
      metrics_group = lit_components.ComponentGroup({
          'regression': metrics.RegressionMetrics(),
          'multiclass': metrics.MulticlassMetrics(),
          'paired': metrics.MulticlassPairedMetrics(),
          'bleu': metrics.CorpusBLEU(),
      })
      self._interpreters = {
          'grad_norm': gradient_maps.GradientNorm(),
          'lime': lime_explainer.LIME(),
          'grad_dot_input': gradient_maps.GradientDotInput(),
          'integrated gradients': gradient_maps.IntegratedGradients(),
          'counterfactual explainer': lemon_explainer.LEMON(),
          'metrics': metrics_group,
          # Embedding projectors expose a standard interface, but get special
          # handling so we can precompute the projections if requested.
          'pca': projection.ProjectionManager(pca.PCAModel),
          'umap': projection.ProjectionManager(umap.UmapModel),
      }

    # Information on models and datasets.
    self._build_metadata()

    # Optionally, run models to pre-populate cache.
    if warm_projections:
      logging.info(
          'Projection (dimensionality reduction) warm-start requested; '
          'will do full warm-start for all models since predictions are needed.'
      )
      warm_start = 1.0

    if warm_start > 0:
      self._warm_start(rate=warm_start)
      self.save_cache()

    # If you add a new embedding projector that should be warm-started,
    # also add it to the list here.
    # TODO(lit-dev): add some registry mechanism / automation if this grows to
    # more than 2-3 projection types.
    if warm_projections:
      self._warm_projections(['pca', 'umap'])

    handlers = {
        # Metadata endpoints.
        '/get_info': self._get_info,
        # Dataset-related endpoints.
        '/get_dataset': self._get_dataset,
        '/get_generated': self._get_generated,
        '/save_datapoints': self._save_datapoints,
        '/load_datapoints': self._load_datapoints,
        '/get_datapoint_ids': self._get_datapoint_ids,
        # Model prediction endpoints.
        '/get_preds': self._get_preds,
        '/get_interpretations': self._get_interpretations,
    }

    self._wsgi_app = wsgi_app.App(
        # Wrap endpoint fns to take (handler, request)
        handlers={k: make_handler(v) for k, v in handlers.items()},
        project_root=client_root,
        index_file='static/index.html',
    )
예제 #4
0
파일: app.py 프로젝트: PAIR-code/lit
    def __init__(
        self,
        models: Mapping[Text, lit_model.Model],
        datasets: Mapping[Text, lit_dataset.Dataset],
        generators: Optional[Mapping[Text, lit_components.Generator]] = None,
        interpreters: Optional[Mapping[Text,
                                       lit_components.Interpreter]] = None,
        annotators: Optional[List[lit_components.Annotator]] = None,
        layouts: Optional[dtypes.LitComponentLayouts] = None,
        # General server config; see server_flags.py.
        data_dir: Optional[Text] = None,
        warm_start: float = 0.0,
        warm_projections: bool = False,
        client_root: Optional[Text] = None,
        demo_mode: bool = False,
        default_layout: Optional[str] = None,
        canonical_url: Optional[str] = None,
        page_title: Optional[str] = None,
        development_demo: bool = False,
    ):
        if client_root is None:
            raise ValueError('client_root must be set on application')
        self._demo_mode = demo_mode
        self._development_demo = development_demo
        self._default_layout = default_layout
        self._canonical_url = canonical_url
        self._page_title = page_title
        self._data_dir = data_dir
        self._layouts = layouts or {}
        if data_dir and not os.path.isdir(data_dir):
            os.mkdir(data_dir)

        # Wrap models in caching wrapper
        self._models = {
            name: caching.CachingModelWrapper(model, name, cache_dir=data_dir)
            for name, model in models.items()
        }

        self._datasets = dict(datasets)
        self._datasets['_union_empty'] = lit_dataset.NoneDataset(self._models)

        self._annotators = annotators or []

        # Run annotation on each dataset, creating an annotated dataset and
        # replace the datasets with the annotated versions.
        for ds_key, ds in self._datasets.items():
            self._datasets[ds_key] = self._run_annotators(ds)

        # Index all datasets
        self._datasets = lit_dataset.IndexedDataset.index_all(
            self._datasets, caching.input_hash)

        if generators is not None:
            self._generators = generators
        else:
            self._generators = {
                'Ablation Flip': ablation_flip.AblationFlip(),
                'Hotflip': hotflip.HotFlip(),
                'Scrambler': scrambler.Scrambler(),
                'Word Replacer': word_replacer.WordReplacer(),
            }

        if interpreters is not None:
            self._interpreters = interpreters
        else:
            metrics_group = lit_components.ComponentGroup({
                'regression':
                metrics.RegressionMetrics(),
                'multiclass':
                metrics.MulticlassMetrics(),
                'paired':
                metrics.MulticlassPairedMetrics(),
                'bleu':
                metrics.CorpusBLEU(),
            })
            self._interpreters = {
                'Grad L2 Norm':
                gradient_maps.GradientNorm(),
                'Grad ⋅ Input':
                gradient_maps.GradientDotInput(),
                'Integrated Gradients':
                gradient_maps.IntegratedGradients(),
                'LIME':
                lime_explainer.LIME(),
                'Model-provided salience':
                model_salience.ModelSalience(self._models),
                'counterfactual explainer':
                lemon_explainer.LEMON(),
                'tcav':
                tcav.TCAV(),
                'thresholder':
                thresholder.Thresholder(),
                'nearest neighbors':
                nearest_neighbors.NearestNeighbors(),
                'metrics':
                metrics_group,
                'pdp':
                pdp.PdpInterpreter(),
                # Embedding projectors expose a standard interface, but get special
                # handling so we can precompute the projections if requested.
                'pca':
                projection.ProjectionManager(pca.PCAModel),
                'umap':
                projection.ProjectionManager(umap.UmapModel),
            }

        # Information on models, datasets, and other components.
        self._info = self._build_metadata()

        # Optionally, run models to pre-populate cache.
        if warm_projections:
            logging.info(
                'Projection (dimensionality reduction) warm-start requested; '
                'will do full warm-start for all models since predictions are needed.'
            )
            warm_start = 1.0

        if warm_start > 0:
            self._warm_start(rate=warm_start)
            self.save_cache()

        # If you add a new embedding projector that should be warm-started,
        # also add it to the list here.
        # TODO(lit-dev): add some registry mechanism / automation if this grows to
        # more than 2-3 projection types.
        if warm_projections:
            self._warm_projections(['pca', 'umap'])

        handlers = {
            # Metadata endpoints.
            '/get_info': self._get_info,
            # Dataset-related endpoints.
            '/get_dataset': self._get_dataset,
            '/create_dataset': self._create_dataset,
            '/create_model': self._create_model,
            '/get_generated': self._get_generated,
            '/save_datapoints': self._save_datapoints,
            '/load_datapoints': self._load_datapoints,
            '/annotate_new_data': self._annotate_new_data,
            # Model prediction endpoints.
            '/get_preds': self._get_preds,
            '/get_interpretations': self._get_interpretations,
        }

        self._wsgi_app = wsgi_app.App(
            # Wrap endpoint fns to take (handler, request, environ)
            handlers={k: self.make_handler(v)
                      for k, v in handlers.items()},
            project_root=client_root,
            index_file='static/index.html',
        )