def evaluate(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: # Build and load the generator netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=self.gpus) print(netG) # state_dict = torch.load(cfg.TRAIN.NET_G) state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) # the path to save generated images s_tmp = cfg.TRAIN.NET_G istart = s_tmp.rfind('_') + 1 iend = s_tmp.rfind('.') iteration = int(s_tmp[istart:iend]) s_tmp = s_tmp[:s_tmp.rfind('/')] save_dir = '%s/iteration%d/%s' % (s_tmp, iteration, split_dir) if cfg.TEST.B_EXAMPLE: folder = '%s/super' % (save_dir) else: folder = '%s/single' % (save_dir) print('Make a new folder: ', folder) mkdir_p(folder) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) if cfg.CUDA: netG.cuda() noise = noise.cuda() # switch to evaluate mode netG.eval() num_batches = int(cfg.TEST.SAMPLE_NUM / self.batch_size) cnt = 0 for step in xrange(num_batches): noise.data.normal_(0, 1) fake_imgs, _, _ = netG(noise) if cfg.TEST.B_EXAMPLE: self.save_superimages(fake_imgs[-1], folder, cnt, 256) else: self.save_singleimages(fake_imgs[-1], folder, cnt, 256) # self.save_singleimages(fake_imgs[-2], folder, 128) # self.save_singleimages(fake_imgs[-3], folder, 64) cnt += self.batch_size
def load_network(gpus): netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=gpus) print(netG) netsD = [] if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) if cfg.TREE.BRANCH_NUM > 3: netsD.append(D_NET512()) if cfg.TREE.BRANCH_NUM > 4: netsD.append(D_NET1024()) # TODO: if cfg.TREE.BRANCH_NUM > 5: for i in range(len(netsD)): netsD[i].apply(weights_init) netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus) # print(netsD[i]) print('# of netsD', len(netsD)) count = 0 if cfg.TRAIN.NET_G != '': state_dict = torch.load(cfg.TRAIN.NET_G) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') count = cfg.TRAIN.NET_G[istart:iend] count = int(count) + 1 if cfg.TRAIN.NET_D != '': for i in range(len(netsD)): print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i)) state_dict = torch.load('%s%d.pth' % (cfg.TRAIN.NET_D, i)) netsD[i].load_state_dict(state_dict) inception_model = INCEPTION_V3() if cfg.CUDA: netG.cuda() for i in range(len(netsD)): netsD[i].cuda() inception_model = inception_model.cuda() inception_model.eval() return netG, netsD, len(netsD), inception_model, count
def build_models(self): # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return #################################################################### image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): # make image encoder grad on p.requires_grad = True for k, v in image_encoder.named_children( ): # freeze the layer1-5 (set eval for BNlayer) if k in frozen_list_image_encoder: v.train(False) v.requires_grad_(False) print('Load image encoder from:', img_encoder_path) # image_encoder.eval() ################################################################### text_encoder = TEXT_TRANSFORMER_ENCODERv2( emb=cfg.TEXT.EMBEDDING_DIM, heads=8, depth=1, seq_length=cfg.TEXT.WORDS_NUM, num_tokens=self.n_words) # state_dict = torch.load(cfg.TRAIN.NET_E) # text_encoder.load_state_dict(state_dict) # print('Load ', cfg.TRAIN.NET_E) state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = True print('Load text encoder from:', cfg.TRAIN.NET_E) # text_encoder.eval() # #######################generator and discriminators############## # netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET # TODO: elif cfg.TREE.BRANCH_NUM > 3: netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################## # # config = Config() cap_model = caption.build_model_v3(config) print("Initializing from Checkpoint...") cap_model_path = cfg.TRAIN.NET_E.replace('text_encoder', 'cap_model') if os.path.exists(cap_model_path): print('Load C from: {0}'.format(cap_model_path)) state_dict = \ torch.load(cap_model_path, map_location=lambda storage, loc: storage) cap_model.load_state_dict(state_dict['model']) else: base_line_path = 'catr/checkpoints/catr_damsm256_proj_coco2014_ep02.pth' print('Load C from: {0}'.format(base_line_path)) checkv3 = torch.load(base_line_path, map_location='cpu') cap_model.load_state_dict(checkv3['model'], strict=False) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() cap_model = cap_model.cuda() # caption transformer added netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, netG, netsD, epoch, cap_model]
def sampling(self, split_dir, num_samples=30000): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: if split_dir == 'test': split_dir = 'valid' # Build and load the generator if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() netG.apply(weights_init) netG.cuda() netG.eval() # text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) text_encoder = text_encoder.cuda() text_encoder.eval() print('Loaded text encoder from:', cfg.TRAIN.NET_E) batch_size = self.batch_size[0] nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz)).cuda() local_noise = Variable(torch.FloatTensor(batch_size, 32)).cuda() model_dir = cfg.TRAIN.NET_G state_dict = torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict["netG"]) max_objects = 10 print('Load G from: ', model_dir) # the path to save generated images s_tmp = model_dir[:model_dir.rfind('.pth')].split("/")[-1] save_dir = '%s/%s/%s' % ("../output", s_tmp, split_dir) mkdir_p(save_dir) print("Saving images to: {}".format(save_dir)) number_batches = num_samples // batch_size if number_batches < 1: number_batches = 1 data_iter = iter(self.data_loader) for step in tqdm(range(number_batches)): data = data_iter.next() imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot, _ = prepare_data( data, eval=True) transf_matrices = transformation_matrices[0] transf_matrices_inv = transformation_matrices[1] hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach(), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) local_noise.data.normal_(0, 1) inputs = (noise, local_noise, sent_emb, words_embs, mask, transf_matrices, transf_matrices_inv, label_one_hot, max_objects) with torch.no_grad(): fake_imgs, _, mu, logvar = nn.parallel.data_parallel( netG, inputs, self.gpus) for batch_idx, j in enumerate(range(batch_size)): s_tmp = '%s/%s' % (save_dir, keys[j]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) k = -1 # for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() # [-1, 1] --> [0, 255] im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_s%d.png' % (s_tmp, step * batch_size + batch_idx) im.save(fullpath)
def evaluate(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: # Build and load the generator if split_dir == 'test': split_dir = 'valid' netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=self.gpus) print(netG) # state_dict = torch.load(cfg.TRAIN.NET_G) state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) # the path to save generated images s_tmp = cfg.TRAIN.NET_G istart = s_tmp.rfind('_') + 1 iend = s_tmp.rfind('.') iteration = int(s_tmp[istart:iend]) s_tmp = s_tmp[:s_tmp.rfind('/')] save_dir = '%s/iteration%d' % (s_tmp, iteration) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) if cfg.CUDA: netG.cuda() noise = noise.cuda() # switch to evaluate mode netG.eval() for step, data in enumerate(self.data_loader, 0): imgs, t_embeddings, filenames = data if cfg.CUDA: t_embeddings = Variable(t_embeddings).cuda() else: t_embeddings = Variable(t_embeddings) # print(t_embeddings[:, 0, :], t_embeddings.size(1)) embedding_dim = t_embeddings.size(1) batch_size = imgs[0].size(0) noise.data.resize_(batch_size, nz) noise.data.normal_(0, 1) fake_img_list = [] for i in range(embedding_dim): fake_imgs, _, _ = netG(noise, t_embeddings[:, i, :]) if cfg.TEST.B_EXAMPLE: # fake_img_list.append(fake_imgs[0].data.cpu()) # fake_img_list.append(fake_imgs[1].data.cpu()) fake_img_list.append(fake_imgs[2].data.cpu()) else: self.save_singleimages(fake_imgs[-1], filenames, save_dir, split_dir, i, 256) # self.save_singleimages(fake_imgs[-2], filenames, # save_dir, split_dir, i, 128) # self.save_singleimages(fake_imgs[-3], filenames, # save_dir, split_dir, i, 64) # break if cfg.TEST.B_EXAMPLE: # self.save_superimages(fake_img_list, filenames, # save_dir, split_dir, 64) # self.save_superimages(fake_img_list, filenames, # save_dir, split_dir, 128) self.save_superimages(fake_img_list, filenames, save_dir, split_dir, 256)
def gen_example(self, data_dic): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: # Build and load the generator text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() # the path to save generated images if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')] model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) netG.cuda() netG.eval() for key in data_dic: save_dir = '%s/%s' % (s_tmp, key) mkdir_p(save_dir) captions, cap_lens, sorted_indices = data_dic[key] batch_size = captions.shape[0] nz = cfg.GAN.Z_DIM captions = Variable(torch.from_numpy(captions), volatile=True) cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True) captions = captions.cuda() cap_lens = cap_lens.cuda() for i in range(1): # 16 noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() # (1) Extract text embeddings hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) mask = (captions == 0) # (2) Generate fake images noise.data.normal_(0, 1) fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) # G attention cap_lens_np = cap_lens.cpu().data.numpy() for j in range(batch_size): save_name = '%s/%d_s_%d' % (save_dir, i, sorted_indices[j]) for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) # print('im', im.shape) im = np.transpose(im, (1, 2, 0)) # print('im', im.shape) im = Image.fromarray(im) fullpath = '%s_g%d.png' % (save_name, k) im.save(fullpath) for k in range(len(attention_maps)): if len(fake_imgs) > 1: im = fake_imgs[k + 1].detach().cpu() else: im = fake_imgs[0].detach().cpu() attn_maps = attention_maps[k] att_sze = attn_maps.size(2) img_set, sentences = \ build_super_images2(im[j].unsqueeze(0), captions[j].unsqueeze(0), [cap_lens_np[j]], self.ixtoword, [attn_maps[j]], att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s_a%d.png' % (save_name, k) im.save(fullpath)
def build_models(self): if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return # vgg16 network style_loss = VGGNet() for p in style_loss.parameters(): p.requires_grad = False print("Load the style loss model") style_loss.eval() image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) netG.apply(weights_init) for i in range(len(netsD)): netsD[i].apply(weights_init) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() style_loss = style_loss.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, netG, netsD, epoch, style_loss]
def build_models(self): # ###################encoders######################################## # image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) image_encoder.train() # #######################generator and discriminators############## # netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET # TODO: elif cfg.TREE.BRANCH_NUM > 3: netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) # epoch = 0 if cfg.PRETRAINED_CNN: image_encoder_params = torch.load( cfg.PRETRAINED_CNN, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(image_encoder_params) if cfg.PRETRAINED_G != '': state_dict = torch.load(cfg.PRETRAINED_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.PRETRAINED_G) if cfg.TRAIN.B_NET_D: Gname = cfg.PRETRAINED_G s_tmp = Gname[:Gname.rfind('/')] for i in range(len(netsD)): Dname = '%s/netD%d.pth' % ( s_tmp, i ) # the name of Ds should be consistent and differ from each other in i print('Load D from: ', Dname) state_dict = torch.load( Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: image_encoder = image_encoder.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [image_encoder, netG, netsD, epoch]
def gen_example(self, data_dic): if cfg.TRAIN.NET_G == '' or cfg.TRAIN.NET_C == '': print('Error: the path for main module or DCM is not found!') else: # The text encoder text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() # The image encoder image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) print('Load image encoder from:', img_encoder_path) image_encoder = image_encoder.cuda() image_encoder.eval() # The VGG network VGG = VGGNet() print("Load the VGG model") VGG.cuda() VGG.eval() # The main module if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')] model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) netG.cuda() netG.eval() # The DCM netDCM = DCM_Net() if cfg.TRAIN.NET_C != '': state_dict = \ torch.load(cfg.TRAIN.NET_C, map_location=lambda storage, loc: storage) netDCM.load_state_dict(state_dict) print('Load DCM from: ', cfg.TRAIN.NET_C) netDCM.cuda() netDCM.eval() for key in data_dic: save_dir = '%s/%s' % (s_tmp, key) mkdir_p(save_dir) captions, cap_lens, sorted_indices, imgs = data_dic[key] batch_size = captions.shape[0] nz = cfg.GAN.Z_DIM captions = Variable(torch.from_numpy(captions), volatile=True) cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True) captions = captions.cuda() cap_lens = cap_lens.cuda() for i in range(1): noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() ####################################################### # (1) Extract text and image embeddings ###################################################### hidden = text_encoder.init_hidden(batch_size) # The text embeddings words_embs, sent_emb = text_encoder( captions, cap_lens, hidden) # The image embeddings region_features, cnn_code = \ image_encoder(imgs[cfg.TREE.BRANCH_NUM - 1].unsqueeze(0)) mask = (captions == 0) ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) fake_imgs, attention_maps, mu, logvar, h_code, c_code = netG( noise, sent_emb, words_embs, mask, cnn_code, region_features) real_img = imgs[cfg.TREE.BRANCH_NUM - 1].unsqueeze(0) real_features = VGG(real_img)[0] fake_img = netDCM(h_code, real_features, sent_emb, words_embs,\ mask, c_code) cap_lens_np = cap_lens.cpu().data.numpy() for j in range(batch_size): save_name = '%s/%d_s_%d' % (save_dir, i, sorted_indices[j]) for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_g%d.png' % (save_name, k) im.save(fullpath) for k in range(len(attention_maps)): if len(fake_imgs) > 1: im = fake_imgs[k + 1].detach().cpu() else: im = fake_imgs[0].detach().cpu() attn_maps = attention_maps[k] att_sze = attn_maps.size(2) img_set, sentences = \ build_super_images2(im[j].unsqueeze(0), captions[j].unsqueeze(0), [cap_lens_np[j]], self.ixtoword, [attn_maps[j]], att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s_a%d.png' % (save_name, k) im.save(fullpath) save_name = '%s/%d_sf_%d' % (save_dir, 1, sorted_indices[j]) im = fake_img[j].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_SF.png' % (save_name) im.save(fullpath) save_name = '%s/%d_s_%d' % (save_dir, 1, 9) im = imgs[2].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_SR.png' % (save_name) im.save(fullpath)
def gen_example(self, data_dic): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: # Build and load the generator batch_size = 16 text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) print("=======self.n_words: %d", self.n_words) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) # customed restore text encoder parameters # ext_encoder.load_state_dict(state_dict) own_state = text_encoder.state_dict() for name, param in state_dict.items(): if name not in own_state: continue own_state[name] = param print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() # the path to save generated images if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET(text_encoder) s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')] model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) netG.cuda() netG.eval() for key in data_dic: save_dir = '%s/%s' % (s_tmp, key) mkdir_p(save_dir) captions, cap_lens, sorted_indices = data_dic[key] # batch_size = captions.shape[0] total_time = len(captions)//batch_size nz = cfg.GAN.Z_DIM # captions = Variable(torch.from_numpy(captions), volatile=True) # cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True) # captions = captions.cuda() # cap_lens = cap_lens.cuda() with torch.no_grad(): for i in range(total_time): # 16 noise = Variable(torch.FloatTensor(batch_size, nz)) noise = noise.cuda() caption_tmp = Variable(torch.from_numpy(captions[i*batch_size:(i+1)*batch_size])) if i < 3: print(caption_tmp.data) cap_len_tmp = Variable(torch.from_numpy(cap_lens[i*batch_size:(i+1)*batch_size])) caption_tmp = caption_tmp.cuda() cap_len_tmp = cap_len_tmp.cuda() ####################################################### # (1) Extract text embeddings ###################################################### hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb, _ = text_encoder(caption_tmp, cap_len_tmp, None) words_embs, sent_emb = words_embs.detach(), sent_emb.detach() mask = (caption_tmp == 0) ####################################################### # (2) Generate fake images ###################################################### random.seed(datetime.now()) rnd= random.randint(0,1000) torch.cuda.manual_seed(rnd) noise.data.normal_(0, 1) fake_imgs, attention_maps, _, _, _ = netG(noise, sent_emb, words_embs, mask, caption_tmp, cap_len_tmp) # G attention # cap_lens_np = cap_lens.cpu().data.numpy() cap_lens_np = cap_len_tmp.cpu().data.numpy() for j in range(batch_size): save_name = '%s/s_%d' % (save_dir, sorted_indices[i*batch_size+j]) for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() im = ((im + 1.0) / 2)* 255.0 im = im.astype(np.uint8) # print('im', im.shape) im = np.transpose(im, (1, 2, 0)) # print('im', im.shape) im = Image.fromarray(im) fullpath = '%s_g%d.png' % (save_name, k) im.save(fullpath) # save to seperate directory save_dir2 = '%s/stage_%d' % (save_dir, k) mkdir_p(save_dir2) fullpath = '%s/%d_g%d.png' % (save_dir2, sorted_indices[i*batch_size+j], k) im.save(fullpath) for k in range(len(attention_maps)): if len(fake_imgs) > 1: im = fake_imgs[k + 1].detach().cpu() else: im = fake_imgs[0].detach().cpu() attn_maps = attention_maps[k] att_sze = attn_maps.size(2) img_set, sentences = \ build_super_images2(im[j].unsqueeze(0), caption_tmp[j].unsqueeze(0), [cap_len_tmp[j]], self.ixtoword, [attn_maps[j]], att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s_a%d.png' % (save_name, k) im.save(fullpath)
def generate_fake_im(self, data_dic): global text_encoder_path, net_G_path # Build and load the generator ##################################### ## load the encoder # ##################################### text_encoder = \ BERT_RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(text_encoder_path, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Loaded text encoder from:', text_encoder_path) text_encoder.eval() text_encoder = text_encoder.cuda() netG = G_NET() ###################################### ## load the generator # ###################################### state_dict = \ torch.load(net_G_path, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load Generator from: ', net_G_path) s_tmp = net_G_path[:net_G_path.rfind('.pth')] netG.cuda() netG.eval() for key in data_dic: save_dir = '%s/%s' % (s_tmp, key) mkdir_p(save_dir) captions, cap_lens, sorted_indices = data_dic[key] batch_size = captions.shape[0] nz = cfg.GAN.Z_DIM captions = Variable(torch.from_numpy(captions)) cap_lens = Variable(torch.from_numpy(cap_lens)) captions = captions.cuda() cap_lens = cap_lens.cuda() for i in range(1): # 16 noise = Variable(torch.FloatTensor(batch_size, nz)) noise = noise.cuda() ####################################################### # (1) Extract text embeddings ###################################################### hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) mask = (captions == 0) ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) return fake_imgs, attention_maps
def evaluate(self): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: # Build and load the generator self.num_Ds = cfg.TREE.BRANCH_NUM self.base_num = 135 netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=self.gpus) print(netG) # state_dict = torch.load(cfg.TRAIN.NET_G) state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) # the path to save generated images s_tmp = cfg.TRAIN.NET_G istart = s_tmp.rfind('_') + 1 iend = s_tmp.rfind('.') iteration = int(s_tmp[istart:iend]) s_tmp = s_tmp[:s_tmp.rfind('/')] save_dir = '%s/iteration%d' % (s_tmp, iteration) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) if cfg.CUDA: netG.cuda() noise = noise.cuda() # switch to evaluate mode netG.eval() for step, data in enumerate(self.data_loader, 0): imgs, t_embeddings, filenames, _ = data embedding_dim = t_embeddings.size(1) batch_size = imgs[0].size(0) noise.data.resize_(batch_size, nz) noise.data.normal_(0, 1) crop_vbase = [] crop_base_imgs = torch.zeros(batch_size, 3, self.img_size, self.img_size) for step, (base_img_list) in enumerate(data[3]): if cfg.DATASET_NAME.find('flower') != -1: base_ix = random.randint(1, self.base_num) base_img_name = '%s/%s.jpg' % (base_img_list, str(base_ix)) else: temp_base_list = os.listdir(base_img_list) base_ix = random.randint(0, len(temp_base_list) - 1) base_img_name = '%s/%s.jpg' % (base_img_list, str(base_ix)) base_img = Image.open(base_img_name).convert('RGB') crop_base = base_img.resize([self.img_size, self.img_size]) crop_base = Torchtransform(crop_base) crop_base_imgs[step, :] = crop_base if cfg.CUDA: crop_vbase.append(Variable(crop_base_imgs).cuda()) else: crop_vbase.append(Variable(crop_base_imgs)) if cfg.CUDA: t_embeddings = Variable(t_embeddings).cuda() else: t_embeddings = Variable(t_embeddings) for i in range(embedding_dim): fake_imgs, fake_segs, _, _ = netG(noise, t_embeddings[:, i, :], crop_vbase) self.save_singleimages(fake_imgs, fake_segs[-1], crop_vbase[0], filenames, save_dir, i, self.img_size)
def build_models(self): def count_parameters(model): total_param = 0 for name, param in model.named_parameters(): if param.requires_grad: num_param = np.prod(param.size()) if param.dim() > 1: print(name, ':', 'x'.join(str(x) for x in list(param.size())), '=', num_param) else: print(name, ':', num_param) total_param += num_param return total_param # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # #######################generator and discriminators############## # netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM ==1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET # TODO: elif cfg.TREE.BRANCH_NUM > 3: netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: print('number of trainable parameters =', count_parameters(netG)) print('number of trainable parameters =', count_parameters(netsD[-1])) netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # # Create a target network. target_netG = deepcopy(netG) if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() # The target network is stored on the scondary GPU.--------------------------------- target_netG.cuda() #target_netG.ca_net.device = secondary_device #----------------------------------------------------------------------------------- for i in range(len(netsD)): netsD[i].cuda() # Disable training in the target network: for p in target_netG.parameters(): p.requires_grad = False return [text_encoder, image_encoder, netG, target_netG, netsD, epoch]
def sampling(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: if split_dir == 'test': split_dir = 'valid' # Build and load the generator if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() netG.apply(weights_init) netG.cuda() netG.eval() # load text encoder text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() #load image encoder image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) print('Load image encoder from:', img_encoder_path) image_encoder = image_encoder.cuda() image_encoder.eval() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() model_dir = cfg.TRAIN.NET_G state_dict = torch.load(model_dir, map_location=lambda storage, loc: storage) # state_dict = torch.load(cfg.TRAIN.NET_G) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) # the path to save generated images s_tmp = model_dir[:model_dir.rfind('.pth')] save_dir = '%s/%s' % (s_tmp, split_dir) mkdir_p(save_dir) cnt = 0 R_count = 0 R = np.zeros(30000) cont = True for ii in range(11): # (cfg.TEXT.CAPTIONS_PER_IMAGE): if (cont == False): break for step, data in enumerate(self.data_loader, 0): cnt += batch_size if (cont == False): break if step % 100 == 0: print('cnt: ', cnt) # if step > 50: # break imgs, captions, cap_lens, class_ids, keys = prepare_data(data) hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach(), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask, cap_lens) for j in range(batch_size): s_tmp = '%s/single/%s' % (save_dir, keys[j]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): #print('Make a new folder: ', folder) mkdir_p(folder) k = -1 # for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() # [-1, 1] --> [0, 255] im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_s%d_%d.png' % (s_tmp, k, ii) im.save(fullpath) _, cnn_code = image_encoder(fake_imgs[-1]) for i in range(batch_size): mis_captions, mis_captions_len = self.dataset.get_mis_caption(class_ids[i]) hidden = text_encoder.init_hidden(99) _, sent_emb_t = text_encoder(mis_captions, mis_captions_len, hidden) rnn_code = torch.cat((sent_emb[i, :].unsqueeze(0), sent_emb_t), 0) ### cnn_code = 1 * nef ### rnn_code = 100 * nef scores = torch.mm(cnn_code[i].unsqueeze(0), rnn_code.transpose(0, 1)) # 1* 100 cnn_code_norm = torch.norm(cnn_code[i].unsqueeze(0), 2, dim=1, keepdim=True) rnn_code_norm = torch.norm(rnn_code, 2, dim=1, keepdim=True) norm = torch.mm(cnn_code_norm, rnn_code_norm.transpose(0, 1)) scores0 = scores / norm.clamp(min=1e-8) if torch.argmax(scores0) == 0: R[R_count] = 1 R_count += 1 if R_count >= 30000: sum = np.zeros(10) np.random.shuffle(R) for i in range(10): sum[i] = np.average(R[i * 3000:(i + 1) * 3000 - 1]) R_mean = np.average(sum) R_std = np.std(sum) print("R mean:{:.4f} std:{:.4f}".format(R_mean, R_std)) cont = False
def loading_model(dataset_name='bird'): #IMPORTANT ARGUMENTS if (dataset_name=='bird') : cfg_file=os.path.join(current_dir,"cfg/eval_bird.yml") else : cfg_file=os.path.join(current_dir,"cfg/eval_coco.yml") gpu_id=-1 #change it to 0 or more when using gpu data_dir='' manualSeed = 100 #cfg file set if cfg_file is not None: cfg_from_file(cfg_file) if gpu_id != -1: cfg.GPU_ID = gpu_id else: cfg.CUDA = False if data_dir != '': cfg.DATA_DIR = data_dir now = datetime.datetime.now(dateutil.tz.tzlocal()) timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') output_dir = '../output/%s_%s_%s' % \ (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) split_dir, bshuffle = 'train', True if not cfg.TRAIN.FLAG: # bshuffle = False split_dir = 'test' # Get data loader imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM - 1)) image_transform = transforms.Compose([ transforms.Scale(int(imsize * 76 / 64)), transforms.RandomCrop(imsize), transforms.RandomHorizontalFlip()]) dataset = TextDataset(cfg.DATA_DIR, split_dir, base_size=cfg.TREE.BASE_SIZE, transform=image_transform) assert dataset dataloader = torch.utils.data.DataLoader( dataset, batch_size=cfg.TRAIN.BATCH_SIZE, drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS)) ###setting up ALGO # Define models and go to train/evaluate algo = trainer(output_dir, dataloader, dataset.n_words, dataset.ixtoword) #loading text ENCODER text_encoder = RNN_ENCODER(algo.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) #TRAIN.NET_E path can be given directly text_encoder.load_state_dict(state_dict) # print('Load text encoder from:', cfg.TRAIN.NET_E) ###edited here if cfg.CUDA: text_encoder = text_encoder.cuda() text_encoder.eval() #LOADING Generator netG = G_NET() model_dir = cfg.TRAIN.NET_G #directory for model can be given directly as well state_dict = torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) # print('Load G from: ', model_dir) ###edited here if cfg.CUDA: netG.cuda() netG.eval() return [algo,text_encoder,netG,dataset]
def build_models(self): # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # #######################generator and discriminators############## # netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET # TODO: elif cfg.TREE.BRANCH_NUM > 3: netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs text_encoder = nn.DataParallel(text_encoder) image_encoder = nn.DataParallel(image_encoder) netG = nn.DataParallel(netG) for i in range(len(netsD)): netsD[i] = nn.DataParallel(netsD[i]) image_encoder.to(self.device) text_encoder.to(self.device) netG.to(self.device) for i in range(len(netsD)): netsD[i].to(self.device) # if cfg.CUDA and torch.cuda.is_available(): # text_encoder = text_encoder.cuda() # image_encoder = image_encoder.cuda() # netG.cuda() # for i in range(len(netsD)): # netsD[i].cuda() # if cfg.PARALLEL: # netG = torch.nn.DataParallel(netG, device_ids=[0, 1, 2]) # text_encoder = torch.nn.DataParallel(text_encoder, device_ids=[0, 1, 2]) # image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[0, 1, 2]) # for i in range(len(netsD)): # netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=[0, 1, 2]) return [text_encoder, image_encoder, netG, netsD, epoch]
def build_models(self): # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) # customed restore text encoder parameters own_state = text_encoder.state_dict() for name, param in state_dict.items(): if name not in own_state: continue own_state[name] = param # text_encoder.load_state_dict(state_dict) # customed restore text encoder parameteres end for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.train() # #######################generator and discriminators############## # netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM ==1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET # TODO: elif cfg.TREE.BRANCH_NUM > 3: netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET(text_encoder) if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, netG, netsD, epoch]
def gen_example(self, data_dic): if cfg.TRAIN.NET_G == '': print('Error: the path for models is not found!') else: # Build and load the generator image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) print('Load image encoder from:', img_encoder_path) image_encoder = image_encoder.cuda() image_encoder.eval() # the path to save generated images if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')] model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) netG.cuda() netG.eval() words_embs = Variable( torch.zeros(1, cfg.TEXT.EMBEDDING_DIM, cfg.TEXT.WORDS_NUM)) mask = Variable(torch.zeros(1, cfg.TEXT.WORDS_NUM)) noise = Variable(torch.FloatTensor(1, cfg.GAN.Z_DIM), volatile=True) words_embs, mask, noise = words_embs.cuda(), mask.cuda( ), noise.cuda() for key in data_dic: save_path = '%s/custom/%s' % (s_tmp, key) img = data_dic[key] img = Variable(img).unsqueeze(0).cuda() ####################################################### # (1) Extract image embeddings ###################################################### _, sent_emb = image_encoder(img) ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) # G attention for k in range(len(fake_imgs)): im = fake_imgs[k][-1].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) # print('im', im.shape) im = np.transpose(im, (1, 2, 0)) # print('im', im.shape) im = Image.fromarray(im) fullpath = '%s_g%d.png' % (save_path, k) im.save(fullpath)
def sample(self, split_dir, num_samples=25, draw_bbox=False): from PIL import Image, ImageDraw, ImageFont import cPickle as pickle import torchvision import torchvision.utils as vutils if cfg.TRAIN.NET_G == '': print('Error: the path for model NET_G is not found!') else: if split_dir == 'test': split_dir = 'valid' # Build and load the generator text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() batch_size = cfg.TRAIN.BATCH_SIZE nz = cfg.GAN.Z_DIM model_dir = cfg.TRAIN.NET_G state_dict = torch.load(model_dir, map_location=lambda storage, loc: storage) # state_dict = torch.load(cfg.TRAIN.NET_G) netG = G_NET() print('Load G from: ', model_dir) netG.apply(weights_init) netG.load_state_dict(state_dict["netG"]) netG.cuda() netG.eval() # the path to save generated images s_tmp = model_dir[:model_dir.rfind('.pth')] save_dir = '%s_%s' % (s_tmp, split_dir) mkdir_p(save_dir) ####################################### noise = Variable(torch.FloatTensor(9, nz)) imsize = 256 for step, data in enumerate(self.data_loader, 0): if step >= num_samples: break imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot, bbox = \ self.prepare_data(data, eval=True) transf_matrices_inv = transformation_matrices[1][0].unsqueeze( 0) label_one_hot = label_one_hot[0].unsqueeze(0) img = imgs[-1][0] val_image = img.view(1, 3, imsize, imsize) hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) words_embs, sent_emb = words_embs[0].unsqueeze( 0).detach(), sent_emb[0].unsqueeze(0).detach() words_embs = words_embs.repeat(9, 1, 1) sent_emb = sent_emb.repeat(9, 1) mask = (captions == 0) mask = mask[0].unsqueeze(0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] mask = mask.repeat(9, 1) transf_matrices_inv = transf_matrices_inv.repeat(9, 1, 1, 1) label_one_hot = label_one_hot.repeat(9, 1, 1) ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (noise, sent_emb, words_embs, mask, transf_matrices_inv, label_one_hot) with torch.no_grad(): fake_imgs, _, mu, logvar = nn.parallel.data_parallel( netG, inputs, self.gpus) data_img = torch.FloatTensor(10, 3, imsize, imsize).fill_(0) data_img[0] = val_image data_img[1:10] = fake_imgs[-1] if draw_bbox: for idx in range(3): x, y, w, h = tuple( [int(imsize * x) for x in bbox[0, idx]]) w = imsize - 1 if w > imsize - 1 else w h = imsize - 1 if h > imsize - 1 else h if x <= -1: break data_img[:10, :, y, x:x + w] = 1 data_img[:10, :, y:y + h, x] = 1 data_img[:10, :, y + h, x:x + w] = 1 data_img[:10, :, y:y + h, x + w] = 1 # get caption cap = captions[0].data.cpu().numpy() sentence = "" for j in range(len(cap)): if cap[j] == 0: break word = self.ixtoword[cap[j]].encode( 'ascii', 'ignore').decode('ascii') sentence += word + " " sentence = sentence[:-1] vutils.save_image(data_img, '{}/{}_{}.png'.format( save_dir, sentence, step), normalize=True, nrow=10) print("Saved {} files to {}".format(step, save_dir))
def evaluate(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: # Build and load the generator if split_dir == 'test': split_dir = 'valid' netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=self.gpus) print(netG) # state_dict = torch.load(cfg.TRAIN.NET_G) state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) netE = load_embedding_model(self.data_loader.dataset.dictionary) print(netE) nz = cfg.GAN.Z_DIM sample_size = cfg.TEST.NUM_IMAGES noise = Variable(torch.FloatTensor(sample_size, nz)) if cfg.CUDA: netG.cuda() netE.cuda() noise = noise.cuda() # switch to evaluate mode netG.eval() count = 0 output_dir = os.path.join(cfg.OUTPUT_DIR, cfg.EXPERIMENT_NAME) for step, data in enumerate( tqdm(self.data_loader, desc='evaluate'), 0): imgs, txt_ids, txts = data if cfg.CUDA: txt_ids = Variable(txt_ids).cuda() else: txt_ids = Variable(txt_ids) txts_embeddings = netE(txt_ids) batch_size = imgs[0].size(0) imgs64, imgs128, imgs256 = [], [], [] for i in range(0, batch_size): noise.data.normal_(0, 1) txt_embedding = txts_embeddings[i].repeat(sample_size, 1) fake_imgs, _, _ = netG(noise, txt_embedding) imgs64.append(normalize_(fake_imgs[0])) imgs128.append(normalize_(fake_imgs[1])) imgs256.append(normalize_(fake_imgs[2])) save_images_with_text(imgs64, imgs128, imgs256, imgs, txts, batch_size, cfg.TEXT.MAX_LEN, count, output_dir) count = count + batch_size + 1
def generate_fake_images_with_incremental_noise(self, data_dic, sizeim): global text_encoder_path, net_G_path print(os.getcwd(), os.path.join(os.getcwd(), text_encoder_path)) text_encoder_path = os.path.join(os.getcwd(), text_encoder_path) net_G_path = os.path.join(os.getcwd(), net_G_path) # Build and load the generator ##################################### ## load the encoder # ##################################### print('Loading text encoder from:', text_encoder_path) text_encoder = \ BERT_RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(text_encoder_path, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Loaded text encoder from:', text_encoder_path) text_encoder.eval() text_encoder = text_encoder.cuda() netG = G_NET() ###################################### ## load the generator # ###################################### state_dict = \ torch.load(net_G_path, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load Generator from: ', net_G_path) s_tmp = net_G_path[:net_G_path.rfind('.pth')] netG.cuda() netG.eval() for key in data_dic: save_dir = '%s/%s' % ('res', key) mkdir_p(save_dir) captions, cap_lens, sorted_indices = data_dic[key] batch_size = captions.shape[0] nz = cfg.GAN.Z_DIM captions = Variable(torch.from_numpy(captions)) cap_lens = Variable(torch.from_numpy(cap_lens)) captions = captions.cuda() cap_lens = cap_lens.cuda() base_noise = Variable(torch.FloatTensor(batch_size, nz)) base_noise = base_noise.cuda() for i in range(sizeim): # number of images to be created noise = base_noise.clone() noise[0][i % 100] = base_noise[0][i % 100] + torch.mean(base_noise) ####################################################### # (1) Extract text embeddings ###################################################### hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) mask = (captions == 0) ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) im = fake_imgs[2].squeeze(0).data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) # print('im', im.shape) im = np.transpose(im, (1, 2, 0)) # print('im', im.shape) im = Image.fromarray(im) fullpath = os.path.join(save_dir, '{0}.png'.format(i)) im.save(fullpath)
def sampling(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for model is not found!') else: if split_dir == 'test': split_dir = 'valid' # Build and load the generator if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() netG.apply(weights_init) netG.cuda() netG.eval() # text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) # the path to save generated images s_tmp = model_dir[:model_dir.rfind('.pth')] save_dir = '%s/%s' % (s_tmp, split_dir) mkdir_p(save_dir) cnt = 0 for _ in range(1): # (cfg.TEXT.CAPTIONS_PER_IMAGE): for step, data in enumerate(self.data_loader, 0): cnt += batch_size if step % 100 == 0: print('step: ', step) # if step > 50: # break imgs, captions, cap_lens, class_ids, keys = prepare_data(data) hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach(), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] # (2) Generate fake images noise.data.normal_(0, 1) fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask) for j in range(batch_size): s_tmp = '%s/single/%s' % (save_dir, keys[j]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) k = -1 # for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() # [-1, 1] --> [0, 255] im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_s%d.png' % (s_tmp, k) im.save(fullpath)
def evaluate_finegan(self): self.save_dir = os.path.join(cfg.SAVE_DIR, 'images') mkdir_p(self.save_dir) random.seed(datetime.now()) depth = cfg.TEST_DEPTH res = 32 * 2**depth if cfg.TRAIN.NET_G == '': print('Error: the path for model not found!') else: # Build and load the generator netG = G_NET(depth) netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=self.gpus) model_dict = netG.state_dict() state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) state_dict = { k: v for k, v in state_dict.items() if k in model_dict } model_dict.update(state_dict) netG.load_state_dict(model_dict) print('Load ', cfg.TRAIN.NET_G) # Uncomment this to print Generator layers # print(netG) nrow = 6 ncol = 4 z_std = 0.1 p_vs_c = False reprod = False if not reprod: torch.manual_seed(random.randint(-9999, 9999)) bg_li = [] pf_li = [] cf_li = [] pk_li = [] ck_li = [] pfg_li = [] cfg_li = [] pfgmk_li = [] cfgmk_li = [] b = random.randint(0, cfg.FINE_GRAINED_CATEGORIES - 1) nz = cfg.GAN.Z_DIM noise = torch.FloatTensor(1, nz) noise.data.normal_(0, z_std) # noise = noise.repeat(self.batch_size, 1) if cfg.CUDA: netG.cuda() noise = noise.cuda() netG.eval() c_li = np.random.randint(0, cfg.FINE_GRAINED_CATEGORIES - 1, size=nrow) p_li = np.random.randint(0, cfg.SUPER_CATEGORIES - 1, size=nrow) for k in range(ncol): p = p_li[k] # p = random.randint(0, cfg.SUPER_CATEGORIES-1) for i in range(nrow): bg_code = torch.zeros( [self.batch_size, cfg.FINE_GRAINED_CATEGORIES]) p_code = torch.zeros( [self.batch_size, cfg.SUPER_CATEGORIES]) c_code = torch.zeros( [self.batch_size, cfg.FINE_GRAINED_CATEGORIES]) c = c_li[i] for j in range(self.batch_size): bg_code[j][b] = 1 p_code[j][p] = 1 c_code[j][c] = 1 fake_imgs, fg_imgs, mk_imgs, fgmk_imgs = netG( noise, c_code, None, p_code, bg_code) # Forward pass through the generator bg_li.append(fake_imgs[3 * depth][0]) pf_li.append(fake_imgs[3 * depth + 1][0]) cf_li.append(fake_imgs[3 * depth + 2][0]) pk_li.append(mk_imgs[2 * depth][0]) ck_li.append(mk_imgs[2 * depth + 1][0]) pfg_li.append(fg_imgs[2 * depth][0]) cfg_li.append(fg_imgs[2 * depth + 1][0]) pfgmk_li.append(fgmk_imgs[2 * depth][0]) cfgmk_li.append(fgmk_imgs[2 * depth + 1][0]) save_image(bg_li, self.save_dir, 'background_pvc', nrow, res) save_image(pf_li, self.save_dir, 'parent_final_pvc', nrow, res) save_image(cf_li, self.save_dir, 'child_final_pvc', nrow, res) save_image(pfg_li, self.save_dir, 'parent_foreground_pvc', nrow, res) save_image(cfg_li, self.save_dir, 'child_foreground_pvc', nrow, res) save_image(pk_li, self.save_dir, 'parent_mask_pvc', nrow, res) save_image(ck_li, self.save_dir, 'child_mask_pvc', nrow, res) save_image(pfgmk_li, self.save_dir, 'parent_foreground_masked_pvc', nrow, res) save_image(cfgmk_li, self.save_dir, 'child_foreground_masked_pvc', nrow, res) bg_li = [] pf_li = [] cf_li = [] pk_li = [] ck_li = [] pfg_li = [] cfg_li = [] pfgmk_li = [] cfgmk_li = [] for _ in range(ncol): noise.data.normal_(0, z_std) for i in range(nrow): bg_code = torch.zeros( [self.batch_size, cfg.FINE_GRAINED_CATEGORIES]) p_code = torch.zeros( [self.batch_size, cfg.SUPER_CATEGORIES]) c_code = torch.zeros( [self.batch_size, cfg.FINE_GRAINED_CATEGORIES]) c = c_li[i] p = p_li[i] for j in range(self.batch_size): bg_code[j][b] = 1 p_code[j][p] = 1 c_code[j][c] = 1 fake_imgs, fg_imgs, mk_imgs, fgmk_imgs = netG( noise, c_code, None, p_code, bg_code) # Forward pass through the generator bg_li.append(fake_imgs[3 * depth][0]) pf_li.append(fake_imgs[3 * depth + 1][0]) cf_li.append(fake_imgs[3 * depth + 2][0]) pk_li.append(mk_imgs[2 * depth][0]) ck_li.append(mk_imgs[2 * depth + 1][0]) pfg_li.append(fg_imgs[2 * depth][0]) cfg_li.append(fg_imgs[2 * depth + 1][0]) pfgmk_li.append(fgmk_imgs[2 * depth][0]) cfgmk_li.append(fgmk_imgs[2 * depth + 1][0]) save_image(bg_li, self.save_dir, 'background_zvpc', nrow, res) save_image(pf_li, self.save_dir, 'parent_final_zvpc', nrow, res) save_image(cf_li, self.save_dir, 'child_final_zvpc', nrow, res) save_image(pfg_li, self.save_dir, 'parent_foreground_zvpc', nrow, res) save_image(cfg_li, self.save_dir, 'child_foreground_zvpc', nrow, res) save_image(pk_li, self.save_dir, 'parent_mask_zvpc', nrow, res) save_image(ck_li, self.save_dir, 'child_mask_zvpc', nrow, res) save_image(pfgmk_li, self.save_dir, 'parent_foreground_masked_zvpc', nrow, res) save_image(cfgmk_li, self.save_dir, 'child_foreground_masked_zvpc', nrow, res)
def build_models(self): # text encoders if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # Caption models - cnn_encoder and rnn_decoder caption_cnn = CAPTION_CNN(cfg.CAP.embed_size) caption_cnn.load_state_dict(torch.load(cfg.CAP.caption_cnn_path, map_location=lambda storage, loc: storage)) for p in caption_cnn.parameters(): p.requires_grad = False print('Load caption model from:', cfg.CAP.caption_cnn_path) caption_cnn.eval() caption_rnn = CAPTION_RNN(cfg.CAP.embed_size, cfg.CAP.hidden_size * 2, self.n_words, cfg.CAP.num_layers) caption_rnn.load_state_dict(torch.load(cfg.CAP.caption_rnn_path, map_location=lambda storage, loc: storage)) for p in caption_rnn.parameters(): p.requires_grad = False print('Load caption model from:', cfg.CAP.caption_rnn_path) # Generator and Discriminator: netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() caption_cnn = caption_cnn.cuda() caption_rnn = caption_rnn.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, epoch]
def sample_images(self): sample_size = 24 save_dir = '../sample_images/' save_final = '../sample_finals/' if not os.path.exists(save_dir): os.makedirs(save_dir) if not os.path.exists(save_final): os.makedirs(save_final) random.seed(datetime.now()) depth = cfg.TEST_DEPTH res = 32 * 2**depth if cfg.TRAIN.NET_G == '': print('Error: the path for model not found!') else: # Build and load the generator netG = G_NET(depth) netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=self.gpus) model_dict = netG.state_dict() state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) state_dict = { k: v for k, v in state_dict.items() if k in model_dict } model_dict.update(state_dict) netG.load_state_dict(model_dict) print('Load ', cfg.TRAIN.NET_G) # Uncomment this to print Generator layers # print(netG) nz = cfg.GAN.Z_DIM noise = torch.FloatTensor(1, nz) # noise.data.normal_(0, 1) # noise = noise.repeat(1, 1) if cfg.CUDA: netG.cuda() noise = noise.cuda() netG.eval() for i in tqdm(range(sample_size)): noise.data.normal_(0, 1) bg_code = torch.zeros([1, cfg.FINE_GRAINED_CATEGORIES]).cuda() p_code = torch.zeros([1, cfg.SUPER_CATEGORIES]).cuda() c_code = torch.zeros([1, cfg.FINE_GRAINED_CATEGORIES]).cuda() b = random.randint(0, cfg.FINE_GRAINED_CATEGORIES - 1) p = random.randint(0, cfg.SUPER_CATEGORIES - 1) c = random.randint(0, cfg.FINE_GRAINED_CATEGORIES - 1) bg_code[0][b] = 1 p_code[0][p] = 1 c_code[0][c] = 1 fake_imgs, fg_imgs, mk_imgs, fgmk_imgs = netG( noise, c_code, 1, p_code, bg_code) # Forward pass through the generator self.save_image(fake_imgs[3 * depth + 0][0], save_dir, '%d_bg' % i) self.save_image(fake_imgs[3 * depth + 1][0], save_dir, '%d_pf' % i) self.save_image(fake_imgs[3 * depth + 2][0], save_dir, '%d_cf' % i) self.save_image(fake_imgs[3 * depth + 2][0], save_final, '%d' % i) # self.save_image(fg_imgs[2 * depth + 0][0], save_dir, 'parent_foreground') # self.save_image(fg_imgs[2 * depth + 1][0], save_dir, 'child_foreground') self.save_image(mk_imgs[2 * depth + 0][0], save_dir, '%d_pmk' % i) self.save_image(mk_imgs[2 * depth + 1][0], save_dir, '%d_cmk' % i)
def evaluate(self, split_dir, n_samples=4, extractor='googlenet', save_dir=None): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: # Build and load the generator if split_dir == 'test': split_dir = 'valid' netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=self.gpus) mapper = EXTRACTOR_MAPPING[extractor]() mapper = torch.nn.DataParallel(mapper, device_ids=self.gpus) set_parameter_requires_grad(netG, False) set_parameter_requires_grad(mapper, False) print(netG) # state_dict = torch.load(cfg.TRAIN.NET_G) state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) if save_dir is None: # the path to save generated images s_tmp = cfg.TRAIN.NET_G istart = s_tmp.rfind('_') + 1 iend = s_tmp.rfind('.') iteration = int(s_tmp[istart:iend]) s_tmp = s_tmp[:s_tmp.rfind('/')] save_dir = '%s/iteration%d' % (s_tmp, iteration) nz = cfg.GAN.Z_DIM if cfg.CUDA: netG.cuda() mapper.cuda() # switch to evaluate mode netG.eval() mapper.eval() synthetic_ds = SyntheticDataset(save_dir) for class_embeddings, synthetic_id in self.data_loader.dataset.embeddings_by_class( ): if cfg.CUDA: class_embeddings = class_embeddings.cuda() class_embeddings = class_embeddings.mean( dim=1) # mean of 10 captions per image for i in range(class_embeddings.size(0)): image_embeddings = class_embeddings[i].repeat(n_samples, 1) noise = torch.randn(n_samples, nz) if cfg.CUDA: noise = noise.cuda() imgs, _, _ = netG(noise, image_embeddings) imgs = imgs[-1] samples = mapper(imgs) synthetic_ds.save_pairs(samples, synthetic_id)
def build_models(self): # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': raise FileNotFoundError( 'No pretrained text encoder found in directory DAMSMencoders/. \n' + 'Please train the DAMSM first before training the GAN (see README for details).' ) image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() if self.text_encoder_type == 'rnn': text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) elif self.text_encoder_type == 'transformer': text_encoder = GPT2Model.from_pretrained(TRANSFORMER_ENCODER) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # #######################generator and discriminators############## # netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET # TODO: elif cfg.TREE.BRANCH_NUM > 3: netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] elif cfg.GAN.B_STYLEGEN: netG = G_NET_STYLED() if cfg.GAN.B_STYLEDISC: from model import D_NET_STYLED64, D_NET_STYLED128, D_NET_STYLED256 if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET_STYLED64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET_STYLED128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET_STYLED256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: else: from model import D_NET64, D_NET128, D_NET256 if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print(netG.__class__) for i in netsD: print(i.__class__) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) if cfg.GAN.B_STYLEGEN: netG.w_ewma = state_dict['w_ewma'] if cfg.CUDA: netG.w_ewma = netG.w_ewma.to('cuda:' + str(cfg.GPU_ID)) netG.load_state_dict(state_dict['netG_state_dict']) else: netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, netG, netsD, epoch]
def build_models(self): # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # #######################generator and discriminators############## # netsD = [] from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) netG.apply(weights_init) for i in range(len(netsD)): netsD[i].apply(weights_init) print('# of netsD', len(netsD)) epoch = 0 if self.resume: checkpoint_list = sorted( [ckpt for ckpt in glob.glob(self.model_dir + "/" + '*.pth')]) latest_checkpoint = checkpoint_list[-1] state_dict = torch.load(latest_checkpoint, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict["netG"]) for i in range(len(netsD)): netsD[i].load_state_dict(state_dict["netD"][i]) epoch = int(latest_checkpoint[-8:-4]) + 1 print("Resuming training from checkpoint {} at epoch {}.".format( latest_checkpoint, epoch)) # if cfg.TRAIN.NET_G != '': state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = torch.load( Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, netG, netsD, epoch]
def evaluate_finegan(self): if cfg.TRAIN.NET_G == '': print('Error: the path for model not found!') else: # Build and load the generator netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=self.gpus) model_dict = netG.state_dict() state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) state_dict = { k: v for k, v in state_dict.items() if k in model_dict } model_dict.update(state_dict) netG.load_state_dict(model_dict) print('Load ', cfg.TRAIN.NET_G) # Uncomment this to print Generator layers # print(netG) nz = cfg.GAN.Z_DIM noise = torch.FloatTensor(self.batch_size, nz) noise.data.normal_(0, 1) if cfg.CUDA: netG.cuda() noise = noise.cuda() netG.eval() background_class = cfg.TEST_BACKGROUND_CLASS parent_class = cfg.TEST_PARENT_CLASS child_class = cfg.TEST_CHILD_CLASS bg_code = torch.zeros( [self.batch_size, cfg.FINE_GRAINED_CATEGORIES]) p_code = torch.zeros([self.batch_size, cfg.SUPER_CATEGORIES]) c_code = torch.zeros( [self.batch_size, cfg.FINE_GRAINED_CATEGORIES]) for j in range(self.batch_size): bg_code[j][background_class] = 1 p_code[j][parent_class] = 1 c_code[j][child_class] = 1 fake_imgs, fg_imgs, mk_imgs, fgmk_imgs = netG( noise, c_code, p_code, bg_code) # Forward pass through the generator self.save_image(fake_imgs[0][0], self.save_dir, 'background') self.save_image(fake_imgs[1][0], self.save_dir, 'parent_final') self.save_image(fake_imgs[2][0], self.save_dir, 'child_final') self.save_image(fg_imgs[0][0], self.save_dir, 'parent_foreground') self.save_image(fg_imgs[1][0], self.save_dir, 'child_foreground') self.save_image(mk_imgs[0][0], self.save_dir, 'parent_mask') self.save_image(mk_imgs[1][0], self.save_dir, 'child_mask') self.save_image(fgmk_imgs[0][0], self.save_dir, 'parent_foreground_masked') self.save_image(fgmk_imgs[1][0], self.save_dir, 'child_foreground_masked')
def sampling(self, split_dir): if cfg.TRAIN.NET_G == '' or cfg.TRAIN.NET_C == '': print('Error: the path for main module or DCM is not found!') else: if split_dir == 'test': split_dir = 'valid' if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() netG.apply(weights_init) netG.cuda() netG.eval() # The text encoder text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() # The image encoder image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) print('Load image encoder from:', img_encoder_path) image_encoder = image_encoder.cuda() image_encoder.eval() # The VGG network VGG = VGGNet() print("Load the VGG model") VGG.cuda() VGG.eval() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() # The DCM netDCM = DCM_Net() if cfg.TRAIN.NET_C != '': state_dict = \ torch.load(cfg.TRAIN.NET_C, map_location=lambda storage, loc: storage) netDCM.load_state_dict(state_dict) print('Load DCM from: ', cfg.TRAIN.NET_C) netDCM.cuda() netDCM.eval() model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) # the path to save generated images s_tmp = model_dir[:model_dir.rfind('.pth')] save_dir = '%s/%s' % (s_tmp, split_dir) mkdir_p(save_dir) cnt = 0 idx = 0 for _ in range(5): # (cfg.TEXT.CAPTIONS_PER_IMAGE): for step, data in enumerate(self.data_loader, 0): cnt += batch_size if step % 100 == 0: print('step: ', step) imgs, captions, cap_lens, class_ids, keys, wrong_caps, \ wrong_caps_len, wrong_cls_id = prepare_data(data) ####################################################### # (1) Extract text and image embeddings ###################################################### hidden = text_encoder.init_hidden(batch_size) words_embs, sent_emb = text_encoder( wrong_caps, wrong_caps_len, hidden) words_embs, sent_emb = words_embs.detach( ), sent_emb.detach() mask = (wrong_caps == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] region_features, cnn_code = \ image_encoder(imgs[cfg.TREE.BRANCH_NUM - 1]) ####################################################### # (2) Modify real images ###################################################### noise.data.normal_(0, 1) fake_imgs, attention_maps, mu, logvar, h_code, c_code = netG( noise, sent_emb, words_embs, mask, cnn_code, region_features) real_img = imgs[cfg.TREE.BRANCH_NUM - 1] real_features = VGG(real_img)[0] fake_img = netDCM(h_code, real_features, sent_emb, words_embs,\ mask, c_code) for j in range(batch_size): s_tmp = '%s/single' % (save_dir) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) k = -1 im = fake_img[j].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_s%d.png' % (s_tmp, idx) idx = idx + 1 im.save(fullpath)
def build_models(self): # ################### models ######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return if cfg.TRAIN.NET_G == '': print('Error: no pretrained main module') return VGG = VGGNet() for p in VGG.parameters(): p.requires_grad = False print("Load the VGG model") VGG.eval() image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() if cfg.GAN.B_DCGAN: netG = G_DCGAN() from model import D_NET256 as D_NET netD = D_NET(b_jcu=False) else: from model import D_NET256 netG = G_NET() netD = D_NET256() netD.apply(weights_init) state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) netG.eval() print('Load G from: ', cfg.TRAIN.NET_G) epoch = 0 netDCM = DCM_Net() if cfg.TRAIN.NET_C != '': state_dict = \ torch.load(cfg.TRAIN.NET_C, map_location=lambda storage, loc: storage) netDCM.load_state_dict(state_dict) print('Load DCM from: ', cfg.TRAIN.NET_C) istart = cfg.TRAIN.NET_C.rfind('_') + 1 iend = cfg.TRAIN.NET_C.rfind('.') epoch = cfg.TRAIN.NET_C[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.NET_D != '': state_dict = \ torch.load(cfg.TRAIN.NET_D, map_location=lambda storage, loc: storage) netD.load_state_dict(state_dict) print('Load DCM Discriminator from: ', cfg.TRAIN.NET_D) if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() netDCM.cuda() VGG = VGG.cuda() netD.cuda() return [text_encoder, image_encoder, netG, netD, epoch, VGG, netDCM]
def build_models(self): print('Building models...') print('N_words: ', self.n_words) ##################### ## TEXT ENCODERS ## ##################### if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) print('Built image encoder: ', image_encoder) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Built text encoder: ', text_encoder) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() ###################### ## CAPTION MODELS ## ###################### # cnn_encoder and rnn_encoder if cfg.CAP.USE_ORIGINAL: caption_cnn = CAPTION_CNN(embed_size=cfg.TEXT.EMBEDDING_DIM) caption_rnn = CAPTION_RNN(embed_size=cfg.TEXT.EMBEDDING_DIM, hidden_size=cfg.CAP.HIDDEN_SIZE, vocab_size=self.n_words, num_layers=cfg.CAP.NUM_LAYERS) else: caption_cnn = Encoder() caption_rnn = Decoder(idx2word=self.ixtoword) caption_cnn_checkpoint = torch.load( cfg.CAP.CAPTION_CNN_PATH, map_location=lambda storage, loc: storage) caption_rnn_checkpoint = torch.load( cfg.CAP.CAPTION_RNN_PATH, map_location=lambda storage, loc: storage) caption_cnn.load_state_dict(caption_cnn_checkpoint['model_state_dict']) caption_rnn.load_state_dict(caption_rnn_checkpoint['model_state_dict']) for p in caption_cnn.parameters(): p.requires_grad = False print('Load caption model from: ', cfg.CAP.CAPTION_CNN_PATH) caption_cnn.eval() for p in caption_rnn.parameters(): p.requires_grad = False print('Load caption model from: ', cfg.CAP.CAPTION_RNN_PATH) ################################# ## GENERATOR & DISCRIMINATOR ## ################################# netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) text_encoder = text_encoder.to(cfg.DEVICE) image_encoder = image_encoder.to(cfg.DEVICE) caption_cnn = caption_cnn.to(cfg.DEVICE) caption_rnn = caption_rnn.to(cfg.DEVICE) netG.to(cfg.DEVICE) for i in range(len(netsD)): netsD[i].to(cfg.DEVICE) return [ text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, epoch ]