def scann(k):
     """Returns brute-force-like ScaNN for testing."""
     return factorized_top_k.ScaNN(
         k=k,
         num_leaves=1,
         num_leaves_to_search=1,
         num_reordering_candidates=num_candidates)
Пример #2
0
    def test_scann(self, identifier_dtype):

        num_candidates, num_queries = (1000, 4)

        rng = np.random.RandomState(42)
        candidates = rng.normal(size=(num_candidates, 4)).astype(np.float32)
        query = rng.normal(size=(num_queries, 4)).astype(np.float32)
        candidate_names = np.arange(num_candidates).astype(identifier_dtype)

        scann = factorized_top_k.ScaNN()
        scann.index(candidates, candidate_names)

        for _ in range(100):
            pre_serialization_results = scann(query[:2])

        path = os.path.join(self.get_temp_dir(), "query_model")
        scann.save(
            path,
            options=tf.saved_model.SaveOptions(namespace_whitelist=["Scann"]))
        loaded = tf.keras.models.load_model(path)

        for _ in range(100):
            post_serialization_results = loaded(tf.constant(query[:2]))

        self.assertAllEqual(post_serialization_results,
                            pre_serialization_results)
Пример #3
0
    def test_scann_dataset_arg_no_identifiers(self):

        num_candidates = 100
        candidates = tf.data.Dataset.from_tensor_slices(
            np.random.normal(size=(num_candidates, 4)).astype(np.float32))

        index = factorized_top_k.ScaNN()
        index.index(candidates.batch(100))
    def test_scann_dataset_arg_no_identifiers(self):

        num_candidates, num_queries = (100, 4)

        rng = np.random.RandomState(42)
        candidates = tf.data.Dataset.from_tensor_slices(
            rng.normal(size=(num_candidates, 4)).astype(np.float32))
        query = rng.normal(size=(num_queries, 4)).astype(np.float32)

        scann = factorized_top_k.ScaNN()
        scann.index_from_dataset(candidates.batch(100))

        self.run_save_and_restore_test(scann, query, 100)
    def test_scann(self, identifier_dtype):

        num_candidates, num_queries = (1000, 4)

        rng = np.random.RandomState(42)
        candidates = rng.normal(size=(num_candidates, 4)).astype(np.float32)
        query = rng.normal(size=(num_queries, 4)).astype(np.float32)
        candidate_names = np.arange(num_candidates).astype(identifier_dtype)

        scann = factorized_top_k.ScaNN()
        scann.index(candidates, candidate_names)

        self.run_save_and_restore_test(scann, query, 100)