Example #1
0
  def fit(self, X):
    X /= np.linalg.norm(X, axis=1)[:, np.newaxis]

    self.searcher = scann.ScannBuilder(X, 10, "dot_product").tree(
        self.n_leaves, 1, training_sample_size=350000, spherical=True, quantize_centroids=True).score_ah(
            self.dims_per_block, anisotropic_quantization_threshold=self.avq_threshold).reorder(
                1).create_pybind()
Example #2
0
    def scann(self, train_data, train_labels, test_data, test_labels):
        """
    Run ScaNN. Compute training and testing times

    [1] https://github.com/google-research/google-research/tree/master/scann
    """
        train_time = 0
        time0 = time.time()

        model = scann.ScannBuilder(train_data, self.ann_k, "dot_product").tree(
            num_leaves=330,
            num_leaves_to_search=self.scann_leaves,
            training_sample_size=25000).score_ah(
                dimensions_per_block=2,
                anisotropic_quantization_threshold=0.2).reorder(
                    reordering_num_neighbors=self.scann_leaves).create_pybind(
                    )

        time1 = time.time()
        predicted_labels = self.scann_predict(model, test_data, train_labels)

        if self.test_time_ms == None:
            time2 = time.time()
            self.scann_predict(model, train_data[0:self.test_time_samples],
                               train_labels)
            time3 = time.time()
            self.test_time_ms = 1000.0 * (time3 -
                                          time2) / (self.test_time_samples)

        train_time = time1 - time0
        return predicted_labels, train_time
Example #3
0
    def build_advanced_index(self, vecs: 'np.ndarray'):
        """Load vectors into Scann indexers
        This is a lazy evaluation.
        The .score_ah(...) and .reorder(...) are creating configuration
        and only .create_pybind() is building the object

        1) (Optional) The first step is the partitioning, this will be done
        with .tree(...) during training time,
        and at query time it will select the top partitions
        2) The second stage is the Scoring.
            If partitioning isn't enabled it will measure the distance
            between the query and all datapoints.
            If partitioning is enabled it will measure only within the
            partition to search
        3) (Optional) This is highly recommended if AH was used.
        It will take the top k-distances and re-compute the distance.
        Then the top-k from this new measurement will be selected.
        """
        import scann
        index = scann.ScannBuilder(vecs, self.training_iterations, self.distance_measure). \
            score_ah(self.dimensions_per_block, self.anisotropic_quantization_threshold). \
            reorder(self.reordering_num_neighbors).create_pybind()
        return index
Example #4
0
def main(_):
  query_ids_fn = FLAGS.input_path + "/queries_" + FLAGS.suffix + "_ids.h5py"
  passage_ids_fn = FLAGS.input_path + "/passage_ids.h5py"
  queries = h5py.File(query_ids_fn, "r")
  query_ids = queries["ids"][:]
  passages = h5py.File(passage_ids_fn, "r")
  passage_ids = passages["ids"][:]
  num_neighbors = FLAGS.num_vec_per_passage * FLAGS.num_neighbors
  neighbors_path = FLAGS.output_path + "/neighbors_" + FLAGS.suffix + ".h5py"
  scores_path = FLAGS.output_path + "/scores_" + FLAGS.suffix + ".h5py"
  if not os.path.isfile(neighbors_path):
    query_encoding_fn = FLAGS.input_path + "/queries_" + FLAGS.suffix + "_encodings.h5py"
    queries = h5py.File(query_encoding_fn, "r")
    query_encodings = queries["encodings"][:]
    passage_encoding_fn = FLAGS.input_path + "/passage_encodings.h5py"
    passages = h5py.File(passage_encoding_fn, "r")
    passage_encodings = passages["encodings"][:]
    print("Number of queries: " + str(query_ids.shape[0]))
    print("Number of passages: " + str(passage_ids.shape[0]))

    start = time.time()
    if FLAGS.brute_force:
      print("Start indexing (exact search)")
      searcher = scann.ScannBuilder(
          passage_encodings, num_neighbors,
          "dot_product").score_brute_force().create_pybind()
    else:
      print("Start indexing (approximate search)")
      searcher = scann.ScannBuilder(
          passage_encodings, num_neighbors, "dot_product").tree(
              num_leaves=FLAGS.num_leaves,
              num_leaves_to_search=FLAGS.num_leaves_to_search,
              training_sample_size=passage_encodings.shape[0]).score_ah(
                  2, anisotropic_quantization_threshold=0.2).reorder(
                      num_neighbors).create_pybind()

    end = time.time()
    print("Indexing Time:", end - start)
    start = time.time()
    neighbors, distances = searcher.search_batched_parallel(query_encodings)
    end = time.time()
    print("Search Time:", end - start)

    h5f = h5py.File(neighbors_path, "w")
    h5f.create_dataset("neighbors", data=neighbors)
    h5f.close()
    h5f = h5py.File(scores_path, "w")
    h5f.create_dataset("scores", data=distances)
    h5f.close()

  else:
    print("Neighbors file exists: " + neighbors_path)
    neighbors = h5py.File(neighbors_path, "r")["neighbors"][:]
    print("Scores file exists: " + scores_path)
    distances = h5py.File(scores_path, "r")["scores"][:]

  if FLAGS.write_tsv:
    output_tsv_fn = FLAGS.output_path + "/neighbors_" + FLAGS.suffix + ".tsv"
    write_result_to_tsv(neighbors, query_ids, passage_ids, num_neighbors,
                        output_tsv_fn)
  if FLAGS.write_json:
    output_json_fn = FLAGS.output_path + "/neighbors_" + FLAGS.suffix + ".json"
    write_result_to_json(neighbors, query_ids, passage_ids, distances,
                         num_neighbors, output_json_fn)
def write_to_csv(zh_json, zh_hdf5, ja_json, ja_hdf5, csv_file_name):
    ## 中文,作为 queary
    with open(zh_json) as f:
        zdata = json.load(f)
    query_h5f = h5py.File(zh_hdf5, 'r')
    queries = query_h5f['zh']

    ## 日文,作为 dataset
    with open(ja_json) as f:
        jdata = json.load(f)
    dataset_h5f = h5py.File(ja_hdf5, 'r')
    dataset = dataset_h5f['it']

    # pdb.set_trace()
    ## 创建 Searcher
    normalized_dataset = dataset / np.linalg.norm(dataset, axis=1)[:,
                                                                   np.newaxis]

    searcher = scann.ScannBuilder(normalized_dataset, 10, "dot_product").tree(
        num_leaves=5000, num_leaves_to_search=200,
        training_sample_size=750000).score_ah(
            2, anisotropic_quantization_threshold=0.2).reorder(
                100).create_pybind()

    ## brute_force
    # searcher = scann.ScannBuilder(normalized_dataset, 10, "dot_product").tree(  # dot_product /  squared_l2
    # num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000).score_brute_force(True).create_pybind()

    print("Start to search...")
    start = time.time()
    neighbors, distances = searcher.search_batched(queries,
                                                   final_num_neighbors=1)
    end = time.time()

    print("Search Time:", end - start)

    ## 给下标排序,得到 source/target 对应下标
    dis_top = distances[:, :5]  # top-1 的 distance
    # pdb.set_trace()
    indexs = np.argsort(dis_top, axis=0)  # 按 top-1 的 distance 排序
    save_indexs = indexs[::-1]

    # pdb.set_trace()
    save_neighbors = neighbors[save_indexs, :]  # 获取所有 top-20 的 neighbor id
    source_index = save_indexs
    target_index = save_neighbors
    nbr_row, _ = source_index.shape

    zh_indices = list()
    ja_indices = list()
    start = time.time()
    for i in range(nbr_row):
        # pdb.set_trace()
        _source_index = source_index[i][0]
        _target_index = target_index[i][0]

        zh_indices.append(_source_index)
        ja_indices.append(_target_index)

    ## 保存结果
    submit_file = open(csv_file_name, 'w', newline='', encoding='utf8')
    writer = csv.writer(submit_file, delimiter=',')
    writer.writerow(
        ['file_source', 'location_source', 'file_target',
         'location_target'])  # headline

    ## 遍历获取匹配结果
    nbr_row = min(30000, len(zh_indices))
    for row in range(nbr_row):
        zh_key = str(zh_indices[row])
        file_source = zdata[zh_key]['file']
        location_source = zdata[zh_key][
            'location']  # 'file', 'location', 'text'

        ja_key = str(ja_indices[row][0])
        file_target = jdata[ja_key]['file']
        location_target = jdata[ja_key]['location']

        writer.writerow(
            [file_source, location_source, file_target, location_target])

    ## 关闭文件
    submit_file.close()