Ejemplo n.º 1
0
def load_model(args, do_print=True):
    # colbert = ColBERT.from_pretrained('bert-base-uncased',
    #                                   query_maxlen=args.query_maxlen,
    #                                   doc_maxlen=args.doc_maxlen,
    #                                   dim=args.dim,
    #                                   similarity_metric=args.similarity,
    #                                   mask_punctuation=args.mask_punctuation)
    bert_config = BertConfig().from_pretrained('bert-base-uncased')

    colbert = ColKBERT(bert_config,
                       query_maxlen=args.query_maxlen,
                       doc_maxlen=args.doc_maxlen,
                       dim=args.dim,
                       similarity_metric=args.similarity,
                       mask_punctuation=args.mask_punctuation)

    colbert = colbert.to(DEVICE)

    print_message("#> Loading model checkpoint.", condition=do_print)

    checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print)

    colbert.eval()

    return colbert, checkpoint
Ejemplo n.º 2
0
    def queries_to_embedding_ids(self, faiss_depth, Q, verbose=True):
        # Flatten into a matrix for the faiss search.
        num_queries, embeddings_per_query, dim = Q.size()
        Q_faiss = Q.view(num_queries * embeddings_per_query,
                         dim).cpu().contiguous()

        # Search in large batches with faiss.
        print_message(
            "#> Search in batches with faiss. \t\t",
            f"Q.size() = {Q.size()}, Q_faiss.size() = {Q_faiss.size()}",
            condition=verbose)

        embeddings_ids = []
        faiss_bsize = embeddings_per_query * 5000
        for offset in range(0, Q_faiss.size(0), faiss_bsize):
            endpos = min(offset + faiss_bsize, Q_faiss.size(0))

            print_message("#> Searching from {} to {}...".format(
                offset, endpos),
                          condition=verbose)

            some_Q_faiss = Q_faiss[offset:endpos].float().numpy()
            _, some_embedding_ids = self.faiss_index.search(
                some_Q_faiss, faiss_depth)
            embeddings_ids.append(torch.from_numpy(some_embedding_ids))

        embedding_ids = torch.cat(embeddings_ids)

        # Reshape to (number of queries, non-unique embedding IDs per query)
        embedding_ids = embedding_ids.view(
            num_queries, embeddings_per_query * embedding_ids.size(1))

        return embedding_ids
Ejemplo n.º 3
0
    def init(self, rank, root, experiment, name):
        assert '/' not in experiment, experiment
        assert '/' not in name, name

        self.experiments_root = os.path.abspath(root)
        self.experiment = experiment
        self.name = name
        self.path = os.path.join(self.experiments_root, self.experiment, self.script, self.name)

        if rank < 1:
            if os.path.exists(self.path):
                print('\n\n')
                print_message("It seems that ", self.path, " already exists.")
                print_message("Do you want to overwrite it? \t yes/no \n")

                # TODO: This should timeout and exit (i.e., fail) given no response for 60 seconds.

                response = input()
                if response.strip() != 'yes':
                    assert not os.path.exists(self.path), self.path
            else:
                create_directory(self.path)

        distributed.barrier(rank)

        self._logger = Logger(rank, self)
        self._log_args = self._logger._log_args
        self.warn = self._logger.warn
        self.info = self._logger.info
        self.info_all = self._logger.info_all
        self.log_metric = self._logger.log_metric
        self.log_new_artifact = self._logger.log_new_artifact
Ejemplo n.º 4
0
    def add(self, index, data, offset):
        assert self.ngpu > 0
        print_message('Adding...')
        t0 = time.time()
        nb = data.shape[0]
        for i0 in range(0, nb, self.add_batch_size):
            i1 = min(i0 + self.add_batch_size, nb)
            xs = data[i0:i1]

            # print_message(f"Add with ids {type(xs)} {xs.shape}")
            self.gpu_index.add_with_ids(xs, np.arange(offset + i0,
                                                      offset + i1))

            if self.max_add > 0 and self.gpu_index.ntotal > self.max_add:
                self._flush_to_cpu(index, nb, offset)

            print('\r%d/%d (%.3f s)  ' % (i0, nb, time.time() - t0), end=' ')
            sys.stdout.flush()

        if self.gpu_index.ntotal > 0:
            self._flush_to_cpu(index, nb, offset)

        assert index.ntotal == offset + nb, (index.ntotal, offset + nb, offset,
                                             nb)
        print(f"add(.) time: %.3f s \t\t--\t\t index.ntotal = {index.ntotal}" %
              (time.time() - t0))
Ejemplo n.º 5
0
def main(args):
    Rankings = defaultdict(list)

    for path in args.input:
        print_message(f"#> Loading the rankings in {path} ..")

        with open(path) as f:
            for line in file_tqdm(f):
                qid, pid, rank, score = line.strip().split('\t')
                qid, pid, rank = map(int, [qid, pid, rank])
                score = float(score)

                Rankings[qid].append((score, rank, pid))

    with open(args.output, 'w') as f:
        print_message(f"#> Writing the output rankings to {args.output} ..")

        for qid in tqdm.tqdm(Rankings):
            ranking = sorted(Rankings[qid], reverse=True)

            for rank, (score, original_rank, pid) in enumerate(ranking):
                rank = rank + 1  # 1-indexed

                if (args.depth > 0) and (rank > args.depth):
                    break

                line = [qid, pid, rank, score]
                line = '\t'.join(map(str, line)) + '\n'
                f.write(line)
Ejemplo n.º 6
0
def main(args):
    print_message("#> Starting...")

    collectionX_path = os.path.join(args.datadir, 'wiki.abstracts.2017/collection.json')
    queries_path = os.path.join(args.datadir, 'hover/dev/questions.tsv')
    qas_path = os.path.join(args.datadir, 'hover/dev/qas.json')

    checkpointL1 = os.path.join(args.datadir, 'hover.checkpoints-v1.0/condenserL1-v1.0.dnn')
    checkpointL2 = os.path.join(args.datadir, 'hover.checkpoints-v1.0/condenserL2-v1.0.dnn')

    with Run().context(RunConfig(root=args.root)):
        searcher = HopSearcher(index=args.index)
        condenser = Condenser(checkpointL1=checkpointL1, checkpointL2=checkpointL2,
                              collectionX_path=collectionX_path, deviceL1='cuda:0', deviceL2='cuda:0')

        baleen = Baleen(collectionX_path, searcher, condenser)
        baleen.searcher.configure(nprobe=2, ncandidates=8192)

    queries = Queries(path=queries_path)
    outputs = {}

    for qid, query in tqdm.tqdm(queries.items()):
        facts, pids_bag, _ = baleen.search(query, num_hops=4)
        outputs[qid] = (facts, pids_bag)

    with Run().open('output.json', 'w') as f:
        f.write(ujson.dumps(outputs) + '\n')
Ejemplo n.º 7
0
    def __init__(self, tensor, doclens):
        self.tensor = tensor
        self.doclens = doclens

        self.maxsim_dtype = torch.float32
        self.doclens_pfxsum = [0] + list(accumulate(self.doclens))

        self.doclens = torch.tensor(self.doclens)
        self.doclens_pfxsum = torch.tensor(self.doclens_pfxsum)

        self.dim = self.tensor.size(-1)

        self.strides = [
            torch_percentile(self.doclens, p) for p in [25, 50, 75]
        ]
        self.strides.append(self.doclens.max().item())
        self.strides = sorted(list(set(self.strides)))

        print_message(f"#> Using strides {self.strides}..")

        self.views = self._create_views(self.tensor)
        device = 'cuda:0' if DEVICE == 'cuda' else 'cpu'
        print_message(f"device: {device}")
        self.buffers = self._create_buffers(BSIZE, self.tensor.dtype,
                                            {'cpu', device})
Ejemplo n.º 8
0
def prepare_ranges(index_path, dim, step, part_range):
    print_message(
        "#> Launching a separate thread to load index parts asynchronously.")
    parts, _, _ = get_parts(index_path)

    positions = [(offset, offset + step)
                 for offset in range(0, len(parts), step)]

    if part_range is not None:
        positions = positions[part_range.start:part_range.stop]

    loaded_parts = queue.Queue(maxsize=2)

    def _loader_thread(index_path, dim, positions):
        for offset, endpos in positions:
            index = IndexPart(index_path,
                              dim=dim,
                              part_range=range(offset, endpos),
                              verbose=True)
            loaded_parts.put(index, block=True)

    thread = threading.Thread(target=_loader_thread,
                              args=(
                                  index_path,
                                  dim,
                                  positions,
                              ))
    thread.start()

    return positions, loaded_parts, thread
Ejemplo n.º 9
0
    def _log_exception(self, etype, value, tb):
        if not self.is_main:
            return

        output_path = os.path.join(self.logs_path, 'exception.txt')
        trace = ''.join(traceback.format_exception(etype, value, tb)) + '\n'
        print_message(trace, '\n\n')

        self.log_new_artifact(output_path, trace)
Ejemplo n.º 10
0
def load_qas_(path):
    print_message("#> Loading the reference QAs from", path)

    triples = []

    with open(path) as f:
        for line in f:
            qa = ujson.loads(line)
            triples.append((qa['qid'], qa['question'], qa['answers']))

    return triples
Ejemplo n.º 11
0
    def train(self, train_data):
        print_message(f"#> Training now (using {self.gpu.ngpu} GPUs)...")

        if self.gpu.ngpu > 0:
            self.gpu.training_initialize(self.index, self.quantizer)

        s = time.time()
        self.index.train(train_data)
        print(time.time() - s)

        if self.gpu.ngpu > 0:
            self.gpu.training_finalize()
Ejemplo n.º 12
0
    def _load_queries(self, path):
        print_message("#> Loading queries...")

        queries = {}

        with open(path) as f:
            for line in f:
                qid, query = line.strip().split('\t')
                qid = int(qid)
                queries[qid] = query

        return queries
Ejemplo n.º 13
0
    def _prepare_gpu_resources(self):
        print_message(f"Preparing resources for {self.ngpu} GPUs.")

        gpu_resources = []

        for _ in range(self.ngpu):
            res = faiss.StandardGpuResources()
            if self.tempmem >= 0:
                res.setTempMemory(self.tempmem)
            gpu_resources.append(res)

        return gpu_resources
Ejemplo n.º 14
0
    def add(self, data):
        print_message(f"Add data with shape {data.shape} (offset = {self.offset})..")

        if self.gpu.ngpu > 0 and self.offset == 0:
            self.gpu.adding_initialize(self.index)

        if self.gpu.ngpu > 0:
            self.gpu.add(self.index, data, self.offset)
        else:
            self.index.add(data)

        self.offset += data.shape[0]
Ejemplo n.º 15
0
def main(args):
    print_message("#> Starting...")

    collection_path = os.path.join(args.datadir,
                                   'wiki.abstracts.2017/collection.tsv')
    checkpoint_path = os.path.join(args.datadir,
                                   'hover.checkpoints-v1.0/flipr-v1.0.dnn')

    with Run().context(RunConfig(root=args.root)):
        config = ColBERTConfig(doc_maxlen=256, nbits=args.nbits)
        indexer = Indexer(checkpoint_path, config=config)
        indexer.index(name=args.index, collection=collection_path)
Ejemplo n.º 16
0
def main():
    random.seed(12345)

    parser = Arguments(
        description='Precomputing document representations with ColBERT.')

    parser.add_model_parameters()
    parser.add_model_inference_parameters()
    parser.add_indexing_input()

    parser.add_argument('--chunksize',
                        dest='chunksize',
                        default=6.0,
                        required=False,
                        type=float)  # in GiBs

    args = parser.parse()

    with Run.context():
        args.index_path = os.path.join(args.index_root, args.index_name)
        # try:
        assert not os.path.exists(args.index_path), args.index_path
        # except:
        #     print("\n\nNOT EXISTING:", args.index_path, args.index_path, '\n\n')

        distributed.barrier(args.rank)

        if args.rank < 1:
            create_directory(args.index_root)
            create_directory(args.index_path)

        distributed.barrier(args.rank)

        process_idx = max(0, args.rank)
        encoder = CollectionEncoder(args,
                                    process_idx=process_idx,
                                    num_processes=args.nranks)
        encoder.encode()

        distributed.barrier(args.rank)

        # Save metadata.
        if args.rank < 1:
            metadata_path = os.path.join(args.index_path, 'metadata.json')
            print_message("Saving (the following) metadata to", metadata_path,
                          "..")
            print(args.input_arguments)

            with open(metadata_path, 'w') as output_metadata:
                ujson.dump(args.input_arguments.__dict__, output_metadata)

        distributed.barrier(args.rank)
Ejemplo n.º 17
0
def retrieve(args):
    inference = ModelInference(args.colbert, amp=args.amp)
    ranker = Ranker(args, inference, faiss_depth=args.faiss_depth)

    ranking_logger = RankingLogger(Run.path, qrels=None)
    milliseconds = 0

    with ranking_logger.context('ranking.tsv',
                                also_save_annotations=False) as rlogger:
        queries = args.queries
        qids_in_order = list(queries.keys())

        for qoffset, qbatch in batch(qids_in_order, 100, provide_offset=True):
            qbatch_text = [queries[qid] for qid in qbatch]

            rankings = []

            for query_idx, q in enumerate(qbatch_text):
                torch.cuda.synchronize('cuda:0')
                s = time.time()

                Q = ranker.encode([q])
                pids, scores = ranker.rank(Q)

                torch.cuda.synchronize()
                milliseconds += (time.time() - s) * 1000.0

                if len(pids):
                    print(qoffset + query_idx, q, len(scores), len(pids),
                          scores[0], pids[0],
                          milliseconds / (qoffset + query_idx + 1), 'ms')

                rankings.append(zip(pids, scores))

            for query_idx, (qid, ranking) in enumerate(zip(qbatch, rankings)):
                query_idx = qoffset + query_idx

                if query_idx % 100 == 0:
                    print_message(
                        f"#> Logging query #{query_idx} (qid {qid}) now...")

                ranking = [
                    (score, pid, None)
                    for pid, score in itertools.islice(ranking, args.depth)
                ]
                rlogger.log(qid, ranking, is_ranked=True)

    print('\n\n')
    print(ranking_logger.filename)
    print("#> Done.")
    print('\n\n')
Ejemplo n.º 18
0
def prepare_faiss_index(slice_samples_paths, partitions, sample_fraction=None):
    training_sample = load_sample(slice_samples_paths,
                                  sample_fraction=sample_fraction)

    dim = training_sample.shape[-1]
    index = FaissIndex(dim, partitions)

    print_message("#> Training with the vectors...")

    index.train(training_sample)

    print_message("Done training!\n")

    return index
Ejemplo n.º 19
0
    def _load_collection(self, path):
        print_message("#> Loading collection...")

        collection = []

        with open(path) as f:
            for line_idx, line in enumerate(f):
                pid, passage, title, *_ = line.strip().split('\t')
                assert pid == 'id' or int(pid) == line_idx

                passage = title + ' | ' + passage
                collection.append(passage)

        return collection
Ejemplo n.º 20
0
def load_model(args, do_print=True):
    colbert = ColBERT.from_pretrained('bert-base-multilingual-uncased',
                                      query_maxlen=args.query_maxlen,
                                      doc_maxlen=args.doc_maxlen,
                                      dim=args.dim,
                                      similarity_metric=args.similarity,
                                      mask_punctuation=args.mask_punctuation)
    colbert = colbert.to(DEVICE)

    print_message("#> Loading model checkpoint.", condition=do_print)

    checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print)

    colbert.eval()

    return colbert, checkpoint
Ejemplo n.º 21
0
    def add(self, data):
        print_message(
            f"Add data with shape {data.shape} (offset = {self.offset})..")

        # if self.gpu.ngpu > 0 and self.offset == 0:
        #     print_message(f"Initialiting GPU index...\nngpu: {self.gpu.ngpu}\noffset: {self.offset}")
        #     self.gpu.adding_initialize(self.index)

        # if self.gpu.ngpu > 0:
        #     print_message(f"Adding index... {self.gpu.ngpu}")
        #     self.gpu.add(self.index, data, self.offset)
        # else:
        print_message(f"Adding index... {self.gpu.ngpu}")
        self.index.add(data)

        self.offset += data.shape[0]
Ejemplo n.º 22
0
def load_queries(queries_path):
    queries = OrderedDict()

    print_message("#> Loading the queries from", queries_path, "...")

    with open(queries_path) as f:
        for line in f:
            qid, query, *_ = line.strip().split('\t')
            qid = int(qid)

            assert (qid not in queries), ("Query QID", qid, "is repeated!")
            queries[qid] = query

    print_message("#> Got", len(queries), "queries. All QIDs are unique.\n")

    return queries
Ejemplo n.º 23
0
    def _load_parts(self, dim, verbose):
        tensor = torch.zeros(self.num_embeddings + 512, dim, dtype=torch.float16)

        if verbose:
            print_message("tensor.size() = ", tensor.size())

        offset = 0
        for idx, filename in enumerate(self.parts_paths):
            print_message("|> Loading", filename, "...", condition=verbose)

            endpos = offset + sum(self.parts_doclens[idx])
            part = load_index_part(filename, verbose=verbose)

            tensor[offset:endpos] = part
            offset = endpos

        return tensor
Ejemplo n.º 24
0
def main(args):
    qas = load_qas_(args.qas)
    collection = load_collection_(args.collection, retain_titles=True)
    rankings = load_ranking(args.ranking)
    parallel_pool = Pool(30)

    print_message('#> Tokenize the answers in the Q&As in parallel...')
    qas = list(parallel_pool.map(tokenize_all_answers, qas))

    qid2answers = {qid: tok_answers for qid, _, tok_answers in qas}
    assert len(qas) == len(qid2answers), (len(qas), len(qid2answers))

    print_message('#> Lookup passages from PIDs...')
    expanded_rankings = [(qid, pid, rank, collection[pid], qid2answers[qid])
                         for qid, pid, rank, *_ in rankings]

    print_message('#> Assign labels in parallel...')
    labeled_rankings = list(
        parallel_pool.map(assign_label_to_passage,
                          enumerate(expanded_rankings)))

    # Dump output.
    print_message("#> Dumping output to", args.output, "...")
    qid2rankings = groupby_first_item(labeled_rankings)

    num_judged_queries, num_ranked_queries = check_sizes(
        qid2answers, qid2rankings)

    # Evaluation metrics and depths.
    success, counts = compute_and_write_labels(args.output, qid2answers,
                                               qid2rankings)

    # Dump metrics.
    with open(args.output_metrics, 'w') as f:
        d = {
            'num_ranked_queries': num_ranked_queries,
            'num_judged_queries': num_judged_queries
        }

        extra = '__WARNING' if num_judged_queries != num_ranked_queries else ''
        d[f'success{extra}'] = {
            k: v / num_judged_queries
            for k, v in success.items()
        }
        d[f'counts{extra}'] = {
            k: v / num_judged_queries
            for k, v in counts.items()
        }
        d['arguments'] = get_metadata(args)

        f.write(format_metadata(d) + '\n')

    print('\n\n')
    print(args.output)
    print(args.output_metrics)
    print("#> Done\n")
Ejemplo n.º 25
0
    def context(self, filename, also_save_annotations=False):
        assert self.filename is None
        assert self.also_save_annotations is None

        filename = os.path.join(self.directory, filename)
        self.filename, self.also_save_annotations = filename, also_save_annotations

        print_message("#> Logging ranked lists to {}".format(self.filename))

        with open(filename, 'w') as f:
            self.f = f
            with (open(filename + '.annotated', 'w')
                  if also_save_annotations else NullContextManager()) as g:
                self.g = g
                try:
                    yield self
                finally:
                    pass
Ejemplo n.º 26
0
    def _load_triples(self, path, rank, nranks):
        """
        NOTE: For distributed sampling, this isn't equivalent to perfectly uniform sampling.
        In particular, each subset is perfectly represented in every batch! However, since we never
        repeat passes over the data, we never repeat any particular triple, and the split across
        nodes is random (since the underlying file is pre-shuffled), there's no concern here.
        """
        print_message("#> Loading triples...")

        triples = []

        with open(path) as f:
            for line_idx, line in enumerate(f):
                if line_idx % nranks == rank:
                    qid, pos, neg = ujson.loads(line)
                    triples.append((qid, pos, neg))

        return triples
Ejemplo n.º 27
0
def load_sample(samples_paths, sample_fraction=None):
    sample = []

    for filename in samples_paths:
        print_message(f"#> Loading {filename} ...")
        part = load_index_part(filename)
        if sample_fraction:
            part = part[torch.randint(0,
                                      high=part.size(0),
                                      size=(int(
                                          part.size(0) * sample_fraction), ))]
        sample.append(part)

    sample = torch.cat(sample).float().numpy()

    print("#> Sample has shape", sample.shape)

    return sample
Ejemplo n.º 28
0
def main(args):
    print_message("#> Loading all..")
    qas = load_qas_(args.qas)
    rankings = load_ranking(args.ranking)
    qid2rankings = groupby_first_item(rankings)

    print_message("#> Subsampling all..")
    qas_sample = random.sample(qas, args.sample)

    with open(args.output, 'w') as f:
        for qid, *_ in qas_sample:
            for items in qid2rankings[qid]:
                items = [qid] + items
                line = '\t'.join(map(str, items)) + '\n'
                f.write(line)

    print('\n\n')
    print(args.output)
    print("#> Done.")
Ejemplo n.º 29
0
    def batch_rank(self, all_query_embeddings, all_query_indexes, all_pids,
                   sorted_pids):
        assert sorted_pids is True

        ######

        scores = []
        range_start, range_end = 0, 0

        for pid_offset in range(0, len(self.doclens), 50_000):
            pid_endpos = min(pid_offset + 50_000, len(self.doclens))

            range_start = range_start + (all_pids[range_start:] <
                                         pid_offset).sum()
            range_end = range_end + (all_pids[range_end:] < pid_endpos).sum()

            pids = all_pids[range_start:range_end]
            query_indexes = all_query_indexes[range_start:range_end]

            print_message(
                f"###--> Got {len(pids)} query--passage pairs in this sub-range {(pid_offset, pid_endpos)}."
            )

            if len(pids) == 0:
                continue

            print_message(
                f"###--> Ranking in batches the pairs #{range_start} through #{range_end} in this sub-range."
            )

            tensor_offset = self.doclens_pfxsum[pid_offset].item()
            tensor_endpos = self.doclens_pfxsum[pid_endpos].item() + 512

            collection = self.tensor[tensor_offset:tensor_endpos].to(DEVICE)
            views = self._create_views(collection)

            print_message(
                f"#> Ranking in batches of {BSIZE} query--passage pairs...")

            for batch_idx, offset in enumerate(range(0, len(pids), BSIZE)):
                if batch_idx % 100 == 0:
                    print_message(
                        "#> Processing batch #{}..".format(batch_idx))

                endpos = offset + BSIZE
                batch_query_index, batch_pids = query_indexes[
                    offset:endpos], pids[offset:endpos]

                Q = all_query_embeddings[batch_query_index]

                scores.extend(
                    self.rank(Q, batch_pids, views, shift=tensor_offset))
Ejemplo n.º 30
0
def main(args):
    qid_to_file_idx = {}

    for qrels_idx, qrels in enumerate(args.all_queries):
        with open(qrels) as f:
            for line in f:
                qid, *_ = line.strip().split('\t')
                qid = int(qid)

                assert qid_to_file_idx.get(qid, qrels_idx) == qrels_idx, (qid, qrels_idx)
                qid_to_file_idx[qid] = qrels_idx

    all_outputs_paths = [f'{args.ranking}.{idx}' for idx in range(len(args.all_queries))]

    assert all(not os.path.exists(path) for path in all_outputs_paths)

    all_outputs = [open(path, 'w') for path in all_outputs_paths]

    with open(args.ranking) as f:
        print_message(f"#> Loading ranked lists from {f.name} ..")

        last_file_idx = -1

        for line in file_tqdm(f):
            qid, *_ = line.strip().split('\t')

            file_idx = qid_to_file_idx[int(qid)]

            if file_idx != last_file_idx:
                print_message(f"#> Switched to file #{file_idx} at {all_outputs[file_idx].name}")

            last_file_idx = file_idx

            all_outputs[file_idx].write(line)

    print()

    for f in all_outputs:
        print(f.name)
        f.close()

    print("#> Done!")