def save_img_results(self, fake_imgs, attn_maps, bt_attn_maps, captions, cap_lens, gen_iterations): font_max = [50, 50] font_size = [30, 50] batch_size = fake_imgs[0].size(0) # Save images for i in range(len(attn_maps)): if len(fake_imgs) > 1: img = fake_imgs[i + 1].detach().cpu() lr_img = fake_imgs[i].detach().cpu() else: img = fake_imgs[0].detach().cpu() lr_img = None attn_maps = attn_maps[i] att_sze = attn_maps.size(2) img_set, _ = \ build_super_images(img, captions, self.ixtoword, attn_maps, att_sze, lr_imgs=lr_img, font_max=font_max[i], font_size=font_size[i], batch_size=batch_size) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/G_%d_%d.png'\ % (self.snapshot_dir, gen_iterations, i) im.save(fullpath) bt_attn_maps = bt_attn_maps[i] att_sze = bt_attn_maps.size(2) img_set, _ = \ build_super_images(img, captions, self.ixtoword, bt_attn_maps, att_sze, lr_imgs=lr_img, font_max=font_max[i], font_size=font_size[i], batch_size=batch_size) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/bt_G_%d_%d.png'\ % (self.snapshot_dir, gen_iterations, i) im.save(fullpath)
def save_img_results(self, netG, noise, sent_emb, words_embs, mask, image_encoder, captions, cap_lens, gen_iterations, name='current'): # Save images fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) for i in range(len(attention_maps)): if len(fake_imgs) > 1: img = fake_imgs[i + 1].detach().cpu() lr_img = fake_imgs[i].detach().cpu() else: img = fake_imgs[0].detach().cpu() lr_img = None attn_maps = attention_maps[i] att_sze = attn_maps.size(2) img_set, _ = \ build_super_images(img, captions, self.ixtoword, attn_maps, att_sze, lr_imgs=lr_img) if img_set is not None: im = Image.fromarray(img_set) myDriveAttnGanImage = '/content/drive/My Drive/cubImageGAN' fullpath = '%s/G_%s_%d_%d.png' % (self.image_dir, name, gen_iterations, i) fullpathDrive = '%s/G_%s_%d_%d.png' % (myDriveAttnGanImage, name, gen_iterations, i) im.save(fullpath) im.save(fullpathDrive) # for i in range(len(netsD)): i = -1 img = fake_imgs[i].detach() region_features, _ = image_encoder(img) att_sze = region_features.size(2) _, _, att_maps = words_loss(region_features.detach(), words_embs.detach(), None, cap_lens, None, self.batch_size) img_set, _ = \ build_super_images(fake_imgs[i].detach().cpu(), captions, self.ixtoword, att_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) myDriveAttnGanImage = '/content/drive/My Drive/cubImageGAN' fullpath = '%s/D_%s_%d.png' % (self.image_dir, name, gen_iterations) fullpathDrive = '%s/D_%s_%d.png' % (myDriveAttnGanImage, name, gen_iterations) im.save(fullpath) im.save(fullpathDrive)
def save_img_results(self, netG, noise, sent_emb, words_embs, mask, image_encoder, captions, cap_lens, gen_iterations, transf_matrices_inv, label_one_hot, name='current'): # Save images # fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) inputs = (noise, sent_emb, words_embs, mask, transf_matrices_inv, label_one_hot) fake_imgs, attention_maps, _, _ = nn.parallel.data_parallel( netG, inputs, self.gpus) for i in range(len(attention_maps)): if len(fake_imgs) > 1: img = fake_imgs[i + 1].detach().cpu() lr_img = fake_imgs[i].detach().cpu() else: img = fake_imgs[0].detach().cpu() lr_img = None attn_maps = attention_maps[i] att_sze = attn_maps.size(2) img_set, _ = \ build_super_images(img, captions, self.ixtoword, attn_maps, att_sze, lr_imgs=lr_img) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/G_%s_%d_%d.png'\ % (self.image_dir, name, gen_iterations, i) im.save(fullpath) # for i in range(len(netsD)): i = -1 img = fake_imgs[i].detach() region_features, _ = image_encoder(img) att_sze = region_features.size(2) _, _, att_maps = words_loss(region_features.detach(), words_embs.detach(), None, cap_lens, None, self.batch_size) img_set, _ = \ build_super_images(fake_imgs[i].detach().cpu(), captions, self.ixtoword, att_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/D_%s_%d.png'\ % (self.image_dir, name, gen_iterations) im.save(fullpath)
def save_img_results(self, netG, noise, sent_emb, words_embs, mask, image_encoder, captions, cap_lens, gen_iterations, name='current'): # Save images fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) for i in range(len(attention_maps)): if len(fake_imgs) > 1: img = fake_imgs[i + 1].detach().cpu() lr_img = fake_imgs[i].detach().cpu() else: img = fake_imgs[0].detach().cpu() lr_img = None attn_maps = attention_maps[i] att_sze = attn_maps.size(2) # print(img.shape, lr_img.shape, attn_maps.shape) # debug img_set, _ = \ build_super_images(img[:, :3], captions, self.ixtoword, attn_maps, att_sze, lr_imgs=lr_img[:, :3]) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/G_%s_%d_%d.png'\ % (self.image_dir, name, gen_iterations, i) im.save(fullpath) # for i in range(len(netsD)): i = -1 img = fake_imgs[i].detach() region_features, _ = image_encoder(img) att_sze = region_features.size(2) _, _, att_maps = words_loss(region_features.detach(), words_embs.detach(), None, cap_lens, None, self.batch_size) img_set, _ = \ build_super_images(fake_imgs[i][:, :3].detach().cpu(), captions, self.ixtoword, att_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/D_%s_%d.png'\ % (self.image_dir, name, gen_iterations) im.save(fullpath) self.writer.add_image(tag="image_attn", img_tensor=transforms.ToTensor()(im), global_step=gen_iterations)
def save_img_results(self, netG, noise, sent_emb, words_embs, mask, image_encoder, gen_iterations, name='current'): # Save images if cfg.CUDA: caption = Variable(torch.tensor([])).cuda() else: caption = Variable(torch.tensor([])) fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) for i in range(len(attention_maps)): if len(fake_imgs) > 1: img = fake_imgs[i + 1].detach().cpu() lr_img = fake_imgs[i].detach().cpu() else: img = fake_imgs[0].detach().cpu() lr_img = None attn_maps = attention_maps[i] att_sze = attn_maps.size(2) img_set, _ = \ build_super_images(img, caption, {}, attn_maps, att_sze, lr_imgs=lr_img) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/G_%s_%d_%d.png'\ % (self.image_dir, name, gen_iterations, i) im.save(fullpath) # for i in range(len(netsD)): i = -1 img = fake_imgs[i].detach() region_features, _ = image_encoder(img) att_sze = region_features.size(2) _, _, att_maps = words_loss(region_features.detach(), words_embs, None, 0, None, self.batch_size) img_set, _ = \ build_super_images(fake_imgs[i].detach().cpu(), caption, {}, att_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/D_%s_%d.png'\ % (self.image_dir, name, gen_iterations) im.save(fullpath)
def save_img_results(self, netG, noise, sent_emb, words_embs, mask, image_encoder, captions, cap_lens, gen_iterations, transf_matrices_inv, label_one_hot, local_noise, transf_matrices, max_objects, subset_idx, name='current'): # Save images inputs = (noise, local_noise, sent_emb, words_embs, mask, transf_matrices, transf_matrices_inv, label_one_hot, max_objects) fake_imgs, attention_maps, _, _ = netG(*inputs) for i in range(len(attention_maps)): if len(fake_imgs) > 1: img = fake_imgs[i + 1].detach().cpu() lr_img = fake_imgs[i].detach().cpu() else: img = fake_imgs[0].detach().cpu() lr_img = None attn_maps = attention_maps[i] att_sze = attn_maps.size(2) img_set, _ = build_super_images(img, captions, self.ixtoword, attn_maps, att_sze, lr_imgs=lr_img, batch_size=self.batch_size[0]) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/G_%s_%d_%d.png' % (self.image_dir, name, gen_iterations, i) im.save(fullpath) # for i in range(len(netsD)): i = -1 img = fake_imgs[i].detach() region_features, _ = image_encoder(img) att_sze = region_features.size(2) if cfg.TRAIN.OPTIMIZE_DATA_LOADING: _, _, att_maps = words_loss(region_features.detach(), words_embs.detach(), None, cap_lens, None, self.batch_size[subset_idx]) else: _, _, att_maps = words_loss(region_features.detach(), words_embs.detach(), None, cap_lens, None, self.batch_size[0]) img_set, _ = build_super_images(fake_imgs[i].detach().cpu(), captions, self.ixtoword, att_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/D_%s_%d.png' % (self.image_dir, name, gen_iterations) im.save(fullpath)
def evaluate(dataloader, cnn_model, rnn_model, batch_size, writer, count, ixtoword, labels, image_dir): cnn_model.eval() rnn_model.eval() s_total_loss = 0 w_total_loss = 0 for step, data in enumerate(dataloader, 0): real_imgs, captions, cap_lens, \ class_ids, keys = prepare_data(data) words_features, sent_code = cnn_model(real_imgs[-1]) # nef = words_features.size(1) # words_features = words_features.view(batch_size, nef, -1) nef, att_sze = words_features.size(1), words_features.size(2) # hidden = rnn_model.init_hidden(batch_size) words_emb, sent_emb = rnn_model(captions, cap_lens) w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size) w_total_loss += (w_loss0 + w_loss1).data s_loss0, s_loss1 = \ sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) s_total_loss += (s_loss0 + s_loss1).item() if step == 50: break s_cur_loss = s_total_loss / step w_cur_loss = w_total_loss / step writer.add_scalars(main_tag="eval_loss", tag_scalar_dict={ 's_loss': s_cur_loss, 'w_loss': w_cur_loss }, global_step=count) # save a image # attention Maps img_set, _ = \ build_super_images(real_imgs[-1][:,:3].cpu(), captions, ixtoword, attn, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/attention_maps_eval_%d.png' % (image_dir, count) im.save(fullpath) writer.add_image(tag="image_DAMSM_eval", img_tensor=transforms.ToTensor()(im), global_step=count) return s_cur_loss, w_cur_loss
def save_img_results(self, netG, noise, sent_emb, words_embs, mask, image_encoder, captions, cap_lens, gen_iterations, name='current'): fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) for i in range(len(attention_maps)): if len(fake_imgs) > 1: img = fake_imgs[i + 1].detach().cpu() lr_img = fake_imgs[i].detach().cpu() else: img = fake_imgs[0].detach().cpu() lr_img = None attn_maps = attention_maps[i] att_size = attn_maps.size(2) img_set, _ = \ build_super_images(img, captions, self.ixtoword, attn_maps, att_size, lr_imgs=lr_img) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/D_%s_%d.png' \ % (self.image_dir, name, gen_iterations) im.save(fullpath)
def save_img_results(self, netG, noise, sent_emb, words_embs, mask, image_encoder, captions, cap_lens, gen_iterations, cnn_code, region_features, real_imgs, netDCM, real_features, name='current'): # Save images fake_imgs, attention_maps, _, _, h_code, c_code = netG( noise, sent_emb, words_embs, mask, cnn_code, region_features) fake_img = netDCM(h_code, real_features, sent_emb, words_embs, mask, c_code) for i in range(len(attention_maps)): if len(fake_imgs) > 1: img = fake_imgs[i + 1].detach().cpu() lr_img = fake_imgs[i].detach().cpu() else: img = fake_imgs[0].detach().cpu() lr_img = None attn_maps = attention_maps[i] att_sze = attn_maps.size(2) img_set, _ = \ build_super_images(img, captions, self.ixtoword, attn_maps, att_sze, lr_imgs=lr_img) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/G_%s_%d_%d.png'\ % (self.image_dir, name, gen_iterations, i) im.save(fullpath) i = -1 img = fake_imgs[i].detach() region_features, _ = image_encoder(img) att_sze = region_features.size(2) _, _, att_maps = words_loss(region_features.detach(), words_embs.detach(), None, cap_lens, None, self.batch_size) img_set, _ = \ build_super_images(fake_imgs[i].detach().cpu(), captions, self.ixtoword, att_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/D_%s_%d.png'\ % (self.image_dir, name, gen_iterations) im.save(fullpath) img_set, _ = \ build_super_images(fake_img.detach().cpu(), captions, self.ixtoword, att_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/C_%s_%d.png'\ % (self.image_dir, name, gen_iterations) im.save(fullpath) '''
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer, epoch, ixtoword, image_dir): cnn_model.train() rnn_model.train() s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 count = (epoch + 1) * len(dataloader) start_time = time.time() for step, data in enumerate(dataloader, 0): # print('step', step) rnn_model.zero_grad() cnn_model.zero_grad() imgs, captions, cap_lens, \ class_ids, keys = prepare_data(data) # words_features: batch_size x nef x 17 x 17 # sent_code: batch_size x nef words_features, sent_code = cnn_model(imgs[-1]) # --> batch_size x nef x 17*17 nef, att_sze = words_features.size(1), words_features.size(2) # words_features = words_features.view(batch_size, nef, -1) hidden = rnn_model.init_hidden(batch_size) # words_emb: batch_size x nef x seq_len # sent_emb: batch_size x nef words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size) w_total_loss0 += w_loss0.data w_total_loss1 += w_loss1.data loss = w_loss0 + w_loss1 s_loss0, s_loss1 = \ sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) loss += s_loss0 + s_loss1 s_total_loss0 += s_loss0.data s_total_loss1 += s_loss1.data # loss.backward() # # `clip_grad_norm` helps prevent # the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm(rnn_model.parameters(), cfg.TRAIN.RNN_GRAD_CLIP) optimizer.step() if step % UPDATE_INTERVAL == 0: count = epoch * len(dataloader) + step s_cur_loss0 = s_total_loss0 / UPDATE_INTERVAL s_cur_loss1 = s_total_loss1 / UPDATE_INTERVAL w_cur_loss0 = w_total_loss0 / UPDATE_INTERVAL w_cur_loss1 = w_total_loss1 / UPDATE_INTERVAL elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 's_loss {:5.2f} {:5.2f} | ' 'w_loss {:5.2f} {:5.2f}'.format( epoch, step, len(dataloader), elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0, s_cur_loss1, w_cur_loss0, w_cur_loss1)) s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 start_time = time.time() # attention Maps img_set, _ = \ build_super_images(imgs[-1].cpu(), captions, ixtoword, attn_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/attention_maps%d.png' % (image_dir, step) im.save(fullpath) return count
def train(dataloader, cnn_model, rnn_model, d_model, batch_size, labels, generator_optimizer, discriminator_optimizer, epoch, ixtoword, image_dir): cnn_model.train() rnn_model.train() s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 d_total_loss = 0 g_total_loss = 0 count = (epoch + 1) * len(dataloader) start_time = time.time() for step, data in enumerate(dataloader, 0): # print('step', step) rnn_model.zero_grad() cnn_model.zero_grad() imgs, captions, cap_lens, class_ids, keys = prepare_data(data) target_classes = torch.LongTensor(class_ids) if cfg.CUDA: target_classes = target_classes.cuda() # words_features: batch_size x nef x 17 x 17 # sent_code: batch_size x nef words_features, sent_code = cnn_model(imgs[-1]) # --> batch_size x nef x 17*17 nef, att_sze = words_features.size(1), words_features.size(2) # words_features = words_features.view(batch_size, nef, -1) hidden = rnn_model.init_hidden(batch_size) # words_emb: batch_size x nef x seq_len # sent_emb: batch_size x nef words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size) loss = w_loss0 + w_loss1 s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) loss += s_loss0 + s_loss1 is_fake_0, pred_class_0 = d_model(sent_code) is_fake_1, pred_class_1 = d_model(sent_emb) g_loss = (F.binary_cross_entropy_with_logits(is_fake_0, torch.zeros_like(is_fake_0)) + F.binary_cross_entropy_with_logits(is_fake_1, torch.zeros_like(is_fake_1)) + F.cross_entropy(pred_class_0, target_classes) + F.cross_entropy(pred_class_1, target_classes)) loss += g_loss * cfg.TRAIN.SMOOTH.SUPERVISED_COEF loss.backward() # # `clip_grad_norm` helps prevent # the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm(rnn_model.parameters(), cfg.TRAIN.RNN_GRAD_CLIP) s_total_loss0 += s_loss0.item() s_total_loss1 += s_loss1.item() w_total_loss0 += w_loss0.item() w_total_loss1 += w_loss1.item() g_total_loss += g_loss.item() generator_optimizer.step() d_model.zero_grad() _, sent_code = cnn_model(imgs[-1]) hidden = rnn_model.init_hidden(batch_size) _, sent_emb = rnn_model(captions, cap_lens, hidden) is_fake_0, pred_class_0 = d_model(sent_code) is_fake_1, pred_class_1 = d_model(sent_emb) d_loss = (F.binary_cross_entropy_with_logits(is_fake_0, torch.zeros_like(is_fake_0)) + F.binary_cross_entropy_with_logits(is_fake_1, torch.ones_like(is_fake_1)) + F.cross_entropy(pred_class_0, target_classes) + F.cross_entropy(pred_class_1, target_classes)) loss = d_loss loss.backward() discriminator_optimizer.step() d_total_loss += d_loss.item() if step % UPDATE_INTERVAL == 0: count = epoch * len(dataloader) + step s_cur_loss0 = s_total_loss0 / UPDATE_INTERVAL s_cur_loss1 = s_total_loss1 / UPDATE_INTERVAL w_cur_loss0 = w_total_loss0 / UPDATE_INTERVAL w_cur_loss1 = w_total_loss1 / UPDATE_INTERVAL d_cur_loss = d_total_loss / UPDATE_INTERVAL g_cur_loss = g_total_loss / UPDATE_INTERVAL elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 's_loss {:5.2f} {:5.2f} | ' 'w_loss {:5.2f} {:5.2f} | d_loss {:5.2f} | g_loss {:5.2f}' .format(epoch, step, len(dataloader), elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0, s_cur_loss1, w_cur_loss0, w_cur_loss1, d_cur_loss, g_cur_loss)) s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 d_total_loss = 0 g_total_loss = 0 start_time = time.time() # attention Maps img_set, _ = \ build_super_images(imgs[-1].cpu(), captions, ixtoword, attn_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/attention_maps%d.png' % (image_dir, step) im.save(fullpath) return count
def train(dataloader, cnn_model, trx_model, batch_size, labels, optimizer, epoch, ixtoword, image_dir): cnn_model.train() trx_model.train() s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 count = (epoch + 1) * len(dataloader) start_time = time.time() for step, data in enumerate(dataloader, 0): print('step:{:6d}|{:3d}'.format(step, len(dataloader)), end='\r') trx_model.zero_grad() cnn_model.zero_grad() imgs, captions, cap_lens, class_ids, keys, _, _, _, _ = prepare_data( data) # words_features: batch_size x nef x 17 x 17 # sent_code: batch_size x nef words_features, sent_code = cnn_model(imgs[-1]) # --> batch_size x nef x 17*17 # print(words_features.shape,sent_code.shape) nef, att_sze = words_features.size(1), words_features.size(2) # words_features = words_features.view(batch_size, nef, -1) # print('nef:{0},att_sze:{1}'.format(nef,att_sze)) # hidden = trx_model.init_hidden(batch_size) # words_emb: batch_size x nef x seq_len # sent_emb: batch_size x nef # print('captions:',captions, captions.size()) # words_emb: batch_size x nef x seq_len # sent_emb: batch_size x nef # words_emb, sent_emb = trx_model(captions, cap_lens, hidden) words_emb, sent_emb = trx_model(captions) # print('words_emb:',words_emb.size(),', sent_emb:', sent_emb.size()) w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size) # print(w_loss0.data) # print('--------------------------') # print(w_loss1.data) # print('--------------------------') # print(attn_maps[0].shape) w_total_loss0 += w_loss0.data w_total_loss1 += w_loss1.data loss = w_loss0 + w_loss1 # print(loss) s_loss0, s_loss1 = \ sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) loss += s_loss0 + s_loss1 s_total_loss0 += s_loss0.data s_total_loss1 += s_loss1.data # print(s_total_loss0[0],s_total_loss1[0]) # loss.backward() # # `clip_grad_norm` helps prevent # the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm_(trx_model.parameters(), cfg.TRAIN.RNN_GRAD_CLIP) optimizer.step() if (step % UPDATE_INTERVAL == 0 or step == (len(dataloader) - 1)) and step > 0: count = epoch * len(dataloader) + step # print(count) s_cur_loss0 = s_total_loss0 / UPDATE_INTERVAL s_cur_loss1 = s_total_loss1 / UPDATE_INTERVAL w_cur_loss0 = w_total_loss0 / UPDATE_INTERVAL w_cur_loss1 = w_total_loss1 / UPDATE_INTERVAL elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 's_loss {:5.2f} {:5.2f} | ' 'w_loss {:5.2f} {:5.2f}'.format( epoch, step, len(dataloader), elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0, s_cur_loss1, w_cur_loss0, w_cur_loss1)) tbw.add_scalar('Birds_Train/train_w_loss0', float(w_cur_loss0.item()), epoch) tbw.add_scalar('Birds_Train/train_s_loss0', float(s_cur_loss0.item()), epoch) tbw.add_scalar('Birds_Train/train_w_loss1', float(w_cur_loss1.item()), epoch) tbw.add_scalar('Birds_Train/train_s_loss1', float(s_cur_loss1.item()), epoch) s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 start_time = time.time() # attention Maps # print(imgs[-1].cpu().shape, captions.shape, len(attn_maps),attn_maps[-1].shape, att_sze) img_set, _ = \ build_super_images(imgs[-1].cpu(), captions, ixtoword, attn_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '{0}/attention_maps_e{1}_s{2}.png'.format( image_dir, epoch, step) im.save(fullpath) return count
def train(dataloader, cnn_model, nlp_model, text_encoder_type, batch_size, labels, optimizer, epoch, ixtoword, image_dir): cnn_model.train() nlp_model.train() text_encoder_type = text_encoder_type.casefold() assert text_encoder_type in ( 'rnn', 'transformer', ) s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 count = (epoch + 1) * len(dataloader) start_time = time.time() for step, data in enumerate(dataloader, 0): # print('step', step) nlp_model.zero_grad() cnn_model.zero_grad() imgs, captions, cap_lens, \ class_ids, keys = prepare_data( data ) # words_features: batch_size x nef x 17 x 17 # sent_code: batch_size x nef words_features, sent_code = cnn_model(imgs[-1]) # print( words_features.shape, sent_code.shape ) # --> batch_size x nef x 17*17 nef, att_sze = words_features.size(1), words_features.size(2) # words_features = words_features.view(batch_size, nef, -1) # Forward Prop: # inputs: # captions: torch.LongTensor of ids of size batch x n_steps # outputs: # words_emb: batch_size x nef x seq_len # sent_emb: batch_size x nef if text_encoder_type == 'rnn': hidden = nlp_model.init_hidden(batch_size) words_emb, sent_emb = nlp_model(captions, cap_lens, hidden) elif text_encoder_type == 'transformer': words_emb = nlp_model(captions)[0].transpose(1, 2).contiguous() sent_emb = words_emb[:, :, -1].contiguous() # sent_emb = sent_emb.view(batch_size, -1) # print( words_emb.shape, sent_emb.shape ) # Compute Loss: # NOTE: the ideal loss for Transformer may be different than that for bi-directional LSTM w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size) w_total_loss0 += w_loss0.data w_total_loss1 += w_loss1.data loss = w_loss0 + w_loss1 s_loss0, s_loss1 = \ sent_loss( sent_code, sent_emb, labels, class_ids, batch_size ) loss += s_loss0 + s_loss1 s_total_loss0 += s_loss0.data s_total_loss1 += s_loss1.data # # Backprop: loss.backward() # # `clip_grad_norm` helps prevent # the exploding gradient problem in RNNs / LSTMs. if text_encoder_type == 'rnn': torch.nn.utils.clip_grad_norm(nlp_model.parameters(), cfg.TRAIN.RNN_GRAD_CLIP) optimizer.step() if step % UPDATE_INTERVAL == 0: count = epoch * len(dataloader) + step # print( s_total_loss0, s_total_loss1 ) s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL # print( w_total_loss0, w_total_loss1 ) w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 's_loss {:5.5f} {:5.5f} | ' 'w_loss {:5.5f} {:5.5f}'.format( epoch, step, len(dataloader), elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0, s_cur_loss1, w_cur_loss0, w_cur_loss1)) s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 start_time = time.time() # Attention Maps img_set, _ = \ build_super_images(imgs[-1].cpu(), captions, ixtoword, attn_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/attention_maps%d.png' % (image_dir, step) im.save(fullpath) return count
def save_img_results(self, netG, noise, sent_emb, words_embs, glove_words_embs, clabels_feat, mask, hmaps, rois, fm_rois, num_rois, bt_masks, fm_bt_masks, image_encoder, captions, cap_lens, gen_iterations, name='current'): # Save images glb_max_num_roi = int(torch.max(num_rois)) fake_imgs, _, attention_maps, bt_attention_maps, _, _ = netG( noise, sent_emb, words_embs, glove_words_embs, clabels_feat, mask, hmaps, rois, fm_rois, num_rois, bt_masks, fm_bt_masks, glb_max_num_roi) for i in range(len(attention_maps)): if len(fake_imgs) > 1: img = fake_imgs[i + 1].detach().cpu() lr_img = fake_imgs[i].detach().cpu() else: img = fake_imgs[0].detach().cpu() lr_img = None attn_maps = attention_maps[i] att_sze = attn_maps.size(2) img_set, _ = \ build_super_images(img, captions, self.ixtoword, attn_maps, att_sze, lr_imgs=lr_img) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/G_%s_%d_%d.png'\ % (self.image_dir, name, gen_iterations, i) im.save(fullpath) bt_attn_maps = bt_attention_maps[i] att_sze = bt_attn_maps.size(2) img_set, _ = \ build_super_images(img, captions, self.ixtoword, bt_attn_maps, att_sze, lr_imgs=lr_img) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/bt_G_%s_%d_%d.png'\ % (self.image_dir, name, gen_iterations, i) im.save(fullpath) i = -1 img = fake_imgs[i].detach() region_features, _ = image_encoder(img) att_sze = region_features.size(2) _, _, att_maps, _ = words_loss(region_features.detach(), words_embs.detach(), None, cap_lens, None, self.batch_size) img_set, _ = \ build_super_images(fake_imgs[i].detach().cpu(), captions, self.ixtoword, att_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/D_%s_%d.png'\ % (self.image_dir, name, gen_iterations) im.save(fullpath)
def save_img_results(self, netG, noise, imgs, bbox_maps_fwd, bbox_maps_bwd, bbox_fmaps, hmaps, rois, num_rois, gen_iterations, name='current'): # Save images font_max = 20 font_size = 12 imgs = imgs.cpu() fake_hmaps = netG(noise, bbox_maps_fwd, bbox_maps_bwd, bbox_fmaps) fake_hmaps = fake_hmaps.squeeze().detach().cpu() hmaps = hmaps.squeeze().cpu() # prepare captions batch_size = fake_hmaps.size(0) captions = Variable(torch.zeros(batch_size, cfg.ROI.BOXES_NUM)).cuda() for batch_index in range(self.batch_size): for roi_index in range(num_rois[batch_index]): rela_cat_id = int(rois[batch_index, roi_index, 4]) captions[batch_index, roi_index] = self.cats_dict[rela_cat_id][0] att_sze = fake_hmaps.size(2) img_set, _ = build_super_images(imgs, captions, self.ixtoword, fake_hmaps, att_sze, lr_imgs=None, font_max=font_max, font_size=font_size, max_word_num=cfg.ROI.BOXES_NUM) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/G_%s_%d.png' % (self.image_dir, name, gen_iterations) im.save(fullpath) img_set, _ = build_super_images(imgs, captions, self.ixtoword, hmaps, att_sze, lr_imgs=None, font_max=font_max, font_size=font_size, max_word_num=cfg.ROI.BOXES_NUM) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/D_%s_%d.png' % (self.image_dir, name, gen_iterations) im.save(fullpath) # img_set, _ = build_super_images2(imgs, captions, self.ixtoword, fake_hmaps, att_sze, lr_imgs=None, font_max=font_max, font_size=font_size, max_word_num=cfg.ROI.BOXES_NUM) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/G2_%s_%d.png' % (self.image_dir, name, gen_iterations) im.save(fullpath) img_set, _ = build_super_images2(imgs, captions, self.ixtoword, hmaps, att_sze, lr_imgs=None, font_max=font_max, font_size=font_size, max_word_num=cfg.ROI.BOXES_NUM) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/D2_%s_%d.png' % (self.image_dir, name, gen_iterations) im.save(fullpath)
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer, epoch, ixtoword, image_dir, exp): cnn_model.train() rnn_model.train() s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 count = (epoch + 1) * len(dataloader) start_time = time.time() for step, data in enumerate(dataloader): rnn_model.zero_grad() cnn_model.zero_grad() shape, cap, cap_len, cls_id, key = data sorted_cap_lens, sorted_cap_indices = torch.sort(cap_len, 0, True) #sort shapes = shape[sorted_cap_indices].squeeze() captions = cap[sorted_cap_indices].squeeze() cap_len = cap_len[sorted_cap_indices].squeeze() class_ids = cls_id[sorted_cap_indices].squeeze().numpy() if torch.cuda.is_available(): shapes = shapes.cuda() captions = captions.cuda() #model words_features, sent_code = cnn_model(shapes) nef, att_sze = words_features.size(1), words_features.size(2) hidden = rnn_model.init_hidden(batch_size) words_emb, sent_emb = rnn_model(captions, sorted_cap_lens, hidden) w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, sorted_cap_lens, class_ids, batch_size) w_total_loss0 += w_loss0.item() w_total_loss1 += w_loss1.item() loss = w_loss0 + w_loss1 s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) loss += s_loss0 + s_loss1 s_total_loss0 += s_loss0.item() s_total_loss1 += s_loss1.item() loss.backward() torch.nn.utils.clip_grad_norm(rnn_model.parameters(), cfg.TRAIN.RNN_GRAD_CLIP) optimizer.step() if step % UPDATE_INTERVAL == 0: count = epoch * len(dataloader) + step s_cur_loss0 = s_total_loss0 / UPDATE_INTERVAL s_cur_loss1 = s_total_loss1 / UPDATE_INTERVAL w_cur_loss0 = w_total_loss0 / UPDATE_INTERVAL w_cur_loss1 = w_total_loss1 / UPDATE_INTERVAL elapsed = time.time() - start_time exp.log_metric('s_cur_loss0', s_cur_loss0, step=step, epoch=epoch) exp.log_metric('s_cur_loss1', s_cur_loss1, step=step, epoch=epoch) exp.log_metric('w_cur_loss0', w_cur_loss0, step=step, epoch=epoch) exp.log_metric('w_cur_loss1', w_cur_loss1, step=step, epoch=epoch) print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 's_loss {:5.2f} {:5.2f} | ' 'w_loss {:5.2f} {:5.2f}'.format( epoch, step, len(dataloader), elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0, s_cur_loss1, w_cur_loss0, w_cur_loss1)) s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 start_time = time.time() if step == 1: fullpath = '%s/attention_maps%d' % (image_dir, step) build_super_images(shapes.cpu().detach().numpy(), captions, cap_len, ixtoword, attn_maps, att_sze, exp, fullpath, epoch) return count
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer, epoch, ixtoword, image_dir): train_function_start_time = time.time() cnn_model.train() rnn_model.train() s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 # print("keyword |||||||||||||||||||||||||||||||") # print("len(dataloader) : " , len(dataloader) ) # print(" count = " , (epoch + 1) * len(dataloader) ) # print("keyword |||||||||||||||||||||||||||||||") count = (epoch + 1) * len(dataloader) start_time = time.time() for step, data in enumerate(dataloader, 0): # print('step', step) !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!MARKER!!!!!!!!!!!!!!!!!!!!!!!! rnn_model.zero_grad() cnn_model.zero_grad() imgs, captions, cap_lens, class_ids, keys = prepare_data(data) # words_features: batch_size x nef x 17 x 17 # sent_code: batch_size x nef words_features, sent_code = cnn_model(imgs[-1]) # --> batch_size x nef x 17*17 nef, att_sze = words_features.size(1), words_features.size(2) # words_features = words_features.view(batch_size, nef, -1) hidden = rnn_model.init_hidden(batch_size) # words_emb: batch_size x nef x seq_len # sent_emb: batch_size x nef words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size) w_total_loss0 += w_loss0.data w_total_loss1 += w_loss1.data loss = w_loss0 + w_loss1 s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) loss += s_loss0 + s_loss1 s_total_loss0 += s_loss0.data s_total_loss1 += s_loss1.data # loss.backward() # # `clip_grad_norm` helps prevent # the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm(rnn_model.parameters(), cfg.TRAIN.RNN_GRAD_CLIP) optimizer.step() if step % UPDATE_INTERVAL == 0: count = epoch * len(dataloader) + step # print ("====================================================") # print ("s_total_loss0 : " , s_total_loss0) # print ("s_total_loss0.item() : " , s_total_loss0.item()) # print ("UPDATE_INTERVAL : " , UPDATE_INTERVAL) print("s_total_loss0.item()/UPDATE_INTERVAL : ", s_total_loss0.item() / UPDATE_INTERVAL) print("s_total_loss1.item()/UPDATE_INTERVAL : ", s_total_loss1.item() / UPDATE_INTERVAL) print("w_total_loss0.item()/UPDATE_INTERVAL : ", w_total_loss0.item() / UPDATE_INTERVAL) print("w_total_loss1.item()/UPDATE_INTERVAL : ", w_total_loss1.item() / UPDATE_INTERVAL) # print ("s_total_loss0/UPDATE_INTERVAL : " , s_total_loss0/UPDATE_INTERVAL) # print ("=====================================================") s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 's_loss {:5.2f} {:5.2f} | ' 'w_loss {:5.2f} {:5.2f}'.format( epoch, step, len(dataloader), elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0, s_cur_loss1, w_cur_loss0, w_cur_loss1)) s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 start_time = time.time() # attention Maps #Save image only every 8 epochs && Save it to The Drive if (epoch % 8 == 0): print("bulding images") img_set, _ = build_super_images(imgs[-1].cpu(), captions, ixtoword, attn_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/attention_maps%d.png' % (image_dir, step) im.save(fullpath) mydriveimg = '/content/drive/My Drive/cubImage' drivepath = '%s/attention_maps%d.png' % (mydriveimg, epoch) im.save(drivepath) print("keyTime |||||||||||||||||||||||||||||||") print("train_function_time : ", time.time() - train_function_start_time) print("KeyTime |||||||||||||||||||||||||||||||") return count
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer, epoch, ixtoword, image_dir): cnn_model.train() rnn_model.train() s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 count = (epoch + 1) * len(dataloader) start_time = time.time() for step, data in enumerate(dataloader, 0): # print('step', step) """ Sets gradients of all model parameters to zero. Every time a variable is back propogated through, the gradient will be accumulated instead of being replaced. (This makes it easier for rnn, because each module will be back propogated through several times.) """ rnn_model.zero_grad() cnn_model.zero_grad() imgs, captions, cap_lens, \ class_ids, keys = prepare_data(data) # words_features: batch_size x nef x 17 x 17 # sent_code: batch_size x nef words_features, sent_code = cnn_model(imgs[-1]) # --> batch_size x nef x 17*17 nef, att_sze = words_features.size(1), words_features.size(2) # words_features = words_features.view(batch_size, nef, -1) """Dont understand completely ??""" hidden = rnn_model.init_hidden(batch_size) # words_emb: batch_size x nef x seq_len # sent_emb: batch_size x nef words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size) w_total_loss0 += w_loss0.data w_total_loss1 += w_loss1.data loss = w_loss0 + w_loss1 s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) loss += s_loss0 + s_loss1 s_total_loss0 += s_loss0.data s_total_loss1 += s_loss1.data loss.backward() # # `clip_grad_norm` helps prevent # the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm(rnn_model.parameters(), cfg.TRAIN.RNN_GRAD_CLIP) optimizer.step() if step % UPDATE_INTERVAL == 0: count = epoch * len(dataloader) + step s_cur_loss0 = s_total_loss0[0] / UPDATE_INTERVAL s_cur_loss1 = s_total_loss1[0] / UPDATE_INTERVAL w_cur_loss0 = w_total_loss0[0] / UPDATE_INTERVAL w_cur_loss1 = w_total_loss1[0] / UPDATE_INTERVAL elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 's_loss {:5.2f} {:5.2f} | ' 'w_loss {:5.2f} {:5.2f}'.format( epoch, step, len(dataloader), elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0, s_cur_loss1, w_cur_loss0, w_cur_loss1)) s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 start_time = time.time() # attention Maps img_set, _ = build_super_images(imgs[-1].cpu(), captions, ixtoword, attn_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/attention_maps_%d_epoch_%d_step.png' % ( image_dir, epoch, step) im.save(fullpath) return count
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer, epoch, ixtoword, image_dir, writer, logger, update_interval): cnn_model.train() rnn_model.train() s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 count = (epoch) * len(dataloader) start_time = time.time() for step, data in enumerate(dataloader, 0): # print('step', step) rnn_model.zero_grad() cnn_model.zero_grad() imgs, captions, cap_lens, \ class_ids, keys = prepare_data(data) # words_features: batch_size x nef x 17 x 17 # sent_code: batch_size x nef words_features, sent_code = cnn_model(imgs[-1]) # --> batch_size x nef x 17*17 nef, att_sze = words_features.size(1), words_features.size(2) # words_features = words_features.view(batch_size, nef, -1) # hidden = rnn_model.init_hidden(batch_size) # words_emb: batch_size x nef x seq_len # sent_emb: batch_size x nef words_emb, sent_emb = rnn_model(captions, cap_lens) w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size) w_total_loss0 += w_loss0.item() w_total_loss1 += w_loss1.item() loss = w_loss0 + w_loss1 s_loss0, s_loss1 = \ sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) loss += s_loss0 + s_loss1 s_total_loss0 += s_loss0.item() s_total_loss1 += s_loss1.item() # loss.backward() # # `clip_grad_norm` helps prevent # the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm_(rnn_model.parameters(), cfg.TRAIN.RNN_GRAD_CLIP) optimizer.step() global_step = epoch * len(dataloader) + step writer.add_scalars(main_tag="batch_loss", tag_scalar_dict={ "loss": loss.cpu().item(), "w_loss0": w_loss0.cpu().item(), "w_loss1": w_loss1.cpu().item(), "s_loss0": s_loss0.cpu().item(), "s_loss1": s_loss1.cpu().item() }, global_step=global_step) if step % update_interval == 0: count = epoch * len(dataloader) + step s_cur_loss0 = s_total_loss0 / update_interval s_cur_loss1 = s_total_loss1 / update_interval w_cur_loss0 = w_total_loss0 / update_interval w_cur_loss1 = w_total_loss1 / update_interval elapsed = time.time() - start_time logger.info( '| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 's_loss {:6.4f} {:6.4f} | ' 'w_loss {:6.4f} {:6.4f}'.format( epoch, step, len(dataloader), elapsed * 1000. / update_interval, s_cur_loss0, s_cur_loss1, w_cur_loss0, w_cur_loss1)) s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 start_time = time.time() # attention Maps if global_step % (10 * update_interval) == 0: img_set, _ = \ build_super_images(imgs[-1][:,:3].cpu(), captions, ixtoword, attn_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/attention_maps%d.png' % (image_dir, count) im.save(fullpath) writer.add_image(tag="image_DAMSM", img_tensor=transforms.ToTensor()(im), global_step=count) return count
def save_img_results(self, real_img, netG, noise, sent_emb, words_embs, mask, image_encoder, captions, cap_lens, gen_iterations, transf_matrices_inv, label_one_hot, name='current', num_visualize=8): qa_nums = (cap_lens > 0).sum(1) real_captions = captions captions, _ = make_fake_captions(qa_nums) # fake caption. # Save images # fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) inputs = (noise, sent_emb, words_embs, mask, transf_matrices_inv, label_one_hot) fake_imgs, attention_maps, _, _ = nn.parallel.data_parallel( netG, inputs, self.gpus) for i in range(len(attention_maps)): if len(fake_imgs) > 1: img = fake_imgs[i + 1].detach().cpu() lr_img = fake_imgs[i].detach().cpu() else: img = fake_imgs[0].detach().cpu() lr_img = None attn_maps = attention_maps[i] att_sze = attn_maps.size(2) img_set, _ = \ build_super_images(img, captions, self.ixtoword, attn_maps, att_sze, lr_imgs=lr_img, nvis = num_visualize) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/G_%s_%d_%d.png'\ % (self.image_dir, name, gen_iterations, i) im.save(fullpath) for i in range(cfg.TREE.BRANCH_NUM): save_pure_img_results(real_img[i].detach().cpu(), fake_imgs[i].detach().cpu(), gen_iterations, self.image_dir, token='level%d' % i) i = -1 img = fake_imgs[i].detach() region_features, _ = image_encoder(img) att_sze = region_features.size(2) _, _, att_maps = words_loss(region_features.detach(), words_embs.detach(), None, qa_nums, None, self.batch_size) img_set, _ = build_super_images(fake_imgs[i].detach().cpu(), captions, self.ixtoword, att_maps, att_sze, nvis=num_visualize) # FIXME currently the `render_attn_to_html` supports only the last level. # please implement multiple level rendering. html_doc = render_attn_to_html([ real_img[i].detach().cpu(), fake_imgs[i].detach().cpu(), ], real_captions, self.ixtoword, att_maps, att_sze, None, info=['Real Images', 'Fake Images']) with open('%s/damsm_attn_%d.html' % (self.image_dir, gen_iterations), 'w') as html_f: html_f.write(str(html_doc)) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/D_%s_%d.png'\ % (self.image_dir, name, gen_iterations) im.save(fullpath)
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer, epoch, ixtoword, image_dir): global global_step cnn_model.train() rnn_model.train() s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 count = (epoch + 1) * len(dataloader) start_time = time.time() for step, data in enumerate(dataloader, 0): global_step += 1 # print('step', step) rnn_model.zero_grad() cnn_model.zero_grad() # imgs: b x 3 x nbasesize x nbasesize imgs, captions, cap_lens, \ class_ids, _, _, _, keys, _ = prepare_data(data) class_ids = None # Oh. is this ok? FIXME # words_features: batch_size x nef x 17 x 17 # sent_code: batch_size x nef words_features, sent_code = cnn_model(imgs[-1]) # --> batch_size x nef x 17*17 nef, att_sze = words_features.size(1), words_features.size(2) # num_caps = (cap_lens > 0).sum(1) per_qa_embs, avg_qa_embs, num_caps = Level1RNNEncodeMagic(captions, cap_lens, rnn_model) w_loss0, w_loss1, attn_maps = words_loss(words_features, per_qa_embs, labels, num_caps, class_ids, batch_size) w_total_loss0 += w_loss0.data w_total_loss1 += w_loss1.data loss = w_loss0 + w_loss1 s_loss0, s_loss1 = \ sent_loss(sent_code, avg_qa_embs, labels, class_ids, batch_size) loss += s_loss0 + s_loss1 s_total_loss0 += s_loss0.data s_total_loss1 += s_loss1.data # loss.backward() # # `clip_grad_norm` helps prevent # the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm(rnn_model.parameters(), cfg.TRAIN.RNN_GRAD_CLIP) optimizer.step() if step % UPDATE_INTERVAL == 0: count = epoch * len(dataloader) + step s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 's_loss {:5.2f} {:5.2f} | ' 'w_loss {:5.2f} {:5.2f}' .format(epoch, step, len(dataloader), elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0, s_cur_loss1, w_cur_loss0, w_cur_loss1)) s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 start_time = time.time() # attention Maps def make_fake_captions(num_caps): caps = torch.zeros(batch_size, cfg.TEXT.MAX_QA_NUM, dtype = torch.int64) ref = torch.arange(0, cfg.TEXT.MAX_QA_NUM).view(1, -1).repeat(batch_size, 1).cuda() targ = num_caps.view(-1, 1).repeat(1, cfg.TEXT.MAX_QA_NUM) caps[ref < targ] = 1 return caps, {1: 'DUMMY'} _captions, _ixtoword = make_fake_captions(num_caps) html_doc = render_attn_to_html(imgs[-1].cpu(), captions, ixtoword, attn_maps, att_sze) with open('%s/attn_step%d.html' % (image_dir, global_step), 'w') as html_f: html_f.write(str(html_doc)) img_set, _ = \ build_super_images(imgs[-1].cpu(), _captions, _ixtoword, attn_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/attention_maps%d.png' % (image_dir, global_step) im.save(fullpath) return count
def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer, epoch, ixtoword, image_dir): cnn_model.train() rnn_model.train() s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 t_total_loss = 0 count = (epoch + 1) * len(dataloader) start_time = time.time() if(cfg.LOCAL_PRETRAINED): tokenizer = tokenization.FullTokenizer(vocab_file=cfg.BERT_ENCODER.VOCAB, do_lower_case=True) else: tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') debug_flag = False #debug_flag = True for step, data in enumerate(dataloader, 0): # print('step', step) rnn_model.zero_grad() cnn_model.zero_grad() if(debug_flag): with open('./debug0.pkl', 'wb') as f: pickle.dump({'data':data, 'cnn_model':cnn_model, 'rnn_model':rnn_model, 'labels':labels}, f) #imgs, captions, cap_lens, class_ids, keys = prepare_data(data) imgs, captions, cap_lens, class_ids, keys = prepare_data_bert(data, tokenizer) #imgs, captions, cap_lens, class_ids, keys, \ # input_ids, segment_ids, input_mask = prepare_data_bert(data, tokenizer) # sent_code: batch_size x nef #words_features, sent_code, word_logits = cnn_model(imgs[-1], captions) words_features, sent_code, word_logits = cnn_model(imgs[-1], captions, cap_lens) #words_features, sent_code, word_logits = cnn_model(imgs[-1], captions, input_ids, segment_ids, input_mask) # bs x T x vocab_size if(debug_flag): with open('./debug1.pkl', 'wb') as f: pickle.dump({'words_features':words_features, 'sent_code':sent_code, 'word_logits':word_logits}, f) nef, att_sze = words_features.size(1), words_features.size(2) # words_features = words_features.view(batch_size, nef, -1) hidden = rnn_model.init_hidden(batch_size) # words_emb: batch_size x nef x seq_len # sent_emb: batch_size x nef words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) #words_emb, sent_emb = rnn_model(captions, cap_lens, hidden, input_ids, segment_ids, input_mask) w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size) if(debug_flag): with open('./debug2.pkl', 'wb') as f: pickle.dump({'words_features':words_features, 'words_emb':words_emb, 'labels':labels, 'cap_lens':cap_lens, 'class_ids':class_ids, 'batch_size':batch_size}, f) w_total_loss0 += w_loss0.data w_total_loss1 += w_loss1.data loss = w_loss0 + w_loss1 s_loss0, s_loss1 = \ sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) if(debug_flag): with open('./debug3.pkl', 'wb') as f: pickle.dump({'sent_code':sent_code, 'sent_emb':sent_emb, 'labels':labels, 'class_ids':class_ids, 'batch_size':batch_size}, f) loss += s_loss0 + s_loss1 s_total_loss0 += s_loss0.data s_total_loss1 += s_loss1.data # added code #print(word_logits.shape, captions.shape) t_loss = image_to_text_loss(word_logits, captions) if(debug_flag): with open('./debug4.pkl', 'wb') as f: pickle.dump({'word_logits':word_logits, 'captions':captions}, f) loss += t_loss t_total_loss += t_loss.data loss.backward() # # `clip_grad_norm` helps prevent # the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm_(rnn_model.parameters(), cfg.TRAIN.RNN_GRAD_CLIP) optimizer.step() if step % UPDATE_INTERVAL == 0: count = epoch * len(dataloader) + step s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL t_curr_loss = t_total_loss.item() / UPDATE_INTERVAL elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 's_loss {:5.2f} {:5.2f} | ' 'w_loss {:5.2f} {:5.2f} | ' 't_loss {:5.2f}' .format(epoch, step, len(dataloader), elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0, s_cur_loss1, w_cur_loss0, w_cur_loss1, t_curr_loss)) s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 t_total_loss = 0 start_time = time.time() # attention Maps img_set, _ = \ build_super_images(imgs[-1].cpu(), captions, ixtoword, attn_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/attention_maps%d.png' % (image_dir, step) im.save(fullpath) return count