def build_models(): # build model ############################################################ text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) labels = Variable(torch.LongTensor(range(batch_size))) start_epoch = 0 if cfg.TRAIN.NET_E != '': state_dict = torch.load(cfg.TRAIN.NET_E) text_encoder.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_E) # name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = torch.load(name) image_encoder.load_state_dict(state_dict) print('Load ', name) istart = cfg.TRAIN.NET_E.rfind('_') + 8 iend = cfg.TRAIN.NET_E.rfind('.') start_epoch = cfg.TRAIN.NET_E[istart:iend] start_epoch = int(start_epoch) + 1 print('start_epoch', start_epoch) if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() labels = labels.cuda() return text_encoder, image_encoder, labels, start_epoch
def __init__(self, opt): """Initialize the pix2pix class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ BaseModel.__init__(self, opt) self.batch_size = opt.batch_size self.img_size = opt.crop_size # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> self.visual_names = ['real_A', 'fake_B', 'real_B'] # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks> if self.isTrain: self.model_names = ['G', 'D'] else: # during test time, only load G self.model_names = ['G'] # define networks (both generator and discriminator) self.text_encoder = RNN_ENCODER(opt.text_words_num, nhidden=opt.text_embedding_dim).to( self.device) state_dict = torch.load(opt.text_encoder, map_location=lambda storage, loc: storage) self.text_encoder.load_state_dict(state_dict) for p in self.text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', opt.text_encoder) self.text_encoder.eval() self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: # define loss functions self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>. self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D)
def models(word_len): cfg_from_file('../flask-server/AttnGAN/code/cfg/eval_plans2.yaml') text_encoder = cache.get('text_encoder') if text_encoder is None: #print("text_encoder not cached") text_encoder = RNN_ENCODER(word_len, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) if cfg.CUDA: text_encoder.cuda() text_encoder.eval() cache.set('text_encoder', text_encoder, timeout=60 * 60 * 24) netG = cache.get('netG') if netG is None: #print("netG not cached") netG = G_NET() state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) if cfg.CUDA: netG.cuda() netG.eval() cache.set('netG', netG, timeout=60 * 60 * 24) return text_encoder, netG
def __init__(self, output_dir, data_loader, n_words, ixtoword): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') mkdir_p(self.model_dir) mkdir_p(self.image_dir) torch.cuda.set_device(cfg.GPU_ID) cudnn.benchmark = True self.batch_size = cfg.TRAIN.BATCH_SIZE self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL self.n_words = n_words self.ixtoword = ixtoword self.data_loader = data_loader self.num_batches = len(self.data_loader) # Build and load the generator self.text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) self.text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) self.text_encoder = self.text_encoder.cuda() self.text_encoder.eval() # the path to save generated images if cfg.GAN.B_DCGAN: self.netG = G_DCGAN() else: self.netG = G_NET() s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')] model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) self.netG.load_state_dict(state_dict) print('Load G from: ', model_dir) self.netG.cuda() self.netG.eval()
class Pix2PixModel(BaseModel): """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. The model training requires '--dataset_mode aligned' dataset. By default, it uses a '--netG unet256' U-Net generator, a '--netD basic' discriminator (PatchGAN), and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf """ @staticmethod def modify_commandline_options(parser, is_train=True): """Add new dataset-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. For pix2pix, we do not use image buffer The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets. """ # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned') if is_train: parser.set_defaults(pool_size=0, gan_mode='vanilla') parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') return parser def __init__(self, opt): """Initialize the pix2pix class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ BaseModel.__init__(self, opt) self.batch_size = opt.batch_size self.img_size = opt.crop_size # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses> self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> self.visual_names = ['real_A', 'fake_B', 'real_B'] # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks> if self.isTrain: self.model_names = ['G', 'D'] else: # during test time, only load G self.model_names = ['G'] # define networks (both generator and discriminator) self.text_encoder = RNN_ENCODER(opt.text_words_num, nhidden=opt.text_embedding_dim).to( self.device) state_dict = torch.load(opt.text_encoder, map_location=lambda storage, loc: storage) self.text_encoder.load_state_dict(state_dict) for p in self.text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', opt.text_encoder) self.text_encoder.eval() self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: # define loss functions self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>. self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input (dict): include the data itself and its metadata information. The option 'direction' can be used to swap images in domain A and domain B. """ AtoB = self.opt.direction == 'AtoB' self.caption_len, sort_idx = input["caption_len"].sort(descending=True) self.real_A = input['A' if AtoB else 'B'][[sort_idx]].to(self.device) self.real_B = input['B' if AtoB else 'A'][[sort_idx]].to(self.device) self.caption = input["caption"][[sort_idx]].to(self.device) self.image_paths = list( np.array( input['A_paths' if AtoB else 'B_paths'])[[sort_idx.tolist()]]) # Encode text hidden = self.text_encoder.init_hidden(self.batch_size) _, self.sent_emb = self.text_encoder( self.caption, self.caption_len, hidden) # sent_emb: [batch_size(1), sent_dim(128)] self.tiled_sentence = self.sent_emb.unsqueeze(2).unsqueeze(3).repeat( 1, 1, self.img_size, self.img_size) def forward(self): """Run forward pass; called by both functions <optimize_parameters> and <test>.""" real_A = torch.cat( (self.real_A, self.tiled_sentence), 1 ) # real_A: [batch_size(1), 3+sent_dim(128), crop_size(256), crop_size(256)] self.fake_B = self.netG(real_A) # G(A) def backward_D(self): """Calculate GAN loss for the discriminator""" # Fake; stop backprop to the generator by detaching fake_B fake_AB = torch.cat( (self.real_A, self.fake_B, self.tiled_sentence), 1 ) # we use conditional GANs; we need to feed both input and output to the discriminator pred_fake = self.netD(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B, self.tiled_sentence), 1) pred_real = self.netD(real_AB) self.loss_D_real = self.criterionGAN(pred_real, True) # combine loss and calculate gradients self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): """Calculate GAN and L1 loss for the generator""" # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B, self.tiled_sentence), 1) pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 # combine loss and calculate gradients self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G.backward() def optimize_parameters(self): self.forward() # compute fake images: G(A) # update D self.set_requires_grad(self.netD, True) # enable backprop for D self.optimizer_D.zero_grad() # set D's gradients to zero self.backward_D() # calculate gradients for D self.optimizer_D.step() # update D's weights # update G self.set_requires_grad( self.netD, False) # D requires no gradients when optimizing G self.optimizer_G.zero_grad() # set G's gradients to zero self.backward_G() # calculate graidents for G self.optimizer_G.step() # udpate G's weights
def build_models(self): # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # #######################generator and discriminators############## # netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM ==1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET # TODO: elif cfg.TREE.BRANCH_NUM > 3: netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, netG, netsD, epoch]
def sampling(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: if split_dir == 'test': split_dir = 'valid' # Build and load the generator if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() netG.apply(weights_init) netG.cuda() netG.eval() # text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) # state_dict = torch.load(cfg.TRAIN.NET_G) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) # the path to save generated images s_tmp = model_dir[:model_dir.rfind('.pth')] save_dir = '%s/%s' % (s_tmp, split_dir) mkdir_p(save_dir) cnt = 0 for _ in range(1): # (cfg.TEXT.CAPTIONS_PER_IMAGE): for step, data in enumerate(self.data_loader, 0): cnt += batch_size if step % 100 == 0: print('step: ', step) # if step > 50: # break imgs, captions, cap_lens, class_ids, keys = prepare_data(data) hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach(), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask) for j in range(batch_size): s_tmp = '%s/single/%s' % (save_dir, keys[j]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) k = -1 # for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() # [-1, 1] --> [0, 255] im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_s%d.png' % (s_tmp, k) im.save(fullpath)
class condGANTrainer(object): def __init__(self, output_dir, data_loader, n_words, ixtoword): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') mkdir_p(self.model_dir) mkdir_p(self.image_dir) torch.cuda.set_device(cfg.GPU_ID) cudnn.benchmark = True self.batch_size = cfg.TRAIN.BATCH_SIZE self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL self.n_words = n_words self.ixtoword = ixtoword self.data_loader = data_loader self.num_batches = len(self.data_loader) # Build and load the generator self.text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) self.text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) self.text_encoder = self.text_encoder.cuda() self.text_encoder.eval() # the path to save generated images if cfg.GAN.B_DCGAN: self.netG = G_DCGAN() else: self.netG = G_NET() s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')] model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) self.netG.load_state_dict(state_dict) print('Load G from: ', model_dir) self.netG.cuda() self.netG.eval() def build_models(self): # ###################encoders######################################## # if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # #######################generator and discriminators############## # netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM ==1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET # TODO: elif cfg.TREE.BRANCH_NUM > 3: netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) # TODO: if cfg.TREE.BRANCH_NUM > 3: netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) # epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, netG, netsD, epoch] def define_optimizers(self, netG, netsD): optimizersD = [] num_Ds = len(netsD) for i in range(num_Ds): opt = optim.Adam(netsD[i].parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) optimizersD.append(opt) optimizerG = optim.Adam(netG.parameters(), lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) return optimizerG, optimizersD def prepare_labels(self): batch_size = self.batch_size real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) match_labels = Variable(torch.LongTensor(range(batch_size))) if cfg.CUDA: real_labels = real_labels.cuda() fake_labels = fake_labels.cuda() match_labels = match_labels.cuda() return real_labels, fake_labels, match_labels def save_model(self, netG, avg_param_G, netsD, epoch): backup_para = copy_G_params(netG) load_params(netG, avg_param_G) torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (self.model_dir, epoch)) load_params(netG, backup_para) # for i in range(len(netsD)): netD = netsD[i] torch.save(netD.state_dict(), '%s/netD%d.pth' % (self.model_dir, i)) print('Save G/Ds models.') def set_requires_grad_value(self, models_list, brequires): for i in range(len(models_list)): for p in models_list[i].parameters(): p.requires_grad = brequires def save_img_results(self, netG, noise, sent_emb, words_embs, mask, image_encoder, captions, cap_lens, gen_iterations, name='current'): # Save images fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) for i in range(len(attention_maps)): if len(fake_imgs) > 1: img = fake_imgs[i + 1].detach().cpu() lr_img = fake_imgs[i].detach().cpu() else: img = fake_imgs[0].detach().cpu() lr_img = None attn_maps = attention_maps[i] att_sze = attn_maps.size(2) img_set, _ = \ build_super_images(img, captions, self.ixtoword, attn_maps, att_sze, lr_imgs=lr_img) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/G_%s_%d_%d.png'\ % (self.image_dir, name, gen_iterations, i) im.save(fullpath) # for i in range(len(netsD)): i = -1 img = fake_imgs[i].detach() region_features, _ = image_encoder(img) att_sze = region_features.size(2) _, _, att_maps = words_loss(region_features.detach(), words_embs.detach(), None, cap_lens, None, self.batch_size) img_set, _ = \ build_super_images(fake_imgs[i].detach().cpu(), captions, self.ixtoword, att_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/D_%s_%d.png'\ % (self.image_dir, name, gen_iterations) im.save(fullpath) def train(self): text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models() avg_param_G = copy_G_params(netG) optimizerG, optimizersD = self.define_optimizers(netG, netsD) real_labels, fake_labels, match_labels = self.prepare_labels() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() gen_iterations = 0 # gen_iterations = start_epoch * self.num_batches for epoch in range(start_epoch, self.max_epoch): start_t = time.time() data_iter = iter(self.data_loader) step = 0 while step < self.num_batches: # reset requires_grad to be trainable for all Ds # self.set_requires_grad_value(netsD, True) ###################################################### # (1) Prepare training data and Compute text embeddings ###################################################### data = data_iter.next() imgs, captions, cap_lens, class_ids, keys = prepare_data(data) hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach(), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask) ####################################################### # (3) Update D network ###################################################### errD_total = 0 D_logs = '' for i in range(len(netsD)): netsD[i].zero_grad() errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i], sent_emb, real_labels, fake_labels) # backward and update parameters errD.backward() optimizersD[i].step() errD_total += errD D_logs += 'errD%d: %.2f ' % (i, errD.data[0]) ####################################################### # (4) Update G network: maximize log(D(G(z))) ###################################################### # compute total loss for training G step += 1 gen_iterations += 1 # do not need to compute gradient for Ds # self.set_requires_grad_value(netsD, False) netG.zero_grad() errG_total, G_logs = \ generator_loss(netsD, image_encoder, fake_imgs, real_labels, words_embs, sent_emb, match_labels, cap_lens, class_ids) kl_loss = KL_loss(mu, logvar) errG_total += kl_loss G_logs += 'kl_loss: %.2f ' % kl_loss.data[0] # backward and update parameters errG_total.backward() optimizerG.step() for p, avg_p in zip(netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) if gen_iterations % 100 == 0: print(D_logs + '\n' + G_logs) # save images if gen_iterations % 1000 == 0: backup_para = copy_G_params(netG) load_params(netG, avg_param_G) self.save_img_results(netG, fixed_noise, sent_emb, words_embs, mask, image_encoder, captions, cap_lens, epoch, name='average') load_params(netG, backup_para) # # self.save_img_results(netG, fixed_noise, sent_emb, # words_embs, mask, image_encoder, # captions, cap_lens, # epoch, name='current') end_t = time.time() print('''[%d/%d][%d] Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' % (epoch, self.max_epoch, self.num_batches, errD_total.data[0], errG_total.data[0], end_t - start_t)) if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: # and epoch != 0: self.save_model(netG, avg_param_G, netsD, epoch) self.save_model(netG, avg_param_G, netsD, self.max_epoch) def save_singleimages(self, images, filenames, save_dir, split_dir, sentenceID=0): for i in range(images.size(0)): s_tmp = '%s/single_samples/%s/%s' %\ (save_dir, split_dir, filenames[i]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) fullpath = '%s_%d.jpg' % (s_tmp, sentenceID) # range from [-1, 1] to [0, 1] # img = (images[i] + 1.0) / 2 img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte() # range from [0, 1] to [0, 255] ndarr = img.permute(1, 2, 0).data.cpu().numpy() im = Image.fromarray(ndarr) im.save(fullpath) def sampling(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: if split_dir == 'test': split_dir = 'valid' # Build and load the generator if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() netG.apply(weights_init) netG.cuda() netG.eval() # text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) # state_dict = torch.load(cfg.TRAIN.NET_G) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) # the path to save generated images s_tmp = model_dir[:model_dir.rfind('.pth')] save_dir = '%s/%s' % (s_tmp, split_dir) mkdir_p(save_dir) cnt = 0 for _ in range(1): # (cfg.TEXT.CAPTIONS_PER_IMAGE): for step, data in enumerate(self.data_loader, 0): cnt += batch_size if step % 100 == 0: print('step: ', step) # if step > 50: # break imgs, captions, cap_lens, class_ids, keys = prepare_data(data) hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach(), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask) for j in range(batch_size): s_tmp = '%s/single/%s' % (save_dir, keys[j]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) k = -1 # for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() # [-1, 1] --> [0, 255] im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_s%d.png' % (s_tmp, k) im.save(fullpath) def gen_example(self, data_dic): captions, cap_lens, sorted_indices = data_dic batch_size = captions.shape[0] nz = cfg.GAN.Z_DIM captions = Variable(torch.from_numpy(captions), volatile=True) cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True) captions = captions.cuda() cap_lens = cap_lens.cuda() noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() ####################################################### # (1) Extract text embeddings ###################################################### hidden = self.text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = self.text_encoder(captions, cap_lens, hidden) mask = (captions == 0) ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) fake_imgs, attention_maps, _, _ = self.netG(noise, sent_emb, words_embs, mask) # G attention cap_lens_np = cap_lens.cpu().data.numpy() generated_images = [] for j in range(batch_size): im = fake_imgs[2][j].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) # print('im', im.shape) im = np.transpose(im, (1, 2, 0)) # print('im', im.shape) #im = Image.fromarray(im) #fullpath = '%s.png' % (save_name) #im.save(fullpath) generated_images.append(im) return np.array(generated_images)