コード例 #1
0
def load_model(args, do_print=True):
    # colbert = ColBERT.from_pretrained('bert-base-uncased',
    #                                   query_maxlen=args.query_maxlen,
    #                                   doc_maxlen=args.doc_maxlen,
    #                                   dim=args.dim,
    #                                   similarity_metric=args.similarity,
    #                                   mask_punctuation=args.mask_punctuation)
    bert_config = BertConfig().from_pretrained('bert-base-uncased')

    colbert = ColKBERT(bert_config,
                       query_maxlen=args.query_maxlen,
                       doc_maxlen=args.doc_maxlen,
                       dim=args.dim,
                       similarity_metric=args.similarity,
                       mask_punctuation=args.mask_punctuation)

    colbert = colbert.to(DEVICE)

    print_message("#> Loading model checkpoint.", condition=do_print)

    checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print)

    colbert.eval()

    return colbert, checkpoint
コード例 #2
0
ファイル: faiss_index.py プロジェクト: vjeronymo2/ColBERT
    def queries_to_embedding_ids(self, faiss_depth, Q, verbose=True):
        # Flatten into a matrix for the faiss search.
        num_queries, embeddings_per_query, dim = Q.size()
        Q_faiss = Q.view(num_queries * embeddings_per_query,
                         dim).cpu().contiguous()

        # Search in large batches with faiss.
        print_message(
            "#> Search in batches with faiss. \t\t",
            f"Q.size() = {Q.size()}, Q_faiss.size() = {Q_faiss.size()}",
            condition=verbose)

        embeddings_ids = []
        faiss_bsize = embeddings_per_query * 5000
        for offset in range(0, Q_faiss.size(0), faiss_bsize):
            endpos = min(offset + faiss_bsize, Q_faiss.size(0))

            print_message("#> Searching from {} to {}...".format(
                offset, endpos),
                          condition=verbose)

            some_Q_faiss = Q_faiss[offset:endpos].float().numpy()
            _, some_embedding_ids = self.faiss_index.search(
                some_Q_faiss, faiss_depth)
            embeddings_ids.append(torch.from_numpy(some_embedding_ids))

        embedding_ids = torch.cat(embeddings_ids)

        # Reshape to (number of queries, non-unique embedding IDs per query)
        embedding_ids = embedding_ids.view(
            num_queries, embeddings_per_query * embedding_ids.size(1))

        return embedding_ids
コード例 #3
0
ファイル: runs.py プロジェクト: vjeronymo2/ColBERT
    def init(self, rank, root, experiment, name):
        assert '/' not in experiment, experiment
        assert '/' not in name, name

        self.experiments_root = os.path.abspath(root)
        self.experiment = experiment
        self.name = name
        self.path = os.path.join(self.experiments_root, self.experiment, self.script, self.name)

        if rank < 1:
            if os.path.exists(self.path):
                print('\n\n')
                print_message("It seems that ", self.path, " already exists.")
                print_message("Do you want to overwrite it? \t yes/no \n")

                # TODO: This should timeout and exit (i.e., fail) given no response for 60 seconds.

                response = input()
                if response.strip() != 'yes':
                    assert not os.path.exists(self.path), self.path
            else:
                create_directory(self.path)

        distributed.barrier(rank)

        self._logger = Logger(rank, self)
        self._log_args = self._logger._log_args
        self.warn = self._logger.warn
        self.info = self._logger.info
        self.info_all = self._logger.info_all
        self.log_metric = self._logger.log_metric
        self.log_new_artifact = self._logger.log_new_artifact
コード例 #4
0
    def add(self, index, data, offset):
        assert self.ngpu > 0
        print_message('Adding...')
        t0 = time.time()
        nb = data.shape[0]
        for i0 in range(0, nb, self.add_batch_size):
            i1 = min(i0 + self.add_batch_size, nb)
            xs = data[i0:i1]

            # print_message(f"Add with ids {type(xs)} {xs.shape}")
            self.gpu_index.add_with_ids(xs, np.arange(offset + i0,
                                                      offset + i1))

            if self.max_add > 0 and self.gpu_index.ntotal > self.max_add:
                self._flush_to_cpu(index, nb, offset)

            print('\r%d/%d (%.3f s)  ' % (i0, nb, time.time() - t0), end=' ')
            sys.stdout.flush()

        if self.gpu_index.ntotal > 0:
            self._flush_to_cpu(index, nb, offset)

        assert index.ntotal == offset + nb, (index.ntotal, offset + nb, offset,
                                             nb)
        print(f"add(.) time: %.3f s \t\t--\t\t index.ntotal = {index.ntotal}" %
              (time.time() - t0))
コード例 #5
0
def main(args):
    Rankings = defaultdict(list)

    for path in args.input:
        print_message(f"#> Loading the rankings in {path} ..")

        with open(path) as f:
            for line in file_tqdm(f):
                qid, pid, rank, score = line.strip().split('\t')
                qid, pid, rank = map(int, [qid, pid, rank])
                score = float(score)

                Rankings[qid].append((score, rank, pid))

    with open(args.output, 'w') as f:
        print_message(f"#> Writing the output rankings to {args.output} ..")

        for qid in tqdm.tqdm(Rankings):
            ranking = sorted(Rankings[qid], reverse=True)

            for rank, (score, original_rank, pid) in enumerate(ranking):
                rank = rank + 1  # 1-indexed

                if (args.depth > 0) and (rank > args.depth):
                    break

                line = [qid, pid, rank, score]
                line = '\t'.join(map(str, line)) + '\n'
                f.write(line)
コード例 #6
0
def main(args):
    print_message("#> Starting...")

    collectionX_path = os.path.join(args.datadir, 'wiki.abstracts.2017/collection.json')
    queries_path = os.path.join(args.datadir, 'hover/dev/questions.tsv')
    qas_path = os.path.join(args.datadir, 'hover/dev/qas.json')

    checkpointL1 = os.path.join(args.datadir, 'hover.checkpoints-v1.0/condenserL1-v1.0.dnn')
    checkpointL2 = os.path.join(args.datadir, 'hover.checkpoints-v1.0/condenserL2-v1.0.dnn')

    with Run().context(RunConfig(root=args.root)):
        searcher = HopSearcher(index=args.index)
        condenser = Condenser(checkpointL1=checkpointL1, checkpointL2=checkpointL2,
                              collectionX_path=collectionX_path, deviceL1='cuda:0', deviceL2='cuda:0')

        baleen = Baleen(collectionX_path, searcher, condenser)
        baleen.searcher.configure(nprobe=2, ncandidates=8192)

    queries = Queries(path=queries_path)
    outputs = {}

    for qid, query in tqdm.tqdm(queries.items()):
        facts, pids_bag, _ = baleen.search(query, num_hops=4)
        outputs[qid] = (facts, pids_bag)

    with Run().open('output.json', 'w') as f:
        f.write(ujson.dumps(outputs) + '\n')
コード例 #7
0
    def __init__(self, tensor, doclens):
        self.tensor = tensor
        self.doclens = doclens

        self.maxsim_dtype = torch.float32
        self.doclens_pfxsum = [0] + list(accumulate(self.doclens))

        self.doclens = torch.tensor(self.doclens)
        self.doclens_pfxsum = torch.tensor(self.doclens_pfxsum)

        self.dim = self.tensor.size(-1)

        self.strides = [
            torch_percentile(self.doclens, p) for p in [25, 50, 75]
        ]
        self.strides.append(self.doclens.max().item())
        self.strides = sorted(list(set(self.strides)))

        print_message(f"#> Using strides {self.strides}..")

        self.views = self._create_views(self.tensor)
        device = 'cuda:0' if DEVICE == 'cuda' else 'cpu'
        print_message(f"device: {device}")
        self.buffers = self._create_buffers(BSIZE, self.tensor.dtype,
                                            {'cpu', device})
コード例 #8
0
ファイル: batch_reranking.py プロジェクト: vjeronymo2/ColBERT
def prepare_ranges(index_path, dim, step, part_range):
    print_message(
        "#> Launching a separate thread to load index parts asynchronously.")
    parts, _, _ = get_parts(index_path)

    positions = [(offset, offset + step)
                 for offset in range(0, len(parts), step)]

    if part_range is not None:
        positions = positions[part_range.start:part_range.stop]

    loaded_parts = queue.Queue(maxsize=2)

    def _loader_thread(index_path, dim, positions):
        for offset, endpos in positions:
            index = IndexPart(index_path,
                              dim=dim,
                              part_range=range(offset, endpos),
                              verbose=True)
            loaded_parts.put(index, block=True)

    thread = threading.Thread(target=_loader_thread,
                              args=(
                                  index_path,
                                  dim,
                                  positions,
                              ))
    thread.start()

    return positions, loaded_parts, thread
コード例 #9
0
    def _log_exception(self, etype, value, tb):
        if not self.is_main:
            return

        output_path = os.path.join(self.logs_path, 'exception.txt')
        trace = ''.join(traceback.format_exception(etype, value, tb)) + '\n'
        print_message(trace, '\n\n')

        self.log_new_artifact(output_path, trace)
コード例 #10
0
ファイル: qa_loaders.py プロジェクト: vjeronymo2/ColBERT
def load_qas_(path):
    print_message("#> Loading the reference QAs from", path)

    triples = []

    with open(path) as f:
        for line in f:
            qa = ujson.loads(line)
            triples.append((qa['qid'], qa['question'], qa['answers']))

    return triples
コード例 #11
0
    def train(self, train_data):
        print_message(f"#> Training now (using {self.gpu.ngpu} GPUs)...")

        if self.gpu.ngpu > 0:
            self.gpu.training_initialize(self.index, self.quantizer)

        s = time.time()
        self.index.train(train_data)
        print(time.time() - s)

        if self.gpu.ngpu > 0:
            self.gpu.training_finalize()
コード例 #12
0
    def _load_queries(self, path):
        print_message("#> Loading queries...")

        queries = {}

        with open(path) as f:
            for line in f:
                qid, query = line.strip().split('\t')
                qid = int(qid)
                queries[qid] = query

        return queries
コード例 #13
0
    def _prepare_gpu_resources(self):
        print_message(f"Preparing resources for {self.ngpu} GPUs.")

        gpu_resources = []

        for _ in range(self.ngpu):
            res = faiss.StandardGpuResources()
            if self.tempmem >= 0:
                res.setTempMemory(self.tempmem)
            gpu_resources.append(res)

        return gpu_resources
コード例 #14
0
    def add(self, data):
        print_message(f"Add data with shape {data.shape} (offset = {self.offset})..")

        if self.gpu.ngpu > 0 and self.offset == 0:
            self.gpu.adding_initialize(self.index)

        if self.gpu.ngpu > 0:
            self.gpu.add(self.index, data, self.offset)
        else:
            self.index.add(data)

        self.offset += data.shape[0]
コード例 #15
0
def main(args):
    print_message("#> Starting...")

    collection_path = os.path.join(args.datadir,
                                   'wiki.abstracts.2017/collection.tsv')
    checkpoint_path = os.path.join(args.datadir,
                                   'hover.checkpoints-v1.0/flipr-v1.0.dnn')

    with Run().context(RunConfig(root=args.root)):
        config = ColBERTConfig(doc_maxlen=256, nbits=args.nbits)
        indexer = Indexer(checkpoint_path, config=config)
        indexer.index(name=args.index, collection=collection_path)
コード例 #16
0
ファイル: index.py プロジェクト: ryparmar/master-thesis
def main():
    random.seed(12345)

    parser = Arguments(
        description='Precomputing document representations with ColBERT.')

    parser.add_model_parameters()
    parser.add_model_inference_parameters()
    parser.add_indexing_input()

    parser.add_argument('--chunksize',
                        dest='chunksize',
                        default=6.0,
                        required=False,
                        type=float)  # in GiBs

    args = parser.parse()

    with Run.context():
        args.index_path = os.path.join(args.index_root, args.index_name)
        # try:
        assert not os.path.exists(args.index_path), args.index_path
        # except:
        #     print("\n\nNOT EXISTING:", args.index_path, args.index_path, '\n\n')

        distributed.barrier(args.rank)

        if args.rank < 1:
            create_directory(args.index_root)
            create_directory(args.index_path)

        distributed.barrier(args.rank)

        process_idx = max(0, args.rank)
        encoder = CollectionEncoder(args,
                                    process_idx=process_idx,
                                    num_processes=args.nranks)
        encoder.encode()

        distributed.barrier(args.rank)

        # Save metadata.
        if args.rank < 1:
            metadata_path = os.path.join(args.index_path, 'metadata.json')
            print_message("Saving (the following) metadata to", metadata_path,
                          "..")
            print(args.input_arguments)

            with open(metadata_path, 'w') as output_metadata:
                ujson.dump(args.input_arguments.__dict__, output_metadata)

        distributed.barrier(args.rank)
コード例 #17
0
ファイル: retrieval.py プロジェクト: vjeronymo2/ColBERT
def retrieve(args):
    inference = ModelInference(args.colbert, amp=args.amp)
    ranker = Ranker(args, inference, faiss_depth=args.faiss_depth)

    ranking_logger = RankingLogger(Run.path, qrels=None)
    milliseconds = 0

    with ranking_logger.context('ranking.tsv',
                                also_save_annotations=False) as rlogger:
        queries = args.queries
        qids_in_order = list(queries.keys())

        for qoffset, qbatch in batch(qids_in_order, 100, provide_offset=True):
            qbatch_text = [queries[qid] for qid in qbatch]

            rankings = []

            for query_idx, q in enumerate(qbatch_text):
                torch.cuda.synchronize('cuda:0')
                s = time.time()

                Q = ranker.encode([q])
                pids, scores = ranker.rank(Q)

                torch.cuda.synchronize()
                milliseconds += (time.time() - s) * 1000.0

                if len(pids):
                    print(qoffset + query_idx, q, len(scores), len(pids),
                          scores[0], pids[0],
                          milliseconds / (qoffset + query_idx + 1), 'ms')

                rankings.append(zip(pids, scores))

            for query_idx, (qid, ranking) in enumerate(zip(qbatch, rankings)):
                query_idx = qoffset + query_idx

                if query_idx % 100 == 0:
                    print_message(
                        f"#> Logging query #{query_idx} (qid {qid}) now...")

                ranking = [
                    (score, pid, None)
                    for pid, score in itertools.islice(ranking, args.depth)
                ]
                rlogger.log(qid, ranking, is_ranked=True)

    print('\n\n')
    print(ranking_logger.filename)
    print("#> Done.")
    print('\n\n')
コード例 #18
0
def prepare_faiss_index(slice_samples_paths, partitions, sample_fraction=None):
    training_sample = load_sample(slice_samples_paths,
                                  sample_fraction=sample_fraction)

    dim = training_sample.shape[-1]
    index = FaissIndex(dim, partitions)

    print_message("#> Training with the vectors...")

    index.train(training_sample)

    print_message("Done training!\n")

    return index
コード例 #19
0
    def _load_collection(self, path):
        print_message("#> Loading collection...")

        collection = []

        with open(path) as f:
            for line_idx, line in enumerate(f):
                pid, passage, title, *_ = line.strip().split('\t')
                assert pid == 'id' or int(pid) == line_idx

                passage = title + ' | ' + passage
                collection.append(passage)

        return collection
コード例 #20
0
ファイル: load_model.py プロジェクト: vjeronymo2/ColBERT
def load_model(args, do_print=True):
    colbert = ColBERT.from_pretrained('bert-base-multilingual-uncased',
                                      query_maxlen=args.query_maxlen,
                                      doc_maxlen=args.doc_maxlen,
                                      dim=args.dim,
                                      similarity_metric=args.similarity,
                                      mask_punctuation=args.mask_punctuation)
    colbert = colbert.to(DEVICE)

    print_message("#> Loading model checkpoint.", condition=do_print)

    checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print)

    colbert.eval()

    return colbert, checkpoint
コード例 #21
0
ファイル: faiss_index.py プロジェクト: ryparmar/master-thesis
    def add(self, data):
        print_message(
            f"Add data with shape {data.shape} (offset = {self.offset})..")

        # if self.gpu.ngpu > 0 and self.offset == 0:
        #     print_message(f"Initialiting GPU index...\nngpu: {self.gpu.ngpu}\noffset: {self.offset}")
        #     self.gpu.adding_initialize(self.index)

        # if self.gpu.ngpu > 0:
        #     print_message(f"Adding index... {self.gpu.ngpu}")
        #     self.gpu.add(self.index, data, self.offset)
        # else:
        print_message(f"Adding index... {self.gpu.ngpu}")
        self.index.add(data)

        self.offset += data.shape[0]
コード例 #22
0
ファイル: loaders.py プロジェクト: vjeronymo2/ColBERT
def load_queries(queries_path):
    queries = OrderedDict()

    print_message("#> Loading the queries from", queries_path, "...")

    with open(queries_path) as f:
        for line in f:
            qid, query, *_ = line.strip().split('\t')
            qid = int(qid)

            assert (qid not in queries), ("Query QID", qid, "is repeated!")
            queries[qid] = query

    print_message("#> Got", len(queries), "queries. All QIDs are unique.\n")

    return queries
コード例 #23
0
ファイル: index_part.py プロジェクト: vjeronymo2/ColBERT
    def _load_parts(self, dim, verbose):
        tensor = torch.zeros(self.num_embeddings + 512, dim, dtype=torch.float16)

        if verbose:
            print_message("tensor.size() = ", tensor.size())

        offset = 0
        for idx, filename in enumerate(self.parts_paths):
            print_message("|> Loading", filename, "...", condition=verbose)

            endpos = offset + sum(self.parts_doclens[idx])
            part = load_index_part(filename, verbose=verbose)

            tensor[offset:endpos] = part
            offset = endpos

        return tensor
コード例 #24
0
def main(args):
    qas = load_qas_(args.qas)
    collection = load_collection_(args.collection, retain_titles=True)
    rankings = load_ranking(args.ranking)
    parallel_pool = Pool(30)

    print_message('#> Tokenize the answers in the Q&As in parallel...')
    qas = list(parallel_pool.map(tokenize_all_answers, qas))

    qid2answers = {qid: tok_answers for qid, _, tok_answers in qas}
    assert len(qas) == len(qid2answers), (len(qas), len(qid2answers))

    print_message('#> Lookup passages from PIDs...')
    expanded_rankings = [(qid, pid, rank, collection[pid], qid2answers[qid])
                         for qid, pid, rank, *_ in rankings]

    print_message('#> Assign labels in parallel...')
    labeled_rankings = list(
        parallel_pool.map(assign_label_to_passage,
                          enumerate(expanded_rankings)))

    # Dump output.
    print_message("#> Dumping output to", args.output, "...")
    qid2rankings = groupby_first_item(labeled_rankings)

    num_judged_queries, num_ranked_queries = check_sizes(
        qid2answers, qid2rankings)

    # Evaluation metrics and depths.
    success, counts = compute_and_write_labels(args.output, qid2answers,
                                               qid2rankings)

    # Dump metrics.
    with open(args.output_metrics, 'w') as f:
        d = {
            'num_ranked_queries': num_ranked_queries,
            'num_judged_queries': num_judged_queries
        }

        extra = '__WARNING' if num_judged_queries != num_ranked_queries else ''
        d[f'success{extra}'] = {
            k: v / num_judged_queries
            for k, v in success.items()
        }
        d[f'counts{extra}'] = {
            k: v / num_judged_queries
            for k, v in counts.items()
        }
        d['arguments'] = get_metadata(args)

        f.write(format_metadata(d) + '\n')

    print('\n\n')
    print(args.output)
    print(args.output_metrics)
    print("#> Done\n")
コード例 #25
0
    def context(self, filename, also_save_annotations=False):
        assert self.filename is None
        assert self.also_save_annotations is None

        filename = os.path.join(self.directory, filename)
        self.filename, self.also_save_annotations = filename, also_save_annotations

        print_message("#> Logging ranked lists to {}".format(self.filename))

        with open(filename, 'w') as f:
            self.f = f
            with (open(filename + '.annotated', 'w')
                  if also_save_annotations else NullContextManager()) as g:
                self.g = g
                try:
                    yield self
                finally:
                    pass
コード例 #26
0
    def _load_triples(self, path, rank, nranks):
        """
        NOTE: For distributed sampling, this isn't equivalent to perfectly uniform sampling.
        In particular, each subset is perfectly represented in every batch! However, since we never
        repeat passes over the data, we never repeat any particular triple, and the split across
        nodes is random (since the underlying file is pre-shuffled), there's no concern here.
        """
        print_message("#> Loading triples...")

        triples = []

        with open(path) as f:
            for line_idx, line in enumerate(f):
                if line_idx % nranks == rank:
                    qid, pos, neg = ujson.loads(line)
                    triples.append((qid, pos, neg))

        return triples
コード例 #27
0
def load_sample(samples_paths, sample_fraction=None):
    sample = []

    for filename in samples_paths:
        print_message(f"#> Loading {filename} ...")
        part = load_index_part(filename)
        if sample_fraction:
            part = part[torch.randint(0,
                                      high=part.size(0),
                                      size=(int(
                                          part.size(0) * sample_fraction), ))]
        sample.append(part)

    sample = torch.cat(sample).float().numpy()

    print("#> Sample has shape", sample.shape)

    return sample
コード例 #28
0
ファイル: dev_subsample.py プロジェクト: vjeronymo2/ColBERT
def main(args):
    print_message("#> Loading all..")
    qas = load_qas_(args.qas)
    rankings = load_ranking(args.ranking)
    qid2rankings = groupby_first_item(rankings)

    print_message("#> Subsampling all..")
    qas_sample = random.sample(qas, args.sample)

    with open(args.output, 'w') as f:
        for qid, *_ in qas_sample:
            for items in qid2rankings[qid]:
                items = [qid] + items
                line = '\t'.join(map(str, items)) + '\n'
                f.write(line)

    print('\n\n')
    print(args.output)
    print("#> Done.")
コード例 #29
0
    def batch_rank(self, all_query_embeddings, all_query_indexes, all_pids,
                   sorted_pids):
        assert sorted_pids is True

        ######

        scores = []
        range_start, range_end = 0, 0

        for pid_offset in range(0, len(self.doclens), 50_000):
            pid_endpos = min(pid_offset + 50_000, len(self.doclens))

            range_start = range_start + (all_pids[range_start:] <
                                         pid_offset).sum()
            range_end = range_end + (all_pids[range_end:] < pid_endpos).sum()

            pids = all_pids[range_start:range_end]
            query_indexes = all_query_indexes[range_start:range_end]

            print_message(
                f"###--> Got {len(pids)} query--passage pairs in this sub-range {(pid_offset, pid_endpos)}."
            )

            if len(pids) == 0:
                continue

            print_message(
                f"###--> Ranking in batches the pairs #{range_start} through #{range_end} in this sub-range."
            )

            tensor_offset = self.doclens_pfxsum[pid_offset].item()
            tensor_endpos = self.doclens_pfxsum[pid_endpos].item() + 512

            collection = self.tensor[tensor_offset:tensor_endpos].to(DEVICE)
            views = self._create_views(collection)

            print_message(
                f"#> Ranking in batches of {BSIZE} query--passage pairs...")

            for batch_idx, offset in enumerate(range(0, len(pids), BSIZE)):
                if batch_idx % 100 == 0:
                    print_message(
                        "#> Processing batch #{}..".format(batch_idx))

                endpos = offset + BSIZE
                batch_query_index, batch_pids = query_indexes[
                    offset:endpos], pids[offset:endpos]

                Q = all_query_embeddings[batch_query_index]

                scores.extend(
                    self.rank(Q, batch_pids, views, shift=tensor_offset))
コード例 #30
0
def main(args):
    qid_to_file_idx = {}

    for qrels_idx, qrels in enumerate(args.all_queries):
        with open(qrels) as f:
            for line in f:
                qid, *_ = line.strip().split('\t')
                qid = int(qid)

                assert qid_to_file_idx.get(qid, qrels_idx) == qrels_idx, (qid, qrels_idx)
                qid_to_file_idx[qid] = qrels_idx

    all_outputs_paths = [f'{args.ranking}.{idx}' for idx in range(len(args.all_queries))]

    assert all(not os.path.exists(path) for path in all_outputs_paths)

    all_outputs = [open(path, 'w') for path in all_outputs_paths]

    with open(args.ranking) as f:
        print_message(f"#> Loading ranked lists from {f.name} ..")

        last_file_idx = -1

        for line in file_tqdm(f):
            qid, *_ = line.strip().split('\t')

            file_idx = qid_to_file_idx[int(qid)]

            if file_idx != last_file_idx:
                print_message(f"#> Switched to file #{file_idx} at {all_outputs[file_idx].name}")

            last_file_idx = file_idx

            all_outputs[file_idx].write(line)

    print()

    for f in all_outputs:
        print(f.name)
        f.close()

    print("#> Done!")