Ejemplo n.º 1
0
def launch(task_type: TaskType, profiler: TextProfiler,
           dataset: Dataset[Union[LabeledTextInstance, ScoredTextInstance]]):
    if task_type == TaskType.CLASSIFICATION:
        if not isinstance(dataset.instances[0], LabeledTextInstance):
            raise ValueError("Inconsistent type between Instance and TaskType")
        examples, labels = _setup_classification_dataset(dataset)
        models = {
            "text_classifier": LITModelForTextClassifier(profiler, labels)
        }
        lit_datasets = {
            "classification_dataset":
            TextClassificationLITDataset(examples, labels)
        }
        lit_demo = dev_server.Server(models, lit_datasets,
                                     **server_flags.get_flags())
        lit_demo.serve()
    elif task_type == TaskType.REGRESSION:
        if not isinstance(dataset.instances[0], ScoredTextInstance):
            raise ValueError("Inconsistent type between Instance and TaskType")
        examples = _setup_regression_dataset(dataset)
        models = {"text_regressor": LITModelForTextRegressor(profiler)}
        lit_datasets = {
            "regression_dataset": TextRegressionLITDataset(examples)
        }
        lit_demo = dev_server.Server(models, lit_datasets,
                                     **server_flags.get_flags())
        lit_demo.serve()
    else:
        raise ValueError(
            f"Unsupported task({task_type}) for launching LIT server")
Ejemplo n.º 2
0
def main(_):
    demo_layout = lit_dtypes.LitComponentLayout(
        components={
            'Main': [
                'data-table-module',
                'datapoint-editor-module',
                'lit-slice-module',
                'color-module',
            ],
            'Predictions': ['classification-module', 'scalar-module'],
            'Explanations': ['classification-module', 'salience-map-module'],
        },
        description='Basic layout for image demo',
    )
    datasets = {'imagenette': imagenette.ImagenetteDataset()}
    models = {'mobilenet': mobilenet.MobileNet()}
    interpreters = {
        'Grad': image_gradient_maps.VanillaGradients(),
        'Integrated Gradients': image_gradient_maps.IntegratedGradients(),
        'Blur IG': image_gradient_maps.BlurIG(),
        'Guided IG': image_gradient_maps.GuidedIG(),
        'XRAI': image_gradient_maps.XRAI(),
        'XRAI GIG': image_gradient_maps.XRAIGIG(),
    }

    lit_demo = dev_server.Server(models,
                                 datasets,
                                 interpreters=interpreters,
                                 generators={},
                                 layouts={'demo_layout': demo_layout},
                                 **server_flags.get_flags())
    return lit_demo.serve()
Ejemplo n.º 3
0
def main():
    model = SentimentClassifierModel()
    models = {"sst": model}
    datasets = {"sst": SSTData(labels=model.labels)}

    lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
    lit_demo.serve()
Ejemplo n.º 4
0
def main(_):
    models = {"gector": GectorBertModel('bert_0_gector.th')}
    datasets = {"test_data": GeceProdigyData('test_sample.jsonl')}

    # Start the LIT server. See server_flags.py for server options.
    lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
    lit_demo.serve()
Ejemplo n.º 5
0
def run_server(load_path: str):
    """Run a LIT server with the trained coreference model."""
    # Normally path is a directory; if it's an archive file, download and
    # extract to the transformers cache.
    if load_path.endswith(".tar.gz"):
        load_path = transformers.file_utils.cached_path(
            load_path, extract_compressed_file=True)
    # Load model from disk.
    full_model = model.FrozenEncoderCoref.from_saved(
        load_path,
        encoder_cls=encoders.BertEncoderWithOffsets,
        classifier_cls=edge_predictor.SingleEdgePredictor)

    # Set up the LIT server.
    models = {"model": full_model}
    datasets = {"winogender": winogender.WinogenderDataset()}
    if FLAGS.ontonotes_edgeprobe_path:
        datasets["ontonotes_dev"] = ontonotes.OntonotesCorefDataset(
            os.path.join(FLAGS.ontonotes_edgeprobe_path, "development.json"))
    # Start the LIT server. See server_flags.py for server options.
    lit_demo = dev_server.Server(models,
                                 datasets,
                                 layouts=CUSTOM_LAYOUTS,
                                 **server_flags.get_flags())
    return lit_demo.serve()
Ejemplo n.º 6
0
def main(_):
    # Normally path is a directory; if it's an archive file, download and
    # extract to the transformers cache.
    model_path = FLAGS.model_path
    if model_path.endswith(".tar.gz"):
        model_path = transformers.file_utils.cached_path(
            model_path, extract_compressed_file=True)

    models = {
        "nli": glue_models.MNLIModel(model_path, inference_batch_size=16)
    }
    datasets = {
        "xnli": classification.XNLIData("validation", FLAGS.languages),
        "mnli_dev": glue.MNLIData("validation_matched"),
        "mnli_dev_mm": glue.MNLIData("validation_mismatched"),
    }

    # Truncate datasets if --max_examples is set.
    for name in datasets:
        logging.info("Dataset: '%s' with %d examples", name,
                     len(datasets[name]))
        datasets[name] = datasets[name].slice[:FLAGS.max_examples]
        logging.info("  truncated to %d examples", len(datasets[name]))

    # Start the LIT server. See server_flags.py for server options.
    lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
    lit_demo.serve()
Ejemplo n.º 7
0
def main(_):

  models = {}
  datasets = {}

  if "sst2" in FLAGS.tasks:
    models["sst2"] = glue_models.SST2Model(
        os.path.join(FLAGS.models_path, "sst2"))
    datasets["sst_dev"] = glue.SST2Data("validation")
    logging.info("Loaded models and data for SST-2 task.")

  if "stsb" in FLAGS.tasks:
    models["stsb"] = glue_models.STSBModel(
        os.path.join(FLAGS.models_path, "stsb"))
    datasets["stsb_dev"] = glue.STSBData("validation")
    logging.info("Loaded models and data for STS-B task.")

  if "mnli" in FLAGS.tasks:
    models["mnli"] = glue_models.MNLIModel(
        os.path.join(FLAGS.models_path, "mnli"))
    datasets["mnli_dev"] = glue.MNLIData("validation_matched")
    datasets["mnli_dev_mm"] = glue.MNLIData("validation_mismatched")
    logging.info("Loaded models and data for MultiNLI task.")

  # Truncate datasets if --max_examples is set.
  for name in datasets:
    logging.info("Dataset: '%s' with %d examples", name, len(datasets[name]))
    datasets[name] = datasets[name].slice[:FLAGS.max_examples]
    logging.info("  truncated to %d examples", len(datasets[name]))

  # Start the LIT server. See server_flags.py for server options.
  lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
  lit_demo.serve()
Ejemplo n.º 8
0
    def __init__(self, *args, height=1000, render=False, proxy_url=None, **kw):
        """Start LIT server and optionally render the UI immediately.

    Args:
      *args: Positional arguments for the LitApp.
      height: Height to display the LIT UI in pixels. Defaults to 1000.
      render: Whether to render the UI when this object is constructed.
          Defaults to False.
      proxy_url: Optional proxy URL, if using in a notebook with a server proxy.
          Defaults to None.
      **kw: Keyword arguments for the LitApp.
    """
        app_flags = server_flags.get_flags()
        app_flags['server_type'] = 'notebook'
        app_flags['host'] = 'localhost'
        app_flags['port'] = None
        app_flags.update(kw)

        lit_demo = dev_server.Server(*args, **app_flags)
        self._server = typing.cast(wsgi_serving.NotebookWsgiServer,
                                   lit_demo.serve())
        self._height = height
        self._proxy_url = proxy_url

        if render:
            self.render()
Ejemplo n.º 9
0
def main(_):
    # Load the model we defined above.
    models = {"sst": SimpleSentimentModel(FLAGS.model_path)}
    # Load SST-2 validation set from TFDS.
    datasets = {"sst_dev": glue.SST2Data("validation")}

    # Start the LIT server. See server_flags.py for server options.
    lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
    lit_demo.serve()
Ejemplo n.º 10
0
def main() -> None:
    parser = argparse.ArgumentParser(
        description="Analyzes an NMT model using LIT")
    parser.add_argument("experiment", help="Experiment name")
    parser.add_argument("--memory-growth",
                        default=False,
                        action="store_true",
                        help="Enable memory growth")
    parser.add_argument(
        "--eager-execution",
        default=False,
        action="store_true",
        help="Enable TensorFlow eager execution.",
    )
    parser.add_argument("--checkpoint", type=str, help="Analyze checkpoint")
    parser.add_argument("--last",
                        default=False,
                        action="store_true",
                        help="Analyze last checkpoint")
    parser.add_argument("--best",
                        default=False,
                        action="store_true",
                        help="Analyze best evaluated checkpoint")
    parser.add_argument("--avg",
                        default=False,
                        action="store_true",
                        help="Analyze averaged checkpoint")
    args = parser.parse_args()

    get_git_revision_hash()

    set_tf_log_level()

    if args.eager_execution:
        tf.config.run_functions_eagerly(True)

    if args.memory_growth:
        enable_memory_growth()

    checkpoints: Set[str] = set()
    if args.avg:
        checkpoints.add("avg")
    if args.checkpoint is not None:
        checkpoints.add(args.checkpoint)
    if args.last:
        checkpoints.add("last")
    if args.best:
        checkpoints.add("best")
    if len(checkpoints) == 0:
        checkpoints.add("last")

    server_args, server_kw = create_lit_args(args.experiment, checkpoints)

    lit_server = dev_server.Server(*server_args, **server_kw,
                                   **server_flags.get_flags())
    lit_server.serve()
Ejemplo n.º 11
0
def main(_):
    # Load the model we defined above.
    models = {
        "NCBI BERT Finetuned": NerModel(FLAGS.model_path,
                                        labels_file=FLAGS.labels)
    }
    datasets = {"I2b2 2014": I2b2Dataset(data_dir=FLAGS.test_data_dir)}

    # Start the LIT server. See server_flags.py for server options.
    lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
    lit_demo.serve()
Ejemplo n.º 12
0
def main(_):
    # Quick-start mode.
    if FLAGS.quickstart:
        FLAGS.models = QUICK_START_MODELS  # smaller, faster models
        if FLAGS.max_examples is None or FLAGS.max_examples > 1000:
            FLAGS.max_examples = 1000  # truncate larger eval sets
        logging.info(
            "Quick-start mode; overriding --models and --max_examples.")

    models = {}
    datasets = {}

    tasks_to_load = set()
    for model_string in FLAGS.models:
        # Only split on the first two ':', because path may be a URL
        # containing 'https://'
        name, task, path = model_string.split(":", 2)
        logging.info("Loading model '%s' for task '%s' from '%s'", name, task,
                     path)
        # Normally path is a directory; if it's an archive file, download and
        # extract to the transformers cache.
        if path.endswith(".tar.gz"):
            path = transformers.file_utils.cached_path(
                path, extract_compressed_file=True)
        # Load the model from disk.
        models[name] = MODELS_BY_TASK[task](path)
        tasks_to_load.add(task)

    ##
    # Load datasets for each task that we have a model for
    if "sst2" in tasks_to_load:
        logging.info("Loading data for SST-2 task.")
        datasets["sst_dev"] = glue.SST2Data("validation")

    if "stsb" in tasks_to_load:
        logging.info("Loading data for STS-B task.")
        datasets["stsb_dev"] = glue.STSBData("validation")

    if "mnli" in tasks_to_load:
        logging.info("Loading data for MultiNLI task.")
        datasets["mnli_dev"] = glue.MNLIData("validation_matched")
        datasets["mnli_dev_mm"] = glue.MNLIData("validation_mismatched")

    # Truncate datasets if --max_examples is set.
    for name in datasets:
        logging.info("Dataset: '%s' with %d examples", name,
                     len(datasets[name]))
        datasets[name] = datasets[name].slice[:FLAGS.max_examples]
        logging.info("  truncated to %d examples", len(datasets[name]))

    # Start the LIT server. See server_flags.py for server options.
    lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
    return lit_demo.serve()
Ejemplo n.º 13
0
def main(_):
    model_path = FLAGS.model_path or tempfile.mkdtemp()
    logging.info("Working directory: %s", model_path)
    run_finetuning(model_path)

    # Load our trained model.
    models = {"sst": glue_models.SST2Model(model_path)}
    datasets = {"sst_dev": glue.SST2Data("validation")}

    # Start the LIT server. See server_flags.py for server options.
    lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
    lit_demo.serve()
Ejemplo n.º 14
0
def main(_):
    model_path = FLAGS.model_path

    models = {'species classifier': penguin_model.PenguinModel(model_path)}
    datasets = {'penguins': penguin_data.PenguinDataset()}
    generators = {
        'Minimal Targeted Counterfactuals':
        minimal_targeted_counterfactuals.TabularMTC()
    }
    lit_demo = dev_server.Server(models,
                                 datasets,
                                 generators=generators,
                                 **server_flags.get_flags())
    return lit_demo.serve()
Ejemplo n.º 15
0
def main(_):

  ##
  # Load models, according to the --models flag.
  models = {}
  for model_name_or_path in FLAGS.models:
    # Ignore path prefix, if using /path/to/<model_name> to load from a
    # specific directory rather than the default shortcut.
    model_name = os.path.basename(model_name_or_path)
    if model_name.startswith("bert-"):
      models[model_name] = pretrained_lms.BertMLM(
          model_name_or_path, top_k=FLAGS.top_k)
    elif model_name.startswith("gpt2") or model_name in ["distilgpt2"]:
      models[model_name] = pretrained_lms.GPT2LanguageModel(
          model_name_or_path, top_k=FLAGS.top_k)
    else:
      raise ValueError(
          f"Unsupported model name '{model_name}' from path '{model_name_or_path}'"
      )

  datasets = {
      # Single sentences from movie reviews (SST dev set).
      "sst_dev": glue.SST2Data("validation").remap({"sentence": "text"}),
      # Longer passages from movie reviews (IMDB dataset, test split).
      "imdb_train": classification.IMDBData("test"),
      # Empty dataset, if you just want to type sentences into the UI.
      "blank": lm.PlaintextSents(""),
  }
  # Guard this with a flag, because TFDS will download and process 1.67 GB
  # of data if you haven't loaded `lm1b` before.
  if FLAGS.load_bwb:
    # A few sentences from the Billion Word Benchmark (Chelba et al. 2013).
    datasets["bwb"] = lm.BillionWordBenchmark(
        "train", max_examples=FLAGS.max_examples)

  for name in datasets:
    datasets[name] = datasets[name].slice[:FLAGS.max_examples]
    logging.info("Dataset: '%s' with %d examples", name, len(datasets[name]))

  generators = {"word_replacer": word_replacer.WordReplacer()}

  lit_demo = dev_server.Server(
      models,
      datasets,
      generators=generators,
      layouts=CUSTOM_LAYOUTS,
      **server_flags.get_flags())
  return lit_demo.serve()
Ejemplo n.º 16
0
def main(_):
    # Normally path is a directory; if it's an archive file, download and
    # extract to the transformers cache.
    model_path = FLAGS.model_path
    if model_path.endswith(".tar.gz"):
        model_path = transformers.file_utils.cached_path(
            model_path, extract_compressed_file=True)

    # Load the model we defined above.
    models = {"sst": SimpleSentimentModel(model_path)}
    # Load SST-2 validation set from TFDS.
    datasets = {"sst_dev": glue.SST2Data("validation")}

    # Start the LIT server. See server_flags.py for server options.
    lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
    lit_demo.serve()
Ejemplo n.º 17
0
def main(_):
    model_path = FLAGS.model_path
    logging.info("Working directory: %s", model_path)

    # Load our trained model.
    models = {"toxicity": glue_models.ToxicityModel(model_path)}
    datasets = {"toxicity_test": classification.ToxicityData("test")}

    # Truncate datasets if --max_examples is set.
    for name in datasets:
        logging.info("Dataset: '%s' with %d examples", name,
                     len(datasets[name]))
        datasets[name] = datasets[name].slice[:FLAGS.max_examples]
        logging.info("  truncated to %d examples", len(datasets[name]))

    # Start the LIT server. See server_flags.py for server options.
    lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
    return lit_demo.serve()
Ejemplo n.º 18
0
def main(_):

    models = {
        "nli": glue_models.MNLIModel(FLAGS.model_path, inference_batch_size=16)
    }
    datasets = {
        "xnli": classification.XNLIData("validation", FLAGS.languages),
        "mnli_dev": glue.MNLIData("validation_matched"),
        "mnli_dev_mm": glue.MNLIData("validation_mismatched"),
    }

    # Truncate datasets if --max_examples is set.
    for name in datasets:
        logging.info("Dataset: '%s' with %d examples", name,
                     len(datasets[name]))
        datasets[name] = datasets[name].slice[:FLAGS.max_examples]
        logging.info("  truncated to %d examples", len(datasets[name]))

    # Start the LIT server. See server_flags.py for server options.
    lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
    lit_demo.serve()
Ejemplo n.º 19
0
def main(_):
  # Load the model we defined above.
  models = {"sst": SimpleSentimentModel(FLAGS.model_path)}
  # Load SST-2 validation set from TFDS.
  #datasets = {"sst_dev": glue.SST2Data("validation")}
    
     ## RATING
  #path = "/Users/lucialarraona/Desktop/Codigo TFG ordenado/01. Data/aws_data_raw/dataset_en_test.json"
  #datasets = {"amz_rev": AMAZONes_5cat(path)}
  
     ## CATEGORY AMAZON
  #path = "/Users/lucialarraona/Desktop/Codigo TFG ordenado/01. Data/aws_data_final/test/amazonEN_test.csv"
  #datasets = {"amz_rev": AMAZONes_31cat(path)}
  
      ## CATEGORY CHATBOT
  path = "/Users/lucialarraona/Desktop/Codigo TFG ordenado/01. Data/chatbot_empresa/dato_chatbot_procesado/chatbot_testEN.csv"
  datasets = {"amz_rev": CHATBOTes_30cat(path)}


  # Start the LIT server. See server_flags.py for server options.
  lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
  lit_demo.serve()
Ejemplo n.º 20
0
  def __init__(self,
               *args,
               height=1000,
               render=False,
               proxy_url=None,
               layouts: Optional[dtypes.LitComponentLayouts] = None,
               **kw):
    """Start LIT server and optionally render the UI immediately.

    Args:
      *args: Positional arguments for the LitApp.
      height: Height to display the LIT UI in pixels. Defaults to 1000.
      render: Whether to render the UI when this object is constructed. Defaults
        to False.
      proxy_url: Optional proxy URL, if using in a notebook with a server proxy.
        Defaults to None.
      layouts: Optional custom UI layouts. TODO(lit-dev): support simple module
        lists here as well.
      **kw: Keyword arguments for the LitApp.
    """
    app_flags = server_flags.get_flags()
    app_flags['server_type'] = 'notebook'
    app_flags['host'] = 'localhost'
    app_flags['port'] = None
    app_flags['warm_start'] = 1
    layouts = dict(layouts or {})
    if 'notebook' not in layouts:
      layouts['notebook'] = LIT_NOTEBOOK_LAYOUT
    # This will be 'notebook' unless custom layouts are also given in Python.
    app_flags['default_layout'] = list(layouts.keys())[0]
    app_flags.update(kw)

    lit_demo = dev_server.Server(*args, layouts=layouts, **app_flags)
    self._server = cast(wsgi_serving.NotebookWsgiServer, lit_demo.serve())
    self._height = height
    self._proxy_url = proxy_url

    if render:
      self.render()
Ejemplo n.º 21
0
def main(_):
    ##
    # Load models. You can specify several here, if you want to compare different
    # models side-by-side, and can also include models of different types that use
    # different datasets.
    base_models = {}
    for model_name_or_path in FLAGS.models:
        # Ignore path prefix, if using /path/to/<model_name> to load from a
        # specific directory rather than the default shortcut.
        model_name = os.path.basename(model_name_or_path)
        if model_name_or_path.startswith("SavedModel"):
            saved_model_path = model_name_or_path.split(":", 1)[1]
            base_models[model_name] = t5.T5SavedModel(saved_model_path)
        else:
            # TODO(lit-dev): attention is temporarily disabled, because O(n^2) between
            # tokens in a long document can get very, very large. Re-enable once we
            # can send this to the frontend more efficiently.
            base_models[model_name] = t5.T5HFModel(
                model_name=model_name_or_path,
                num_to_generate=FLAGS.num_to_generate,
                token_top_k=FLAGS.token_top_k,
                output_attention=False)

    ##
    # Load eval sets and model wrappers for each task.
    # Model wrappers share the same in-memory T5 model, but add task-specific pre-
    # and post-processing code.
    models = {}
    datasets = {}

    if "summarization" in FLAGS.tasks:
        for k, m in base_models.items():
            models[k + "_summarization"] = t5.SummarizationWrapper(m)
        datasets["CNNDM"] = summarization.CNNDMData(
            split="validation", max_examples=FLAGS.max_examples)

    if "mt" in FLAGS.tasks:
        for k, m in base_models.items():
            models[k + "_translation"] = t5.TranslationWrapper(m)
        datasets["wmt14_enfr"] = mt.WMT14Data(version="fr-en", reverse=True)
        datasets["wmt14_ende"] = mt.WMT14Data(version="de-en", reverse=True)

    # Truncate datasets if --max_examples is set.
    for name in datasets:
        logging.info("Dataset: '%s' with %d examples", name,
                     len(datasets[name]))
        datasets[name] = datasets[name].slice[:FLAGS.max_examples]
        logging.info("  truncated to %d examples", len(datasets[name]))

    ##
    # We can also add custom components. Generators are used to create new
    # examples by perturbing or modifying existing ones.
    generators = {
        # Word-substitution, like "great" -> "terrible"
        "word_replacer": word_replacer.WordReplacer(),
    }

    if FLAGS.use_indexer:
        indexer = build_indexer(models)
        # Wrap the indexer into a Generator component that we can query.
        generators[
            "similarity_searcher"] = similarity_searcher.SimilaritySearcher(
                indexer=indexer)

    ##
    # Actually start the LIT server, using the models, datasets, and other
    # components constructed above.
    lit_demo = dev_server.Server(models,
                                 datasets,
                                 generators=generators,
                                 **server_flags.get_flags())
    return lit_demo.serve()
Ejemplo n.º 22
0
def main(_):
  ##
  # Load models. You can specify several here, if you want to compare different
  # models side-by-side, and can also include models of different types that use
  # different datasets.
  models = {}
  for model_name_or_path in FLAGS.models:
    # Ignore path prefix, if using /path/to/<model_name> to load from a
    # specific directory rather than the default shortcut.
    model_name = os.path.basename(model_name_or_path)
    # TODO(lit-dev): attention is temporarily disabled, because O(n^2) between
    # tokens in a long document can get very, very large. Re-enable once we can
    # send this to the frontend more efficiently.
    models[model_name] = t5.T5GenerationModel(
        model_name=model_name_or_path,
        input_prefix="summarize: ",
        top_k=FLAGS.top_k,
        output_attention=False)

  ##
  # Load datasets. Typically you"ll have the constructor actually read the
  # examples and do any pre-processing, so that they"re in memory and ready to
  # send to the frontend when you open the web UI.
  datasets = {
      "CNNDM":
          summarization.CNNDMData(
              split="validation", max_examples=FLAGS.max_examples),
  }
  for name, ds in datasets.items():
    logging.info("Dataset: '%s' with %d examples", name, len(ds))

  ##
  # We can also add custom components. Generators are used to create new
  # examples by perturbing or modifying existing ones.
  generators = {
      # Word-substitution, like "great" -> "terrible"
      "word_replacer": word_replacer.WordReplacer(),
  }

  if FLAGS.use_indexer:
    assert FLAGS.data_dir, "--data_dir must be set to use the indexer."
    # Datasets for indexer - this one loads the training corpus instead of val.
    index_datasets = {
        "CNNDM":
            summarization.CNNDMData(
                split="train", max_examples=FLAGS.max_index_examples),
    }
    # Set up the Indexer, building index if necessary (this may be slow).
    indexer = index.Indexer(
        datasets=index_datasets,
        models=models,
        data_dir=FLAGS.data_dir,
        initialize_new_indices=FLAGS.initialize_index)

    # Wrap the indexer into a Generator component that we can query.
    generators["similarity_searcher"] = similarity_searcher.SimilaritySearcher(
        indexer=indexer)

  ##
  # Actually start the LIT server, using the models, datasets, and other
  # components constructed above.
  lit_demo = dev_server.Server(
      models, datasets, generators=generators, **server_flags.get_flags())
  lit_demo.serve()