def __init__(self, index_path, args, inference, verbose=False, memtype='mmap'): self.args = args self.doc_maxlen = args.doc_maxlen assert self.doc_maxlen > 0 self.inference = inference self.dim = 128 #TODO self.verbose = verbose # Every pt file gets its own list of doc lengths self.part_doclens = load_doclens(index_path, flatten=False) assert len(self.part_doclens ) > 0, "Did not find any indices at %s" % index_path # Local mmapped tensors with local, single file accesses self.part_mmap: List[file_part_mmap] = re_ranker_mmap._load_parts( index_path, self.part_doclens, memtype) # last pid (inclusive, e.g., the -1) in each pt file # the -1 is used in the np.searchsorted # so if each partition has 1000 docs, the array is [999, 1999, ...] # this helps us map from passage id to part (inclusive, explaning the -1) self.part_pid_end_offsets = np.cumsum( [len(x) for x in self.part_doclens]) - 1 # first pid (inclusive) in each pt file tmp = np.cumsum([len(x) for x in self.part_doclens]) tmp[-1] = 0 self.part_pid_begin_offsets = np.roll(tmp, 1) # [0, 1000, 2000, ...] self.part_pid_begin_offsets
def main(): random.seed(12345) parser = Arguments(description='Faiss indexing for end-to-end retrieval with ColBERT.') parser.add_index_use_input() parser.add_argument('--sample', dest='sample', default=None, type=float) parser.add_argument('--slices', dest='slices', default=1, type=int) args = parser.parse() assert args.slices >= 1 assert args.sample is None or (0.0 < args.sample < 1.0), args.sample with Run.context(): args.index_path = os.path.join(args.index_root, args.index_name) assert os.path.exists(args.index_path), args.index_path num_embeddings = sum(load_doclens(args.index_path)) print("#> num_embeddings =", num_embeddings) if args.partitions is None: args.partitions = 1 << math.ceil(math.log2(8 * math.sqrt(num_embeddings))) print('\n\n') Run.warn("You did not specify --partitions!") Run.warn("Default computation chooses", args.partitions, "partitions (for {} embeddings)".format(num_embeddings)) print('\n\n') index_faiss(args)
def __init__(self, index_path, faiss_index_path, nprobe, part_range=None): print_message("#> Loading the FAISS index from", faiss_index_path, "..") faiss_part_range = os.path.basename(faiss_index_path).split( '.')[-2].split('-') if len(faiss_part_range) == 2: faiss_part_range = range(*map(int, faiss_part_range)) assert part_range[0] in faiss_part_range, (part_range, faiss_part_range) assert part_range[-1] in faiss_part_range, (part_range, faiss_part_range) else: faiss_part_range = None self.part_range = part_range self.faiss_part_range = faiss_part_range self.faiss_index = faiss.read_index(faiss_index_path) self.faiss_index.nprobe = nprobe print_message("#> Building the emb2pid mapping..") all_doclens = load_doclens(index_path, flatten=False) pid_offset = 0 if faiss_part_range is not None: print( f"#> Restricting all_doclens to the range {faiss_part_range}.") pid_offset = len(flatten(all_doclens[:faiss_part_range.start])) all_doclens = all_doclens[faiss_part_range.start:faiss_part_range. stop] self.relative_range = None if self.part_range is not None: start = self.faiss_part_range.start if self.faiss_part_range is not None else 0 a = len(flatten(all_doclens[:self.part_range.start - start])) b = len(flatten(all_doclens[:self.part_range.stop - start])) self.relative_range = range(a, b) print(f"self.relative_range = {self.relative_range}") all_doclens = flatten(all_doclens) total_num_embeddings = sum(all_doclens) self.emb2pid = torch.zeros(total_num_embeddings, dtype=torch.int) offset_doclens = 0 for pid, dlength in enumerate(all_doclens): self.emb2pid[offset_doclens:offset_doclens + dlength] = pid_offset + pid offset_doclens += dlength print_message("len(self.emb2pid) =", len(self.emb2pid)) self.parallel_pool = Pool(16)
def index(self, iterator): from timeit import default_timer as timer starttime = timer() maxdocs = 100 assert not os.path.exists(self.args.index_path), self.args.index_path docnos = [] docid = 0 def convert_gen(iterator): import pyterrier as pt nonlocal docnos nonlocal docid if self.num_docs is not None: iterator = pt.tqdm(iterator, total=self.num_docs, desc="encoding", unit="d") for l in iterator: l["docid"] = docid docnos.append(l['docno']) docid += 1 yield l self.args.generator = convert_gen(iterator) ceg = CollectionEncoder_Generator(self.prepend_title, self.args, 0, 1) create_directory(self.args.index_root) create_directory(self.args.index_path) ceg.encode() self.colbert = ceg.colbert self.checkpoint = ceg.checkpoint assert os.path.exists(self.args.index_path), self.args.index_path num_embeddings = sum(load_doclens(self.args.index_path)) print("#> num_embeddings =", num_embeddings) import pyterrier as pt with pt.io.autoopen( os.path.join(self.args.index_path, "docnos.pkl.gz"), "wb") as f: pickle.dump(docnos, f) if self.args.partitions is None: self.args.partitions = 1 << math.ceil( math.log2(8 * math.sqrt(num_embeddings))) warn("You did not specify --partitions!") warn("Default computation chooses", self.args.partitions, "partitions (for {} embeddings)".format(num_embeddings)) index_faiss(self.args) print("#> Faiss encoding complete") endtime = timer() print("#> Indexing complete, Time elapsed %0.2f seconds" % (endtime - starttime))
def __init__(self, directory, dim=128, part_range=None, verbose=True): first_part, last_part = (0, None) if part_range is None else (part_range.start, part_range.stop) # Load parts metadata all_parts, all_parts_paths, _ = get_parts(directory) self.parts = all_parts[first_part:last_part] self.parts_paths = all_parts_paths[first_part:last_part] # Load doclens metadata all_doclens = load_doclens(directory, flatten=False) self.doc_offset = sum([len(part_doclens) for part_doclens in all_doclens[:first_part]]) self.doc_endpos = sum([len(part_doclens) for part_doclens in all_doclens[:last_part]]) self.pids_range = range(self.doc_offset, self.doc_endpos) self.parts_doclens = all_doclens[first_part:last_part] self.doclens = flatten(self.parts_doclens) self.num_embeddings = sum(self.doclens) self.tensor = self._load_parts(dim, verbose) self.ranker = IndexRanker(self.tensor, self.doclens)