예제 #1
0
파일: app.py 프로젝트: zhiyiZeng/lit
  def _get_dataset(self, unused_data, dataset_name: Text = None):
    """Attempt to get dataset, or override with a specific path."""

    # TODO(lit-team): add functionality to load data from a given path, as
    # passed from the frontend?
    assert dataset_name is not None, 'No dataset specified.'
    # TODO(lit-team): possibly allow IDs from persisted dataset.
    return caching.add_hashes_to_input(self._datasets[dataset_name].examples)
예제 #2
0
 def _get_embedding(self, model, example, embedding_name, dataset_name):
     """Calls the model on the example to get the embedding."""
     # TODO(b/158626879): no longer need the add_hashes call.
     model_input = caching.add_hashes_to_input([example])
     model_output = model.predict_with_metadata(model_input,
                                                dataset_name=dataset_name)
     embedding = [o[embedding_name] for o in model_output][0]
     return embedding
예제 #3
0
 def _train_instance(self, model: lit_model.Model,
                     dataset: lit_dataset.Dataset, config: Dict[Text, Any],
                     name: Text) -> ProjectionInterpreter:
   # Ignore pytype warning about abstract methods, since this should always
   # be a subclass of ProjectorModel which has these implemented.
   projector = self._model_factory(**config.get("proj_kw", {}))  # pytype: disable=not-instantiable
   # TODO(lit-dev): recomputing hashes here is a bit wasteful - consider
   # creating an 'IndexedDataset' class in the server, and passing that
   # around so that components can access IndexedInputs directly.
   train_inputs = caching.add_hashes_to_input(dataset.examples)
   # TODO(lit-dev): remove 'dataset_name' from caching logic so we don't need
   # to track it here or elsewhere.
   train_outputs = list(
       model.predict_with_metadata(
           train_inputs, dataset_name=config.get("dataset_name")))
   logging.info("Creating new projection instance on %d points",
                len(train_inputs))
   return ProjectionInterpreter(
       model,
       train_inputs,
       train_outputs,
       projector=projector,
       field_name=config["field_name"],
       name=name)
예제 #4
0
 def _get_dataset(self, dataset_name: Text = None):
     """Convert examples into ones to be used by the model (adding hashes)."""
     assert dataset_name is not None, "No dataset specified."
     return caching.add_hashes_to_input(
         self._datasets[dataset_name].examples)