def evaluate(self, cnn_model, trx_model, batch_size): cnn_model.eval() trx_model.eval() # cap_model.eval() ### s_total_loss = 0 w_total_loss = 0 s_t_total_loss = 0 w_t_total_loss = 0 ### add caption criterion here. ##### labels = Variable(torch.LongTensor( range(batch_size))) # used for matching loss if cfg.CUDA: labels = labels.cuda() ##################################### val_data_iter = iter(self.dataloader_val) for step in tqdm(range(len(val_data_iter)), leave=False): real_imgs, captions, masks, class_ids, cap_lens = val_data_iter.next( ) class_ids = class_ids.numpy() ids = np.array(list(range(batch_size))) neg_ids = Variable( torch.LongTensor([ np.random.choice(ids[ids != x]) for x in ids ])) # used for matching loss if cfg.CUDA: real_imgs, captions, masks, cap_lens = real_imgs.cuda( ), captions.cuda(), masks.cuda(), cap_lens.cuda() neg_ids = neg_ids.cuda() words_features, sent_code = cnn_model(real_imgs) words_emb, sent_emb = trx_model(captions, masks) w_loss0, w_loss1, attn = words_loss(words_features, words_emb[:, :, 1:], labels, cap_lens - 1, class_ids, batch_size) w_total_loss += (w_loss0 + w_loss1).data s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) s_total_loss += (s_loss0 + s_loss1).data w_t_loss0, w_t_loss1, _ = words_triplet_loss( words_features, words_emb[:, :, 1:], labels, neg_ids, cap_lens - 1, batch_size) w_t_total_loss += (w_t_loss0 + w_t_loss1).data s_t_loss0, s_t_loss1 = sent_triplet_loss(sent_code, sent_emb, labels, neg_ids, batch_size) s_t_total_loss += (s_t_loss0 + s_t_loss1).data s_cur_loss = s_total_loss / (step + 1) w_cur_loss = w_total_loss / (step + 1) s_t_cur_loss = s_t_total_loss / (step + 1) w_t_cur_loss = w_t_total_loss / (step + 1) return s_cur_loss, w_cur_loss, s_t_cur_loss, w_t_cur_loss
def train(self): now = datetime.datetime.now(dateutil.tz.tzlocal()) timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') # LAMBDA_FT,LAMBDA_FI,LAMBDA_DAMSM=01,50,10 tb_dir = '../tensorboard/{0}_{1}_{2}'.format(cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) mkdir_p(tb_dir) tbw = SummaryWriter(log_dir=tb_dir) # Tensorboard logging ####### init models ######## text_encoder, image_encoder, start_epoch, = self.build_models() labels = Variable(torch.LongTensor(range( self.batch_size))) # used for matching loss text_encoder.train() image_encoder.train() ############################################################### ###### init optimizers ##### optimizerI, optimizerT, lr_schedulerI, lr_schedulerT = self.define_optimizers( image_encoder, text_encoder) ############################################ ##### init data ############################# match_labels = self.prepare_labels() batch_size = self.batch_size ################################################################## ###### init caption model criterion ############ if cfg.CUDA: labels = labels.cuda() ################################################# tensorboard_step = 0 gen_iterations = 0 # gen_iterations = start_epoch * self.num_batches #### print lambdas ### # print('LAMBDA_GEN:{0},LAMBDA_CAP:{1},LAMBDA_FT:{2},LAMBDA_FI:{3},LAMBDA_DAMSM:{4}'.format(cfg.TRAIN.SMOOTH.LAMBDA_GEN # ,cfg.TRAIN.SMOOTH.LAMBDA_CAP # ,cfg.TRAIN.SMOOTH.LAMBDA_FT # ,cfg.TRAIN.SMOOTH.LAMBDA_FI # ,cfg.TRAIN.SMOOTH.LAMBDA_DAMSM)) best_val_loss = 100.0 for epoch in range(start_epoch, self.max_epoch): ##### set everything to trainable #### text_encoder.train() image_encoder.train() #################################### ####### init loss variables ############ s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 s_t_total_loss0 = 0 s_t_total_loss1 = 0 w_t_total_loss0 = 0 w_t_total_loss1 = 0 total_damsm_loss = 0 total_t_loss = 0 total_combo_loss = 0 ####### print out lr of each optimizer before training starts, make sure lrs are correct ######### print('Learning rates: lr_i %.7f, lr_t %.7f' % (optimizerI.param_groups[0]['lr'], optimizerT.param_groups[0]['lr'])) ######################################################################################### start_t = time.time() data_iter = iter(self.data_loader) # step = 0 pbar = tqdm(range(self.num_batches)) for step in pbar: # while step < self.num_batches: ###################################################### # (1) Prepare training data and Compute text embeddings ###################################################### imgs, captions, masks, class_ids, cap_lens = data_iter.next() class_ids = class_ids.numpy() ids = np.array(list(range(batch_size))) neg_ids = Variable( torch.LongTensor([ np.random.choice(ids[ids != x]) for x in ids ])) # used for matching loss if cfg.CUDA: imgs, captions, masks, cap_lens = imgs.cuda( ), captions.cuda(), masks.cuda(), cap_lens.cuda() neg_ids = neg_ids.cuda() # add images, image masks, captions, caption masks for catr model ################## feedforward damsm model ################## image_encoder.zero_grad() # image/text encoders zero_grad here text_encoder.zero_grad() words_features, sent_code = image_encoder( imgs) # input images to image encoder, feedforward nef, att_sze = words_features.size(1), words_features.size(2) # hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, masks) #### damsm losses s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) s_total_loss0 += s_loss0.item() s_total_loss1 += s_loss1.item() damsm_loss = s_loss0 + s_loss1 w_loss0, w_loss1, attn_maps = words_loss( words_features, words_embs[:, :, 1:], labels, cap_lens - 1, class_ids, batch_size) w_total_loss0 += w_loss0.item() w_total_loss1 += w_loss1.item() damsm_loss += w_loss0 + w_loss1 total_damsm_loss += damsm_loss.item() #### triplet loss s_t_loss0, s_t_loss1 = sent_triplet_loss( sent_code, sent_emb, labels, neg_ids, batch_size) s_t_total_loss0 += s_t_loss0.item() s_t_total_loss1 += s_t_loss1.item() t_loss = s_t_loss0 + s_t_loss1 w_t_loss0, w_t_loss1, attn_maps = words_triplet_loss( words_features, words_embs[:, :, 1:], labels, neg_ids, cap_lens - 1, batch_size) w_t_total_loss0 += w_t_loss0.item() w_t_total_loss1 += w_t_loss1.item() t_loss += w_t_loss0 + w_t_loss1 total_t_loss += t_loss.item() ############################################################################ damsm_triplet_combo_loss = cfg.LAMBDA_DAMSM * damsm_loss + cfg.LAMBDA_TRIPLET * t_loss total_combo_loss += damsm_triplet_combo_loss.item() # damsm_loss.backward() # t_loss.backward() damsm_triplet_combo_loss.backward() torch.nn.utils.clip_grad_norm_(image_encoder.parameters(), cfg.clip_max_norm) optimizerI.step() torch.nn.utils.clip_grad_norm_(text_encoder.parameters(), cfg.clip_max_norm) optimizerT.step() ##################### loss values for each step ######################################### ## damsm ## tbw.add_scalar('Train_step/train_w_step_loss0', float(w_loss0.item()), step + epoch * self.num_batches) tbw.add_scalar('Train_step/train_s_step_loss0', float(s_loss0.item()), step + epoch * self.num_batches) tbw.add_scalar('Train_step/train_w_step_loss1', float(w_loss1.item()), step + epoch * self.num_batches) tbw.add_scalar('Train_step/train_s_step_loss1', float(s_loss1.item()), step + epoch * self.num_batches) tbw.add_scalar('Train_step/train_damsm_step_loss', float(damsm_loss.item()), step + epoch * self.num_batches) ## triplet ## tbw.add_scalar('Train_step/train_w_t_step_loss0', float(w_t_loss0.item()), step + epoch * self.num_batches) tbw.add_scalar('Train_step/train_s_t_step_loss0', float(s_t_loss0.item()), step + epoch * self.num_batches) tbw.add_scalar('Train_step/train_w_t_step_loss1', float(w_t_loss1.item()), step + epoch * self.num_batches) tbw.add_scalar('Train_step/train_s_t_step_loss1', float(s_t_loss1.item()), step + epoch * self.num_batches) tbw.add_scalar('Train_step/train_t_step_loss', float(t_loss.item()), step + epoch * self.num_batches) ################################################################################################ ############ tqdm descriptions showing running average loss in terminal ############################## # pbar.set_description('damsm %.5f' % ( float(total_damsm_loss) / (step+1))) pbar.set_description('combo_loss %.5f' % (float(total_combo_loss) / (step + 1))) ###################################################################################################### ########################################################## v_s_cur_loss, v_w_cur_loss, v_s_t_cur_loss, v_w_t_cur_loss = self.evaluate( image_encoder, text_encoder, self.val_batch_size) print( '[epoch: %d] val_w_loss: %.4f, val_s_loss: %.4f, val_w_t_loss: %.4f, val_s_t_loss: %.4f' % (epoch, v_w_cur_loss, v_s_cur_loss, v_w_t_cur_loss, v_s_t_cur_loss)) print('-' * 80) ### val losses ### tbw.add_scalar('Val_step/val_w_loss', float(v_w_cur_loss), epoch) tbw.add_scalar('Val_step/val_s_loss', float(v_s_cur_loss), epoch) tbw.add_scalar('Val_step/val_w_t_loss', float(v_w_t_cur_loss), epoch) tbw.add_scalar('Val_step/val_s_t_loss', float(v_s_t_cur_loss), epoch) lr_schedulerI.step() lr_schedulerT.step() end_t = time.time() total_val_loss = (float(v_w_cur_loss) + float(v_s_cur_loss) + float(v_w_t_cur_loss) + float(v_s_t_cur_loss)) / 4.0 if total_val_loss < best_val_loss: best_val_loss = total_val_loss self.save_model(image_encoder, text_encoder, optimizerI, optimizerT, lr_schedulerI, lr_schedulerT, epoch)