def train(opt): # Load data loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length # Tensorboard summaries (they're great!) tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) # Load pretrained model, info file, histories file infos = {} histories = {} if opt.start_from is not None: print("opt.start_from: " + str(opt.start_from)) with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = ["rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) # create model model = models.setup(opt).cuda() dp_model = torch.nn.DataParallel(model) # load model if os.path.isfile("log_sc/model.pth"): model_path = "log_sc/model.pth" state_dict = torch.load(model_path) dp_model.load_state_dict(state_dict) dp_model.train() # create/load vector model vectorModel = models.setup_vectorModel().cuda() dp_vectorModel = torch.nn.DataParallel(vectorModel) # load vector model if os.path.isfile("log_sc/model_vec.pth"): model_vec_path = "log_sc/model_vec.pth" state_dict_vec = torch.load(model_vec_path) dp_vectorModel.load_state_dict(state_dict_vec) dp_vectorModel.train() optimizer = utils.build_optimizer( list(model.parameters()) + list(vectorModel.parameters()), opt) update_lr_flag = True # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) # Loss function crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() vec_crit = nn.L1Loss() # create idxs for doc2vec vectors with open('paragraphs_image_ids.txt', 'r') as file: paragraph_image_ids = file.readlines() paragraph_image_ids = [int(i) for i in paragraph_image_ids] # select corresponding vectors with open('paragraphs_vectors.txt', 'r') as the_file: vectors = the_file.readlines() vectors_list = [] for string in vectors: vectors_list.append([float(s) for s in string.split(' ')]) vectors_list_np = np.asarray(vectors_list) print("Starting training loop!") # Training loop while True: # Update learning rate once per epoch if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False # Load data from train split (0) start = time.time() data = loader.get_batch('train') data_time = time.time() - start start = time.time() # pad data['att_feats'] axis=1 to have length = 83 def pad_along_axis(array, target_length, axis=0): pad_size = target_length - array.shape[axis] axis_nb = len(array.shape) if pad_size < 0: return a npad = [(0, 0) for x in range(axis_nb)] npad[axis] = (0, pad_size) b = np.pad(array, pad_width=npad, mode='constant', constant_values=0) return b data['att_feats'] = pad_along_axis(data['att_feats'], 83, axis=1) # Unpack data torch.cuda.synchronize() tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp idx = [] for element in data['infos']: idx.append(paragraph_image_ids.index(element['id'])) batch_vectors = vectors_list_np[idx] # Forward pass and loss optimizer.zero_grad() if not sc_flag: loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:]) else: gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) att_feats_reshaped = att_feats.permute(0, 2, 1).cuda() semantic_features = dp_vectorModel(att_feats_reshaped.cuda(), fc_feats) # (10, 2048) batch_vectors = torch.from_numpy( batch_vectors).float().cuda() # (10, 512) vec_loss = vec_crit(semantic_features, batch_vectors) alpha_ = 1 loss = loss + (alpha_ * vec_loss) # Backward pass loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() # Print total_time = time.time() - start if iteration % opt.print_freq == 1: print('Read data:', time.time() - start) if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, data_time, total_time)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:, 0]), iteration) loss_history[iteration] = train_loss if not sc_flag else np.mean( reward[:, 0]) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # Validate and save model if True: # Evaluate model eval_kwargs = {'split': 'test', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, crit, loader, eval_kwargs) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Our metric is CIDEr if available, otherwise validation loss if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss # Save model in checkpoint path best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(dp_model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) # save vec model checkpoint_path = os.path.join(opt.checkpoint_path, 'model_vec.pth') torch.save(dp_vectorModel.state_dict(), checkpoint_path) print("model_vec saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(histories, f) # Save model to unique file if new best model if best_flag: model_fname = 'model-best-i{:05d}-score{:.4f}.pth'.format( iteration, best_val_score) infos_fname = 'model-best-i{:05d}-infos.pkl'.format(iteration) checkpoint_path = os.path.join(opt.checkpoint_path, model_fname) torch.save(dp_model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) # best vec model_fname_vec = 'model-best-vec-i{:05d}-score{:.4f}.pth'.format( iteration, best_val_score) checkpoint_path = os.path.join(opt.checkpoint_path, model_fname_vec) torch.save(dp_vectorModel.state_dict(), checkpoint_path) print("model_vec saved to {}".format(checkpoint_path)) with open(os.path.join(opt.checkpoint_path, infos_fname), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt): # Deal with feature things before anything acc_steps = getattr(opt, 'acc_steps', 1) loader = DataLoaderRaw(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'), 'rb') as f: infos = utils.pickle_load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f: histories = utils.pickle_load(f) else: infos['iter'] = 0 infos['epoch'] = 0 infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['vocab'] = loader.get_vocab() infos['opt'] = opt iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) opt.vocab = loader.get_vocab() dp_model = models.setup(opt) model = dp_model.cuda() del opt.vocab dp_lw_model = LossWrapper(dp_model, opt) lw_model = dp_lw_model epoch_done = True # Assure in training mode dp_lw_model.train() if opt.noamopt: assert opt.caption_model in [ 'transformer', 'mngrcnn' ], 'noamopt can only work with transformer' optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) optimizer._step = iteration elif opt.reduce_on_plateau: optimizer = utils.build_optimizer(model.parameters(), opt) optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) else: optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) def save_checkpoint(model, infos, optimizer, histories=None, append=''): if len(append) > 0: append = '-' + append # if checkpoint_path doesn't exist if not os.path.isdir(opt.checkpoint_path): os.makedirs(opt.checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' % (append)) torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' % (append)) torch.save(optimizer.state_dict(), optimizer_path) with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '%s.pkl' % (append)), 'wb') as f: utils.pickle_dump(infos, f) if histories: with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '%s.pkl' % (append)), 'wb') as f: utils.pickle_dump(histories, f) try: while True: if epoch_done: if not opt.noamopt and not opt.reduce_on_plateau: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate # set the decayed rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min( opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False epoch_done = False start = time.time() if (opt.use_warmup == 1) and (iteration < opt.noamopt_warmup): opt.current_lr = opt.learning_rate * \ (iteration+1) / opt.noamopt_warmup utils.set_lr(optimizer, opt.current_lr) # Load data from train split (0) data = loader.get_batch('train') print('Read data:', time.time() - start) if (iteration % acc_steps == 0): optimizer.zero_grad() torch.cuda.synchronize() start = time.time() tmp = [data['attr'], data['img'], data['labels'], data['masks']] tmp = [_ if _ is None else _.cuda() for _ in tmp] attrs, imgs, labels, masks = tmp model_out = dp_lw_model(attrs, imgs, labels, masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag) loss = model_out['loss'].mean() loss_sp = loss / acc_steps loss_sp.backward() if ((iteration + 1) % acc_steps == 0): utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() torch.cuda.synchronize() train_loss = loss.item() end = time.time() if not sc_flag: print( "{}: iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" .format( time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), iteration, epoch, train_loss, end - start)) else: print( "{}: iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" .format( time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), iteration, epoch, model_out['reward'].mean(), end - start)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 epoch_done = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) if opt.noamopt: opt.current_lr = optimizer.rate() elif opt.reduce_on_plateau: opt.current_lr = optimizer.current_lr add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration) loss_history[ iteration] = train_loss if not sc_flag else model_out[ 'reward'].mean() lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # update infos infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, lw_model.crit, loader, eval_kwargs) if opt.reduce_on_plateau: if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: optimizer.scheduler_step(val_loss) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True # Dump miscalleous informations infos['best_val_score'] = best_val_score histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history save_checkpoint(model, infos, optimizer, histories) if opt.save_history_ckpt: save_checkpoint(model, infos, optimizer, append=str(iteration)) if best_flag: save_checkpoint(model, infos, optimizer, append='best') # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break except (RuntimeError, KeyboardInterrupt): print('Save ckpt on exception ...') save_checkpoint(model, infos, optimizer) print('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace)
def train(opt): # on node-13 this line cauuses a bug from torch.utils.tensorboard import SummaryWriter ################################ # Build dataloader ################################ # the loader here needs to be fixed actually... # so that data loading is correct # need to modify opt here so everything else is correct loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length ########################## # Initialize infos ########################## infos = { 'iter': 0, 'epoch': 0, 'loader_state_dict': None, 'vocab': loader.get_vocab(), } # Load old infos(if there is) and check if models are compatible if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')): raise Exception("not implemented") infos['opt'] = opt ######################### # Build logger ######################### # naive dict logger histories = defaultdict(dict) if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f: histories.update(utils.pickle_load(f)) # tensorboard logger tb_summary_writer = SummaryWriter(opt.checkpoint_path) ########################## # Build model ########################## # opt.vocab = loader.get_vocab() opt.vocab = loader.get_vocab() model = TransformerLM(opt).cuda() # only set up the language model del opt.vocab # Load pretrained weights: if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, 'model.pth')): model.load_state_dict( torch.load(os.path.join(opt.start_from, 'model.pth'))) # Wrap generation model with loss function(used for training) # This allows loss function computed separately on each machine lw_model = LossWrapper(model, opt) # Wrap with dataparallel dp_model = torch.nn.DataParallel(model) dp_lw_model = torch.nn.DataParallel(lw_model) ########################## # Build optimizer ########################## if opt.noamopt: assert opt.caption_model == 'transformer', 'noamopt can only work with transformer' optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) elif opt.reduce_on_plateau: optimizer = utils.build_optimizer(model.parameters(), opt) optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) else: optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) ######################### # Get ready to start ######################### iteration = infos['iter'] epoch = infos['epoch'] # For back compatibility if 'iterators' in infos: infos['loader_state_dict'] = { split: { 'index_list': infos['split_ix'][split], 'iter_counter': infos['iterators'][split] } for split in ['train', 'val', 'test'] } loader.load_state_dict(infos['loader_state_dict']) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) if opt.noamopt: optimizer._step = iteration # flag indicating finish of an epoch # Always set to True at the beginning to initialize the lr or etc. epoch_done = True # Assure in training mode dp_lw_model.train() # Start training try: while True: if epoch_done: if not opt.noamopt and not opt.reduce_on_plateau: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # set the decayed rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min( opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False # If start structure loss training if opt.structure_after != -1 and epoch >= opt.structure_after: struc_flag = True init_scorer(opt.cached_tokens) else: struc_flag = False epoch_done = False start = time.time() # Load data from train split (0) data = loader.get_batch('train') print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else _.cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp optimizer.zero_grad() model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag) loss = model_out['loss'].mean() loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() end = time.time() if struc_flag: print( "iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), end - start)) elif not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, model_out['reward'].mean(), end - start)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 epoch_done = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): tb_summary_writer.add_scalar('train_loss', train_loss, iteration) if opt.noamopt: opt.current_lr = optimizer.rate() elif opt.reduce_on_plateau: opt.current_lr = optimizer.current_lr tb_summary_writer.add_scalar('learning_rate', opt.current_lr, iteration) tb_summary_writer.add_scalar('scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: tb_summary_writer.add_scalar('avg_reward', model_out['reward'].mean(), iteration) elif struc_flag: tb_summary_writer.add_scalar( 'lm_loss', model_out['lm_loss'].mean().item(), iteration) tb_summary_writer.add_scalar( 'struc_loss', model_out['struc_loss'].mean().item(), iteration) tb_summary_writer.add_scalar( 'reward', model_out['reward'].mean().item(), iteration) histories['loss_history'][ iteration] = train_loss if not sc_flag else model_out[ 'reward'].mean() histories['lr_history'][iteration] = opt.current_lr histories['ss_prob_history'][iteration] = model.ss_prob # update infos infos['iter'] = iteration infos['epoch'] = epoch infos['loader_state_dict'] = loader.state_dict() # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, lw_model.crit, loader, eval_kwargs) if opt.reduce_on_plateau: if lang_stats is not None: if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: optimizer.scheduler_step(val_loss) else: optimizer.scheduler_step(val_loss) # Write validation result into summary tb_summary_writer.add_scalar('validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): tb_summary_writer.add_scalar(k, v, iteration) histories['val_result_history'][iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True # Dump miscellaneous information infos['best_val_score'] = best_val_score utils.save_checkpoint(opt, model, infos, optimizer, histories) if opt.save_history_ckpt: utils.save_checkpoint(opt, model, infos, optimizer, append=str(iteration)) if best_flag: utils.save_checkpoint(opt, model, infos, optimizer, append='best') # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break except (RuntimeError, KeyboardInterrupt): print('Save ckpt on exception ...') utils.save_checkpoint(opt, model, infos, optimizer) print('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace)
def train(opt): print("=================Training Information==============") print("start from {}".format(opt.start_from)) print("box from {}".format(opt.input_box_dir)) print("input json {}".format(opt.input_json)) print("attributes from {}".format(opt.input_att_dir)) print("features from {}".format(opt.input_fc_dir)) print("batch size ={}".format(opt.batch_size)) print("#GPU={}".format(torch.cuda.device_count())) # Deal with feature things before anything opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model) if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 acc_steps = getattr(opt, 'acc_steps', 1) name_append = opt.name_append if len(name_append) > 0 and name_append[0] != '-': name_append = '_' + name_append loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length opt.write_summary = write_summary if opt.write_summary: print("write summary to {}".format(opt.checkpoint_path)) tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible infors_path = os.path.join(opt.start_from, 'infos' + name_append + '.pkl') print("Load model information {}".format(infors_path)) with open(infors_path, 'rb') as f: infos = utils.pickle_load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme histories_path = os.path.join(opt.start_from, 'histories_' + name_append + '.pkl') if os.path.isfile(histories_path): with open(histories_path, 'rb') as f: histories = utils.pickle_load(f) else: # start from scratch print("Initialize training process from all begining") infos['iter'] = 0 infos['epoch'] = 0 infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['vocab'] = loader.get_vocab() infos['opt'] = opt iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) # sanity check for the saved model name has a correct index if opt.name_append.isdigit() and int(opt.name_append) < 100: assert int( opt.name_append ) == epoch, "dismatch in the model index and the real epoch number" epoch += 1 print( "==================start from {} epoch================".format(epoch)) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) # pdb.set_trace() loader.iterators = infos.get('iterators', loader.iterators) start_Img_idx = loader.iterators['train'] loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) opt.vocab = loader.get_vocab() model = models.setup(opt).cuda() del opt.vocab dp_model = torch.nn.DataParallel(model) lw_model = LossWrapper(model, opt) # wrap loss into model dp_lw_model = torch.nn.DataParallel(lw_model) epoch_done = True # Assure in training mode dp_lw_model.train() if opt.noamopt: assert opt.caption_model in [ 'transformer', 'aoa' ], 'noamopt can only work with transformer' optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) optimizer._step = iteration elif opt.reduce_on_plateau: optimizer = utils.build_optimizer(model.parameters(), opt) optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) else: optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None: optimizer_path = os.path.join(opt.start_from, 'optimizer' + name_append + '.pth') if os.path.isfile(optimizer_path): print("Loading optimizer............") optimizer.load_state_dict(torch.load(optimizer_path)) def save_checkpoint(model, infos, optimizer, histories=None, append=''): if len(append) > 0: append = '_' + append # if checkpoint_path doesn't exist if not os.path.isdir(opt.checkpoint_path): os.makedirs(opt.checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' % (append)) torch.save(model.state_dict(), checkpoint_path) print("Save model state to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' % (append)) torch.save(optimizer.state_dict(), optimizer_path) print("Save model optimizer to {}".format(optimizer_path)) with open( os.path.join(opt.checkpoint_path, 'infos' + '%s.pkl' % (append)), 'wb') as f: utils.pickle_dump(infos, f) print("Save training information to {}".format( os.path.join(opt.checkpoint_path, 'infos' + '%s.pkl' % (append)))) if histories: with open( os.path.join(opt.checkpoint_path, 'histories_' + '%s.pkl' % (append)), 'wb') as f: utils.pickle_dump(histories, f) print("Save training historyes to {}".format( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '%s.pkl' % (append)))) try: while True: # pdb.set_trace() if epoch_done: if not opt.noamopt and not opt.reduce_on_plateau: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # set the decayed rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min( opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False epoch_done = False print("{}th Epoch Training starts now!".format(epoch)) with tqdm(total=len(loader.split_ix['train']), initial=start_Img_idx) as pbar: for i in range(start_Img_idx, len(loader.split_ix['train']), opt.batch_size): # import ipdb; ipdb.set_trace() start = time.time() if (opt.use_warmup == 1) and (iteration < opt.noamopt_warmup): opt.current_lr = opt.learning_rate * ( iteration + 1) / opt.noamopt_warmup utils.set_lr(optimizer, opt.current_lr) # Load data from train split (0) data = loader.get_batch('train') # print('Read data:', time.time() - start) if (iteration % acc_steps == 0): optimizer.zero_grad() torch.cuda.synchronize() start = time.time() tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else _.cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag) loss = model_out['loss'].mean() loss_sp = loss / acc_steps loss_sp.backward() if ((iteration + 1) % acc_steps == 0): utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() torch.cuda.synchronize() train_loss = loss.item() end = time.time() # if not sc_flag: # print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" # .format(iteration, epoch, train_loss, end - start)) # else: # print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" # .format(iteration, epoch, model_out['reward'].mean(), end - start)) if not sc_flag: pbar.set_description( "iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" .format(iteration, epoch, train_loss, end - start)) else: pbar.set_description( "iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" .format(iteration, epoch, model_out['reward'].mean(), end - start)) # Update the iteration and epoch iteration += 1 pbar.update(opt.batch_size) if data['bounds']['wrapped']: # save after each epoch save_checkpoint(model, infos, optimizer, append=str(epoch)) epoch += 1 # infos['epoch'] = epoch epoch_done = True # Write validation result into summary if (iteration % opt.losses_log_every == 0) and opt.write_summary: add_summary_value(tb_summary_writer, 'loss/train_loss', train_loss, iteration) if opt.noamopt: opt.current_lr = optimizer.rate() elif opt.reduce_on_plateau: opt.current_lr = optimizer.current_lr add_summary_value(tb_summary_writer, 'hyperparam/learning_rate', opt.current_lr, iteration) add_summary_value( tb_summary_writer, 'hyperparam/scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration) loss_history[ iteration] = train_loss if not sc_flag else model_out[ 'reward'].mean() lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # update infos infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix # make evaluation on validation set, and save model # TODO modify it to evaluate by each epoch # ipdb.set_trace() if (iteration % opt.save_checkpoint_every == 0) and eval_ and epoch > 20: model_path = os.path.join( opt.checkpoint_path, 'model_itr%s.pth' % (iteration)) eval_kwargs = { 'split': 'val', 'dataset': opt.input_json, 'model': model_path } eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, lw_model.crit, loader, eval_kwargs) if opt.reduce_on_plateau: if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: optimizer.scheduler_step(val_loss) # Write validation result into summary if opt.write_summary: add_summary_value(tb_summary_writer, 'loss/validation loss', val_loss, iteration) if lang_stats is not None: bleu_dict = {} for k, v in lang_stats.items(): if 'Bleu' in k: bleu_dict[k] = v if len(bleu_dict) > 0: tb_summary_writer.add_scalars( 'val/Bleu', bleu_dict, epoch) for k, v in lang_stats.items(): if 'Bleu' not in k: add_summary_value( tb_summary_writer, 'val/' + k, v, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True # Dump miscalleous informations infos['best_val_score'] = best_val_score histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history # save_checkpoint(model, infos, optimizer, histories, append=str(iteration)) save_checkpoint(model, infos, optimizer, histories) # if opt.save_history_ckpt: # save_checkpoint(model, infos, optimizer, append=str(iteration)) if best_flag: save_checkpoint(model, infos, optimizer, append='best') print( "update best model at {} iteration--{} epoch". format(iteration, epoch)) start_Img_idx = 0 # if epoch_done: # go through the set, start a new epoch loop # break # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: print("epoch {} break all".format(epoch)) save_checkpoint(model, infos, optimizer) tb_summary_writer.close() print("============{} Training Done !==============".format( 'Refine' if opt.use_test or opt.use_val else '')) break except (RuntimeError, KeyboardInterrupt): print('Save ckpt on exception ...') save_checkpoint(model, infos, optimizer, append='_interrupt') print('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace)
def train(opt): ################################ # Build dataloader ################################ loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length ########################## # Initialize infos ########################## infos = { 'iter': 0, 'epoch': 0, 'loader_state_dict': None, 'vocab': loader.get_vocab(), } # Load old infos(if there is) and check if models are compatible if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')): with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'), 'rb') as f: infos = utils.pickle_load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert getattr(saved_model_opt, checkme) == getattr( opt, checkme ), "Command line argument and saved model disagree on '%s' " % checkme infos['opt'] = opt ######################### # Build logger ######################### # naive dict logger histories = defaultdict(dict) if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f: histories.update(utils.pickle_load(f)) # tensorboard logger tb_summary_writer = SummaryWriter(opt.checkpoint_path) ########################## # Build model ########################## USE_CUDA = torch.cuda.is_available() device = torch.device("cuda:0" if USE_CUDA else "cpu") opt.vocab = loader.get_vocab() model = models.setup(opt).cuda() del opt.vocab # Load pretrained weights: if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, 'model.pth')): model.load_state_dict( torch.load(os.path.join(opt.start_from, 'model.pth'))) # Wrap generation model with loss function(used for training) # This allows loss function computed separately on each machine lw_model = LossWrapper(model, opt) # Wrap with dataparallel dp_model = torch.nn.DataParallel(model) dp_lw_model = torch.nn.DataParallel(lw_model) dp_model.to(device) dp_lw_model.to(device) ########################## # Build optimizer ########################## optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) ######################### # Get ready to start ######################### iteration = infos['iter'] epoch = infos['epoch'] # For back compatibility if 'iterators' in infos: infos['loader_state_dict'] = { split: { 'index_list': infos['split_ix'][split], 'iter_counter': infos['iterators'][split] } for split in ['train', 'val', 'test'] } loader.load_state_dict(infos['loader_state_dict']) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) # flag indicating finish of an epoch # Always set to True at the beginning to initialize the lr or etc. epoch_done = True # Assure in training mode dp_lw_model.train() # Start training try: while True: # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break if epoch_done: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # set the decayed rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min( opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False # If start structure loss training if opt.structure_after != -1 and epoch >= opt.structure_after: struc_flag = True init_scorer(opt.cached_tokens) else: struc_flag = False epoch_done = False start = time.time() # Load data from train split (0) data = loader.get_batch('train') torch.cuda.synchronize() start = time.time() tmp = [ data['semantic_feat'], data["semantic1_feat"], data['att_feats'], data["att1_feats"], data["box_feat"], data["box1_feat"], data['labels'], data['masks'] ] tmp = [_ if _ is None else _.cuda() for _ in tmp] semantic_feat, semantic1_feat, att_feats, att1_feats, box_feat, box1_feat, labels, masks = tmp optimizer.zero_grad() model_out = dp_lw_model(semantic_feat, semantic1_feat, att_feats, att1_feats, box_feat, box1_feat, labels, masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag) loss = model_out['loss'].mean() loss.backward() if opt.grad_clip_value != 0: getattr(torch.nn.utils, 'clip_grad_%s_' % (opt.grad_clip_mode))(model.parameters(), opt.grad_clip_value) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() end = time.time() if struc_flag and iteration % opt.losses_log_every == 0: print("iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, cider = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), model_out['cider'].mean().item(), end - start)) elif not sc_flag and iteration % opt.losses_log_every == 0: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) else: if iteration % opt.losses_log_every == 0: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, model_out['reward'].mean(), end - start)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 epoch_done = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): tb_summary_writer.add_scalar('train_loss', train_loss, iteration) tb_summary_writer.add_scalar('learning_rate', opt.current_lr, iteration) tb_summary_writer.add_scalar('scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: tb_summary_writer.add_scalar('avg_reward', model_out['reward'].mean(), iteration) elif struc_flag: tb_summary_writer.add_scalar( 'lm_loss', model_out['lm_loss'].mean().item(), iteration) tb_summary_writer.add_scalar( 'struc_loss', model_out['struc_loss'].mean().item(), iteration) tb_summary_writer.add_scalar( 'reward', model_out['reward'].mean().item(), iteration) tb_summary_writer.add_scalar( 'reward_var', model_out['reward'].var(1).mean(), iteration) histories['loss_history'][ iteration] = train_loss if not sc_flag else model_out[ 'reward'].mean() histories['lr_history'][iteration] = opt.current_lr histories['ss_prob_history'][iteration] = model.ss_prob # update infos infos['iter'] = iteration infos['epoch'] = epoch infos['loader_state_dict'] = loader.state_dict() # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0 and not opt.save_every_epoch) or \ (epoch_done and opt.save_every_epoch): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, lw_model.crit, loader, eval_kwargs) # Write validation result into summary tb_summary_writer.add_scalar('validation loss', val_loss, iteration) histories['val_result_history'][iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats else: current_score = -val_loss print("val_loss = {:.3f}".format(val_loss)) best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True # Dump miscalleous informations infos['best_val_score'] = best_val_score utils.save_checkpoint(opt, model, infos, optimizer, histories) if opt.save_history_ckpt: utils.save_checkpoint( opt, model, infos, optimizer, append=str(epoch) if opt.save_every_epoch else str(iteration)) if best_flag: utils.save_checkpoint(opt, model, infos, optimizer, append='best') except (RuntimeError, KeyboardInterrupt): print('Save ckpt on exception ...') utils.save_checkpoint(opt, model, infos, optimizer) print('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace)
def train(opt): torch.cuda.set_device(opt.device) # opt.use_att = utils.if_use_att(opt.caption_model) opt.use_att = True if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length print(opt.seq_length) print(opt.checkpoint_path) tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')): with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) critic_loss_history = histories.get('critic_loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) variance_history = histories.get('variance_history', {}) time_history = histories.get('time_history', {}) pseudo_num_history = histories.get('pseudo_num_history', {}) pseudo_num_length_history = histories.get('pseudo_num_length_history', {}) pseudo_num_batch_history = histories.get('pseudo_num_batch_history', {}) sum_logits_history = histories.get('sum_logits_history', {}) reward_main_history = histories.get('reward_main_history', {}) first_order = histories.get('first_order_history', np.zeros(1)) second_order = histories.get('second_order_history', np.zeros(1)) first_order = torch.from_numpy(first_order).float().cuda() second_order = torch.from_numpy(second_order).float().cuda() loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) model = models.setup(opt).cuda() dp_model = model target_actor = models.setup(opt).cuda() ####################### Critic pretrain ##################################################################### ##### Critic with state as input # if opt.critic_model == 'state_critic': # critic_model = CriticModel(opt) # else: critic_model = AttCriticModel(opt) target_critic = AttCriticModel(opt) if vars(opt).get('start_from_critic', None) is not None and True: # check if all necessary files exist assert os.path.isdir(opt.start_from_critic), " %s must be a a path" % opt.start_from_critic print(os.path.join(opt.start_from_critic, opt.critic_model + '_model.pth')) critic_model.load_state_dict(torch.load(os.path.join(opt.start_from_critic, opt.critic_model + '_model.pth'))) target_critic.load_state_dict(torch.load(os.path.join(opt.start_from_critic, opt.critic_model + '_model.pth'))) critic_model = critic_model.cuda() target_critic = target_critic.cuda() critic_optimizer = utils.build_optimizer(critic_model.parameters(), opt) dp_model.eval() critic_iter = 0 init_scorer(opt.cached_tokens) critic_model.train() error_sum = 0 loss_vector_sum = 0 while opt.pretrain_critic == 1: if critic_iter > opt.pretrain_critic_steps: print('****************Finished critic training!') break data = loader.get_batch('train') torch.cuda.synchronize() start = time.time() tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp critic_model.train() critic_optimizer.zero_grad() # assert opt.critic_model == 'att_critic_vocab' # crit_loss, reward, std = critic_loss_fun(fc_feats, att_feats, att_masks, dp_model, critic_model, opt, data) crit_loss, reward, std = target_critic_loss_fun_mask(fc_feats, att_feats, att_masks, dp_model, critic_model, opt, data, target_critic, target_actor) crit_loss.backward() critic_optimizer.step() #TODO update target. for cp, tp in zip(critic_model.parameters(), target_critic.parameters()): tp.data = tp.data + opt.gamma_critic * (cp.data - tp.data) crit_train_loss = crit_loss.item() torch.cuda.synchronize() end = time.time() error_sum += crit_train_loss**0.5-std if (critic_iter % opt.losses_log_every == 0): print("iter {} , crit_train_loss = {:.3f}, difference = {:.3f}, difference_sum = {:.3f}, time/batch = {:.3f}" \ .format(critic_iter, crit_train_loss**0.5, crit_train_loss**0.5-std, error_sum, end - start)) print(opt.checkpoint_path) opt.importance_sampling = 1 critic_model.eval() _, _, _, _ = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader, critic_model, test_critic=True) critic_iter += 1 # make evaluation on validation set, and save model if (critic_iter % opt.save_checkpoint_every == 0): if not os.path.isdir(opt.checkpoint_path): os.mkdir(opt.checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, opt.critic_model + '_model.pth') torch.save(critic_model.state_dict(), checkpoint_path) ######################### Actor-critic Training ##################################################################### update_lr_flag = True # Assure in training mode dp_model.train() crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")): optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) # first_order = 0 # second_order = 0 while True: if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate ** frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False # Load data from train split (0) data = loader.get_batch('train') # if data['bounds']['it_pos_now'] > 5000: # loader.reset_iterator('train') # continue dp_model.train() critic_model.eval() torch.cuda.synchronize() start = time.time() gen_result = None tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp optimizer.zero_grad() if not sc_flag: loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:]) else: if opt.rl_type == 'sc': gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) pseudo_num = 0 pseudo_num_length = 0 elif opt.rl_type == 'reinforce': gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample') reward = get_reward(data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) pseudo_num_length = 0 pseudo_num = 0 elif opt.rl_type == 'arsm': loss, pseudo_num, pseudo_num_length, pseudo_num_batch, rewards_main, sum_logits = get_arm_loss_daniel(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) #print(loss) reward = np.zeros([2,2]) elif opt.rl_type == 'rf4': loss,_,_,_ = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) # print(loss) reward = np.zeros([2, 2]) elif opt.rl_type == 'importance_sampling': opt.importance_sampling = 1 loss, gen_result, reward, sample_logprobs_total = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = np.repeat(reward[:, np.newaxis], gen_result.shape[1], 1) std = np.std(reward) elif opt.rl_type == 'importance_sampling_critic': opt.importance_sampling = 1 loss, gen_result, reward, sample_logprobs_total = get_rf_loss(target_actor, fc_feats, att_feats, att_masks, data, opt, loader, target_critic) reward = np.repeat(reward[:, np.newaxis], gen_result.shape[1], 1) std = np.std(reward) elif opt.rl_type == 'ar': loss = get_ar_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = np.zeros([2,2]) elif opt.rl_type == 'mct': opt.rf_demean = 0 gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = get_reward(data, gen_result, opt) pseudo_num = 0 pseudo_num_length = 0 reward_cuda = torch.from_numpy(reward).float().cuda() mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0] final_reward = torch.cat([mct_baseline[:, 1:], reward_cuda[:, 0:1]], 1) final_reward = final_reward - torch.mean(final_reward) if opt.arm_step_sample == 'greedy': sample_logprobs = sample_logprobs * probs loss = rl_crit(sample_logprobs, gen_result.data, final_reward) elif opt.rl_type == 'mct_sc': opt.rf_demean = 0 gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = get_reward(data, gen_result, opt) pseudo_num = 0 pseudo_num_length = 0 reward_cuda = torch.from_numpy(reward).float().cuda() mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0] final_reward = torch.cat([mct_baseline[:, 1:], reward_cuda[:, 0:1]], 1) gen_result_sc, sample_logprobs_sc = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 1}, mode='sample') reward = get_reward(data, gen_result_sc, opt) final_reward = final_reward - torch.from_numpy(reward).float().cuda() loss = rl_crit(sample_logprobs, gen_result.data, final_reward) elif opt.rl_type == 'mct_critic': #TODO change the critic to attention if opt.critic_model == 'state_critic': opt.rf_demean = 0 gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) gen_result_pad = torch.cat( [gen_result.new_zeros(gen_result.size(0), 1, dtype=torch.long), gen_result], 1) critic_value = critic_model(gen_result_pad, fc_feats, att_feats, True, opt, att_masks).squeeze(2) reward, std = get_reward(data, gen_result, opt, critic=True) pseudo_num = 0 pseudo_num_length = 0 reward_cuda = torch.from_numpy(reward).float().cuda() mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0] final_reward = torch.cat([mct_baseline[:, 1:], reward_cuda[:, 0:1]], 1) print(critic_value.shape) loss = rl_crit(sample_logprobs, gen_result.data, final_reward - critic_value) critic_value, gen_result, sample_logprobs = critic_model(dp_model, fc_feats, att_feats, opt, att_masks) reward, std = get_reward(data, gen_result, opt, critic=True) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - critic_value[:,:-1].data) elif opt.critic_model == 'att_critic': opt.rf_demean = 0 gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) gen_result_pad = torch.cat( [gen_result.new_zeros(gen_result.size(0), 1, dtype=torch.long), gen_result], 1) critic_value = critic_model(gen_result_pad, fc_feats, att_feats, True, opt, att_masks).squeeze(2) reward, std = get_reward(data, gen_result, opt, critic=True) pseudo_num = 0 pseudo_num_length = 0 reward_cuda = torch.from_numpy(reward).float().cuda() mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0] final_reward = torch.cat([mct_baseline[:, 1:], reward_cuda[:, 0:1]], 1) print(critic_value.shape) loss = rl_crit(sample_logprobs, gen_result.data, final_reward - critic_value) elif opt.rl_type =='mct_baseline': opt.rf_demean = 0 gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = get_reward(data, gen_result, opt) pseudo_num = 0 pseudo_num_length = 0 reward_cuda = torch.from_numpy(reward).float().cuda() mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0] if opt.arm_step_sample == 'greedy': sample_logprobs = sample_logprobs * probs loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - mct_baseline) elif opt.rl_type == 'arsm_baseline': opt.arm_as_baseline = 1 opt.rf_demean = 0 gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = get_reward(data, gen_result, opt) reward_cuda = torch.from_numpy(reward).float().cuda() arm_baseline[arm_baseline < 0] = reward_cuda[arm_baseline < 0] if opt.arm_step_sample == 'greedy' and False: sample_logprobs = sample_logprobs * probs loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda - arm_baseline) elif opt.rl_type == 'ars_indicator': opt.arm_as_baseline = 1 opt.rf_demean = 0 gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) reward_cuda = torch.from_numpy(reward).float().cuda() loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda * arm_baseline) elif opt.rl_type == 'arsm_baseline_critic': opt.arm_as_baseline = 1 opt.rf_demean = 0 gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader, critic_model) reward, std = get_reward(data, gen_result, opt, critic=True) if opt.arm_step_sample == 'greedy': sample_logprobs = sample_logprobs * probs loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - arm_baseline) elif opt.rl_type == 'arsm_critic': #print(opt.critic_model) tic = time.time() loss = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader, critic_model) #print('arm_loss time', str(time.time()-tic)) reward = np.zeros([2, 2]) elif opt.rl_type == 'critic_vocab_sum': assert opt.critic_model == 'att_critic_vocab' tic = time.time() gen_result, sample_logprobs_total = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, total_probs=True, mode='sample') #batch, seq, vocab #print('generation time', time.time()-tic) gen_result_pad = torch.cat( [gen_result.new_zeros(gen_result.size(0), 1, dtype=torch.long), gen_result], 1) tic = time.time() critic_value = critic_model(gen_result_pad, fc_feats, att_feats, True, opt, att_masks) #batch, seq, vocab #print('critic time', time.time() - tic) probs = torch.sum(F.softmax(sample_logprobs_total, 2) * critic_value.detach(), 2) mask = (gen_result > 0).float() mask = torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1) loss = -torch.sum(probs * mask) / torch.sum(mask) reward = np.zeros([2, 2]) elif opt.rl_type == 'reinforce_critic': #TODO change the critic to attention if opt.critic_model == 'state_critic': critic_value, gen_result, sample_logprobs = critic_model(dp_model, fc_feats, att_feats, opt, att_masks) reward, std = get_reward(data, gen_result, opt, critic=True) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - critic_value[:,:-1].data) elif opt.critic_model == 'att_critic': gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') gen_result_pad = torch.cat( [gen_result.new_zeros(gen_result.size(0), 1, dtype=torch.long), gen_result], 1) critic_value = critic_model(gen_result_pad, fc_feats, att_feats, True, opt, att_masks).squeeze(2) reward, std = get_reward(data, gen_result, opt, critic=True) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - critic_value.data) if opt.mle_weights != 0: loss += opt.mle_weights * crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:]) #TODO make sure all sampling replaced by greedy for critic #### update the actor loss.backward() # with open(os.path.join(opt.checkpoint_path, 'best_embed.pkl'), 'wb') as f: # cPickle.dump(list(dp_model.embed.parameters())[0].data.cpu().numpy(), f) # with open(os.path.join(opt.checkpoint_path, 'best_logit.pkl'), 'wb') as f: # cPickle.dump(list(dp_model.logit.parameters())[0].data.cpu().numpy(), f) ## compute variance gradient = torch.zeros([0]).cuda() for i in model.parameters(): gradient = torch.cat((gradient, i.grad.view(-1)), 0) first_order = 0.9999 * first_order + 0.0001 * gradient second_order = 0.9999 * second_order + 0.0001 * gradient.pow(2) # print(torch.max(torch.abs(gradient))) variance = torch.mean(torch.abs(second_order - first_order.pow(2))).item() if opt.rl_type != 'arsm' or not sc_flag: utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() # ### update the critic if 'critic' in opt.rl_type: dp_model.eval() critic_model.train() utils.set_lr(critic_optimizer, opt.critic_learning_rate) critic_optimizer.zero_grad() #assert opt.critic_model == 'att_critic_vocab' crit_loss, reward, std = target_critic_loss_fun_mask(fc_feats, att_feats, att_masks, dp_model, critic_model, opt, data, target_critic, target_actor, gen_result=gen_result, sample_logprobs_total=sample_logprobs_total, reward=reward) crit_loss.backward() critic_optimizer.step() for cp, tp in zip(critic_model.parameters(), target_critic.parameters()): tp.data = tp.data + opt.gamma_critic * (cp.data - tp.data) for cp, tp in zip(dp_model.parameters(), target_actor.parameters()): tp.data = tp.data + opt.gamma_actor * (cp.data - tp.data) crit_train_loss = crit_loss.item() error_sum += crit_train_loss ** 0.5 - std train_loss = loss.item() torch.cuda.synchronize() end = time.time() if (iteration % opt.losses_log_every == 0): if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) print(opt.checkpoint_path) elif 'critic' in opt.rl_type: print( "iter {} , crit_train_loss = {:.3f}, difference = {:.3f}, difference_sum = {:.3f},variance = {:g}, time/batch = {:.3f}" \ .format(iteration, crit_train_loss ** 0.5, crit_train_loss ** 0.5 - std, error_sum, variance, end - start)) print(opt.checkpoint_path) critic_model.eval() _, _, _, _ = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader, critic_model, test_critic=True) else: print("iter {} (epoch {}), avg_reward = {:.3f}, variance = {:g}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:, 0]), variance, end - start)) print("pseudo num: ", pseudo_num) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward), iteration) add_summary_value(tb_summary_writer, 'variance', variance, iteration) loss_history[iteration] = train_loss if not sc_flag else np.mean(reward) critic_loss_history[iteration] = crit_train_loss if 'critic' in opt.rl_type else 0 lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob variance_history[iteration] = variance pseudo_num_history[iteration] = pseudo_num reward_main_history[iteration] = rewards_main #print(pseudo_num_length) #print(type(pseudo_num_length).__module__) if type(pseudo_num_length).__module__ != 'torch': print('not right') pseudo_num_length_history[iteration] = pseudo_num_length pseudo_num_batch_history[iteration] = pseudo_num_batch sum_logits_history[iteration] = sum_logits else: pseudo_num_length_history[iteration] = pseudo_num_length.data.cpu().numpy() pseudo_num_batch_history[iteration] = pseudo_num_batch.data.cpu().numpy() sum_logits_history[iteration] = sum_logits.data.cpu().numpy() time_history[iteration] = end - start # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, crit, loader, eval_kwargs) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k,v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = - val_loss best_flag = False if True: # if true if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True if not os.path.isdir(opt.checkpoint_path): os.mkdir(opt.checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, opt.critic_model + '_model.pth') torch.save(critic_model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['critic_loss_history'] = critic_loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history histories['variance_history'] = variance_history histories['pseudo_num_history'] = pseudo_num_history histories['pseudo_num_length_history'] = pseudo_num_length_history histories['pseudo_num_batch_history'] = pseudo_num_batch_history histories['sum_logits_history'] = sum_logits_history histories['reward_main_history'] = reward_main_history histories['time'] = time_history histories['first_order_history'] = first_order.data.cpu().numpy() histories['second_order_history'] = second_order.data.cpu().numpy() # histories['variance'] = 0 with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f: cPickle.dump(infos, f) with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f: cPickle.dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt): # opt.use_att = utils.if_use_att(opt.caption_model) opt.use_att = True if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length print(opt.checkpoint_path) tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) critic_loss_history = histories.get('critic_loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) variance_history = histories.get('variance_history', {}) time_history = histories.get('time_history', {}) pseudo_num_history = histories.get('pseudo_num_history', {}) pseudo_num_depth_history = histories.get('pseudo_num_depth_history', {}) pseudo_num_length_history = histories.get('pseudo_num_length_history', {}) pseudo_num_batch_history = histories.get('pseudo_num_batch_history', {}) reward_batch_history = histories.get('reward_batch_history', {}) entropy_batch_history = histories.get('entropy_batch_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) model = models.setup(opt).cuda() dp_model = model ######################### Actor-critic Training ##################################################################### update_lr_flag = True # Assure in training mode dp_model.train() #TODO: change this to a flag crit = utils.LanguageModelCriterion_binary() rl_crit = utils.RewardCriterion_binary() optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) first_order = 0 second_order = 0 while True: if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False # Load data from train split (0) data = loader.get_batch('train') if data['bounds']['it_pos_now'] > 10000: loader.reset_iterator('train') continue dp_model.train() torch.cuda.synchronize() start = time.time() gen_result = None tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp optimizer.zero_grad() if not sc_flag: loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:], dp_model.depth, dp_model.vocab2code, dp_model.phi_list, dp_model.cluster_size) else: if opt.rl_type == 'sc': gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda(), dp_model.depth) elif opt.rl_type == 'reinforce': gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = get_reward(data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda(), dp_model.depth) elif opt.rl_type == 'arm': loss, pseudo_num, pseudo_num_depth, pseudo_num_length, pseudo_num_batch, reward_batch, entropy_batch = dp_model.get_arm_loss_binary_fast( fc_feats, att_feats, att_masks, opt, data, loader) #print(loss) reward = np.zeros([2, 2]) elif opt.rl_type == 'rf4': loss, _, _, _ = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) # print(loss) reward = np.zeros([2, 2]) elif opt.rl_type == 'ar': loss = get_ar_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = np.zeros([2, 2]) elif opt.rl_type == 'mct_baseline': opt.rf_demean = 0 gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss( dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = get_reward(data, gen_result, opt) reward_cuda = torch.from_numpy(reward).float().cuda() mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0] if opt.arm_step_sample == 'greedy': sample_logprobs = sample_logprobs * probs loss = rl_crit( sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - mct_baseline) elif opt.rl_type == 'arsm_baseline': opt.arm_as_baseline = 1 opt.rf_demean = 0 gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss( dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = get_reward(data, gen_result, opt) reward_cuda = torch.from_numpy(reward).float().cuda() arm_baseline[arm_baseline < 0] = reward_cuda[arm_baseline < 0] if opt.arm_step_sample == 'greedy' and False: sample_logprobs = sample_logprobs * probs loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda - arm_baseline) elif opt.rl_type == 'ars_indicator': opt.arm_as_baseline = 1 opt.rf_demean = 0 gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss( dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) reward_cuda = torch.from_numpy(reward).float().cuda() loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda * arm_baseline) if opt.mle_weights != 0: loss += opt.mle_weights * crit( dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:]) #TODO make sure all sampling replaced by greedy for critic #### update the actor loss.backward() # with open(os.path.join(opt.checkpoint_path, 'embeddings.pkl'), 'wb') as f: # cPickle.dump(list(dp_model.embed.parameters())[0].data.cpu().numpy(), f) ## compute variance gradient = torch.zeros([0]).cuda() for i in model.parameters(): gradient = torch.cat((gradient, i.grad.view(-1)), 0) first_order = 0.999 * first_order + 0.001 * gradient second_order = 0.999 * second_order + 0.001 * gradient.pow(2) # print(torch.max(torch.abs(gradient))) variance = torch.mean(torch.abs(second_order - first_order.pow(2))).item() if opt.rl_type != 'arsm' or not sc_flag: utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() # ### update the critic train_loss = loss.item() torch.cuda.synchronize() end = time.time() if (iteration % opt.losses_log_every == 0): if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) print(opt.checkpoint_path) else: print("iter {} (epoch {}), avg_reward = {:.3f}, variance = {:g}, time/batch = {:.3f}, pseudo num = {:.3f}, " \ .format(iteration, epoch, np.mean(reward[:, 0]), variance, end - start, pseudo_num)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward), iteration) add_summary_value(tb_summary_writer, 'variance', variance, iteration) loss_history[iteration] = train_loss if not sc_flag else np.mean( reward) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob variance_history[iteration] = variance time_history[iteration] = end - start pseudo_num_history[iteration] = pseudo_num.item() pseudo_num_length_history[iteration] = pseudo_num_length.data.cpu( ).numpy() pseudo_num_depth_history[iteration] = pseudo_num_depth.data.cpu( ).numpy() pseudo_num_batch_history[iteration] = pseudo_num_batch.data.cpu( ).numpy() reward_batch_history[iteration] = reward_batch entropy_batch_history[iteration] = entropy_batch.data.cpu().numpy() # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils_binary.eval_split( dp_model, crit, loader, eval_kwargs) print('1') # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss print('2') best_flag = False if True: # if true if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True if not os.path.isdir(opt.checkpoint_path): os.mkdir(opt.checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) print('3') checkpoint_path = os.path.join(opt.checkpoint_path, opt.critic_model + '_model.pth') print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) print('4') # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['critic_loss_history'] = critic_loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history histories['variance_history'] = variance_history histories['time'] = time_history histories['pseudo_num_history'] = pseudo_num_history histories[ 'pseudo_num_length_history'] = pseudo_num_length_history histories[ 'pseudo_num_depth_history'] = pseudo_num_depth_history histories[ 'pseudo_num_batch_history'] = pseudo_num_batch_history histories['reward_batch_history'] = reward_batch_history histories['entropy_batch_history'] = entropy_batch_history # histories['variance'] = 0 with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '-best.pkl'), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt): # Load data print('Loading dataset...') loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length loader.input_encoding_size = opt.input_encoding_size BERT_features = None if opt.cached_bert_features == "": # Extract BERT features print('Extracting pretrained BERT features...') BERT_features = process_bert.extract_BERT_features(loader, opt) with open(opt.data_path + 'BERT_features.pkl', 'wb') as f: pickle.dump(BERT_features, f) else: # Load BERT tokenization results print('Loading pretrained BERT features...') with open(opt.data_path + 'BERT_features.pkl', 'rb') as f: BERT_features = pickle.load(f) bert_vocab_path = opt.data_path + 'bert-base-cased-vocab.txt' opt.vocab_size = loader.update_bert_tokens(bert_vocab_path, BERT_features) print('Vocabulary size: ' + str(opt.vocab_size)) # Load pretrained model, info file, histories file infos = {} histories = {} if opt.start_from is not None: with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = ["rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) # Create model model = models.setup(opt).cuda() dp_model = torch.nn.DataParallel(model) dp_model.train() # Loss function crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() # Optimizer and learning rate adjustment flag optimizer = utils.build_optimizer(model.parameters(), opt) update_lr_flag = True # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) # Training loop while True: # Update learning rate once per epoch if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob update_lr_flag = False # Load data from train split (0) start = time.time() data = loader.get_batch('train') data_time = time.time() - start start = time.time() # Unpack data torch.cuda.synchronize() tmp = [ data['bert_feats'], data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] bert_feats, fc_feats, att_feats, labels, masks, att_masks = tmp bert_feats.requires_grad = False # Forward pass and loss optimizer.zero_grad() outputs = dp_model(bert_feats, fc_feats, att_feats, labels, att_masks) loss = crit(outputs, labels[:, 1:], masks[:, 1:]) # Backward pass loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() # Print total_time = time.time() - start if iteration % opt.print_freq == 1: print("iter {} (epoch {}), train_loss = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, data_time, total_time)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Validate and save model if (iteration % opt.save_checkpoint_every == 0): # Evaluate model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, crit, loader, eval_kwargs) # Our metric is CIDEr if available, otherwise validation loss if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss # Save model in checkpoint path best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(histories, f) # Save model to unique file if new best model if best_flag: model_fname = 'model-best-i{:05d}-score{:.4f}.pth'.format( iteration, best_val_score) infos_fname = 'model-best-i{:05d}-infos.pkl'.format(iteration) checkpoint_path = os.path.join(opt.checkpoint_path, model_fname) torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open(os.path.join(opt.checkpoint_path, infos_fname), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt): # Deal with feature things before anything opt.use_att = utils.if_use_att(opt.caption_model) opt.use_fc = utils.if_use_fc(opt.caption_model) loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length infos = load_info(opt) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = infos.get('val_result_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) # Define and load model, optimizer, critics decoder = setup(opt).train().cuda() crit = utils.LanguageModelCriterion().cuda() rl_crit = utils.RewardCriterion().cuda() optimizer = utils.build_optimizer(decoder.parameters(), opt) models = {'decoder': decoder} optimizers = {'decoder': optimizer} save_nets_structure(models, opt) load_checkpoint(models, optimizers, opt) epoch_done = True sc_flag = False while True: if epoch_done: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) decoder.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False epoch_done = False # 1. fetch a batch of data from train split data = loader.get_batch('train') tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp sg_data = { key: data['sg_data'][key] if data['sg_data'][key] is None else torch.from_numpy(data['sg_data'][key]).cuda() for key in data['sg_data'] } # 2. Forward model and compute loss torch.cuda.synchronize() optimizer.zero_grad() if not sc_flag: out = decoder(sg_data, fc_feats, att_feats, labels, att_masks) loss = crit(out, labels[:, 1:], masks[:, 1:]) else: gen_result, sample_logprobs, core_args = decoder( sg_data, fc_feats, att_feats, att_masks, opt={ 'sample_max': 0, 'return_core_args': True }, mode='sample') reward = get_self_critical_reward(decoder, core_args, sg_data, fc_feats, att_feats, att_masks, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) # 3. Update model loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() # Update the iteration and epoch iteration += 1 # Write the training loss summary if (iteration % opt.log_loss_every == 0): # logging log logger.info("{} ({}), loss: {:.3f}".format(iteration, epoch, train_loss)) tb.add_values('loss', {'train': train_loss}, iteration) if data['bounds']['wrapped']: epoch += 1 epoch_done = True # Make evaluation and save checkpoint if (opt.save_checkpoint_every > 0 and iteration % opt.save_checkpoint_every == 0) or (opt.save_checkpoint_every == -1 and epoch_done): # eval model eval_kwargs = { 'split': 'val', 'dataset': opt.input_json, 'expand_features': False } eval_kwargs.update(vars(opt)) predictions, lang_stats = eval_utils.eval_split( decoder, loader, eval_kwargs) # log val results if not lang_stats is None: logger.info("Scores: {}".format(lang_stats)) tb.add_values('scores', lang_stats, epoch) val_result_history[epoch] = { 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result current_score = 0 if lang_stats is None else lang_stats['CIDEr'] best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() infos['val_result_history'] = val_result_history save_checkpoint(models, optimizers, infos, best_flag, opt) # Stop if reaching max epochs if epoch > opt.max_epochs and opt.max_epochs != -1: break
def train(opt): # Deal with feature things before anything opt.use_att = utils.if_use_att(opt.caption_model) if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length training_mode = 0 optimizer_reset = 0 change_mode1 = 0 change_mode2 = 0 use_rela = getattr(opt, 'use_rela', 0) if use_rela: opt.rela_dict_size = loader.rela_dict_size #need another parameter to control how to train the model tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open( os.path.join( opt.checkpoint_path, 'infos_' + opt.id + format(int(opt.start_from), '04') + '.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join( opt.checkpoint_path, 'histories_' + opt.id + format(int(opt.start_from), '04') + '.pkl')): with open( os.path.join( opt.checkpoint_path, 'histories_' + opt.id + format(int(opt.start_from), '04') + '.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) if epoch >= opt.step2_train_after and epoch < opt.step3_train_after: training_mode = 1 elif epoch >= opt.step3_train_after: training_mode = 2 else: training_mode = 0 val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) model = models.setup(opt).cuda() #dp_model = torch.nn.DataParallel(model) #dp_model = torch.nn.DataParallel(model, [0, 1]) dp_model = model for name, param in model.named_parameters(): print(name) update_lr_flag = True # Assure in training mode dp_model.train() crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() optimizer = utils.build_optimizer(model.parameters(), opt) optimizer_mem = optim.Adam([model.memory_cell], opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join( opt.checkpoint_path, 'optimizer' + opt.id + format(int(opt.start_from), '04') + '.pth')): optimizer.load_state_dict( torch.load( os.path.join( opt.checkpoint_path, 'optimizer' + opt.id + format(int(opt.start_from), '04') + '.pth'))) if (training_mode == 1 or training_mode == 2) and os.path.isfile( os.path.join( opt.checkpoint_path, 'optimizer_mem' + opt.id + format(int(opt.start_from), '04') + '.pth')): optimizer_mem.load_state_dict( torch.load( os.path.join( opt.checkpoint_path, 'optimizer_mem' + opt.id + format(int(opt.start_from), '04') + '.pth'))) optimizer.zero_grad() optimizer_mem.zero_grad() accumulate_iter = 0 reward = np.zeros([1, 1]) train_loss = 0 while True: # if optimizer_reset == 1: # print("++++++++++++++++++++++++++++++") # print('reset optimizer') # print("++++++++++++++++++++++++++++++") # optimizer = utils.build_optimizer(model.parameters(), opt) # optimizer_mem = optim.Adam([model.memory_cell], opt.learning_rate, (opt.optim_alpha, opt.optim_beta), # opt.optim_epsilon, # weight_decay=opt.weight_decay) # optimizer_reset = 0 if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False start = time.time() # Load data from train split (0) data = loader.get_batch(opt.train_split) print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() if epoch >= opt.step2_train_after and epoch < opt.step3_train_after: training_mode = 1 if change_mode1 == 0: change_mode1 = 1 optimizer_reset = 1 elif epoch >= opt.step3_train_after: training_mode = 2 if change_mode2 == 0: change_mode2 = 1 optimizer_reset = 1 else: training_mode = 0 fc_feats = None att_feats = None att_masks = None ssg_data = None rela_data = None tmp = [data['fc_feats'], data['labels'], data['masks']] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, labels, masks = tmp tmp = [ data['att_feats'], data['att_masks'], data['rela_rela_matrix'], data['rela_rela_masks'], data['rela_attr_matrix'], data['rela_attr_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] att_feats, att_masks, rela_rela_matrix, rela_rela_masks, \ rela_attr_matrix, rela_attr_masks = tmp rela_data = {} rela_data['att_feats'] = att_feats rela_data['att_masks'] = att_masks rela_data['rela_matrix'] = rela_rela_matrix rela_data['rela_masks'] = rela_rela_masks rela_data['attr_matrix'] = rela_attr_matrix rela_data['attr_masks'] = rela_attr_masks tmp = [ data['ssg_rela_matrix'], data['ssg_rela_masks'], data['ssg_obj'], data['ssg_obj_masks'], data['ssg_attr'], data['ssg_attr_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] ssg_rela_matrix, ssg_rela_masks, ssg_obj, ssg_obj_masks, ssg_attr, ssg_attr_masks = tmp ssg_data = {} ssg_data['ssg_rela_matrix'] = ssg_rela_matrix ssg_data['ssg_rela_masks'] = ssg_rela_masks ssg_data['ssg_obj'] = ssg_obj ssg_data['ssg_obj_masks'] = ssg_obj_masks ssg_data['ssg_attr'] = ssg_attr ssg_data['ssg_attr_masks'] = ssg_attr_masks if not sc_flag: loss = crit( dp_model(fc_feats, att_feats, labels, att_masks, rela_data, ssg_data, use_rela, training_mode), labels[:, 1:], masks[:, 1:]) else: gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, rela_data, ssg_data, use_rela, training_mode, opt={'sample_max': 0}, mode='sample') rela_data = {} rela_data['att_feats'] = att_feats rela_data['att_masks'] = att_masks rela_data['rela_matrix'] = rela_rela_matrix rela_data['rela_masks'] = rela_rela_masks rela_data['attr_matrix'] = rela_attr_matrix rela_data['attr_masks'] = rela_attr_masks reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, rela_data, ssg_data, use_rela, training_mode, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) accumulate_iter = accumulate_iter + 1 loss = loss / opt.accumulate_number loss.backward() if accumulate_iter % opt.accumulate_number == 0: if training_mode == 0: utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() optimizer.zero_grad() elif training_mode == 1: utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() optimizer.zero_grad() utils.clip_gradient(optimizer_mem, opt.grad_clip) optimizer_mem.step() optimizer_mem.zero_grad() elif training_mode == 2: utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() optimizer.zero_grad() utils.clip_gradient(optimizer_mem, opt.grad_clip) optimizer_mem.step() optimizer_mem.zero_grad() iteration += 1 accumulate_iter = 0 train_loss = loss.item() * opt.accumulate_number end = time.time() text_file = open(opt.id + '.txt', "aw") if not sc_flag: print("iter {} (epoch {}), train_model {}, train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, training_mode, train_loss, end - start)) text_file.write("iter {} (epoch {}), train_model {}, train_loss = {:.3f}, time/batch = {:.3f}\n" \ .format(iteration, epoch, training_mode, train_loss, end - start)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:, 0]), end - start)) text_file.write("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}\n" \ .format(iteration, epoch, np.mean(reward[:, 0]), end - start)) text_file.close() torch.cuda.synchronize() # Update the iteration and epoch if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0) and (accumulate_iter % opt.accumulate_number == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:, 0]), iteration) loss_history[iteration] = train_loss if not sc_flag else np.mean( reward[:, 0]) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0) and (accumulate_iter % opt.accumulate_number == 0): # eval model eval_kwargs = { 'split': 'test', 'dataset': opt.input_json, 'use_rela': use_rela, 'num_images': 1, } eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils_mem.eval_split( dp_model, crit, loader, training_mode, eval_kwargs) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if True: # if true save_id = iteration / opt.save_checkpoint_every if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join( opt.checkpoint_path, 'model' + opt.id + format(int(save_id), '04') + '.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join( opt.checkpoint_path, 'optimizer' + opt.id + format(int(save_id), '04') + '.pth') torch.save(optimizer.state_dict(), optimizer_path) if training_mode == 1 or training_mode == 2 or opt.caption_model == 'lstm_mem': optimizer_mem_path = os.path.join( opt.checkpoint_path, 'optimizer_mem' + opt.id + format(int(save_id), '04') + '.pth') torch.save(optimizer_mem.state_dict(), optimizer_mem_path) memory_cell = dp_model.memory_cell.data.cpu().numpy() memory_cell_path = os.path.join( opt.checkpoint_path, 'memory_cell' + opt.id + format(int(save_id), '04') + '.npz') np.savez(memory_cell_path, memory_cell=memory_cell) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open( os.path.join( opt.checkpoint_path, 'infos_' + opt.id + format(int(save_id), '04') + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join( opt.checkpoint_path, 'histories_' + opt.id + format(int(save_id), '04') + '.pkl'), 'wb') as f: cPickle.dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '-best.pkl'), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt): # Load data loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length # Tensorboard summaries (they're great!) tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) # Load pretrained model, info file, histories file infos = {} histories = {} if opt.start_from is not None: with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = ["rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) #ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) # Create model model = convcap(opt).cuda() # pretrained_dict = torch.load('log_xe_final_before_review/all2model12000.pth') # model.load_state_dict(pretrained_dict, strict=False) back_model = convcap(opt).cuda() back_model.train() # d_pretrained_dict = torch.load('log_xe_final_before_review/all2d_model12000.pth') # back_model.load_state_dict(d_pretrained_dict, strict=False) dp_model = model dp_model.train() dis_model = Discriminator(512, 512, 512, 0.2) dis_model = dis_model.cuda() dis_model.train() # dis_pretrained_dict = torch.load('./log_xe_final_before_review/all2dis_model12000.pth') # dis_model.load_state_dict(dis_pretrained_dict, strict=False) d_optimizer = utils.build_optimizer(dis_model.parameters(), opt) back_model.train() # Loss functio} crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() # Optimizer and learning rate adjustment flag optimizer = utils.build_optimizer_adam( chain(model.parameters(), back_model.parameters()), opt) #back_optimizer = utils.build_optimizer(back_model.parameters(), opt) update_lr_flag = True #Load the optimizer # if os.path.isfile(os.path.join('log_xe_final_before_review/',"optimizer.pth")): # optimizer.load_state_dict(torch.load(os.path.join('log_xe_final_before_review/', 'optimizer.pth'))) # print ('optimiser loaded') # print (optimizer) # Training loop while True: # Update learning rate once per epoch if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every #opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) #model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False # Load data from train split (0) start = time.time() data = loader.get_batch('train') data_time = time.time() - start start = time.time() # Unpack data torch.cuda.synchronize() tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['dist'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, dist_label, masks, attmasks = tmp labels = labels.long() labels[:, :, 0] = 8667 nd_labels = labels batchsize = fc_feats.size(0) # Forward pass and loss optimizer.zero_grad() d_steps = 1 g_steps = 1 #print (torch.sum(labels!=0), torch.sum(masks!=0)) if 1: if iteration >= 0: if 1: dp_model.eval() back_model.eval() with torch.no_grad(): _, x_all_d = dp_model(fc_feats, att_feats, nd_labels.long(), 30, 6) labels_nd = nd_labels.view(batchsize, -1) idx = [ i for i in range(labels_nd.size()[1] - 1, -1, -1) ] labels_flip_nd = labels_nd[:, idx] labels_flip_nd = labels_flip_nd.view(batchsize, 6, 30) labels_flip_nd[:, :, 0] = 8667 _, x_all_flip_d = back_model(fc_feats, att_feats, labels_flip_nd, 30, 6) x_all_d = x_all_d[:, :, :-1] x_all_flip_d = x_all_flip_d[:, :, :-1] idx = [ i for i in range(x_all_flip_d.size()[2] - 1, -1, -1) ] idx = torch.LongTensor(idx[1:]) idx = Variable(idx).cuda() invert_backstates = x_all_flip_d.index_select(2, idx) x_all_d.detach() invert_backstates.detach() x_all_d = x_all_d[:, :, :-1] autoregressive_scores = dis_model( x_all_d.transpose(2, 1).cuda()) teacher_forcing_scores = dis_model( invert_backstates.transpose(2, 1).cuda()) tf_loss, ar_loss = _calcualte_discriminator_loss( teacher_forcing_scores, autoregressive_scores) tf_loss.backward(retain_graph=True) ar_loss.backward() d_optimizer.step() for p in dis_model.parameters(): p.data.clamp_(-0.01, 0.01) torch.cuda.synchronize() total_time = time.time() - start if 1: dp_model.train() back_model.train() wordact, x_all = dp_model(fc_feats, att_feats, labels, 30, 6) mask = masks.view(batchsize, -1) mask = mask[:, 1:].contiguous() wordact = wordact[:, :, :-1] wordact_t = wordact.permute(0, 2, 1).contiguous() wordact_t = wordact_t.view( wordact_t.size(0) * wordact_t.size(1), -1) labels_flat = labels.view(batchsize, -1) wordclass_v = labels_flat[:, 1:] wordclass_t = wordclass_v.contiguous().view(\ wordclass_v.size(0) * wordclass_v.size(1), 1) maskids = torch.nonzero( mask.view(-1).cpu()).numpy().reshape(-1) loss_xe = F.cross_entropy(wordact_t[maskids, ...], \ wordclass_t[maskids, ...].contiguous().view(maskids.shape[0])).cuda() idx = [i for i in range(labels_flat.size()[1] - 1, -1, -1)] labels_flip = labels_flat[:, idx] labels_flip = labels_flip.view(batchsize, 6, 30) labels_flip[:, :, 0] = 8667 wordact, x_all_flip = back_model(fc_feats, att_feats, labels_flip, 30, 6) mask = masks.view(batchsize, -1).flip((1, )) reverse_mask = mask[:, 1:].contiguous() wordact = wordact[:, :, :-1] wordact_t = wordact.permute(0, 2, 1).contiguous() wordact_t = wordact_t.view( wordact_t.size(0) * wordact_t.size(1), -1) labels_flip = labels_flip.contiguous().view(-1, 6 * 30) wordclass_v = labels_flip[:, 1:] wordclass_t = wordclass_v.contiguous().view(\ wordclass_v.size(0) * wordclass_v.size(1), 1) maskids = torch.nonzero( reverse_mask.view(-1).cpu()).numpy().reshape(-1) loss_xe_flip = F.cross_entropy(wordact_t[maskids, ...], \ wordclass_t[maskids, ...].contiguous().view(maskids.shape[0])).cuda() train_loss = loss_xe x_all_flip = x_all_flip[:, :, :-1].cuda() x_all = x_all[:, :, :-1].cuda() idx = [i for i in range(x_all_flip.size()[2] - 1, -1, -1)] idx = torch.LongTensor(idx[1:]) idx = Variable(idx).cuda() invert_backstates = x_all_flip.index_select(2, idx) invert_backstates = invert_backstates.detach() l2_loss = ((x_all[:, :, :-1] - invert_backstates)**2).mean() autoregressive_scores = dis_model( x_all.transpose(2, 1).cuda()) ad_loss = _calculate_generator_loss( autoregressive_scores).sum() all_loss = loss_xe + loss_xe_flip + l2_loss ad_loss.backward(retain_graph=True) all_loss.backward() # utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() if 1: if iteration % opt.print_freq == 1: print('Read data:', time.time() - start) if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f},l2_loss= {:.3f}, flip_loss = {:.3f}, loss_ad = {:.3f}, fake = {:.3f}, real = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, loss_xe, l2_loss, loss_xe_flip, ad_loss, ar_loss, tf_loss, data_time, total_time)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) #add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:, 0]), iteration) loss_history[ iteration] = train_loss if not sc_flag else np.mean( reward[:, 0]) lr_history[iteration] = opt.current_lr #ss_prob_history[iteration] = model.ss_prob # Validate and save model if (iteration % opt.save_checkpoint_every == 0): checkpoint_path = os.path.join( opt.checkpoint_path, 'all2model{:05d}.pth'.format(iteration)) torch.save(model.state_dict(), checkpoint_path) checkpoint_path = os.path.join( opt.checkpoint_path, 'all2d_model{:05d}.pth'.format(iteration)) torch.save(back_model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) checkpoint_path = os.path.join( opt.checkpoint_path, 'all2dis_model{:05d}.pth'.format(iteration)) torch.save(dis_model.state_dict(), checkpoint_path) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Evaluate model if (iteration % 1000 == 0): eval_kwargs = {'split': 'test', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, crit, loader, eval_kwargs) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Our metric is CIDEr if available, otherwise validation loss if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss # Save model in checkpoint path best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'd_model.pth') torch.save(back_model.state_dict(), checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'dis_model.pth') torch.save(dis_model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history #histories['ss_prob_history'] = ss_prob_history with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(histories, f) # Save model to unique file if new best model if best_flag: model_fname = 'model-best-i{:05d}-score{:.4f}.pth'.format( iteration, best_val_score) infos_fname = 'model-best-i{:05d}-infos.pkl'.format(iteration) checkpoint_path = os.path.join(opt.checkpoint_path, model_fname) torch.save(model.state_dict(), checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'd_model-best.pth') torch.save(back_model.state_dict(), checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'dis_model-best.pth') torch.save(dis_model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open(os.path.join(opt.checkpoint_path, infos_fname), 'wb') as f: cPickle.dump(infos, f)
def main(configs, args): global net, dataloader, optimizer, lr_scheduler, writer, epochs, logger best_acc = 0 torch.manual_seed(6666) configs = init_configs(configs) net = build_model(configs) net = init_model(net, configs) net = net.cuda().train() print(net) if args.debug: configs.log_dir = os.path.join('debug', configs.log_dir) configs.ckpt.save_config_path = os.path.join( 'debug', configs.ckpt.save_config_path) configs.ckpt.save_model_path = os.path.join( 'debug', configs.ckpt.save_model_path) configs.ckpt.save_optim_path = os.path.join( 'debug', configs.ckpt.save_optim_path) check_dir(configs.log_dir) if not configs.do_test: config_path = configs.ckpt.save_config_path torch.save({'configs': configs}, os.path.join(config_path, 'configs.pth')) logger = create_logger(configs.log_dir, configs.cfg_name) writer = SummaryWriter(configs.log_dir) for name, param in net.named_parameters(): print('%s required grad is %s' % (name, param.requires_grad)) dataloader = build_dataset(configs) optimizer = build_optimizer(net.parameters(), configs.optimizer) optimizer = init_optim(optimizer, configs) lr_scheduler = get_lr_scheduler(configs.training) max_iterations = configs.training.max_episodes test_every_iterations = configs.testing.test_every_episodes for iteration in range(1, max_iterations + 1): try: if iteration % test_every_iterations == 0 or configs.do_test or ( args.debug and args.debug_test): epochs += 1 acc = test('test', configs) optim_path = configs.ckpt.save_optim_path model_path = configs.ckpt.save_model_path z, ind_z, den_z, images, labels = extract_features( 'test', configs) if not configs.do_test: torch.save({'model': net.state_dict()}, os.path.join(model_path, 'model_%d.pth' % iteration)) torch.save({'optim': optimizer.state_dict()}, os.path.join(optim_path, 'optim_%d.pth' % iteration)) torch.save( { 'z': z.numpy(), 'ind_z': ind_z.numpy(), 'den_z': den_z.numpy(), 'labels': labels, 'images': images }, os.path.join(model_path, 'results_%d.pth' % iteration)) if acc > best_acc: best_acc = acc torch.save({'model': net.state_dict()}, os.path.join(model_path, 'model_best.pth')) torch.save({'optim': optimizer.state_dict()}, os.path.join(optim_path, 'optim_best.pth')) if configs.do_test or (args.debug and args.debug_test): return train(iteration, configs) except KeyboardInterrupt: import ipdb ipdb.set_trace()
def train(opt): # opt.use_att = utils.if_use_att(opt.caption_model) opt.use_att = True if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 opt.vocab_size = 50 opt.seq_length = 10 opt.fc_feat_size = 100 opt.train_true = True opt.train_true_step = 100 np.random.seed(0) data_num = 5000 data_features = np.random.normal(size=[data_num, opt.fc_feat_size]) test_data_num = 1000 test_data_features = np.random.normal( size=[test_data_num, opt.fc_feat_size]) print(opt.checkpoint_path) tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) critic_loss_history = histories.get('critic_loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) variance_history = histories.get('variance_history', {}) time_history = histories.get('time_history', {}) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) model = models.setup(opt).cuda() dp_model = model #TODO: save true model true_model = models.setup(opt).cuda() if vars(opt).get('start_from', None) is not None: # check if all necessary files exist assert os.path.isdir( opt.start_from), " %s must be a a path" % opt.start_from assert os.path.isfile( os.path.join(opt.start_from, "infos_" + opt.id + ".pkl") ), "infos.pkl file does not exist in path %s" % opt.start_from true_model.load_state_dict( torch.load(os.path.join(opt.start_from, 'truemodel.pth'))) true_model.eval() ######################### Actor-critic Training ##################################################################### update_lr_flag = True # Assure in training mode dp_model.train() crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() optimizer = utils.build_optimizer(model.parameters(), opt) tm_optimizer = utils.build_optimizer(true_model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) first_order = 0 second_order = 0 while True: if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False dp_model.train() torch.cuda.synchronize() start = time.time() gen_result = None start_index = (iteration * opt.batch_size) % data_num end_index = start_index + opt.batch_size fc_feats = torch.from_numpy( data_features[start_index:end_index, :]).cuda().float() att_feats = None att_masks = None labels, total_logits = true_model(fc_feats, att_feats, att_masks, opt={'sample_max': 1}, total_probs=True, mode='sample') labels = torch.cat( [torch.zeros(labels.size(0), 1).cuda().long(), labels], 1) masks = (labels > 0).float() # train true model: if iteration < opt.train_true_step and opt.train_true: tm_optimizer.zero_grad() loss = -((total_logits * F.softmax(total_logits, 2)).sum(2)).mean() loss.backward() tm_optimizer.step() optimizer.zero_grad() if not sc_flag: loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:]) else: if opt.rl_type == 'sc': gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') gen_result_sc, _ = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 1}, mode='sample') reward = reward_fun(gen_result, fc_feats, true_model).unsqueeze(1).repeat( 1, sample_logprobs.size(1)) reward_sc = reward_fun(gen_result_sc, fc_feats, true_model).unsqueeze(1).repeat( 1, sample_logprobs.size(1)) reward = reward - reward_sc loss = rl_crit(sample_logprobs, gen_result.data, reward) reward = np.zeros([2, 2]) elif opt.rl_type == 'reinforce': gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = reward_fun(gen_result, fc_feats, true_model).unsqueeze(1).repeat( 1, sample_logprobs.size(1)) loss = rl_crit(sample_logprobs, gen_result.data, reward) reward = np.zeros([2, 2]) elif opt.rl_type == 'reinforce_demean': gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = reward_fun(gen_result, fc_feats, true_model).unsqueeze(1).repeat( 1, sample_logprobs.size(1)) loss = rl_crit(sample_logprobs, gen_result.data, reward - reward.mean()) reward = np.zeros([2, 2]) elif opt.rl_type == 'arsm': loss = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, true_model, opt) #print(loss) reward = np.zeros([2, 2]) elif opt.rl_type == 'ars': loss = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, true_model, opt, type='ars') #print(loss) reward = np.zeros([2, 2]) elif opt.rl_type == 'ar': loss = get_ar_loss(dp_model, fc_feats, att_feats, att_masks, true_model, opt) # print(loss) reward = np.zeros([2, 2]) elif opt.rl_type == 'mct_baseline': opt.rf_demean = 0 gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss( dp_model, fc_feats, att_feats, att_masks, opt, true_model) reward = reward_fun(gen_result, fc_feats, true_model).unsqueeze(1).repeat( 1, sample_logprobs.size(1)) reward_cuda = reward #mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0] loss = rl_crit(sample_logprobs, gen_result.data, reward - mct_baseline) if opt.mle_weights != 0: loss += opt.mle_weights * crit( dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:]) #TODO make sure all sampling replaced by greedy for critic #### update the actor loss.backward() # with open(os.path.join(opt.checkpoint_path, 'best_embed.pkl'), 'wb') as f: # cPickle.dump(list(dp_model.embed.parameters())[0].data.cpu().numpy(), f) # with open(os.path.join(opt.checkpoint_path, 'best_logit.pkl'), 'wb') as f: # cPickle.dump(list(dp_model.logit.parameters())[0].data.cpu().numpy(), f) ## compute variance gradient = torch.zeros([0]).cuda() for i in model.parameters(): gradient = torch.cat((gradient, i.grad.view(-1)), 0) first_order = 0.9999 * first_order + 0.0001 * gradient second_order = 0.9999 * second_order + 0.0001 * gradient.pow(2) # print(torch.max(torch.abs(gradient))) variance = torch.mean(torch.abs(second_order - first_order.pow(2))).item() if opt.rl_type != 'arsm' or not sc_flag: utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() end = time.time() if (iteration % opt.losses_log_every == 0): if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) print(opt.checkpoint_path) else: print("iter {} (epoch {}), avg_reward = {:.3f}, variance = {:g}, time/batch = {:.3f}" \ .format(iteration, epoch, reward.mean(), variance, end - start)) # Update the iteration and epoch iteration += 1 if (iteration * opt.batch_size) % data_num == 0: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', reward.mean(), iteration) add_summary_value(tb_summary_writer, 'variance', variance, iteration) #loss_history[iteration] = train_loss if not sc_flag else reward.mean() lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob variance_history[iteration] = variance time_history[iteration] = end - start # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model val_loss, lang_stats = eval_utils_syn(dp_model, true_model, test_data_features, opt.batch_size, crit) lang_stats = lang_stats.item() val_loss = val_loss.item() # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats } # Save model if is improving on validation result print('loss', val_loss, 'lang_stats', lang_stats) if True: # if true checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') if not os.path.isdir(opt.checkpoint_path): os.mkdir(opt.checkpoint_path) torch.save(model.state_dict(), checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'truemodel.pth') torch.save(true_model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = opt.vocab_size histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['critic_loss_history'] = critic_loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history histories['variance_history'] = variance_history histories['time'] = time_history # histories['variance'] = 0 with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(histories, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt, num_switching=None): global internal if opt.gpu2 is None: torch.cuda.set_device(opt.gpu) RL_count = 0 pure_reward = None # Deal with feature things before anything opt.use_att = utils.if_use_att(opt.caption_model) if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 # set dataloder loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length opt.baseline_concat = 0 # setting of record result_path = '/mnt/workspace2019/nakamura/selfsequential/log_python3/' + opt.checkpoint_path tb_summary_writer = tb and tb.SummaryWriter(result_path) infos = {} histories = {} # --- pretrained model loading --- # if opt.start_from is not None: opt.start_from = '/mnt/workspace2019/nakamura/selfsequential/log_python3/' + opt.start_from if opt.start_from is not None: # open old infos and check if models are compatible infos = cPickle.load(open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'), mode='rb')) saved_model_opt = infos['opt'] # need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"] need_be_same = ["rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars(opt)[ checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')): histories = cPickle.load(open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl') , mode='rb')) if opt.sf_epoch is not None and opt.sf_itr is not None: iteration = opt.sf_itr epoch = opt.sf_epoch else: iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) #---------------------------------------# # I forget about these parameter, they maybe are not used. b_regressor = None opt.regressor = b_regressor # model setting if opt.gpu2 is not None: model = models.setup(opt).cuda() dp_model = torch.nn.DataParallel(model) else: model = models.setup(opt).cuda() dp_model = model update_lr_flag = True # Assure in training mode dp_model.train() # set rl mode and internal critic and similairty model info_json = json.load(open(opt.input_json)) sim_model = None new_internal = None if opt.internal_model == 'sim' or opt.internal_model == 'sim_newr' or opt.internal_model == 'sim_dammy': # setting internal critic and similarity prediction network sim_model = sim.Sim_model(opt.input_encoding_size, opt.rnn_size, vocab_size=len(info_json['ix_to_word'])) if opt.region_bleu_flg == 0: if opt.sim_pred_type == 0: # model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim2/model_13_1700.pt' model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim_bu/model_6_0.pt' elif opt.sim_pred_type == 1: model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim_noshuffle04model_71_1300.pt' elif opt.sim_pred_type == 2: model_root = '/mnt/workspace2019/nakamura/selfsequential/sim_model/subset_similarity/model_0_3000.pt' else: print('select 0 or 1') exit() checkpoint = torch.load(model_root, map_location='cuda:0') sim_model.load_state_dict(checkpoint['model_state_dict']) sim_model.cuda() sim_model.eval() for param in sim_model.parameters(): param.requires_grad = False sim_model_optimizer = None elif opt.region_bleu_flg == 1: sim_model.cuda() if opt.sf_internal_epoch is not None: sim_model.load_state_dict( torch.load(os.path.join(opt.start_from, 'sim_model_' + str(opt.sf_internal_epoch) + '_' + str( opt.sf_internal_itr) + '.pth'))) # sim_model_optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'internal_optimizer_' + str( # opt.sf_internal_epoch) + '_' + str(opt.sf_internal_itr) + '.pth'))) sim_model_optimizer = utils.build_internal_optimizer(sim_model.parameters(), opt) else: print('not implimented') exit() if opt.only_critic_train == 1: random.seed(100) if opt.critic_encode==1: internal = models.CriticModel_with_encoder(opt) elif opt.bag_flg == 1: internal = models.CriticModel_bag(opt) elif opt.ppo == 1: # internal = models.CriticModel_sim(opt) internal = models.CriticModel_nodropout(opt) new_internal = models.CriticModel_nodropout(opt) internal.load_state_dict(new_internal.state_dict()) elif opt.input_h_flg == 1: internal = models.CriticModel_sim(opt) else: internal = models.CriticModel_sim_h(opt) internal = internal.cuda() if new_internal is not None: new_internal = new_internal.cuda() if opt.ppo == 1: internal_optimizer = utils.build_internal_optimizer(new_internal.parameters(), opt) else: internal_optimizer = utils.build_internal_optimizer(internal.parameters(), opt) if opt.sf_internal_epoch is not None: internal.load_state_dict(torch.load(os.path.join(opt.start_from,'internal_' + str(opt.sf_internal_epoch) + '_' + str( opt.sf_internal_itr) + '.pth'))) internal_optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'internal_optimizer_' + str( opt.sf_internal_epoch) + '_' + str(opt.sf_internal_itr) + '.pth'))) # new_internal = models.CriticModel_nodropout(opt) new_internal.load_state_dict(torch.load(os.path.join(opt.start_from,'internal_' + str(opt.sf_internal_epoch) + '_' + str( opt.sf_internal_itr) + '.pth'))) if opt.multi_learn_flg != 1: if opt.internal_rl_flg == 1: internal_rl_flg = True dp_model.eval() else: internal.eval() internal_rl_flg = False else: internal_rl_flg = True else: if opt.sim_reward_flg > 0: # setting internal critic and similarity prediction network sim_model = sim.Sim_model(opt.input_encoding_size, opt.rnn_size, vocab_size=len(info_json['ix_to_word'])) if opt.sim_pred_type == 0: # model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim2/model_13_1700.pt' # model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim_bu/model_6_0.pt' model_root = '/mnt/workspace2019/nakamura/selfsequential/sim_model/no_shuffle_simforcoco/model_37_34000.pt' elif opt.sim_pred_type == 1: model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim_noshuffle04model_71_1300.pt' elif opt.sim_pred_type == 2: model_root = '/mnt/workspace2019/nakamura/selfsequential/sim_model/subset_similarity/model_0_3000.pt' else: print('select 0 or 1') exit() if opt.region_bleu_flg == 0: if opt.sim_pred_type == 0: # model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim2/model_13_1700.pt' opt.sim_model_dir = '/mnt/workspace2018/nakamura/vg_feature/model_cossim_bu/model_6_0.pt' elif opt.sim_pred_type == 1: opt.sim_model_dir = '/mnt/workspace2018/nakamura/vg_feature/model_cossim_noshuffle04model_71_1300.pt' elif opt.sim_pred_type == 2: opt.sim_model_dir = '/mnt/workspace2019/nakamura/selfsequential/sim_model/subset_similarity/model_0_3000.pt' else: opt.sim_model_dir = '/mnt/workspace2019/nakamura/selfsequential/log_python3/log_' + opt.id + '/sim_model' + opt.model[-13:-4] + '.pth' checkpoint = torch.load(opt.sim_model_dir, map_location='cuda:0') sim_model.load_state_dict(checkpoint['model_state_dict']) sim_model.cuda() sim_model.eval() for param in sim_model.parameters(): param.requires_grad = False sim_model_optimizer = None elif opt.region_bleu_flg == 1: sim_model_optimizer = utils.build_internal_optimizer(sim_model.parameters(), opt) sim_model.cuda() internal = None internal_optimizer = None internal_rl_flg = False opt.c_current_lr = 0 # opt.internal = internal # set Discriminator if opt.discriminator_weight > 0: dis_opt = opt if opt.dis_type == 'coco': discrimiantor_model_dir = '/mnt/workspace2018/nakamura/selfsequential/discriminator_log/coco/discriminator_150.pth' dis_opt.input_label_h5 = '/mnt/poplin/share/dataset/MSCOCO/cocotalk_coco_for_discriminator_label.h5' dis_opt.input_json = '/mnt/poplin/share/dataset/MSCOCO/cocotalk_coco_for_discriminator.json' elif opt.dis_type == 'iapr': discrimiantor_model_dir = '/mnt/workspace2018/nakamura/selfsequential/discriminator_log/iapr_dict/discriminator_125.pth' dis_opt.input_label_h5 = '/mnt/workspace2019/visual_genome_pretrain/iapr_talk_cocodict_label.h5' dis_opt.input_json = '/mnt/workspace2018/nakamura/IAPR/iapr_talk_cocodict.json' elif opt.dis_type == 'ss': discrimiantor_model_dir = '/mnt/workspace2019/nakamura/selfsequential/discriminator_log/shuttorstock_dict/discriminator_900.pth' dis_opt.input_label_h5 = '/mnt/workspace2019/nakamura/shutterstock/shuttorstock_talk_cocodict_label.h5' dis_opt.input_json = '/mnt/workspace2019/nakamura/shutterstock/shuttorstock_talk_cocodict.json' elif opt.dis_type == 'sew': discrimiantor_model_dir = '/mnt/workspace2019/nakamura/selfsequential/discriminator_log/sew/discriminator_900.pth' dis_opt.input_label_h5 = '/mnt/poplin/share/dataset/simple_english_wikipedia/sew_talk_label.h5' dis_opt.input_json = '/mnt/poplin/share/dataset/simple_english_wikipedia/sew_talk.json' elif opt.dis_type == 'sew_cut5': discrimiantor_model_dir = '/mnt/workspace2019/nakamura/selfsequential/discriminator_log/sew_cut5/discriminator_90.pth' dis_opt.input_label_h5 = '/mnt/poplin/share/dataset/simple_english_wikipedia/sew_talk_label.h5' dis_opt.input_json = '/mnt/poplin/share/dataset/simple_english_wikipedia/sew_talk.json' opt.cut_length = 5 elif opt.dis_type == 'vg_cut5': opt.cut_length = 5 discrimiantor_model_dir = '/mnt/workspace2019/nakamura/selfsequential/discriminator_log/vg_cut5/discriminator_200.pth' dis_opt.input_label_h5 = '/mnt/poplin/share/dataset/MSCOCO/cocotalk_subset_vg_larger_label.h5' dis_opt.input_json = '/mnt/poplin/share/dataset/MSCOCO/cocotalk_subset_vg_larger_addregions.json' else: print('select existing discriminative model!') exit() discriminator_path_learned = os.path.join(result_path, 'discriminator_{}_{}.pth'.format(epoch, iteration)) Discriminator = dis_utils.Discriminator(opt) if os.path.isfile(discriminator_path_learned): Discriminator.load_state_dict(torch.load(discriminator_path_learned, map_location='cuda:' + str(opt.gpu))) else: Discriminator.load_state_dict(torch.load(discrimiantor_model_dir, map_location='cuda:' + str(opt.gpu))) Discriminator = Discriminator.cuda() # change discriminator learning rate # opt.learning_rate = opt.learning_rate/10 dis_optimizer = utils.build_optimizer(Discriminator.parameters(), opt) # for group in dis_optimizer.param_groups: # group['lr'] = opt.learning_rate/100 Discriminator.eval() dis_loss_func = nn.BCELoss().cuda() dis_loader = dis_dataloader.DataLoader(dis_opt) else: Discriminator = None dis_loader = None dis_optimizer = None # set Acter Critic network if opt.actor_critic_flg == 1: Q_net = models.Actor_Critic_Net_upper(opt) target_Q_net = models.Actor_Critic_Net_upper(opt) Q_net.load_state_dict(target_Q_net.state_dict()) target_model = models.setup(opt).cuda() target_model.load_state_dict(model.state_dict()) target_model.eval() Q_net.cuda() target_Q_net.cuda() Q_net_optimizer = utils.build_optimizer(Q_net.parameters(), opt) elif opt.actor_critic_flg == 2: Q_net = models.Actor_Critic_Net_seq(opt) target_Q_net = models.Actor_Critic_Net_seq(opt) Q_net.load_state_dict(target_Q_net.state_dict()) target_model = models.setup(opt).cuda() target_model.load_state_dict(model.state_dict()) target_model.eval() Q_net.cuda() target_Q_net.cuda() Q_net_optimizer = utils.build_optimizer(Q_net.parameters(), opt) seq_mask = torch.zeros((opt.batch_size * opt.seq_per_img, opt.seq_length, opt.seq_length)).cuda().type(torch.cuda.LongTensor) for i in range(opt.seq_length): seq_mask[:, i, :i] += 1 elif opt.t_model_flg == 1: target_model = models.setup(opt).cuda() target_model.load_state_dict(model.state_dict()) target_model.eval() else: target_model = None baseline = None new_model = None # set functions calculating loss if opt.caption_model == 'hcatt_hard' or opt.caption_model == 'basicxt_hard' or opt.caption_model == 'hcatt_hard_nregion' or opt.caption_model == 'basicxt_hard_nregion' : if opt.ppo == 1: new_model = models.setup(opt).cuda() new_model.load_state_dict(model.state_dict()) # new_optimizer = utils.build_optimizer(new_model.parameters(), opt) # new_model.eval() # If you use hard attention, use this setting (but is is not implemented completely) crit = utils.LanguageModelCriterion_hard() else: crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() rl_crit_hard = utils.RewardCriterion_hard() rl_crit_conly = utils.RewardCriterion_conly() rl_crit_hard_base = utils.RewardCriterion_hard_baseline() att_crit = utils.AttentionCriterion() if opt.caption_model == 'hcatt_hard' and opt.ppo == 1: optimizer = utils.build_optimizer(new_model.parameters(), opt) else: # set optimizer optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")): if opt.sf_epoch is None: optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) else: optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer_' + str(opt.sf_epoch) + '_' +str(opt.sf_itr) + '.pth'))) critic_train_count = 0 total_critic_reward = 0 pre_para = None #------------------------------------------------------------------------------------------------------------# # training start while True: train_loss = 0 if update_lr_flag: # cahnge lr opt, optimizer, model, internal_optimizer, dis_optimizer = utils.change_lr(opt, epoch, optimizer, model, internal_optimizer, dis_optimizer) # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True # internal_rl_flg == False init_scorer(opt.cached_tokens, len(info_json['ix_to_word'])) else: sc_flag = False update_lr_flag = False # # !!!!! # internal_rl_flg = False # model.train() # internal.eval() # #!!!!! # Load data from train split (0) data = loader.get_batch('train') torch.cuda.synchronize() start = time.time() # get datch tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'], data['bbox'], data['sub_att'], data['fixed_region']] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks, bbox, sub_att, fixed_region = tmp optimizer.zero_grad() # calculating loss... if not sc_flag: # use cross entropy if opt.weight_deterministic_flg > 0: weight_index = np.array(data['weight_index']) # fc_feats = fc_feats * 0.0 output = dp_model(fc_feats, att_feats, labels, att_masks, internal, weight_index=weight_index) # output = dp_model(fc_feats, att_feats, labels, att_masks, internal, weight_index=None) else: output = dp_model(fc_feats, att_feats, labels, att_masks, internal) if opt.caption_model == 'hcatt_prob': print(torch.exp(output).mean(), model.probs.mean()) output = output + model.probs.view(output.size(0), output.size(1), 1) loss = crit(output, labels[:,1:], masks[:,1:]) elif opt.caption_model != 'hcatt_hard' and opt.caption_model != 'hcatt_hard_nregion'and opt.caption_model != 'basicxt_hard_nregion' and opt.caption_model != 'basicxt_hard': loss = crit(output, labels[:,1:], masks[:,1:]) else: if baseline is None: baseline = torch.zeros((output.size()[0], output.size()[1]))/output.size()[1] baseline = baseline.cuda() # baseline = torch.log(baseline) # print('pre:', baseline.mean().item()) loss, baseline = crit(output, labels[:,1:], masks[:,1:], baseline, dp_model.weights_p, dp_model.weights) # print('after:', baseline.mean().item()) else: # use rl if opt.weight_deterministic_flg > 0: weight_index = np.array(data['weight_index']) else: weight_index = None if dp_model.training: sample_max_flg = 0 else: sample_max_flg = 1 # get predicted captions and logprops, similarity gen_result, sample_logprobs, word_exist_seq = dp_model(fc_feats, att_feats, att_masks,internal, opt={'sample_max':sample_max_flg}, sim_model = sim_model, New_Critic=new_internal, bbox=bbox, sub_att=sub_att, label_region = data['label_region'], weight_index=weight_index,mode='sample') train_similarity = dp_model.similarity # ---------- learning discriminator ---------------- if Discriminator is not None and opt.dis_adv_flg == 1 and internal_rl_flg == False: correct = 0 Discriminator.train() fake_data = gen_result.data.cpu() hokan = torch.zeros((len(fake_data), 1)).type(torch.LongTensor) fake_data = torch.cat((hokan, fake_data, hokan), 1).cuda() fake_data = fake_data[:, 1:] label = torch.ones((fake_data.size(0))).cuda() # pdb.set_trace() Discriminator, dis_optimizer, correct, neg_loss = \ dis_utils.learning_func(Discriminator, dis_optimizer, fake_data, label, correct, 0, opt.cut_length, opt.random_disc, opt.all_switch_end_dis, opt.all_switch_dis, loss_func=dis_loss_func, weight_index=weight_index, model_gate=model.gate.data.cpu().numpy()) dis_data = dis_loader.get_batch('train', batch_size=fake_data.size(0)) real_data = torch.from_numpy(dis_data['labels']).cuda() real_data = real_data[:, 1:] Discriminator, dis_optimizer, correct, pos_loss = \ dis_utils.learning_func(Discriminator, dis_optimizer, real_data, label, correct, 1, opt.cut_length, 0, 0, 0, loss_func=dis_loss_func, weight_index=weight_index) loss_mean = (pos_loss + neg_loss) / 2 dis_accuracy = correct/(fake_data.size(0) * 2) print('Discriminator loss: {}, accuracy: {}'.format(loss_mean, dis_accuracy)) Discriminator.eval() else: loss_mean = -1.0 dis_accuracy = -1.0 # -------------------------------------------------- # ---------- calculate att loss ----------- if opt.att_reward_flg == 1 and model.training: # if opt.att_reward_flg == 1 : att_loss = att_crit(model, gen_result.data.cpu().numpy()) att_loss_num = att_loss.data.cpu().numpy() else: att_loss = 0.0 att_loss_num = 0.0 # ------------------------------------------ # --- get states and actions xt and weights, ccs, seqs --- if opt.actor_critic_flg==1 and model.training: xts = model.all_xts weights_p = model.weights_p ccs = internal.output_action if opt.actor_critic_flg == 2 and model.training: all_logprops = model.all_logprops weight_state = model.state_weights # xts = model.all_xts gen_result_repeat = gen_result.repeat(1, opt.seq_length).view(all_logprops.size(0), opt.seq_length, opt.seq_length) # xts = seq_mask * gen_result_repeat xts = gen_result_repeat weights_p = model.weights_p # pdb.set_trace() if internal is not None: ccs = internal.output_action else: ccs = torch.zeros((len(xts), weights_p.size(1))).cuda() if opt.caption_model == 'hcatt_hard' and opt.ppo==1: xts = model.all_xts weights_p = model.weights_p weights = model.weights # ---------------------------------------------------------- # ---------------- Calculate reward (CIDEr, Discriminator, Similarity...)--------------------- if opt.actor_critic_flg == 2 and model.training: reward, pure_reward = get_self_critical_and_similarity_reward_for_actor_critic(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt, train_similarity, internal=internal, sim_model=sim_model, label_region=data['label_region'], D=Discriminator) else: reward, pure_reward, actor_critic_reward, target_update_flg = get_self_critical_and_similarity_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt, train_similarity, internal=internal, sim_model=sim_model, label_region=data['label_region'], bbox=bbox, D=Discriminator, weight_index=weight_index, fixed_region=fixed_region, target_model=target_model) if target_update_flg and target_model is not None: print('----- target model updated ! -----') target_model.load_state_dict(model.state_dict()) # print(train_similarity.mean(), model.similarity.mean()) #---------------------------------------------------------- #-------------------------------- calculate captioning model loss ----------------------------------------- #------------ Calculate actor critic loss ---------------- if opt.actor_critic_flg == 1 and model.training: # get q_value q_value = Q_net(fc_feats, att_feats, xts, weights_p, gen_result) # get target_sample with torch.no_grad(): gen_result_sample, __ = target_model(fc_feats, att_feats, att_masks, seqs=gen_result, ccs=ccs, mode='sample') target_q_value = target_Q_net(fc_feats, att_feats, target_model.all_xts, target_model.weights_p, gen_result) # calculate actor critic loss actor_critic_loss = Q_net.loss_func(actor_critic_reward, q_value, target_q_value) add_summary_value(tb_summary_writer, 'actor_critic_loss', actor_critic_loss.item(), iteration, opt.tag) Q_net_optimizer.zero_grad() elif opt.actor_critic_flg == 2 and model.training: # get q_value q_value = Q_net(fc_feats, att_feats, xts, weight_state.detach(), weights_p, all_logprops[:,:-1,:], gen_result) # get target_sample with torch.no_grad(): gen_result_sample, __ = target_model(fc_feats, att_feats, att_masks, seqs=gen_result, ccs=ccs, mode='sample', state_weights=weight_state) # pdb.set_trace() target_q_value = target_Q_net(fc_feats, att_feats, xts, target_model.state_weights, target_model.weights_p, target_model.all_logprops[:,:-1,:], gen_result) # calculate actor critic loss if reward is None: pdb.set_trace() actor_critic_loss = Q_net.loss_func(reward, q_value, target_q_value, gen_result) print('actor_critic_loss', actor_critic_loss.item()) add_summary_value(tb_summary_writer, 'actor_critic_loss', actor_critic_loss.item(), iteration, opt.tag) Q_net_optimizer.zero_grad() else: actor_critic_loss = 0 model.att_score = att_loss_num # update ppo old policy if new_internal is not None and internal.iteration % 1 == 0: internal.load_state_dict(new_internal.state_dict()) if opt.caption_model == 'hcatt_hard' and opt.ppo == 1: model.load_state_dict(new_model.state_dict()) if not internal_rl_flg or opt.multi_learn_flg == 1: # if opt.ppo == 1 and opt.caption_model == 'hcatt_hard': # -------------- calculaete self critical loss --------------- if False: # get coeffitient and calculate new_gen_result, new_sample_logprobs = new_model(fc_feats, att_feats, att_masks, seqs=gen_result, mode='sample', decided_att=weights) new_model.pre_weights_p = new_model.weights_p new_model.pre_weights = new_model.weights att_index = np.where(weights.data.cpu() > 0) weights_p_ = weights_p[att_index].view(weights_p.size(0), weights_p.size(1)) # (batch, seq_length) reward_coefficient = 1 / (torch.exp(sample_logprobs) * weights_p_).data.cpu() # train caption network get reward and calculate loss reward_loss, baseline = utils.calculate_loss(opt, rl_crit, rl_crit_hard, rl_crit_hard_base, new_sample_logprobs, gen_result, reward, baseline, new_model, reward_coefficient=reward_coefficient) elif (not internal_rl_flg or opt.multi_learn_flg == 1) and opt.actor_critic_flg == 0: # train caption network get reward and calculate loss if opt.weight_deterministic_flg == 7: reward_loss, baseline = utils.calculate_loss(opt, rl_crit, rl_crit_hard, rl_crit_hard_base, sample_logprobs, word_exist_seq, reward, baseline, model) else: reward_loss, baseline = utils.calculate_loss(opt, rl_crit, rl_crit_hard, rl_crit_hard_base, sample_logprobs, gen_result, reward, baseline, model) else: reward_loss = 0 # -------------- calculaete self critical loss --------------- if (opt.caption_model == 'hcatt_simple' or opt.caption_model == 'hcatt_simple_switch') and opt.xe_weight > 0.0: output = dp_model(fc_feats, att_feats, labels, att_masks, internal, weight_index=weight_index) xe_loss = crit(output, labels[:, 1:], masks[:, 1:]) print('r_loss: {}, xe_loss: {}'.format(reward_loss.item(), xe_loss.item())) add_summary_value(tb_summary_writer, 'xe_loss', xe_loss.item(), iteration, opt.tag) add_summary_value(tb_summary_writer, 'r_loss', reward_loss.item(), iteration, opt.tag) else: xe_loss = 0.0 loss = opt.rloss_weight * reward_loss + opt.att_lambda * att_loss + actor_critic_loss + opt.xe_weight * xe_loss # -------------------------------------------------------------------------------------------------------- # ------------------------- calculate internal critic loss and update --------------------------- if internal_optimizer is not None and internal_rl_flg == True and sc_flag: internal_optimizer.zero_grad() if opt.region_bleu_flg == 1: sim_model_optimizer.zero_grad() if opt.only_critic_train == 0: internal_loss = rl_crit(internal.pre_output, gen_result.data, torch.from_numpy(reward).float().cuda(), reward_coefficient=internal.pre_reward_coefficient) else: internal_loss = rl_crit_conly(internal.pre_output, gen_result.data, torch.from_numpy(reward).float().cuda(), reward_coefficient=internal.pre_reward_coefficient, c_count=critic_train_count) q_value_prop = torch.exp(internal.pre_output) entropy = torch.mean(-1 * q_value_prop * torch.log2(q_value_prop + 1e-8) + -1 * (1 - q_value_prop) * torch.log2( 1 - q_value_prop + 1e-8)) internal_loss = internal_loss internal_loss.backward() internal_optimizer.step() if opt.region_bleu_flg == 1: sim_model_optimizer.step() # ----- record loss and reward to tensorboard ----- # q_value_prop = torch.exp(internal.pre_output) # entropy = torch.mean(-1 * q_value_prop * torch.log2(q_value_prop + 1e-8) + -1 * (1 - q_value_prop) * torch.log2(1 - q_value_prop + 1e-8)) if opt.only_critic_train == 1: if internal is not None and sc_flag: num_internal_switching = internal.same_action_flg.mean().item() else: num_internal_switching = 0 total_critic_reward += np.mean(pure_reward) total_critic_reward = utils.record_tb_about_critic(model, internal_loss.cpu().data, critic_train_count, opt.tag, tb_summary_writer, reward, pure_reward, entropy, opt.sim_sum_flg,num_internal_switching, total_critic_reward=total_critic_reward) else: if internal is not None and sc_flag: num_internal_switching = internal.same_action_flg.mean().item() else: num_internal_switching = 0 total_critic_reward = utils.record_tb_about_critic(model, internal_loss.cpu().data, iteration, opt.tag, tb_summary_writer, reward, pure_reward, entropy, opt.sim_sum_flg, num_internal_switching) # ------------------------------------------------- critic_train_count += 1 internal.reset() internal.iteration+=1 print('iter {} (epoch {}), internal_loss: {}, avg_reward: {}, entropy: {}'.format(iteration, epoch,internal_loss, reward.mean(), entropy)) # -------------------------------------------------------------------------------------------------------- else: #------------------------- updating captioning model ---------------------------- loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() if opt.actor_critic_flg > 0 and model.training: utils.clip_gradient(Q_net_optimizer, opt.grad_clip) Q_net_optimizer.step() utils.soft_update(target_model, model, 0.001) utils.soft_update(target_Q_net, Q_net, 0.001) # if iteration % 1000 == 0: # utils.hard_update(target_model, model) # utils.hard_update(target_Q_net, Q_net) # else: # utils.soft_update(target_model, model, 0.001) # utils.soft_update(target_Q_net, Q_net, 0.001) train_loss = loss.item() torch.cuda.synchronize() del loss end = time.time() if internal is not None and sc_flag: num_internal_switching = internal.same_action_flg.mean().item() else: num_internal_switching = 0 if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) else: try: print("iter {} (epoch {}), avg_reward = {:.3f}, att_loss = {:.3f}. time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:,0]), model.att_score.item(), end - start)) utils.record_tb_about_model(model, pure_reward, tb_summary_writer, iteration, opt.tag, opt.sim_sum_flg, loss_mean, dis_accuracy, num_internal_switching) except AttributeError: print("iter {} (epoch {}), avg_reward = {:.3f}, att_loss = {:.3f}. time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:, 0]), model.att_score, end - start)) utils.record_tb_about_model(model, pure_reward, tb_summary_writer, iteration, opt.tag, opt.sim_sum_flg, loss_mean, dis_accuracy, num_internal_switching) RL_count += 1 # -------------------------------------------------------------------------------- # Update the iteration and epoch iteration += 1 # -------------------- change train internal critic or caption network ----------------------------- if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True if opt.cycle is None and internal is not None and opt.multi_learn_flg != 1: # and entropy < 1.0 if internal_rl_flg == True and opt.only_critic_train == 0: if opt.actor_critic_flg == 1: utils.hard_update(target_model, model) utils.hard_update(target_Q_net, Q_net) internal_rl_flg = False internal.eval() dp_model.train() if weight_index is not None and loader.weight_deterministic_flg == 4: loader.weight_deterministic_flg = 5 if opt.region_bleu_flg == 1: sim_model.eval() train_loss = None # elif internal_optimizer is not None and internal_rl_flg == False: # elif internal_optimizer is not None and internal_rl_flg == False and (epoch + 1) % 3 == 0 and opt.internal_model != 'sim_dammy': # elif internal_optimizer is not None and internal_rl_flg == False and opt.internal_model != 'sim_dammy': else: internal_rl_flg = True # internal.load_state_dict(torch.load(result_path + '/internal_best.pth')) if opt.ppo == 1: internal_optimizer = optim.Adam(new_internal.parameters(), opt.c_learning_rate, weight_decay=1e-5) else: internal_optimizer = optim.Adam(internal.parameters(), opt.c_learning_rate, weight_decay=1e-5) internal.train() if opt.region_bleu_flg == 1: sim_model.train() dp_model.eval() if weight_index is not None and loader.weight_deterministic_flg == 5: loader.weight_deterministic_flg = 4 internal.reset() internal.max_r = 0 # -------------------------------------------------------------------------------------------------- # ------------------- Write the training loss summary ------------------------------ if (iteration % opt.losses_log_every == 0) and internal_rl_flg == False and train_loss is not None: add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration, opt.tag) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration, opt.tag) add_summary_value(tb_summary_writer, 'critic_learning_rate', opt.c_current_lr, iteration, opt.tag) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration, opt.tag) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration, opt.tag) loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0]) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # ---------------------------------------------------------------------------------- # ------------------------ make evaluation on validation set, and save model ------------------------------ wdf7_eval_flg = (opt.weight_deterministic_flg != 7 or sc_flag) if ((iteration % opt.save_checkpoint_every == 0) or iteration == 39110 or iteration == 113280 or iteration == 151045 or iteration == 78225 or iteration == 31288 or iteration == 32850 or iteration == 46934) and train_loss is not None: if sc_flag and (opt.caption_model == 'hcatt_hard' or opt.caption_model == 'basicxt_hard' or opt.caption_model == 'hcatt_hard_nregion' or opt.caption_model == 'basicxt_hard_nregion'): if baseline is None: baseline = torch.zeros((sample_logprobs.size()[0], sample_logprobs.size()[1] + 1)) / sample_logprobs.size()[1] baseline = baseline.cuda() # baseline = torch.log(baseline) # eval model varbose_loss = not sc_flag eval_kwargs = {'split': 'val', 'internal': internal, 'sim_model': sim_model, 'caption_model': opt.caption_model, 'baseline': baseline, 'gts': data['gts'], 'dataset': opt.dataset, 'verbose_loss': varbose_loss, 'weight_deterministic_flg': opt.weight_deterministic_flg } eval_kwargs.update(vars(opt)) # pdb.set_trace() if wdf7_eval_flg: # eval_utils.eval_writer(dp_model, iteration, loader, tb_summary_writer, eval_kwargs) val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, crit, loader, eval_kwargs) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration, opt.tag) if lang_stats is not None: for k,v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration, opt.tag) val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = - val_loss else: val_result_history[iteration] = {'loss': None, 'lang_stats': None, 'predictions': None} current_score = 0 best_flag = False if True: # if true if internal_rl_flg == False: if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(result_path, 'model_{}_{}.pth'.format(epoch, iteration)) torch.save(model.state_dict(), checkpoint_path) optimizer_path = os.path.join(result_path, 'optimizer_{}_{}.pth'.format(epoch, iteration)) torch.save(optimizer.state_dict(), optimizer_path) print("model saved to {}".format(checkpoint_path)) if internal is not None: internal.eval() checkpoint_path = os.path.join(result_path, 'internal_{}_{}.pth'.format(epoch, iteration)) torch.save(internal.state_dict(), checkpoint_path) optimizer_path = os.path.join(result_path, 'internal_optimizer_{}_{}.pth'.format(epoch, iteration)) torch.save(internal_optimizer.state_dict(), optimizer_path) print("internal model saved to {}".format(checkpoint_path)) checkpoint_path = os.path.join(result_path, 'sim_model_{}_{}.pth'.format(epoch, iteration)) torch.save(sim_model.state_dict(), checkpoint_path) print("sim_model saved to {}".format(checkpoint_path)) else: checkpoint_path = os.path.join(result_path, 'model_{}_{}.pth'.format(epoch, iteration)) torch.save(model.state_dict(), checkpoint_path) optimizer_path = os.path.join(result_path, 'optimizer_{}_{}.pth'.format(epoch, iteration)) torch.save(optimizer.state_dict(), optimizer_path) print("model saved to {}".format(checkpoint_path)) if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(result_path, 'internal_{}_{}.pth'.format(epoch, iteration)) torch.save(internal.state_dict(), checkpoint_path) optimizer_path = os.path.join(result_path, 'internal_optimizer_{}_{}.pth'.format(epoch, iteration)) torch.save(internal_optimizer.state_dict(), optimizer_path) print("internal model saved to {}".format(checkpoint_path)) checkpoint_path = os.path.join(result_path, 'sim_model_{}_{}.pth'.format(epoch, iteration)) torch.save(sim_model.state_dict(), checkpoint_path) print("sim_model saved to {}".format(checkpoint_path)) dp_model.eval() if Discriminator is not None: discriminator_path = os.path.join(result_path, 'discriminator_{}_{}.pth'.format(epoch, iteration)) torch.save(Discriminator.state_dict(), discriminator_path) dis_optimizer_path = os.path.join(result_path, 'dis_optimizer_{}_{}.pth'.format(epoch, iteration)) torch.save(dis_optimizer.state_dict(), dis_optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() infos['internal_rl_flg'] = internal_rl_flg histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open(os.path.join(result_path, 'infos_'+opt.id+'.pkl'), 'wb') as f: cPickle.dump(infos, f) with open(os.path.join(result_path, 'histories_'+opt.id+'.pkl'), 'wb') as f: cPickle.dump(histories, f) if best_flag: checkpoint_path = os.path.join(result_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open(os.path.join(result_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f: cPickle.dump(infos, f) # pdb.set_trace() # --------------------------------------------------------------------------------------------------------- # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt): acc_steps = getattr(opt, 'acc_steps', 1) loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length opt.ix_to_word = loader.ix_to_word tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'), 'rb') as f: infos = utils.pickle_load(f) saved_model_opt = infos['opt'] need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f: histories = utils.pickle_load(f) else: infos['iter'] = 0 infos['epoch'] = 0 infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['vocab'] = loader.get_vocab() infos['opt'] = opt iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) opt.vocab = loader.get_vocab() model = models.setup(opt).cuda() del opt.vocab dp_model = torch.nn.DataParallel(model) lw_model = LossWrapper(model, opt) dp_lw_model = torch.nn.DataParallel(lw_model) epoch_done = True # Assure in training mode dp_lw_model.train() if opt.noamopt: optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) optimizer._step = iteration elif opt.reduce_on_plateau: optimizer = utils.build_optimizer(model.parameters(), opt) optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) else: optimizer = utils.build_optimizer(model.parameters(), opt) def save_checkpoint(model, infos, optimizer, histories=None, append=''): if len(append) > 0: append = '-' + append # if checkpoint_path doesn't exist if not os.path.isdir(opt.checkpoint_path): os.makedirs(opt.checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append)) torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append)) torch.save(optimizer.state_dict(), optimizer_path) with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f: utils.pickle_dump(infos, f) if histories: with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f: utils.pickle_dump(histories, f) try: while True: sys.stdout.flush() if epoch_done: if not opt.noamopt and not opt.reduce_on_plateau: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate ** frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) print('Learning Rate: ', opt.current_lr) if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False epoch_done = False data = loader.get_batch('train') if (iteration % acc_steps == 0): optimizer.zero_grad() torch.cuda.synchronize() start = time.time() tmp = [data['fc_feats'], data['att_feats'], data['c3d_feats'], data['labels'], data['masks'], data['att_masks'], data['c3d_masks']] tmp = [_ if _ is None else _.cuda() for _ in tmp] fc_feats, att_feats, c3d_feats, labels, masks, att_masks, c3d_masks = tmp model_out = dp_lw_model(fc_feats, att_feats, c3d_feats, labels, masks, att_masks, c3d_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag) loss = model_out['loss'].mean() loss_sp = loss / acc_steps loss_sp.backward() if ((iteration + 1) % acc_steps == 0): utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() torch.cuda.synchronize() train_loss = loss.item() end = time.time() if iteration % 1 == 0: if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}".format(iteration, epoch, train_loss, end - start)) else: print("iter {} (epoch {}), reward1 = {:.3f}, reward2 = {:.3f}, reward3 = {:.3f}, train_loss = {:.3f}, time/batch = {:.3f}".format(iteration, epoch, model_out['reward_layer1'].mean(), model_out['reward_layer2'].mean(), model_out['reward_layer3'].mean(), train_loss, end - start)) iteration += 1 if data['bounds']['wrapped']: epoch += 1 epoch_done = True if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) if opt.noamopt: opt.current_lr = optimizer.rate() elif opt.reduce_on_plateau: opt.current_lr = optimizer.current_lr add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'reward1', model_out['reward_layer1'].mean(), iteration) add_summary_value(tb_summary_writer, 'reward2', model_out['reward_layer2'].mean(), iteration) add_summary_value(tb_summary_writer, 'reward3', model_out['reward_layer3'].mean(), iteration) loss_history[iteration] = train_loss lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': opt.val_split, 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, lw_model.crit, loader, eval_kwargs) print('Summary Epoch {} Iteration {}: CIDEr: {} BLEU-4: {}'.format(epoch, iteration, lang_stats['CIDEr'], lang_stats['Bleu_4'])) if opt.reduce_on_plateau: if opt.reward_metric == 'cider': optimizer.scheduler_step(-lang_stats['CIDEr']) elif opt.reward_metric == 'bleu': optimizer.scheduler_step(-lang_stats['Bleu_4']) elif opt.reward_metric == 'meteor': optimizer.scheduler_step(-lang_stats['METEOR']) elif opt.reward_metric == 'rouge': optimizer.scheduler_step(-lang_stats['ROUGE_L']) else: optimizer.scheduler_step(val_loss) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k,v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} # Save model if is improving on validation result if opt.language_eval == 1: if opt.reward_metric == 'cider': current_score = lang_stats['CIDEr'] elif opt.reward_metric == 'bleu': current_score = lang_stats['Bleu_4'] elif opt.reward_metric == 'meteor': current_score = lang_stats['METEOR'] elif opt.reward_metric == 'rouge': current_score = lang_stats['ROUGE_L'] else: current_score = - val_loss best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True # Dump miscalleous informations infos['best_val_score'] = best_val_score histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history save_checkpoint(model, infos, optimizer, histories) if opt.save_history_ckpt: save_checkpoint(model, infos, optimizer, append=str(iteration)) if best_flag: save_checkpoint(model, infos, optimizer, append='best') # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break except (RuntimeError, KeyboardInterrupt): print('Save ckpt on exception ...') save_checkpoint(model, infos, optimizer) print('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace)
def train(opt): # Deal with feature things before anything opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model) if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'), 'rb') as f: infos = utils.pickle_load(f) saved_model_opt = infos['opt'] need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars(opt)[ checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f: histories = utils.pickle_load(f) else: infos['iter'] = 0 infos['epoch'] = 0 infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['vocab'] = loader.get_vocab() infos['opt'] = opt iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) # cnn_model = utils.build_cnn(opt) cnn_model = create_extractor("/root/PycharmProjects/vgg_vae_best_model.pth") cnn_model = cnn_model.cuda() if vars(opt).get('start_from', None) is not None: cnn_model.load_state_dict(torch.load(os.path.join(opt.start_from, 'model-cnn.pth'))) print("load cnn model parameters from {}".format(os.path.join(opt.start_from, 'model-cnn.pth'))) model = models.setup(opt).cuda() dp_model = torch.nn.DataParallel(model) lw_model = LossWrapper(model, opt) dp_lw_model = torch.nn.DataParallel(lw_model) # dp_lw_model = lw_model epoch_done = True # Assure in training mode dp_lw_model.train() if opt.noamopt: assert opt.caption_model == 'transformer', 'noamopt can only work with transformer' optimizer = utils.get_std_opt( model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) optimizer._step = iteration elif opt.reduce_on_plateau: optimizer = utils.build_optimizer(model.parameters(), opt) optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) else: optimizer = utils.build_optimizer(model.parameters(), opt) # if opt.finetune_cnn_after != -1: # # only finetune the layer2 to layer4 cnn_optimizer = optim.Adam([ {'params': module.parameters()} for module in cnn_model.finetune_modules ], lr=opt.cnn_learning_rate, weight_decay=opt.cnn_weight_decay) # Load the optimizer if vars(opt).get('start_from', None) is not None: if os.path.isfile(os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict(torch.load( os.path.join(opt.start_from, 'optimizer.pth'))) if opt.finetune_cnn_after != -1: if os.path.isfile(os.path.join(opt.start_from, 'optimizer-cnn.pth')): cnn_optimizer.load_state_dict(torch.load( os.path.join(opt.start_from, 'optimizer-cnn.pth'))) def save_checkpoint(model, cnn_model, infos, optimizer, cnn_optimizer, histories=None, append=''): if len(append) > 0: append = '-' + append # if checkpoint_path doesn't exist if not os.path.isdir(opt.checkpoint_path): os.makedirs(opt.checkpoint_path) checkpoint_path = os.path.join( opt.checkpoint_path, 'model%s.pth' % (append)) torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) cnn_checkpoint_path = os.path.join( opt.checkpoint_path, 'model-cnn%s.pth' % (append)) torch.save(cnn_model.state_dict(), cnn_checkpoint_path) print("cnn model saved to {}".format(cnn_checkpoint_path)) optimizer_path = os.path.join( opt.checkpoint_path, 'optimizer%s.pth' % (append)) torch.save(optimizer.state_dict(), optimizer_path) if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after: cnn_optimizer_path = os.path.join( opt.checkpoint_path, 'optimizer%s-cnn.pth' % (append)) torch.save(cnn_optimizer.state_dict(), cnn_optimizer_path) with open(os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '%s.pkl' % (append)), 'wb') as f: utils.pickle_dump(infos, f) if histories: with open(os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '%s.pkl' % (append)), 'wb') as f: utils.pickle_dump(histories, f) try: while True: if epoch_done: if not opt.noamopt and not opt.reduce_on_plateau: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = ( epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate ** frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate # set the decayed rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = ( epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every opt.ss_prob = min( opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # Update the training stage of cnn if opt.finetune_cnn_after == -1 or epoch < opt.finetune_cnn_after: for p in cnn_model.parameters(): p.requires_grad = False cnn_model.eval() else: for p in cnn_model.parameters(): p.requires_grad = True # Fix the first few layers: for module in cnn_model.fixed_modules: for p in module.parameters(): p.requires_grad = False cnn_model.train() # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False epoch_done = False start = time.time() # Load data from train split (0) data = loader.get_batch('train') torch.cuda.synchronize() print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']] tmp = [_ if _ is None else _.cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp # att_feats 8x672x224 att_feats = att_feats.view(att_feats.size(0), 3, 224, 224) att_feats, fc_feats = cnn_model(att_feats) # fc_feats = att_feats.mean(3).mean(2) # att_feats = torch.nn.functional.adaptive_avg_pool2d( # att_feats, [7, 7]).permute(0, 2, 3, 1) att_feats = att_feats.permute(0, 2, 3, 1) att_feats = att_feats.view(att_feats.size(0), 49, -1) att_feats = att_feats.unsqueeze(1).expand(*((att_feats.size(0), opt.seq_per_img,) + att_feats.size( )[1:])).contiguous().view((att_feats.size(0) * opt.seq_per_img), -1, att_feats.size()[-1]) fc_feats = fc_feats.unsqueeze(1).expand(*((fc_feats.size(0), opt.seq_per_img,) + fc_feats.size( )[1:])).contiguous().view(*((fc_feats.size(0) * opt.seq_per_img,) + fc_feats.size()[1:])) optimizer.zero_grad() if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after: cnn_optimizer.zero_grad() model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag) loss = model_out['loss'].mean() # loss.backward() # utils.clip_gradient(optimizer, opt.grad_clip) # optimizer.step() if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after: utils.clip_gradient(cnn_optimizer, opt.grad_clip) cnn_optimizer.step() train_loss = loss.item() torch.cuda.synchronize() end = time.time() if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" .format(iteration, epoch, train_loss, end - start)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" .format(iteration, epoch, model_out['reward'].mean(), end - start)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 epoch_done = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) if opt.noamopt: opt.current_lr = optimizer.rate() elif opt.reduce_on_plateau: opt.current_lr = optimizer.current_lr add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value( tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value( tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration) loss_history[iteration] = train_loss if not sc_flag else model_out['reward'].mean( ) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # update infos infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( cnn_model, model, lw_model.crit, loader, eval_kwargs) if opt.reduce_on_plateau: if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: optimizer.scheduler_step(val_loss) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = - val_loss best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True # Dump miscalleous informations infos['best_val_score'] = best_val_score histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history save_checkpoint(model, cnn_model, infos, optimizer, cnn_optimizer, histories) if opt.save_history_ckpt: save_checkpoint(model, cnn_model, infos, optimizer, cnn_optimizer, append=str(iteration)) if best_flag: save_checkpoint(model, cnn_model, infos, optimizer, cnn_optimizer, append='best') # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break except (RuntimeError, KeyboardInterrupt): print('Save ckpt on exception ...') save_checkpoint(model, cnn_model, infos, optimizer, cnn_optimizer) print('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace) # test model test_kwargs = {'split': 'test', 'dataset': opt.input_json} test_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( cnn_model, model, lw_model.crit, loader, test_kwargs) if opt.reduce_on_plateau: if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: optimizer.scheduler_step(val_loss) # Write validation result into summary add_summary_value(tb_summary_writer, 'test loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}
def train(opt): print(opt) # To reproduce training results init_seed() # Image Preprocessing # For normalization, see https://github.com/pytorch/vision#models transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomRotation(degrees=10), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)) ]) # Deal with feature things before anything opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model) if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader(opt, transform=transform) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_' + opt.id + '-best.pkl'), 'rb') as f: infos = utils.pickle_load(f) saved_model_opt = infos['opt'] need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars(opt)[ checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f: histories = utils.pickle_load(f) else: infos['iter'] = 0 infos['epoch'] = 0 infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['vocab'] = loader.get_vocab() infos['opt'] = opt iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) opt.vocab = loader.get_vocab() if torch.cuda.is_available(): model = models.setup(opt).cuda() else: model = models.setup(opt) del opt.vocab dp_model = torch.nn.DataParallel(model) lw_model = LossWrapper(model, opt) dp_lw_model = torch.nn.DataParallel(lw_model) #fgm = FGM(model) cnn_model = ResnetBackbone() if torch.cuda.is_available(): cnn_model = cnn_model.cuda() if opt.start_from is not None: model_dict = cnn_model.state_dict() predict_dict = torch.load(os.path.join(opt.start_from, 'cnn_model-best.pth')) model_dict = {k: predict_dict["module."+k] for k, _ in model_dict.items() if "module."+ k in predict_dict} cnn_model.load_state_dict(model_dict) cnn_model = torch.nn.DataParallel(cnn_model) epoch_done = True # Assure in training mode dp_lw_model.train() if opt.noamopt: assert opt.caption_model == 'transformer', 'noamopt can only work with transformer' optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) optimizer._step = iteration elif opt.reduce_on_plateau: optimizer = utils.build_optimizer(model.parameters(), opt) optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) else: optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer-best.pth'))) def save_checkpoint(model, cnn_model, infos, optimizer, histories=None, append=''): if len(append) > 0: append = '-' + append # if checkpoint_path doesn't exist if not os.path.isdir(opt.checkpoint_path): os.makedirs(opt.checkpoint_path) #Transformer model checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' % (append)) torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) #CNN model checkpoint_path = os.path.join(opt.checkpoint_path, 'cnn_model%s.pth' % (append)) if not os.path.exists(checkpoint_path): torch.save(cnn_model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' % (append)) torch.save(optimizer.state_dict(), optimizer_path) with open(os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '%s.pkl' % (append)), 'wb') as f: utils.pickle_dump(infos, f) if histories: with open(os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '%s.pkl' % (append)), 'wb') as f: utils.pickle_dump(histories, f) cnn_after = 3 try: while True: if epoch_done: if opt.fix_cnn or epoch < cnn_after: for p in cnn_model.parameters(): p.requires_grad = False cnn_model.eval() cnn_optimizer = None else: for p in cnn_model.parameters(): p.requires_grad = True # Fix the first few layers: for module in cnn_model._modules['module']._modules['resnet_conv'][:5]._modules.values(): for p in module.parameters(): p.requires_grad = False cnn_model.train() # Constructing CNN parameters for optimization, only fine-tuning higher layers cnn_optimizer = torch.optim.Adam( (filter(lambda p: p.requires_grad, cnn_model.parameters())), lr=2e-6 if (opt.self_critical_after != -1 and epoch >= opt.self_critical_after) else 5e-5, betas=(0.8, 0.999)) if not opt.noamopt and not opt.reduce_on_plateau: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate ** frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # set the decayed rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False epoch_done = False start = time.time() # Load data from train split (0) data = loader.get_batch('train') if iteration % opt.losses_log_every == 0: print('Read data:', time.time() - start) if torch.cuda.is_available(): torch.cuda.synchronize() start = time.time() if torch.cuda.is_available(): data['att_feats'] = cnn_model( data['att_feats'].cuda()) else: data['att_feats'] = cnn_model( data['att_feats'] ) data['att_feats'] = repeat_feat(data['att_feats']) tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']] if torch.cuda.is_available(): tmp = [_ if _ is None else _.cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp optimizer.zero_grad() if cnn_optimizer is not None: cnn_optimizer.zero_grad() # if epoch >= cnn_after: # att_feats.register_hook(save_grad("att_feats")) model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag) loss = model_out['loss'].mean() loss.backward() #loss.backward(retain_graph=True) # adversarial training #fgm.attack(emb_name='model.tgt_embed.0.lut.weight') #adv_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], # torch.arange(0, len(data['gts'])), sc_flag) #adv_loss = adv_out['loss'].mean() #adv_loss.backward() #fgm.restore(emb_name="model.tgt_embed.0.lut.weight") # utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() if cnn_optimizer is not None: cnn_optimizer.step() train_loss = loss.item() if torch.cuda.is_available(): torch.cuda.synchronize() end = time.time() if not sc_flag and iteration % opt.losses_log_every == 0: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) elif iteration % opt.losses_log_every == 0: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, model_out['reward'].mean(), end - start)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 epoch_done = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) if opt.noamopt: opt.current_lr = optimizer.rate() elif opt.reduce_on_plateau: opt.current_lr = optimizer.current_lr add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration) loss_history[iteration] = train_loss if not sc_flag else model_out['reward'].mean() lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # update infos infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) eval_kwargs["cnn_model"] = cnn_model val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, lw_model.crit, loader, eval_kwargs) if opt.reduce_on_plateau: if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: optimizer.scheduler_step(val_loss) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = - val_loss best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True # Dump miscalleous informations infos['best_val_score'] = best_val_score histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history save_checkpoint(model, cnn_model, infos, optimizer, histories) if opt.save_history_ckpt: save_checkpoint(model, cnn_model, infos, optimizer, append=str(iteration)) if best_flag: save_checkpoint(model, cnn_model, infos, optimizer, append='best') # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break except (RuntimeError, KeyboardInterrupt): print('Save ckpt on exception ...') save_checkpoint(model, cnn_model, infos, optimizer) print('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace)
def train(opt): loader = Loader(opt) infos = {} histories = {} Model = model.setup(opt).cuda() LW_model = LossWrapper(Model, opt) # DP_lw_model = torch.nn.DataParallel(LW_model) LW_model.train() optimizer = utils.build_optimizer(Model.parameters(), opt) if opt.start_from is not None: with open(os.path.join(opt.start_from, 'infos-best.pkl'), 'rb') as f: infos = utils.pickle_load(f) if os.path.isfile(os.path.join(opt.start_from, 'histories-best.pkl')): with open(os.path.join(opt.start_from, 'histories-best.pkl'), 'rb') as f: histories = utils.pickle_load(f) if os.path.isfile(os.path.join(opt.start_from, 'optimizer-best.pth')): optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer-best.pth'))) else: infos['iter'] = 0 infos['epoch'] = 0 infos['opt'] = opt infos['label2id'] = load_label(opt.input_label2id) iteration = infos.get('iter', '0') epoch = infos.get('epoch', '0') best_val_score = infos.get('best_val_score', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) epoch_done = True best_epoch = -1 try: while True: if epoch_done: iteration = 0 if epoch != 0: predictions, targets, _ ,metrics = eval_utils.evaluate(Model, loader, infos['label2id'], opt.eval_batch_size, opt.rel_num, 'dev') val_result_history[iteration] = {'predictions': predictions, 'metrics': metrics, 'targets': targets} #print('dev res: ', metrics) current_score = metrics['F1'] histories['c'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history best_flag = False if current_score > best_val_score: best_epoch = epoch best_val_score = current_score best_flag = True infos['best_val_score'] = best_val_score save_checkpoint(Model, infos, optimizer, histories) if best_flag: save_checkpoint(Model, infos, optimizer, append='best') epoch_done = False if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay ** frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) start = time.time() data = loader.get_batch_train(opt.batch_size) #data = sorted(data, key=lambda x: x[-1], reverse=True) wrapped = data[-1] data = data[:-1] #print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() data = [t.cuda() for t in data] sents, rels, labels, poses, chars, sen_lens = data if not opt.use_char: chars = None if not opt.use_pos: poses = None mask = torch.zeros(sents.size()).cuda() for i in range(sents.size(0)): mask[i][:sen_lens[i]] = 1 mask2 = torch.where(labels == 8, torch.ones_like(sents), torch.ones_like(sents)*10).cuda() mask2 = mask2.float() * mask.float() optimizer.zero_grad() sum_loss = LW_model(sents, sen_lens, rels, mask, labels, mask2, poses, chars) loss = sum_loss/sents.shape[0] loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() if iteration % 200 == 0: end = time.time() print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) iteration += 1 if wrapped: epoch += 1 epoch_done = True infos['iter'] = iteration infos['epoch'] = epoch if iteration % opt.save_loss_every == 0: loss_history[iteration] = train_loss lr_history[iteration] = opt.current_lr if opt.max_epoch != -1 and epoch >= opt.max_epoch: break except (RuntimeError, KeyboardInterrupt): print('Save ckpt on exception ...') save_checkpoint(Model, infos, optimizer) print('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace)
def train(opt): if vars(opt).get('start_from', None) is not None: opt.checkpoint_path = opt.start_from opt.id = opt.checkpoint_path.split('/')[-1] print('Point to folder: {}'.format(opt.checkpoint_path)) else: opt.id = datetime.datetime.now().strftime( '%Y%m%d_%H%M%S') + '_' + opt.caption_model opt.checkpoint_path = os.path.join(opt.checkpoint_path, opt.id) if not os.path.exists(opt.checkpoint_path): os.makedirs(opt.checkpoint_path) print('Create folder: {}'.format(opt.checkpoint_path)) # Deal with feature things before anything opt.use_att = utils.if_use_att(opt.caption_model) # opt.use_att = False if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader_UP(opt) opt.vocab_size = loader.vocab_size if opt.use_rela == 1: opt.rela_dict_size = loader.rela_dict_size opt.seq_length = loader.seq_length use_rela = getattr(opt, 'use_rela', 0) try: tb_summary_writer = tf and tf.compat.v1.summary.FileWriter( opt.checkpoint_path) except: print('Set tensorboard error!') pdb.set_trace() infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.checkpoint_path, 'infos.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile(os.path.join(opt.checkpoint_path, 'histories.pkl')): with open(os.path.join(opt.checkpoint_path, 'histories.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) model = models.setup(opt).cuda() # dp_model = torch.nn.DataParallel(model) # dp_model = torch.nn.DataParallel(model, [0,2,3]) dp_model = model print('### Model summary below###\n {}\n'.format(str(model))) model_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print('model parameter:{}'.format(model_params)) update_lr_flag = True # Assure in training mode dp_model.train() parameters = model.named_children() crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() optimizer = utils.build_optimizer( filter(lambda p: p.requires_grad, model.parameters()), opt) optimizer.zero_grad() accumulate_iter = 0 train_loss = 0 reward = np.zeros([1, 1]) while True: if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False start = time.time() # Load data from train split (0) data = loader.get_batch(opt.train_split) # print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() fc_feats = None att_feats = None att_masks = None ssg_data = None rela_data = None if getattr(opt, 'use_ssg', 0) == 1: if getattr(opt, 'use_isg', 0) == 1: tmp = [ data['fc_feats'], data['labels'], data['masks'], data['att_feats'], data['att_masks'], data['isg_rela_matrix'], data['isg_rela_masks'], data['isg_obj'], data['isg_obj_masks'], data['isg_attr'], data['isg_attr_masks'], data['ssg_rela_matrix'], data['ssg_rela_masks'], data['ssg_obj'], data['ssg_obj_masks'], data['ssg_attr'], data['ssg_attr_masks'] ] tmp = [ _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp ] fc_feats, labels, masks, att_feats, att_masks, \ isg_rela_matrix, isg_rela_masks, isg_obj, isg_obj_masks, isg_attr, isg_attr_masks, \ ssg_rela_matrix, ssg_rela_masks, ssg_obj, ssg_obj_masks, ssg_attr, ssg_attr_masks = tmp # image graph domain isg_data = {} isg_data['att_feats'] = att_feats isg_data['att_masks'] = att_masks isg_data['isg_rela_matrix'] = isg_rela_matrix isg_data['isg_rela_masks'] = isg_rela_masks isg_data['isg_obj'] = isg_obj isg_data['isg_obj_masks'] = isg_obj_masks isg_data['isg_attr'] = isg_attr isg_data['isg_attr_masks'] = isg_attr_masks # text graph domain ssg_data = {} ssg_data['ssg_rela_matrix'] = ssg_rela_matrix ssg_data['ssg_rela_masks'] = ssg_rela_masks ssg_data['ssg_obj'] = ssg_obj ssg_data['ssg_obj_masks'] = ssg_obj_masks ssg_data['ssg_attr'] = ssg_attr ssg_data['ssg_attr_masks'] = ssg_attr_masks else: tmp = [ data['fc_feats'], data['ssg_rela_matrix'], data['ssg_rela_masks'], data['ssg_obj'], data['ssg_obj_masks'], data['ssg_attr'], data['ssg_attr_masks'], data['labels'], data['masks'] ] tmp = [ _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp ] fc_feats, ssg_rela_matrix, ssg_rela_masks, ssg_obj, ssg_obj_masks, ssg_attr, ssg_attr_masks, labels, masks = tmp ssg_data = {} ssg_data['ssg_rela_matrix'] = ssg_rela_matrix ssg_data['ssg_rela_masks'] = ssg_rela_masks ssg_data['ssg_obj'] = ssg_obj ssg_data['ssg_obj_masks'] = ssg_obj_masks ssg_data['ssg_attr'] = ssg_attr isg_data = None ssg_data['ssg_attr_masks'] = ssg_attr_masks else: tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp if not sc_flag: # loss = crit(dp_model(model_zh,model_en,itow_zh,itow, fc_feats, labels, isg_data, ssg_data), labels[:, 1:], masks[:, 1:]) # print('ssg:') # print(ssg_data['ssg_obj']) # print('predict:') # print(dp_model(fc_feats, labels, isg_data, ssg_data)) # print('label:') # print(labels[:, 1:]) loss = crit(dp_model(fc_feats, labels, isg_data, ssg_data), labels[:, 1:], masks[:, 1:]) else: gen_result, sample_logprobs = dp_model(fc_feats, isg_data, ssg_data, opt={'sample_max': 0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, isg_data, ssg_data, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) accumulate_iter = accumulate_iter + 1 loss = loss / opt.accumulate_number loss.backward() if accumulate_iter % opt.accumulate_number == 0: utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() optimizer.zero_grad() iteration += 1 accumulate_iter = 0 train_loss = loss.item() * opt.accumulate_number end = time.time() if not sc_flag: print("{}/{}/{}|train_loss={:.3f}|time/batch={:.3f}" \ .format(opt.id, iteration, epoch, train_loss, end - start)) else: print("{}/{}/{}|avg_reward={:.3f}|time/batch={:.3f}" \ .format(opt.id, iteration, epoch, np.mean(reward[:, 0]), end - start)) torch.cuda.synchronize() # Update the iteration and epoch if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0) and (iteration != 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:, 0]), iteration) loss_history[iteration] = train_loss if not sc_flag else np.mean( reward[:, 0]) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # make evaluation on validation set, and save model # if (iteration %2 == 0) and (iteration != 0): if (iteration % opt.save_checkpoint_every == 0) and (iteration != 0): # eval model if use_rela: eval_kwargs = { 'split': 'val', 'dataset': opt.input_json, 'use_real': 1 } else: eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) # val_loss, predictions, lang_stats = eval_utils.eval_split(model_zh,model_en,itow_zh,itow, dp_model, crit, loader, eval_kwargs) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, crit, loader, eval_kwargs) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if True: # if true save_id = iteration / opt.save_checkpoint_every if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open(os.path.join(opt.checkpoint_path, 'infos.pkl'), 'wb') as f: cPickle.dump(infos, f) with open(os.path.join(opt.checkpoint_path, 'histories.pkl'), 'wb') as f: cPickle.dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open( os.path.join(opt.checkpoint_path, 'infos-best.pkl'), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def load_info(loader, start_from, checkpoint_path, p_flag): infos = {} histories = {} if start_from is not None: # open old infos and check if models are compatible with open(os.path.join(checkpoint_path, 'infos.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] # need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"] need_be_same = ["rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile(os.path.join(checkpoint_path, 'histories.pkl')): with open(os.path.join(checkpoint_path, 'histories.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) opt.p_flag = p_flag if getattr(opt, 'p_flag', 0) == 0: opt.caption_model = opt.caption_model_zh else: opt.caption_model = opt.caption_model_en model = models.setup(opt).cuda() # dp_model = torch.nn.DataParallel(model) # dp_model = torch.nn.DataParallel(model, [0,2,3]) dp_model = model update_lr_flag = True # Assure in training mode dp_model.train() parameters = model.named_children() crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() optimizer = utils.build_optimizer( filter(lambda p: p.requires_grad, model.parameters()), opt) optimizer.zero_grad() accumulate_iter = 0 train_loss = 0 train_loss_kl = 0 train_loss_all = 0 reward = np.zeros([1, 1]) return loader,iteration,epoch,val_result_history,loss_history,lr_history,ss_prob_history,best_val_score,\ infos,histories,update_lr_flag,model,dp_model,parameters,crit,rl_crit,optimizer,accumulate_iter,train_loss,reward,train_loss_kl,train_loss_all
def train(opt): # Deal with feature things before anything opt.use_att = utils.if_use_att(opt.caption_model) if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from_path is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from_path, 'infos_' + opt.id + '.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from_path, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from_path, 'histories_' + opt.id + '.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) #print(val_result_history.get(3000)) #exit(0) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) model = models.setup(opt).cuda() no = sum(p.numel() for p in model.parameters()) pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print("Trainable Params:" + str(pytorch_total_params)) print("Total Params:" + str(no)) #exit(0) dp_model = torch.nn.DataParallel(model) epoch_done = True # Assure in training mode dp_model.train() if (opt.use_obj_mcl_loss == 1): mcl_crit = utils.MultiLabelClassification() if opt.label_smoothing > 0: crit = utils.LabelSmoothing(smoothing=opt.label_smoothing) else: crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() if opt.noamopt: assert opt.caption_model == 'transformer', 'noamopt can only work with transformer' optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) optimizer._step = iteration elif opt.reduce_on_plateau: optimizer = utils.build_optimizer(model.parameters(), opt) optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) else: optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from_path', None) is not None and os.path.isfile( os.path.join(opt.start_from_path, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from_path, 'optimizer.pth'))) time_epoch_start = time.time() data_time_sum = 0 batch_time_sum = 0 while True: if epoch_done: torch.cuda.synchronize() time_epoch_end = time.time() time_elapsed = (time_epoch_end - time_epoch_start) print('[DEBUG] Epoch Time: ' + str(time_elapsed)) print('[DEBUG] Sum Data Time: ' + str(data_time_sum)) print('[DEBUG] Sum Batch Time: ' + str(batch_time_sum)) #if epoch==1: # exit(0) if not opt.noamopt and not opt.reduce_on_plateau: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # set the decayed rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False epoch_done = False start = time.time() # Load data from train split (0) data = loader.get_batch('train') print('Read data:', time.time() - start) data_time_sum += time.time() - start torch.cuda.synchronize() start = time.time() if (opt.use_obj_mcl_loss == 0): tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp else: if opt.use_obj_att and opt.use_seg_feat: tmp = [ data['fc_feats'], data['att_feats'], data['obj_att_feats'], data['seg_feat_feats'], data['labels'], data['masks'], data['obj_labels'], data['att_masks'], data['obj_att_masks'], data['seg_feat_masks'] ] tmp = [ _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp ] fc_feats, att_feats, obj_att_feats, seg_feat_feats, labels, masks, obj_labels, att_masks, obj_att_masks, seg_feat_masks = tmp elif not opt.use_obj_att and opt.use_seg_feat: tmp = [ data['fc_feats'], data['att_feats'], data['seg_feat_feats'], data['labels'], data['masks'], data['obj_labels'], data['att_masks'], data['seg_feat_masks'] ] tmp = [ _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp ] fc_feats, att_feats, seg_feat_feats, labels, masks, obj_labels, att_masks, seg_feat_masks = tmp elif not opt.use_obj_att and not opt.use_seg_feat: tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['obj_labels'], data['att_masks'] ] tmp = [ _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp ] fc_feats, att_feats, labels, masks, obj_labels, att_masks = tmp elif opt.use_obj_att and not opt.use_seg_feat: tmp = [ data['fc_feats'], data['att_feats'], data['obj_att_feats'], data['labels'], data['masks'], data['obj_labels'], data['att_masks'], data['obj_att_masks'] ] tmp = [ _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp ] fc_feats, att_feats, obj_att_feats, labels, masks, obj_labels, att_masks, obj_att_masks = tmp optimizer.zero_grad() if (opt.use_obj_mcl_loss == 0): if not sc_flag: loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:]) else: gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) else: if opt.use_obj_att and opt.use_seg_feat: if not sc_flag: logits, out = dp_model( fc_feats, [att_feats, obj_att_feats, seg_feat_feats], labels, [att_masks, obj_att_masks, seg_feat_masks]) caption_loss = crit(logits, labels[:, 1:], masks[:, 1:]) obj_loss = mcl_crit(out, obj_labels) loss = opt.lambda_caption * caption_loss + opt.lambda_obj * obj_loss #loss = 0.1*caption_loss + obj_loss #loss = caption_loss + 0 * obj_loss else: gen_result, sample_logprobs = dp_model( fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) elif not opt.use_obj_att and opt.use_seg_feat: if not sc_flag: logits, out = dp_model(fc_feats, [att_feats, seg_feat_feats], labels, [att_masks, seg_feat_masks]) caption_loss = crit(logits, labels[:, 1:], masks[:, 1:]) obj_loss = mcl_crit(out, obj_labels) loss = opt.lambda_caption * caption_loss + opt.lambda_obj * obj_loss #loss = caption_loss + 0 * obj_loss else: gen_result, sample_logprobs = dp_model( fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) if not opt.use_obj_att and not opt.use_seg_feat: if not sc_flag: logits, out = dp_model(fc_feats, att_feats, labels, att_masks) caption_loss = crit(logits, labels[:, 1:], masks[:, 1:]) obj_loss = mcl_crit(out, obj_labels) loss = opt.lambda_caption * caption_loss + opt.lambda_obj * obj_loss #loss = caption_loss + 0 * obj_loss else: gen_result, sample_logprobs = dp_model( fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) elif opt.use_obj_att and not opt.use_seg_feat: if not sc_flag: logits, out = dp_model(fc_feats, [att_feats, obj_att_feats], labels, [att_masks, obj_att_masks]) caption_loss = crit(logits, labels[:, 1:], masks[:, 1:]) obj_loss = mcl_crit(out, obj_labels) loss = 0.1 * caption_loss + obj_loss #loss = caption_loss + 0 * obj_loss else: gen_result, sample_logprobs = dp_model( fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() end = time.time() batch_time_sum += end - start if not sc_flag: if (opt.use_obj_mcl_loss == 1): print("iter {} (epoch {}), train_loss = {:.3f}, caption_loss = {:.3f}, object_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, caption_loss.item(), obj_loss.item(), end - start)) else: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:,0]), end - start)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 epoch_done = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) if (opt.use_obj_mcl_loss == 1): add_summary_value(tb_summary_writer, 'obj_loss', obj_loss.item(), iteration) add_summary_value(tb_summary_writer, 'caption_loss', caption_loss.item(), iteration) if opt.noamopt: opt.current_lr = optimizer.rate() elif opt.reduce_on_plateau: opt.current_lr = optimizer.current_lr add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:, 0]), iteration) loss_history[iteration] = train_loss if not sc_flag else np.mean( reward[:, 0]) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model orig_batch_size = opt.batch_size opt.batch_size = 1 eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) loader.batch_size = eval_kwargs.get('batch_size', 1) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, crit, loader, eval_kwargs) opt.batch_size = orig_batch_size loader.batch_size = orig_batch_size if opt.reduce_on_plateau: if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: optimizer.scheduler_step(val_loss) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) for k, v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if True: # if true if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '-best.pkl'), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt): # Deal with feature things before anything opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model) if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) #opt.ss_prob=0.0 infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl'), 'rb') as f: infos = utils.pickle_load(f) saved_model_opt = infos['opt'] need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')): with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'), 'rb') as f: histories = utils.pickle_load(f) else: infos['iter'] = 0 infos['epoch'] = 0 infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['vocab'] = loader.get_vocab() infos['pix_perss']=loader.get_personality() infos['opt'] = opt iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) print("current epoch: ",epoch) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) opt.vocab = loader.get_vocab() opt.xpersonality=loader.get_personality() if opt.use_joint==0: #torch.cuda.set_device(0) model = models.setup(opt).cuda() elif opt.use_joint==1: model = models.JointModel(opt) model.cuda() #model=models.setup(opt) del opt.vocab if opt.start_from is not None: opt.model=os.path.join(opt.start_from, 'model'+'.pth') model.load_state_dict(torch.load(opt.model)) dp_model = torch.nn.DataParallel(model) lw_model = LossWrapper(model, opt) dp_lw_model = torch.nn.DataParallel(lw_model) #dp_lw_model=LossWrapper(model, opt) # this is for no cuda epoch_done = True # Assure in training mode #dp_lw_model=lw_model dp_lw_model.train() if opt.noamopt: assert opt.caption_model == 'transformer', 'noamopt can only work with transformer' optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) optimizer._step = iteration elif opt.reduce_on_plateau: optimizer = utils.build_optimizer([p for p in model.parameters() if p.requires_grad], opt) optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) else: optimizer = utils.build_optimizer([p for p in model.parameters() if p.requires_grad], opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")): optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) else: print('Optimizer param group number not matched? There must be new parameters. Reinit the optimizer.') def save_checkpoint(model, infos, optimizer, histories=None, append=''): if len(append) > 0: append = '-' + append # if checkpoint_path doesn't exist if not os.path.isdir(opt.checkpoint_path): os.makedirs(opt.checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append)) torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append)) torch.save(optimizer.state_dict(), optimizer_path) with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f: utils.pickle_dump(infos, f) if histories: with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f: utils.pickle_dump(histories, f) try: while True: if epoch_done: if not opt.noamopt and not opt.reduce_on_plateau: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate ** frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # set the decayed rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False # Assign retrieval loss weight if epoch > opt.retrieval_reward_weight_decay_start and opt.retrieval_reward_weight_decay_start >= 0: frac = (epoch - opt.retrieval_reward_weight_decay_start) // opt.retrieval_reward_weight_decay_every model.retrieval_reward_weight = opt.retrieval_reward_weight * (opt.retrieval_reward_weight_decay_rate ** frac) epoch_done = False start = time.time() # Load data from train split (0) data = loader.get_batch('train') print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() with torch.autograd.set_detect_anomaly(True): tmp = [data['fc_feats'], data['att_feats'],data['densecap'], data['labels'], data['masks'], data['att_masks'], data['personality']] tmp = [_ if _ is None else _.cuda() for _ in tmp] fc_feats, att_feats,densecap, labels, masks, att_masks,personality = tmp optimizer.zero_grad() model_out = dp_lw_model(fc_feats, att_feats,densecap, labels, masks, att_masks,personality, data['gts'], torch.arange(0, len(data['gts'])), sc_flag) loss = model_out['loss'].mean() loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() end = time.time() if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f},train_loss = {:.3f}" \ .format(iteration, epoch, model_out['reward'].mean(), end - start,train_loss)) if opt.use_joint==1: for k, v in model.loss().items(): prt_str += "{} = {:.3f} ".format(k, v) print(prt_str) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 epoch_done = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) if opt.noamopt: opt.current_lr = optimizer.rate() elif opt.reduce_on_plateau: opt.current_lr = optimizer.current_lr add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration) loss_history[iteration] = train_loss if not sc_flag else model_out['reward'].mean() lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # update infos infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, lw_model.crit, loader, eval_kwargs) if opt.reduce_on_plateau: if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: optimizer.scheduler_step(val_loss) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k,v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} # Save model if is improving on validation result if opt.language_eval == 1: if opt.use_joint==1: current_score = lang_stats['SPICE']*100 elif opt.use_joint==0: current_score = lang_stats['CIDEr'] # could use SPICE else: if opt.use_joint==0: current_score = - val_loss elif opt.use_joint==1: current_score= - val_loss['loss_cap'] if opt.use_joint==1: current_score_vse = val_loss.get(opt.vse_eval_criterion, 0)*100 best_flag = False best_flag_vse= False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True if opt.use_joint==1: if best_val_score_vse is None or current_score_vse > best_val_score_vse: best_val_score_vse = current_score_vse best_flag_vse = True infos['best_val_score_vse'] = best_val_score_vse # Dump miscalleous informations infos['best_val_score'] = best_val_score histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history save_checkpoint(model, infos, optimizer, histories) if opt.save_history_ckpt: save_checkpoint(model, infos, optimizer, append=str(iteration)) if best_flag: save_checkpoint(model, infos, optimizer, append='best') if best_flag_vse: save_checkpoint(model, infos, optimizer, append='vse-best') # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break except (RuntimeError, KeyboardInterrupt): print('Save ckpt on exception ...') save_checkpoint(model, infos, optimizer) print('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace)
def train(opt): # Deal with feature things before anything opt.use_att = utils.if_use_att(opt.caption_model) if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tb_summary_writer = tb and tb.SummaryWriter(log_dir=opt.checkpoint_path) print(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible print(os.getcwd()) with open( os.path.join(os.getcwd(), opt.start_from, 'infos_' + opt.id + '.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) model = models.setup(opt).cuda() # dp_model = torch.nn.DataParallel(model) dp_model = model update_lr_flag = True # Assure in training mode dp_model.train() crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) while True: # # [added] reproduce straight line learning rate decay in supplementary # # ---- the original paper used 60k iters # # ---- if lr goes to zero just stay at the last lr # linear_lr = -(iteration+1)*opt.learning_rate/60000 + opt.learning_rate # if linear_lr <= 0: # pass # else: # opt.current_lr = linear_lr # utils.set_lr(optimizer, opt.current_lr) if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False start = time.time() # Load data from train split (0) # [naxin] knn_data is the nearest neighbour batch, the format is identical to data data, knn_data = loader.get_batch('train') print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] # tmp = [knn_data['fc_feats'], knn_data['att_feats'], knn_data['labels'], knn_data['masks'], knn_data['att_masks']] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp optimizer.zero_grad() if not sc_flag: loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:]) else: gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() end = time.time() if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:,0]), end - start)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:, 0]), iteration) loss_history[iteration] = train_loss if not sc_flag else np.mean( reward[:, 0]) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, crit, loader, eval_kwargs, eval_knn=opt.use_knn) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if True: # if true if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '-best.pkl'), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt): import random random.seed(0) torch.manual_seed(0) torch.cuda.manual_seed_all(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False np.random.seed(0) # Deal with feature things before anything opt.use_att = utils.if_use_att(opt.caption_model) if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 from dataloader_pair import DataLoader loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) if opt.log_to_file: if os.path.exists(os.path.join(opt.checkpoint_path, 'log')): suffix = time.strftime("%Y-%m-%d %X", time.localtime()) print('Warning !!! %s already exists ! use suffix ! ' % os.path.join(opt.checkpoint_path, 'log')) sys.stdout = open( os.path.join(opt.checkpoint_path, 'log' + suffix), "w") else: print('logging to file %s' % os.path.join(opt.checkpoint_path, 'log')) sys.stdout = open(os.path.join(opt.checkpoint_path, 'log'), "w") infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible if os.path.isfile(opt.start_from): with open(os.path.join(opt.infos)) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme else: if opt.load_best != 0: print('loading best info') with open( os.path.join(opt.start_from, 'infos_' + opt.id + '-best.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme else: with open( os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f: try: histories = cPickle.load(f) except: print('load history error!') histories = {} iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) start_epoch = epoch val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) model = models.setup(opt).cuda() dp_model = torch.nn.DataParallel(model) update_lr_flag = True # Assure in training mode dp_model.train() crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() optimizer = utils.build_optimizer(model.parameters(), opt) #Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) if opt.caption_model == 'att2in2p': optimized = [ 'logit2', 'ctx2att2', 'core2', 'prev_sent_emb', 'prev_sent_wrap' ] optimized_param = [] optimized_param1 = [] for name, param in model.named_parameters(): second = False for n in optimized: if n in name: print('second', name) optimized_param.append(param) second = True if 'embed' in name: print('all', name) optimized_param1.append(param) optimized_param.append(param) elif not second: print('first', name) optimized_param1.append(param) while True: if opt.val_only: eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) print('start evaluating') val_loss, predictions, lang_stats = eval_utils_pair.eval_split( dp_model, crit, loader, eval_kwargs) exit(0) if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False start = time.time() # Load data from train split (0) data = loader.get_batch('train') print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() tmp = [ data['pair_fc_feats'], data['pair_att_feats'], data['pair_labels'], data['pair_masks'], data['pair_att_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp masks = masks.float() optimizer.zero_grad() if not sc_flag: if opt.onlysecond: # only using the second sentence from a visual paraphrase pair. opt.caption_model should be a one-stage decoding model loss = crit( dp_model(fc_feats, att_feats, labels[:, 1, :], att_masks), labels[:, 1, 1:], masks[:, 1, 1:]) loss1 = loss2 = loss / 2 elif opt.first: # using the first sentence tmp = [ data['first_fc_feats'], data['first_att_feats'], data['first_labels'], data['first_masks'], data['first_att_masks'] ] tmp = [ _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp ] fc_feats, att_feats, labels, masks, att_masks = tmp masks = masks.float() loss = crit( dp_model(fc_feats, att_feats, labels[:, :], att_masks), labels[:, 1:], masks[:, 1:]) loss1 = loss2 = loss / 2 elif opt.onlyfirst: # only using the second sentence from a visual paraphrase pair loss = crit( dp_model(fc_feats, att_feats, labels[:, 0, :], att_masks), labels[:, 0, 1:], masks[:, 0, 1:]) loss1 = loss2 = loss / 2 else: # proposed DCVP model, opt.caption_model should be att2inp output1, output2 = dp_model(fc_feats, att_feats, labels, att_masks, masks[:, 0, 1:]) loss1 = crit(output1, labels[:, 0, 1:], masks[:, 0, 1:]) loss2 = crit(output2, labels[:, 1, 1:], masks[:, 1, 1:]) loss = loss1 + loss2 else: raise NotImplementedError # Our DCVP model does not support self-critical sequence training # We found that RL(SCST) with CIDEr reward will improve conventional metrics (BLEU, CIDEr, etc.) # but harm diversity and descriptiveness # Please refer to the paper for the details loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() end = time.time() if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, loss1 = {:.3f}, loss2 = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, loss.item(), loss1.item(), loss2.item(), end - start)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:,0]), end - start)) sys.stdout.flush() # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:, 0]), iteration) loss_history[iteration] = train_loss if not sc_flag else np.mean( reward[:, 0]) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils_pair.eval_split( dp_model, crit, loader, eval_kwargs) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if True: # if true if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '-best.pkl'), 'wb') as f: cPickle.dump(infos, f) checkpoint_path = os.path.join( opt.checkpoint_path, 'model' + str(iteration) + '.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open( os.path.join( opt.checkpoint_path, 'infos_' + opt.id + '_' + str(iteration) + '.pkl'), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt): ################################ # Build dataloader ################################ loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length ########################## # Initialize infos ########################## infos = { 'iter': 0, 'epoch': 0, 'loader_state_dict': None, 'vocab': loader.get_vocab(), } # Load old infos(if there is) and check if models are compatible if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')): with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'), 'rb') as f: infos = utils.pickle_load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert getattr(saved_model_opt, checkme) == getattr( opt, checkme ), "Command line argument and saved model disagree on '%s' " % checkme infos['opt'] = opt ######################### # Build logger ######################### # naive dict logger histories = defaultdict(dict) if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f: histories.update(utils.pickle_load(f)) # tensorboard logger tb_summary_writer = SummaryWriter(opt.checkpoint_path) ########################## # Build model ########################## opt.vocab = loader.get_vocab() multi_models_list = [] for order in range(opt.number_of_models): multi_models_list.append(models.setup(opt).cuda()) for order in range(opt.number_of_models): multi_models_list.append(models.setup(opt).cuda()) for order in range(opt.number_of_models, 2 * opt.number_of_models): for param in multi_models_list[order].parameters(): param.detach_() for order in range(opt.number_of_models): for param, param_ema in zip( multi_models_list[order].parameters(), multi_models_list[order + opt.number_of_models].parameters()): param_ema.data = param.data.clone() # multi_models = MultiModels(multi_models_list) # multi_models_list.append(SenEncodeModel(opt).cuda()) multi_models = nn.ModuleList(multi_models_list) del opt.vocab # Load pretrained weights: if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, 'model.pth')): multi_models.load_state_dict( torch.load(os.path.join(opt.start_from, 'model.pth'))) # Wrap generation model with loss function(used for training) # This allows loss function computed separately on each machine lw_models = nn.ModuleList([ LossWrapper(multi_models[index], opt) for index in range(opt.number_of_models) ]) kdlw_models = nn.ModuleList([ KDLossWrapper(multi_models[index], opt) for index in range(opt.number_of_models) ]) lw_models_ema = nn.ModuleList([ LossWrapper(multi_models[opt.number_of_models + index], opt) for index in range(opt.number_of_models) ]) kdlw_models_ema = nn.ModuleList([ KDLossWrapper(multi_models[opt.number_of_models + index], opt) for index in range(opt.number_of_models) ]) # Wrap with dataparallel dp_models = nn.ModuleList([ torch.nn.DataParallel(multi_models[index]) for index in range(opt.number_of_models) ]) dp_lw_models = nn.ModuleList([ torch.nn.DataParallel(lw_models[index]) for index in range(opt.number_of_models) ]) dp_kdlw_models = nn.ModuleList([ torch.nn.DataParallel(kdlw_models[index]) for index in range(opt.number_of_models) ]) dp_models_ema = nn.ModuleList([ torch.nn.DataParallel(multi_models[opt.number_of_models + index]) for index in range(opt.number_of_models) ]) dp_lw_models_ema = nn.ModuleList([ torch.nn.DataParallel(lw_models_ema[index]) for index in range(opt.number_of_models) ]) dp_kdlw_models_ema = nn.ModuleList([ torch.nn.DataParallel(kdlw_models_ema[index]) for index in range(opt.number_of_models) ]) ########################## # Build optimizer ########################## if opt.noamopt: assert opt.caption_model in [ 'transformer', 'bert', 'm2transformer' ], 'noamopt can only work with transformer' optimizer = utils.get_std_opt(multi_models, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) elif opt.reduce_on_plateau: optimizer = utils.build_optimizer(multi_models.parameters(), opt) optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) else: optimizer = utils.build_optimizer(multi_models.parameters(), opt) # Load the optimizer if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) ########################## # Build loss ########################## # triplet_loss = nn.TripletMarginLoss() ######################### # Get ready to start ######################### iteration = infos['iter'] epoch = infos['epoch'] # For back compatibility if 'iterators' in infos: infos['loader_state_dict'] = { split: { 'index_list': infos['split_ix'][split], 'iter_counter': infos['iterators'][split] } for split in [ 'paired_train', 'unpaired_images_train', 'unpaired_captions_train', 'train', 'val', 'test' ] } loader.load_state_dict(infos['loader_state_dict']) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) if opt.noamopt: optimizer._step = iteration # flag indicating finish of an epoch # Always set to True at the beginning to initialize the lr or etc. epoch_done = True # Assure in training mode dp_lw_models.train() dp_kdlw_models.train() dp_lw_models_ema.train() dp_kdlw_models_ema.train() # Build the ensemble model # # Setup the model model_ensemble = AttEnsemble(multi_models_list[opt.number_of_models:2 * opt.number_of_models], weights=None) # model_ensemble.seq_length = 20 model_ensemble.cuda() # model_ensemble.eval() kd_model_outs_list = [] # Start training try: while True: # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break if epoch_done: if not opt.noamopt and not opt.reduce_on_plateau: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # set the decayed rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min( opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) for index in range(opt.number_of_models): multi_models[index].ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False # If start structure loss training if opt.structure_after != -1 and epoch >= opt.structure_after: struc_flag = True init_scorer(opt.cached_tokens) else: struc_flag = False if epoch >= opt.paired_train_epoch: opt.current_lambda_x = opt.hyper_parameter_lambda_x * \ (epoch - (opt.paired_train_epoch - 1)) /\ (opt.max_epochs - opt.paired_train_epoch) opt.current_lambda_y = opt.hyper_parameter_lambda_y * \ (epoch - (opt.paired_train_epoch - 1)) / \ (opt.max_epochs - opt.paired_train_epoch) epoch_done = False start = time.time() # Load data from train split (0) if epoch < opt.language_pretrain_epoch: data = loader.get_batch('unpaired_captions_train') elif epoch < opt.paired_train_epoch: data = loader.get_batch('paired_train') else: data = loader.get_batch('paired_train') unpaired_data = loader.get_batch('unpaired_images_train') unpaired_caption = loader.get_batch('unpaired_captions_train') print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() if epoch < opt.language_pretrain_epoch: tmp = [ data['fc_feats'] * 0, data['att_feats'] * 0, data['labels'], data['masks'], data['att_masks'] ] elif epoch < opt.paired_train_epoch: tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] else: tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] unpaired_tmp = [ unpaired_data['fc_feats'], unpaired_data['att_feats'], unpaired_data['labels'], unpaired_data['masks'], unpaired_data['att_masks'] ] unpaired_caption_tmp = [ unpaired_caption['fc_feats'] * 0, unpaired_caption['att_feats'] * 0, unpaired_caption['labels'], unpaired_caption['masks'], unpaired_caption['att_masks'] ] tmp = [_ if _ is None else _.cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp if epoch >= opt.paired_train_epoch: unpaired_tmp = [ _ if _ is None else _.cuda() for _ in unpaired_tmp ] unpaired_fc_feats, unpaired_att_feats, unpaired_labels, unpaired_masks, unpaired_att_masks = unpaired_tmp unpaired_caption_tmp = [ _ if _ is None else _.cuda() for _ in unpaired_caption_tmp ] unpaired_caption_fc_feats, unpaired_caption_att_feats, unpaired_caption_labels, unpaired_caption_masks, unpaired_caption_att_masks = unpaired_caption_tmp unpaired_caption_fc_feats = unpaired_caption_fc_feats.repeat( 5, 1) unpaired_caption_fc_feats = opt.std_pseudo_visual_feature * torch.randn_like( unpaired_caption_fc_feats) unpaired_caption_att_feats = unpaired_caption_att_feats.repeat( 5, 1, 1) unpaired_caption_fc_feats.requires_grad = True unpaired_caption_att_feats.requires_grad = True unpaired_caption_labels = unpaired_caption_labels.reshape( unpaired_caption_fc_feats.shape[0], -1) unpaired_caption_masks = unpaired_caption_masks.reshape( unpaired_caption_fc_feats.shape[0], -1) optimizer.zero_grad() if epoch < opt.language_pretrain_epoch: language_loss = 0 model_outs_list = [] for index in range(opt.number_of_models): model_out = dp_lw_models[index]( fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag) model_outs_list.append(model_out) language_loss += model_out['loss'].mean() loss = language_loss elif epoch < opt.paired_train_epoch: language_loss = 0 model_outs_list = [] for index in range(opt.number_of_models): model_out = dp_lw_models[index]( fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag) model_outs_list.append(model_out) language_loss += model_out['loss'].mean() loss = language_loss else: language_loss = 0 model_outs_list = [] for index in range(opt.number_of_models): model_out = dp_lw_models[index]( fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag) model_outs_list.append(model_out) language_loss += model_out['loss'].mean() loss = language_loss # else: # for unpaired image sentences # # Setup the model # model_ensemble = AttEnsemble(multi_models_list[:opt.number_of_models], weights=None) # model_ensemble.seq_length = 16 # model_ensemble.cuda() # model_ensemble.eval() model_ensemble.eval() eval_kwargs = dict() eval_kwargs.update(vars(opt)) with torch.no_grad(): seq, seq_logprobs = model_ensemble(unpaired_fc_feats, unpaired_att_feats, unpaired_att_masks, opt=eval_kwargs, mode='sample') # val_loss, predictions, lang_stats = eval_utils.eval_split(model_ensemble, lw_models[0].crit, loader, # eval_kwargs) # print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in # model_ensemble.done_beams[0]])) # print('++' * 10) # for ii in range(10): # sents = utils.decode_sequence(loader.get_vocab(), seq[ii].unsqueeze(0)) # gt_sent = utils.decode_sequence(loader.get_vocab(), labels[ii,0].unsqueeze(0)) # a=1 model_ensemble.train() model_ensemble_sudo_labels = labels.new_zeros( (opt.batch_size, opt.beam_size, eval_kwargs['max_length'] + 2)) model_ensemble_sudo_log_prob = masks.new_zeros( (opt.batch_size, opt.beam_size, eval_kwargs['max_length'] + 2, len(loader.get_vocab()) + 1)) model_ensemble_sum_log_prob = masks.new_zeros( (opt.batch_size, opt.beam_size)) for batch_index in range(opt.batch_size): for beam_index in range(opt.beam_size): # for beam_index in range(3): pred = model_ensemble.done_beams[batch_index][ beam_index]['seq'] log_prob = model_ensemble.done_beams[batch_index][ beam_index]['logps'] model_ensemble_sudo_labels[batch_index, beam_index, 1:pred.shape[0] + 1] = pred model_ensemble_sudo_log_prob[batch_index, beam_index, 1:pred.shape[0] + 1] = log_prob model_ensemble_sum_log_prob[batch_index][ beam_index] = model_ensemble.done_beams[ batch_index][beam_index]['p'] # model_ensemble_prob = F.softmax(model_ensemble_sum_log_prob) data_ensemble_sudo_gts = list() for data_ensemble_sudo_gts_index in range( model_ensemble_sudo_labels.shape[0]): data_ensemble_sudo_gts.append(model_ensemble_sudo_labels[ data_ensemble_sudo_gts_index, :, 1:-1].data.cpu().numpy()) # generated_sentences = list() # for i in range(unpaired_fc_feats.shape[0]): # generated_sentences.append( # [utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in # model_ensemble.done_beams[i]]) # # pos_tag_results = list() # for i in range(unpaired_fc_feats.shape[0]): # generated_sentences_i = generated_sentences[i] # pos_tag_results_i = [] # for text in generated_sentences_i: # text_tokenize = nltk.word_tokenize(text) # pos_tag_results_i_jbeam = [] # for vob, vob_type in nltk.pos_tag(text_tokenize): # if vob_type == 'NN' or vob_type == 'NNS': # pos_tag_results_i_jbeam.append(vob) # pos_tag_results_i.append(pos_tag_results_i_jbeam) # pos_tag_results.append(pos_tag_results_i) # for i in range(fc_feats.shape[0]): # print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in # model_ensemble.done_beams[i]])) # print('--' * 10) # dets = data['dets'] # # promising_flag = labels.new_zeros(opt.batch_size, opt.beam_size) # for batch_index in range(opt.batch_size): # dets_batch = dets[batch_index] # for beam_index in range(opt.beam_size): # indicator = [0] * len(dets_batch) # pos_tag_batch_beam = pos_tag_results[batch_index][beam_index] # for pos_tag_val in pos_tag_batch_beam: # for ii in range(len(dets_batch)): # possible_list = vob_transform_list[dets_batch[ii]] # if pos_tag_val in possible_list: # indicator[ii] = 1 # if sum(indicator) == len(dets_batch) or sum(indicator) >= 2: # promising_flag[batch_index, beam_index] = 1 # # # model_ensemble_sudo_log_prob = model_ensemble_sudo_log_prob * promising_flag.unsqueeze(-1).unsqueeze(-1) # model_ensemble_sudo_labels = model_ensemble_sudo_labels * promising_flag.unsqueeze(-1) #sudo_masks_for_model = sudo_masks_for_model.detach() distilling_loss = 0 # We use the random study machinism who_to_study = random.randint(0, opt.number_of_models - 1) # for index in range(opt.number_of_models): # model_out = dp_kdlw_models[index](unpaired_fc_feats, unpaired_att_feats, model_ensemble_sudo_labels, # model_ensemble_sudo_log_prob, att_masks, data_ensemble_sudo_gts, # torch.arange(0, len(data_ensemble_sudo_gts)), sc_flag, # struc_flag, model_ensemble_sum_log_prob) # kd_model_outs_list.append(model_out) model_out = dp_kdlw_models[who_to_study]( unpaired_fc_feats, unpaired_att_feats, model_ensemble_sudo_labels, model_ensemble_sudo_log_prob, att_masks, data_ensemble_sudo_gts, torch.arange(0, len(data_ensemble_sudo_gts)), sc_flag, struc_flag, model_ensemble_sum_log_prob) # kd_model_outs_list.append(model_out) distilling_loss += model_out['loss'].mean() loss += opt.number_of_models * opt.current_lambda_x * distilling_loss ################################################################### # use unlabelled captions # simple_sgd = utils.gradient_descent(unpaired_caption_fc_feats, stepsize=1e3) simple_sgd = utils.gradient_descent_adagrad( unpaired_caption_fc_feats, stepsize=1) gts_tmp = unpaired_caption['gts'] new_gts = [] for ii in range(len(data['gts'])): for jj in range(gts_tmp[ii].shape[0]): new_gts.append(gts_tmp[ii][jj]) unpaired_caption['gts'] = new_gts for itr in range(opt.inner_iteration): unlabelled_caption_model_out = dp_lw_models_ema[ itr % opt.number_of_models]( unpaired_caption_fc_feats, unpaired_caption_att_feats, unpaired_caption_labels, unpaired_caption_masks, unpaired_caption_att_masks, unpaired_caption['gts'], torch.arange(0, len(unpaired_caption['gts'])), sc_flag, struc_flag) unlabelled_caption_loss = unlabelled_caption_model_out[ 'loss'].mean() unlabelled_caption_loss.backward() # print(unlabelled_caption_loss) simple_sgd.update(unpaired_caption_fc_feats) # a=1 unpaired_caption_fc_feats.requires_grad = False unpaired_caption_att_feats.requires_grad = False unlabelled_caption_model_out = dp_lw_models[who_to_study]( unpaired_caption_fc_feats, unpaired_caption_att_feats, unpaired_caption_labels, unpaired_caption_masks, unpaired_caption_att_masks, unpaired_caption['gts'], torch.arange(0, len(unpaired_caption['gts'])), sc_flag, struc_flag) unlabelled_caption_loss = unlabelled_caption_model_out[ 'loss'].mean() loss += opt.number_of_models * opt.current_lambda_y * unlabelled_caption_loss loss.backward() if opt.grad_clip_value != 0: getattr(torch.nn.utils, 'clip_grad_%s_' % (opt.grad_clip_mode))(multi_models.parameters(), opt.grad_clip_value) optimizer.step() for order in range(opt.number_of_models): for param, param_ema in zip( multi_models_list[order].parameters(), multi_models_list[order + opt.number_of_models].parameters()): param_ema.data = opt.alpha * param_ema.data + ( 1 - opt.alpha) * param.data train_loss = loss.item() torch.cuda.synchronize() end = time.time() # if struc_flag: # print("iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \ # .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), end - start)) # elif not sc_flag: # print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ # .format(iteration, epoch, train_loss, end - start)) # else: # print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ # .format(iteration, epoch, model_out['reward'].mean(), end - start)) if struc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss/opt.number_of_models, sum([model_outs_list[index]['lm_loss'].mean().item() for index in range(opt.number_of_models)])/opt.number_of_models, sum([model_outs_list[index]['struc_loss'].mean().item() for index in range(opt.number_of_models)])/opt.number_of_models, end - start)) elif not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, language_loss.item()/opt.number_of_models, end - start)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, sum([model_outs_list[index]['reward'].mean().item() for index in range(opt.number_of_models)])/opt.number_of_models, end - start)) # Update the iteration and epoch iteration += 1 if epoch < opt.paired_train_epoch: if data['bounds']['wrapped']: epoch += 1 epoch_done = True else: if data['bounds']['wrapped']: epoch += 1 epoch_done = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): # tb_summary_writer.add_scalar('train_loss', train_loss, iteration) for index in range(opt.number_of_models): model_id = 'model_{}'.format(index) tb_summary_writer.add_scalars('language_loss', { model_id: model_outs_list[index]['loss'].mean().item() }, iteration) if epoch >= opt.paired_train_epoch: # for index in range(opt.number_of_models): # model_id = 'model_{}'.format(index) # kd_model_outs_val = 0 if len(kd_model_outs_list) == 0 else kd_model_outs_list[index]['loss'].mean().item() # tb_summary_writer.add_scalars('distilling_loss', # {model_id: kd_model_outs_val}, # iteration) tb_summary_writer.add_scalar('distilling_loss', distilling_loss.item(), iteration) tb_summary_writer.add_scalar( 'unlabelled_caption_loss', unlabelled_caption_loss.item(), iteration) tb_summary_writer.add_scalar('hyper_parameter_lambda_x', opt.current_lambda_x, iteration) tb_summary_writer.add_scalar('hyper_parameter_lambda_y', opt.current_lambda_y, iteration) # tb_summary_writer.add_scalar('triplet_loss', triplet_loss_val.item(), iteration) if opt.noamopt: opt.current_lr = optimizer.rate() elif opt.reduce_on_plateau: opt.current_lr = optimizer.current_lr tb_summary_writer.add_scalar('learning_rate', opt.current_lr, iteration) tb_summary_writer.add_scalar('scheduled_sampling_prob', multi_models[0].ss_prob, iteration) if sc_flag: for index in range(opt.number_of_models): # tb_summary_writer.add_scalar('avg_reward', model_out['reward'].mean(), iteration) model_id = 'model_{}'.format(index) tb_summary_writer.add_scalars( 'avg_reward', { model_id: model_outs_list[index]['reward'].mean().item() }, iteration) elif struc_flag: # tb_summary_writer.add_scalar('lm_loss', model_out['lm_loss'].mean().item(), iteration) # tb_summary_writer.add_scalar('struc_loss', model_out['struc_loss'].mean().item(), iteration) # tb_summary_writer.add_scalar('reward', model_out['reward'].mean().item(), iteration) # tb_summary_writer.add_scalar('reward_var', model_out['reward'].var(1).mean(), iteration) model_id = 'model_{}'.format(index) for index in range(opt.number_of_models): tb_summary_writer.add_scalars( 'lm_loss', { model_id: model_outs_list[index] ['lm_loss'].mean().item() }, iteration) tb_summary_writer.add_scalars( 'struc_loss', { model_id: model_outs_list[index] ['struc_loss'].mean().item() }, iteration) tb_summary_writer.add_scalars( 'reward', { model_id: model_outs_list[index]['reward'].mean().item() }, iteration) tb_summary_writer.add_scalars( 'reward_var', { model_id: model_outs_list[index]['reward'].var(1).mean() }, iteration) histories['loss_history'][ iteration] = train_loss if not sc_flag else sum([ model_outs_list[index]['reward'].mean().item() for index in range(opt.number_of_models) ]) / opt.number_of_models histories['lr_history'][iteration] = opt.current_lr histories['ss_prob_history'][iteration] = multi_models[ 0].ss_prob # update infos infos['iter'] = iteration infos['epoch'] = epoch infos['loader_state_dict'] = loader.state_dict() # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0 and not opt.save_every_epoch and epoch >= opt.paired_train_epoch) or \ (epoch_done and opt.save_every_epoch and epoch >= opt.paired_train_epoch): # load ensemble # Setup the model model = AttEnsemble(multi_models_list[opt.number_of_models:2 * opt.number_of_models], weights=None) model.seq_length = opt.max_length model.cuda() model.eval() # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) # eval_kwargs['beam_size'] = 5 # eval_kwargs['verbose_beam'] = 1 # eval_kwargs['verbose_loss'] = 1 # val_loss, predictions, lang_stats = eval_utils.eval_split( # dp_model, lw_model.crit, loader, eval_kwargs) with torch.no_grad(): val_loss, predictions, lang_stats = eval_utils.eval_split( model, lw_models[0].crit, loader, eval_kwargs) model.train() if opt.reduce_on_plateau: if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: optimizer.scheduler_step(val_loss) # Write validation result into summary tb_summary_writer.add_scalar('validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): tb_summary_writer.add_scalar(k, v, iteration) histories['val_result_history'][iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True # Dump miscalleous informations infos['best_val_score'] = best_val_score utils.save_checkpoint(opt, multi_models, infos, optimizer, histories) if opt.save_history_ckpt: utils.save_checkpoint( opt, multi_models, infos, optimizer, append=str(epoch) if opt.save_every_epoch else str(iteration)) if best_flag: utils.save_checkpoint(opt, multi_models, infos, optimizer, append='best') # if epoch_done and epoch == opt.paired_train_epoch: # utils.save_checkpoint(opt, multi_models, infos, optimizer, histories) # if opt.save_history_ckpt: # utils.save_checkpoint(opt, multi_models, infos, optimizer, # append=str(epoch) if opt.save_every_epoch else str(iteration)) # cmd = 'cp -r ' + 'log_' + opt.id + ' ' + 'log_' + opt.id + '_backup' # os.system(cmd) except (RuntimeError, KeyboardInterrupt): print('Save ckpt on exception ...') utils.save_checkpoint(opt, multi_models, infos, optimizer) print('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace)
def train(opt): ################################ # Build dataloader ################################ loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length ########################## # Initialize infos ########################## infos = { 'iter': 0, 'epoch': 0, 'loader_state_dict': None, 'vocab': loader.get_vocab(), } # Load old infos(if there is) and check if models are compatible if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')): with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'), 'rb') as f: infos = utils.pickle_load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert getattr(saved_model_opt, checkme) == getattr( opt, checkme ), "Command line argument and saved model disagree on '%s' " % checkme infos['opt'] = opt ######################### # Build logger ######################### # naive dict logger histories = defaultdict(dict) if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f: histories.update(utils.pickle_load(f)) # tensorboard logger tb_summary_writer = SummaryWriter(opt.checkpoint_path) ########################## # Build model ########################## opt.vocab = loader.get_vocab() model = models.setup(opt).cuda() del opt.vocab # Load pretrained weights: if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, 'model.pth')): model.load_state_dict( torch.load(os.path.join(opt.start_from, 'model.pth'))) # Wrap generation model with loss function(used for training) # This allows loss function computed separately on each machine lw_model = LossWrapper(model, opt) # Wrap with dataparallel # dp_model = torch.nn.DataParallel(model) # dp_lw_model = torch.nn.DataParallel(lw_model) dp_model = model dp_lw_model = lw_model ########################## # Build optimizer ########################## if opt.noamopt: assert opt.caption_model == 'transformer', 'noamopt can only work with transformer' optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) elif opt.reduce_on_plateau: optimizer = utils.build_optimizer(model.parameters(), opt) optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) else: optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if opt.start_from is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) ######################### # Get ready to start ######################### iteration = infos['iter'] epoch = infos['epoch'] # For back compatibility if 'iterators' in infos: infos['loader_state_dict'] = { split: { 'index_list': infos['split_ix'][split], 'iter_counter': infos['iterators'][split] } for split in ['train', 'val', 'test'] } loader.load_state_dict(infos['loader_state_dict']) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) if opt.noamopt: optimizer._step = iteration # flag indicating finish of an epoch # Always set to True at the beginning to initialize the lr or etc. epoch_done = True # Assure in training mode dp_lw_model.train() # Start training try: while True: # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break if epoch_done: if not opt.noamopt and not opt.reduce_on_plateau: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # set the decayed rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min( opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False # If start structure loss training if opt.structure_after != -1 and epoch >= opt.structure_after: struc_flag = True init_scorer(opt.cached_tokens) else: struc_flag = False epoch_done = False start = time.time() # Load data from train split (0) data = loader.get_batch('train') torch.cuda.synchronize() start = time.time() tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'], data['topics'] ] tmp = [_ if _ is None else _.cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks, topic_vecs = tmp # gts_end = data['gts_end'] # import ipdb; ipdb.set_trace() train_loss = 0.0 #### Model forward pass #### Liangming: Add for loop sent_num = opt.seq_per_img labels = labels.view(opt.batch_size, sent_num, -1) masks = masks.view(opt.batch_size, sent_num, -1) # topic_vecs = topic_vecs.view(opt.batch_size, sent_num, -1) # initilize topic vec, the shape is: [batch_size 10, max_seq_len 31, hidden_size 512] topic_vec = torch.zeros((opt.batch_size, labels.shape[2] - 1, opt.rnn_size)).float().cuda() total_loss = 0.0 for sent_n in range(sent_num): # prepare sentence data sent_label = labels[:, sent_n, :] sent_mask = masks[:, sent_n, :] # We should skip the batch in which the sentences for all examples in the batch are 0s. # This is likely to happen at the end of the paragraph) if torch.sum(sent_label).item() == 0: continue # model forward pass optimizer.zero_grad() model_out = dp_lw_model(fc_feats, att_feats, sent_label, sent_mask, att_masks, topic_vec, data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag) # loss calculation loss = model_out['loss'].mean() total_loss += loss decoder_output = model_out['decoder_output'] # Cannot backward here, this will cause multiple backwards... # Specify retain_graph=True, if you still want to backward here #loss.backward() #getattr(torch.nn.utils, 'clip_grad_%s_' %(opt.grad_clip_mode))(model.parameters(), opt.grad_clip_value) #optimizer.step() #train_loss = loss.item() # PLM: treat decoder output as "topic vec" # topic_vec = decoder_output # Optional: Shrink the size of it based on the mask? max_sent_len = -1 for row in range(sent_mask.shape[0]): sent_len = int(sum(sent_mask[row, :]).data.item()) if sent_len > max_sent_len: max_sent_len = sent_len topic_vec = decoder_output[:, 0:max_sent_len, :] avg_loss = total_loss / sent_num avg_loss.backward() #loss.backward() getattr(torch.nn.utils, 'clip_grad_%s_' % (opt.grad_clip_mode))( model.parameters(), opt.grad_clip_value) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() end = time.time() if iteration % opt.print_freq == 1: print('Read data:', time.time() - start) if struc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), end - start)) elif not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, model_out['reward'].mean(), end - start)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 epoch_done = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): tb_summary_writer.add_scalar('train_loss', train_loss, iteration) if opt.noamopt: opt.current_lr = optimizer.rate() elif opt.reduce_on_plateau: opt.current_lr = optimizer.current_lr tb_summary_writer.add_scalar('learning_rate', opt.current_lr, iteration) tb_summary_writer.add_scalar('scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: tb_summary_writer.add_scalar('avg_reward', model_out['reward'].mean(), iteration) elif struc_flag: tb_summary_writer.add_scalar( 'lm_loss', model_out['lm_loss'].mean().item(), iteration) tb_summary_writer.add_scalar( 'struc_loss', model_out['struc_loss'].mean().item(), iteration) tb_summary_writer.add_scalar( 'reward', model_out['reward'].mean().item(), iteration) histories['loss_history'][ iteration] = train_loss if not sc_flag else model_out[ 'reward'].mean() histories['lr_history'][iteration] = opt.current_lr histories['ss_prob_history'][iteration] = model.ss_prob # update infos infos['iter'] = iteration infos['epoch'] = epoch infos['loader_state_dict'] = loader.state_dict() # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0 and not opt.save_every_epoch) or \ (epoch_done and opt.save_every_epoch): # eval model eval_kwargs = {'split': 'val', 'dataset': 'val'} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, lw_model.crit, loader, eval_kwargs) if opt.reduce_on_plateau: if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: optimizer.scheduler_step(val_loss) # Write validation result into summary tb_summary_writer.add_scalar('validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): tb_summary_writer.add_scalar(k, v, iteration) histories['val_result_history'][iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True # Dump miscalleous informations infos['best_val_score'] = best_val_score utils.save_checkpoint(opt, model, infos, optimizer, histories) if opt.save_history_ckpt: utils.save_checkpoint( opt, model, infos, optimizer, append=str(epoch) if opt.save_every_epoch else str(iteration)) if best_flag: utils.save_checkpoint(opt, model, infos, optimizer, append='best') except (RuntimeError, KeyboardInterrupt): print('Save ckpt on exception ...') utils.save_checkpoint(opt, model, infos, optimizer) print('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace)
def train(opt): # Deal with feature things before anything opt.use_att = utils.if_use_att(opt.caption_model) ac = 0 loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open( os.path.join( opt.checkpoint_path, 'infos_' + opt.id + format(int(opt.start_from), '04') + '.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join( opt.checkpoint_path, 'histories_' + opt.id + format(int(opt.start_from), '04') + '.pkl')): with open( os.path.join( opt.checkpoint_path, 'histories_' + opt.id + format(int(opt.start_from), '04') + '.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) model = models.setup(opt).cuda() #dp_model = torch.nn.DataParallel(model) #dp_model = torch.nn.DataParallel(model, [0,2,3]) dp_model = model update_lr_flag = True # Assure in training mode dp_model.train() for name, param in model.named_parameters(): print(name) crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() CE_ac = utils.CE_ac() optim_para = model.parameters() optimizer = utils.build_optimizer(optim_para, opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join( opt.checkpoint_path, 'optimizer' + opt.id + format(int(opt.start_from), '04') + '.pth')): optimizer.load_state_dict( torch.load( os.path.join( opt.checkpoint_path, 'optimizer' + opt.id + format(int(opt.start_from), '04') + '.pth'))) optimizer.zero_grad() accumulate_iter = 0 train_loss = 0 reward = np.zeros([1, 1]) sim_lambda = opt.sim_lambda while True: if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False start = time.time() # Load data from train split (0) data = loader.get_batch(opt.train_split) print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() tmp = [data['labels'], data['masks'], data['mods']] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] labels, masks, mods = tmp tmp = [ data['att_feats'], data['att_masks'], data['attr_feats'], data['attr_masks'], data['rela_feats'], data['rela_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] att_feats, att_masks, attr_feats, attr_masks, rela_feats, rela_masks = tmp rs_data = {} rs_data['att_feats'] = att_feats rs_data['att_masks'] = att_masks rs_data['attr_feats'] = attr_feats rs_data['attr_masks'] = attr_masks rs_data['rela_feats'] = rela_feats rs_data['rela_masks'] = rela_masks if not sc_flag: logits, cw_logits = dp_model(rs_data, labels) ac = CE_ac(logits, labels[:, 1:], masks[:, 1:]) print('ac :{0}'.format(ac)) loss_lan = crit(logits, labels[:, 1:], masks[:, 1:]) else: gen_result, sample_logprobs, cw_logits = dp_model( rs_data, opt={'sample_max': 0}, mode='sample') reward = get_self_critical_reward(dp_model, rs_data, data, gen_result, opt) loss_lan = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) loss_cw = crit(cw_logits, mods[:, 1:], masks[:, 1:]) ac2 = CE_ac(cw_logits, mods[:, 1:], masks[:, 1:]) print('ac :{0}'.format(ac2)) if epoch < opt.step2_train_after: loss = loss_lan + sim_lambda * loss_cw else: loss = loss_lan accumulate_iter = accumulate_iter + 1 loss = loss / opt.accumulate_number loss.backward() if accumulate_iter % opt.accumulate_number == 0: utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() optimizer.zero_grad() iteration += 1 accumulate_iter = 0 train_loss = loss.item() * opt.accumulate_number train_loss_lan = loss_lan.item() train_loss_cw = loss_cw.item() end = time.time() if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) print("train_loss_lan = {:.3f}, train_loss_cw = {:.3f}" \ .format(train_loss_lan, train_loss_cw)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:, 0]), end - start)) print("train_loss_lan = {:.3f}, train_loss_cw = {:.3f}" \ .format(train_loss_lan, train_loss_cw)) print('lr:{0}'.format(opt.current_lr)) torch.cuda.synchronize() # Update the iteration and epoch if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0) and (accumulate_iter % opt.accumulate_number == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tb_summary_writer, 'train_loss_lan', train_loss_lan, iteration) add_summary_value(tb_summary_writer, 'train_loss_cw', train_loss_cw, iteration) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) add_summary_value(tb_summary_writer, 'ac', ac, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:, 0]), iteration) loss_history[iteration] = train_loss if not sc_flag else np.mean( reward[:, 0]) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0) and (accumulate_iter % opt.accumulate_number == 0): # eval model eval_kwargs = {'split': 'test', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) #val_loss, predictions, lang_stats = eval_utils_rs3.eval_split(dp_model, crit, loader, eval_kwargs) # Write validation result into summary # add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) # if lang_stats is not None: # for k,v in lang_stats.items(): # add_summary_value(tb_summary_writer, k, v, iteration) # val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} # Save model if is improving on validation result # if opt.language_eval == 1: # current_score = lang_stats['CIDEr'] # else: # current_score = - val_loss current_score = 0 best_flag = False if True: # if true save_id = iteration / opt.save_checkpoint_every if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join( opt.checkpoint_path, 'model' + opt.id + format(int(save_id), '04') + '.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join( opt.checkpoint_path, 'optimizer' + opt.id + format(int(save_id), '04') + '.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open( os.path.join( opt.checkpoint_path, 'infos_' + opt.id + format(int(save_id), '04') + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join( opt.checkpoint_path, 'histories_' + opt.id + format(int(save_id), '04') + '.pkl'), 'wb') as f: cPickle.dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '-best.pkl'), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt): iteration = 0 epoch = 0 # Load data loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length # Tensorboard summaries (they're great!) tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) # Load pretrained model, info file, histories file infos = {} histories = {} # Create model model = convcap(opt).cuda() pretrained_dict = torch.load(opt.model) model.load_state_dict(pretrained_dict, strict=False) start = time.time() dp_model = torch.nn.DataParallel(model) dp_model.train() optimizer = utils.build_optimizer(model.parameters(), opt) update_lr_flag = True samplenet = sampleNet(dp_model, opt) while True: # Unpack data #torch.cuda.synchronize() data = loader.get_batch('train') data_time = time.time() - start tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['dist'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, dist_label, masks, att_masks = tmp batchsize = fc_feats.size(0) # Forward pass and loss optimizer.zero_grad() labels_decode = labels.view(-1, 180) captions = utils.decode_sequence(loader.get_vocab(), labels_decode, None) captions_all = [] for index, caption in enumerate(captions): caption = caption.replace('<start>', '').replace(' ,', '').replace(' ', ' ') captions_all.append(caption) #print (captions_all[0]) #with torch.no_grad(): target, outcap, sample_right = samplenet(batchsize, 30 * 6, loader.get_vocab(), att_feats) #wordclass_feed = wordclass_feed.reshape((batchsize, 6, 30)) #out, _ = dp_model(fc_feats, att_feats, torch.tensor(wordclass_feed)) #Logprobs = torch.log(F.softmax(out.transpose(2,1))) #target = target.view((batchsize, (30*6), -1)) #sampleLogprobs = Logprobs.gather(2, target.long().unsqueeze(2)) # gather t #print (sampleLogprobs.size(), sample_right.size()) #print (sampleLogprobs.squeeze()[:, :], sample_right[:, :]) with torch.no_grad(): reward, cider_sample, cider_greedy = get_self_critical_reward( batchsize, dp_model, att_feats, outcap, captions_all, loader.get_vocab(), 30 * 6) loss_rl = rl_crit(sample_right, target.data, torch.from_numpy(reward).float()) wordact, x_all = dp_model(fc_feats, att_feats, labels, 30, 6) mask = masks[:, 1:].contiguous() wordact = wordact[:, :, :-1] wordact_t = wordact.permute(0, 2, 1).contiguous() wordact_t = wordact_t.view(wordact_t.size(0) * wordact_t.size(1), -1) labels = labels.contiguous().view(-1, 6 * 30).cpu() wordclass_v = labels[:, 1:] wordclass_t = wordclass_v.contiguous().view(\ wordclass_v.size(0) * wordclass_v.size(1), 1) maskids = torch.nonzero(mask.view(-1).cpu()).numpy().reshape(-1) loss_xe = F.cross_entropy(wordact_t[maskids, ...], \ wordclass_t[maskids, ...].contiguous().view(maskids.shape[0])) loss_xe_all = loss_rl #+ F.mse_loss(x_all_inference.cuda(), x_all.cuda()).cuda() loss_xe_all.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss_xe_all.item() torch.cuda.synchronize() # Print total_time = time.time() - start reward = reward[:, 0].mean() cider_sample = cider_sample.mean() cider_greedy = cider_greedy.mean() if 1: if iteration % 2 == 1: print('Read data:', time.time() - start) if 0: print("iter {} (epoch {}), train_loss = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, data_time, total_time)) if 1: print("iter {} (epoch {}), train_loss = {:.3f}, avg_reward = {:.3f},cider_sample = {:.3f}, cider_greedy ={:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, reward, cider_sample, cider_greedy, data_time, total_time)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary ''' if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) #add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration) loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0]) lr_history[iteration] = opt.current_lr #ss_prob_history[iteration] = model.ss_prob ''' # Validate and save model if (iteration % opt.save_checkpoint_every == 0): checkpoint_path = os.path.join( opt.checkpoint_path, 'model' + str(iteration) + '.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Evaluate model '''
def train(opt): # Load data loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length # Tensorboard summaries (they're great!) tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) # Load pretrained model, info file, histories file infos = {} histories = {} if opt.start_from is not None: with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = ["rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) # Create model model = models.setup(opt).cuda() dp_model = torch.nn.DataParallel(model, device_ids=[0]) dp_model.train() # Loss function crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() # Optimizer and learning rate adjustment flag optimizer = utils.build_optimizer(model.parameters(), opt) update_lr_flag = True # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) # Training loop while True: # Update learning rate once per epoch if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False # Load data from train split (0) start = time.time() data = loader.get_batch('train') data_time = time.time() - start start = time.time() # Unpack data torch.cuda.synchronize() tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp # Forward pass and loss optimizer.zero_grad() if not sc_flag: loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:]) else: gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) # Backward pass loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() # Print total_time = time.time() - start if iteration % opt.print_freq == 1: print('Read data:', time.time() - start) if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, data_time, total_time)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:, 0]), iteration) loss_history[iteration] = train_loss if not sc_flag else np.mean( reward[:, 0]) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # Validate and save model if (iteration % opt.save_checkpoint_every == 0): # Evaluate model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, crit, loader, eval_kwargs) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Our metric is CIDEr if available, otherwise validation loss if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss # Save model in checkpoint path best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(histories, f) # Save model to unique file if new best model if best_flag: model_fname = 'model-best-i{:05d}-score{:.4f}.pth'.format( iteration, best_val_score) infos_fname = 'model-best-i{:05d}-infos.pkl'.format(iteration) checkpoint_path = os.path.join(opt.checkpoint_path, model_fname) torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open(os.path.join(opt.checkpoint_path, infos_fname), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt): logger = initialize_logger(os.path.join(opt.checkpoint_path, 'train.log')) print = logger.info if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length # Print out the option variables print("*" * 20) for k, v in opt.__dict__.items(): print("%r: %r" % (k, v)) print("*" * 20) infos = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos.json'), 'r') as f: infos = json.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) else: best_val_score = None model = models.setup(opt).to(device) dp_model = torch.nn.DataParallel(model) update_lr_flag = True # Assure in training mode dp_model.train() crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) start_time = time.time() while True: if update_lr_flag: # Assign the learning rate if 0 <= opt.learning_rate_decay_start < epoch: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor utils.set_lr(optimizer, opt.current_lr) # set the decayed rate else: opt.current_lr = opt.learning_rate # Assign the scheduled sampling prob if 0 <= opt.scheduled_sampling_start < epoch: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer() else: sc_flag = False update_lr_flag = False # Load data from train split (0) batch_data = loader.get_batch('train') torch.cuda.synchronize(device) tmp = [ batch_data['fc_feats'], batch_data['att_feats'], batch_data['labels'], batch_data['masks'], batch_data['att_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).to(device) for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp optimizer.zero_grad() if not sc_flag: outputs = dp_model(fc_feats, att_feats, labels, att_masks) loss = crit(outputs, labels[:, 1:], masks[:, 1:]) else: gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, batch_data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().to(device)) loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.data torch.cuda.synchronize(device) # Update the iteration and epoch iteration += 1 if batch_data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Print train loss or avg reward if iteration % opt.losses_print_every == 0: if not sc_flag: print( "iter {} (epoch {}), loss = {:.3f}, time = {:.3f}".format( iteration, epoch, loss.item(), time.time() - start_time)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, time = {:.3f}". format(iteration, epoch, np.mean(reward[:, 0]), time.time() - start_time)) start_time = time.time() # make evaluation on validation set, and save model if (opt.save_checkpoint_every > 0 and iteration % opt.save_checkpoint_every == 0)\ or (opt.save_checkpoint_every <= 0 and update_lr_flag): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.simple_eval_split( dp_model, loader, eval_kwargs) # Save model if is improving on validation result if not os.path.exists(opt.checkpoint_path): os.makedirs(opt.checkpoint_path) if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscellaneous information infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = vars(opt) infos['vocab'] = loader.get_vocab() with open(os.path.join(opt.checkpoint_path, 'infos.json'), 'w') as f: json.dump(infos, f, sort_keys=True, indent=4) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open(os.path.join(opt.checkpoint_path, 'infos-best.json'), 'w') as f: json.dump(infos, f, sort_keys=True, indent=4) # Stop if reaching max epochs if opt.max_epochs != -1 and epoch >= opt.max_epochs: break