def evaluate(dataloader, cnn_model, rnn_model, batch_size): cnn_model.eval() rnn_model.eval() s_total_loss = 0 w_total_loss = 0 for step, data in enumerate(dataloader, 0): real_imgs, captions, cap_lens, \ class_ids, keys = data data = (real_imgs, captions, cap_lens, class_ids, keys) real_imgs, captions, cap_lens, \ class_ids, keys = prepare_data(data) real_imgs = real_imgs words_features, sent_code = cnn_model(real_imgs[-1]) # nef = words_features.size(1) # words_features = words_features.view(batch_size, nef, -1) if isinstance(rnn_model, torch.nn.parallel.DistributedDataParallel): hidden = rnn_model.model.init_hidden(batch_size) else: hidden = rnn_model.init_hidden(batch_size) words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size) w_total_loss += (w_loss0 + w_loss1).data s_loss0, s_loss1 = \ sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) s_total_loss += (s_loss0 + s_loss1).data s_cur_loss = s_total_loss.item() / step w_cur_loss = w_total_loss.item() / step return s_cur_loss, w_cur_loss
def evaluate(dataloader, cnn_model, rnn_model, batch_size): cnn_model.eval() rnn_model.eval() s_total_loss = 0 w_total_loss = 0 for step, data in enumerate(dataloader, 0): real_imgs, captions, cap_lens, \ class_ids, keys = prepare_data(data) words_features, sent_code = cnn_model(real_imgs[-1]) # nef = words_features.size(1) # words_features = words_features.view(batch_size, nef, -1) # hidden = rnn_model.init_hidden(batch_size) hidden = rnn_model.init_hidden() words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size) w_total_loss += (w_loss0 + w_loss1).data s_loss0, s_loss1 = \ sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) s_total_loss += (s_loss0 + s_loss1).data if step == 50: break s_cur_loss = s_total_loss[0] / step w_cur_loss = w_total_loss[0] / step return s_cur_loss, w_cur_loss
def evaluate(self, dataloader, image_encoder, text_encoder): image_encoder.eval() text_encoder.eval() s_total_loss = 0 w_total_loss = 0 for step, data in enumerate(dataloader, 0): real_imgs, captions, class_ids, input_mask = prepare_data( data, self.device) words_features, sent_code = image_encoder(real_imgs[-1]) batch_size = words_features.size(0) words_emb, sent_emb = self.text_enc_forward( text_encoder, captions, input_mask) labels = Variable(torch.LongTensor(range(batch_size))).to( self.device) w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels, class_ids, batch_size) w_total_loss += (w_loss0 + w_loss1).data s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) s_total_loss += (s_loss0 + s_loss1).data s_cur_loss = s_total_loss.item() / step w_cur_loss = w_total_loss.item() / step return s_cur_loss, w_cur_loss
def save_img_results(self, netG, noise, sent_emb, words_embs, mask, image_encoder, captions, epoch, step, name='current'): # Save images fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) for i in range(len(attention_maps)): if len(fake_imgs) > 1: img = fake_imgs[i + 1].detach().cpu() lr_img = fake_imgs[i].detach().cpu() else: img = fake_imgs[0].detach().cpu() lr_img = None attn_maps = attention_maps[i] att_sze = attn_maps.size(2) img_set, _ = \ build_super_images(img, captions, self.ixtoword, attn_maps, att_sze, lr_imgs=lr_img, batch_size=self.batch_size, max_word_num=18) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/G_%s_%d_%d_%d.png'\ % (self.image_dir, name, epoch, step, i) im.save(fullpath) # for i in range(len(netsD)): i = -1 img = fake_imgs[i].detach() region_features, _ = image_encoder(img) att_sze = region_features.size(2) _, _, att_maps = words_loss(region_features.detach(), words_embs.detach(), None, None, self.batch_size) img_set, _ = \ build_super_images(fake_imgs[i].detach().cpu(), captions, self.ixtoword, att_maps, att_sze, None, batch_size=self.batch_size, max_word_num=18) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/D_%s_%d_%d.png'\ % (self.image_dir, name, epoch, step) im.save(fullpath)
def test(dataloader, cnn_model, rnn_model, batch_size, labels, ixtoword, image_dir, input_channels): cnn_model.eval() rnn_model.eval() # s_total_loss0 = 0 # s_total_loss1 = 0 # w_total_loss0 = 0 # w_total_loss1 = 0 # count = (epoch + 1) * len(dataloader) # start_time = time.time() for step, data in enumerate(dataloader, 0): # print('step', step) # rnn_model.zero_grad() # cnn_model.zero_grad() imgs, captions, cap_lens, \ class_ids, keys = prepare_data(data) # words_features: batch_size x nef x 17 x 17 # sent_code: batch_size x nef words_features, sent_code = cnn_model(imgs[-1]) # --> batch_size x nef x 17*17 nef, att_len, att_sze = words_features.size(1), words_features.size( 2), words_features.size(3) # words_features = words_features.view(batch_size, nef, -1) # hidden = rnn_model.init_hidden(batch_size) hidden = rnn_model.init_hidden() # words_emb: batch_size x nef x seq_len # sent_emb: batch_size x nef words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size) img_set, _ = \ build_super_images(imgs[-1].cpu(), captions, ixtoword, attn_maps, att_len, att_sze, input_channels) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/attention_maps%d.png' % (image_dir, step) # fullpath = '%s/attention_maps%d.png' % (image_dir, count) im.save(fullpath) # return count return step
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer, epoch, ixtoword, image_dir): cnn_model.train() rnn_model.train() s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 count = (epoch + 1) * len(dataloader) start_time = time.time() print("start training...") for step, data in enumerate(dataloader, 0): # print('step', step) rnn_model.zero_grad() cnn_model.zero_grad() imgs, captions, cap_lens, \ class_ids, keys = prepare_data(data) imgs = imgs # words_features: batch_size x nef x 17 x 17 # sent_code: batch_size x nef words_features, sent_code = cnn_model(imgs[-1]) # --> batch_size x nef x 17*17 nef, att_sze = words_features.size(1), words_features.size(2) # words_features = words_features.view(batch_size, nef, -1) if isinstance(rnn_model, torch.nn.parallel.DistributedDataParallel): hidden = rnn_model.model.init_hidden(batch_size) else: hidden = rnn_model.init_hidden(batch_size) # words_emb: batch_size x nef x seq_len # sent_emb: batch_size x nef words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) # print("calculate loss...") w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size) w_total_loss0 += w_loss0.data w_total_loss1 += w_loss1.data loss = w_loss0 + w_loss1 s_loss0, s_loss1 = \ sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) loss += s_loss0 + s_loss1 s_total_loss0 += s_loss0.data s_total_loss1 += s_loss1.data # loss.backward() # # `clip_grad_norm` helps prevent # the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm_(rnn_model.parameters(), cfg.TRAIN.RNN_GRAD_CLIP) optimizer.step() if step % UPDATE_INTERVAL == 0: count = epoch * len(dataloader) + step s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 's_loss {:5.2f} {:5.2f} | ' 'w_loss {:5.2f} {:5.2f}' .format(epoch, step, len(dataloader), elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0, s_cur_loss1, w_cur_loss0, w_cur_loss1)) s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 start_time = time.time() # attention Maps img_set, _ = \ build_super_images(imgs[-1].cpu(), captions, ixtoword, attn_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/attention_maps%d.png' % (image_dir, step) im.save(fullpath) return count
def train(self, dataloader, image_encoder, text_encoder, optimizer, epoch, ixtoword, image_dir, batch_size): image_encoder.train() text_encoder.train() s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 count = (epoch + 1) * len(dataloader) start_time = time.time() s_epoch_loss = 0 w_epoch_loss = 0 num_batches = len(dataloader) for step, data in enumerate(dataloader, 0): print('step', step) optimizer.zero_grad() imgs, captions, class_ids, input_mask = prepare_data( data, self.device) # words_features: batch_size x nef x 17 x 17 # sent_code: batch_size x nef words_features, sent_code = image_encoder(imgs[-1]) # --> batch_size x nef x 17*17 batch_size, nef, att_size, _ = words_features.shape # words_features = words_features.view(batch_size, nef, -1) words_emb, sent_emb = self.text_enc_forward( text_encoder, captions, input_mask) labels = Variable(torch.LongTensor(range(batch_size))).to( self.device) w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, class_ids, batch_size) w_total_loss0 += w_loss0.data.item() w_total_loss1 += w_loss1.data.item() loss = w_loss0 + w_loss1 w_epoch_loss += w_loss0.item() + w_loss1.item() s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) loss += s_loss0 + s_loss1 s_total_loss0 += s_loss0.data.item() s_total_loss1 += s_loss1.data.item() s_epoch_loss += s_loss0.item() + s_loss1.item() loss.backward(retain_graph=True) # `clip_grad_norm` helps prevent # the exploding gradient problem in RNNs / LSTMs. if self.opts.TEXT.ENCODER != 'bert': torch.nn.utils.clip_grad_norm_(text_encoder.parameters(), self.opts.TRAIN.RNN_GRAD_CLIP) optimizer.step() if step != 0 and step % self.update_interval == 0: count = epoch * len(dataloader) + step s_cur_loss0 = s_total_loss0 / self.update_interval s_cur_loss1 = s_total_loss1 / self.update_interval w_cur_loss0 = w_total_loss0 / self.update_interval w_cur_loss1 = w_total_loss1 / self.update_interval elapsed = time.time() - start_time print( '| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 's_loss {:5.2f} {:5.2f} | ' 'w_loss {:5.2f} {:5.2f}'.format( epoch, step, len(dataloader), elapsed * 1000. / self.update_interval, s_cur_loss0, s_cur_loss1, w_cur_loss0, w_cur_loss1)) s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 start_time = time.time() if step == num_batches - 1: # attention Maps img_set, _ = build_super_images(imgs[-1].cpu(), captions, ixtoword, attn_maps, att_size, None, batch_size, max_word_num=18) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/attention_maps%d.png' % (image_dir, step) im.save(fullpath) s_epoch_loss /= len(dataloader) w_epoch_loss /= len(dataloader) return count, s_epoch_loss, w_epoch_loss