예제 #1
0
    def __init__(self,
                 index_path,
                 args,
                 inference,
                 verbose=False,
                 memtype='mmap'):
        self.args = args
        self.doc_maxlen = args.doc_maxlen
        assert self.doc_maxlen > 0
        self.inference = inference
        self.dim = 128  #TODO
        self.verbose = verbose

        # Every pt file gets its own list of doc lengths
        self.part_doclens = load_doclens(index_path, flatten=False)
        assert len(self.part_doclens
                   ) > 0, "Did not find any indices at %s" % index_path
        # Local mmapped tensors with local, single file accesses
        self.part_mmap: List[file_part_mmap] = re_ranker_mmap._load_parts(
            index_path, self.part_doclens, memtype)

        # last pid (inclusive, e.g., the -1) in each pt file
        # the -1 is used in the np.searchsorted
        # so if each partition has 1000 docs, the array is [999, 1999, ...]
        # this helps us map from passage id to part (inclusive, explaning the -1)
        self.part_pid_end_offsets = np.cumsum(
            [len(x) for x in self.part_doclens]) - 1

        # first pid (inclusive) in each pt file
        tmp = np.cumsum([len(x) for x in self.part_doclens])
        tmp[-1] = 0
        self.part_pid_begin_offsets = np.roll(tmp, 1)
        # [0, 1000, 2000, ...]
        self.part_pid_begin_offsets
예제 #2
0
def main():
    random.seed(12345)

    parser = Arguments(description='Faiss indexing for end-to-end retrieval with ColBERT.')
    parser.add_index_use_input()

    parser.add_argument('--sample', dest='sample', default=None, type=float)
    parser.add_argument('--slices', dest='slices', default=1, type=int)

    args = parser.parse()
    assert args.slices >= 1
    assert args.sample is None or (0.0 < args.sample < 1.0), args.sample

    with Run.context():
        args.index_path = os.path.join(args.index_root, args.index_name)
        assert os.path.exists(args.index_path), args.index_path

        num_embeddings = sum(load_doclens(args.index_path))
        print("#> num_embeddings =", num_embeddings)

        if args.partitions is None:
            args.partitions = 1 << math.ceil(math.log2(8 * math.sqrt(num_embeddings)))
            print('\n\n')
            Run.warn("You did not specify --partitions!")
            Run.warn("Default computation chooses", args.partitions,
                     "partitions (for {} embeddings)".format(num_embeddings))
            print('\n\n')

        index_faiss(args)
예제 #3
0
    def __init__(self, index_path, faiss_index_path, nprobe, part_range=None):
        print_message("#> Loading the FAISS index from", faiss_index_path,
                      "..")

        faiss_part_range = os.path.basename(faiss_index_path).split(
            '.')[-2].split('-')

        if len(faiss_part_range) == 2:
            faiss_part_range = range(*map(int, faiss_part_range))
            assert part_range[0] in faiss_part_range, (part_range,
                                                       faiss_part_range)
            assert part_range[-1] in faiss_part_range, (part_range,
                                                        faiss_part_range)
        else:
            faiss_part_range = None

        self.part_range = part_range
        self.faiss_part_range = faiss_part_range

        self.faiss_index = faiss.read_index(faiss_index_path)
        self.faiss_index.nprobe = nprobe

        print_message("#> Building the emb2pid mapping..")
        all_doclens = load_doclens(index_path, flatten=False)

        pid_offset = 0
        if faiss_part_range is not None:
            print(
                f"#> Restricting all_doclens to the range {faiss_part_range}.")
            pid_offset = len(flatten(all_doclens[:faiss_part_range.start]))
            all_doclens = all_doclens[faiss_part_range.start:faiss_part_range.
                                      stop]

        self.relative_range = None
        if self.part_range is not None:
            start = self.faiss_part_range.start if self.faiss_part_range is not None else 0
            a = len(flatten(all_doclens[:self.part_range.start - start]))
            b = len(flatten(all_doclens[:self.part_range.stop - start]))
            self.relative_range = range(a, b)
            print(f"self.relative_range = {self.relative_range}")

        all_doclens = flatten(all_doclens)

        total_num_embeddings = sum(all_doclens)
        self.emb2pid = torch.zeros(total_num_embeddings, dtype=torch.int)

        offset_doclens = 0
        for pid, dlength in enumerate(all_doclens):
            self.emb2pid[offset_doclens:offset_doclens +
                         dlength] = pid_offset + pid
            offset_doclens += dlength

        print_message("len(self.emb2pid) =", len(self.emb2pid))

        self.parallel_pool = Pool(16)
예제 #4
0
    def index(self, iterator):
        from timeit import default_timer as timer
        starttime = timer()
        maxdocs = 100
        assert not os.path.exists(self.args.index_path), self.args.index_path
        docnos = []
        docid = 0

        def convert_gen(iterator):
            import pyterrier as pt
            nonlocal docnos
            nonlocal docid
            if self.num_docs is not None:
                iterator = pt.tqdm(iterator,
                                   total=self.num_docs,
                                   desc="encoding",
                                   unit="d")
            for l in iterator:
                l["docid"] = docid
                docnos.append(l['docno'])
                docid += 1
                yield l

        self.args.generator = convert_gen(iterator)
        ceg = CollectionEncoder_Generator(self.prepend_title, self.args, 0, 1)
        create_directory(self.args.index_root)
        create_directory(self.args.index_path)
        ceg.encode()
        self.colbert = ceg.colbert
        self.checkpoint = ceg.checkpoint

        assert os.path.exists(self.args.index_path), self.args.index_path
        num_embeddings = sum(load_doclens(self.args.index_path))
        print("#> num_embeddings =", num_embeddings)

        import pyterrier as pt
        with pt.io.autoopen(
                os.path.join(self.args.index_path, "docnos.pkl.gz"),
                "wb") as f:
            pickle.dump(docnos, f)

        if self.args.partitions is None:
            self.args.partitions = 1 << math.ceil(
                math.log2(8 * math.sqrt(num_embeddings)))
            warn("You did not specify --partitions!")
            warn("Default computation chooses", self.args.partitions,
                 "partitions (for {} embeddings)".format(num_embeddings))
        index_faiss(self.args)
        print("#> Faiss encoding complete")
        endtime = timer()
        print("#> Indexing complete, Time elapsed %0.2f seconds" %
              (endtime - starttime))
예제 #5
0
    def __init__(self, directory, dim=128, part_range=None, verbose=True):
        first_part, last_part = (0, None) if part_range is None else (part_range.start, part_range.stop)

        # Load parts metadata
        all_parts, all_parts_paths, _ = get_parts(directory)
        self.parts = all_parts[first_part:last_part]
        self.parts_paths = all_parts_paths[first_part:last_part]

        # Load doclens metadata
        all_doclens = load_doclens(directory, flatten=False)

        self.doc_offset = sum([len(part_doclens) for part_doclens in all_doclens[:first_part]])
        self.doc_endpos = sum([len(part_doclens) for part_doclens in all_doclens[:last_part]])
        self.pids_range = range(self.doc_offset, self.doc_endpos)

        self.parts_doclens = all_doclens[first_part:last_part]
        self.doclens = flatten(self.parts_doclens)
        self.num_embeddings = sum(self.doclens)

        self.tensor = self._load_parts(dim, verbose)
        self.ranker = IndexRanker(self.tensor, self.doclens)