def load_model(args, checkpoint_path): label_list = ["0", "1"] num_labels = len(label_list) args.model_type = args.model_type.lower() configObj = MSMarcoConfigDict[args.model_type] args.model_name_or_path = checkpoint_path #print(checkpoint_path) model = configObj.model_class(args) saved_state = load_states_from_checkpoint(checkpoint_path) model_to_load = get_model_obj(model) logger.info('Loading saved model state ...') model_to_load.load_state_dict(saved_state.model_dict) model.to(args.device) logger.info("Inference parameters %s", args) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) return model
def load_model(args, checkpoint_path,load_flag=False): label_list = ["0", "1"] num_labels = len(label_list) args.model_type = args.model_type.lower() configObj = MSMarcoConfigDict[args.model_type] args.model_name_or_path = checkpoint_path model = configObj.model_class(args) if args.init_from_fp16_ckpt: checkpoint_step = checkpoint_path.split('-')[-1].replace('/','') init_step = args.pretrained_checkpoint_dir.split('-')[-1].replace('/','') load_flag = checkpoint_step > init_step if args.fp16 and load_flag: checkpoint = torch.load(checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu')) new_state_dict = OrderedDict() for k, v in checkpoint['model'].items(): name = k[7:] new_state_dict[name] = v model.load_state_dict(new_state_dict) else: saved_state = load_states_from_checkpoint(checkpoint_path) model_to_load = get_model_obj(model) logger.info('Loading saved model state ...') model_to_load.load_state_dict(saved_state.model_dict) model.is_representation_l2_normalization = args.representation_l2_normalization model.to(args.device) logger.info("Inference parameters %s", args) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) return model
def train(args, model, tokenizer, query_cache, passage_cache): """ Train the model """ logger.info("Training/evaluation parameters %s", args) tb_writer = None if is_first_worker(): tb_writer = SummaryWriter(log_dir=args.log_dir) args.train_batch_size = args.per_gpu_train_batch_size * max( 1, args.n_gpu) #nll loss for query real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * ( torch.distributed.get_world_size() if args.local_rank != -1 else 1) optimizer = get_optimizer( args, model, weight_decay=args.weight_decay, ) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train! logger.info("***** Running training *****") logger.info(" Max steps = %d", args.max_steps) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) tr_loss = 0.0 model.zero_grad() model.train() set_seed(args) # Added here for reproductibility last_ann_no = -1 train_dataloader = None train_dataloader_iter = None dev_ndcg = 0 step = 0 iter_count = 0 scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps) global_step = 0 if args.model_name_or_path != "bert-base-uncased": saved_state = load_states_from_checkpoint(args.model_name_or_path) global_step = _load_saved_state(model, optimizer, scheduler, saved_state) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from global step %d", global_step) #nq_dev_nll_loss, nq_correct_ratio = evaluate_dev(args, model, passage_cache) #dev_nll_loss_trivia, correct_ratio_trivia = evaluate_dev(args, model, passage_cache, "-trivia") #if is_first_worker(): # tb_writer.add_scalar("dev_nll_loss/dev_nll_loss", nq_dev_nll_loss, global_step) # tb_writer.add_scalar("dev_nll_loss/correct_ratio", nq_correct_ratio, global_step) # tb_writer.add_scalar("dev_nll_loss/dev_nll_loss_trivia", dev_nll_loss_trivia, global_step) # tb_writer.add_scalar("dev_nll_loss/correct_ratio_trivia", correct_ratio_trivia, global_step) print(args.num_epoch) #step = global_step print(step, args.max_steps, global_step) global_step = 0 while global_step < args.max_steps: if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0: if args.num_epoch == 0: #print('yes') # check if new ann training data is availabe ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir) #print(ann_path) #print(ann_no) #print(ndcg_json) if ann_path is not None and ann_no != last_ann_no: logger.info("Training on new add data at %s", ann_path) time.sleep(180) with open(ann_path, 'r') as f: #print(ann_path) ann_training_data = f.readlines() logger.info("Training data line count: %d", len(ann_training_data)) ann_training_data = [ l for l in ann_training_data if len(l.split('\t')[2].split(',')) > 1 ] logger.info("Filtered training data line count: %d", len(ann_training_data)) #ann_checkpoint_path = ndcg_json['checkpoint'] #ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path) aligned_size = (len(ann_training_data) // args.world_size) * args.world_size ann_training_data = ann_training_data[:aligned_size] logger.info("Total ann queries: %d", len(ann_training_data)) if args.triplet: train_dataset = StreamingDataset( ann_training_data, GetTripletTrainingDataProcessingFn( args, query_cache, passage_cache)) train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size) else: train_dataset = StreamingDataset( ann_training_data, GetTrainingDataProcessingFn( args, query_cache, passage_cache)) train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size * 2) train_dataloader_iter = iter(train_dataloader) # re-warmup if not args.single_warmup: scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=len(ann_training_data)) if args.local_rank != -1: dist.barrier() if is_first_worker(): # add ndcg at checkpoint step used instead of current step #tb_writer.add_scalar("retrieval_accuracy/top20_nq", ndcg_json['top20'], ann_checkpoint_no) #tb_writer.add_scalar("retrieval_accuracy/top100_nq", ndcg_json['top100'], ann_checkpoint_no) #if 'top20_trivia' in ndcg_json: # tb_writer.add_scalar("retrieval_accuracy/top20_trivia", ndcg_json['top20_trivia'], ann_checkpoint_no) # tb_writer.add_scalar("retrieval_accuracy/top100_trivia", ndcg_json['top100_trivia'], ann_checkpoint_no) if last_ann_no != -1: tb_writer.add_scalar("epoch", last_ann_no, global_step - 1) tb_writer.add_scalar("epoch", ann_no, global_step) last_ann_no = ann_no elif step == 0: train_data_path = os.path.join(args.data_dir, "train-data") with open(train_data_path, 'r') as f: training_data = f.readlines() if args.triplet: train_dataset = StreamingDataset( training_data, GetTripletTrainingDataProcessingFn( args, query_cache, passage_cache)) train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size) else: train_dataset = StreamingDataset( training_data, GetTrainingDataProcessingFn(args, query_cache, passage_cache)) train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size * 2) all_batch = [b for b in train_dataloader] logger.info("Total batch count: %d", len(all_batch)) train_dataloader_iter = iter(train_dataloader) try: batch = next(train_dataloader_iter) except StopIteration: logger.info("Finished iterating current dataset, begin reiterate") if args.num_epoch != 0: iter_count += 1 if is_first_worker(): tb_writer.add_scalar("epoch", iter_count - 1, global_step - 1) tb_writer.add_scalar("epoch", iter_count, global_step) #nq_dev_nll_loss, nq_correct_ratio = evaluate_dev(args, model, passage_cache) #dev_nll_loss_trivia, correct_ratio_trivia = evaluate_dev(args, model, passage_cache, "-trivia") #if is_first_worker(): # tb_writer.add_scalar("dev_nll_loss/dev_nll_loss", nq_dev_nll_loss, global_step) # tb_writer.add_scalar("dev_nll_loss/correct_ratio", nq_correct_ratio, global_step) # tb_writer.add_scalar("dev_nll_loss/dev_nll_loss_trivia", dev_nll_loss_trivia, global_step) # tb_writer.add_scalar("dev_nll_loss/correct_ratio_trivia", correct_ratio_trivia, global_step) ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir) if ann_path is not None: with open(ann_path, 'r') as f: print(ann_path) ann_training_data = f.readlines() logger.info("Training data line count: %d", len(ann_training_data)) ann_training_data = [ l for l in ann_training_data if len(l.split('\t')[2].split(',')) > 1 ] logger.info("Filtered training data line count: %d", len(ann_training_data)) aligned_size = (len(ann_training_data) // args.world_size) * args.world_size ann_training_data = ann_training_data[:aligned_size] train_dataset = StreamingDataset( ann_training_data, GetTrainingDataProcessingFn(args, query_cache, passage_cache)) train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size * 2) train_dataloader_iter = iter(train_dataloader) batch = next(train_dataloader_iter) dist.barrier() if args.num_epoch != 0 and iter_count > args.num_epoch: break step += 1 if args.triplet: loss = triplet_fwd_pass(args, model, batch) else: loss, correct_cnt = do_biencoder_fwd_pass(args, model, batch) if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: if step % args.gradient_accumulation_steps == 0: loss.backward() else: with model.no_sync(): loss.backward() tr_loss += loss.item() if step % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} loss_scalar = tr_loss / args.logging_steps learning_rate_scalar = scheduler.get_lr()[0] logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar tr_loss = 0 if is_first_worker(): for key, value in logs.items(): tb_writer.add_scalar(key, value, global_step) logger.info(json.dumps({**logs, **{"step": global_step}})) if is_first_worker( ) and args.save_steps > 0 and global_step % args.save_steps == 0: _save_checkpoint(args, model, optimizer, scheduler, global_step) if args.local_rank == -1 or torch.distributed.get_rank() == 0: tb_writer.close() return global_step