def scann(k): """Returns brute-force-like ScaNN for testing.""" return factorized_top_k.ScaNN( k=k, num_leaves=1, num_leaves_to_search=1, num_reordering_candidates=num_candidates)
def test_scann(self, identifier_dtype): num_candidates, num_queries = (1000, 4) rng = np.random.RandomState(42) candidates = rng.normal(size=(num_candidates, 4)).astype(np.float32) query = rng.normal(size=(num_queries, 4)).astype(np.float32) candidate_names = np.arange(num_candidates).astype(identifier_dtype) scann = factorized_top_k.ScaNN() scann.index(candidates, candidate_names) for _ in range(100): pre_serialization_results = scann(query[:2]) path = os.path.join(self.get_temp_dir(), "query_model") scann.save( path, options=tf.saved_model.SaveOptions(namespace_whitelist=["Scann"])) loaded = tf.keras.models.load_model(path) for _ in range(100): post_serialization_results = loaded(tf.constant(query[:2])) self.assertAllEqual(post_serialization_results, pre_serialization_results)
def test_scann_dataset_arg_no_identifiers(self): num_candidates = 100 candidates = tf.data.Dataset.from_tensor_slices( np.random.normal(size=(num_candidates, 4)).astype(np.float32)) index = factorized_top_k.ScaNN() index.index(candidates.batch(100))
def test_scann_dataset_arg_no_identifiers(self): num_candidates, num_queries = (100, 4) rng = np.random.RandomState(42) candidates = tf.data.Dataset.from_tensor_slices( rng.normal(size=(num_candidates, 4)).astype(np.float32)) query = rng.normal(size=(num_queries, 4)).astype(np.float32) scann = factorized_top_k.ScaNN() scann.index_from_dataset(candidates.batch(100)) self.run_save_and_restore_test(scann, query, 100)
def test_scann(self, identifier_dtype): num_candidates, num_queries = (1000, 4) rng = np.random.RandomState(42) candidates = rng.normal(size=(num_candidates, 4)).astype(np.float32) query = rng.normal(size=(num_queries, 4)).astype(np.float32) candidate_names = np.arange(num_candidates).astype(identifier_dtype) scann = factorized_top_k.ScaNN() scann.index(candidates, candidate_names) self.run_save_and_restore_test(scann, query, 100)