def build_models(self): if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return # vgg16 network style_loss = VGGNet() for p in style_loss.parameters(): p.requires_grad = False print("Load the style loss model") style_loss.eval() image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) netG.apply(weights_init) for i in range(len(netsD)): netsD[i].apply(weights_init) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # Create a target network. target_netG = deepcopy(netG) if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() style_loss = style_loss.cuda() # The target network is stored on the scondary GPU.--------------------------------- target_netG.cuda(secondary_device) target_netG.ca_net.device = secondary_device #----------------------------------------------------------------------------------- netG.cuda() for i in range(len(netsD)): netsD[i] = netsD[i].cuda() # Disable training in the target network: for p in target_netG.parameters(): p.requires_grad = False return [ text_encoder, image_encoder, netG, target_netG, netsD, epoch, style_loss ]
def gen_example(self, data_dic): if cfg.TRAIN.NET_G == '' or cfg.TRAIN.NET_C == '': print('Error: the path for main module or DCM is not found!') else: # The text encoder text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() # The image encoder image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) print('Load image encoder from:', img_encoder_path) image_encoder = image_encoder.cuda() image_encoder.eval() # The VGG network VGG = VGGNet() print("Load the VGG model") VGG.cuda() VGG.eval() # The main module if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')] model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) netG.cuda() netG.eval() # The DCM netDCM = DCM_Net() if cfg.TRAIN.NET_C != '': state_dict = \ torch.load(cfg.TRAIN.NET_C, map_location=lambda storage, loc: storage) netDCM.load_state_dict(state_dict) print('Load DCM from: ', cfg.TRAIN.NET_C) netDCM.cuda() netDCM.eval() for key in data_dic: save_dir = '%s/%s' % (s_tmp, key) mkdir_p(save_dir) captions, cap_lens, sorted_indices, imgs = data_dic[key] batch_size = captions.shape[0] nz = cfg.GAN.Z_DIM captions = Variable(torch.from_numpy(captions), volatile=True) cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True) captions = captions.cuda() cap_lens = cap_lens.cuda() for i in range(1): noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() ####################################################### # (1) Extract text and image embeddings ###################################################### hidden = text_encoder.init_hidden(batch_size) # The text embeddings words_embs, sent_emb = text_encoder( captions, cap_lens, hidden) # The image embeddings region_features, cnn_code = \ image_encoder(imgs[cfg.TREE.BRANCH_NUM - 1].unsqueeze(0)) mask = (captions == 0) ####################################################### # (2) Modify real images ###################################################### noise.data.normal_(0, 1) fake_imgs, attention_maps, mu, logvar, h_code, c_code = netG( noise, sent_emb, words_embs, mask, cnn_code, region_features) real_img = imgs[cfg.TREE.BRANCH_NUM - 1].unsqueeze(0) real_features = VGG(real_img)[0] fake_img = netDCM(h_code, real_features, sent_emb, words_embs, \ mask, c_code) cap_lens_np = cap_lens.cpu().data.numpy() for j in range(batch_size): save_name = '%s/%d_s_%d' % (save_dir, i, sorted_indices[j]) for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_g%d.png' % (save_name, k) im.save(fullpath) for k in range(len(attention_maps)): if len(fake_imgs) > 1: im = fake_imgs[k + 1].detach().cpu() else: im = fake_imgs[0].detach().cpu() attn_maps = attention_maps[k] att_sze = attn_maps.size(2) img_set, sentences = \ build_super_images2(im[j].unsqueeze(0), captions[j].unsqueeze(0), [cap_lens_np[j]], self.ixtoword, [attn_maps[j]], att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s_a%d.png' % (save_name, k) im.save(fullpath) save_name = '%s/%d_sf_%d' % (save_dir, 1, sorted_indices[j]) im = fake_img[j].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_SF.png' % (save_name) im.save(fullpath) save_name = '%s/%d_s_%d' % (save_dir, 1, 9) im = imgs[2].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_SR.png' % (save_name) im.save(fullpath)
def sampling(self, split_dir): if cfg.TRAIN.NET_G == '' or cfg.TRAIN.NET_C == '': print('Error: the path for main module or DCM is not found!') else: if split_dir == 'test': split_dir = 'valid' if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() netG.apply(weights_init) netG.cuda() netG.eval() # The text encoder text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() # The image encoder image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) print('Load image encoder from:', img_encoder_path) image_encoder = image_encoder.cuda() image_encoder.eval() # The VGG network VGG = VGGNet() print("Load the VGG model") VGG.cuda() VGG.eval() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() # The DCM netDCM = DCM_Net() if cfg.TRAIN.NET_C != '': state_dict = \ torch.load(cfg.TRAIN.NET_C, map_location=lambda storage, loc: storage) netDCM.load_state_dict(state_dict) print('Load DCM from: ', cfg.TRAIN.NET_C) netDCM.cuda() netDCM.eval() model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) # the path to save modified images s_tmp = model_dir[:model_dir.rfind('.pth')] save_dir = '%s/%s' % (s_tmp, split_dir) mkdir_p(save_dir) cnt = 0 idx = 0 for _ in range(5): # (cfg.TEXT.CAPTIONS_PER_IMAGE): for step, data in enumerate(self.data_loader, 0): cnt += batch_size if step % 100 == 0: print('step: ', step) imgs, captions, cap_lens, class_ids, keys, wrong_caps, \ wrong_caps_len, wrong_cls_id = prepare_data(data) ####################################################### # (1) Extract text and image embeddings ###################################################### hidden = text_encoder.init_hidden(batch_size) words_embs, sent_emb = text_encoder( wrong_caps, wrong_caps_len, hidden) words_embs, sent_emb = words_embs.detach( ), sent_emb.detach() mask = (wrong_caps == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] region_features, cnn_code = \ image_encoder(imgs[cfg.TREE.BRANCH_NUM - 1]) ####################################################### # (2) Modify real images ###################################################### noise.data.normal_(0, 1) fake_imgs, attention_maps, mu, logvar, h_code, c_code = netG( noise, sent_emb, words_embs, mask, cnn_code, region_features) real_img = imgs[cfg.TREE.BRANCH_NUM - 1] real_features = VGG(real_img)[0] fake_img = netDCM(h_code, real_features, sent_emb, words_embs, \ mask, c_code) for j in range(batch_size): s_tmp = '%s/single' % (save_dir) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) k = -1 im = fake_img[j].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_s%d.png' % (s_tmp, idx) idx = idx + 1 im.save(fullpath)
def build_models(self): ################### Text and Image encoders ######################################## # if cfg.TRAIN.NET_E == '': # print('Error: no pretrained text-image encoders') # return VGG = VGGNet() for p in VGG.parameters(): p.requires_grad = False print("Load the VGG model") VGG.eval() image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) # when NET_E = '', train the image_encoder and text_encoder jointly if cfg.TRAIN.NET_E != '': state_dict = torch.load( cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage).state_dict() text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = torch.load( img_encoder_path, map_location=lambda storage, loc: storage).state_dict() image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() ####################### Generator and Discriminators ############## netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET netG = G_DCGAN() if cfg.TRAIN.W_GAN: netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() netG.apply(weights_init) if cfg.TRAIN.W_GAN: if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) for i in range(len(netsD)): netsD[i].apply(weights_init) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() VGG = VGG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, netG, netsD, epoch, VGG]
def build_models(self): # ################### models ######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return if cfg.TRAIN.NET_G == '': print('Error: no pretrained main module') return VGG = VGGNet() for p in VGG.parameters(): p.requires_grad = False print("Load the VGG model") VGG.eval() image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() if cfg.GAN.B_DCGAN: netG = G_DCGAN() from model import D_NET256 as D_NET netD = D_NET(b_jcu=False) else: from model import D_NET256 netG = G_NET() netD = D_NET256() netD.apply(weights_init) state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) netG.eval() print('Load G from: ', cfg.TRAIN.NET_G) epoch = 0 netDCM = DCM_Net() if cfg.TRAIN.NET_C != '': state_dict = \ torch.load(cfg.TRAIN.NET_C, map_location=lambda storage, loc: storage) netDCM.load_state_dict(state_dict) print('Load DCM from: ', cfg.TRAIN.NET_C) istart = cfg.TRAIN.NET_C.rfind('_') + 1 iend = cfg.TRAIN.NET_C.rfind('.') epoch = cfg.TRAIN.NET_C[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.NET_D != '': state_dict = \ torch.load(cfg.TRAIN.NET_D, map_location=lambda storage, loc: storage) netD.load_state_dict(state_dict) print('Load DCM Discriminator from: ', cfg.TRAIN.NET_D) if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() netDCM.cuda() VGG = VGG.cuda() netD.cuda() return [text_encoder, image_encoder, netG, netD, epoch, VGG, netDCM]