def train(opt): # Load data loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length # Tensorboard summaries (they're great!) # Load pretrained model, info file, histories file infos = {} histories = {} if opt.start_from is not None: with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same=["rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')): with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) #ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) encoder = MemoryAugmentedEncoder(3, 0, attention_module=ScaledDotProductAttentionMemory, attention_module_kwargs={'m': 40}) decoder = MeshedDecoder(8668, 180, 3, 0) models = Transformer(8667, encoder, decoder) # Create model model = models.cuda() lang_model = Seq2Seq().cuda() # Create model model.load_state_dict(torch.load('./log_cvpr_mesh/all2model20000.pth')) lang_model.load_state_dict(torch.load('log_cvpr/all2model16000.pth'), strict=False) optimizer = utils.build_optimizer_adam(list(models.parameters()) + list(lang_model.parameters()), opt) update_lr_flag = True while True: # Update learning rate once per epoch if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate ** frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every #opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) #model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False # Load data from train split (0) start = time.time() data = loader.get_batch('train') data_time = time.time() - start start = time.time() # Unpack data torch.cuda.synchronize() tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['dist'], data['masks'], data['att_masks']] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, dist_label, masks, attmasks = tmp labels = labels.long() batchsize = fc_feats.size(0) labels_decode = labels.view(-1, 180) captions = utils.decode_sequence(loader.get_vocab(), labels_decode, None) captions_all = [] for index, caption in enumerate(captions): caption = caption.replace('<start>', '').replace(' ,', '').replace(' ', ' ') captions_all.append(caption) # Forward pass and loss d_steps = 1 g_steps = 1 #print (torch.sum(labels!=0), torch.sum(masks!=0)) if 1: if 1: model.train() optimizer.zero_grad() wordact, x_all_image = model(att_feats, labels.view(batchsize, -1)) wordact_t = wordact[:,:-1,:] wordact_t = wordact_t.contiguous().view(wordact_t.size(0) * wordact_t.size(1), -1) labels_flat = labels.view(batchsize,-1) wordclass_v = labels_flat[:, 1:] wordclass_t = wordclass_v.contiguous().view(\ wordclass_v.size(0) * wordclass_v.size(1), -1) loss_xe = F.cross_entropy(wordact_t[ ...], \ wordclass_t[...].contiguous().view(-1)) ''' wordact = lang_model(labels.view(batchsize, -1).transpose(1, 0), labels.view(batchsize, -1).transpose(1, 0), fc_feats) wordact_t = wordact.transpose(1, 0)[:, 1:, :] wordact_t = wordact_t.contiguous().view(wordact_t.size(0) * wordact_t.size(1), -1) labels_flat = labels.view(batchsize, -1) wordclass_v = labels_flat[:, 1:] wordclass_t = wordclass_v.contiguous().view( \ wordclass_v.size(0) * wordclass_v.size(1), -1) loss_xe_lang = F.cross_entropy(wordact_t[...], wordclass_t[...].view(-1)) ''' outcap, sampled_ids, sample_logprobs= lang_model.sample(labels.view(batchsize, -1).transpose(1,0),labels.view(batchsize, -1).transpose(1,0), fc_feats, loader.get_vocab()) sampled_ids[:, 0] = 8667 logprobs_input, _ = model(att_feats, sampled_ids.long().cuda()) log_probs = F.log_softmax(logprobs_input[:, :-1, :], -1) sample_logprobs_true = log_probs.gather(2, sampled_ids[:, 1:].cuda().long().unsqueeze(2)) with torch.no_grad(): reward, cider_sample, cider_greedy = get_self_critical_reward(batchsize, lang_model, labels.view(batchsize, -1).transpose(1,0), fc_feats, outcap, captions_all, loader, 180) print (np.mean(cider_greedy)) loss_rl1 = rl_crit(torch.exp(sample_logprobs_true.squeeze()) / torch.exp(sample_logprobs[:, 1:]).cuda().detach(),sampled_ids[:, 1:].cpu(), torch.from_numpy(reward).float().cuda()) #loss_rl = rl_crit(sample_logprobs, sampled_ids.cpu(), torch.from_numpy(reward).float()).cuda() #x_all_langauge = x_all_langauge.cuda().detach() #l2_loss = ((x_all_image.transpose(2,1).cuda() - x_all_langauge) ** 2).mean().cuda() train_loss = loss_xe + loss_rl1 # + loss_xe_lang train_loss.backward() optimizer.step() if 1: if iteration % opt.print_freq == 1: print('Read data:', time.time() - start) if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, data_time = {:.3f}" \ .format(iteration, epoch, loss_xe, data_time)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) #add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration) loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0]) lr_history[iteration] = opt.current_lr #ss_prob_history[iteration] = model.ss_prob # Validate and save model if (iteration % opt.save_checkpoint_every == 0): checkpoint_path = os.path.join(opt.checkpoint_path, 'all2model{:05d}.pth'.format(iteration)) torch.save(model.state_dict(), checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'lang_model{:05d}.pth'.format(iteration)) torch.save(lang_model.state_dict(), checkpoint_path) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path)
text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, tokenize='spacy', remove_punctuation=True, nopoints=False) # Create the dataset dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder) train_dataset, val_dataset, test_dataset = dataset.splits if not os.path.isfile('vocab_%s.pkl' % args.exp_name): print("Building vocabulary") text_field.build_vocab(train_dataset, val_dataset, min_freq=5) pickle.dump(text_field.vocab, open('vocab_%s.pkl' % args.exp_name, 'wb')) else: text_field.vocab = pickle.load(open('vocab_%s.pkl' % args.exp_name, 'rb')) # Model and dataloaders encoder = MemoryAugmentedEncoder(3, 0, attention_module=ScaledDotProductAttentionMemory, attention_module_kwargs={'m': args.m}) decoder = MeshedDecoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi['<pad>']) model = Transformer(text_field.vocab.stoi['<bos>'], encoder, decoder).to(device) dict_dataset_train = train_dataset.image_dictionary({'image': image_field, 'text': RawField()}) ref_caps_train = list(train_dataset.text) cider_train = Cider(PTBTokenizer.tokenize(ref_caps_train)) dict_dataset_val = val_dataset.image_dictionary({'image': image_field, 'text': RawField()}) dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()}) def lambda_lr(s): warm_up = args.warmup s += 1 return (model.d_model ** -.5) * min(s ** -.5, s * warm_up ** -1.5)
def train(opt): # Load data loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length # Tensorboard summaries (they're great!) tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) # Load pretrained model, info file, histories file infos = {} histories = {} if opt.start_from is not None: with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same=["rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')): with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) #ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) encoder = MemoryAugmentedEncoder(3, 0, attention_module=ScaledDotProductAttentionMemory, attention_module_kwargs={'m': 40}) decoder = MeshedDecoder(8668, 180, 3, 0) models = Transformer(8667, encoder, decoder) # Create model model = models.cuda() lang_model = Seq2Seq().cuda() model.load_state_dict(torch.load('log_meshed/all2model20000.pth')) lang_model.load_state_dict(torch.load('language_model/langmodel06000.pth')) optimizer = utils.build_optimizer_adam(list(models.parameters())+ list(lang_model.parameters()), opt) update_lr_flag = True while True: # Update learning rate once per epoch if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate ** frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every #opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) #model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False # Load data from train split (0) start = time.time() data = loader.get_batch('train') data_time = time.time() - start start = time.time() # Unpack data torch.cuda.synchronize() tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['dist'], data['masks'], data['att_masks']] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, dist_label, masks, attmasks = tmp labels = labels.long() captions = utils.decode_sequence(loader.get_vocab(), labels.view(fc_feats.size(0), -1), None) captions_all = [] for index, caption in enumerate(captions): caption = caption.replace('<start>', '').replace(' ,', '').replace(' ', ' ') captions_all.append(caption) nd_labels = labels batchsize = fc_feats.size(0) # Forward pass and loss d_steps = 1 g_steps = 1 beta = 0.2 #print (orch.sum(labels!=0), torch.sum(masks!=0)) if 1: if 1: model.train() optimizer.zero_grad() wordact, _ = model(att_feats, labels.view(batchsize, -1)) wordact_t = wordact[:,:-1,:] wordact_t = wordact_t.contiguous().view(wordact_t.size(0) * wordact_t.size(1), -1) labels_flat = labels.view(batchsize,-1) wordclass_v = labels_flat[:, 1:] wordclass_t = wordclass_v.contiguous().view(\ wordclass_v.size(0) * wordclass_v.size(1), -1) loss_xe = F.cross_entropy(wordact_t[ ...], \ wordclass_t[...].contiguous().view(-1)) with torch.no_grad(): outcap, sampled_ids, sample_logprobs, x_all_langauge, outputs, log_probs_all = lang_model.sample(labels.view(batchsize, -1).transpose(1,0), att_feats.transpose(1,0), loader.get_vocab()) logprobs_input, _ = model(att_feats, sampled_ids.cuda().long()) log_probs = F.log_softmax(logprobs_input[:,:,:], 2) sample_logprobs_true = log_probs.gather(2, sampled_ids[:,:].cuda().long().unsqueeze(2)) with torch.no_grad(): reward, cider_sample, cider_greedy, caps_sample, caps = get_self_critical_reward(batchsize, lang_model, labels.view(batchsize, -1).transpose(1,0), att_feats.transpose(1,0), outcap, captions_all, loader, 180) reward = torch.tensor(reward) kl_div = F.kl_div(log_probs.squeeze().cuda().detach(), torch.exp(log_probs_all.transpose(1,0)).cuda().detach(), reduce= False) ratio_no = sample_logprobs_true.squeeze().cpu().double() ratio_de = sample_logprobs.cpu().double() ratio_no_f = torch.exp(ratio_no) ratio_de_f = torch.exp(ratio_de) ratio = (ratio_no_f/((1-beta)*ratio_de_f+ beta*ratio_no_f)) ratio = torch.clamp(ratio, min = 0.96) ratio_prod = ratio.prod(1) reward = (torch.tensor(reward).cuda()) - 0.05 * kl_div.mean() loss_rl1 = rl_crit(ratio_prod.cuda().unsqueeze(1).detach()*sample_logprobs_true.squeeze()[:,:-1], sampled_ids[:,1:].cpu(), reward.float().cuda().detach()) #writer.add_scalar('RL loss', loss_rl1 , iteration) #writer.add_scalar('TRIS ratio', ratio.mean(), iteration) #writer.add_scalar('XE_loss', loss_xe, iteration) #writer.add_scalar('KL_div', kl_div.mean(), iteration) lamb = 0.5 train_loss = lamb * loss_rl1 + (1 - lamb)* loss_xe train_loss.backward() optimizer.step() if 1: if iteration % opt.print_freq == 1: print('Read data:', time.time() - start) if not sc_flag: print (ratio.mean()) print (reward.mean()) print (kl_div.mean()) print("iter {} (epoch {}), train_loss = {:.4f}, xe_loss = {:.3f}, train_time = {:.3f}" \ .format(iteration, epoch, train_loss.item(), loss_xe, data_time)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) #add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration) loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0]) lr_history[iteration] = opt.current_lr #ss_prob_history[iteration] = model.ss_prob # Validate and save model if (iteration % opt.save_checkpoint_every == 0): checkpoint_path = os.path.join(opt.checkpoint_path, 'all2model{:05d}.pth'.format(iteration)) torch.save(model.state_dict(), checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'lang_model{:05d}.pth'.format(iteration)) torch.save(lang_model.state_dict(), checkpoint_path) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Evaluate model #if 0: eval_kwargs = {'split': 'test', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) crit = utils.LanguageModelCriterion() val_loss, predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs) # Write validation result into summary #add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) #if lang_stats is not None: # for k,v in lang_stats.items(): # add_summary_value(tb_summary_writer, k, v, iteration) #val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} # Our metric is CIDEr if available, otherwise validation loss #if opt.language_eval == 1: current_score = lang_stats['CIDEr'] # else: # current_score = - val_loss # Save model in checkpoint path best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path)
parser.add_argument('--vocab', type=str, default='vocab.pkl') args = parser.parse_args() print('Meshed-Memory Transformer Evaluation') # Pipeline for image regions image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=50, load_in_tmp=False) # Pipeline for text text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, tokenize='spacy', remove_punctuation=True, nopoints=False) # Create the dataset dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder) _, _, test_dataset = dataset.splits text_field.vocab = pickle.load(open(args.vocab, 'rb')) # Model and dataloaders encoder = MemoryAugmentedEncoder(3, 0, args.d_in, d_ff=args.d_in, attention_module=ScaledDotProductAttentionMemory, attention_module_kwargs={'m': 40}) decoder = MeshedDecoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi['<pad>'], d_ff=args.d_in) model = Transformer(text_field.vocab.stoi['<bos>'], encoder, decoder).to(device) data = torch.load(args.weights) model.load_state_dict(data['state_dict']) dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()}) dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size, num_workers=args.workers) scores = predict_captions(model, dict_dataloader_test, text_field) print(scores) # python test_custom.py --features_path features_path --weights weights (--d_in d_in)
def train(opt): # Load data loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length # Tensorboard summaries (they're great!) tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) # Load pretrained model, info file, histories file infos = {} histories = {} if opt.start_from is not None: with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] need_be_same = ["rnn_type", "rnn_size", "num_layers"] for checkme in need_be_same: assert vars(saved_model_opt)[checkme] == vars( opt )[checkme], "Command line argument and saved model disagree on '%s' " % checkme if os.path.isfile( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')): with open( os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')) as f: histories = cPickle.load(f) iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = histories.get('val_result_history', {}) loss_history = histories.get('loss_history', {}) lr_history = histories.get('lr_history', {}) #ss_prob_history = histories.get('ss_prob_history', {}) loader.iterators = infos.get('iterators', loader.iterators) loader.split_ix = infos.get('split_ix', loader.split_ix) if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) encoder = MemoryAugmentedEncoder( 3, 0, attention_module=ScaledDotProductAttentionMemory, attention_module_kwargs={'m': 40}) decoder = MeshedDecoder(8668, 180, 3, 0) models = Transformer(8667, encoder, decoder) # Create model model = models.cuda() # pretrained_dict = torch.load('log_xe_final_before_review/all2model12000.pth') # model.load_state_dict(pretrained_dict, strict=False) # d_pretrained_dict = torch.load('log_xe_final_before_review/all2d_model12000.pth') # back_model.load_state_dict(d_pretrained_dict, strict=False) optimizer = utils.build_optimizer_adam(models.parameters(), opt) #back_optimizer = utils.build_optimizer(back_model.parameters(), opt) update_lr_flag = True while True: # Update learning rate once per epoch if update_lr_flag: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start ) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate**frac opt.current_lr = opt.learning_rate * decay_factor else: opt.current_lr = opt.learning_rate utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every #opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) #model.ss_prob = opt.ss_prob # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True init_scorer(opt.cached_tokens) else: sc_flag = False update_lr_flag = False # Load data from train split (0) start = time.time() data = loader.get_batch('train') data_time = time.time() - start start = time.time() # Unpack data torch.cuda.synchronize() tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['dist'], data['masks'], data['att_masks'] ] tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] fc_feats, att_feats, labels, dist_label, masks, attmasks = tmp labels = labels.long() nd_labels = labels batchsize = fc_feats.size(0) # Forward pass and loss d_steps = 1 g_steps = 1 #print (torch.sum(labels!=0), torch.sum(masks!=0)) if 1: if 1: model.train() optimizer.zero_grad() wordact = model(att_feats, labels.view(batchsize, -1)) ##mask = masks.view(batchsize,-1) #mask = mask[:,1:].contiguous() wordact_t = wordact[:, :-1, :] # wordact_t = wordact.permute(0, 2, 1).contiguous() wordact_t = wordact_t.contiguous().view( wordact_t.size(0) * wordact_t.size(1), -1) labels_flat = labels.view(batchsize, -1) wordclass_v = labels_flat[:, 1:] wordclass_t = wordclass_v.contiguous().view(\ wordclass_v.size(0) * wordclass_v.size(1), -1) #maskids = torch.nonzero(mask.view(-1).cpu()).numpy().reshape(-1) loss_xe = F.cross_entropy(wordact_t[ ...], \ wordclass_t[...].contiguous().view(-1)) train_loss = loss_xe train_loss.backward() optimizer.step() if 1: if iteration % opt.print_freq == 1: print('Read data:', time.time() - start) if not sc_flag: print("iter {} (epoch {}), train_loss = {:.3f}, data_time = {:.3f}" \ .format(iteration, epoch, loss_xe, data_time)) else: print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \ .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time)) # Update the iteration and epoch iteration += 1 if data['bounds']['wrapped']: epoch += 1 update_lr_flag = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) #add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:, 0]), iteration) loss_history[ iteration] = train_loss if not sc_flag else np.mean( reward[:, 0]) lr_history[iteration] = opt.current_lr #ss_prob_history[iteration] = model.ss_prob # Validate and save model if (iteration % opt.save_checkpoint_every == 0): checkpoint_path = os.path.join( opt.checkpoint_path, 'all2model{:05d}.pth'.format(iteration)) torch.save(model.state_dict(), checkpoint_path) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Evaluate model if (iteration % 20000 == 0): eval_kwargs = {'split': 'test', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split( model, crit, loader, eval_kwargs) # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) if lang_stats is not None: for k, v in lang_stats.items(): add_summary_value(tb_summary_writer, k, v, iteration) val_result_history[iteration] = { 'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions } # Our metric is CIDEr if available, otherwise validation loss if opt.language_eval == 1: current_score = lang_stats['CIDEr'] else: current_score = -val_loss # Save model in checkpoint path best_flag = False if best_val_score is None or current_score > best_val_score: best_val_score = current_score best_flag = True checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') torch.save(model.state_dict(), checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'd_model.pth') torch.save(back_model.state_dict(), checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'dis_model.pth') torch.save(dis_model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') torch.save(optimizer.state_dict(), optimizer_path) # Dump miscalleous informations infos['iter'] = iteration infos['epoch'] = epoch infos['iterators'] = loader.iterators infos['split_ix'] = loader.split_ix infos['best_val_score'] = best_val_score infos['opt'] = opt infos['vocab'] = loader.get_vocab() histories['val_result_history'] = val_result_history histories['loss_history'] = loss_history histories['lr_history'] = lr_history #histories['ss_prob_history'] = ss_prob_history with open( os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(infos, f) with open( os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f: cPickle.dump(histories, f) # Save model to unique file if new best model if best_flag: model_fname = 'model-best-i{:05d}-score{:.4f}.pth'.format( iteration, best_val_score) infos_fname = 'model-best-i{:05d}-infos.pkl'.format(iteration) checkpoint_path = os.path.join(opt.checkpoint_path, model_fname) torch.save(model.state_dict(), checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'd_model-best.pth') torch.save(back_model.state_dict(), checkpoint_path) checkpoint_path = os.path.join(opt.checkpoint_path, 'dis_model-best.pth') torch.save(dis_model.state_dict(), checkpoint_path) print("model saved to {}".format(checkpoint_path)) with open(os.path.join(opt.checkpoint_path, infos_fname), 'wb') as f: cPickle.dump(infos, f)
# Create the dataset dataset = ArtEmis(image_field, text_field, emotion_field, args.annotation_folder) _, _, test_dataset = dataset.splits text_field.vocab = pickle.load(open('vocab_%s.pkl' % args.exp_name, 'rb')) # Model and dataloaders emotion_dim = 0 emotion_encoder = None if args.use_emotion_labels: emotion_dim = 10 emotion_encoder = torch.nn.Sequential(torch.nn.Linear(9, emotion_dim)) emotion_encoder.to(device) encoder = MemoryAugmentedEncoder( 3, 0, attention_module=ScaledDotProductAttentionMemory, attention_module_kwargs={'m': 40}, d_in=2048 + emotion_dim) decoder = MeshedDecoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi['<pad>']) model = Transformer(text_field.vocab.stoi['<bos>'], encoder, decoder).to(device) fname = 'saved_models/%s_best.pth' % args.exp_name data = torch.load(fname) model.load_state_dict(data['state_dict']) if emotion_encoder is not None: emotion_encoder.to(device) fname = 'saved_models/%s_emo_best.pth' % args.exp_name data = torch.load(fname)