def traditional_cider(self, fc_feats, att_feats, att_masks, data, loss, gen_result, greedy_res, sample_logprobs, gen_masks): # Use the differenced rewards if self.use_gen_cider_scores == 0: reward, cider_greedy = rewards.get_self_critical_reward( data, gen_result, greedy_res) else: # use the original rewards reward, _, cider_greedy = \ rewards.get_self_critical_reward( data, gen_result, greedy_res, return_gen_scores=True) self._loss['avg_reward'] = reward.mean() self._loss['cider_greedy'] = cider_greedy loss_cider = sample_logprobs * utils.var_wrapper( -reward.astype('float32'), cuda=torch.cuda.is_available()).unsqueeze(1) * ( gen_masks[:, 1:].detach()) loss_cider = loss_cider.sum() / \ gen_masks[:, 1:].data.float().sum() loss += self.cider_optimization * loss_cider self._loss['loss_cider'] = loss_cider.data[0] return loss
def forward(self, fc_feats, att_feats, obj_label, rela_label, rela_sub, rela_obj, rela_n2r, geometry, adj1, adj2, labels, masks, att_masks, rela_masks, gts, gt_indices, sc_flag): out = {} if not sc_flag: loss = self.crit(self.model(fc_feats, att_feats, obj_label, rela_label, rela_sub, rela_obj, rela_n2r, geometry, adj1, adj2, rela_masks, labels, att_masks), labels[:, 1:], masks[:, 1:]) else: self.model.eval() with torch.no_grad(): greedy_res, _ = self.model(fc_feats, att_feats, obj_label, rela_label, rela_sub, rela_obj, rela_n2r, geometry, adj1, adj2, rela_masks, labels, att_masks, mode='sample') self.model.train() gen_result, sample_logprobs = self.model(fc_feats, att_feats, obj_label, rela_label, rela_sub, rela_obj, rela_n2r, geometry, adj1, adj2, rela_masks, labels, att_masks, opt={'sample_method': 'sample'}, mode='sample') gts = [gts[_] for _ in gt_indices.tolist()] reward = get_self_critical_reward( greedy_res, gts, gen_result, self.opt) reward = torch.from_numpy(reward).float().to(gen_result.device) loss = self.rl_crit(sample_logprobs, gen_result.data, reward) out['reward'] = reward[:, 0].mean() out['loss'] = loss return out
def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices, sc_flag): out = {} if not sc_flag: loss = self.crit(self.model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:]) else: self.model.eval() with torch.no_grad(): greedy_res, _ = self.model(fc_feats, att_feats, att_masks, mode='sample') self.model.train() gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, opt={'sample_method':'sample'}, mode='sample') gts = [gts[_] for _ in gt_indices.tolist()] reward = get_self_critical_reward(greedy_res, gts, gen_result, self.opt) reward = torch.from_numpy(reward).float().to(gen_result.device) loss = self.rl_crit(sample_logprobs, gen_result.data, reward) out['reward'] = reward[:,0].mean() if self.opt.caption_model == 'aat': all_aat_loss = torch.stack(self.model.all_att_cost).t() if not sc_flag: mask_ = masks[:,:all_aat_loss.size()[1]] else: mask_ = (torch.cat((gen_result.new_ones(gen_result.size(0),1), gen_result), dim=1)>0)[:,:all_aat_loss.size()[1]] aat_loss = (all_aat_loss * mask_.float()).sum(1).mean() out['aat_loss'] = aat_loss out['att_step'] = self.model.all_att_step out['avg_att_time'] = (np.array(self.model.all_att_step).transpose() * mask_.cpu().numpy()).sum()/mask_.cpu().numpy().sum() out['loss_'] = loss.clone() loss += self.opt.aat_lambda * aat_loss out['loss'] = loss return out
def train(loader, model, crit, optimizer, lr_scheduler, opt, rl_crit=None): model.train() model = nn.DataParallel(model) for epoch in range(opt["epochs"]): lr_scheduler.step() iteration = 0 # If start self crit training if opt["self_crit_after"] != -1 and epoch >= opt["self_crit_after"]: sc_flag = True init_cider_scorer(opt["cached_tokens"]) else: sc_flag = False for data in loader: torch.cuda.synchronize() fc_feats = Variable(data['fc_feats']).cuda() labels = Variable(data['labels']).long().cuda() masks = Variable(data['masks']).cuda() optimizer.zero_grad() if not sc_flag: seq_probs, _ = model(fc_feats, labels, 'train') loss = crit(seq_probs, labels[:, 1:], masks[:, 1:]) else: seq_probs, seq_preds = model(fc_feats, mode='inference', opt=opt) reward = get_self_critical_reward(model, fc_feats, data, seq_preds) print(reward.shape) loss = rl_crit( seq_probs, seq_preds, Variable(torch.from_numpy(reward).float().cuda())) loss.backward() utils.clip_gradient(optimizer, opt["grad_clip"]) optimizer.step() train_loss = loss.data[0] torch.cuda.synchronize() iteration += 1 if not sc_flag: print("iter %d (epoch %d), train_loss = %.6f" % (iteration, epoch, train_loss)) else: print("iter %d (epoch %d), avg_reward = %.6f" % (iteration, epoch, np.mean(reward[:, 0]))) if epoch != 0 and epoch % opt["save_checkpoint_every"] == 0: model_path = os.path.join(opt["checkpoint_path"], 'model_%d.pth' % (epoch)) model_info_path = os.path.join(opt["checkpoint_path"], 'model_score.txt') torch.save(model.state_dict(), model_path) print("model saved to %s" % (model_path)) with open(model_info_path, 'a') as f: f.write("model_%d, loss: %.6f\n" % (epoch, train_loss))
def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices, sc_flag, ppo_flag, clipped_lambda, sc_lambda): ## Added ppo_flag and old_model for ppo 9/sep/2019 out = {} ################ ADDED THIS SECTION for ppo 8/sep/2019 if ppo_flag: gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, opt={'sample_max': 0},mode='sample') #print("######### SAMPLE LOGPROB#######",sample_logprobs.shape,sample_logprobs) ## REMOVE LATER #if self.old_sample_logprobs == None: ## Added this to control the intial null problem of the old policy # self.old_sample_logprobs = sample_logprobs.clone() ## Added this on 11/Sep/2019 #print('gen_result length:\n',gen_result) gts = [gts[_] for _ in gt_indices.tolist()] #reward = get_self_critical_reward(greedy_res, gts, gen_result, self.opt) reward = get_self_critical_reward(self.model, fc_feats, att_feats, att_masks, gts, gen_result, self.opt) #print("Reward given:", reward, len(reward)) reward = torch.from_numpy(reward).float().to(gen_result.device) # loss = self.ppo_crit(sample_logprobs, self.old_sample_logprobs, gen_result.data,reward) # The loss is the main part, the core of reinforce, I guess, coming from utils.RewardCriterion() ###### Added in 24/sep/2019 as a way of combining PPO-clip and scst####### loss_ppo = self.ppo_crit(sample_logprobs, self.old_sample_logprobs, gen_result.data,reward) # The loss is the main part, the core of reinforce, I guess, coming from utils.RewardCriterion() loss_sc = self.rl_crit(sample_logprobs, gen_result.data, reward) self.old_sample_logprobs = sample_logprobs.clone() print("Using sc_lambda: {}\tclipped_lambda: {}".format(sc_lambda,clipped_lambda)) loss = sc_lambda * loss_sc + clipped_lambda * loss_ppo #loss = sc_lambda * loss_sc + clipped_lambda * 1 #********* Replacing with a dummy value c = 1 - 13/oct/2019 #loss = loss_ppo ## Activate for only Clipped-SC loss ######################################################################### out['reward'] = reward[:, 0].mean() else: ############################################## if not sc_flag: loss = self.crit(self.model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:]) else: gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample') gts = [gts[_] for _ in gt_indices.tolist()] reward = get_self_critical_reward(self.model, fc_feats, att_feats, att_masks, gts, gen_result, self.opt) reward = torch.from_numpy(reward).float().to(gen_result.device) loss = self.rl_crit(sample_logprobs, gen_result.data, reward) out['reward'] = reward[:,0].mean() out['loss'] = loss return out
def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices, sc_flag, struc_flag): opt = self.opt out = {} if struc_flag: if opt.structure_loss_weight < 1: lm_loss = self.crit(self.model(fc_feats, att_feats, labels, att_masks), labels[..., 1:], masks[..., 1:]) else: lm_loss = torch.tensor(0).type_as(fc_feats) if opt.structure_loss_weight > 0: gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, opt={'sample_method':opt.train_sample_method, 'beam_size':opt.train_beam_size, 'output_logsoftmax': opt.struc_use_logsoftmax or opt.structure_loss_type == 'softmax_margin'\ or not 'margin' in opt.structure_loss_type, 'sample_n': opt.train_sample_n}, mode='sample') gts = [gts[_] for _ in gt_indices.tolist()] struc_loss = self.struc_crit(sample_logprobs, gen_result, gts) else: struc_loss = {'loss': torch.tensor(0).type_as(fc_feats), 'reward': torch.tensor(0).type_as(fc_feats)} loss = (1-opt.structure_loss_weight) * lm_loss + opt.structure_loss_weight * struc_loss['loss'] out['lm_loss'] = lm_loss out['struc_loss'] = struc_loss['loss'] out['reward'] = struc_loss['reward'] elif not sc_flag: loss = self.crit(self.model(fc_feats, att_feats, labels, att_masks), labels[..., 1:], masks[..., 1:]) else: self.model.eval() with torch.no_grad(): greedy_res, _ = self.model(fc_feats, att_feats, att_masks, mode='sample', opt={'sample_method': opt.sc_sample_method, 'beam_size': opt.sc_beam_size}) self.model.train() gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, opt={'sample_method':opt.train_sample_method, 'beam_size':opt.train_beam_size, 'sample_n': opt.train_sample_n}, mode='sample') gts = [gts[_] for _ in gt_indices.tolist()] reward = get_self_critical_reward(greedy_res, gts, gen_result, self.opt) reward = torch.from_numpy(reward).float().to(gen_result.device) loss = self.rl_crit(sample_logprobs, gen_result.data, reward) out['reward'] = reward[:,0].mean() out['loss'] = loss return out
def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices, sc_flag): out = {} if not sc_flag: loss = self.crit( self.model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:]) else: gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') gts = [gts[_] for _ in gt_indices.tolist()] reward = get_self_critical_reward(self.model, fc_feats, att_feats, att_masks, gts, gen_result, self.opt) reward = torch.from_numpy(reward).float().to(gen_result.device) loss = self.rl_crit(sample_logprobs, gen_result.data, reward) out['reward'] = reward[:, 0].mean() out['loss'] = loss return out
def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices, sc_flag, itera): out = {} if not sc_flag: c = 'LabelSmoothing' if self.ls > 0 else 'CrossEntropy' # print(f'----------sc_flag:{sc_flag}, crit:{c}------------') loss = self.crit( self.model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:]) else: # print(f'----------sc_flag:{sc_flag}, crit:Reward------------') self.model.eval() with torch.no_grad(): # greedy search? greedy_res, _ = self.model(fc_feats, att_feats, att_masks, mode='sample') # print(f'===greedy_res:{greedy_res.shape}===') self.model.train() gen_result, sample_logprobs = self.model( fc_feats, att_feats, att_masks, opt={'sample_method': 'sample'}, mode='sample') # print(f'===gen_result:{gen_result.shape}===') # print(f'===sample_logprobs:{sample_logprobs.shape}===') # print(f'===gts:{len(gts), gts}===') gts = [gts[_] for _ in gt_indices.tolist() ] # ground truth samples each has 5 captions # print(f'===gts:{len(gts), gts}===') # gts : list reward = get_self_critical_reward(greedy_res, gts, gen_result, self.opt, itera) reward = torch.from_numpy(reward).float().to(gen_result.device) loss = self.rl_crit(sample_logprobs, gen_result.data, reward) out['reward'] = reward[:, 0].mean() out['loss'] = loss return out
def train(dataset, loader, model, rem, crit, optimizer, lr_scheduler, opt, rl_crit=None): writer = SummaryWriter('./runs/video_caption22') model.load_state_dict( torch.load( '/home/diml/video-caption.pytorch/save/RECON222_model_200.pth')) rem.load_state_dict( torch.load( '/home/diml/video-caption.pytorch/save/RECON222_module_200.pth')) #model.load_state_dict(torch.load('/home/diml/video-caption.pytorch/save/new_model_200.pth')) #model = nn.DataParallel(model) model.train() rem.train() vocab = dataset.get_vocab() for epoch in trange(opt["epochs"]): t_loss = [0, 0, 0] # ============================================================================= # model.eval() # ev.demov(model,crit, dataset, dataset.get_vocab(),opt) # ============================================================================= lr_scheduler.step() iteration = 0 # If start self crit training if opt["self_crit_after"] != -1 and epoch >= opt["self_crit_after"]: sc_flag = True init_cider_scorer(opt["cached_tokens"]) else: sc_flag = False for idx, data in enumerate(loader): torch.cuda.synchronize() fc_feats = data['fc_feats'].cuda() labels = data['labels'].cuda() labels2 = data['labels2'].cuda() masks2 = data['masks2'].cuda() masks = data['masks'].cuda() optimizer.zero_grad() if not sc_flag: seq_probs, seq_preds, hn, de_hn = model( fc_feats, labels, 'train') loss_C = crit(seq_probs, labels[:, 1:], masks[:, 1:]) fake_en_hn = rem(de_hn, seq_probs) f_seq_probs, f_seq_preds, hn, de_hn = model(fc_feats, labels2, 'train', h=fake_en_hn) loss_R = crit(f_seq_probs, labels2[:, 1:], masks2[:, 1:]) loss = loss_R + loss_C else: seq_probs, seq_preds = model(fc_feats, mode='inference', opt=opt) reward = get_self_critical_reward(model, fc_feats, data, seq_preds) print(reward.shape) loss = rl_crit(seq_probs, seq_preds, torch.from_numpy(reward).float().cuda()) t_loss[0] += loss.item() t_loss[1] += loss_C.item() t_loss[2] += loss_R.item() loss.backward() clip_grad_value_(model.parameters(), opt['grad_clip']) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() iteration += 1 if not sc_flag: print("iter %d (epoch %d), train_loss = %.6f" % (iteration, epoch, train_loss)) else: print("iter %d (epoch %d), avg_reward = %.6f" % (iteration, epoch, np.mean(reward[:, 0]))) writer.add_scalar('training total loss', t_loss[0] / 140, epoch + 200) writer.add_scalar('training Caption loss', t_loss[1] / 140, epoch + 200) writer.add_scalar('training Reconstruction loss', t_loss[2] / 140, epoch + 200) if epoch % opt["save_checkpoint_every"] == 0: model_path = os.path.join(opt["checkpoint_path"], 'RECON222_model_%d.pth' % (epoch + 200)) rem_path = os.path.join(opt["checkpoint_path"], 'RECON222_module_%d.pth' % (epoch + 200)) model_info_path = os.path.join(opt["checkpoint_path"], 'RECON222_model_score.txt') torch.save(model.state_dict(), model_path) torch.save(rem.state_dict(), rem_path) print("model saved to %s" % (model_path)) with open(model_info_path, 'a') as f: f.write("model_%d, loss: %.6f\n" % (epoch, train_loss)) with torch.no_grad(): _, seq_preds, __, ___ = model(fc_feats, mode='inference', opt=opt) _, f_seq_preds, __, ___ = model(fc_feats, mode='inference', h=fake_en_hn, opt=opt) origin = utils.decode_sequence(vocab, seq_preds)[0] revision = utils.decode_sequence(vocab, f_seq_preds)[0] with open('./results/training_versus.txt', 'a') as f: f.write("epoch is %d \n" % epoch) origin = "origin caption: " + origin + "\n" revision = "revision caption: " + revision + "\n" f.write(origin) f.write(revision)
def forward(self, fc_feats, att_feats, att_masks, seq, masks, data): if self.caption_loss_weight > 0 and not self.cider_optimization: loss_cap = self.caption_generator(fc_feats, att_feats, att_masks, seq, masks) else: loss_cap = Variable(torch.cuda.FloatTensor([0])) if self.vse_loss_weight > 0: loss_vse = self.vse(fc_feats, att_feats, seq, masks, only_one_retrieval=self.only_one_retrieval) else: loss_vse = Variable(torch.cuda.FloatTensor([0])) loss = self.caption_loss_weight * loss_cap + self.vse_loss_weight * loss_vse if self.retrieval_reward_weight > 0: if True: _seqs, _sampleLogProbs = self.caption_generator.sample( fc_feats, att_feats, att_masks, { 'sample_max': 0, 'temperature': 1 }) gen_result, sample_logprobs = _seqs, _sampleLogProbs _masks = torch.cat([ Variable( _seqs.data.new(_seqs.size(0), 2).fill_(1).float()), (_seqs > 0).float()[:, :-1] ], 1) gen_masks = _masks _seqs = torch.cat([ Variable( _seqs.data.new( _seqs.size(0), 1).fill_(self.caption_generator.vocab_size + 1)), _seqs ], 1) if True: retrieval_loss = self.vse( fc_feats, att_feats, _seqs, _masks, True, only_one_retrieval=self.only_one_retrieval) if self.reinforce_baseline_type == 'greedy': _seqs_greedy, _sampleLogProbs_greedy = self.caption_generator.sample( *utils.var_wrapper( [fc_feats, att_feats, att_masks], volatile=True), opt={ 'sample_max': 1, 'temperature': 1 }) greedy_res = _seqs_greedy # Do we need weights here??? if True: #not self.use_word_weights: _masks_greedy = torch.cat([ Variable( _seqs_greedy.data.new(_seqs.size(0), 2).fill_(1).float()), (_seqs_greedy > 0).float()[:, :-1] ], 1) else: _masks_greedy = self.get_word_weights_mask( _seqs_greedy) _seqs_greedy = torch.cat([ Variable( _seqs_greedy.data.new(_seqs_greedy.size(0), 1). fill_(self.caption_generator.vocab_size + 1)), _seqs_greedy ], 1) baseline = self.vse( fc_feats, att_feats, _seqs_greedy, _masks_greedy, True, only_one_retrieval=self.only_one_retrieval) elif self.reinforce_baseline_type == 'gt': baseline = self.vse( fc_feats, att_feats, seq, masks, True, only_one_retrieval=self.only_one_retrieval) else: baseline = 0 sc_loss = _sampleLogProbs * ( utils.var_wrapper(retrieval_loss) - utils.var_wrapper(baseline)).detach().unsqueeze(1) * ( _masks[:, 1:].detach().float()) sc_loss = sc_loss.sum() / _masks[:, 1:].data.float().sum() loss += self.retrieval_reward_weight * sc_loss self._loss['retrieval_sc_loss'] = sc_loss.data[0] self._loss['retrieval_loss'] = retrieval_loss.sum().data[0] self._loss['retrieval_loss_greedy'] = baseline.sum( ).data[0] if isinstance(baseline, Variable) else baseline if self.cider_optimization: if 'gen_result' not in locals(): gen_result, sample_logprobs = self.caption_generator.sample( fc_feats, att_feats, att_masks, opt={'sample_max': 0}) gen_masks = torch.cat([ Variable( gen_result.data.new(gen_result.size(0), 2).fill_(1).float()), (gen_result > 0).float()[:, :-1] ], 1) if 'greedy_res' not in locals(): greedy_res, _ = self.caption_generator.sample( *utils.var_wrapper([fc_feats, att_feats, att_masks], volatile=True), opt={'sample_max': 1}) reward, cider_greedy = rewards.get_self_critical_reward( data, gen_result, greedy_res) self._loss['avg_reward'] = reward.mean() self._loss['cider_greedy'] = cider_greedy loss_cap = sample_logprobs * utils.var_wrapper(-reward.astype( 'float32')).unsqueeze(1) * (gen_masks[:, 1:].detach()) loss_cap = loss_cap.sum() / gen_masks[:, 1:].data.float().sum() loss += self.caption_loss_weight * loss_cap self._loss['loss_cap'] = loss_cap.item() self._loss['loss_vse'] = loss_vse.item() self._loss['loss'] = loss.item() return loss
def train(opt): # Deal with feature things before anything opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model) if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'), 'rb') as f: infos = utils.pickle_load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f: histories = utils.pickle_load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) model = models.setup(opt).cuda() dp_model = torch.nn.DataParallel(model) epoch_done = True # Assure in training mode dp_model.train() if opt.label_smoothing > 0: crit = utils.LabelSmoothing(smoothing=opt.label_smoothing) else: crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() if opt.noamopt: assert opt.caption_model == 'transformer', 'noamopt can only work with transformer' optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) optimizer._step = iteration elif opt.reduce_on_plateau: optimizer = utils.build_optimizer(model.parameters(), opt) optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) else: optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) total_loss = 0 times = 0 while True: if epoch_done: if not opt.noamopt and not opt.reduce_on_plateau: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # set the decayed rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False epoch_done = False start = time.time() # Load data from train split (0) data = loader.get_batch('train') print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp times += 1 optimizer.zero_grad() if not sc_flag: loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:]) else: gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.item() total_loss = total_loss + train_loss torch.cuda.synchronize() end = time.time() if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:,0]), end - start)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: # epoch += 1 epoch_done = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) if opt.noamopt: opt.current_lr = optimizer.rate() elif opt.reduce_on_plateau: opt.current_lr = optimizer.current_lr add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:, 0]), iteration) loss_history[iteration] = train_loss if not sc_flag else np.mean( reward[:, 0]) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # make evaluation on validation set, and save model # if (iteration % opt.save_checkpoint_every == 0): if data['bounds']['wrapped']: epoch += 1 # eval model eval_kwargs = { 'split': 'val', 'dataset': opt.input_json, 'verbose': False } eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, crit, loader, eval_kwargs) if opt.reduce_on_plateau: if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: optimizer.scheduler_step(val_loss) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats f = open('train_log_%s.txt' % opt.id, 'a') f.write( 'Epoch {}: | Date: {} | TrainLoss: {} | ValLoss: {} | Score: {}' .format(epoch, str(datetime.now()), str(total_loss / times), str(val_loss), str(current_score))) f.write('\n') f.close() print('-------------------wrote to log file') total_loss = 0 times = 0 current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if True: # if true if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True if not os.path.isdir(opt.checkpoint_path): os.mkdir(opt.checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) # print(str(infos['best_val_score'])) print("model saved to {}".format(checkpoint_path)) if opt.save_history_ckpt: checkpoint_path = os.path.join( opt.checkpoint_path, 'model-%d.pth' % (iteration)) torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: utils.pickle_dump(infos, f) if opt.save_history_ckpt: with open( os.path.join( opt.checkpoint_path, 'infos_' + opt.id + '-%d.pkl' % (iteration)), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: utils.pickle_dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '-best.pkl'), 'wb') as f: utils.pickle_dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt): opt.use_att = utils.if_use_att(opt.caption_model) loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) model = models.setup(opt) model.cuda() #model_D = Discriminator(opt) #model_D.load_state_dict(torch.load('save/model_D.pth')) #model_D.cuda() #criterion_D = nn.CrossEntropyLoss(size_average=True) model_E = Distance(opt) model_E.load_state_dict( torch.load('save/model_E_NCE/model_E_10epoch.pthsfdasdfadf')) model_E.cuda() criterion_E = nn.CosineEmbeddingLoss(margin=0, size_average=True) #criterion_E = nn.CosineSimilarity() logger = Logger(opt) update_lr_flag = True # Assure in training mode model.train() #model_D.train() crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() optimizer_G = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay) #optimizer_D = optim.Adam(model_D.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer_G.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) while True: if update_lr_flag: opt, sc_flag, update_lr_flag, model, optimizer_G = update_lr( opt, epoch, model, optimizer_G) start = time.time() # Load data from train split (0) data = loader.get_batch('train') #print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() #tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']] tmp = [data['fc_feats'], data['labels'], data['masks']] tmp = [ Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp ] #fc_feats, att_feats, labels, masks = tmp fc_feats, labels, masks = tmp ############################################################################################################ ############################################ REINFORCE TRAINING ############################################ ############################################################################################################ if 1: #iteration % opt.D_scheduling != 0: optimizer_G.zero_grad() if not sc_flag: loss = crit(model(fc_feats, labels), labels[:, 1:], masks[:, 1:]) else: gen_result, sample_logprobs = model.sample( fc_feats, {'sample_max': 0}) #reward = get_self_critical_reward(model, fc_feats, att_feats, data, gen_result) sc_reward = get_self_critical_reward(model, fc_feats, data, gen_result, logger) #gan_reward = get_gan_reward(model, model_D, criterion_D, fc_feats, data, logger) # Criterion_D = nn.XEloss() distance_loss_reward1 = get_distance_reward( model, model_E, criterion_E, fc_feats, data, logger, is_mismatched=False) # criterion_E = nn.CosEmbedLoss() distance_loss_reward2 = get_distance_reward( model, model_E, criterion_E, fc_feats, data, logger, is_mismatched=True) # criterion_E = nn.CosEmbedLoss() #cosine_reward = get_distance_reward(model, model_E, criterion_E, fc_feats, data, logger) # criterion_E = nn.CosSim() reward = distance_loss_reward1 + distance_loss_reward2 loss = rl_crit( sample_logprobs, gen_result, Variable(torch.from_numpy(reward).float().cuda(), requires_grad=False)) loss.backward() utils.clip_gradient(optimizer_G, opt.grad_clip) optimizer_G.step() train_loss = loss.data[0] torch.cuda.synchronize() end = time.time() if not sc_flag: log = "iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start) logger.write(log) else: log = "iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:,0]), end - start) logger.write(log) ###################################################################################################### ############################################ GAN TRAINING ############################################ ###################################################################################################### else: #elif iteration % opt.D_scheduling == 0: # gan training model_D.zero_grad() optimizer_D.zero_grad() fc_feats_temp = Variable(fc_feats.data.cpu(), volatile=True).cuda() labels = Variable(labels.data.cpu()).cuda() sample_res, sample_logprobs = model.sample( fc_feats_temp, {'sample_max': 0}) #640, 16 greedy_res, greedy_logprobs = model.sample( fc_feats_temp, {'sample_max': 1}) #640, 16 gt_res = labels # 640, 18 sample_res_embed = model.embed(Variable(sample_res)) greedy_res_embed = model.embed(Variable(greedy_res)) gt_res_embed = model.embed(gt_res) f_label = Variable( torch.FloatTensor(data['fc_feats'].shape[0]).cuda()) r_label = Variable( torch.FloatTensor(data['fc_feats'].shape[0]).cuda()) f_label.data.fill_(0) r_label.data.fill_(1) f_D_output = model_D(sample_res_embed.detach(), fc_feats.detach()) f_loss = criterion_D(f_D_output, f_label.long()) f_loss.backward() r_D_output = model_D(gt_res_embed.detach(), fc_feats.detach()) r_loss = criterion_D(r_D_output, r_label.long()) r_loss.backward() D_loss = f_loss + r_loss optimizer_D.step() torch.cuda.synchronize() log = 'iter {} (epoch {}), Discriminator loss : {}'.format( iteration, epoch, D_loss.data.cpu().numpy()[0]) logger.write(log) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): if tf is not None: add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tf_summary_writer, 'avg_reward', np.mean(reward[:, 0]), iteration) tf_summary_writer.flush() loss_history[iteration] = train_loss if not sc_flag else np.mean( reward[:, 0]) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( model, crit, loader, logger, eval_kwargs) logger.write_dict(lang_stats) # Write validation result into summary if tf is not None: add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration) for k, v in lang_stats.items(): add_summary_value(tf_summary_writer, k, v, iteration) tf_summary_writer.flush() val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if True: # if true if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer_G.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '-best.pkl'), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices, sc_flag, struc_flag, drop_worst_flag): opt = self.opt out = {} reduction = 'none' if drop_worst_flag else 'mean' if struc_flag: if opt.structure_loss_weight < 1: lm_loss = self.crit(self.model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:], reduction=reduction) else: lm_loss = torch.tensor(0).type_as(fc_feats) gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, opt={'sample_max':0, 'output_logsoftmax': opt.struc_use_logsoftmax or opt.structure_loss_type == 'softmax_margin'\ or not 'margin' in opt.structure_loss_type, 'sample_n': opt.structure_sample_n}, mode='sample') gts = [gts[_] for _ in gt_indices.tolist()] struc_loss = self.struc_crit(sample_logprobs, gen_result, gts, reduction=reduction) loss = (1 - opt.structure_loss_weight ) * lm_loss + opt.structure_loss_weight * struc_loss out['lm_loss'] = lm_loss out['struc_loss'] = struc_loss elif not sc_flag: loss = self.crit(self.model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:], reduction=reduction) else: self.model.eval() with torch.no_grad(): greedy_res, _ = self.model(fc_feats, att_feats, att_masks, mode='sample') if self.retrieval_reward_weight > 0: _seqs_greedy, _sampleLogProbs_greedy = greedy_res, _ _masks_greedy = torch.cat([ _seqs_greedy.data.new(_seqs_greedy.size(0), 2).fill_(1).float(), (_seqs_greedy > 0).float()[:, :-1] ], 1) _seqs_greedy = torch.cat([ _seqs_greedy.data.new( _seqs_greedy.size(0), 1).fill_(self.model.vocab_size + 1), _seqs_greedy ], 1) baseline = self.vse(fc_feats, att_feats, att_masks, _seqs_greedy, _masks_greedy, True, only_one_retrieval='off') self.model.train() gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') gts = [gts[_] for _ in gt_indices.tolist()] reward = get_self_critical_reward(greedy_res, gts, gen_result, self.opt) reward = torch.from_numpy(reward).float().to(gen_result.device) out['reward'] = reward[:, 0].mean() if self.retrieval_reward_weight > 0: _seqs, _sampleLogProbs = gen_result, sample_logprobs _masks = torch.cat([ _seqs.data.new(_seqs.size(0), 2).fill_(1).float(), (_seqs > 0).float()[:, :-1] ], 1) gen_masks = _masks _seqs = torch.cat([ _seqs.data.new(_seqs.size(0), 1).fill_(self.model.vocab_size + 1), _seqs ], 1) retrieval_loss = self.vse(fc_feats, att_feats, att_masks, _seqs, _masks, True, only_one_retrieval='off') reward -= self.retrieval_reward_weight * ( retrieval_loss - baseline).unsqueeze(1) out['retrieval_loss'] = retrieval_loss.sum() out['retrieval_loss_greedy'] = baseline.sum() print(out['retrieval_loss'].item(), out['retrieval_loss_greedy'].item()) loss = self.rl_crit(sample_logprobs, gen_result.data, reward, reduction=reduction) out['loss'] = loss return out
def train(opt): assert opt.annfile is not None and len(opt.annfile) > 0 print('Checkpoint path is ' + opt.checkpoint_path) print('This program is using GPU ' + str(os.environ['CUDA_VISIBLE_DEVICES'])) # Deal with feature things before anything opt.use_att = utils.if_use_att(opt.caption_model) if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible if opt.load_best: info_path = os.path.join(opt.start_from, 'infos_' + opt.id + '-best.pkl') else: info_path = os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl') with open(info_path) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) if opt.learning_rate_decay_start is None: opt.learning_rate_decay_start = infos.get( 'opt', None).learning_rate_decay_start # if opt.load_best: # opt.self_critical_after = epoch elif opt.learning_rate_decay_start == -1 and opt.self_critical_after != -1 and epoch >= opt.self_critical_after: opt.learning_rate_decay_start = epoch val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) # loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) best_val_score_ave_model = infos.get('best_val_score_ave_model', None) model = models.setup(opt).cuda() dp_model = torch.nn.DataParallel(model) update_lr_flag = True # Assure in training mode dp_model.train() crit = utils.LanguageModelCriterion(opt.XE_eps) rl_crit = utils.RewardCriterion() # build_optimizer optimizer = build_optimizer(model, opt) # Load the optimizer if opt.load_opti and vars(opt).get( 'start_from', None) is not None and opt.load_best == 0 and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) # initialize the running average of parameters avg_param = deepcopy(list(p.data for p in model.parameters())) # make evaluation using original model best_val_score, histories, infos = eva_original_model( best_val_score, crit, epoch, histories, infos, iteration, loader, loss_history, lr_history, model, opt, optimizer, ss_prob_history, tb_summary_writer, val_result_history) while True: if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: if opt.lr_decay == 'exp': frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor elif opt.lr_decay == 'cosine': lr_epoch = min((epoch - opt.learning_rate_decay_start), opt.lr_max_epoch) cosine_decay = 0.5 * ( 1 + math.cos(math.pi * lr_epoch / opt.lr_max_epoch)) decay_factor = (1 - opt.lr_cosine_decay_base ) * cosine_decay + opt.lr_cosine_decay_base opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate lr = [opt.current_lr] if opt.att_normalize_method is not None and '6' in opt.att_normalize_method: lr = [opt.current_lr, opt.lr_ratio * opt.current_lr] utils.set_lr(optimizer, lr) print('learning rate is: ' + str(lr)) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False # Update the iteration iteration += 1 # Load data from train split (0) data = loader.get_batch(opt.train_split) torch.cuda.synchronize() start = time.time() tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp optimizer.zero_grad() if not sc_flag: output = dp_model(fc_feats, att_feats, labels, att_masks) # calculate loss loss = crit(output[0], labels[:, 1:], masks[:, 1:]) # add some middle variable histogram if iteration % (4 * opt.losses_log_every) == 0: outputs = [ _.data.cpu().numpy() if _ is not None else None for _ in output ] variables_histogram(data, iteration, outputs, tb_summary_writer, opt) else: gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) # grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_max_norm) # add_summary_value(tb_summary_writer, 'grad_L2_norm', grad_norm, iteration) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() end = time.time() # compute the running average of parameters for p, avg_p in zip(model.parameters(), avg_param): avg_p.mul_(opt.beta).add_((1.0 - opt.beta), p.data) if iteration % 10 == 0: if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:,0]), end - start)) # Update the epoch if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:, 0]), iteration) loss_history[iteration] = train_loss if not sc_flag else np.mean( reward[:, 0]) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob if opt.tensorboard_weights_grads and (iteration % (8 * opt.losses_log_every) == 0): # add weights histogram to tensorboard summary for name, param in model.named_parameters(): if (opt.tensorboard_parameters_name is None or sum([ p_name in name for p_name in opt.tensorboard_parameters_name ]) > 0) and param.grad is not None: tb_summary_writer.add_histogram( 'Weights_' + name.replace('.', '/'), param, iteration) tb_summary_writer.add_histogram( 'Grads_' + name.replace('.', '/'), param.grad, iteration) if opt.tensorboard_buffers and (iteration % (opt.losses_log_every) == 0): for name, buffer in model.named_buffers(): if (opt.tensorboard_buffers_name is None or sum([ p_name in name for p_name in opt.tensorboard_buffers_name ]) > 0) and buffer is not None: add_summary_value(tb_summary_writer, name.replace('.', '/'), buffer, iteration) if opt.distance_sensitive_coefficient and iteration % ( 4 * opt.losses_log_every) == 0: print('The coefficient in intra_att_att_lstm is as follows:') print( model.core.intra_att_att_lstm.coefficient.data.cpu().tolist()) print('The coefficient in intra_att_lang_lstm is as follows:') print( model.core.intra_att_lang_lstm.coefficient.data.cpu().tolist()) if opt.distance_sensitive_bias and iteration % ( 4 * opt.losses_log_every) == 0: print('The bias in intra_att_att_lstm is as follows:') print(model.core.intra_att_att_lstm.bias.data.cpu().tolist()) print('The bias in intra_att_lang_lstm is as follows:') print(model.core.intra_att_lang_lstm.bias.data.cpu().tolist()) # make evaluation using original model if (iteration % opt.save_checkpoint_every == 0): best_val_score, histories, infos = eva_original_model( best_val_score, crit, epoch, histories, infos, iteration, loader, loss_history, lr_history, model, opt, optimizer, ss_prob_history, tb_summary_writer, val_result_history) # make evaluation with the averaged parameters model if iteration > opt.ave_threshold and (iteration % opt.save_checkpoint_every == 0): best_val_score_ave_model, infos = eva_ave_model( avg_param, best_val_score_ave_model, crit, infos, iteration, loader, model, opt, tb_summary_writer) # # Stop if reaching max epochs # if epoch >= opt.max_epochs and opt.max_epochs != -1: # break if iteration >= opt.max_iter: break
def train(train_loader, val_loader, model, crit, optimizer, lr_scheduler, opt, rl_crit=None): model.train() model = nn.DataParallel(model) # lowest val loss best_loss = None for epoch in range(opt.epochs): lr_scheduler.step() iteration = 0 # If start self crit training if opt.self_crit_after != -1 and epoch >= opt.self_crit_after: sc_flag = True init_cider_scorer(opt.cached_tokens) else: sc_flag = False for data in train_loader: torch.cuda.synchronize() fc_feats = Variable(data['fc_feats']).cuda() labels = Variable(data['labels']).long().cuda() masks = Variable(data['masks']).cuda() if not sc_flag: seq_probs, predicts = model(fc_feats, labels) loss = crit(seq_probs, labels[:, 1:], masks[:, 1:]) else: gen_result, sample_logprobs = model.sample(fc_feats, vars(opt)) # print(gen_result) reward = get_self_critical_reward(model, fc_feats, data, gen_result) loss = rl_crit( sample_logprobs, gen_result, Variable(torch.from_numpy(reward).float().cuda())) optimizer.zero_grad() loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.data[0] torch.cuda.synchronize() iteration += 1 if not sc_flag: print("iter %d (epoch %d), train_loss = %.6f" % (iteration, epoch, train_loss)) else: print("iter %d (epoch %d), avg_reward = %.3f" % (iteration, epoch, np.mean(reward[:, 0]))) # lowest val loss if epoch % opt.save_checkpoint_every == 0: checkpoint_path = os.path.join(opt.checkpoint_path, 'model_%d.pth' % (epoch)) torch.save(model.state_dict(), checkpoint_path) print("model saved to %s" % (checkpoint_path)) val_loss = val(val_loader, model, crit) print("Val loss is: %.6f" % (val_loss)) model.train() if best_loss is None or val_loss < best_loss: print("(epoch %d), now lowest val loss is %.6f" % (epoch, val_loss)) checkpoint_path = os.path.join(opt.checkpoint_path, 'model_best.pth') torch.save(model.state_dict(), checkpoint_path) print("best model saved to %s" % (checkpoint_path)) best_loss = val_loss
def train(opt): opt.use_att = utils.if_use_att(opt.caption_model) loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')): with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) model = models.setup(opt) model.cuda() update_lr_flag = True # Assure in training mode model.train() crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")): optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) while True: if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate ** frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_cider_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False start = time.time() # Load data from train split (0) data = loader.get_batch('train') print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']] tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp] fc_feats, att_feats, labels, masks = tmp optimizer.zero_grad() if not sc_flag: loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:]) else: gen_result, sample_logprobs = model.sample(fc_feats, att_feats, {'sample_max':0}) reward = get_self_critical_reward(model, fc_feats, att_feats, data, gen_result) loss = rl_crit(sample_logprobs, gen_result, Variable(torch.from_numpy(reward).float().cuda(), requires_grad=False)) loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.data[0] torch.cuda.synchronize() end = time.time() if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:,0]), end - start)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): if tf is not None: add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tf_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration) tf_summary_writer.flush() loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0]) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs) # Write validation result into summary if tf is not None: add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k,v in lang_stats.items(): add_summary_value(tf_summary_writer, k, v, iteration) tf_summary_writer.flush() val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = - val_loss best_flag = False if True: # if true if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f: cPickle.dump(infos, f) with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f: cPickle.dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt): # opt.use_att = utils.if_use_att(opt.caption_model) opt.use_att = True if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length print(opt.checkpoint_path) tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) critic_loss_history = histories.get('critic_loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) variance_history = histories.get('variance_history', {}) time_history = histories.get('time_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) model = models.setup(opt).cuda() dp_model = model target_actor = models.setup(opt).cuda() ####################### Critic pretrain ##################################################################### ##### Critic with state as input # if opt.critic_model == 'state_critic': # critic_model = CriticModel(opt) # else: critic_model = AttCriticModel(opt) target_critic = AttCriticModel(opt) if vars(opt).get('start_from_critic', None) is not None and True: # check if all necessary files exist assert os.path.isdir(opt.start_from_critic ), " %s must be a a path" % opt.start_from_critic print( os.path.join(opt.start_from_critic, opt.critic_model + '_model.pth')) critic_model.load_state_dict( torch.load( os.path.join(opt.start_from_critic, opt.critic_model + '_model.pth'))) target_critic.load_state_dict( torch.load( os.path.join(opt.start_from_critic, opt.critic_model + '_model.pth'))) critic_model = critic_model.cuda() target_critic = target_critic.cuda() critic_optimizer = utils.build_optimizer(critic_model.parameters(), opt) dp_model.eval() critic_iter = 0 init_scorer(opt.cached_tokens) critic_model.train() error_sum = 0 loss_vector_sum = 0 while opt.pretrain_critic == 1: if critic_iter > opt.pretrain_critic_steps: print('****************Finished critic training!') break data = loader.get_batch('train') torch.cuda.synchronize() start = time.time() tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp critic_model.train() critic_optimizer.zero_grad() assert opt.critic_model == 'att_critic_vocab' # crit_loss, reward, std = critic_loss_fun(fc_feats, att_feats, att_masks, dp_model, critic_model, opt, data) crit_loss, reward, std = target_critic_loss_fun_mask( fc_feats, att_feats, att_masks, dp_model, critic_model, opt, data, target_critic, target_actor) crit_loss.backward() critic_optimizer.step() #TODO update target. for cp, tp in zip(critic_model.parameters(), target_critic.parameters()): tp.data = tp.data + opt.gamma_critic * (cp.data - tp.data) crit_train_loss = crit_loss.item() torch.cuda.synchronize end = time.time() error_sum += crit_train_loss**0.5 - std if (critic_iter % opt.losses_log_every == 0): print("iter {} , crit_train_loss = {:.3f}, difference = {:.3f}, difference_sum = {:.3f}, time/batch = {:.3f}" \ .format(critic_iter, crit_train_loss**0.5, crit_train_loss**0.5-std, error_sum, end - start)) print(opt.checkpoint_path) opt.importance_sampling = 1 critic_model.eval() _, _, _, _ = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader, critic_model, test_critic=True) critic_iter += 1 # make evaluation on validation set, and save model if (critic_iter % opt.save_checkpoint_every == 0): if not os.path.isdir(opt.checkpoint_path): os.mkdir(opt.checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, opt.critic_model + '_model.pth') torch.save(critic_model.state_dict(), checkpoint_path) ######################### Actor-critic Training ##################################################################### update_lr_flag = True # Assure in training mode dp_model.train() crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) first_order = 0 second_order = 0 while True: if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False # Load data from train split (0) data = loader.get_batch('train') if data['bounds']['it_pos_now'] > 5000: loader.reset_iterator('train') continue dp_model.train() critic_model.eval() torch.cuda.synchronize() start = time.time() gen_result = None tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp optimizer.zero_grad() if not sc_flag: loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:]) else: if opt.rl_type == 'sc': gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) elif opt.rl_type == 'reinforce': gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = get_reward(data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) elif opt.rl_type == 'arsm': loss = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) #print(loss) reward = np.zeros([2, 2]) elif opt.rl_type == 'rf4': loss, _, _, _ = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) # print(loss) reward = np.zeros([2, 2]) elif opt.rl_type == 'importance_sampling': opt.importance_sampling = 1 loss, gen_result, reward, sample_logprobs_total = get_rf_loss( dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = np.repeat(reward[:, np.newaxis], gen_result.shape[1], 1) std = np.std(reward) elif opt.rl_type == 'importance_sampling_critic': opt.importance_sampling = 1 loss, gen_result, reward, sample_logprobs_total = get_rf_loss( target_actor, fc_feats, att_feats, att_masks, data, opt, loader, target_critic) reward = np.repeat(reward[:, np.newaxis], gen_result.shape[1], 1) std = np.std(reward) elif opt.rl_type == 'ar': loss = get_ar_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = np.zeros([2, 2]) elif opt.rl_type == 'mct_baseline': opt.rf_demean = 0 gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss( dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = get_reward(data, gen_result, opt) reward_cuda = torch.from_numpy(reward).float().cuda() mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0] if opt.arm_step_sample == 'greedy': sample_logprobs = sample_logprobs * probs loss = rl_crit( sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - mct_baseline) elif opt.rl_type == 'arsm_baseline': opt.arm_as_baseline = 1 opt.rf_demean = 0 gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss( dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = get_reward(data, gen_result, opt) reward_cuda = torch.from_numpy(reward).float().cuda() arm_baseline[arm_baseline < 0] = reward_cuda[arm_baseline < 0] if opt.arm_step_sample == 'greedy' and False: sample_logprobs = sample_logprobs * probs loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda - arm_baseline) elif opt.rl_type == 'ars_indicator': opt.arm_as_baseline = 1 opt.rf_demean = 0 gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss( dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) reward_cuda = torch.from_numpy(reward).float().cuda() loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda * arm_baseline) elif opt.rl_type == 'arsm_baseline_critic': opt.arm_as_baseline = 1 opt.rf_demean = 0 gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss( dp_model, fc_feats, att_feats, att_masks, data, opt, loader, critic_model) reward, std = get_reward(data, gen_result, opt, critic=True) if opt.arm_step_sample == 'greedy': sample_logprobs = sample_logprobs * probs loss = rl_crit( sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - arm_baseline) elif opt.rl_type == 'arsm_critic': #print(opt.critic_model) tic = time.time() loss = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader, critic_model) #print('arm_loss time', str(time.time()-tic)) reward = np.zeros([2, 2]) elif opt.rl_type == 'critic_vocab_sum': assert opt.critic_model == 'att_critic_vocab' tic = time.time() gen_result, sample_logprobs_total = dp_model( fc_feats, att_feats, att_masks, opt={'sample_max': 0}, total_probs=True, mode='sample') #batch, seq, vocab #print('generation time', time.time()-tic) gen_result_pad = torch.cat([ gen_result.new_zeros( gen_result.size(0), 1, dtype=torch.long), gen_result ], 1) tic = time.time() critic_value = critic_model(gen_result_pad, fc_feats, att_feats, True, opt, att_masks) #batch, seq, vocab #print('critic time', time.time() - tic) probs = torch.sum( F.softmax(sample_logprobs_total, 2) * critic_value.detach(), 2) mask = (gen_result > 0).float() mask = torch.cat( [mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1) loss = -torch.sum(probs * mask) / torch.sum(mask) reward = np.zeros([2, 2]) elif opt.rl_type == 'reinforce_critic': #TODO change the critic to attention if opt.critic_model == 'state_critic': critic_value, gen_result, sample_logprobs = critic_model( dp_model, fc_feats, att_feats, opt, att_masks) reward, std = get_reward(data, gen_result, opt, critic=True) loss = rl_crit( sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - critic_value[:, :-1].data) elif opt.critic_model == 'att_critic': gen_result, sample_logprobs = dp_model( fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') gen_result_pad = torch.cat([ gen_result.new_zeros(gen_result.size(0), 1, dtype=torch.long), gen_result ], 1) critic_value = critic_model(gen_result_pad, fc_feats, att_feats, True, opt, att_masks).squeeze(2) reward, std = get_reward(data, gen_result, opt, critic=True) loss = rl_crit( sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - critic_value.data) if opt.mle_weights != 0: loss += opt.mle_weights * crit( dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:]) #TODO make sure all sampling replaced by greedy for critic #### update the actor loss.backward() # with open(os.path.join(opt.checkpoint_path, 'best_embed.pkl'), 'wb') as f: # cPickle.dump(list(dp_model.embed.parameters())[0].data.cpu().numpy(), f) # with open(os.path.join(opt.checkpoint_path, 'best_logit.pkl'), 'wb') as f: # cPickle.dump(list(dp_model.logit.parameters())[0].data.cpu().numpy(), f) ## compute variance gradient = torch.zeros([0]).cuda() for i in model.parameters(): gradient = torch.cat((gradient, i.grad.view(-1)), 0) first_order = 0.9999 * first_order + 0.0001 * gradient second_order = 0.9999 * second_order + 0.0001 * gradient.pow(2) # print(torch.max(torch.abs(gradient))) variance = torch.mean(torch.abs(second_order - first_order.pow(2))).item() if opt.rl_type != 'arsm' or not sc_flag: utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() # ### update the critic if 'critic' in opt.rl_type: dp_model.eval() critic_model.train() utils.set_lr(critic_optimizer, opt.critic_learning_rate) critic_optimizer.zero_grad() assert opt.critic_model == 'att_critic_vocab' crit_loss, reward, std = target_critic_loss_fun_mask( fc_feats, att_feats, att_masks, dp_model, critic_model, opt, data, target_critic, target_actor, gen_result=gen_result, sample_logprobs_total=sample_logprobs_total, reward=reward) crit_loss.backward() critic_optimizer.step() for cp, tp in zip(critic_model.parameters(), target_critic.parameters()): tp.data = tp.data + opt.gamma_critic * (cp.data - tp.data) for cp, tp in zip(dp_model.parameters(), target_actor.parameters()): tp.data = tp.data + opt.gamma_actor * (cp.data - tp.data) crit_train_loss = crit_loss.item() error_sum += crit_train_loss**0.5 - std train_loss = loss.item() torch.cuda.synchronize() end = time.time() if (iteration % opt.losses_log_every == 0): if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) print(opt.checkpoint_path) elif 'critic' in opt.rl_type: print( "iter {} , crit_train_loss = {:.3f}, difference = {:.3f}, difference_sum = {:.3f},variance = {:g}, time/batch = {:.3f}" \ .format(iteration, crit_train_loss ** 0.5, crit_train_loss ** 0.5 - std, error_sum, variance, end - start)) print(opt.checkpoint_path) critic_model.eval() _, _, _, _ = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader, critic_model, test_critic=True) else: print("iter {} (epoch {}), avg_reward = {:.3f}, variance = {:g}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:, 0]), variance, end - start)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward), iteration) add_summary_value(tb_summary_writer, 'variance', variance, iteration) loss_history[iteration] = train_loss if not sc_flag else np.mean( reward) critic_loss_history[ iteration] = crit_train_loss if 'critic' in opt.rl_type else 0 lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob variance_history[iteration] = variance time_history[iteration] = end - start # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, crit, loader, eval_kwargs) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if True: # if true if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True if not os.path.isdir(opt.checkpoint_path): os.mkdir(opt.checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, opt.critic_model + '_model.pth') torch.save(critic_model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['critic_loss_history'] = critic_loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history histories['variance_history'] = variance_history histories['time'] = time_history # histories['variance'] = 0 with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '-best.pkl'), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt): # Deal with feature things before anything opt.use_att = utils.if_use_att(opt.caption_model) if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')): with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) model = models.setup(opt).cuda() dp_model = torch.nn.DataParallel(model) update_lr_flag = True # Assure in training mode dp_model.train() crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")): optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) while True: if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate ** frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False start = time.time() # Load data from train split (0) data = loader.get_batch('train') data_time = time.time() - start torch.cuda.synchronize() start = time.time() tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp optimizer.zero_grad() if not sc_flag: loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:]) else: gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() end = time.time() if iteration % opt.print_freq == 0: if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, batch time = {:.3f}, data time = {:.3f}" \ .format(iteration, epoch, train_loss, end - start, data_time)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, batch time = {:.3f}, data time = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:,0]), end - start, data_time)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): if tf is not None: add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tf_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration) tf_summary_writer.flush() loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0]) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) # MODIFIED (ADDED) # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, crit, loader, eval_kwargs) # Write validation result into summary if tf is not None: add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k,v in lang_stats.items(): add_summary_value(tf_summary_writer, k, v, iteration) tf_summary_writer.flush() val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = - val_loss best_flag = False if True: # if true if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f: cPickle.dump(infos, f) with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f: cPickle.dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best-i{}-score{}.pth'.format(iteration, best_val_score)) torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt): # Load data loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length # Tensorboard summaries (they're great!) tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) # Load pretrained model, info file, histories file infos = {} histories = {} if opt.start_from is not None: with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = ["rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) #ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) # Create model model = models.setup(opt).cuda() #pretrained_dict = torch.load(opt.model) #model.load_state_dict(pretrained_dict, strict=False) num_params = get_n_params(model) print('number of parameteres:', num_params) dp_model = torch.nn.DataParallel(model) dp_model.train() # Loss function crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() # Optimizer and learning rate adjustment flag optimizer = utils.build_optimizer(model.parameters(), opt) update_lr_flag = True # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) # Training loop while True: # Update learning rate once per epoch if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every #opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) #model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False # Load data from train split (0) start = time.time() data = loader.get_batch('train') data_time = time.time() - start start = time.time() # Unpack data torch.cuda.synchronize() tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['dist'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, dist_label, masks, att_masks = tmp batchsize = fc_feats.size(0) # Forward pass and loss optimizer.zero_grad() if not sc_flag: wordact, reconstruct = dp_model(fc_feats, att_feats, labels) #loss_dist = F.binary_cross_entropy(dist, dist_label.cpu().float()) fc_feats_max, _ = att_feats.max(1) loss_rec = F.mse_loss(reconstruct.cpu(), fc_feats_max.cpu()) mask = masks[:, 1:].contiguous() wordact = wordact[:, :, :-1] wordact_t = wordact.permute(0, 2, 1).contiguous() wordact_t = wordact_t.view( wordact_t.size(0) * wordact_t.size(1), -1) labels = labels.contiguous().view(-1, 6 * 30).cpu() wordclass_v = labels[:, 1:] wordclass_t = wordclass_v.contiguous().view(\ wordclass_v.size(0) * wordclass_v.size(1), 1) maskids = torch.nonzero(mask.view(-1).cpu()).numpy().reshape(-1) loss_xe = F.cross_entropy(wordact_t[maskids, ...], \ wordclass_t[maskids, ...].contiguous().view(maskids.shape[0])) loss = 5 * loss_xe + loss_rec else: gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) # Backward pass loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() # Print total_time = time.time() - start if iteration % opt.print_freq == 1: print('Read data:', time.time() - start) if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, data_time, total_time)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) #add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:, 0]), iteration) loss_history[iteration] = train_loss if not sc_flag else np.mean( reward[:, 0]) lr_history[iteration] = opt.current_lr #ss_prob_history[iteration] = model.ss_prob # Validate and save model if (iteration >= 60000 and iteration % opt.save_checkpoint_every == 0): checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Evaluate model eval_kwargs = {'split': 'test', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( dp_model, crit, loader, eval_kwargs) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Our metric is CIDEr if available, otherwise validation loss if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss # Save model in checkpoint path best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history #histories['ss_prob_history'] = ss_prob_history with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(histories, f) # Save model to unique file if new best model if best_flag: model_fname = 'model-best.pth' infos_fname = 'model-best.pkl' checkpoint_path = os.path.join(opt.checkpoint_path, model_fname) torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open(os.path.join(opt.checkpoint_path, infos_fname), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def eval_split(model, crit, loader, eval_kwargs={}): eval_att = eval_kwargs.get('eval_att',False) gt_grd_eval = eval_kwargs.get('gt_grd_eval',False) eval_scan = eval_kwargs.get('eval_scan',False) verbose = eval_kwargs.get('verbose', True) verbose_beam = eval_kwargs.get('verbose_beam', 1) verbose_loss = eval_kwargs.get('verbose_loss', 1) num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) split = eval_kwargs.get('split', 'val') lang_eval = eval_kwargs.get('language_eval', 0) dataset = eval_kwargs.get('dataset', 'coco') beam_size = eval_kwargs.get('beam_size', 1) remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0) os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # Use this nasty way to make other code clean since it's a global configuration # Make sure in the evaluation mode model.eval() loader.reset_iterator(split) n = 0 loss = 0 loss_sum = 0 loss_evals = 1e-8 predictions = [] grd_output = defaultdict(list) while True: data = loader.get_batch(split) n = n + loader.batch_size if data.get('labels', None) is not None and verbose_loss: # forward the model to get loss tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'],data['box_feats']] tmp = [_.cuda() if _ is not None else _ for _ in tmp] fc_feats, att_feats, labels, masks, att_masks, box_feats = tmp with torch.no_grad(): loss = crit(model(fc_feats, att_feats, labels, att_masks)[0], labels[:,1:], masks[:,1:]).item() loss_sum = loss_sum + loss loss_evals = loss_evals + 1 if not gt_grd_eval: # forward the model to also get generated samples for each image # Only leave one feature for each image, in case duplicate sample tmp = [data['fc_feats'][np.arange(loader.batch_size) * loader.seq_per_img], data['att_feats'][np.arange(loader.batch_size) * loader.seq_per_img], data['box_feats'][np.arange(loader.batch_size) * loader.seq_per_img], data['att_masks'][np.arange(loader.batch_size) * loader.seq_per_img] if data['att_masks'] is not None else None] tmp = [_.cuda() if _ is not None else _ for _ in tmp] fc_feats, att_feats, box_feats, att_masks = tmp else: tmp = [data['fc_feats'], data['att_feats'], data['box_feats'], data['att_masks'] if data['att_masks'] is not None else None] tmp = [_.cuda() if _ is not None else _ for _ in tmp] fc_feats, att_feats, box_feats, att_masks = tmp # forward the model to also get generated samples for each image with torch.no_grad(): if eval_att: if not gt_grd_eval: assert eval_kwargs['beam_size']==1, 'only support beam_size is 1' seq, _, att_weights = model(fc_feats, att_feats, att_masks, opt=eval_kwargs, mode='sample') seq=seq.detach() att_weights=att_weights.detach() att_ind = torch.max(att_weights, dim=2)[1] else: if not eval_scan: #==This snippet used for evaluating grounding accuracy of caption model on gt sentence.=====# _, att_weights=model(fc_feats, att_feats, labels, att_masks) seq = labels[:,1:] att_weights=att_weights.detach() att_ind = torch.max(att_weights, dim=2)[1] data['infos'] = data['infos']*5 else: # pdb.set_trace() #====This snippet used for evaluating grounding accuracy of SCAN model on gt sentence.======# gts = data['gts'] reward, att_weights, noun_mask= get_self_critical_reward(model, fc_feats, att_feats, att_masks, gts, labels[:,1:], eval_kwargs) seq = labels[:,1:] att_weights=att_weights.detach() att_ind = torch.max(att_weights, dim=2)[1] data['infos'] = data['infos']*5 for i in range(seq.size(0)): tmp_result = {'clss':[], 'idx_in_sent':[], 'bbox':[]} num_sent = 0 # does not really matter which reference to use for j in range(seq.size(1)): if seq[i,j].item() != 0: lemma = loader.wtol[loader.ix_to_word[str(seq[i,j].item())]] if lemma in loader.lemma_det_dict: tmp_result['bbox'].append(box_feats[i, att_ind[i, j], :4].tolist()) tmp_result['clss'].append(loader.itod[loader.lemma_det_dict[lemma]]) tmp_result['idx_in_sent'].append(j) # redundant, for the sake of output format else: break grd_output[str(data['infos'][i]['id'])].append(tmp_result) else: seq = model(fc_feats, att_feats, att_masks, opt=eval_kwargs, mode='sample')[0].data # Print beam search if beam_size > 1 and verbose_beam: for i in range(loader.batch_size): print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]])) print('--' * 10) sents = utils.decode_sequence(loader.get_vocab(), seq) for k, sent in enumerate(sents): entry = {'image_id': data['infos'][k]['id'], 'caption': sent} if eval_kwargs.get('dump_path', 0) == 1: entry['file_name'] = data['infos'][k]['file_path'] predictions.append(entry) if eval_kwargs.get('dump_images', 0) == 1: # dump the raw image to vis/ folder cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross print(cmd) os.system(cmd) if verbose: print('image %s: %s' %(entry['image_id'], entry['caption'])) # if we wrapped around the split or used up val imgs budget then bail ix0 = data['bounds']['it_pos_now'] ix1 = data['bounds']['it_max'] if num_images != -1: ix1 = min(ix1, num_images) for i in range(n - ix1): predictions.pop() if verbose: print('evaluating validation preformance... %d/%d (%f)' %(ix0 - 1, ix1, loss)) if data['bounds']['wrapped']: break if num_images >= 0 and n >= num_images: break lang_stats = None if lang_eval == 1: if not gt_grd_eval: lang_stats = language_eval(dataset, predictions, eval_kwargs['id'], split) if eval_att: # write attention results to file attn_file = 'att_results/attn-gen-sent-results-'+split+'-'+eval_kwargs['id']+'.json' with open(attn_file, 'w') as f: json.dump({'results':grd_output, 'eval_mode':'gen', 'external_data':{'used':True, 'details':'Object detector pre-trained on Visual Genome on object detection task.'}}, f) # offline eval evaluator = FlickrGrdEval(reference_file=eval_kwargs['reference'], submission_file=attn_file, split_file=eval_kwargs['split_file'], val_split=[split], iou_thresh=0.5) print('\nResults Summary (generated sent):') print('Printing attention accuracy on generated sentences...') if not gt_grd_eval: prec_all, recall_all, f1_all = evaluator.grd_eval(mode='all') prec_loc, recall_loc, f1_loc = evaluator.grd_eval(mode='loc') else: grd_accu = evaluator.gt_grd_eval() print('\n') # Switch back to training mode model.train() return loss_sum/loss_evals, predictions, lang_stats
def train(opt): opt.use_att = utils.if_use_att(opt.caption_model) loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size1", "rnn_size2", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) # loader.iterators = infos.get('iterators', loader.iterators) # loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) model = models.setup(opt) model.cuda() update_lr_flag = True # Assure in training mode model.train() # model.set_mode('train') crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) while True: model.train() if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor utils.set_lr(optimizer, opt.current_lr) # set the decayed rate else: opt.current_lr = opt.learning_rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_cider_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False start = time.time() # Load data from train split (0) data = loader.get_batch('train+val') # print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() tmp = [ data['fc_feats'], data['att_feats'], data['num_bbox'], data['labels'], data['masks'] ] tmp = [ Variable(torch.from_numpy(_).float(), requires_grad=False).cuda() for _ in tmp ] fc_feats, att_feats, num_bbox, labels, masks = tmp labels = labels.long() optimizer.zero_grad() if not sc_flag: loss = crit(model(fc_feats, att_feats, num_bbox, labels), labels[:, 1:], masks[:, 1:]) # loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:]) else: gen_result, sample_logprobs = model.sample(fc_feats, att_feats, num_bbox, {'sample_max': 0}) reward = get_self_critical_reward(model, fc_feats, att_feats, num_bbox, data, gen_result) loss = rl_crit( sample_logprobs, gen_result, Variable(torch.from_numpy(reward).float().cuda(), requires_grad=False)) loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.data[0] torch.cuda.synchronize() end = time.time() if not sc_flag: if (iteration % 100 == 0): print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f} lr={}" \ .format(iteration, epoch, train_loss, end - start, opt.current_lr )) else: if (iteration % 100 == 0): print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f} lr={}" \ .format(iteration, epoch, np.mean(reward[:,0]), end - start, opt.current_lr )) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): if tf is not None: add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tf_summary_writer, 'avg_reward', np.mean(reward[:, 0]), iteration) tf_summary_writer.flush() loss_history[iteration] = train_loss if not sc_flag else np.mean( reward[:, 0]) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = { 'split': 'val', 'dataset': opt.input_json, 'val_ref_path': opt.val_ref_path, 'raw_val_anno_path': opt.raw_val_anno_path } eval_kwargs.update(vars(opt)) # predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs) best_flag = False if True: # if true # if best_val_score is None or current_score > best_val_score: # best_val_score = current_score # best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '-best.pkl'), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt): opt.use_att = utils.if_use_att(opt.caption_model) from dataloader import DataLoader loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.vocab_ccg_size = loader.vocab_ccg_size opt.seq_length = loader.seq_length infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) cnn_model = utils.build_cnn(opt) cnn_model.cuda() model = models.setup(opt) model.cuda() # model = DataParallel(model) if vars(opt).get('start_from', None) is not None: # check if all necessary files exist assert os.path.isdir( opt.start_from), " %s must be a a path" % opt.start_from assert os.path.isfile( os.path.join(opt.start_from, "infos_" + opt.id + ".pkl") ), "infos.pkl file does not exist in path %s" % opt.start_from model.load_state_dict( torch.load(os.path.join(opt.start_from, 'model.pth'))) update_lr_flag = True model.train() crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() multilabel_crit = nn.MultiLabelSoftMarginLoss().cuda() # optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay) optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate) if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after: print('finetune mode') cnn_optimizer = optim.Adam([\ {'params': module.parameters()} for module in cnn_model._modules.values()[5:]\ ], lr=opt.cnn_learning_rate, weight_decay=opt.cnn_weight_decay) if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): if os.path.isfile(os.path.join(opt.start_from, 'optimizer.pth')): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after: if os.path.isfile(os.path.join(opt.start_from, 'optimizer-cnn.pth')): cnn_optimizer.load_state_dict( torch.load( os.path.join(opt.start_from, 'optimizer-cnn.pth'))) eval_kwargs = {'split': 'val', 'dataset': opt.input_json, 'verbose': True} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( cnn_model, model, crit, loader, eval_kwargs, True) epoch_start = time.time() while True: if update_lr_flag: if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor utils.set_lr(optimizer, opt.current_lr) # set the decayed rate else: opt.current_lr = opt.learning_rate if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob #model.module.ss_prob = opt.ss_prob if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True else: sc_flag = False # Update the training stage of cnn for p in cnn_model.parameters(): p.requires_grad = True # Fix the first few layers: for module in cnn_model._modules.values()[:5]: for p in module.parameters(): p.requires_grad = False cnn_model.train() update_lr_flag = False cnn_model.apply(utils.set_bn_fix) cnn_model.apply(utils.set_bn_eval) start = time.time() torch.cuda.synchronize() data = loader.get_batch('train') if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after: multilabels = [ data['detection_infos'][i]['label'] for i in range(len(data['detection_infos'])) ] tmp = [ data['labels'], data['masks'], np.array(multilabels, dtype=np.int16) ] tmp = [ Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp ] labels, masks, multilabels = tmp images = data[ 'images'] # it cannot be turned into tensor since different sizes. _fc_feats_2048 = [] _fc_feats_81 = [] _att_feats = [] for i in range(loader.batch_size): x = Variable(torch.from_numpy(images[i]), requires_grad=False).cuda() x = x.unsqueeze(0) att_feats, fc_feats_81 = cnn_model(x) fc_feats_2048 = att_feats.mean(3).mean(2).squeeze() att_feats = F.adaptive_avg_pool2d(att_feats, [14, 14]).squeeze().permute( 1, 2, 0) #(0, 2, 3, 1) _fc_feats_2048.append(fc_feats_2048) _fc_feats_81.append(fc_feats_81) _att_feats.append(att_feats) _fc_feats_2048 = torch.stack(_fc_feats_2048) _fc_feats_81 = torch.stack(_fc_feats_81) _att_feats = torch.stack(_att_feats) att_feats = _att_feats.unsqueeze(1).expand(*((_att_feats.size(0), loader.seq_per_img,) + \ _att_feats.size()[1:])).contiguous().view(*((_att_feats.size(0) * loader.seq_per_img,) + \ _att_feats.size()[1:])) fc_feats_2048 = _fc_feats_2048.unsqueeze(1).expand(*((_fc_feats_2048.size(0), loader.seq_per_img,) + \ _fc_feats_2048.size()[1:])).contiguous().view(*((_fc_feats_2048.size(0) * loader.seq_per_img,) + \ _fc_feats_2048.size()[1:])) fc_feats_81 = _fc_feats_81 # cnn_optimizer.zero_grad() else: tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'] ] tmp = [ Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp ] fc_feats, att_feats, labels, masks = tmp optimizer.zero_grad() if not sc_flag: loss1 = crit(model(fc_feats_2048, att_feats, labels), labels[:, 1:], masks[:, 1:]) loss2 = multilabel_crit(fc_feats_81.double(), multilabels.double()) loss = 0.8 * loss1 + 0.2 * loss2.float() else: gen_result, sample_logprobs = model.sample(fc_feats_2048, att_feats, {'sample_max': 0}) reward = get_self_critical_reward(model, fc_feats_2048, att_feats, data, gen_result) loss1 = rl_crit( sample_logprobs, gen_result, Variable(torch.from_numpy(reward).float().cuda(), requires_grad=False)) loss2 = multilabel_crit(fc_feats_81.double(), multilabels.double()) loss3 = crit(model(fc_feats_2048, att_feats, labels), labels[:, 1:], masks[:, 1:]) loss = 0.995 * loss1 + 0.005 * (loss2.float() + loss3) loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.data[0] mle_loss = loss1.data[0] multilabel_loss = loss2.data[0] torch.cuda.synchronize() end = time.time() if not sc_flag and iteration % 2500 == 0: print("iter {} (epoch {}), mle_loss = {:.3f}, multilabel_loss = {:.3f}, train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, mle_loss, multilabel_loss, train_loss, end - start)) if sc_flag and iteration % 2500 == 0: print("iter {} (epoch {}), avg_reward = {:.3f}, mle_loss = {:.3f}, multilabel_loss = {:.3f}, train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:,0]), mle_loss, multilabel_loss, train_loss, end - start)) iteration += 1 if (iteration % opt.losses_log_every == 0): loss_history[iteration] = train_loss if not sc_flag else np.mean( reward[:, 0]) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob if (iteration % opt.save_checkpoint_every == 0): eval_kwargs = { 'split': 'val', 'dataset': opt.input_json, 'verbose': True } eval_kwargs.update(vars(opt)) if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after: val_loss, predictions, lang_stats = eval_utils.eval_split( cnn_model, model, crit, loader, eval_kwargs, True) else: val_loss, predictions, lang_stats = eval_utils.eval_split( cnn_model, model, crit, loader, eval_kwargs, False) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if True: if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) cnn_checkpoint_path = os.path.join(opt.checkpoint_path, 'model-cnn.pth') torch.save(cnn_model.state_dict(), cnn_checkpoint_path) print("cnn model saved to {}".format(cnn_checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after: cnn_optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer-cnn.pth') torch.save(cnn_optimizer.state_dict(), cnn_optimizer_path) infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) cnn_checkpoint_path = os.path.join(opt.checkpoint_path, 'model-cnn-best.pth') torch.save(cnn_model.state_dict(), cnn_checkpoint_path) print("cnn model saved to {}".format(cnn_checkpoint_path)) with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '-best.pkl'), 'wb') as f: cPickle.dump(infos, f) if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True print("epoch: " + str(epoch) + " during: " + str(time.time() - epoch_start)) epoch_start = time.time() if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices, sc_flag,box_inds): out = {} if not sc_flag: if self.opt.att_supervise: outputs, attn_weights=self.model(fc_feats, att_feats, labels, att_masks) loss1 = self.crit(outputs, labels[:,1:], masks[:,1:]) if self.opt.use_gt_box: box_inds = box_inds[:,1:] if self.opt.att_sup_crit == 'KL' or self.opt.att_sup_crit == 'ExtendNLL': sup_mask = (box_inds != 1e-8* torch.ones(box_inds.size(-1)).type_as(box_inds)).any(dim=-1).view(-1) else: sup_mask = (box_inds>=0).view(-1) else: _, grd_weights,noun_mask= get_self_critical_reward(self.model, fc_feats, att_feats, att_masks, gts, labels[:,1:].detach(), vars(self.opt)) sup_mask = (noun_mask==1).cuda().view(-1) attn_weights = torch.log(torch.clamp(attn_weights,min=self.min_value)).view(-1,attn_weights.size(-1))[sup_mask] if self.opt.use_gt_box: if self.opt.att_sup_crit == 'KL': # Todo grd_target = F.softmax(box_inds/0.5,dim=-1).view(-1, box_inds.size(-1))[sup_mask] loss2 = self.kl_crit(attn_weights, grd_target) elif self.opt.att_sup_crit == 'NLL': grd_target = box_inds.reshape(-1)[sup_mask].long() loss2 = self.nll(attn_weights,grd_target) elif self.opt.att_sup_crit == 'ExtendNLL': grd_target = box_inds.reshape(-1, box_inds.size(-1))[sup_mask] loss2 = self.extendnll(attn_weights, grd_target) else: if self.opt.att_sup_crit == 'KL': grd_target = torch.clamp(grd_weights[:,:17,:],min=self.min_value).view(-1,grd_weights.size(-1))[sup_mask] loss2 = self.kl_crit(attn_weights, grd_target) elif self.opt.att_sup_crit == 'NLL': grd_target = torch.max(grd_weights[:,:17,:],dim=2)[1].view(-1)[sup_mask] loss2 = self.nll(attn_weights,grd_target) elif self.opt.att_sup_crit == 'ExtendNLL': # grd_target = torch.clamp(grd_weights[:,:17,:],min=self.min_value).view(-1,grd_weights.size(-1))[sup_mask] # loss2 = self.extendnll(attn_weights, grd_target) raise NotImplementedError loss=loss1+self.opt.att_supervise_weight*loss2 else: outputs=self.model(fc_feats, att_feats, labels, att_masks)[0] loss = self.crit(outputs, labels[:,1:], masks[:,1:]) else: if self.opt.att_supervise: gen_result, sample_logprobs, attn_weights = self.model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample') else: gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample') gts = [gts[_] for _ in gt_indices.tolist()] if self.opt.att_supervise: reward, grd_weights, noun_mask= get_self_critical_reward(self.model, fc_feats, att_feats, att_masks, gts, gen_result, vars(self.opt)) else: reward = get_self_critical_reward(self.model, fc_feats, att_feats, att_masks, gts, gen_result, vars(self.opt)) reward = torch.from_numpy(reward).float().to(gen_result.device) if self.opt.att_supervise: loss1=self.rl_crit(sample_logprobs, gen_result.data, reward) sup_mask = (noun_mask==1).cuda().view(-1) attn_weights = torch.log(torch.clamp(attn_weights,min=self.min_value)).view(-1,attn_weights.size(-1))[sup_mask] if self.opt.att_sup_crit == 'KL': grd_target = torch.clamp(grd_weights,min=self.min_value).view(-1,grd_weights.size(-1))[sup_mask] loss2 = self.kl_crit(attn_weights, grd_target) elif self.opt.att_sup_crit == 'NLL': grd_target = torch.max(grd_weights,dim=2)[1].view(-1)[sup_mask] loss2 = self.nll(attn_weights,grd_target) elif self.opt.att_sup_crit == 'ExtendNLL': # grd_target = torch.clamp(grd_weights,min=self.min_value).view(-1,grd_weights.size(-1))[sup_mask] # loss2 = self.extendnll(attn_weights, grd_target) raise NotImplementedError loss=loss1+self.opt.att_supervise_weight*loss2 else: loss = self.rl_crit(sample_logprobs, gen_result.data, reward) out['reward'] = reward[:,0].mean() out['loss'] = loss return out
def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices, sc_flag,box_inds, epoch, sents_mask): out = {} # pdb.set_trace() if not sc_flag: if self.opt.cexe and epoch >= self.opt.cexe_after: if self.opt.sup_nde: outputs, outputs_adjust, outputs_nde=self.model(fc_feats, att_feats, labels, att_masks, sents_mask[:,1:]) else: outputs, outputs_adjust=self.model(fc_feats, att_feats, labels, att_masks, sents_mask[:,1:]) #At now, we only consider visual words. adjust_mask = sents_mask[:,1:] == 1 adjust_mask_expand = adjust_mask.unsqueeze(dim=2).expand(outputs.shape) masked_outputs = torch.masked_select(outputs,adjust_mask_expand).view(-1,outputs.shape[2]) masked_outputs_adjust = torch.masked_select(outputs_adjust,adjust_mask_expand).view(-1,outputs.shape[2]) # masked_outputs_nde = torch.masked_select(outputs_nde,adjust_mask_expand).view(-1,outputs.shape[2]) loss1 = self.crit(outputs, labels[:,1:], masks[:,1:]) if self.opt.sup_tie and self.opt.sup_nde: loss2 = F.kl_div(masked_outputs, masked_outputs_adjust.detach(), log_target=True, reduction='batchmean') loss3 = self.nll(masked_outputs_adjust, torch.masked_select(labels[:,1:], adjust_mask)) loss4 = self.crit(outputs_nde, labels[:,1:], masks[:,1:]) loss = loss1 + self.opt.cexe_weight * loss2 + self.opt.tie_weight * loss3 + self.opt.nde_weight * loss4 elif self.opt.sup_tie: loss2 = F.kl_div(masked_outputs, masked_outputs_adjust.detach(), log_target=True, reduction='batchmean') loss3 = self.nll(masked_outputs_adjust, torch.masked_select(labels[:,1:], adjust_mask)) loss = loss1 + self.opt.cexe_weight * loss2 + self.opt.tie_weight * loss3 elif self.opt.sup_nde: loss2 = F.kl_div(masked_outputs, masked_outputs_adjust.detach(), log_target=True, reduction='batchmean') loss4 = self.crit(outputs_nde, labels[:,1:], masks[:,1:]) loss = loss1 + self.opt.cexe_weight * loss2 + self.opt.nde_weight * loss4 else: loss2 = F.kl_div(masked_outputs, masked_outputs_adjust.detach(), log_target=True, reduction='batchmean') loss = loss1 + self.opt.cexe_weight * loss2 else: outputs=self.model(fc_feats, att_feats, labels, att_masks)[0] loss = self.crit(outputs, labels[:,1:], masks[:,1:]) else: if self.opt.cec: gen_result, sample_logprobs, outputs, outputs_tie = self.model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample') else: gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample') gts = [gts[_] for _ in gt_indices.tolist()] reward = get_self_critical_reward(self.model, fc_feats, att_feats, att_masks, gts, gen_result, vars(self.opt)) reward = torch.from_numpy(reward).float().to(gen_result.device) if self.opt.cec: loss1 = self.rl_crit(sample_logprobs, gen_result.data, reward) sents_mask = make_sents_mask(gen_result, self.opt.vocab) adjust_mask = sents_mask == 1 adjust_mask_expand = adjust_mask.unsqueeze(dim=2).expand(outputs.shape) masked_outputs = torch.masked_select(outputs,adjust_mask_expand).view(-1,outputs.shape[2]) masked_outputs_adjust = torch.masked_select(outputs_tie,adjust_mask_expand).view(-1,outputs.shape[2]) batch_div = F.kl_div(masked_outputs, masked_outputs_adjust.detach(), log_target=True, reduction='none').sum(dim=1) masked_reward = torch.masked_select(reward, adjust_mask) masked_reward_positive = (masked_reward>0).float() loss2 = (batch_div * masked_reward * masked_reward_positive).mean() loss = loss1 + self.opt.cec_weight * loss2 else: loss = self.rl_crit(sample_logprobs, gen_result.data, reward) out['reward'] = reward[:,0].mean() out['loss'] = loss return out
def train(loader, model, crit, optimizer, lr_scheduler, opt, rl_crit=None): model.train() if opt['visdom']: viz = visdom.Visdom(env='train') loss_win = viz.line(np.arange(1), opts={'title': 'loss'}) for epoch in range(opt["epochs"]): lr_scheduler.step() iteration = 0 # If start self crit training # print(opt["self_crit_after"]) if opt["self_crit_after"] != -1 and epoch >= opt[ "self_crit_after"]: #每多少次保存一下 sc_flag = True init_cider_scorer(opt["cached_tokens"]) else: sc_flag = False # print(model) for data in loader: # print(data) torch.cuda.synchronize() fc_feats = data['fc_feats'].cuda() # voice_feats = data['voice_feats'].cuda() if opt['with_hand'] == 1: hand_feats = data['hand_feats'].cuda() hand_pro = data['hand_pro'].cuda() labels = data['labels'].cuda() masks = data['masks'].cuda() #print(sc_flag) optimizer.zero_grad() if not sc_flag: # seq_probs, _ = model(fc_feats, voice_feats, hand_feats, labels, 'train') if opt['with_hand'] == 1: seq_probs, _ = model(fc_feats, hand_feats, hand_pro, labels, 'train') else: seq_probs, _ = model.forward2(fc_feats, labels, 'train') loss = crit(seq_probs, labels[:, 1:], masks[:, 1:]) # todo 下面else部分没有修改声音和手语的内容 else: seq_probs, seq_preds = model(fc_feats, mode='inference', opt=opt) reward = get_self_critical_reward(model, fc_feats, data, seq_preds) print(reward.shape) loss = rl_crit(seq_probs, seq_preds, torch.from_numpy(reward).float().cuda()) loss.backward() clip_grad_value_(model.parameters(), opt['grad_clip']) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() iteration += 1 if not sc_flag: print("?iter %d (epoch %d), train_loss = %.6f" % (iteration, epoch, train_loss)) if opt['visdom']: viz.line(Y=np.array([train_loss]), X=np.array([epoch]), win=loss_win, update='append') else: print("??iter %d (epoch %d), avg_reward = %.6f" % (iteration, epoch, np.mean(reward[:, 0]))) if epoch % opt["save_checkpoint_every"] == 0: model_path = os.path.join(opt["checkpoint_path"], 'model_%d.pth' % (epoch)) model_info_path = os.path.join(opt["checkpoint_path"], 'model_score.txt') torch.save(model.state_dict(), model_path) # print("model saved to %s" % (model_path)) with open(model_info_path, 'a') as f: f.write("model_%d, loss: %.6f\n" % (epoch, train_loss))
def __init__(self, opt): super(AttModel, self).__init__() self.image_crop_size = opt.image_crop_size self.vocab_size = opt.vocab_size self.detect_size = opt.detect_size self.input_encoding_size = opt.input_encoding_size #self.rnn_type = opt.rnn_type self.rnn_size = opt.rnn_size self.num_layers = opt.num_layers self.drop_prob_lm = opt.drop_prob_lm self.seq_length = opt.seq_length self.fc_feat_size = opt.fc_feat_size self.att_feat_size = opt.att_feat_size self.att_hid_size = opt.att_hid_size self.finetune_cnn = opt.finetune_cnn self.cbs = opt.cbs self.cbs_mode = opt.cbs_mode self.seq_per_img = 5 if opt.cnn_backend == 'vgg16': self.stride = 16 else: self.stride = 32 self.att_size = int(opt.image_crop_size / self.stride) self.tiny_value = 1e-8 self.pool_feat_size = self.att_feat_size + 300 * 2 self.ss_prob = 0.0 # Schedule sampling probability self.min_value = -1e8 opt.beta = 1 self.beta = opt.beta if opt.cnn_backend == 'res101': self.cnn = resnet(opt, _num_layers=101, _fixed_block=opt.fixed_block, pretrained=True) elif opt.cnn_backend == 'res152': self.cnn = resnet(opt, _num_layers=152, _fixed_block=opt.fixed_block, pretrained=True) elif opt.cnn_backend == 'vgg16': self.cnn = vgg16(opt, pretrained=True) self.det_fc = nn.Sequential(nn.Embedding(self.detect_size + 1, 300), nn.ReLU(), nn.Dropout()) self.loc_fc = nn.Sequential(nn.Linear(5, 300), nn.ReLU(), nn.Dropout()) self.embed = nn.Sequential( nn.Embedding(self.vocab_size + self.detect_size + 1, self.input_encoding_size), nn.ReLU(), nn.Dropout(self.drop_prob_lm)) self.fc_embed = nn.Sequential( nn.Linear(self.fc_feat_size, self.rnn_size), nn.ReLU(), nn.Dropout(self.drop_prob_lm)) self.att_embed = nn.Sequential( nn.Linear(self.att_feat_size, self.rnn_size), nn.ReLU(), nn.Dropout(self.drop_prob_lm)) self.pool_embed = nn.Sequential( nn.Linear(self.pool_feat_size, self.rnn_size), nn.ReLU(), nn.Dropout(self.drop_prob_lm)) self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size) self.ctx2pool = nn.Linear(self.rnn_size, self.att_hid_size) self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) self.roi_align = RoIAlignAvg(1, 1, 1.0 / self.stride) #self.grid_size = 1 #self.roi_crop = _RoICrop() self.critLM = utils.LMCriterion(opt) self.critBN = utils.BNCriterion(opt) self.critFG = utils.FGCriterion(opt) if opt.self_critical: print("load reward function...") self.get_self_critical_reward = get_self_critical_reward(opt) self.critRL = utils.RewardCriterion(opt) # initialize the glove weight for the labels. self.det_fc[0].weight.data.copy_(opt.glove_clss) for p in self.det_fc[0].parameters(): p.requires_grad = False
def train(dataset, loader, model, crit, optimizer, lr_scheduler, opt, rl_crit=None): writer = SummaryWriter('./runs/video_caption_basic') model.load_state_dict( torch.load('/home/diml/video-caption.pytorch/save/new_model_200.pth')) #model = nn.DataParallel(model) model.train() vocab = dataset.get_vocab() for epoch in trange(300): t_loss = 0 # ============================================================================= # model.eval() # ev.demov(model,crit, dataset, dataset.get_vocab(),opt) # ============================================================================= lr_scheduler.step() iteration = 0 # If start self crit training if opt["self_crit_after"] != -1 and epoch >= opt["self_crit_after"]: sc_flag = True init_cider_scorer(opt["cached_tokens"]) else: sc_flag = False for idx, data in enumerate(loader): torch.cuda.synchronize() fc_feats = data['fc_feats'].cuda() labels = data['labels'].cuda() masks = data['masks'].cuda() optimizer.zero_grad() if not sc_flag: seq_probs, seq_preds, hn, de_hn = model( fc_feats, labels, 'train') loss_C = crit(seq_probs, labels[:, 1:], masks[:, 1:]) loss = loss_C else: seq_probs, seq_preds = model(fc_feats, mode='inference', opt=opt) reward = get_self_critical_reward(model, fc_feats, data, seq_preds) print(reward.shape) loss = rl_crit(seq_probs, seq_preds, torch.from_numpy(reward).float().cuda()) t_loss += loss.item() loss.backward() clip_grad_value_(model.parameters(), opt['grad_clip']) optimizer.step() train_loss = loss.item() torch.cuda.synchronize() iteration += 1 if not sc_flag: print("iter %d (epoch %d), train_loss = %.6f" % (iteration, epoch, train_loss)) else: print("iter %d (epoch %d), avg_reward = %.6f" % (iteration, epoch + 201, np.mean(reward[:, 0]))) writer.add_scalar('training total loss', t_loss / 140, epoch + 200) if epoch % opt["save_checkpoint_every"] == 0: model_path = os.path.join(opt["checkpoint_path"], 'new_model_%d.pth' % (epoch + 200)) model_info_path = os.path.join(opt["checkpoint_path"], 'Rnew_model_score.txt') torch.save(model.state_dict(), model_path) print("model saved to %s" % (model_path)) with open(model_info_path, 'a') as f: f.write("model_%d, loss: %.6f\n" % (epoch, train_loss)) with torch.no_grad(): _, seq_preds, __, ___ = model(fc_feats, mode='inference', opt=opt) print(utils.decode_sequence(vocab, seq_preds)[0])
def train(self, data, loader, iteration, epoch, nmt_epoch): nmt_dec_state = None nmt_dec_state_zh = None torch.cuda.synchronize() self.optim.zero_grad() tmp = [ data['fc_feats'], data['attri_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'], data['nmt'] if self.nmt_train_flag else None ] tmp = [ _ if _ is None else (Variable(torch.from_numpy(_), requires_grad=False).cuda() if utils.under_0_4() else torch.from_numpy(_).cuda()) for _ in tmp ] fc_feats, attri_feats, att_feats, labels, masks, att_masks, nmt_batch = tmp if self.i2t_train_flag: if self.update_i2t_lr_flag: self.optim.update_LearningRate( 'i2t', epoch) # Assign the learning rate self.optim.update_ScheduledSampling_prob( self.opt, epoch, self.dp_i2t_model) # Assign the scheduled sampling prob if self.opt.self_critical_after != -1 and epoch >= self.opt.self_critical_after: # If start self critical training self.sc_flag = True init_scorer(self.opt.cached_tokens) else: self.sc_flag = False self.update_i2t_lr_flag = False if not self.sc_flag: i2t_outputs = self.dp_i2t_model(fc_feats, attri_feats, att_feats, labels, att_masks) i2t_loss = self.i2t_crit(i2t_outputs, labels[:, 1:], masks[:, 1:]) else: gen_result, sample_logprobs = self.dp_i2t_model( fc_feats, attri_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = get_self_critical_reward(self.dp_i2t_model, fc_feats, attri_feats, att_feats, att_masks, data, gen_result, self.opt) i2t_loss = self.i2t_rl_crit( sample_logprobs, gen_result.data, Variable(torch.from_numpy(reward).float().cuda(), requires_grad=False)) self.i2t_avg_reward = np.mean(reward[:, 0]) self.i2t_train_loss = i2t_loss.data[0] if utils.under_0_4( ) else i2t_loss.item() i2t_loss.backward(retain_graph=True) if self.nmt_train_flag: if self.update_nmt_lr_flag: self.optim.update_LearningRate( 'nmt', nmt_epoch) # Assign the learning rate outputs, attn, dec_state, upper_bounds = self.dp_nmt_model( nmt_batch.src, nmt_batch.tgt, nmt_batch.lengths, nmt_dec_state) nmt_loss = self.nmt_crit(loader, nmt_batch, outputs, attn) if nmt_dec_state is not None: nmt_dec_state.detach() if nmt_dec_state_zh is not None: nmt_dec_state_zh.detach() self.nmt_crit.report_stats.n_src_words += nmt_batch.lengths.data.sum( ) self.nmt_train_ppl = self.nmt_crit.report_stats.ppl() self.nmt_train_acc = self.nmt_crit.report_stats.accuracy() # Minimize the word embedding weights # wemb_weight_loss = self.weight_trans(self.i2t_model.embed, self.nmt_encoder.embeddings.word_lut) # self.wemb_loss = wemb_weight_loss.data[0] nmt_loss.backward(retain_graph=True) # if self.nmt_train_flag: wemb_weight_loss.backward(retain_graph=True) self.optim.step()
def train(opt): # opt.use_att = utils.if_use_att(opt.caption_model) opt.use_att = True if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length print(opt.checkpoint_path) tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')): with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) critic_loss_history = histories.get('critic_loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) variance_history = histories.get('variance_history', {}) time_history = histories.get('time_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) model = models.setup(opt).cuda() dp_model = model ######################### Actor-critic Training ##################################################################### update_lr_flag = True # Assure in training mode dp_model.train() #TODO: change this to a flag crit = utils.LanguageModelCriterion_binary() rl_crit = utils.RewardCriterion_binary() optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")): optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) first_order = 0 second_order = 0 while True: if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate ** frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False # Load data from train split (0) data = loader.get_batch('train') if data['bounds']['it_pos_now'] > 10000: loader.reset_iterator('train') continue dp_model.train() torch.cuda.synchronize() start = time.time() gen_result = None tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp optimizer.zero_grad() if not sc_flag: loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:], dp_model.depth, dp_model.vocab2code, dp_model.phi_list, dp_model.cluster_size) else: if opt.rl_type == 'sc': gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda(), dp_model.depth) elif opt.rl_type == 'reinforce': gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample') reward = get_reward(data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda(), dp_model.depth) elif opt.rl_type == 'arm': loss = dp_model.get_arm_loss_binary_fast(fc_feats, att_feats, att_masks, opt, data, loader) #print(loss) reward = np.zeros([2,2]) elif opt.rl_type == 'rf4': loss,_,_,_ = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) # print(loss) reward = np.zeros([2, 2]) elif opt.rl_type == 'ar': loss = get_ar_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = np.zeros([2,2]) elif opt.rl_type =='mct_baseline': opt.rf_demean = 0 gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = get_reward(data, gen_result, opt) reward_cuda = torch.from_numpy(reward).float().cuda() mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0] if opt.arm_step_sample == 'greedy': sample_logprobs = sample_logprobs * probs loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - mct_baseline) elif opt.rl_type == 'arsm_baseline': opt.arm_as_baseline = 1 opt.rf_demean = 0 gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = get_reward(data, gen_result, opt) reward_cuda = torch.from_numpy(reward).float().cuda() arm_baseline[arm_baseline < 0] = reward_cuda[arm_baseline < 0] if opt.arm_step_sample == 'greedy' and False: sample_logprobs = sample_logprobs * probs loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda - arm_baseline) elif opt.rl_type == 'ars_indicator': opt.arm_as_baseline = 1 opt.rf_demean = 0 gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader) reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) reward_cuda = torch.from_numpy(reward).float().cuda() loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda * arm_baseline) if opt.mle_weights != 0: loss += opt.mle_weights * crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:]) #TODO make sure all sampling replaced by greedy for critic #### update the actor loss.backward() # with open(os.path.join(opt.checkpoint_path, 'embeddings.pkl'), 'wb') as f: # cPickle.dump(list(dp_model.embed.parameters())[0].data.cpu().numpy(), f) ## compute variance gradient = torch.zeros([0]).cuda() for i in model.parameters(): gradient = torch.cat((gradient, i.grad.view(-1)), 0) first_order = 0.999 * first_order + 0.001 * gradient second_order = 0.999 * second_order + 0.001 * gradient.pow(2) # print(torch.max(torch.abs(gradient))) variance = torch.mean(torch.abs(second_order - first_order.pow(2))).item() if opt.rl_type != 'arsm' or not sc_flag: utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() # ### update the critic train_loss = loss.item() torch.cuda.synchronize() end = time.time() if (iteration % opt.losses_log_every == 0): if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) print(opt.checkpoint_path) else: print("iter {} (epoch {}), avg_reward = {:.3f}, variance = {:g}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:, 0]), variance, end - start)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward), iteration) add_summary_value(tb_summary_writer, 'variance', variance, iteration) loss_history[iteration] = train_loss if not sc_flag else np.mean(reward) lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob variance_history[iteration] = variance time_history[iteration] = end - start # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils_binary.eval_split(dp_model, crit, loader, eval_kwargs) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k,v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = - val_loss best_flag = False if True: # if true if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True if not os.path.isdir(opt.checkpoint_path): os.mkdir(opt.checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, opt.critic_model + '_model.pth') print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['critic_loss_history'] = critic_loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history histories['variance_history'] = variance_history histories['time'] = time_history # histories['variance'] = 0 with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f: cPickle.dump(infos, f) with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f: cPickle.dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def train(opt): logger = initialize_logger(os.path.join(opt.checkpoint_path, 'train.log')) print = logger.info if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length # Print out the option variables print("*" * 20) for k, v in opt.__dict__.items(): print("%r: %r" % (k, v)) print("*" * 20) infos = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos.json'), 'r') as f: infos = json.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) else: best_val_score = None model = models.setup(opt).cuda() dp_model = torch.nn.DataParallel(model) update_lr_flag = True # Assure in training mode dp_model.train() crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join(opt.start_from, "optimizer.pth")): optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) start_time = time.time() while True: if update_lr_flag: # Assign the learning rate if 0 <= opt.learning_rate_decay_start < epoch: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor utils.set_lr(optimizer, opt.current_lr) # set the decayed rate else: opt.current_lr = opt.learning_rate # Assign the scheduled sampling prob if 0 <= opt.scheduled_sampling_start < epoch: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer() else: sc_flag = False update_lr_flag = False # Load data from train split (0) batch_data = loader.get_batch('train') torch.cuda.synchronize() tmp = [ batch_data['fc_feats'], batch_data['att_feats'], batch_data['labels'], batch_data['masks'], batch_data['att_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp optimizer.zero_grad() if not sc_flag: outputs = dp_model(fc_feats, att_feats, labels, att_masks) loss = crit(outputs, labels[:, 1:], masks[:, 1:]) else: gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, mode='sample') reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, batch_data, gen_result, opt) loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.data torch.cuda.synchronize() # Update the iteration and epoch iteration += 1 if batch_data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Print train loss or avg reward if iteration % opt.losses_print_every == 0: if not sc_flag: print( "iter {} (epoch {}), loss = {:.3f}, time = {:.3f}".format( iteration, epoch, loss.item(), time.time() - start_time)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, time = {:.3f}". format(iteration, epoch, np.mean(reward[:, 0]), time.time() - start_time)) start_time = time.time() # make evaluation on validation set, and save model if (opt.save_checkpoint_every > 0 and iteration % opt.save_checkpoint_every == 0)\ or (opt.save_checkpoint_every <= 0 and update_lr_flag): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.simple_eval_split( dp_model, loader, eval_kwargs) # Save model if is improving on validation result if not os.path.exists(opt.checkpoint_path): os.makedirs(opt.checkpoint_path) if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscellaneous information infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = vars(opt) infos['vocab'] = loader.get_vocab() with open(os.path.join(opt.checkpoint_path, 'infos.json'), 'w') as f: json.dump(infos, f, sort_keys=True, indent=4) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open(os.path.join(opt.checkpoint_path, 'infos-best.json'), 'w') as f: json.dump(infos, f, sort_keys=True, indent=4) # Stop if reaching max epochs if opt.max_epochs != -1 and epoch >= opt.max_epochs: break
def train(opt): opt.use_att = utils.if_use_att(opt.caption_model) loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile(os.path.join(opt.start_from, 'histories.pkl')): with open(os.path.join(opt.start_from, 'histories.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) model = models.setup(opt) model.cuda() update_lr_flag = True # Assure in training mode model.train() crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay) # Load the optimizer if vars(opt).get('start_from', None) is not None: optimizer.load_state_dict( torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) while True: if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor utils.set_lr(optimizer, opt.current_lr) # set the decayed rate else: opt.current_lr = opt.learning_rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_cider_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False start = time.time() # Load data from train split (0) data = loader.get_batch('train') torch.cuda.synchronize() start = time.time() tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'] ] tmp = [ Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp ] fc_feats, att_feats, labels, masks = tmp fc_feats = fc_feats.unsqueeze(1).expand(*(( fc_feats.size(0), opt.seq_per_img, ) + fc_feats.size()[1:])).contiguous().view( *((fc_feats.size(0) * opt.seq_per_img, ) + fc_feats.size()[1:])) att_feats = att_feats.unsqueeze(1).expand(*(( att_feats.size(0), opt.seq_per_img, ) + att_feats.size()[1:])).contiguous().view( *((att_feats.size(0) * opt.seq_per_img, ) + att_feats.size()[1:])) optimizer.zero_grad() outputs = model(fc_feats, att_feats, labels) if opt.caption_model == 'stack_cap': loss_coarse = crit(outputs[0], labels[:, 1:], masks[:, 1:]) loss_fine_0 = crit(outputs[1], labels[:, 1:], masks[:, 1:]) loss_fine_1 = crit(outputs[-1], labels[:, 1:], masks[:, 1:]) loss = loss_fine_1 + loss_coarse + loss_fine_0 else: if not sc_flag: loss = crit(outputs, labels[:, 1:], masks[:, 1:]) else: gen_result, sample_logprobs = model.sample( fc_feats, att_feats, {'sample_max': 0}) reward = get_self_critical_reward(model, fc_feats, att_feats, data, gen_result) loss = rl_crit( sample_logprobs, gen_result, Variable(torch.from_numpy(reward).float().cuda(), requires_grad=False)) loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.data[0] torch.cuda.synchronize() end = time.time() if opt.caption_model == 'stack_cap': print("{}|I:{}/E:{}|Tloss_0:{:.3f}/Tloss_1:{:.3f}/Tloss_2:{:.3f}|T/B={:.3f}" \ .format(opt.caption_model, iteration, epoch, loss_coarse.data[0], loss_fine_0.data[0], loss_fine_1.data[0], end - start)) else: if not sc_flag: print("{}|I:{}/E:{}|Train_loss:{:.3f}|T/B={:.3f}".format( opt.caption_model, iteration, epoch, loss.data[0], end - start)) else: print("{}|I:{}/E:{}|Avg_reward:{:.3f}|T/B={:.3f}".format( opt.caption_model, iteration, epoch, np.mean(reward[:, 0]), end - start)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): if tf is not None: if opt.caption_model == 'stack_cap': add_summary_value(tf_summary_writer, 'train_loss_coarse', loss_coarse.data[0], iteration) add_summary_value(tf_summary_writer, 'train_loss_fine_0', loss_fine_0.data[0], iteration) add_summary_value(tf_summary_writer, 'train_loss_fine_1', loss_fine_1.data[0], iteration) else: add_summary_value(tf_summary_writer, 'train_loss', loss.data[0], iteration) add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) tf_summary_writer.flush() loss_history[iteration] = train_loss lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.ss_prob # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( opt, model, crit, loader, eval_kwargs) # Write validation result into summary if tf is not None: add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration) for k, v in lang_stats.items(): add_summary_value(tf_summary_writer, k, v, iteration) tf_summary_writer.flush() val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss best_flag = False if True: # if true if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '-best.pkl'), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break