def main(): random.seed(12345) parser = Arguments(description='Faiss indexing for end-to-end retrieval with ColBERT.') parser.add_index_use_input() parser.add_argument('--sample', dest='sample', default=None, type=float) parser.add_argument('--slices', dest='slices', default=1, type=int) args = parser.parse() assert args.slices >= 1 assert args.sample is None or (0.0 < args.sample < 1.0), args.sample with Run.context(): args.index_path = os.path.join(args.index_root, args.index_name) assert os.path.exists(args.index_path), args.index_path num_embeddings = sum(load_doclens(args.index_path)) print("#> num_embeddings =", num_embeddings) if args.partitions is None: args.partitions = 1 << math.ceil(math.log2(8 * math.sqrt(num_embeddings))) print('\n\n') Run.warn("You did not specify --partitions!") Run.warn("Default computation chooses", args.partitions, "partitions (for {} embeddings)".format(num_embeddings)) print('\n\n') index_faiss(args)
def skip_to_batch(self, batch_idx, intended_batch_size): self._reset_triples() Run.warn( f'Skipping to batch #{batch_idx} (with intended_batch_size = {intended_batch_size}) for training.' ) _ = [ self.reader.readline() for _ in range(batch_idx * intended_batch_size) ] return None
def main(): random.seed(12345) parser = Arguments(description='Exhaustive (slow, not index-based) evaluation of re-ranking with ColBERT.') parser.add_model_parameters() parser.add_model_inference_parameters() parser.add_reranking_input() parser.add_argument('--depth', dest='depth', required=False, default=None, type=int) args = parser.parse() with Run.context(): args.colbert, args.checkpoint = load_colbert(args) args.qrels = load_qrels(args.qrels) if args.collection or args.queries: assert args.collection and args.queries args.queries = load_queries(args.queries) args.collection = load_collection(args.collection) args.topK_pids, args.qrels = load_topK_pids(args.topK, args.qrels) else: args.queries, args.topK_docs, args.topK_pids = load_topK(args.topK) assert (not args.shortcircuit) or args.qrels, \ "Short-circuiting (i.e., applying minimal computation to queries with no positives in the re-ranked set) " \ "can only be applied if qrels is provided." evaluate_recall(args.qrels, args.queries, args.topK_pids) evaluate(args)
def log(self, query_idx): assert query_idx >= self.max_query_idx self.max_query_idx = query_idx Run.log_metric("ranking/max_query_idx", query_idx, query_idx) Run.log_metric("ranking/num_queries_added", self.num_queries_added, query_idx) for depth in sorted(self.mrr_sums): score = self.mrr_sums[depth] / (query_idx + 1.0) Run.log_metric("ranking/MRR." + str(depth), score, query_idx) for depth in sorted(self.success_sums): score = self.success_sums[depth] / (query_idx + 1.0) Run.log_metric("ranking/Success." + str(depth), score, query_idx) for depth in sorted(self.recall_sums): score = self.recall_sums[depth] / (query_idx + 1.0) Run.log_metric("ranking/Recall." + str(depth), score, query_idx)
def main(): random.seed(12345) parser = Arguments( description='End-to-end retrieval and ranking with ColBERT.') parser.add_model_parameters() parser.add_model_inference_parameters() parser.add_ranking_input() parser.add_retrieval_input() parser.add_argument('--faiss_name', dest='faiss_name', default=None, type=str) parser.add_argument('--faiss_depth', dest='faiss_depth', default=1024, type=int) parser.add_argument('--part-range', dest='part_range', default=None, type=str) parser.add_argument('--batch', dest='batch', default=False, action='store_true') parser.add_argument('--depth', dest='depth', default=1000, type=int) args = parser.parse() args.depth = args.depth if args.depth > 0 else None if args.part_range: part_offset, part_endpos = map(int, args.part_range.split('..')) args.part_range = range(part_offset, part_endpos) with Run.context(): args.colbert, args.checkpoint = load_colbert(args) args.qrels = load_qrels(args.qrels) args.queries = load_queries(args.queries) args.index_path = os.path.join(args.index_root, args.index_name) if args.faiss_name is not None: args.faiss_index_path = os.path.join(args.index_path, args.faiss_name) else: args.faiss_index_path = os.path.join(args.index_path, get_faiss_index_name(args)) if args.batch: batch_retrieve(args) else: retrieve(args)
def main(): random.seed(12345) parser = Arguments( description='Precomputing document representations with ColBERT.') parser.add_model_parameters() parser.add_model_inference_parameters() parser.add_indexing_input() parser.add_argument('--chunksize', dest='chunksize', default=6.0, required=False, type=float) # in GiBs args = parser.parse() with Run.context(): args.index_path = os.path.join(args.index_root, args.index_name) # try: assert not os.path.exists(args.index_path), args.index_path # except: # print("\n\nNOT EXISTING:", args.index_path, args.index_path, '\n\n') distributed.barrier(args.rank) if args.rank < 1: create_directory(args.index_root) create_directory(args.index_path) distributed.barrier(args.rank) process_idx = max(0, args.rank) encoder = CollectionEncoder(args, process_idx=process_idx, num_processes=args.nranks) encoder.encode() distributed.barrier(args.rank) # Save metadata. if args.rank < 1: metadata_path = os.path.join(args.index_path, 'metadata.json') print_message("Saving (the following) metadata to", metadata_path, "..") print(args.input_arguments) with open(metadata_path, 'w') as output_metadata: ujson.dump(args.input_arguments.__dict__, output_metadata) distributed.barrier(args.rank)
def load_colbert(args, do_print=True): colbert, checkpoint = load_model(args, do_print) # TODO: If the parameters below were not specified on the command line, their *checkpoint* values should be used. # I.e., not their purely (i.e., training) default values. for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']: if 'arguments' in checkpoint and hasattr(args, k): if k in checkpoint['arguments'] and checkpoint['arguments'][ k] != getattr(args, k): a, b = checkpoint['arguments'][k], getattr(args, k) Run.warn( f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})" ) if 'arguments' in checkpoint: if args.rank < 1: print(ujson.dumps(checkpoint['arguments'], indent=4)) if do_print: print('\n') return colbert, checkpoint
def main(): random.seed(12345) parser = Arguments(description='Re-ranking over a ColBERT index') parser.add_model_parameters() parser.add_model_inference_parameters() parser.add_reranking_input() parser.add_index_use_input() parser.add_argument('--step', dest='step', default=1, type=int) parser.add_argument('--part-range', dest='part_range', default=None, type=str) parser.add_argument('--log-scores', dest='log_scores', default=False, action='store_true') parser.add_argument('--batch', dest='batch', default=False, action='store_true') parser.add_argument('--depth', dest='depth', default=1000, type=int) args = parser.parse() if args.part_range: part_offset, part_endpos = map(int, args.part_range.split('..')) args.part_range = range(part_offset, part_endpos) with Run.context(): args.colbert, args.checkpoint = load_colbert(args) args.queries = load_queries(args.queries) args.qrels = load_qrels(args.qrels) args.topK_pids, args.qrels = load_topK_pids(args.topK, qrels=args.qrels) args.index_path = os.path.join(args.index_root, args.index_name) if args.batch: batch_rerank(args) else: rerank(args)
def main(): parser = Arguments( description= 'Training ColBERT with <query, positive passage, negative passage> triples.' ) parser.add_model_parameters() parser.add_model_training_parameters() parser.add_training_input() args = parser.parse() assert args.bsize % args.accumsteps == 0, (( args.bsize, args.accumsteps ), "The batch size must be divisible by the number of gradient accumulation steps." ) assert args.query_maxlen <= 512 assert args.doc_maxlen <= 512 args.lazy = args.collection is not None with Run.context(consider_failed_if_interrupted=False): train(args)
def parse(self): args = self.parser.parse_args() self.check_arguments(args) args.input_arguments = copy.deepcopy(args) args.nranks, args.distributed = distributed.init(args.rank) args.nthreads = int( max(os.cpu_count(), faiss.omp_get_max_threads()) * 0.8) args.nthreads = max(1, args.nthreads // args.nranks) if args.nranks > 1: print_message( f"#> Restricting number of threads for FAISS to {args.nthreads} per process", condition=(args.rank == 0)) faiss.omp_set_num_threads(args.nthreads) Run.init(args.rank, args.root, args.experiment, args.run) Run._log_args(args) Run.info(args.input_arguments.__dict__, '\n') return args
def train(args): random.seed(12345) np.random.seed(12345) torch.manual_seed(12345) if args.distributed: torch.cuda.manual_seed_all(12345) if args.distributed: assert args.bsize % args.nranks == 0, (args.bsize, args.nranks) assert args.accumsteps == 1 args.bsize = args.bsize // args.nranks print("Using args.bsize =", args.bsize, "(per process) and args.accumsteps =", args.accumsteps) if args.lazy: reader = LazyBatcher(args, (0 if args.rank == -1 else args.rank), args.nranks) else: reader = EagerBatcher(args, (0 if args.rank == -1 else args.rank), args.nranks) if args.rank not in [-1, 0]: torch.distributed.barrier() colbert = ColBERT.from_pretrained('bert-base-uncased', query_maxlen=args.query_maxlen, doc_maxlen=args.doc_maxlen, dim=args.dim, similarity_metric=args.similarity, mask_punctuation=args.mask_punctuation) if args.checkpoint is not None: assert args.resume_optimizer is False, "TODO: This would mean reload optimizer too." print_message( f"#> Starting from checkpoint {args.checkpoint} -- but NOT the optimizer!" ) checkpoint = torch.load(args.checkpoint, map_location='cpu') colbert.load_state_dict(checkpoint['model_state_dict']) if args.rank == 0: torch.distributed.barrier() colbert = colbert.to(DEVICE) colbert.train() if args.distributed: colbert = torch.nn.parallel.DistributedDataParallel( colbert, device_ids=[args.rank], output_device=args.rank, find_unused_parameters=True) optimizer = AdamW(filter(lambda p: p.requires_grad, colbert.parameters()), lr=args.lr, eps=1e-8) optimizer.zero_grad() amp = MixedPrecisionManager(args.amp) criterion = nn.CrossEntropyLoss() labels = torch.zeros(args.bsize, dtype=torch.long, device=DEVICE) start_time = time.time() train_loss = 0.0 start_batch_idx = 0 if args.resume: assert args.checkpoint is not None start_batch_idx = checkpoint['batch'] reader.skip_to_batch(start_batch_idx, checkpoint['arguments']['bsize']) for batch_idx, BatchSteps in zip(range(start_batch_idx, args.maxsteps), reader): this_batch_loss = 0.0 for queries, passages in BatchSteps: with amp.context(): scores = colbert(queries, passages).view(2, -1).permute(1, 0) loss = criterion(scores, labels[:scores.size(0)]) loss = loss / args.accumsteps if args.rank < 1: print_progress(scores) amp.backward(loss) train_loss += loss.item() this_batch_loss += loss.item() amp.step(colbert, optimizer) if args.rank < 1: avg_loss = train_loss / (batch_idx + 1) num_examples_seen = (batch_idx - start_batch_idx) * args.bsize * args.nranks elapsed = float(time.time() - start_time) log_to_mlflow = (batch_idx % 20 == 0) Run.log_metric('train/avg_loss', avg_loss, step=batch_idx, log_to_mlflow=log_to_mlflow) Run.log_metric('train/batch_loss', this_batch_loss, step=batch_idx, log_to_mlflow=log_to_mlflow) Run.log_metric('train/examples', num_examples_seen, step=batch_idx, log_to_mlflow=log_to_mlflow) Run.log_metric('train/throughput', num_examples_seen / elapsed, step=batch_idx, log_to_mlflow=log_to_mlflow) print_message(batch_idx, avg_loss) manage_checkpoints(args, colbert, optimizer, batch_idx + 1)
def skip_to_batch(self, batch_idx, intended_batch_size): Run.warn(f'Skipping to batch #{batch_idx} (with intended_batch_size = {intended_batch_size}) for training.') self.position = intended_batch_size * batch_idx