def models(modelname, cfg, word_len): #print(word_len) text_encoder = cache.get(modelname + '_text_encoder', None) if text_encoder is None: #print("text_encoder not cached") text_encoder = RNN_ENCODER(word_len, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) if cfg.CUDA: text_encoder.cuda() text_encoder.eval() cache[modelname + '_text_encoder'] = text_encoder netG = cache.get(modelname + '_netG', None) if netG is None: #print("netG not cached") netG = G_NET() state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) if cfg.CUDA: netG.cuda() netG.eval() cache[modelname + '_netG'] = netG return text_encoder, netG
def models(word_len): print('Loading Model', word_len) text_encoder = cache.get('text_encoder') print('Text enconder', text_encoder) if text_encoder is None: print("text_encoder not cached") text_encoder = RNN_ENCODER(word_len, nhidden=256) state_dict = torch.load('../DAMSMencoders/coco/text_encoder100.pth', map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('loaded text encoder') text_encoder.cuda() print('text encoder cuda') text_encoder.eval() print('text encoder eval') #cache.set('text_encoder', text_encoder, timeout=60 * 60 * 24) print('Got Text Encoder, moving to netG') netG = cache.get('netG') if netG is None: print("netG not cached") netG = G_NET() state_dict = torch.load('../models/coco_AttnGAN2.pth', map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) if cfg.CUDA: netG.cuda() netG.eval() #cache.set('netG', netG, timeout=60 * 60 * 24) print('Got NetG') return text_encoder, netG
def models(word_len): #print(word_len) text_encoder = cache.get('text_encoder') if text_encoder is None: #print("text_encoder not cached") text_encoder = RNN_ENCODER(word_len, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) if cfg.CUDA: text_encoder.cuda() text_encoder.eval() cache.set('text_encoder', text_encoder, timeout=60 * 60 * 24) netG = cache.get('netG') if netG is None: #print("netG not cached") netG = G_NET() state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) if cfg.CUDA: netG.cuda() netG.eval() cache.set('netG', netG, timeout=60 * 60 * 24) return text_encoder, netG
def build_models(): # build model ############################################################ text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) labels = Variable(torch.LongTensor(range(batch_size))) start_epoch = 0 if cfg.TRAIN.NET_E != '': state_dict = torch.load(cfg.TRAIN.NET_E) text_encoder.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_E) # name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = torch.load(name) image_encoder.load_state_dict(state_dict) print('Load ', name) istart = cfg.TRAIN.NET_E.rfind('_') + 8 iend = cfg.TRAIN.NET_E.rfind('.') start_epoch = cfg.TRAIN.NET_E[istart:iend] start_epoch = int(start_epoch) + 1 print('start_epoch', start_epoch) if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() labels = labels.cuda() return text_encoder, image_encoder, labels, start_epoch
def build_models(): text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) labels = Variable(torch.LongTensor(range(batch_size))) start_epoch = 0 lr = cfg.TRAIN.ENCODER_LR if cfg.TRAIN.NET_E != '': state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load {}'.format(cfg.TRAIN.NET_E)) name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = torch.load(name, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) print('Load {}'.format(name)) istart = cfg.TRAIN.NET_E.rfind('_') + 8 iend = cfg.TRAIN.NET_E.rfind('.') start_epoch = cfg.TRAIN.NET_E[istart:iend] start_epoch = int(start_epoch) + 1 print('start_epoch', start_epoch) # initial lr with the right value # note that the turning point is always epoch 114 if start_epoch < 114: lr = cfg.TRAIN.ENCODER_LR * (0.98 ** start_epoch) else: lr = cfg.TRAIN.ENCODER_LR / 10 if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() labels = labels.cuda() return text_encoder, image_encoder, labels, start_epoch, lr
def build_text_encoder(ntokens): text_encoder = RNN_ENCODER(ntokens, nhidden=cfg.TEXT.EMBEDDING_DIM) if cfg.TRAIN.NET_E != '': state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_E) if cfg.CUDA: text_encoder = text_encoder.cuda() return text_encoder
def models(word_len): text_encoder = cache.get('text_encoder') if text_encoder is None: text_encoder = RNN_ENCODER(word_len, nhidden=256) state_dict = torch.load('../DAMSMencoders/coco/text_encoder100.pth', map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) text_encoder.cuda() text_encoder.eval() #cache.set('text_encoder', text_encoder, timeout=60 * 60 * 24) netG = cache.get('netG') if netG is None: netG = G_NET() state_dict = torch.load('../models/coco_AttnGAN2.pth', map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) if cfg.CUDA: netG.cuda() netG.eval() #cache.set('netG', netG, timeout=60 * 60 * 24) return text_encoder, netG
def models(word_len): print( word_len ) text_encoder = cache.get('text_encoder') if text_encoder is None: print( "text_encoder not cached" ) if sys.argv[1].casefold() == 'rnn': text_encoder = RNN_ENCODER(word_len, nhidden=cfg.TEXT.EMBEDDING_DIM) elif sys.argv[1].casefold() == 'transformer': text_encoder = GPT2Model.from_pretrained( TRANSFORMER_ENCODER ) state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) if cfg.CUDA: text_encoder.cuda() text_encoder.eval() cache.set('text_encoder', text_encoder, timeout=60 * 60 * 24) netG = cache.get('netG') if netG is None: print( "netG not cached" ) if cfg.GAN.B_STYLEGEN: netG = G_NET_STYLED() else: netG = G_NET() checkpoint = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) if cfg.GAN.B_STYLEGEN: netG.w_ewma = checkpoint[ 'w_ewma' ] if cfg.CUDA: netG.w_ewma = netG.w_ewma.to( 'cuda:' + str( cfg.GPU_ID ) ) netG.load_state_dict( checkpoint[ 'netG_state_dict' ] ) else: netG.load_state_dict( checkpoint ) if cfg.CUDA: netG.cuda() netG.eval() cache.set('netG', netG, timeout=60 * 60 * 24) return text_encoder, netG
def build_models(text_encoder_type): # build model ############################################################ text_encoder_type = text_encoder_type.casefold() if text_encoder_type not in ('rnn', 'transformer'): raise ValueError('Unsupported text_encoder_type') if text_encoder_type == 'rnn': text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) labels = Variable(torch.LongTensor(range(batch_size))) start_epoch = 0 if cfg.TRAIN.NET_E: if text_encoder_type == 'rnn': state_dict = torch.load(cfg.TRAIN.NET_E) text_encoder.load_state_dict(state_dict) elif text_encoder_type == 'transformer': text_encoder = GPT2Model.from_pretrained(cfg.TRAIN.NET_E) # output_hidden_states = True ) print('Load ', cfg.TRAIN.NET_E) # name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = torch.load(name) image_encoder.load_state_dict(state_dict) print('Load ', name) istart = cfg.TRAIN.NET_E.rfind('_') + 8 iend = cfg.TRAIN.NET_E.rfind('.') start_epoch = cfg.TRAIN.NET_E[istart:iend] start_epoch = int(start_epoch) + 1 else: if text_encoder_type == 'rnn': print('Training RNN from scratch') elif text_encoder_type == 'transformer': # don't initialize the weights of these huge models from scratch... print('Training Transformer starting from pretrained model') text_encoder = GPT2Model.from_pretrained(TRANSFORMER_ENCODER) # output_hidden_states = True ) print('Training CNN starting from ImageNet pretrained Inception-v3') print('start_epoch', start_epoch) if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() labels = labels.cuda() return text_encoder, image_encoder, labels, start_epoch
def build_models(): # build model ############################################################ text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) labels = Variable(torch.LongTensor(range(batch_size))) start_epoch = 0 # MODIFIED if cfg.PRETRAINED_RNN: text_encoder_params = torch.load( cfg.PRETRAINED_RNN, map_location=lambda storage, loc: storage) text_encoder.rnn.load_state_dict(text_encoder_params['encoder']) pad_idx = text_encoder_params['vocab']['word2id']['<pad>'] n_words, embed_size = text_encoder.encoder.weight.size() text_encoder.encoder = nn.Embedding(n_words, embed_size, pad_idx) text_encoder.encoder.load_state_dict(text_encoder_params['embedding']) if cfg.PRETRAINED_CNN: image_encoder_params = torch.load( cfg.PRETRAINED_CNN, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(image_encoder_params) if cfg.TRAIN.NET_E != '': state_dict = torch.load(cfg.TRAIN.NET_E) text_encoder.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_E) # name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = torch.load(name) image_encoder.load_state_dict(state_dict) print('Load ', name) istart = cfg.TRAIN.NET_E.rfind('_') + 8 iend = cfg.TRAIN.NET_E.rfind('.') start_epoch = cfg.TRAIN.NET_E[istart:iend] start_epoch = int(start_epoch) + 1 print('start_epoch', start_epoch) if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() labels = labels.cuda() return text_encoder, image_encoder, labels, start_epoch
def build_models(self): # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # #######################generator and discriminators############## # netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET # TODO: elif cfg.TREE.BRANCH_NUM > 3: netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) # epoch = 0 # MODIFIED if cfg.PRETRAINED_G != '': state_dict = torch.load(cfg.PRETRAINED_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.PRETRAINED_G) if cfg.TRAIN.B_NET_D: Gname = cfg.PRETRAINED_G s_tmp = Gname[:Gname.rfind('/')] for i in range(len(netsD)): Dname = '%s/netD%d.pth' % ( s_tmp, i ) # the name of Ds should be consistent and differ from each other in i print('Load D from: ', Dname) state_dict = torch.load( Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, netG, netsD, epoch]
def build_models(self): if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return # vgg16 network style_loss = VGGNet() for p in style_loss.parameters(): p.requires_grad = False print("Load the style loss model") style_loss.eval() image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) netG.apply(weights_init) for i in range(len(netsD)): netsD[i].apply(weights_init) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # Create a target network. target_netG = deepcopy(netG) if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() style_loss = style_loss.cuda() # The target network is stored on the scondary GPU.--------------------------------- target_netG.cuda(secondary_device) target_netG.ca_net.device = secondary_device #----------------------------------------------------------------------------------- netG.cuda() for i in range(len(netsD)): netsD[i] = netsD[i].cuda() # Disable training in the target network: for p in target_netG.parameters(): p.requires_grad = False return [ text_encoder, image_encoder, netG, target_netG, netsD, epoch, style_loss ]
def sampling(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for models is not found!') else: if split_dir == 'test': split_dir = 'valid' if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() netG.apply(weights_init) netG.cuda() netG.eval() # text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) # the path to save generated images s_tmp = model_dir[:model_dir.rfind('.pth')] save_dir = '%s/%s' % (s_tmp, split_dir) mkdir_p(save_dir) cnt = 0 idx = 0 ### avg_ddva = 0 for _ in range(1): for step, data in enumerate(self.data_loader, 0): cnt += batch_size if step % 100 == 0: print('step: ', step) captions, cap_lens, imperfect_captions, imperfect_cap_lens, misc = data # Generate images for human-text ---------------------------------------------------------------- data_human = [captions, cap_lens, misc] imgs, captions, cap_lens, class_ids, keys, wrong_caps,\ wrong_caps_len, wrong_cls_id= prepare_data(data_human) hidden = text_encoder.init_hidden(batch_size) words_embs, sent_emb = text_encoder( captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach( ), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] noise.data.normal_(0, 1) fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask) # Generate images for imperfect caption-text------------------------------------------------------- data_imperfect = [ imperfect_captions, imperfect_cap_lens, misc ] imgs, imperfect_captions, imperfect_cap_lens, class_ids, imperfect_keys, wrong_caps,\ wrong_caps_len, wrong_cls_id = prepare_data(data_imperfect) hidden = text_encoder.init_hidden(batch_size) words_embs, sent_emb = text_encoder( imperfect_captions, imperfect_cap_lens, hidden) words_embs, sent_emb = words_embs.detach( ), sent_emb.detach() mask = (imperfect_captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] noise.data.normal_(0, 1) imperfect_fake_imgs, _, _, _ = netG( noise, sent_emb, words_embs, mask) # Sort the results by keys to align ---------------------------------------------------------------- keys, captions, cap_lens, fake_imgs, _, _ = sort_by_keys( keys, captions, cap_lens, fake_imgs, None, None) imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, true_imgs, _ = \ sort_by_keys(imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs,\ imgs, None) # Shift device for the imgs, target_imgs and imperfect_imgs------------------------------------------------ for i in range(len(imgs)): imgs[i] = imgs[i].to(secondary_device) imperfect_fake_imgs[i] = imperfect_fake_imgs[i].to( secondary_device) fake_imgs[i] = fake_imgs[i].to(secondary_device) for j in range(batch_size): s_tmp = '%s/single' % (save_dir) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) k = -1 im = fake_imgs[k][j].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) cap_im = imperfect_fake_imgs[k][j].data.cpu().numpy() cap_im = (cap_im + 1.0) * 127.5 cap_im = cap_im.astype(np.uint8) cap_im = np.transpose(cap_im, (1, 2, 0)) # Uncomment to scale true image true_im = true_imgs[k][j].data.cpu().numpy() true_im = (true_im + 1.0) * 127.5 true_im = true_im.astype(np.uint8) true_im = np.transpose(true_im, (1, 2, 0)) # Uncomment to save images. #true_im = Image.fromarray(true_im) #fullpath = '%s_true_s%d.png' % (s_tmp, idx) #true_im.save(fullpath) im = Image.fromarray(im) fullpath = '%s_s%d.png' % (s_tmp, idx) im.save(fullpath) #cap_im = Image.fromarray(cap_im) #fullpath = '%s_imperfect_s%d.png' % (s_tmp, idx) idx = idx + 1 #cap_im.save(fullpath) neg_ddva = negative_ddva( imperfect_fake_imgs, imgs, fake_imgs, reduce='mean', final_only=True).data.cpu().numpy() avg_ddva += neg_ddva * (-1) #text_caps = [[self.ixtoword[word] for word in sent if word!=0] for sent in captions.tolist()] #imperfect_text_caps = [[self.ixtoword[word] for word in sent if word!=0] for sent in # imperfect_captions.tolist()] print(step) avg_ddva = avg_ddva / (step + 1) print('\n\nAvg_DDVA: ', avg_ddva)
def embedding(self, split_dir, model): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: if split_dir == 'test': split_dir = 'valid' # Build and load the generator if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() netG.apply(weights_init) if cfg.GPU_ID != -1: netG.cuda() netG.eval() # model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') print(img_encoder_path) print('Load image encoder from:', img_encoder_path) state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) if cfg.GPU_ID != -1: image_encoder = image_encoder.cuda() image_encoder.eval() print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) if cfg.GPU_ID != -1: text_encoder = text_encoder.cuda() text_encoder.eval() batch_size = self.batch_size nz = cfg.GAN.Z_DIM with torch.no_grad(): noise = Variable(torch.FloatTensor(batch_size, nz)) if cfg.GPU_ID != -1: noise = noise.cuda() # the path to save generated images save_dir = model_dir[:model_dir.rfind('.pth')] cnt = 0 # new if cfg.TRAIN.CLIP_SENTENCODER: print("Use CLIP SentEncoder for sampling") img_features = dict() txt_features = dict() with torch.no_grad(): for _ in range(1): # (cfg.TEXT.CAPTIONS_PER_IMAGE): for step, data in enumerate(self.data_loader, 0): cnt += batch_size if step % 100 == 0: print('step: ', step) imgs, captions, cap_lens, class_ids, keys, texts = prepare_data( data) hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder( captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach( ), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] if cfg.TRAIN.CLIP_SENTENCODER: # random select one paragraph for each training example sents = [] for idx in range(len(texts)): sents_per_image = texts[idx].split( '\n') # new 3/11 if len(sents_per_image) > 1: sent_ix = np.random.randint( 0, len(sents_per_image) - 1) else: sent_ix = 0 sents.append(sents_per_image[0]) # print('sents: ', sents) sent = clip.tokenize(sents) # .to(device) # load clip #model = torch.jit.load("model.pt").cuda().eval() sent_input = sent if cfg.GPU_ID != -1: sent_input = sent.cuda() # print("text input", sent_input) sent_emb_clip = model.encode_text( sent_input).float() if CLIP: sent_emb = sent_emb_clip ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask) if CLIP: images = [] for j in range(fake_imgs[-1].shape[0]): image = fake_imgs[-1][j].cpu().clone() image = image.squeeze(0) unloader = transforms.ToPILImage() image = unloader(image) image = preprocess( image.convert("RGB")) # 256*256 -> 224*224 images.append(image) image_mean = torch.tensor( [0.48145466, 0.4578275, 0.40821073]).cuda() image_std = torch.tensor( [0.26862954, 0.26130258, 0.27577711]).cuda() image_input = torch.tensor(np.stack(images)).cuda() image_input -= image_mean[:, None, None] image_input /= image_std[:, None, None] cnn_codes = model.encode_image(image_input).float() else: region_features, cnn_codes = image_encoder( fake_imgs[-1]) for j in range(batch_size): cnn_code = cnn_codes[j] temp = keys[j].replace('b', '').replace("'", '') img_features[temp] = cnn_code.cpu().numpy() txt_features[temp] = sent_emb[j].cpu().numpy() with open(save_dir + ".pkl", 'wb') as f: pickle.dump(img_features, f) with open(save_dir + "_text.pkl", 'wb') as f: pickle.dump(txt_features, f)
def sampling(self, split_dir, model): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: if split_dir == 'test': split_dir = 'valid' # Build and load the generator if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() netG.apply(weights_init) if cfg.GPU_ID != -1: netG.cuda() netG.eval() # text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) if cfg.GPU_ID != -1: text_encoder = text_encoder.cuda() text_encoder.eval() batch_size = self.batch_size nz = cfg.GAN.Z_DIM with torch.no_grad(): noise = Variable(torch.FloatTensor(batch_size, nz)) if cfg.GPU_ID != -1: noise = noise.cuda() model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) # state_dict = torch.load(cfg.TRAIN.NET_G) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) # the path to save generated images s_tmp = model_dir[:model_dir.rfind('.pth')] save_dir = '%s/%s' % (s_tmp, split_dir) mkdir_p(save_dir) cnt = 0 #new if cfg.TRAIN.CLIP_SENTENCODER: print("Use CLIP SentEncoder for sampling") for _ in range(1): # (cfg.TEXT.CAPTIONS_PER_IMAGE): for step, data in enumerate(self.data_loader, 0): cnt += batch_size if step % 100 == 0: print('step: ', step) # if step > 50: # break #imgs, captions, cap_lens, class_ids, keys = prepare_data(data) #new imgs, captions, cap_lens, class_ids, keys, texts = prepare_data( data) hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder( captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach( ), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] # new if cfg.TRAIN.CLIP_SENTENCODER: # random select one paragraph for each training example sents = [] for idx in range(len(texts)): sents_per_image = texts[idx].split( '\n') # new 3/11 if len(sents_per_image) > 1: sent_ix = np.random.randint( 0, len(sents_per_image) - 1) else: sent_ix = 0 sents.append(sents_per_image[sent_ix]) with open('%s/%s' % (save_dir, 'eval_sents.txt'), 'a+') as f: f.write(sents_per_image[sent_ix] + '\n') # print('sents: ', sents) sent = clip.tokenize(sents) # .to(device) # load clip #model = torch.jit.load("model.pt").cuda().eval() sent_input = sent if cfg.GPU_ID != -1: sent_input = sent.cuda() # print("text input", sent_input) with torch.no_grad(): sent_emb = model.encode_text(sent_input).float() ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask) for j in range(batch_size): s_tmp = '%s/fake/%s' % (save_dir, keys[j]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) print('Make a new folder: ', f'{save_dir}/real') mkdir_p(f'{save_dir}/real') print('Make a new folder: ', f'{save_dir}/text') mkdir_p(f'{save_dir}/text') k = -1 # for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() # [-1, 1] --> [0, 255] im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_s%d.png' % (s_tmp, k) im.save(fullpath) temp = keys[j].replace('b', '').replace("'", '') shutil.copy(f"../data/Face/images/{temp}.jpg", f"{save_dir}/real/") shutil.copy(f"../data/Face/text/{temp}.txt", f"{save_dir}/text/")
def build_models(self): def count_parameters(model): total_param = 0 for name, param in model.named_parameters(): if param.requires_grad: num_param = np.prod(param.size()) if param.dim() > 1: print(name, ':', 'x'.join(str(x) for x in list(param.size())), '=', num_param) else: print(name, ':', num_param) total_param += num_param return total_param # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # #######################generator and discriminators############## # netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET # TODO: elif cfg.TREE.BRANCH_NUM > 3: netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: print('number of trainable parameters =', count_parameters(netG)) print('number of trainable parameters =', count_parameters(netsD[-1])) netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, netG, netsD, epoch]
def sampling(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for main module is not found!') else: if split_dir == 'test': split_dir = 'valid' if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = EncDecNet() netG.apply(weights_init) netG.cuda() netG.eval() # The text encoder text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() # The image encoder """ image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) print('Load image encoder from:', img_encoder_path) image_encoder = image_encoder.cuda() image_encoder.eval() """ # The VGG network VGG = VGG16() print("Load the VGG model") VGG.cuda() VGG.eval() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() model_dir = os.path.join(cfg.DATA_DIR, 'output', self.args.netG, 'Model/netG_epoch_600.pth') state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) # the path to save modified images save_dir_valid = os.path.join(cfg.DATA_DIR, 'output', self.args.netG, 'valid') #mkdir_p(save_dir) cnt = 0 idx = 0 for i in range(5): # (cfg.TEXT.CAPTIONS_PER_IMAGE): # the path to save modified images save_dir = os.path.join(save_dir_valid, 'valid_%d' % i) save_dir_super = os.path.join(save_dir, 'super') save_dir_single = os.path.join(save_dir, 'single') mkdir_p(save_dir_super) mkdir_p(save_dir_single) for step, data in enumerate(self.data_loader, 0): cnt += batch_size if step % 100 == 0: print('step: ', step) imgs, w_imgs, captions, cap_lens, class_ids, keys, wrong_caps, \ wrong_caps_len, wrong_cls_id = prepare_data(data) ####################################################### # (1) Extract text and image embeddings ###################################################### hidden = text_encoder.init_hidden(batch_size) words_embs, sent_emb = text_encoder( wrong_caps, wrong_caps_len, hidden) words_embs, sent_emb = words_embs.detach( ), sent_emb.detach() mask = (wrong_caps == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] ####################################################### # (2) Modify real images ###################################################### noise.data.normal_(0, 1) fake_img, mu, logvar = netG(imgs[-1], sent_emb, words_embs, noise, mask, VGG) img_set = build_images(imgs[-1], fake_img, captions, wrong_caps, self.ixtoword) img = Image.fromarray(img_set) full_path = '%s/super_step%d.png' % (save_dir_super, step) img.save(full_path) for j in range(batch_size): s_tmp = '%s/single' % (save_dir_single) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) k = -1 im = fake_img[j].data.cpu().numpy() #im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_s%d.png' % (s_tmp, idx) idx = idx + 1 im.save(fullpath)
def sampling(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: netG_list = [ '../models/netG_epoch_50.pth', '../models/netG_epoch_60.pth', '../models/netG_epoch_70.pth', '../models/netG_epoch_80.pth', '../models/netG_epoch_90.pth', '../models/netG_epoch_100.pth', '../models/netG_epoch_110.pth', '../models/netG_epoch_120.pth', '../models/netG_epoch_130.pth', '../models/netG_epoch_140.pth', '../models/netG_epoch_150.pth', '../models/netG_epoch_160.pth' ] if split_dir == 'test': split_dir = 'valid' # Build and load the generator if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() netG.apply(weights_init) netG.cuda() netG.eval() # text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) # state_dict = torch.load(cfg.TRAIN.NET_G) print("LINE==380") print("-----------------netG------------------------") print(netG) print("--------------state-dict---------------------") #print(state_dict) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) print( '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' ) # the path to save generated images s_tmp = model_dir[:model_dir.rfind('.pth')] save_dir = '%s/%s' % (s_tmp, split_dir) print('save_dir:', save_dir) mkdir_p(save_dir) cnt = 0 for _ in range(1): # (cfg.TEXT.CAPTIONS_PER_IMAGE): for step, data in enumerate(self.data_loader, 0): cnt += batch_size if step % 100 == 0: print('step: ', step) # if step > 50: # break imgs, captions, cap_lens, class_ids, keys = prepare_data( data) hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder( captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach( ), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask) for j in range(batch_size): s_tmp = '%s/single/%s' % (save_dir, keys[j]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) k = -1 # for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() # [-1, 1] --> [0, 255] im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_s%d.png' % (s_tmp, k) im.save(fullpath)
def calc_mp(self): if cfg.TRAIN.NET_G == '': print('Error: the path for main module is not found!') else: #if split_dir == 'test': # split_dir = 'valid' if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = EncDecNet() netG.apply(weights_init) netG.cuda() netG.eval() # The text encoder text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() # The image encoder #image_encoder = CNN_dummy() #print('define image_encoder') image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) print('Load image encoder from:', img_encoder_path) image_encoder = image_encoder.cuda() image_encoder.eval() # The VGG network VGG = VGG16() print("Load the VGG model") #VGG.to(torch.device("cuda:1")) VGG.cuda() VGG.eval() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() model_dir = os.path.join(cfg.DATA_DIR, 'output', self.args.netG, 'Model', self.args.netG_epoch) state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) # the path to save modified images cnt = 0 idx = 0 diffs, sims = [], [] for _ in range(1): # (cfg.TEXT.CAPTIONS_PER_IMAGE): for step, data in enumerate(self.data_loader, 0): cnt += batch_size if step % 100 == 0: print('step: ', step) imgs, w_imgs, captions, cap_lens, class_ids, keys, wrong_caps, \ wrong_caps_len, wrong_cls_id = prepare_data(data) ####################################################### # (1) Extract text and image embeddings ###################################################### hidden = text_encoder.init_hidden(batch_size) words_embs, sent_emb = text_encoder( wrong_caps, wrong_caps_len, hidden) words_embs, sent_emb = words_embs.detach( ), sent_emb.detach() mask = (wrong_caps == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] ####################################################### # (2) Modify real images ###################################################### noise.data.normal_(0, 1) fake_img, mu, logvar = netG(imgs[-1], sent_emb, words_embs, noise, mask, VGG) diff = F.l1_loss(fake_img, imgs[-1]) diffs.append(diff.item()) region_code, cnn_code = image_encoder(fake_img) sim = cosine_similarity(sent_emb, cnn_code) sim = torch.mean(sim) sims.append(sim.item()) diff = np.sum(diffs) / len(diffs) sim = np.sum(sims) / len(sims) print('diff: %.3f, sim:%.3f' % (diff, sim)) print('MP: %.3f' % ((1 - diff) * sim)) netG_epoch = self.args.netG_epoch[self.args.netG_epoch.find('_') + 1:-4] print('model_epoch:%s, diff: %.3f, sim:%.3f, MP:%.3f' % (netG_epoch, np.sum(diffs) / len(diffs), np.sum(sims) / len(sims), (1 - diff) * sim))
def build_models(self): # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # #######################generator and discriminators############## # netsD = [] from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) epoch = 0 if self.resume: checkpoint_list = sorted([ckpt for ckpt in glob.glob(self.model_dir + "/" + '*.pth')]) latest_checkpoint = checkpoint_list[-1] state_dict = torch.load(latest_checkpoint, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict["netG"]) for i in range(len(netsD)): netsD[i].load_state_dict(state_dict["netD"][i]) epoch = int(latest_checkpoint[-8:-4]) + 1 print("Resuming training from checkpoint {} at epoch {}.".format(latest_checkpoint, epoch)) # if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, netG, netsD, epoch]
def sampling(self, split_dir, num_samples=30000): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: if split_dir == 'test': split_dir = 'valid' # Build and load the generator if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() netG.apply(weights_init) netG.cuda() netG.eval() # text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz)) noise = noise.cuda() model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) # state_dict = torch.load(cfg.TRAIN.NET_G) netG.load_state_dict(state_dict["netG"]) print('Load G from: ', model_dir) # the path to save generated images s_tmp = model_dir[:model_dir.rfind('.pth')] save_dir = '%s/%s' % (s_tmp, split_dir) mkdir_p(save_dir) cnt = 0 for _ in range(1): # (cfg.TEXT.CAPTIONS_PER_IMAGE): for step, data in enumerate(self.data_loader, 0): cnt += batch_size if step % 10000 == 0: print('step: ', step) if step >= num_samples: break imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot = prepare_data(data) transf_matrices_inv = transformation_matrices[1] hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach(), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (noise, sent_emb, words_embs, mask, transf_matrices_inv, label_one_hot) with torch.no_grad(): fake_imgs, _, mu, logvar = nn.parallel.data_parallel(netG, inputs, self.gpus) for j in range(batch_size): s_tmp = '%s/single/%s' % (save_dir, keys[j]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) k = -1 # for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() # [-1, 1] --> [0, 255] im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_s%d.png' % (s_tmp, k) im.save(fullpath)
def build_models(self): # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # #######################generator and discriminators############## # netG = G_NET(len(self.cats_index_dict)) netsPatD, netsShpD = [], [] if cfg.TREE.BRANCH_NUM > 0: netsPatD.append(PAT_D_NET64()) netsShpD.append(SHP_D_NET64(len(self.cats_index_dict))) if cfg.TREE.BRANCH_NUM > 1: netsPatD.append(PAT_D_NET128()) netsShpD.append(SHP_D_NET128(len(self.cats_index_dict))) if cfg.TREE.BRANCH_NUM > 2: netsPatD.append(PAT_D_NET256()) netsShpD.append(SHP_D_NET256(len(self.cats_index_dict))) netObjSSD = OBJ_SS_D_NET(len(self.cats_index_dict)) netObjLSD = OBJ_LS_D_NET(len(self.cats_index_dict)) netG.apply(weights_init) netObjSSD.apply(weights_init) netObjLSD.apply(weights_init) for i in range(len(netsPatD)): netsPatD[i].apply(weights_init) netsShpD[i].apply(weights_init) print('# of netsPatD', len(netsPatD)) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() netObjSSD.cuda() netObjLSD.cuda() for i in range(len(netsPatD)): netsPatD[i].cuda() netsShpD[i].cuda() if len(cfg.GPU_IDS) > 1: text_encoder = nn.DataParallel(text_encoder) text_encoder.to(self.device) image_encoder = nn.DataParallel(image_encoder) image_encoder.to(self.device) netG = nn.DataParallel(netG) netG.to(self.device) netObjSSD = nn.DataParallel(netObjSSD) netObjSSD.to(self.device) netObjLSD = nn.DataParallel(netObjLSD) netObjLSD.to(self.device) for i in range(len(netsPatD)): netsPatD[i] = nn.DataParallel(netsPatD[i]) netsPatD[i].to(self.device) netsShpD[i] = nn.DataParallel(netsShpD[i]) netsShpD[i].to(self.device) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 Gname = cfg.TRAIN.NET_G for i in range(len(netsPatD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netPatD%d.pth' % (s_tmp, i) print('Load PatD from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsPatD[i].load_state_dict(state_dict) Dname = '%s/netShpD%d.pth' % (s_tmp, i) print('Load ShpD from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsShpD[i].load_state_dict(state_dict) s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netObjSSD.pth' % (s_tmp) print('Load ObjSSD from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netObjSSD.load_state_dict(state_dict) s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netObjLSD.pth' % (s_tmp) print('Load ObjLSD from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netObjLSD.load_state_dict(state_dict) return [ text_encoder, image_encoder, netG, netsPatD, netsShpD, netObjSSD, netObjLSD, epoch ]
def gen_example(self, data_dic): if cfg.TRAIN.NET_G == '' or cfg.TRAIN.NET_C == '': print('Error: the path for main module or DCM is not found!') else: # The text encoder text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() # The image encoder """ image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) print('Load image encoder from:', img_encoder_path) image_encoder = image_encoder.cuda() image_encoder.eval() """ """ image_encoder = CNN_dummy() image_encoder = image_encoder.cuda() image_encoder.eval() """ # The VGG network VGG = VGG16() print("Load the VGG model") VGG.cuda() VGG.eval() # The main module if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = EncDecNet() s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')] s_tmp = os.path.join(cfg.DATA_DIR, 'output', self.args.netG, 'valid/gen_example') model_dir = os.path.join(cfg.DATA_DIR, 'output', self.args.netG, 'Model/netG_epoch_8.pth') state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) #netG = nn.DataParallel(netG, device_ids= self.gpus) netG.cuda() netG.eval() for key in data_dic: save_dir = '%s/%s' % (s_tmp, key) mkdir_p(save_dir) captions, cap_lens, sorted_indices, imgs = data_dic[key] batch_size = captions.shape[0] nz = cfg.GAN.Z_DIM captions = Variable(torch.from_numpy(captions), volatile=True) cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True) captions = captions.cuda() cap_lens = cap_lens.cuda() for i in range(1): noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() ####################################################### # (1) Extract text and image embeddings ###################################################### hidden = text_encoder.init_hidden(batch_size) # The text embeddings words_embs, sent_emb = text_encoder( captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach( ), sent_emb.detach() # The image embeddings mask = (captions == 0) ####################################################### # (2) Modify real images ###################################################### noise.data.normal_(0, 1) imgs_256 = imgs[-1].unsqueeze(0).repeat( batch_size, 1, 1, 1) enc_features = VGG(imgs_256) fake_img, mu, logvar = nn.parallel.data_parallel( netG, (imgs[-1], sent_emb, words_embs, noise, mask, enc_features), self.gpus) cap_lens_np = cap_lens.cpu().data.numpy() one_imgs = [] for j in range(captions.shape[0]): font = ImageFont.truetype('./FreeMono.ttf', 20) canv = Image.new('RGB', (256, 256), (255, 255, 255)) draw = ImageDraw.Draw(canv) sent = [] for k in range(len(captions[j])): if (captions[j][k] == 0): break word = self.ixtoword[captions[j][k].item()].encode( 'ascii', 'ignore').decode('ascii') if (k % 2 == 1): word = word + '\n' sent.append(word) fake_sent = ' '.join(sent) draw.text((0, 0), fake_sent, font=font, fill=(0, 0, 0)) canv_np = np.asarray(canv) real_im = imgs[-1] real_im = (real_im + 1) * 127.5 real_im = real_im.cpu().numpy().astype(np.uint8) real_im = np.transpose(real_im, (1, 2, 0)) fake_im = fake_img[j] fake_im = (fake_im + 1.0) * 127.5 fake_im = fake_im.detach().cpu().numpy().astype( np.uint8) fake_im = np.transpose(fake_im, (1, 2, 0)) one_img = np.concatenate([real_im, canv_np, fake_im], axis=1) one_imgs.append(one_img) img_set = np.concatenate(one_imgs, axis=0) super_img = Image.fromarray(img_set) full_path = os.path.join(save_dir, 'super.png') super_img.save(full_path) """ for j in range(5): ## batch_size save_name = '%s/%d_s_%d' % (save_dir, i, sorted_indices[j]) for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_g%d.png' % (save_name, k) im.save(fullpath) for k in range(len(attention_maps)): if len(fake_imgs) > 1: im = fake_imgs[k + 1].detach().cpu() else: im = fake_imgs[0].detach().cpu() attn_maps = attention_maps[k] att_sze = attn_maps.size(2) """ """ img_set, sentences = \ build_super_images2(im[j].unsqueeze(0), captions[j].unsqueeze(0), [cap_lens_np[j]], self.ixtoword, [attn_maps[j]], att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s_a%d.png' % (save_name, k) im.save(fullpath) """ """
def sampling(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: if split_dir == 'test': split_dir = 'valid' # Build and load the generator if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() netG.apply(weights_init) netG.cuda() netG.eval() # load text encoder text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() #load image encoder image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) print('Load image encoder from:', img_encoder_path) image_encoder = image_encoder.cuda() image_encoder.eval() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() model_dir = cfg.TRAIN.NET_G state_dict = torch.load(model_dir, map_location=lambda storage, loc: storage) # state_dict = torch.load(cfg.TRAIN.NET_G) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) # the path to save generated images s_tmp = model_dir[:model_dir.rfind('.pth')] save_dir = '%s/%s' % (s_tmp, split_dir) mkdir_p(save_dir) cnt = 0 R_count = 0 R = np.zeros(30000) cont = True for ii in range(11): # (cfg.TEXT.CAPTIONS_PER_IMAGE): if (cont == False): break for step, data in enumerate(self.data_loader, 0): cnt += batch_size if (cont == False): break if step % 100 == 0: print('cnt: ', cnt) # if step > 50: # break imgs, captions, cap_lens, class_ids, keys = prepare_data( data) hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder( captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach( ), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask, cap_lens) for j in range(batch_size): s_tmp = '%s/single/%s' % (save_dir, keys[j]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): #print('Make a new folder: ', folder) mkdir_p(folder) k = -1 # for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() # [-1, 1] --> [0, 255] im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_s%d_%d.png' % (s_tmp, k, ii) im.save(fullpath) _, cnn_code = image_encoder(fake_imgs[-1]) for i in range(batch_size): mis_captions, mis_captions_len = self.dataset.get_mis_caption( class_ids[i]) hidden = text_encoder.init_hidden(99) _, sent_emb_t = text_encoder(mis_captions, mis_captions_len, hidden) rnn_code = torch.cat( (sent_emb[i, :].unsqueeze(0), sent_emb_t), 0) ### cnn_code = 1 * nef ### rnn_code = 100 * nef scores = torch.mm(cnn_code[i].unsqueeze(0), rnn_code.transpose(0, 1)) # 1* 100 cnn_code_norm = torch.norm(cnn_code[i].unsqueeze(0), 2, dim=1, keepdim=True) rnn_code_norm = torch.norm(rnn_code, 2, dim=1, keepdim=True) norm = torch.mm(cnn_code_norm, rnn_code_norm.transpose(0, 1)) scores0 = scores / norm.clamp(min=1e-8) if torch.argmax(scores0) == 0: R[R_count] = 1 R_count += 1 if R_count >= 30000: sum = np.zeros(10) np.random.shuffle(R) for i in range(10): sum[i] = np.average(R[i * 3000:(i + 1) * 3000 - 1]) R_mean = np.average(sum) R_std = np.std(sum) print("R mean:{:.4f} std:{:.4f}".format(R_mean, R_std)) cont = False
def gen_samples(self, idx): text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from: {}'.format(cfg.TRAIN.NET_E)) text_encoder = text_encoder.cuda() text_encoder.eval() netG = G_NET() state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: {}'.format(cfg.TRAIN.NET_G)) netG.cuda() netG.eval() s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')] save_dir = '%s/samples' % (s_tmp) mkdir_p(save_dir) batch_size = self.batch_size nz = cfg.GAN.Z_DIM with torch.no_grad(): noise = Variable(torch.FloatTensor(batch_size, nz)) noise = noise.cuda() step = 0 data_iter = iter(self.data_loader) while step < self.num_batches: data = data_iter.next() imgs, captions, cap_lens, class_ids, sorted_cap_indices = self.prepare_data( data) hidden = text_encoder.init_hidden(batch_size) words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] for i in range(10): noise.data.normal_(0, 1) fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) cap_lens_np = cap_lens.cpu().data.numpy() for j in range(batch_size): right_idx = step * batch_size + sorted_cap_indices[j] save_name = '%s/%d_s_%d' % (save_dir, i, right_idx) original_idx = idx[right_idx] shutil.copyfile( '/.local/AttnGAN/data/FashionSynthesis/test/original/test128_{}.png' .format(original_idx + 1), save_dir + '/test128_{0}_{1}.png'.format( original_idx + 1, right_idx)) for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_g%d.png' % (save_name, k) im.save(fullpath) for k in range(len(attention_maps)): if len(fake_imgs) > 1: im = fake_imgs[k + 1].detach().cpu() else: im = fake_imgs[0].detach().cpu() attn_maps = attention_maps[k] att_sze = attn_maps.size(2) img_set, sentences = \ build_super_images2(im[j].unsqueeze(0), captions[j].unsqueeze(0), [cap_lens_np[j]], self.ixtoword, [attn_maps[j]], att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s_a%d.png' % (save_name, k) im.save(fullpath) step += 1
def build_models(self): ################### Text and Image encoders ######################################## if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return """ image_encoder = CNN_dummy() image_encoder.cuda() image_encoder.eval() image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() """ VGG = VGG16() VGG.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() ####################### Generator and Discriminators ############## from model import D_NET256 netD = D_NET256() netG = EncDecNet() netD.apply(weights_init) netG.apply(weights_init) # epoch = 0 """ if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) """ # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() netG.cuda() netD.cuda() VGG.cuda() return [text_encoder, netG, netD, epoch, VGG]
def sample(self, split_dir, num_samples=25, draw_bbox=False): from PIL import Image, ImageDraw, ImageFont import cPickle as pickle import torchvision import torchvision.utils as vutils if cfg.TRAIN.NET_G == '': print('Error: the path for model NET_G is not found!') else: if split_dir == 'test': split_dir = 'valid' # Build and load the generator text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() batch_size = cfg.TRAIN.BATCH_SIZE nz = cfg.GAN.Z_DIM model_dir = cfg.TRAIN.NET_G state_dict = torch.load(model_dir, map_location=lambda storage, loc: storage) # state_dict = torch.load(cfg.TRAIN.NET_G) netG = G_NET() print('Load G from: ', model_dir) netG.apply(weights_init) netG.load_state_dict(state_dict["netG"]) netG.cuda() netG.eval() # the path to save generated images s_tmp = model_dir[:model_dir.rfind('.pth')] save_dir = '%s_%s' % (s_tmp, split_dir) mkdir_p(save_dir) ####################################### noise = Variable(torch.FloatTensor(9, nz)) imsize = 256 for step, data in enumerate(self.data_loader, 0): if step >= num_samples: break imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot, bbox = \ prepare_data(data, eval=True) transf_matrices_inv = transformation_matrices[1][0].unsqueeze(0) label_one_hot = label_one_hot[0].unsqueeze(0) img = imgs[-1][0] val_image = img.view(1, 3, imsize, imsize) hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) words_embs, sent_emb = words_embs[0].unsqueeze(0).detach(), sent_emb[0].unsqueeze(0).detach() words_embs = words_embs.repeat(9, 1, 1) sent_emb = sent_emb.repeat(9, 1) mask = (captions == 0) mask = mask[0].unsqueeze(0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] mask = mask.repeat(9, 1) transf_matrices_inv = transf_matrices_inv.repeat(9, 1, 1, 1) label_one_hot = label_one_hot.repeat(9, 1, 1) ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (noise, sent_emb, words_embs, mask, transf_matrices_inv, label_one_hot) with torch.no_grad(): fake_imgs, _, mu, logvar = nn.parallel.data_parallel(netG, inputs, self.gpus) data_img = torch.FloatTensor(10, 3, imsize, imsize).fill_(0) data_img[0] = val_image data_img[1:10] = fake_imgs[-1] if draw_bbox: for idx in range(3): x, y, w, h = tuple([int(imsize*x) for x in bbox[0, idx]]) w = imsize-1 if w > imsize-1 else w h = imsize-1 if h > imsize-1 else h if x <= -1: break data_img[:10, :, y, x:x + w] = 1 data_img[:10, :, y:y + h, x] = 1 data_img[:10, :, y+h, x:x + w] = 1 data_img[:10, :, y:y + h, x + w] = 1 # get caption cap = captions[0].data.cpu().numpy() sentence = "" for j in range(len(cap)): if cap[j] == 0: break word = self.ixtoword[cap[j]].encode('ascii', 'ignore').decode('ascii') sentence += word + " " sentence = sentence[:-1] vutils.save_image(data_img, '{}/{}_{}.png'.format(save_dir, sentence, step), normalize=True, nrow=10) print("Saved {} files to {}".format(step, save_dir))
def sampling(self, split_dir, num_samples=30000): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: if split_dir == 'test': split_dir = 'valid' # Build and load the generator if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() netG.apply(weights_init) netG.cuda() netG.eval() # text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) text_encoder = text_encoder.cuda() text_encoder.eval() print('Loaded text encoder from:', cfg.TRAIN.NET_E) batch_size = self.batch_size[0] nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz)).cuda() local_noise = Variable(torch.FloatTensor(batch_size, 32)).cuda() model_dir = cfg.TRAIN.NET_G state_dict = torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict["netG"]) max_objects = 10 print('Load G from: ', model_dir) # the path to save generated images s_tmp = model_dir[:model_dir.rfind('.pth')].split("/")[-1] save_dir = '%s/%s/%s' % ("OP-GAN/output", s_tmp, split_dir) ### mkdir_p(save_dir) print("Saving images to: {}".format(save_dir)) number_batches = num_samples // batch_size if number_batches < 1: number_batches = 1 data_iter = iter(self.data_loader) for step in tqdm(range(number_batches)): data = data_iter.next() imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot, _ = prepare_data(data, eval=True) transf_matrices = transformation_matrices[0] transf_matrices_inv = transformation_matrices[1] hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach(), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) local_noise.data.normal_(0, 1) inputs = (noise, local_noise, sent_emb, words_embs, mask, transf_matrices, transf_matrices_inv, label_one_hot, max_objects) with torch.no_grad(): fake_imgs, _, mu, logvar = nn.parallel.data_parallel(netG, inputs, self.gpus) for batch_idx, j in enumerate(range(batch_size)): s_tmp = '%s/%s' % (save_dir, keys[j]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) k = -1 # for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() # [-1, 1] --> [0, 255] im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_s%d.png' % (s_tmp, step*batch_size+batch_idx) im.save(fullpath)
def gen_example(self, data_dic): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: # Build and load the generator text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() # the path to save generated images if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')] model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) netG.cuda() netG.eval() for key in data_dic: save_dir = '%s/%s' % (s_tmp, key) mkdir_p(save_dir) captions, cap_lens, sorted_indices = data_dic[key] batch_size = captions.shape[0] nz = cfg.GAN.Z_DIM captions = Variable(torch.from_numpy(captions), volatile=True) cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True) captions = captions.cuda() cap_lens = cap_lens.cuda() for i in range(1): # 16 noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() ####################################################### # (1) Extract text embeddings ###################################################### hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) mask = (captions == 0) ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) with torch.no_grad(): fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) # G attention cap_lens_np = cap_lens.cpu().data.numpy() for j in range(batch_size): save_name = '%s/%d_s_%d' % (save_dir, i, sorted_indices[j]) for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) # print('im', im.shape) im = np.transpose(im, (1, 2, 0)) # print('im', im.shape) im = Image.fromarray(im) fullpath = '%s_g%d.png' % (save_name, k) im.save(fullpath) for k in range(len(attention_maps)): if len(fake_imgs) > 1: im = fake_imgs[k + 1].detach().cpu() else: im = fake_imgs[0].detach().cpu() attn_maps = attention_maps[k] att_sze = attn_maps.size(2) img_set, sentences = \ build_super_images2(im[j].unsqueeze(0), captions[j].unsqueeze(0), [cap_lens_np[j]], self.ixtoword, [attn_maps[j]], att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s_a%d.png' % (save_name, k) im.save(fullpath)
def build_models(self): # text encoders if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() # self.n_words = 156 text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # Caption models - cnn_encoder and rnn_decoder caption_cnn = CAPTION_CNN(cfg.CAP.embed_size) caption_cnn.load_state_dict(torch.load(cfg.CAP.caption_cnn_path, map_location=lambda storage, loc: storage)) for p in caption_cnn.parameters(): p.requires_grad = False print('Load caption model from:', cfg.CAP.caption_cnn_path) caption_cnn.eval() # self.n_words = 9 caption_rnn = CAPTION_RNN(cfg.CAP.embed_size, cfg.CAP.hidden_size * 2, self.n_words, cfg.CAP.num_layers) # caption_rnn = CAPTION_RNN(cfg.CAP.embed_size, cfg.CAP.hidden_size * 2, self.n_words, cfg.CAP.num_layers) caption_rnn.load_state_dict(torch.load(cfg.CAP.caption_rnn_path, map_location=lambda storage, loc: storage)) for p in caption_rnn.parameters(): p.requires_grad = False print('Load caption model from:', cfg.CAP.caption_rnn_path) # Generator and Discriminator: netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] # print(epoch) # print(state_dict.keys()) # print(netG.keys()) # epoch = state_dict['epoch'] epoch = int(epoch) + 1 # epoch = 187 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() caption_cnn = caption_cnn.cuda() caption_rnn = caption_rnn.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, epoch]