def call(self, queries: Union[tf.Tensor, Dict[Text, tf.Tensor]], k: Optional[int] = None) -> Tuple[tf.Tensor, tf.Tensor]: """Query the index. Args: queries: Query features. If `query_model` was provided in the constructor, these can be raw query features that will be processed by the query model before performing retrieval. If `query_model` was not provided, these should be pre-computed query embeddings. k: The number of candidates to retrieve. Defaults to constructor `k` parameter if not supplied. Returns: Tuple of (top candidate scores, top candidate identifiers). Raises: ValueError if `index` has not been called. ValueError if `queries` is not a tensor (after being passed through the query model) or is not rank 2. """ k = k if k is not None else self._k if self._serialized_searcher is None: raise ValueError("The `index` method must be called first to " "create the retrieval index.") searcher = scann_ops.searcher_from_module(self._serialized_searcher, self._candidates) if self.query_model is not None: queries = self.query_model(queries) if not isinstance(queries, tf.Tensor): raise ValueError(f"Queries must be a tensor, got {type(queries)}.") if len(queries.shape) == 2: if self._parallelize_batch_searches: result = searcher.search_batched_parallel( queries, final_num_neighbors=k) else: result = searcher.search_batched(queries, final_num_neighbors=k) indices = result.indices distances = result.distances elif len(queries.shape) == 1: result = searcher.search(queries, final_num_neighbors=k) indices = result.index distances = result.distance else: raise ValueError( f"Queries must be of rank 2 or 1, got {len(queries.shape)}.") return distances, tf.gather(self._identifiers, indices)
def call(self, queries: Union[tf.Tensor, Dict[Text, tf.Tensor]], k: Optional[int] = None) -> Tuple[tf.Tensor, tf.Tensor]: k = k if k is not None else self._k if self._serialized_searcher is None: raise ValueError("The `index` method must be called first to " "create the retrieval index.") searcher = scann_ops.searcher_from_module(self._serialized_searcher) if self.query_model is not None: queries = self.query_model(queries) if not isinstance(queries, tf.Tensor): raise ValueError(f"Queries must be a tensor, got {type(queries)}.") if len(queries.shape) == 2: if self._parallelize_batch_searches: result = searcher.search_batched_parallel( queries, final_num_neighbors=k) else: result = searcher.search_batched(queries, final_num_neighbors=k) indices = result.indices distances = result.distances elif len(queries.shape) == 1: result = searcher.search(queries, final_num_neighbors=k) indices = result.index distances = result.distance else: raise ValueError( f"Queries must be of rank 2 or 1, got {len(queries.shape)}.") if self._identifiers is None: return distances, indices return distances, tf.gather(self._identifiers, indices)