def evaluate(self): '''Model evaluation''' self.log.add("| Calculating validation error") self.image_encoder.eval() self.text_encoder.eval() s_total_loss = 0 w_total_loss = 0 for step, data in enumerate(self.data_loader_val, 0): real_imgs, captions, cap_lens, class_ids, keys = prepare_data(data) words_features, sent_code = self.image_encoder(real_imgs[-1]) # nef = words_features.size(1) # words_features = words_features.view(batch_size, nef, -1) hidden = self.text_encoder.init_hidden(self.batch_size) words_emb, sent_emb = self.text_encoder(captions, cap_lens, hidden) w_loss0, w_loss1, attn = words_loss(words_features, words_emb, self.labels, cap_lens, class_ids, self.batch_size) w_total_loss += (w_loss0 + w_loss1).data s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, self.labels, class_ids, self.batch_size) s_total_loss += (s_loss0 + s_loss1).data if step == 50: break s_cur_loss = s_total_loss.item()/ step w_cur_loss = w_total_loss.item() / step return s_cur_loss, w_cur_loss
def save_img_results(self, netG, noise, sent_emb, words_embs, mask, image_encoder, captions, cap_lens, gen_iterations, name='current'): '''Save generator results''' self.log.add("| Saving {} result images ... ".format(name)) fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) for i in range(len(attention_maps)): if len(fake_imgs) > 1: img = fake_imgs[i + 1].detach().cpu() lr_img = fake_imgs[i].detach().cpu() else: img = fake_imgs[0].detach().cpu() lr_img = None attn_maps = attention_maps[i] att_sze = attn_maps.size(2) img_set, _ = build_super_images(img, captions, self.ixtoword, attn_maps, att_sze, lr_imgs=lr_img) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/G_%s_%d_%d.png' % (self.image_dir, name, gen_iterations, i) im.save(fullpath) i = -1 img = fake_imgs[i].detach() region_features, _ = image_encoder(img) att_sze = region_features.size(2) _, _, att_maps = words_loss(region_features.detach(), words_embs.detach(), None, cap_lens, None, self.batch_size) img_set, _ = build_super_images(fake_imgs[i].detach().cpu(), captions, self.ixtoword, att_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/D_%s_%d.png' % (self.image_dir, name, gen_iterations) im.save(fullpath)
def train_step(self, epoch): self.image_encoder.train() self.text_encoder.train() s_total_loss0 = 0 s_total_loss1 = 0 s_loss_step = 0. w_total_loss0 = 0 w_total_loss1 = 0 w_loss_step = 0. batch_num = len(self.data_loader) count = (epoch + 1) * batch_num start_time = time.time() for step, data in enumerate(self.data_loader, 0): # self.save_raw_data(step, data) # DEBUG self.text_encoder.zero_grad() self.image_encoder.zero_grad() imgs, captions, cap_lens, class_ids, keys = prepare_data(data) # words_features: batch_size x nef x 17 x 17 # sent_code: batch_size x nef words_features, sent_code = self.image_encoder(imgs[-1]) # --> batch_size x nef x 17*17 nef, att_sze = words_features.size(1), words_features.size(2) # words_features = words_features.view(batch_size, nef, -1) hidden = self.text_encoder.init_hidden(self.batch_size) # words_emb: batch_size x nef x seq_len # sent_emb: batch_size x nef words_emb, sent_emb = self.text_encoder(captions, cap_lens, hidden) w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, self.labels, cap_lens, class_ids, self.batch_size) w_total_loss0 += w_loss0.data w_total_loss1 += w_loss1.data loss = w_loss0 + w_loss1 s_loss0, s_loss1 = \ sent_loss(sent_code, sent_emb, self.labels, class_ids, self.batch_size) loss += s_loss0 + s_loss1 s_total_loss0 += s_loss0.data s_total_loss1 += s_loss1.data # w_loss_step += w_loss0.data + w_loss1.data s_loss_step += s_loss0.data + s_loss1.data loss.backward() # # `clip_grad_norm` helps prevent # the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm_(self.text_encoder.parameters(), cfg.TRAIN.RNN_GRAD_CLIP) self.optimizer.step() if step % self.update_interval == 0: count = epoch * batch_num + step s_cur_loss0 = s_total_loss0[0] / self.update_interval s_cur_loss1 = s_total_loss1[0] / self.update_interval w_cur_loss0 = w_total_loss0[0] / self.update_interval w_cur_loss1 = w_total_loss1[0] / self.update_interval elapsed = time.time() - start_time self.log.add('| Epoch {:3d} | bt {:3d}/{:3d} | ms/bt {:5.2f} | S_loss {:2.4f} {:2.4f} | W_loss {:2.4f} {:5.4f} | Time {:5.2f}s' .format(epoch, step, len(self.data_loader), elapsed * 1000. / self.update_interval, s_cur_loss0, s_cur_loss1, w_cur_loss0, w_cur_loss1, elapsed)) s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 start_time = time.time() # Attention Maps img_set, _ = build_super_images(imgs[-1].cpu(), captions, self.ixtoword, attn_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/attention_maps_%d_%d.png' % (self.image_dir, epoch, step) im.save(fullpath) s_loss_step /= batch_num w_loss_step /= batch_num return count, s_loss_step, w_loss_step