def greedy_baseline(self, fc_feats, att_feats, att_masks, retrieval_loss, _seqs, _sampleLogProbs, _masks): if att_masks is not None: wrapper = [fc_feats, att_feats, att_masks] _seqs_greedy, _sampleLogProbs_greedy = \ self.caption_generator.sample( *utils.var_wrapper(wrapper, cuda=torch.cuda.is_available(), volatile=True), opt={ 'sample_max': 1, 'temperature': 1}) else: wrapper = [fc_feats, att_feats] _seqs_greedy, _sampleLogProbs_greedy = \ self.caption_generator.sample( *utils.var_wrapper(wrapper, cuda=torch.cuda.is_available(), volatile=True), None, opt={ 'sample_max': 1, 'temperature': 1}) greedy_res = _seqs_greedy if (_seqs_greedy > 0).float()[:, :-1].dim() > 1: _masks_greedy = torch.cat([ Variable( _seqs_greedy.data.new(_seqs.size(0), 2).fill_(1).float()), (_seqs_greedy > 0).float()[:, :-1] ], 1) else: _masks_greedy = torch.cat([ Variable( _seqs_greedy.data.new(_seqs.size(0), 2).fill_(1).float()), torch.unsqueeze((_seqs_greedy > 0).float()[:, :-1], 1) ], 1) _seqs_greedy = torch.cat([ Variable( _seqs_greedy.data.new( _seqs_greedy.size(0), 1).fill_(self.caption_generator.vocab_size + 1)), _seqs_greedy ], 1) baseline = self.vse(fc_feats, att_feats, _seqs_greedy, _masks_greedy, True, only_one_retrieval=self.only_one_retrieval) sc_loss = _sampleLogProbs * ( utils.var_wrapper(retrieval_loss, cuda=torch.cuda.is_available()) - utils.var_wrapper(baseline, cuda=torch.cuda.is_available()) ).detach().unsqueeze(1) * (_masks[:, 1:].detach().float()) return baseline, sc_loss, greedy_res
def gt_baseline(self, fc_feats, att_feats, att_masks, retrieval_loss, _seqs, _sampleLogProbs, _masks, seq, masks): baseline = self.vse(fc_feats, att_feats, seq, masks, True, only_one_retrieval=self.only_one_retrieval) sc_loss = _sampleLogProbs * (utils.var_wrapper( retrieval_loss, cuda=torch.cuda.is_available ()) - utils.var_wrapper(baseline, cuda= torch.cuda.is_available())).detach() \ .unsqueeze(1) * (_masks[:, 1:].detach().float()) return baseline, sc_loss
def traditional_cider(self, fc_feats, att_feats, att_masks, data, loss, gen_result, greedy_res, sample_logprobs, gen_masks): # Use the differenced rewards if self.use_gen_cider_scores == 0: reward, cider_greedy = rewards.get_self_critical_reward( data, gen_result, greedy_res) else: # use the original rewards reward, _, cider_greedy = \ rewards.get_self_critical_reward( data, gen_result, greedy_res, return_gen_scores=True) self._loss['avg_reward'] = reward.mean() self._loss['cider_greedy'] = cider_greedy loss_cider = sample_logprobs * utils.var_wrapper( -reward.astype('float32'), cuda=torch.cuda.is_available()).unsqueeze(1) * ( gen_masks[:, 1:].detach()) loss_cider = loss_cider.sum() / \ gen_masks[:, 1:].data.float().sum() loss += self.cider_optimization * loss_cider self._loss['loss_cider'] = loss_cider.data[0] return loss
def greedy_res_for_cider(self, fc_feats, att_feats, att_masks): if att_masks is not None: greedy_res, _ = self.caption_generator.sample( *utils.var_wrapper([fc_feats, att_feats, att_masks], cuda=torch.cuda.is_available(), volatile=True), opt={'sample_max': 1}) else: greedy_res, _ = self.caption_generator.sample( *utils.var_wrapper([fc_feats, att_feats], cuda=torch.cuda.is_available(), volatile=True), att_masks, opt={'sample_max': 1}) return greedy_res
def encode_data(model, loader, eval_kwargs={}): num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) split = eval_kwargs.get('split', 'val') dataset = eval_kwargs.get('dataset', 'coco') # Make sure in the evaluation mode model.eval() loader_seq_per_img = loader.seq_per_img loader.seq_per_img = 5 loader.reset_iterator(split) n = 0 img_embs = [] cap_embs = [] while True: data = loader.get_batch(split) n = n + loader.batch_size tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'] ] tmp = utils.var_wrapper(tmp) fc_feats, att_feats, labels, masks = tmp with torch.no_grad(): img_emb = model.vse.img_enc(fc_feats) cap_emb = model.vse.txt_enc(labels, masks) # if we wrapped around the split or used up val imgs budget then bail ix0 = data['bounds']['it_pos_now'] ix1 = data['bounds']['it_max'] if num_images != -1: ix1 = min(ix1, num_images) if n > ix1: img_emb = img_emb[:(ix1 - n) * loader.seq_per_img] cap_emb = cap_emb[:(ix1 - n) * loader.seq_per_img] # preserve the embeddings by copying from gpu and converting to np img_embs.append(img_emb.data.cpu().numpy().copy()) cap_embs.append(cap_emb.data.cpu().numpy().copy()) if data['bounds']['wrapped']: break if num_images >= 0 and n >= num_images: break print("%d/%d" % (n, ix1)) img_embs = np.vstack(img_embs) cap_embs = np.vstack(cap_embs) assert img_embs.shape[0] == ix1 * loader.seq_per_img loader.seq_per_img = loader_seq_per_img return img_embs, cap_embs
def no_baseline(self, retrieval_loss, _sampleLogProbs, _masks): baseline = 0 sc_loss = _sampleLogProbs * (utils.var_wrapper( retrieval_loss, torch.cuda.is_available())) \ .detach().unsqueeze(1) * (_masks[:, 1:]. detach().float()) return baseline, sc_loss
def forward(self, fc_feats, att_feats, att_masks, seq, masks, data): if self.caption_loss_weight > 0 and not self.cider_optimization: loss_cap = self.caption_generator(fc_feats, att_feats, att_masks, seq, masks) else: loss_cap = Variable(torch.cuda.FloatTensor([0])) if self.vse_loss_weight > 0: loss_vse = self.vse(fc_feats, att_feats, seq, masks, only_one_retrieval=self.only_one_retrieval) else: loss_vse = Variable(torch.cuda.FloatTensor([0])) loss = self.caption_loss_weight * loss_cap + self.vse_loss_weight * loss_vse if self.retrieval_reward_weight > 0: if True: _seqs, _sampleLogProbs = self.caption_generator.sample( fc_feats, att_feats, att_masks, { 'sample_max': 0, 'temperature': 1 }) gen_result, sample_logprobs = _seqs, _sampleLogProbs _masks = torch.cat([ Variable( _seqs.data.new(_seqs.size(0), 2).fill_(1).float()), (_seqs > 0).float()[:, :-1] ], 1) gen_masks = _masks _seqs = torch.cat([ Variable( _seqs.data.new( _seqs.size(0), 1).fill_(self.caption_generator.vocab_size + 1)), _seqs ], 1) if True: retrieval_loss = self.vse( fc_feats, att_feats, _seqs, _masks, True, only_one_retrieval=self.only_one_retrieval) if self.reinforce_baseline_type == 'greedy': _seqs_greedy, _sampleLogProbs_greedy = self.caption_generator.sample( *utils.var_wrapper( [fc_feats, att_feats, att_masks], volatile=True), opt={ 'sample_max': 1, 'temperature': 1 }) greedy_res = _seqs_greedy # Do we need weights here??? if True: #not self.use_word_weights: _masks_greedy = torch.cat([ Variable( _seqs_greedy.data.new(_seqs.size(0), 2).fill_(1).float()), (_seqs_greedy > 0).float()[:, :-1] ], 1) else: _masks_greedy = self.get_word_weights_mask( _seqs_greedy) _seqs_greedy = torch.cat([ Variable( _seqs_greedy.data.new(_seqs_greedy.size(0), 1). fill_(self.caption_generator.vocab_size + 1)), _seqs_greedy ], 1) baseline = self.vse( fc_feats, att_feats, _seqs_greedy, _masks_greedy, True, only_one_retrieval=self.only_one_retrieval) elif self.reinforce_baseline_type == 'gt': baseline = self.vse( fc_feats, att_feats, seq, masks, True, only_one_retrieval=self.only_one_retrieval) else: baseline = 0 sc_loss = _sampleLogProbs * ( utils.var_wrapper(retrieval_loss) - utils.var_wrapper(baseline)).detach().unsqueeze(1) * ( _masks[:, 1:].detach().float()) sc_loss = sc_loss.sum() / _masks[:, 1:].data.float().sum() loss += self.retrieval_reward_weight * sc_loss self._loss['retrieval_sc_loss'] = sc_loss.data[0] self._loss['retrieval_loss'] = retrieval_loss.sum().data[0] self._loss['retrieval_loss_greedy'] = baseline.sum( ).data[0] if isinstance(baseline, Variable) else baseline if self.cider_optimization: if 'gen_result' not in locals(): gen_result, sample_logprobs = self.caption_generator.sample( fc_feats, att_feats, att_masks, opt={'sample_max': 0}) gen_masks = torch.cat([ Variable( gen_result.data.new(gen_result.size(0), 2).fill_(1).float()), (gen_result > 0).float()[:, :-1] ], 1) if 'greedy_res' not in locals(): greedy_res, _ = self.caption_generator.sample( *utils.var_wrapper([fc_feats, att_feats, att_masks], volatile=True), opt={'sample_max': 1}) reward, cider_greedy = rewards.get_self_critical_reward( data, gen_result, greedy_res) self._loss['avg_reward'] = reward.mean() self._loss['cider_greedy'] = cider_greedy loss_cap = sample_logprobs * utils.var_wrapper(-reward.astype( 'float32')).unsqueeze(1) * (gen_masks[:, 1:].detach()) loss_cap = loss_cap.sum() / gen_masks[:, 1:].data.float().sum() loss += self.caption_loss_weight * loss_cap self._loss['loss_cap'] = loss_cap.item() self._loss['loss_vse'] = loss_vse.item() self._loss['loss'] = loss.item() return loss
def train(opt): opt.use_att = utils.if_use_att(opt) loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length tf_summary_writer = tf and SummaryWriter(opt.checkpoint_path) infos = {} histories = {} if opt.start_from is not None: # open old infos and check if models are compatible with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) best_val_score_vse = infos.get('best_val_score_vse', None) model = models.JointModel(opt) model.cuda() update_lr_flag = True # Assure in training mode model.train() optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=opt.learning_rate, weight_decay=opt.weight_decay) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile( os.path.join(opt.start_from, 'optimizer.pth')): state_dict = torch.load(os.path.join(opt.start_from, 'optimizer.pth')) if len(state_dict['state']) == len(optimizer.state_dict()['state']): optimizer.load_state_dict(state_dict) else: print( 'Optimizer param group number not matched? There must be new parameters. Reinit the optimizer.' ) init_scorer(opt.cached_tokens) while True: if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor utils.set_lr(optimizer, opt.current_lr) # set the decayed rate else: opt.current_lr = opt.learning_rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) model.caption_generator.ss_prob = opt.ss_prob # Assign retrieval loss weight if epoch > opt.retrieval_reward_weight_decay_start and opt.retrieval_reward_weight_decay_start >= 0: frac = (epoch - opt.retrieval_reward_weight_decay_start ) // opt.retrieval_reward_weight_decay_every model.retrieval_reward_weight = opt.retrieval_reward_weight * ( opt.retrieval_reward_weight_decay_rate**frac) update_lr_flag = False start = time.time() # Load data from train split (0) data = loader.get_batch('train') print('Read data:', time.time() - start) torch.cuda.synchronize() start = time.time() tmp = [ data['fc_feats'], data['att_feats'], data['att_masks'], data['labels'], data['masks'] ] tmp = utils.var_wrapper(tmp) fc_feats, att_feats, att_masks, labels, masks = tmp optimizer.zero_grad() loss = model(fc_feats, att_feats, att_masks, labels, masks, data) loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() train_loss = loss.data[0] torch.cuda.synchronize() end = time.time() print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, train_loss, end - start)) prt_str = "" for k, v in model.loss().items(): prt_str += "{} = {:.3f} ".format(k, v) print(prt_str) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): if tf is not None: tf_summary_writer.add_scalar('train_loss', train_loss, iteration) for k, v in model.loss().items(): tf_summary_writer.add_scalar(k, v, iteration) tf_summary_writer.add_scalar('learning_rate', opt.current_lr, iteration) tf_summary_writer.add_scalar('scheduled_sampling_prob', model.caption_generator.ss_prob, iteration) tf_summary_writer.add_scalar('retrieval_reward_weight', model.retrieval_reward_weight, iteration) tf_summary_writer.file_writer.flush() loss_history[iteration] = train_loss lr_history[iteration] = opt.current_lr ss_prob_history[iteration] = model.caption_generator.ss_prob # make evaluation on validation set, and save model if (iteration % opt.save_checkpoint_every == 0): # eval model eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) # Load the retrieval model for evaluation val_loss, predictions, lang_stats = eval_utils.eval_split( model, loader, eval_kwargs) # Write validation result into summary if tf is not None: for k, v in val_loss.items(): tf_summary_writer.add_scalar('validation ' + k, v, iteration) for k, v in lang_stats.items(): tf_summary_writer.add_scalar(k, v, iteration) tf_summary_writer.add_text( 'Captions', '.\n\n'.join([_['caption'] for _ in predictions[:100]]), iteration) #tf_summary_writer.add_image('images', utils.make_summary_image(), iteration) #utils.make_html(opt.id, iteration) tf_summary_writer.file_writer.flush() val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Save model if is improving on validation result if opt.language_eval == 1: current_score = lang_stats['SPICE'] * 100 else: current_score = -val_loss['loss_cap'] current_score_vse = val_loss.get(opt.vse_eval_criterion, 0) * 100 best_flag = False best_flag_vse = False if True: # if true if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True if best_val_score_vse is None or current_score_vse > best_val_score_vse: best_val_score_vse = current_score_vse best_flag_vse = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) checkpoint_path = os.path.join(opt.checkpoint_path, 'model-%d.pth' % (iteration)) torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['best_val_score_vse'] = best_val_score_vse infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history histories['ss_prob_history'] = ss_prob_history with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join( opt.checkpoint_path, 'infos_' + opt.id + '-%d.pkl' % (iteration)), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(histories, f) if best_flag: checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '-best.pkl'), 'wb') as f: cPickle.dump(infos, f) if best_flag_vse: checkpoint_path = os.path.join(opt.checkpoint_path, 'model_vse-best.pth') torch.save(model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open( os.path.join(opt.checkpoint_path, 'infos_vse_' + opt.id + '-best.pkl'), 'wb') as f: cPickle.dump(infos, f) # Stop if reaching max epochs if epoch >= opt.max_epochs and opt.max_epochs != -1: break
def eval_split(model, loader, eval_kwargs={}): verbose = eval_kwargs.get('verbose', True) num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) split = eval_kwargs.get('split', 'val') lang_eval = eval_kwargs.get('language_eval', 0) rank_eval = eval_kwargs.get('rank_eval', 0) dataset = eval_kwargs.get('dataset', 'coco') beam_size = eval_kwargs.get('beam_size', 1) # Make sure in the evaluation mode model.eval() np.random.seed(123) loader.reset_iterator(split) n = 0 losses = {} loss_evals = 1e-8 predictions = [ ] # Save the discriminative results. Used for further html visualization. while True: data = loader.get_batch(split) n = n + loader.batch_size if data.get('labels', None) is not None: # forward the model to get loss tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] tmp = [ Variable(torch.from_numpy(_), volatile=True).cuda() for _ in tmp ] fc_feats, att_feats, labels, masks, att_masks = tmp loss = model(fc_feats, att_feats, att_masks, labels, masks, data) loss = loss.data[0] for k, v in model.loss().items(): if k not in losses: losses[k] = 0 losses[k] += v loss_evals = loss_evals + 1 # forward the model to also get generated samples for each image # Only leave one feature for each image, in case duplicate sample tmp = [ data['fc_feats'][np.arange(loader.batch_size) * loader.seq_per_img], data['att_feats'][np.arange(loader.batch_size) * loader.seq_per_img], data['att_masks'][np.arange(loader.batch_size) * loader.seq_per_img] ] tmp = utils.var_wrapper(tmp, volatile=True) fc_feats, att_feats, att_masks = tmp # forward the model to also get generated samples for each image seq, _ = model.sample(fc_feats, att_feats, att_masks, opt=eval_kwargs) sents = utils.decode_sequence(loader.get_vocab(), seq.data) for k, sent in enumerate(sents): entry = {'image_id': data['infos'][k]['id'], 'caption': sent} if eval_kwargs.get('dump_path', 0) == 1: entry['file_name'] = data['infos'][k]['file_path'] predictions.append(entry) if eval_kwargs.get('dump_images', 0) == 1: # dump the raw image to vis/ folder cmd = 'cp "' + os.path.join( eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str( len(predictions)) + '.jpg' # bit gross print(cmd) os.system(cmd) if verbose: print('image %s: %s' % (entry['image_id'], entry['caption'])) # if we wrapped around the split or used up val imgs budget then bail ix0 = data['bounds']['it_pos_now'] ix1 = data['bounds']['it_max'] if num_images != -1: ix1 = min(ix1, num_images) for i in range(n - ix1): predictions.pop() if verbose: print('evaluating validation preformance... %d/%d (%f)' % (ix0 - 1, ix1, loss)) if data['bounds']['wrapped']: break if num_images >= 0 and n >= num_images: break lang_stats = None if lang_eval == 1: lang_stats = language_eval(dataset, predictions, eval_kwargs['id'], split) else: lang_stats = {} ranks = evalrank(model, loader, eval_kwargs) if rank_eval else {} # Switch back to training mode model.train() losses = {k: v / loss_evals for k, v in losses.items()} losses.update(ranks) return losses, predictions, lang_stats
def encode_data_generated(model, loader, captions, eval_kwargs={}): num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) split = eval_kwargs.get('split', 'val') dataset = eval_kwargs.get('dataset', 'coco') # Make sure in the evaluation mode model.eval() loader_seq_per_img = loader.seq_per_img loader.seq_per_img = 5 loader.reset_iterator(split) print('num_images', num_images) print(captions.size()) print(len(loader)) n = 0 img_embs = [] cap_embs = [] while True: data = loader.get_batch(split) labels = captions[n:(n + loader.batch_size)] masks = (labels > 0).float() for i in range(labels.size(0)): for j in range(labels.size(1) - 1): if labels[i, j].item() > 0.5 and labels[i, j + 1].item() < 0.5: masks[i, j + 1] = 1.0 break n = n + loader.batch_size tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'] ] tmp = utils.var_wrapper(tmp, volatile=True) fc_feats, att_feats, _, __ = tmp fc_feats = fc_feats.cuda() labels = labels.cuda() masks = masks.cuda() img_emb = model.vse.img_enc(fc_feats) cap_emb = model.vse.txt_enc(labels, masks) # if we wrapped around the split or used up val imgs budget then bail ix0 = data['bounds']['it_pos_now'] ix1 = data['bounds']['it_max'] if num_images != -1: ix1 = min(ix1, num_images) print(img_emb.size()) img_embs.append(img_emb.data.cpu().numpy().copy()) cap_embs.append(cap_emb.data.cpu().numpy().copy()) if n > ix1: img_emb = img_emb[:(ix1 - n) * loader.seq_per_img] cap_emb = cap_emb[:(ix1 - n) * loader.seq_per_img] # preserve the embeddings by copying from gpu and converting to np print(cap_emb.size()) if data['bounds']['wrapped']: break if num_images >= 0 and n >= num_images: break print("%d/%d" % (n, ix1)) print('start stack') img_embs = np.vstack(img_embs)[:num_images * 5] cap_embs = np.vstack(cap_embs)[:num_images] print(img_embs.shape) print(cap_embs.shape) print('stack') #assert img_embs.shape[0] == ix1 * loader.seq_per_img loader.seq_per_img = loader_seq_per_img return img_embs, cap_embs