def sampling(self, split_dir): if cfg.TRAIN.NET_G == '' or cfg.TRAIN.NET_C == '': print('Error: the path for main module or DCM is not found!') else: if split_dir == 'test': split_dir = 'valid' if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() netG.apply(weights_init) netG.cuda() netG.eval() # The text encoder text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() # The image encoder image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) print('Load image encoder from:', img_encoder_path) image_encoder = image_encoder.cuda() image_encoder.eval() # The VGG network VGG = VGGNet() print("Load the VGG model") VGG.cuda() VGG.eval() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() # The DCM netDCM = DCM_Net() if cfg.TRAIN.NET_C != '': state_dict = \ torch.load(cfg.TRAIN.NET_C, map_location=lambda storage, loc: storage) netDCM.load_state_dict(state_dict) print('Load DCM from: ', cfg.TRAIN.NET_C) netDCM.cuda() netDCM.eval() model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) # the path to save generated images s_tmp = model_dir[:model_dir.rfind('.pth')] save_dir = '%s/%s' % (s_tmp, split_dir) mkdir_p(save_dir) cnt = 0 idx = 0 for _ in range(5): # (cfg.TEXT.CAPTIONS_PER_IMAGE): for step, data in enumerate(self.data_loader, 0): cnt += batch_size if step % 100 == 0: print('step: ', step) imgs, captions, cap_lens, class_ids, keys, wrong_caps, \ wrong_caps_len, wrong_cls_id = prepare_data(data) ####################################################### # (1) Extract text and image embeddings ###################################################### hidden = text_encoder.init_hidden(batch_size) words_embs, sent_emb = text_encoder( wrong_caps, wrong_caps_len, hidden) words_embs, sent_emb = words_embs.detach( ), sent_emb.detach() mask = (wrong_caps == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] region_features, cnn_code = \ image_encoder(imgs[cfg.TREE.BRANCH_NUM - 1]) ####################################################### # (2) Modify real images ###################################################### noise.data.normal_(0, 1) fake_imgs, attention_maps, mu, logvar, h_code, c_code = netG( noise, sent_emb, words_embs, mask, cnn_code, region_features) real_img = imgs[cfg.TREE.BRANCH_NUM - 1] real_features = VGG(real_img)[0] fake_img = netDCM(h_code, real_features, sent_emb, words_embs,\ mask, c_code) for j in range(batch_size): s_tmp = '%s/single' % (save_dir) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) k = -1 im = fake_img[j].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_s%d.png' % (s_tmp, idx) idx = idx + 1 im.save(fullpath)
def build_models(self): # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': LOGGER.error('no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False LOGGER.info(f'Load image encoder from: {img_encoder_path}') image_encoder.eval() text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False LOGGER.info(f'Load text encoder from: {cfg.TRAIN.NET_E}') text_encoder.eval() zsl_discriminator_path = cfg.TRAIN.NET_E.replace( 'text_encoder', 'discriminator') zsl_discriminator = ZSLD(cfg.ZSL.NUM_CLASSES) state_dict = torch.load(zsl_discriminator_path, map_location=lambda storage, loc: storage) zsl_discriminator.load_state_dict(state_dict) for p in zsl_discriminator.parameters(): p.requires_grad = False zsl_discriminator.eval() LOGGER.info(f'Load ZSL Discriminator from: {zsl_discriminator_path}') # #######################generator and discriminators############## # netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from src.model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from src.model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from src.model import D_NET256 as D_NET netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from src.model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) netG.apply(weights_init) for i in range(len(netsD)): netsD[i].apply(weights_init) LOGGER.info(f'# of netsD: {len(netsD)}') # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) LOGGER.info(f'Load G from: {cfg.TRAIN.NET_G}') istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) LOGGER.info(f'Load D from: {Dname}') state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() zsl_discriminator = zsl_discriminator.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [ text_encoder, image_encoder, netG, netsD, zsl_discriminator, epoch ]
def build_models(self): ################### Text and Image encoders ######################################## if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return VGG = VGGNet() for p in VGG.parameters(): p.requires_grad = False print("Load the VGG model") VGG.eval() image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() ####################### Generator and Discriminators ############## netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from src.model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from src.model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from src.model import D_NET256 as D_NET netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from src.model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) netG.apply(weights_init) for i in range(len(netsD)): netsD[i].apply(weights_init) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() VGG = VGG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, netG, netsD, epoch, VGG]
def sampling(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: if split_dir == 'test': split_dir = 'valid' # Build and load the generator if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() netG.apply(weights_init) netG.cuda() netG.eval() # text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) # state_dict = torch.load(cfg.TRAIN.NET_G) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) # the path to save generated images s_tmp = model_dir[:model_dir.rfind('.pth')] save_dir = '%s/%s' % (s_tmp, split_dir) mkdir_p(save_dir) cnt = 0 for _ in range(1): # (cfg.TEXT.CAPTIONS_PER_IMAGE): for step, data in enumerate(self.data_loader, 0): cnt += batch_size if step % 100 == 0: print('step: ', step) # if step > 50: # break imgs, captions, cap_lens, class_ids, keys = prepare_data( data) hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder( captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach( ), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] ####################################################### # (2) Generate fake images ###################################################### fake_imgs, *_ = netG(sent_emb, words_embs, mask) for j in range(batch_size): s_tmp = '%s/single/%s' % (save_dir, keys[j]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) k = -1 # for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() # [-1, 1] --> [0, 255] im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_s%d.png' % (s_tmp, k) im.save(fullpath)