def _get_dataset(self, unused_data, dataset_name: Text = None): """Attempt to get dataset, or override with a specific path.""" # TODO(lit-team): add functionality to load data from a given path, as # passed from the frontend? assert dataset_name is not None, 'No dataset specified.' # TODO(lit-team): possibly allow IDs from persisted dataset. return caching.add_hashes_to_input(self._datasets[dataset_name].examples)
def _get_embedding(self, model, example, embedding_name, dataset_name): """Calls the model on the example to get the embedding.""" # TODO(b/158626879): no longer need the add_hashes call. model_input = caching.add_hashes_to_input([example]) model_output = model.predict_with_metadata(model_input, dataset_name=dataset_name) embedding = [o[embedding_name] for o in model_output][0] return embedding
def _train_instance(self, model: lit_model.Model, dataset: lit_dataset.Dataset, config: Dict[Text, Any], name: Text) -> ProjectionInterpreter: # Ignore pytype warning about abstract methods, since this should always # be a subclass of ProjectorModel which has these implemented. projector = self._model_factory(**config.get("proj_kw", {})) # pytype: disable=not-instantiable # TODO(lit-dev): recomputing hashes here is a bit wasteful - consider # creating an 'IndexedDataset' class in the server, and passing that # around so that components can access IndexedInputs directly. train_inputs = caching.add_hashes_to_input(dataset.examples) # TODO(lit-dev): remove 'dataset_name' from caching logic so we don't need # to track it here or elsewhere. train_outputs = list( model.predict_with_metadata( train_inputs, dataset_name=config.get("dataset_name"))) logging.info("Creating new projection instance on %d points", len(train_inputs)) return ProjectionInterpreter( model, train_inputs, train_outputs, projector=projector, field_name=config["field_name"], name=name)
def _get_dataset(self, dataset_name: Text = None): """Convert examples into ones to be used by the model (adding hashes).""" assert dataset_name is not None, "No dataset specified." return caching.add_hashes_to_input( self._datasets[dataset_name].examples)