Ejemplo n.º 1
0
def main():
    random.seed(12345)

    parser = Arguments(description='Faiss indexing for end-to-end retrieval with ColBERT.')
    parser.add_index_use_input()

    parser.add_argument('--sample', dest='sample', default=None, type=float)
    parser.add_argument('--slices', dest='slices', default=1, type=int)

    args = parser.parse()
    assert args.slices >= 1
    assert args.sample is None or (0.0 < args.sample < 1.0), args.sample

    with Run.context():
        args.index_path = os.path.join(args.index_root, args.index_name)
        assert os.path.exists(args.index_path), args.index_path

        num_embeddings = sum(load_doclens(args.index_path))
        print("#> num_embeddings =", num_embeddings)

        if args.partitions is None:
            args.partitions = 1 << math.ceil(math.log2(8 * math.sqrt(num_embeddings)))
            print('\n\n')
            Run.warn("You did not specify --partitions!")
            Run.warn("Default computation chooses", args.partitions,
                     "partitions (for {} embeddings)".format(num_embeddings))
            print('\n\n')

        index_faiss(args)
Ejemplo n.º 2
0
    def skip_to_batch(self, batch_idx, intended_batch_size):
        self._reset_triples()

        Run.warn(
            f'Skipping to batch #{batch_idx} (with intended_batch_size = {intended_batch_size}) for training.'
        )

        _ = [
            self.reader.readline()
            for _ in range(batch_idx * intended_batch_size)
        ]

        return None
Ejemplo n.º 3
0
def main():
    random.seed(12345)

    parser = Arguments(description='Exhaustive (slow, not index-based) evaluation of re-ranking with ColBERT.')

    parser.add_model_parameters()
    parser.add_model_inference_parameters()
    parser.add_reranking_input()

    parser.add_argument('--depth', dest='depth', required=False, default=None, type=int)

    args = parser.parse()

    with Run.context():
        args.colbert, args.checkpoint = load_colbert(args)
        args.qrels = load_qrels(args.qrels)

        if args.collection or args.queries:
            assert args.collection and args.queries

            args.queries = load_queries(args.queries)
            args.collection = load_collection(args.collection)
            args.topK_pids, args.qrels = load_topK_pids(args.topK, args.qrels)

        else:
            args.queries, args.topK_docs, args.topK_pids = load_topK(args.topK)

        assert (not args.shortcircuit) or args.qrels, \
            "Short-circuiting (i.e., applying minimal computation to queries with no positives in the re-ranked set) " \
            "can only be applied if qrels is provided."

        evaluate_recall(args.qrels, args.queries, args.topK_pids)
        evaluate(args)
Ejemplo n.º 4
0
    def log(self, query_idx):
        assert query_idx >= self.max_query_idx
        self.max_query_idx = query_idx

        Run.log_metric("ranking/max_query_idx", query_idx, query_idx)
        Run.log_metric("ranking/num_queries_added", self.num_queries_added,
                       query_idx)

        for depth in sorted(self.mrr_sums):
            score = self.mrr_sums[depth] / (query_idx + 1.0)
            Run.log_metric("ranking/MRR." + str(depth), score, query_idx)

        for depth in sorted(self.success_sums):
            score = self.success_sums[depth] / (query_idx + 1.0)
            Run.log_metric("ranking/Success." + str(depth), score, query_idx)

        for depth in sorted(self.recall_sums):
            score = self.recall_sums[depth] / (query_idx + 1.0)
            Run.log_metric("ranking/Recall." + str(depth), score, query_idx)
Ejemplo n.º 5
0
def main():
    random.seed(12345)

    parser = Arguments(
        description='End-to-end retrieval and ranking with ColBERT.')

    parser.add_model_parameters()
    parser.add_model_inference_parameters()
    parser.add_ranking_input()
    parser.add_retrieval_input()

    parser.add_argument('--faiss_name',
                        dest='faiss_name',
                        default=None,
                        type=str)
    parser.add_argument('--faiss_depth',
                        dest='faiss_depth',
                        default=1024,
                        type=int)
    parser.add_argument('--part-range',
                        dest='part_range',
                        default=None,
                        type=str)
    parser.add_argument('--batch',
                        dest='batch',
                        default=False,
                        action='store_true')
    parser.add_argument('--depth', dest='depth', default=1000, type=int)

    args = parser.parse()

    args.depth = args.depth if args.depth > 0 else None

    if args.part_range:
        part_offset, part_endpos = map(int, args.part_range.split('..'))
        args.part_range = range(part_offset, part_endpos)

    with Run.context():
        args.colbert, args.checkpoint = load_colbert(args)
        args.qrels = load_qrels(args.qrels)
        args.queries = load_queries(args.queries)

        args.index_path = os.path.join(args.index_root, args.index_name)

        if args.faiss_name is not None:
            args.faiss_index_path = os.path.join(args.index_path,
                                                 args.faiss_name)
        else:
            args.faiss_index_path = os.path.join(args.index_path,
                                                 get_faiss_index_name(args))

        if args.batch:
            batch_retrieve(args)
        else:
            retrieve(args)
Ejemplo n.º 6
0
def main():
    random.seed(12345)

    parser = Arguments(
        description='Precomputing document representations with ColBERT.')

    parser.add_model_parameters()
    parser.add_model_inference_parameters()
    parser.add_indexing_input()

    parser.add_argument('--chunksize',
                        dest='chunksize',
                        default=6.0,
                        required=False,
                        type=float)  # in GiBs

    args = parser.parse()

    with Run.context():
        args.index_path = os.path.join(args.index_root, args.index_name)
        # try:
        assert not os.path.exists(args.index_path), args.index_path
        # except:
        #     print("\n\nNOT EXISTING:", args.index_path, args.index_path, '\n\n')

        distributed.barrier(args.rank)

        if args.rank < 1:
            create_directory(args.index_root)
            create_directory(args.index_path)

        distributed.barrier(args.rank)

        process_idx = max(0, args.rank)
        encoder = CollectionEncoder(args,
                                    process_idx=process_idx,
                                    num_processes=args.nranks)
        encoder.encode()

        distributed.barrier(args.rank)

        # Save metadata.
        if args.rank < 1:
            metadata_path = os.path.join(args.index_path, 'metadata.json')
            print_message("Saving (the following) metadata to", metadata_path,
                          "..")
            print(args.input_arguments)

            with open(metadata_path, 'w') as output_metadata:
                ujson.dump(args.input_arguments.__dict__, output_metadata)

        distributed.barrier(args.rank)
Ejemplo n.º 7
0
def load_colbert(args, do_print=True):
    colbert, checkpoint = load_model(args, do_print)

    # TODO: If the parameters below were not specified on the command line, their *checkpoint* values should be used.
    # I.e., not their purely (i.e., training) default values.

    for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']:
        if 'arguments' in checkpoint and hasattr(args, k):
            if k in checkpoint['arguments'] and checkpoint['arguments'][
                    k] != getattr(args, k):
                a, b = checkpoint['arguments'][k], getattr(args, k)
                Run.warn(
                    f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})"
                )

    if 'arguments' in checkpoint:
        if args.rank < 1:
            print(ujson.dumps(checkpoint['arguments'], indent=4))

    if do_print:
        print('\n')

    return colbert, checkpoint
Ejemplo n.º 8
0
def main():
    random.seed(12345)

    parser = Arguments(description='Re-ranking over a ColBERT index')

    parser.add_model_parameters()
    parser.add_model_inference_parameters()
    parser.add_reranking_input()
    parser.add_index_use_input()

    parser.add_argument('--step', dest='step', default=1, type=int)
    parser.add_argument('--part-range',
                        dest='part_range',
                        default=None,
                        type=str)
    parser.add_argument('--log-scores',
                        dest='log_scores',
                        default=False,
                        action='store_true')
    parser.add_argument('--batch',
                        dest='batch',
                        default=False,
                        action='store_true')
    parser.add_argument('--depth', dest='depth', default=1000, type=int)

    args = parser.parse()

    if args.part_range:
        part_offset, part_endpos = map(int, args.part_range.split('..'))
        args.part_range = range(part_offset, part_endpos)

    with Run.context():
        args.colbert, args.checkpoint = load_colbert(args)

        args.queries = load_queries(args.queries)
        args.qrels = load_qrels(args.qrels)
        args.topK_pids, args.qrels = load_topK_pids(args.topK,
                                                    qrels=args.qrels)

        args.index_path = os.path.join(args.index_root, args.index_name)

        if args.batch:
            batch_rerank(args)
        else:
            rerank(args)
Ejemplo n.º 9
0
def main():
    parser = Arguments(
        description=
        'Training ColBERT with <query, positive passage, negative passage> triples.'
    )

    parser.add_model_parameters()
    parser.add_model_training_parameters()
    parser.add_training_input()

    args = parser.parse()

    assert args.bsize % args.accumsteps == 0, ((
        args.bsize, args.accumsteps
    ), "The batch size must be divisible by the number of gradient accumulation steps."
                                               )
    assert args.query_maxlen <= 512
    assert args.doc_maxlen <= 512

    args.lazy = args.collection is not None

    with Run.context(consider_failed_if_interrupted=False):
        train(args)
Ejemplo n.º 10
0
    def parse(self):
        args = self.parser.parse_args()
        self.check_arguments(args)

        args.input_arguments = copy.deepcopy(args)

        args.nranks, args.distributed = distributed.init(args.rank)

        args.nthreads = int(
            max(os.cpu_count(), faiss.omp_get_max_threads()) * 0.8)
        args.nthreads = max(1, args.nthreads // args.nranks)

        if args.nranks > 1:
            print_message(
                f"#> Restricting number of threads for FAISS to {args.nthreads} per process",
                condition=(args.rank == 0))
            faiss.omp_set_num_threads(args.nthreads)

        Run.init(args.rank, args.root, args.experiment, args.run)
        Run._log_args(args)
        Run.info(args.input_arguments.__dict__, '\n')

        return args
Ejemplo n.º 11
0
def train(args):
    random.seed(12345)
    np.random.seed(12345)
    torch.manual_seed(12345)
    if args.distributed:
        torch.cuda.manual_seed_all(12345)

    if args.distributed:
        assert args.bsize % args.nranks == 0, (args.bsize, args.nranks)
        assert args.accumsteps == 1
        args.bsize = args.bsize // args.nranks

        print("Using args.bsize =", args.bsize,
              "(per process) and args.accumsteps =", args.accumsteps)

    if args.lazy:
        reader = LazyBatcher(args, (0 if args.rank == -1 else args.rank),
                             args.nranks)
    else:
        reader = EagerBatcher(args, (0 if args.rank == -1 else args.rank),
                              args.nranks)

    if args.rank not in [-1, 0]:
        torch.distributed.barrier()

    colbert = ColBERT.from_pretrained('bert-base-uncased',
                                      query_maxlen=args.query_maxlen,
                                      doc_maxlen=args.doc_maxlen,
                                      dim=args.dim,
                                      similarity_metric=args.similarity,
                                      mask_punctuation=args.mask_punctuation)

    if args.checkpoint is not None:
        assert args.resume_optimizer is False, "TODO: This would mean reload optimizer too."
        print_message(
            f"#> Starting from checkpoint {args.checkpoint} -- but NOT the optimizer!"
        )

        checkpoint = torch.load(args.checkpoint, map_location='cpu')
        colbert.load_state_dict(checkpoint['model_state_dict'])

    if args.rank == 0:
        torch.distributed.barrier()

    colbert = colbert.to(DEVICE)
    colbert.train()

    if args.distributed:
        colbert = torch.nn.parallel.DistributedDataParallel(
            colbert,
            device_ids=[args.rank],
            output_device=args.rank,
            find_unused_parameters=True)

    optimizer = AdamW(filter(lambda p: p.requires_grad, colbert.parameters()),
                      lr=args.lr,
                      eps=1e-8)
    optimizer.zero_grad()

    amp = MixedPrecisionManager(args.amp)
    criterion = nn.CrossEntropyLoss()
    labels = torch.zeros(args.bsize, dtype=torch.long, device=DEVICE)

    start_time = time.time()
    train_loss = 0.0

    start_batch_idx = 0

    if args.resume:
        assert args.checkpoint is not None
        start_batch_idx = checkpoint['batch']

        reader.skip_to_batch(start_batch_idx, checkpoint['arguments']['bsize'])

    for batch_idx, BatchSteps in zip(range(start_batch_idx, args.maxsteps),
                                     reader):
        this_batch_loss = 0.0

        for queries, passages in BatchSteps:
            with amp.context():
                scores = colbert(queries, passages).view(2, -1).permute(1, 0)
                loss = criterion(scores, labels[:scores.size(0)])
                loss = loss / args.accumsteps

            if args.rank < 1:
                print_progress(scores)

            amp.backward(loss)

            train_loss += loss.item()
            this_batch_loss += loss.item()

        amp.step(colbert, optimizer)

        if args.rank < 1:
            avg_loss = train_loss / (batch_idx + 1)

            num_examples_seen = (batch_idx -
                                 start_batch_idx) * args.bsize * args.nranks
            elapsed = float(time.time() - start_time)

            log_to_mlflow = (batch_idx % 20 == 0)
            Run.log_metric('train/avg_loss',
                           avg_loss,
                           step=batch_idx,
                           log_to_mlflow=log_to_mlflow)
            Run.log_metric('train/batch_loss',
                           this_batch_loss,
                           step=batch_idx,
                           log_to_mlflow=log_to_mlflow)
            Run.log_metric('train/examples',
                           num_examples_seen,
                           step=batch_idx,
                           log_to_mlflow=log_to_mlflow)
            Run.log_metric('train/throughput',
                           num_examples_seen / elapsed,
                           step=batch_idx,
                           log_to_mlflow=log_to_mlflow)

            print_message(batch_idx, avg_loss)
            manage_checkpoints(args, colbert, optimizer, batch_idx + 1)
Ejemplo n.º 12
0
 def skip_to_batch(self, batch_idx, intended_batch_size):
     Run.warn(f'Skipping to batch #{batch_idx} (with intended_batch_size = {intended_batch_size}) for training.')
     self.position = intended_batch_size * batch_idx