def fit(self, X): X /= np.linalg.norm(X, axis=1)[:, np.newaxis] self.searcher = scann.ScannBuilder(X, 10, "dot_product").tree( self.n_leaves, 1, training_sample_size=350000, spherical=True, quantize_centroids=True).score_ah( self.dims_per_block, anisotropic_quantization_threshold=self.avq_threshold).reorder( 1).create_pybind()
def scann(self, train_data, train_labels, test_data, test_labels): """ Run ScaNN. Compute training and testing times [1] https://github.com/google-research/google-research/tree/master/scann """ train_time = 0 time0 = time.time() model = scann.ScannBuilder(train_data, self.ann_k, "dot_product").tree( num_leaves=330, num_leaves_to_search=self.scann_leaves, training_sample_size=25000).score_ah( dimensions_per_block=2, anisotropic_quantization_threshold=0.2).reorder( reordering_num_neighbors=self.scann_leaves).create_pybind( ) time1 = time.time() predicted_labels = self.scann_predict(model, test_data, train_labels) if self.test_time_ms == None: time2 = time.time() self.scann_predict(model, train_data[0:self.test_time_samples], train_labels) time3 = time.time() self.test_time_ms = 1000.0 * (time3 - time2) / (self.test_time_samples) train_time = time1 - time0 return predicted_labels, train_time
def build_advanced_index(self, vecs: 'np.ndarray'): """Load vectors into Scann indexers This is a lazy evaluation. The .score_ah(...) and .reorder(...) are creating configuration and only .create_pybind() is building the object 1) (Optional) The first step is the partitioning, this will be done with .tree(...) during training time, and at query time it will select the top partitions 2) The second stage is the Scoring. If partitioning isn't enabled it will measure the distance between the query and all datapoints. If partitioning is enabled it will measure only within the partition to search 3) (Optional) This is highly recommended if AH was used. It will take the top k-distances and re-compute the distance. Then the top-k from this new measurement will be selected. """ import scann index = scann.ScannBuilder(vecs, self.training_iterations, self.distance_measure). \ score_ah(self.dimensions_per_block, self.anisotropic_quantization_threshold). \ reorder(self.reordering_num_neighbors).create_pybind() return index
def main(_): query_ids_fn = FLAGS.input_path + "/queries_" + FLAGS.suffix + "_ids.h5py" passage_ids_fn = FLAGS.input_path + "/passage_ids.h5py" queries = h5py.File(query_ids_fn, "r") query_ids = queries["ids"][:] passages = h5py.File(passage_ids_fn, "r") passage_ids = passages["ids"][:] num_neighbors = FLAGS.num_vec_per_passage * FLAGS.num_neighbors neighbors_path = FLAGS.output_path + "/neighbors_" + FLAGS.suffix + ".h5py" scores_path = FLAGS.output_path + "/scores_" + FLAGS.suffix + ".h5py" if not os.path.isfile(neighbors_path): query_encoding_fn = FLAGS.input_path + "/queries_" + FLAGS.suffix + "_encodings.h5py" queries = h5py.File(query_encoding_fn, "r") query_encodings = queries["encodings"][:] passage_encoding_fn = FLAGS.input_path + "/passage_encodings.h5py" passages = h5py.File(passage_encoding_fn, "r") passage_encodings = passages["encodings"][:] print("Number of queries: " + str(query_ids.shape[0])) print("Number of passages: " + str(passage_ids.shape[0])) start = time.time() if FLAGS.brute_force: print("Start indexing (exact search)") searcher = scann.ScannBuilder( passage_encodings, num_neighbors, "dot_product").score_brute_force().create_pybind() else: print("Start indexing (approximate search)") searcher = scann.ScannBuilder( passage_encodings, num_neighbors, "dot_product").tree( num_leaves=FLAGS.num_leaves, num_leaves_to_search=FLAGS.num_leaves_to_search, training_sample_size=passage_encodings.shape[0]).score_ah( 2, anisotropic_quantization_threshold=0.2).reorder( num_neighbors).create_pybind() end = time.time() print("Indexing Time:", end - start) start = time.time() neighbors, distances = searcher.search_batched_parallel(query_encodings) end = time.time() print("Search Time:", end - start) h5f = h5py.File(neighbors_path, "w") h5f.create_dataset("neighbors", data=neighbors) h5f.close() h5f = h5py.File(scores_path, "w") h5f.create_dataset("scores", data=distances) h5f.close() else: print("Neighbors file exists: " + neighbors_path) neighbors = h5py.File(neighbors_path, "r")["neighbors"][:] print("Scores file exists: " + scores_path) distances = h5py.File(scores_path, "r")["scores"][:] if FLAGS.write_tsv: output_tsv_fn = FLAGS.output_path + "/neighbors_" + FLAGS.suffix + ".tsv" write_result_to_tsv(neighbors, query_ids, passage_ids, num_neighbors, output_tsv_fn) if FLAGS.write_json: output_json_fn = FLAGS.output_path + "/neighbors_" + FLAGS.suffix + ".json" write_result_to_json(neighbors, query_ids, passage_ids, distances, num_neighbors, output_json_fn)
def write_to_csv(zh_json, zh_hdf5, ja_json, ja_hdf5, csv_file_name): ## 中文,作为 queary with open(zh_json) as f: zdata = json.load(f) query_h5f = h5py.File(zh_hdf5, 'r') queries = query_h5f['zh'] ## 日文,作为 dataset with open(ja_json) as f: jdata = json.load(f) dataset_h5f = h5py.File(ja_hdf5, 'r') dataset = dataset_h5f['it'] # pdb.set_trace() ## 创建 Searcher normalized_dataset = dataset / np.linalg.norm(dataset, axis=1)[:, np.newaxis] searcher = scann.ScannBuilder(normalized_dataset, 10, "dot_product").tree( num_leaves=5000, num_leaves_to_search=200, training_sample_size=750000).score_ah( 2, anisotropic_quantization_threshold=0.2).reorder( 100).create_pybind() ## brute_force # searcher = scann.ScannBuilder(normalized_dataset, 10, "dot_product").tree( # dot_product / squared_l2 # num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000).score_brute_force(True).create_pybind() print("Start to search...") start = time.time() neighbors, distances = searcher.search_batched(queries, final_num_neighbors=1) end = time.time() print("Search Time:", end - start) ## 给下标排序,得到 source/target 对应下标 dis_top = distances[:, :5] # top-1 的 distance # pdb.set_trace() indexs = np.argsort(dis_top, axis=0) # 按 top-1 的 distance 排序 save_indexs = indexs[::-1] # pdb.set_trace() save_neighbors = neighbors[save_indexs, :] # 获取所有 top-20 的 neighbor id source_index = save_indexs target_index = save_neighbors nbr_row, _ = source_index.shape zh_indices = list() ja_indices = list() start = time.time() for i in range(nbr_row): # pdb.set_trace() _source_index = source_index[i][0] _target_index = target_index[i][0] zh_indices.append(_source_index) ja_indices.append(_target_index) ## 保存结果 submit_file = open(csv_file_name, 'w', newline='', encoding='utf8') writer = csv.writer(submit_file, delimiter=',') writer.writerow( ['file_source', 'location_source', 'file_target', 'location_target']) # headline ## 遍历获取匹配结果 nbr_row = min(30000, len(zh_indices)) for row in range(nbr_row): zh_key = str(zh_indices[row]) file_source = zdata[zh_key]['file'] location_source = zdata[zh_key][ 'location'] # 'file', 'location', 'text' ja_key = str(ja_indices[row][0]) file_target = jdata[ja_key]['file'] location_target = jdata[ja_key]['location'] writer.writerow( [file_source, location_source, file_target, location_target]) ## 关闭文件 submit_file.close()