示例#1
0
    def call(self,
             queries: Union[tf.Tensor, Dict[Text, tf.Tensor]],
             k: Optional[int] = None) -> Tuple[tf.Tensor, tf.Tensor]:
        """Query the index.

    Args:
      queries: Query features. If `query_model` was provided in the constructor,
        these can be raw query features that will be processed by the query
        model before performing retrieval. If `query_model` was not provided,
        these should be pre-computed query embeddings.
      k: The number of candidates to retrieve. Defaults to constructor `k`
        parameter if not supplied.

    Returns:
      Tuple of (top candidate scores, top candidate identifiers).

    Raises:
      ValueError if `index` has not been called.
      ValueError if `queries` is not a tensor (after being passed through
        the query model) or is not rank 2.
    """

        k = k if k is not None else self._k

        if self._serialized_searcher is None:
            raise ValueError("The `index` method must be called first to "
                             "create the retrieval index.")

        searcher = scann_ops.searcher_from_module(self._serialized_searcher,
                                                  self._candidates)

        if self.query_model is not None:
            queries = self.query_model(queries)

        if not isinstance(queries, tf.Tensor):
            raise ValueError(f"Queries must be a tensor, got {type(queries)}.")

        if len(queries.shape) == 2:
            if self._parallelize_batch_searches:
                result = searcher.search_batched_parallel(
                    queries, final_num_neighbors=k)
            else:
                result = searcher.search_batched(queries,
                                                 final_num_neighbors=k)
            indices = result.indices
            distances = result.distances
        elif len(queries.shape) == 1:
            result = searcher.search(queries, final_num_neighbors=k)
            indices = result.index
            distances = result.distance
        else:
            raise ValueError(
                f"Queries must be of rank 2 or 1, got {len(queries.shape)}.")

        return distances, tf.gather(self._identifiers, indices)
    def call(self,
             queries: Union[tf.Tensor, Dict[Text, tf.Tensor]],
             k: Optional[int] = None) -> Tuple[tf.Tensor, tf.Tensor]:

        k = k if k is not None else self._k

        if self._serialized_searcher is None:
            raise ValueError("The `index` method must be called first to "
                             "create the retrieval index.")

        searcher = scann_ops.searcher_from_module(self._serialized_searcher)

        if self.query_model is not None:
            queries = self.query_model(queries)

        if not isinstance(queries, tf.Tensor):
            raise ValueError(f"Queries must be a tensor, got {type(queries)}.")

        if len(queries.shape) == 2:
            if self._parallelize_batch_searches:
                result = searcher.search_batched_parallel(
                    queries, final_num_neighbors=k)
            else:
                result = searcher.search_batched(queries,
                                                 final_num_neighbors=k)
            indices = result.indices
            distances = result.distances
        elif len(queries.shape) == 1:
            result = searcher.search(queries, final_num_neighbors=k)
            indices = result.index
            distances = result.distance
        else:
            raise ValueError(
                f"Queries must be of rank 2 or 1, got {len(queries.shape)}.")

        if self._identifiers is None:
            return distances, indices

        return distances, tf.gather(self._identifiers, indices)