Пример #1
0
  def __init__(
      self,
      models: Mapping[Text, lit_model.Model],
      datasets: MutableMapping[Text, lit_dataset.Dataset],
      generators: Optional[Mapping[Text, lit_components.Generator]] = None,
      interpreters: Optional[Mapping[Text, lit_components.Interpreter]] = None,
      # General server config; see server_flags.py.
      data_dir: Optional[Text] = None,
      warm_start: float = 0.0,
      warm_projections: bool = False,
      client_root: Optional[Text] = None,
      demo_mode: bool = False,
      default_layout: str = None,
      canonical_url: str = None,
  ):
    if client_root is None:
      raise ValueError('client_root must be set on application')
    self._demo_mode = demo_mode
    self._default_layout = default_layout
    self._canonical_url = canonical_url
    if data_dir and not os.path.isdir(data_dir):
      os.mkdir(data_dir)
    self._models = {
        name: caching.CachingModelWrapper(model, name, cache_dir=data_dir)
        for name, model in models.items()
    }
    self._datasets = datasets
    self._datasets['_union_empty'] = NoneDataset(self._models)
    if generators is not None:
      self._generators = generators
    else:
      self._generators = {
          'scrambler': scrambler.Scrambler(),
          'word_replacer': word_replacer.WordReplacer(),
      }

    if interpreters is not None:
      self._interpreters = interpreters
    else:
      metrics_group = lit_components.ComponentGroup({
          'regression': metrics.RegressionMetrics(),
          'multiclass': metrics.MulticlassMetrics(),
          'paired': metrics.MulticlassPairedMetrics(),
          'bleu': metrics.CorpusBLEU(),
      })
      self._interpreters = {
          'grad_norm': gradient_maps.GradientNorm(),
          'lime': lime_explainer.LIME(),
          'grad_dot_input': gradient_maps.GradientDotInput(),
          'integrated gradients': gradient_maps.IntegratedGradients(),
          'counterfactual explainer': lemon_explainer.LEMON(),
          'metrics': metrics_group,
          # Embedding projectors expose a standard interface, but get special
          # handling so we can precompute the projections if requested.
          'pca': projection.ProjectionManager(pca.PCAModel),
          'umap': projection.ProjectionManager(umap.UmapModel),
      }

    # Information on models and datasets.
    self._build_metadata()

    # Optionally, run models to pre-populate cache.
    if warm_projections:
      logging.info(
          'Projection (dimensionality reduction) warm-start requested; '
          'will do full warm-start for all models since predictions are needed.'
      )
      warm_start = 1.0

    if warm_start > 0:
      self._warm_start(rate=warm_start)
      self.save_cache()

    # If you add a new embedding projector that should be warm-started,
    # also add it to the list here.
    # TODO(lit-dev): add some registry mechanism / automation if this grows to
    # more than 2-3 projection types.
    if warm_projections:
      self._warm_projections(['pca', 'umap'])

    handlers = {
        # Metadata endpoints.
        '/get_info': self._get_info,
        # Dataset-related endpoints.
        '/get_dataset': self._get_dataset,
        '/get_generated': self._get_generated,
        '/save_datapoints': self._save_datapoints,
        '/load_datapoints': self._load_datapoints,
        '/get_datapoint_ids': self._get_datapoint_ids,
        # Model prediction endpoints.
        '/get_preds': self._get_preds,
        '/get_interpretations': self._get_interpretations,
    }

    self._wsgi_app = wsgi_app.App(
        # Wrap endpoint fns to take (handler, request)
        handlers={k: make_handler(v) for k, v in handlers.items()},
        project_root=client_root,
        index_file='static/index.html',
    )
Пример #2
0
def create_lit_args(exp_name: str,
                    checkpoints: Set[str] = {"last"}) -> Tuple[tuple, dict]:
    config = load_config(exp_name)

    datasets: Dict[str, lit_dataset.Dataset] = {
        "test": create_test_dataset(config)
    }

    src_spp, trg_spp = config.create_sp_processors()

    model = create_model(config)

    models: Dict[str, NMTModel] = {}

    for checkpoint in checkpoints:
        if checkpoint == "avg":
            checkpoint_path, _ = get_last_checkpoint(config.model_dir / "avg")
            models["avg"] = NMTModel(config,
                                     model,
                                     src_spp,
                                     trg_spp,
                                     -1,
                                     checkpoint_path,
                                     type="avg")
        elif checkpoint == "last":
            last_checkpoint_path, last_step = get_last_checkpoint(
                config.model_dir)
            step_str = str(last_step)
            if step_str in models:
                models[step_str].types.append("last")
            else:
                models[str(last_step)] = NMTModel(config,
                                                  model,
                                                  src_spp,
                                                  trg_spp,
                                                  last_step,
                                                  last_checkpoint_path,
                                                  type="last")
        elif checkpoint == "best":
            best_model_path, best_step = get_best_model_dir(config.model_dir)
            step_str = str(best_step)
            if step_str in models:
                models[step_str].types.append("best")
            else:
                models[step_str] = NMTModel(config,
                                            model,
                                            src_spp,
                                            trg_spp,
                                            best_step,
                                            best_model_path / "ckpt",
                                            type="best")
        else:
            checkpoint_path = config.model_dir / f"ckpt-{checkpoint}"
            step = int(checkpoint)
            models[checkpoint] = NMTModel(config, model, src_spp, trg_spp,
                                          step, checkpoint_path)

    index_datasets: Dict[str, lit_dataset.Dataset] = {
        "test": create_train_dataset(config)
    }
    indexer = IndexerEx(
        models,
        lit_dataset.IndexedDataset.index_all(index_datasets,
                                             caching.input_hash),
        data_dir=str(config.exp_dir / "lit-index"),
        initialize_new_indices=True,
    )

    generators: Dict[str, lit_components.Generator] = {
        "word_replacer": word_replacer.WordReplacer(),
        "similarity_searcher": similarity_searcher.SimilaritySearcher(indexer),
    }

    interpreters: Dict[str, lit_components.Interpreter] = {
        "metrics":
        lit_components.ComponentGroup({"bleu": BLEUMetrics(config.data)}),
        "pca":
        projection.ProjectionManager(pca.PCAModel),
        "umap":
        projection.ProjectionManager(umap.UmapModel),
    }

    return ((models, datasets), {
        "generators": generators,
        "interpreters": interpreters
    })
Пример #3
0
    def __init__(
        self,
        models: Mapping[Text, lit_model.Model],
        datasets: Mapping[Text, lit_dataset.Dataset],
        generators: Optional[Mapping[Text, lit_components.Generator]] = None,
        interpreters: Optional[Mapping[Text,
                                       lit_components.Interpreter]] = None,
        annotators: Optional[List[lit_components.Annotator]] = None,
        layouts: Optional[dtypes.LitComponentLayouts] = None,
        # General server config; see server_flags.py.
        data_dir: Optional[Text] = None,
        warm_start: float = 0.0,
        warm_projections: bool = False,
        client_root: Optional[Text] = None,
        demo_mode: bool = False,
        default_layout: Optional[str] = None,
        canonical_url: Optional[str] = None,
        page_title: Optional[str] = None,
        development_demo: bool = False,
    ):
        if client_root is None:
            raise ValueError('client_root must be set on application')
        self._demo_mode = demo_mode
        self._development_demo = development_demo
        self._default_layout = default_layout
        self._canonical_url = canonical_url
        self._page_title = page_title
        self._data_dir = data_dir
        self._layouts = layouts or {}
        if data_dir and not os.path.isdir(data_dir):
            os.mkdir(data_dir)

        # Wrap models in caching wrapper
        self._models = {
            name: caching.CachingModelWrapper(model, name, cache_dir=data_dir)
            for name, model in models.items()
        }

        self._datasets = dict(datasets)
        self._datasets['_union_empty'] = lit_dataset.NoneDataset(self._models)

        self._annotators = annotators or []

        # Run annotation on each dataset, creating an annotated dataset and
        # replace the datasets with the annotated versions.
        for ds_key, ds in self._datasets.items():
            self._datasets[ds_key] = self._run_annotators(ds)

        # Index all datasets
        self._datasets = lit_dataset.IndexedDataset.index_all(
            self._datasets, caching.input_hash)

        if generators is not None:
            self._generators = generators
        else:
            self._generators = {
                'Ablation Flip': ablation_flip.AblationFlip(),
                'Hotflip': hotflip.HotFlip(),
                'Scrambler': scrambler.Scrambler(),
                'Word Replacer': word_replacer.WordReplacer(),
            }

        if interpreters is not None:
            self._interpreters = interpreters
        else:
            metrics_group = lit_components.ComponentGroup({
                'regression':
                metrics.RegressionMetrics(),
                'multiclass':
                metrics.MulticlassMetrics(),
                'paired':
                metrics.MulticlassPairedMetrics(),
                'bleu':
                metrics.CorpusBLEU(),
            })
            self._interpreters = {
                'Grad L2 Norm':
                gradient_maps.GradientNorm(),
                'Grad ⋅ Input':
                gradient_maps.GradientDotInput(),
                'Integrated Gradients':
                gradient_maps.IntegratedGradients(),
                'LIME':
                lime_explainer.LIME(),
                'Model-provided salience':
                model_salience.ModelSalience(self._models),
                'counterfactual explainer':
                lemon_explainer.LEMON(),
                'tcav':
                tcav.TCAV(),
                'thresholder':
                thresholder.Thresholder(),
                'nearest neighbors':
                nearest_neighbors.NearestNeighbors(),
                'metrics':
                metrics_group,
                'pdp':
                pdp.PdpInterpreter(),
                # Embedding projectors expose a standard interface, but get special
                # handling so we can precompute the projections if requested.
                'pca':
                projection.ProjectionManager(pca.PCAModel),
                'umap':
                projection.ProjectionManager(umap.UmapModel),
            }

        # Information on models, datasets, and other components.
        self._info = self._build_metadata()

        # Optionally, run models to pre-populate cache.
        if warm_projections:
            logging.info(
                'Projection (dimensionality reduction) warm-start requested; '
                'will do full warm-start for all models since predictions are needed.'
            )
            warm_start = 1.0

        if warm_start > 0:
            self._warm_start(rate=warm_start)
            self.save_cache()

        # If you add a new embedding projector that should be warm-started,
        # also add it to the list here.
        # TODO(lit-dev): add some registry mechanism / automation if this grows to
        # more than 2-3 projection types.
        if warm_projections:
            self._warm_projections(['pca', 'umap'])

        handlers = {
            # Metadata endpoints.
            '/get_info': self._get_info,
            # Dataset-related endpoints.
            '/get_dataset': self._get_dataset,
            '/create_dataset': self._create_dataset,
            '/create_model': self._create_model,
            '/get_generated': self._get_generated,
            '/save_datapoints': self._save_datapoints,
            '/load_datapoints': self._load_datapoints,
            '/annotate_new_data': self._annotate_new_data,
            # Model prediction endpoints.
            '/get_preds': self._get_preds,
            '/get_interpretations': self._get_interpretations,
        }

        self._wsgi_app = wsgi_app.App(
            # Wrap endpoint fns to take (handler, request, environ)
            handlers={k: self.make_handler(v)
                      for k, v in handlers.items()},
            project_root=client_root,
            index_file='static/index.html',
        )