def evaluate(args, model): """ Train the model """ dev_dataset = SequenceDataset( TextTokenIdsCache(args.preprocess_dir, f"{args.mode}-query"), args.max_seq_length) collate_fn = get_collate_function(args.max_seq_length) batch_size = args.pergpu_eval_batch_size if args.n_gpu > 1: batch_size *= args.n_gpu dev_dataloader = DataLoader(dev_dataset, batch_size=batch_size, collate_fn=collate_fn) if args.n_gpu > 1: model = torch.nn.DataParallel(model) qembedding_memmap = np.memmap(args.qmemmap_path, dtype="float32", shape=(len(dev_dataset), 768), mode="w+") with torch.no_grad(): for step, (batch, qoffsets) in enumerate(tqdm(dev_dataloader)): batch = {k: v.to(args.model_device) for k, v in batch.items()} model.eval() embeddings = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], is_query=True) embeddings = embeddings.detach().cpu().numpy() qembedding_memmap[qoffsets] = embeddings return qembedding_memmap
def doc_inference(model, args, embedding_size): if os.path.exists(args.doc_memmap_path): print(f"{args.doc_memmap_path} exists, skip inference") return doc_collator = single_get_collate_function(args.max_doc_length) ids_cache = TextTokenIdsCache(data_dir=args.preprocess_dir, prefix="passages") subset = list(range(len(ids_cache))) doc_dataset = SubsetSeqDataset(subset=subset, ids_cache=ids_cache, max_seq_length=args.max_doc_length) assert not os.path.exists(args.doc_memmap_path) doc_memmap = np.memmap(args.doc_memmap_path, dtype=np.float32, mode="w+", shape=(len(doc_dataset), embedding_size)) docid_memmap = np.memmap(args.docid_memmap_path, dtype=np.int32, mode="w+", shape=(len(doc_dataset), )) try: prediction(model, doc_collator, args, doc_dataset, doc_memmap, docid_memmap, is_query=False) except: subprocess.check_call(["rm", args.doc_memmap_path]) subprocess.check_call(["rm", args.docid_memmap_path]) raise
def query_inference(model, args, embedding_size): if os.path.exists(args.query_memmap_path): print(f"{args.query_memmap_path} exists, skip inference") return query_collator = single_get_collate_function(args.max_query_length) query_dataset = SequenceDataset(ids_cache=TextTokenIdsCache( data_dir=args.preprocess_dir, prefix=f"{args.mode}-query"), max_seq_length=args.max_query_length) query_memmap = np.memmap(args.query_memmap_path, dtype=np.float32, mode="w+", shape=(len(query_dataset), embedding_size)) queryids_memmap = np.memmap(args.queryids_memmap_path, dtype=np.int32, mode="w+", shape=(len(query_dataset), )) try: prediction(model, query_collator, args, query_dataset, query_memmap, queryids_memmap, is_query=True) except: subprocess.check_call(["rm", args.query_memmap_path]) subprocess.check_call(["rm", args.queryids_memmap_path]) raise
def query_inference(model, args, embedding_size): query_collator = single_get_collate_function(args.max_query_length) query_dataset = SequenceDataset(ids_cache=TextTokenIdsCache( data_dir=args.preprocess_dir, prefix=f"{args.mode}-query"), max_seq_length=args.max_query_length) query_memmap = np.memmap(args.query_memmap_path, dtype=np.float32, mode="w+", shape=(len(query_dataset), embedding_size)) queryids_memmap = np.memmap(args.queryids_memmap_path, dtype=np.int32, mode="w+", shape=(len(query_dataset), )) prediction(model, query_collator, args, query_dataset, query_memmap, queryids_memmap, is_query=True)
def main(): parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, MyTrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. " "Use --overwrite_output_dir to overcome.") # Setup logging logging.basicConfig( format="%(asctime)s-%(levelname)s-%(name)s- %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN, ) # Log on each process the small summary: logger.warning( f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) # Set the verbosity to info of the Transformers logger (on main process only): if is_main_process(training_args.local_rank): transformers.utils.logging.set_verbosity_info() logger.info(f"Training/evaluation parameters {training_args}") # Set seed before initializing model. set_seed(training_args.seed) config = RobertaConfig.from_pretrained( model_args.model_name_or_path, finetuning_task="msmarco", gradient_checkpointing=model_args.gradient_checkpointing) tokenizer = RobertaTokenizer.from_pretrained( model_args.model_name_or_path, use_fast=False, ) config.gradient_checkpointing = model_args.gradient_checkpointing data_args.label_path = os.path.join(data_args.data_dir, "train-qrel.tsv") rel_dict = load_rel(data_args.label_path) if training_args.hard_neg: train_dataset = TrainInbatchWithHardDataset( rel_file=data_args.label_path, rank_file=data_args.hardneg_path, queryids_cache=TextTokenIdsCache(data_dir=data_args.data_dir, prefix="train-query"), docids_cache=TextTokenIdsCache(data_dir=data_args.data_dir, prefix="passages"), max_query_length=data_args.max_query_length, max_doc_length=data_args.max_doc_length, hard_num=training_args.per_query_hard_num) data_collator = triple_get_collate_function( data_args.max_query_length, data_args.max_doc_length, rel_dict=rel_dict, padding=training_args.padding) model_class = RobertaDot_InBatch elif training_args.rand_neg: train_dataset = TrainInbatchWithRandDataset( rel_file=data_args.label_path, rand_num=training_args.per_query_hard_num, queryids_cache=TextTokenIdsCache(data_dir=data_args.data_dir, prefix="train-query"), docids_cache=TextTokenIdsCache(data_dir=data_args.data_dir, prefix="passages"), max_query_length=data_args.max_query_length, max_doc_length=data_args.max_doc_length) data_collator = triple_get_collate_function( data_args.max_query_length, data_args.max_doc_length, rel_dict=rel_dict, padding=training_args.padding) model_class = RobertaDot_Rand else: train_dataset = TrainInbatchDataset( rel_file=data_args.label_path, queryids_cache=TextTokenIdsCache(data_dir=data_args.data_dir, prefix="train-query"), docids_cache=TextTokenIdsCache(data_dir=data_args.data_dir, prefix="passages"), max_query_length=data_args.max_query_length, max_doc_length=data_args.max_doc_length) data_collator = dual_get_collate_function( data_args.max_query_length, data_args.max_doc_length, rel_dict=rel_dict, padding=training_args.padding) model_class = RobertaDot_InBatch model = model_class.from_pretrained( model_args.model_name_or_path, from_tf=bool(".ckpt" in model_args.model_name_or_path), config=config, ) # Initialize our Trainer trainer = DRTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=None, compute_metrics=None, tokenizer=tokenizer, data_collator=data_collator, ) trainer.remove_callback(TensorBoardCallback) trainer.add_callback( MyTensorBoardCallback(tb_writer=SummaryWriter( os.path.join(training_args.output_dir, "log")))) trainer.add_callback(MyTrainerCallback()) # Training trainer.train() trainer.save_model() # Saves the tokenizer too for easy upload
def train(args, model): """ Train the model """ tb_writer = SummaryWriter(args.log_dir) passage_embeddings = np.memmap(args.pembed_path, dtype=np.float32, mode="r").reshape( -1, model.output_embedding_size) args.train_batch_size = args.per_gpu_batch_size train_dataset = TrainQueryDataset( TextTokenIdsCache(args.preprocess_dir, "train-query"), os.path.join(args.preprocess_dir, "train-qrel.MSMARCO.tsv"), args.max_seq_length) train_sampler = RandomSampler(train_dataset) collate_fn = get_collate_function(args.max_seq_length) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate_fn) t_total = len(train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay }, { 'params': [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) index = load_index(passage_embeddings, args.index_path, args.faiss_gpu_index, args.use_gpu and not args.index_cpu) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Total train batch size (w. accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 tr_loss, logging_loss = 0.0, 0.0 tr_mrr, logging_mrr = 0.0, 0.0 tr_recall, logging_recall = 0.0, 0.0 model.zero_grad() train_iterator = trange(int(args.num_train_epochs), desc="Epoch") set_seed( args) # Added here for reproductibility (even between python 2 and 3) for epoch_idx, _ in enumerate(train_iterator): epoch_iterator = tqdm(train_dataloader, desc="Iteration") for step, (batch, _, all_rel_poffsets) in enumerate(epoch_iterator): batch = {k: v.to(args.model_device) for k, v in batch.items()} model.train() query_embeddings = model(query_ids=batch["input_ids"], attention_mask_q=batch["attention_mask"], is_query=True) I_nearest_neighbor = index.search( query_embeddings.detach().cpu().numpy(), args.neg_topk)[1] loss = 0 for retrieve_poffsets, cur_rel_poffsets, qembedding in zip( I_nearest_neighbor, all_rel_poffsets, query_embeddings): target_labels = np.isin(retrieve_poffsets, cur_rel_poffsets).astype(np.int32) first_rel_pos = np.where(target_labels[:10])[0] mrr = 1 / (1 + first_rel_pos[0]) if len(first_rel_pos) > 0 else 0 tr_mrr += mrr / args.train_batch_size recall = 1 if mrr > 0 else 0 tr_recall += recall / args.train_batch_size if np.sum(target_labels) == 0: retrieve_poffsets = np.hstack( [retrieve_poffsets, cur_rel_poffsets]) target_labels = np.hstack( [target_labels, [True] * len(cur_rel_poffsets)]) target_labels = target_labels.reshape(-1, 1) rel_diff = target_labels - target_labels.T pos_pairs = (rel_diff > 0).astype(np.float32) num_pos_pairs = np.sum(pos_pairs, (0, 1)) assert num_pos_pairs > 0 neg_pairs = (rel_diff < 0).astype(np.float32) num_pairs = 2 * num_pos_pairs # num pos pairs and neg pairs are always the same pos_pairs = torch.FloatTensor(pos_pairs).to(args.model_device) neg_pairs = torch.FloatTensor(neg_pairs).to(args.model_device) topK_passage_embeddings = torch.FloatTensor( passage_embeddings[retrieve_poffsets]).to( args.model_device) y_pred = (qembedding.unsqueeze(0) * topK_passage_embeddings).sum(-1, keepdim=True) C_pos = torch.log(1 + torch.exp(y_pred - y_pred.t())) C_neg = torch.log(1 + torch.exp(y_pred - y_pred.t())) C = pos_pairs * C_pos + neg_pairs * C_neg if args.metric is not None: with torch.no_grad(): weights = metric_weights(y_pred, args.metric_cut) C = C * weights cur_loss = torch.sum(C, (0, 1)) / num_pairs loss += cur_loss loss /= (args.train_batch_size * args.gradient_accumulation_steps) loss.backward() tr_loss += loss.item() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.logging_steps > 0 and global_step % args.logging_steps == 0: tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) cur_loss = (tr_loss - logging_loss) / args.logging_steps tb_writer.add_scalar('train/all_loss', cur_loss, global_step) logging_loss = tr_loss cur_mrr = (tr_mrr - logging_mrr) / ( args.logging_steps * args.gradient_accumulation_steps) tb_writer.add_scalar('train/mrr_10', cur_mrr, global_step) logging_mrr = tr_mrr cur_recall = (tr_recall - logging_recall) / ( args.logging_steps * args.gradient_accumulation_steps) tb_writer.add_scalar('train/recall_10', cur_recall, global_step) logging_recall = tr_recall if args.save_steps > 0 and global_step % args.save_steps == 0: save_model(model, args.model_save_dir, 'ckpt-{}'.format(global_step), args) save_model(model, args.model_save_dir, 'epoch-{}'.format(epoch_idx + 1), args)