Beispiel #1
0
    def test_parse_sub_string(self):
        generator = word_replacer.WordReplacer()

        query_string = 'foo -> bar, spam -> eggs'
        expected = {'foo': ['bar'], 'spam': ['eggs']}
        self.assertDictEqual(generator.parse_subs_string(query_string),
                             expected)

        # Should ignore the malformed rule
        query_string = 'foo -> bar, spam eggs'
        expected = {'foo': ['bar']}
        self.assertDictEqual(generator.parse_subs_string(query_string),
                             expected)

        # Multiple target tokens.
        query_string = 'foo -> bar, spam -> eggs|donuts | cream'
        expected = {'foo': ['bar'], 'spam': ['eggs', 'donuts', 'cream']}
        self.assertDictEqual(generator.parse_subs_string(query_string),
                             expected)

        query_string = ''
        expected = {}
        self.assertDictEqual(generator.parse_subs_string(query_string),
                             expected)

        query_string = '♞ -> ♟'
        expected = {'♞': ['♟']}
        self.assertDictEqual(generator.parse_subs_string(query_string),
                             expected)
Beispiel #2
0
def main(_):

  ##
  # Load models, according to the --models flag.
  models = {}
  for model_name_or_path in FLAGS.models:
    # Ignore path prefix, if using /path/to/<model_name> to load from a
    # specific directory rather than the default shortcut.
    model_name = os.path.basename(model_name_or_path)
    if model_name.startswith("bert-"):
      models[model_name] = pretrained_lms.BertMLM(
          model_name_or_path, top_k=FLAGS.top_k)
    elif model_name.startswith("gpt2") or model_name in ["distilgpt2"]:
      models[model_name] = pretrained_lms.GPT2LanguageModel(
          model_name_or_path, top_k=FLAGS.top_k)
    else:
      raise ValueError(
          f"Unsupported model name '{model_name}' from path '{model_name_or_path}'"
      )

  datasets = {
      # Single sentences from movie reviews (SST dev set).
      "sst_dev": glue.SST2Data("validation").remap({"sentence": "text"}),
      # Longer passages from movie reviews (IMDB dataset, test split).
      "imdb_train": classification.IMDBData("test"),
      # Empty dataset, if you just want to type sentences into the UI.
      "blank": lm.PlaintextSents(""),
  }
  # Guard this with a flag, because TFDS will download and process 1.67 GB
  # of data if you haven't loaded `lm1b` before.
  if FLAGS.load_bwb:
    # A few sentences from the Billion Word Benchmark (Chelba et al. 2013).
    datasets["bwb"] = lm.BillionWordBenchmark(
        "train", max_examples=FLAGS.max_examples)

  for name in datasets:
    datasets[name] = datasets[name].slice[:FLAGS.max_examples]
    logging.info("Dataset: '%s' with %d examples", name, len(datasets[name]))

  generators = {"word_replacer": word_replacer.WordReplacer()}

  lit_demo = dev_server.Server(
      models,
      datasets,
      generators=generators,
      layouts=CUSTOM_LAYOUTS,
      **server_flags.get_flags())
  return lit_demo.serve()
  def test_parse_sub_string(self):
    generator = word_replacer.WordReplacer()

    query_string = 'foo -> bar, spam -> eggs'
    expected = {'foo': 'bar', 'spam': 'eggs'}
    self.assertDictEqual(generator.parse_subs_string(query_string), expected)

    # Should ignore the malformed rule
    query_string = 'foo -> bar, spam eggs'
    expected = {'foo': 'bar'}
    self.assertDictEqual(generator.parse_subs_string(query_string), expected)

    query_string = ''
    expected = {}
    self.assertDictEqual(generator.parse_subs_string(query_string), expected)

    query_string = '♞ -> ♟'
    expected = {'♞': '♟'}
    self.assertDictEqual(generator.parse_subs_string(query_string), expected)
Beispiel #4
0
def main(_):
    ##
    # Load models. You can specify several here, if you want to compare different
    # models side-by-side, and can also include models of different types that use
    # different datasets.
    base_models = {}
    for model_name_or_path in FLAGS.models:
        # Ignore path prefix, if using /path/to/<model_name> to load from a
        # specific directory rather than the default shortcut.
        model_name = os.path.basename(model_name_or_path)
        if model_name_or_path.startswith("SavedModel"):
            saved_model_path = model_name_or_path.split(":", 1)[1]
            base_models[model_name] = t5.T5SavedModel(saved_model_path)
        else:
            # TODO(lit-dev): attention is temporarily disabled, because O(n^2) between
            # tokens in a long document can get very, very large. Re-enable once we
            # can send this to the frontend more efficiently.
            base_models[model_name] = t5.T5HFModel(
                model_name=model_name_or_path,
                num_to_generate=FLAGS.num_to_generate,
                token_top_k=FLAGS.token_top_k,
                output_attention=False)

    ##
    # Load eval sets and model wrappers for each task.
    # Model wrappers share the same in-memory T5 model, but add task-specific pre-
    # and post-processing code.
    models = {}
    datasets = {}

    if "summarization" in FLAGS.tasks:
        for k, m in base_models.items():
            models[k + "_summarization"] = t5.SummarizationWrapper(m)
        datasets["CNNDM"] = summarization.CNNDMData(
            split="validation", max_examples=FLAGS.max_examples)

    if "mt" in FLAGS.tasks:
        for k, m in base_models.items():
            models[k + "_translation"] = t5.TranslationWrapper(m)
        datasets["wmt14_enfr"] = mt.WMT14Data(version="fr-en", reverse=True)
        datasets["wmt14_ende"] = mt.WMT14Data(version="de-en", reverse=True)

    # Truncate datasets if --max_examples is set.
    for name in datasets:
        logging.info("Dataset: '%s' with %d examples", name,
                     len(datasets[name]))
        datasets[name] = datasets[name].slice[:FLAGS.max_examples]
        logging.info("  truncated to %d examples", len(datasets[name]))

    ##
    # We can also add custom components. Generators are used to create new
    # examples by perturbing or modifying existing ones.
    generators = {
        # Word-substitution, like "great" -> "terrible"
        "word_replacer": word_replacer.WordReplacer(),
    }

    if FLAGS.use_indexer:
        indexer = build_indexer(models)
        # Wrap the indexer into a Generator component that we can query.
        generators[
            "similarity_searcher"] = similarity_searcher.SimilaritySearcher(
                indexer=indexer)

    ##
    # Actually start the LIT server, using the models, datasets, and other
    # components constructed above.
    lit_demo = dev_server.Server(models,
                                 datasets,
                                 generators=generators,
                                 **server_flags.get_flags())
    return lit_demo.serve()
Beispiel #5
0
def create_lit_args(exp_name: str,
                    checkpoints: Set[str] = {"last"}) -> Tuple[tuple, dict]:
    config = load_config(exp_name)

    datasets: Dict[str, lit_dataset.Dataset] = {
        "test": create_test_dataset(config)
    }

    src_spp, trg_spp = config.create_sp_processors()

    model = create_model(config)

    models: Dict[str, NMTModel] = {}

    for checkpoint in checkpoints:
        if checkpoint == "avg":
            checkpoint_path, _ = get_last_checkpoint(config.model_dir / "avg")
            models["avg"] = NMTModel(config,
                                     model,
                                     src_spp,
                                     trg_spp,
                                     -1,
                                     checkpoint_path,
                                     type="avg")
        elif checkpoint == "last":
            last_checkpoint_path, last_step = get_last_checkpoint(
                config.model_dir)
            step_str = str(last_step)
            if step_str in models:
                models[step_str].types.append("last")
            else:
                models[str(last_step)] = NMTModel(config,
                                                  model,
                                                  src_spp,
                                                  trg_spp,
                                                  last_step,
                                                  last_checkpoint_path,
                                                  type="last")
        elif checkpoint == "best":
            best_model_path, best_step = get_best_model_dir(config.model_dir)
            step_str = str(best_step)
            if step_str in models:
                models[step_str].types.append("best")
            else:
                models[step_str] = NMTModel(config,
                                            model,
                                            src_spp,
                                            trg_spp,
                                            best_step,
                                            best_model_path / "ckpt",
                                            type="best")
        else:
            checkpoint_path = config.model_dir / f"ckpt-{checkpoint}"
            step = int(checkpoint)
            models[checkpoint] = NMTModel(config, model, src_spp, trg_spp,
                                          step, checkpoint_path)

    index_datasets: Dict[str, lit_dataset.Dataset] = {
        "test": create_train_dataset(config)
    }
    indexer = IndexerEx(
        models,
        lit_dataset.IndexedDataset.index_all(index_datasets,
                                             caching.input_hash),
        data_dir=str(config.exp_dir / "lit-index"),
        initialize_new_indices=True,
    )

    generators: Dict[str, lit_components.Generator] = {
        "word_replacer": word_replacer.WordReplacer(),
        "similarity_searcher": similarity_searcher.SimilaritySearcher(indexer),
    }

    interpreters: Dict[str, lit_components.Interpreter] = {
        "metrics":
        lit_components.ComponentGroup({"bleu": BLEUMetrics(config.data)}),
        "pca":
        projection.ProjectionManager(pca.PCAModel),
        "umap":
        projection.ProjectionManager(umap.UmapModel),
    }

    return ((models, datasets), {
        "generators": generators,
        "interpreters": interpreters
    })
Beispiel #6
0
def main(_):
  ##
  # Load models. You can specify several here, if you want to compare different
  # models side-by-side, and can also include models of different types that use
  # different datasets.
  models = {}
  for model_name_or_path in FLAGS.models:
    # Ignore path prefix, if using /path/to/<model_name> to load from a
    # specific directory rather than the default shortcut.
    model_name = os.path.basename(model_name_or_path)
    # TODO(lit-dev): attention is temporarily disabled, because O(n^2) between
    # tokens in a long document can get very, very large. Re-enable once we can
    # send this to the frontend more efficiently.
    models[model_name] = t5.T5GenerationModel(
        model_name=model_name_or_path,
        input_prefix="summarize: ",
        top_k=FLAGS.top_k,
        output_attention=False)

  ##
  # Load datasets. Typically you"ll have the constructor actually read the
  # examples and do any pre-processing, so that they"re in memory and ready to
  # send to the frontend when you open the web UI.
  datasets = {
      "CNNDM":
          summarization.CNNDMData(
              split="validation", max_examples=FLAGS.max_examples),
  }
  for name, ds in datasets.items():
    logging.info("Dataset: '%s' with %d examples", name, len(ds))

  ##
  # We can also add custom components. Generators are used to create new
  # examples by perturbing or modifying existing ones.
  generators = {
      # Word-substitution, like "great" -> "terrible"
      "word_replacer": word_replacer.WordReplacer(),
  }

  if FLAGS.use_indexer:
    assert FLAGS.data_dir, "--data_dir must be set to use the indexer."
    # Datasets for indexer - this one loads the training corpus instead of val.
    index_datasets = {
        "CNNDM":
            summarization.CNNDMData(
                split="train", max_examples=FLAGS.max_index_examples),
    }
    # Set up the Indexer, building index if necessary (this may be slow).
    indexer = index.Indexer(
        datasets=index_datasets,
        models=models,
        data_dir=FLAGS.data_dir,
        initialize_new_indices=FLAGS.initialize_index)

    # Wrap the indexer into a Generator component that we can query.
    generators["similarity_searcher"] = similarity_searcher.SimilaritySearcher(
        indexer=indexer)

  ##
  # Actually start the LIT server, using the models, datasets, and other
  # components constructed above.
  lit_demo = dev_server.Server(
      models, datasets, generators=generators, **server_flags.get_flags())
  lit_demo.serve()
Beispiel #7
0
  def test_all_replacements(self):
    input_spec = {'text': lit_types.TextSegment()}
    model = testing_utils.TestRegressionModel(input_spec)
    # Dataset is only used for spec in word_replacer so define once
    dataset = lit_dataset.Dataset(input_spec, [{'text': 'blank'}])

    ## Test replacements
    generator = word_replacer.WordReplacer()
    # Unicode to Unicode
    input_dict = {'text': '♞ is a black chess knight.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: '♞ -> ♟',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': '♟ is a black chess knight.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Unicode to ASCII
    input_dict = {'text': 'Is répertoire a unicode word?'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'répertoire -> repertoire',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'Is repertoire a unicode word?'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Ignore capitalization
    input_dict = {'text': 'Capitalization is ignored.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'Capitalization -> blank',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'blank is ignored.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    input_dict = {'text': 'Capitalization is ignored.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'capitalization -> blank',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'blank is ignored.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Do not Ignore capitalization
    input_dict = {'text': 'Capitalization is important.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'Capitalization -> blank',
        word_replacer.IGNORE_CASING_KEY: False,
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'blank is important.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    input_dict = {'text': 'Capitalization is important.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'capitalization -> blank',
        word_replacer.IGNORE_CASING_KEY: False,
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = []
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Repetition
    input_dict = {'text': 'maybe repetition repetition maybe'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'repetition -> blank',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'maybe blank repetition maybe'},
                {'text': 'maybe repetition blank maybe'}]
    self.assertCountEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # No partial match
    input_dict = {'text': 'A catastrophic storm'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'cat -> blank',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = []
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    ## Special characters
    # Punctuation
    input_dict = {'text': 'A catastrophic storm .'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: '. -> -',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'A catastrophic storm -'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    input_dict = {'text': 'A.catastrophic. storm'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: '. -> -',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'A-catastrophic. storm'},
                {'text': 'A.catastrophic- storm'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    input_dict = {'text': 'A...catastrophic.... storm'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: '.. -> --',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'A--.catastrophic.... storm'},
                {'text': 'A...catastrophic--.. storm'},
                {'text': 'A...catastrophic..-- storm'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Underscore
    input_dict = {'text': 'A catastrophic_storm is raging.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'catastrophic_storm -> nice_storm',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'A nice_storm is raging.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Deletion
    input_dict = {'text': 'A storm is raging.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'storm -> ',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'A  is raging.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Word next to punctuation and words with punctuation.
    input_dict = {'text': 'It`s raining cats and dogs.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'dogs -> blank',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'It`s raining cats and blank.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Multiple target tokens.
    input_dict = {'text': 'It`s raining cats and dogs.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'dogs -> horses|donkeys',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'It`s raining cats and horses.'},
                {'text': 'It`s raining cats and donkeys.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    ## Test default_replacements applied at init.
    replacements = {'tree': ['car']}
    generator = word_replacer.WordReplacer(replacements=replacements)
    input_dict = {'text': 'black truck hit the tree'}
    expected = [{'text': 'black truck hit the car'}]
    config_dict = {
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }

    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    ## Test not passing replacements not breaking.
    generator = word_replacer.WordReplacer()
    input_dict = {'text': 'xyz yzy zzz.'}
    expected = []

    self.assertEqual(
        generator.generate(input_dict, model, dataset), expected)

    # Multi word match.
    input_dict = {'text': 'A red cat is coming.'}
    config_dict = {
        word_replacer.SUBSTITUTIONS_KEY: 'red cat -> black dog',
        word_replacer.FIELDS_TO_REPLACE_KEY: ['text'],
    }
    expected = [{'text': 'A black dog is coming.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)
Beispiel #8
0
  def __init__(
      self,
      models: Mapping[Text, lit_model.Model],
      datasets: MutableMapping[Text, lit_dataset.Dataset],
      generators: Optional[Mapping[Text, lit_components.Generator]] = None,
      interpreters: Optional[Mapping[Text, lit_components.Interpreter]] = None,
      # General server config; see server_flags.py.
      data_dir: Optional[Text] = None,
      warm_start: float = 0.0,
      warm_projections: bool = False,
      client_root: Optional[Text] = None,
      demo_mode: bool = False,
      default_layout: str = None,
      canonical_url: str = None,
  ):
    if client_root is None:
      raise ValueError('client_root must be set on application')
    self._demo_mode = demo_mode
    self._default_layout = default_layout
    self._canonical_url = canonical_url
    if data_dir and not os.path.isdir(data_dir):
      os.mkdir(data_dir)
    self._models = {
        name: caching.CachingModelWrapper(model, name, cache_dir=data_dir)
        for name, model in models.items()
    }
    self._datasets = datasets
    self._datasets['_union_empty'] = NoneDataset(self._models)
    if generators is not None:
      self._generators = generators
    else:
      self._generators = {
          'scrambler': scrambler.Scrambler(),
          'word_replacer': word_replacer.WordReplacer(),
      }

    if interpreters is not None:
      self._interpreters = interpreters
    else:
      metrics_group = lit_components.ComponentGroup({
          'regression': metrics.RegressionMetrics(),
          'multiclass': metrics.MulticlassMetrics(),
          'paired': metrics.MulticlassPairedMetrics(),
          'bleu': metrics.CorpusBLEU(),
      })
      self._interpreters = {
          'grad_norm': gradient_maps.GradientNorm(),
          'lime': lime_explainer.LIME(),
          'grad_dot_input': gradient_maps.GradientDotInput(),
          'integrated gradients': gradient_maps.IntegratedGradients(),
          'counterfactual explainer': lemon_explainer.LEMON(),
          'metrics': metrics_group,
          # Embedding projectors expose a standard interface, but get special
          # handling so we can precompute the projections if requested.
          'pca': projection.ProjectionManager(pca.PCAModel),
          'umap': projection.ProjectionManager(umap.UmapModel),
      }

    # Information on models and datasets.
    self._build_metadata()

    # Optionally, run models to pre-populate cache.
    if warm_projections:
      logging.info(
          'Projection (dimensionality reduction) warm-start requested; '
          'will do full warm-start for all models since predictions are needed.'
      )
      warm_start = 1.0

    if warm_start > 0:
      self._warm_start(rate=warm_start)
      self.save_cache()

    # If you add a new embedding projector that should be warm-started,
    # also add it to the list here.
    # TODO(lit-dev): add some registry mechanism / automation if this grows to
    # more than 2-3 projection types.
    if warm_projections:
      self._warm_projections(['pca', 'umap'])

    handlers = {
        # Metadata endpoints.
        '/get_info': self._get_info,
        # Dataset-related endpoints.
        '/get_dataset': self._get_dataset,
        '/get_generated': self._get_generated,
        '/save_datapoints': self._save_datapoints,
        '/load_datapoints': self._load_datapoints,
        '/get_datapoint_ids': self._get_datapoint_ids,
        # Model prediction endpoints.
        '/get_preds': self._get_preds,
        '/get_interpretations': self._get_interpretations,
    }

    self._wsgi_app = wsgi_app.App(
        # Wrap endpoint fns to take (handler, request)
        handlers={k: make_handler(v) for k, v in handlers.items()},
        project_root=client_root,
        index_file='static/index.html',
    )
  def test_all_replacements(self):
    input_spec = {'text': lit_types.TextSegment()}
    model = testing_utils.TestRegressionModel(input_spec)
    # Dataset is only used for spec in word_replacer so define once
    dataset = lit_dataset.Dataset(input_spec, {'text': 'blank'})

    ## Test replacements
    generator = word_replacer.WordReplacer()
    # Unicode to Unicode
    input_dict = {'text': '♞ is a black chess knight.'}
    config_dict = {'subs': '♞ -> ♟'}
    expected = [{'text': '♟ is a black chess knight.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Unicode to ASCII
    input_dict = {'text': 'Is répertoire a unicode word?'}
    config_dict = {'subs': 'répertoire -> repertoire'}
    expected = [{'text': 'Is repertoire a unicode word?'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Capitalization
    input_dict = {'text': 'Capitalization is important.'}
    config_dict = {'subs': 'Capitalization -> blank'}
    expected = [{'text': 'blank is important.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    input_dict = {'text': 'Capitalization is important.'}
    config_dict = {'subs': 'capitalization -> blank'}
    expected = []
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Repetition
    input_dict = {'text': 'maybe repetition repetition maybe'}
    config_dict = {'subs': 'repetition -> blank'}
    expected = [{'text': 'maybe blank repetition maybe'},
                {'text': 'maybe repetition blank maybe'}]
    self.assertCountEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # No partial match
    input_dict = {'text': 'A catastrophic storm'}
    config_dict = {'subs': 'cat -> blank'}
    expected = []
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    ## Special characters
    # Punctuation
    input_dict = {'text': 'A catastrophic storm .'}
    config_dict = {'subs': '. -> -'}
    expected = [{'text': 'A catastrophic storm -'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Underscore
    input_dict = {'text': 'A catastrophic_storm is raging.'}
    config_dict = {'subs': 'catastrophic_storm -> nice_storm'}
    expected = [{'text': 'A nice_storm is raging.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    # Word next to punctuation and words with punctuation.
    input_dict = {'text': 'It`s raining cats and dogs.'}
    config_dict = {'subs': 'dogs -> blank'}
    expected = [{'text': 'It`s raining cats and blank.'}]
    self.assertEqual(
        generator.generate(input_dict, model, dataset, config=config_dict),
        expected)

    ## Test default_replacements applied at init.
    replacements = {'tree': 'car'}
    generator = word_replacer.WordReplacer(replacements=replacements)
    input_dict = {'text': 'black truck hit the tree'}
    expected = [{'text': 'black truck hit the car'}]

    self.assertEqual(
        generator.generate(input_dict, model, dataset), expected)

    ## Test not passing replacements not breaking.
    generator = word_replacer.WordReplacer()
    input_dict = {'text': 'xyz yzy zzz.'}
    expected = []

    self.assertEqual(
        generator.generate(input_dict, model, dataset), expected)
Beispiel #10
0
    def __init__(
        self,
        models: Mapping[Text, lit_model.Model],
        datasets: Mapping[Text, lit_dataset.Dataset],
        generators: Optional[Mapping[Text, lit_components.Generator]] = None,
        interpreters: Optional[Mapping[Text,
                                       lit_components.Interpreter]] = None,
        annotators: Optional[List[lit_components.Annotator]] = None,
        layouts: Optional[dtypes.LitComponentLayouts] = None,
        # General server config; see server_flags.py.
        data_dir: Optional[Text] = None,
        warm_start: float = 0.0,
        warm_projections: bool = False,
        client_root: Optional[Text] = None,
        demo_mode: bool = False,
        default_layout: Optional[str] = None,
        canonical_url: Optional[str] = None,
        page_title: Optional[str] = None,
        development_demo: bool = False,
    ):
        if client_root is None:
            raise ValueError('client_root must be set on application')
        self._demo_mode = demo_mode
        self._development_demo = development_demo
        self._default_layout = default_layout
        self._canonical_url = canonical_url
        self._page_title = page_title
        self._data_dir = data_dir
        self._layouts = layouts or {}
        if data_dir and not os.path.isdir(data_dir):
            os.mkdir(data_dir)

        # Wrap models in caching wrapper
        self._models = {
            name: caching.CachingModelWrapper(model, name, cache_dir=data_dir)
            for name, model in models.items()
        }

        self._datasets = dict(datasets)
        self._datasets['_union_empty'] = lit_dataset.NoneDataset(self._models)

        self._annotators = annotators or []

        # Run annotation on each dataset, creating an annotated dataset and
        # replace the datasets with the annotated versions.
        for ds_key, ds in self._datasets.items():
            self._datasets[ds_key] = self._run_annotators(ds)

        # Index all datasets
        self._datasets = lit_dataset.IndexedDataset.index_all(
            self._datasets, caching.input_hash)

        if generators is not None:
            self._generators = generators
        else:
            self._generators = {
                'Ablation Flip': ablation_flip.AblationFlip(),
                'Hotflip': hotflip.HotFlip(),
                'Scrambler': scrambler.Scrambler(),
                'Word Replacer': word_replacer.WordReplacer(),
            }

        if interpreters is not None:
            self._interpreters = interpreters
        else:
            metrics_group = lit_components.ComponentGroup({
                'regression':
                metrics.RegressionMetrics(),
                'multiclass':
                metrics.MulticlassMetrics(),
                'paired':
                metrics.MulticlassPairedMetrics(),
                'bleu':
                metrics.CorpusBLEU(),
            })
            self._interpreters = {
                'Grad L2 Norm':
                gradient_maps.GradientNorm(),
                'Grad ⋅ Input':
                gradient_maps.GradientDotInput(),
                'Integrated Gradients':
                gradient_maps.IntegratedGradients(),
                'LIME':
                lime_explainer.LIME(),
                'Model-provided salience':
                model_salience.ModelSalience(self._models),
                'counterfactual explainer':
                lemon_explainer.LEMON(),
                'tcav':
                tcav.TCAV(),
                'thresholder':
                thresholder.Thresholder(),
                'nearest neighbors':
                nearest_neighbors.NearestNeighbors(),
                'metrics':
                metrics_group,
                'pdp':
                pdp.PdpInterpreter(),
                # Embedding projectors expose a standard interface, but get special
                # handling so we can precompute the projections if requested.
                'pca':
                projection.ProjectionManager(pca.PCAModel),
                'umap':
                projection.ProjectionManager(umap.UmapModel),
            }

        # Information on models, datasets, and other components.
        self._info = self._build_metadata()

        # Optionally, run models to pre-populate cache.
        if warm_projections:
            logging.info(
                'Projection (dimensionality reduction) warm-start requested; '
                'will do full warm-start for all models since predictions are needed.'
            )
            warm_start = 1.0

        if warm_start > 0:
            self._warm_start(rate=warm_start)
            self.save_cache()

        # If you add a new embedding projector that should be warm-started,
        # also add it to the list here.
        # TODO(lit-dev): add some registry mechanism / automation if this grows to
        # more than 2-3 projection types.
        if warm_projections:
            self._warm_projections(['pca', 'umap'])

        handlers = {
            # Metadata endpoints.
            '/get_info': self._get_info,
            # Dataset-related endpoints.
            '/get_dataset': self._get_dataset,
            '/create_dataset': self._create_dataset,
            '/create_model': self._create_model,
            '/get_generated': self._get_generated,
            '/save_datapoints': self._save_datapoints,
            '/load_datapoints': self._load_datapoints,
            '/annotate_new_data': self._annotate_new_data,
            # Model prediction endpoints.
            '/get_preds': self._get_preds,
            '/get_interpretations': self._get_interpretations,
        }

        self._wsgi_app = wsgi_app.App(
            # Wrap endpoint fns to take (handler, request, environ)
            handlers={k: self.make_handler(v)
                      for k, v in handlers.items()},
            project_root=client_root,
            index_file='static/index.html',
        )