def main(exercise: str = "Aufgabe-2"): with open("./config.{}.json".format(exercise), mode="r") as f: args = json.load(f, object_hook=lambda d: namedtuple("X", d.keys()) (*d.values())) environment = gym.make(args.env) input_shape = environment.observation_space.shape num_actions = environment.action_space.n output_directory = "./tmp/{}/{}".format(exercise, datetime.datetime.now()) writer = FileWriter(output_directory) agent = make_agent(args, input_shape, num_actions, output_directory) rewards = [] for episode in range(args.episodes): episode_rewards = run_episode( environment, agent, render=episode % args.render_episode_interval == 0, max_length=args.max_episode_length, ) rewards.append(episode_rewards) if episode % args.training_interval == 0: for _ in range(args.training_interval): loss = agent.train() if loss and episode % (args.training_interval * 10) == 0: mean_rewards = np.mean(rewards) std_rewards = np.std(rewards) writer.add_summary(summary=summary.scalar("dqn/loss", loss), global_step=episode) writer.add_summary( summary=summary.scalar("rewards/mean", mean_rewards), global_step=episode, ) writer.add_summary( summary=summary.scalar("rewards/standard deviation", std_rewards), global_step=episode, ) writer.add_summary( summary=summary.scalar("dqn/epsilon", agent.exploration_strategy.epsilon), global_step=episode, ) print("Episode {}\tMean rewards {:f}\tLoss {:f}\tEpsilon {:f}". format(episode, mean_rewards, loss, agent.exploration_strategy.epsilon)) rewards.clear()
class condGANTrainer(object): def __init__(self, output_dir, data_loader, imsize): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) #torch.cuda.set_device(self.gpus[0]) #torch._C._cuda_setDevice(-1) cudnn.benchmark = True self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL self.data_loader = data_loader self.num_batches = len(self.data_loader) def prepare_data(self, data): imgs, w_imgs, t_embedding, _ = data real_vimgs, wrong_vimgs = [], [] if cfg.CUDA: vembedding = Variable(t_embedding).cuda() else: vembedding = Variable(t_embedding) for i in range(self.num_Ds): if cfg.CUDA: real_vimgs.append(Variable(imgs[i]).cuda()) wrong_vimgs.append(Variable(w_imgs[i]).cuda()) else: real_vimgs.append(Variable(imgs[i])) wrong_vimgs.append(Variable(w_imgs[i])) return imgs, real_vimgs, wrong_vimgs, vembedding def train_Dnet(self, idx, count): flag = count % 100 batch_size = self.real_imgs[0].size(0) criterion, mu = self.criterion, self.mu netD, optD = self.netsD[idx], self.optimizersD[idx] real_imgs = self.real_imgs[idx] wrong_imgs = self.wrong_imgs[idx] fake_imgs = self.fake_imgs[idx] # netD.zero_grad() # Forward real_labels = self.real_labels[:batch_size] fake_labels = self.fake_labels[:batch_size] # for real real_logits = netD(real_imgs, mu.detach()) wrong_logits = netD(wrong_imgs, mu.detach()) fake_logits = netD(fake_imgs.detach(), mu.detach()) # errD_real = criterion(real_logits[0], real_labels) errD_wrong = criterion(wrong_logits[0], fake_labels) errD_fake = criterion(fake_logits[0], fake_labels) if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ criterion(real_logits[1], real_labels) errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ criterion(wrong_logits[1], real_labels) errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ criterion(fake_logits[1], fake_labels) # errD_real = errD_real + errD_real_uncond errD_wrong = errD_wrong + errD_wrong_uncond errD_fake = errD_fake + errD_fake_uncond # errD = errD_real + errD_wrong + errD_fake else: errD = errD_real + 0.5 * (errD_wrong + errD_fake) # backward errD.backward() # update parameters optD.step() # log if flag == 0: summary_D = summary.scalar('D_loss%d' % idx, errD.data[0]) self.summary_writer.add_summary(summary_D, count) return errD def train_Gnet(self, count): self.netG.zero_grad() errG_total = 0 flag = count % 100 batch_size = self.real_imgs[0].size(0) criterion, mu, logvar = self.criterion, self.mu, self.logvar real_labels = self.real_labels[:batch_size] for i in range(self.num_Ds): outputs = self.netsD[i](self.fake_imgs[i], mu) errG = criterion(outputs[0], real_labels) if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS *\ criterion(outputs[1], real_labels) errG = errG + errG_patch errG_total = errG_total + errG if flag == 0: summary_D = summary.scalar('G_loss%d' % i, errG.data[0]) self.summary_writer.add_summary(summary_D, count) # Compute color consistency losses if cfg.TRAIN.COEFF.COLOR_LOSS > 0: if self.num_Ds > 1: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1]) mu2, covariance2 = \ compute_mean_covariance(self.fake_imgs[-2].detach()) like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ nn.MSELoss()(covariance1, covariance2) errG_total = errG_total + like_mu2 + like_cov2 if flag == 0: sum_mu = summary.scalar('G_like_mu2', like_mu2.data[0]) self.summary_writer.add_summary(sum_mu, count) sum_cov = summary.scalar('G_like_cov2', like_cov2.data[0]) self.summary_writer.add_summary(sum_cov, count) if self.num_Ds > 2: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2]) mu2, covariance2 = \ compute_mean_covariance(self.fake_imgs[-3].detach()) like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ nn.MSELoss()(covariance1, covariance2) errG_total = errG_total + like_mu1 + like_cov1 if flag == 0: sum_mu = summary.scalar('G_like_mu1', like_mu1.data[0]) self.summary_writer.add_summary(sum_mu, count) sum_cov = summary.scalar('G_like_cov1', like_cov1.data[0]) self.summary_writer.add_summary(sum_cov, count) kl_loss = KL_loss(mu, logvar) * cfg.TRAIN.COEFF.KL errG_total = errG_total + kl_loss errG_total.backward() self.optimizerG.step() return kl_loss, errG_total def train(self): self.netG, self.netsD, self.num_Ds,\ self.inception_model, start_count = load_network(self.gpus) avg_param_G = copy_G_params(self.netG) self.optimizerG, self.optimizersD = \ define_optimizers(self.netG, self.netsD) self.criterion = nn.BCELoss() self.real_labels = \ Variable(torch.FloatTensor(self.batch_size).fill_(1)) self.fake_labels = \ Variable(torch.FloatTensor(self.batch_size).fill_(0)) self.gradient_one = torch.FloatTensor([1.0]) self.gradient_half = torch.FloatTensor([0.5]) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1)) if cfg.CUDA: self.criterion.cuda() self.real_labels = self.real_labels.cuda() self.fake_labels = self.fake_labels.cuda() self.gradient_one = self.gradient_one.cuda() self.gradient_half = self.gradient_half.cuda() noise, fixed_noise = noise.cuda(), fixed_noise.cuda() predictions = [] count = start_count start_epoch = start_count // (self.num_batches) for epoch in range(start_epoch, self.max_epoch): start_t = time.time() for step, data in enumerate(self.data_loader, 0): ####################################################### # (0) Prepare training data ###################################################### self.imgs_tcpu, self.real_imgs, self.wrong_imgs, \ self.txt_embedding = self.prepare_data(data) ####################################################### # (1) Generate fake images ###################################################### noise.data.normal_(0, 1) self.fake_imgs, self.mu, self.logvar = \ self.netG(noise, self.txt_embedding) ####################################################### # (2) Update D network ###################################################### errD_total = 0 for i in range(self.num_Ds): errD = self.train_Dnet(i, count) errD_total += errD ####################################################### # (3) Update G network: maximize log(D(G(z))) ###################################################### kl_loss, errG_total = self.train_Gnet(count) for p, avg_p in zip(self.netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) # for inception score pred = self.inception_model(self.fake_imgs[-1].detach()) predictions.append(pred.data.cpu().numpy()) if count % 100 == 0: summary_D = summary.scalar('D_loss', errD_total.data[0]) summary_G = summary.scalar('G_loss', errG_total.data[0]) summary_KL = summary.scalar('KL_loss', kl_loss.data[0]) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) count = count + 1 if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) # Save images backup_para = copy_G_params(self.netG) load_params(self.netG, avg_param_G) # self.fake_imgs, _, _ = \ self.netG(fixed_noise, self.txt_embedding) save_img_results(self.imgs_tcpu, self.fake_imgs, self.num_Ds, count, self.image_dir, self.summary_writer) # load_params(self.netG, backup_para) # Compute inception score if len(predictions) > 500: predictions = np.concatenate(predictions, 0) mean, std = compute_inception_score(predictions, 10) # print('mean:', mean, 'std', std) m_incep = summary.scalar('Inception_mean', mean) self.summary_writer.add_summary(m_incep, count) # mean_nlpp, std_nlpp = \ negative_log_posterior_probability(predictions, 10) m_nlpp = summary.scalar('NLPP_mean', mean_nlpp) self.summary_writer.add_summary(m_nlpp, count) # predictions = [] end_t = time.time() print('''[%d/%d][%d] Loss_D: %.2f Loss_G: %.2f Loss_KL: %.2f Time: %.2fs ''' # D(real): %.4f D(wrong):%.4f D(fake) %.4f % (epoch, self.max_epoch, self.num_batches, errD_total.data[0], errG_total.data[0], kl_loss.data[0], end_t - start_t)) save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) self.summary_writer.close() def save_superimages(self, images_list, filenames, save_dir, split_dir, imsize): batch_size = images_list[0].size(0) num_sentences = len(images_list) for i in range(batch_size): s_tmp = '%s/super/%s/%s' %\ (save_dir, split_dir, filenames[i]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) # savename = '%s_%d.png' % (s_tmp, imsize) super_img = [] for j in range(num_sentences): img = images_list[j][i] # print(img.size()) img = img.view(1, 3, imsize, imsize) # print(img.size()) super_img.append(img) # break super_img = torch.cat(super_img, 0) vutils.save_image(super_img, savename, nrow=10, normalize=True) def save_singleimages(self, images, filenames, save_dir, split_dir, sentenceID, imsize): for i in range(images.size(0)): s_tmp = '%s/single_samples/%s/%s' %\ (save_dir, split_dir, filenames[i]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) fullpath = '%s_%d_sentence%d.png' % (s_tmp, imsize, sentenceID) # range from [-1, 1] to [0, 255] img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte() ndarr = img.permute(1, 2, 0).data.cpu().numpy() im = Image.fromarray(ndarr) im.save(fullpath) def evaluate(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: # Build and load the generator if split_dir == 'test': split_dir = 'valid' netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=self.gpus) print(netG) # state_dict = torch.load(cfg.TRAIN.NET_G) state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) # the path to save generated images s_tmp = cfg.TRAIN.NET_G istart = s_tmp.rfind('_') + 1 iend = s_tmp.rfind('.') iteration = int(s_tmp[istart:iend]) s_tmp = s_tmp[:s_tmp.rfind('/')] save_dir = '%s/iteration%d' % (s_tmp, iteration) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) if cfg.CUDA: netG.cuda() noise = noise.cuda() # switch to evaluate mode netG.eval() for step, data in enumerate(self.data_loader, 0): imgs, t_embeddings, filenames = data if cfg.CUDA: t_embeddings = Variable(t_embeddings).cuda() else: t_embeddings = Variable(t_embeddings) # print(t_embeddings[:, 0, :], t_embeddings.size(1)) embedding_dim = t_embeddings.size(1) batch_size = imgs[0].size(0) noise.data.resize_(batch_size, nz) noise.data.normal_(0, 1) fake_img_list = [] for i in range(embedding_dim): fake_imgs, _, _ = netG(noise, t_embeddings[:, i, :]) if cfg.TEST.B_EXAMPLE: # fake_img_list.append(fake_imgs[0].data.cpu()) # fake_img_list.append(fake_imgs[1].data.cpu()) fake_img_list.append(fake_imgs[2].data.cpu()) else: self.save_singleimages(fake_imgs[-1], filenames, save_dir, split_dir, i, 256) # self.save_singleimages(fake_imgs[-2], filenames, # save_dir, split_dir, i, 128) # self.save_singleimages(fake_imgs[-3], filenames, # save_dir, split_dir, i, 64) # break if cfg.TEST.B_EXAMPLE: # self.save_superimages(fake_img_list, filenames, # save_dir, split_dir, 64) # self.save_superimages(fake_img_list, filenames, # save_dir, split_dir, 128) self.save_superimages(fake_img_list, filenames, save_dir, split_dir, 256)
class GANTrainer(object): def __init__(self, output_dir, data_loader, imsize): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) #torch.cuda.set_device(self.gpus[0]) #torch._C._cuda_setDevice(-1) cudnn.benchmark = True self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL self.data_loader = data_loader self.num_batches = len(self.data_loader) def prepare_data(self, data): imgs = data vimgs = [] for i in range(self.num_Ds): if cfg.CUDA: vimgs.append(Variable(imgs[i]).cuda()) else: vimgs.append(Variable(imgs[i])) return imgs, vimgs def train_Dnet(self, idx, count): flag = count % 100 batch_size = self.real_imgs[0].size(0) criterion = self.criterion netD, optD = self.netsD[idx], self.optimizersD[idx] real_imgs = self.real_imgs[idx] fake_imgs = self.fake_imgs[idx] real_labels = self.real_labels[:batch_size] fake_labels = self.fake_labels[:batch_size] # netD.zero_grad() # real_logits = netD(real_imgs) fake_logits = netD(fake_imgs.detach()) # errD_real = criterion(real_logits[0], real_labels) errD_fake = criterion(fake_logits[0], fake_labels) # errD = errD_real + errD_fake errD.backward() # update parameters optD.step() # log if flag == 0: summary_D = summary.scalar('D_loss%d' % idx, errD.data[0]) self.summary_writer.add_summary(summary_D, count) return errD def train_Gnet(self, count): self.netG.zero_grad() errG_total = 0 flag = count % 100 batch_size = self.real_imgs[0].size(0) criterion = self.criterion real_labels = self.real_labels[:batch_size] for i in range(self.num_Ds): netD = self.netsD[i] outputs = netD(self.fake_imgs[i]) errG = criterion(outputs[0], real_labels) # errG = self.stage_coeff[i] * errG errG_total = errG_total + errG if flag == 0: summary_G = summary.scalar('G_loss%d' % i, errG.data[0]) self.summary_writer.add_summary(summary_G, count) # Compute color preserve losses if cfg.TRAIN.COEFF.COLOR_LOSS > 0: if self.num_Ds > 1: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1]) mu2, covariance2 = \ compute_mean_covariance(self.fake_imgs[-2].detach()) like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ nn.MSELoss()(covariance1, covariance2) errG_total = errG_total + like_mu2 + like_cov2 if self.num_Ds > 2: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2]) mu2, covariance2 = \ compute_mean_covariance(self.fake_imgs[-3].detach()) like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ nn.MSELoss()(covariance1, covariance2) errG_total = errG_total + like_mu1 + like_cov1 if flag == 0: sum_mu = summary.scalar('G_like_mu2', like_mu2.data[0]) self.summary_writer.add_summary(sum_mu, count) sum_cov = summary.scalar('G_like_cov2', like_cov2.data[0]) self.summary_writer.add_summary(sum_cov, count) if self.num_Ds > 2: sum_mu = summary.scalar('G_like_mu1', like_mu1.data[0]) self.summary_writer.add_summary(sum_mu, count) sum_cov = summary.scalar('G_like_cov1', like_cov1.data[0]) self.summary_writer.add_summary(sum_cov, count) errG_total.backward() self.optimizerG.step() return errG_total def train(self): self.netG, self.netsD, self.num_Ds,\ self.inception_model, start_count = load_network(self.gpus) avg_param_G = copy_G_params(self.netG) self.optimizerG, self.optimizersD = \ define_optimizers(self.netG, self.netsD) self.criterion = nn.BCELoss() self.real_labels = \ Variable(torch.FloatTensor(self.batch_size).fill_(1)) self.fake_labels = \ Variable(torch.FloatTensor(self.batch_size).fill_(0)) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1)) if cfg.CUDA: self.criterion.cuda() noise, fixed_noise = noise.cuda(), fixed_noise.cuda() self.real_labels = self.real_labels.cuda() self.fake_labels = self.fake_labels.cuda() predictions = [] count = start_count start_epoch = start_count // (self.num_batches) for epoch in range(start_epoch, self.max_epoch): start_t = time.time() for step, data in enumerate(self.data_loader, 0): ####################################################### # (0) Prepare training data ###################################################### self.imgs_tcpu, self.real_imgs = self.prepare_data(data) ####################################################### # (1) Generate fake images ###################################################### noise.data.normal_(0, 1) self.fake_imgs, _, _ = self.netG(noise) ####################################################### # (2) Update D network ###################################################### errD_total = 0 for i in range(self.num_Ds): errD = self.train_Dnet(i, count) errD_total += errD ####################################################### # (3) Update G network: maximize log(D(G(z))) ###################################################### errG_total = self.train_Gnet(count) for p, avg_p in zip(self.netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) # for inception score pred = self.inception_model(self.fake_imgs[-1].detach()) predictions.append(pred.data.cpu().numpy()) if count % 100 == 0: summary_D = summary.scalar('D_loss', errD_total.data[0]) summary_G = summary.scalar('G_loss', errG_total.data[0]) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_G, count) if step == 0: print('''[%d/%d][%d/%d] Loss_D: %.2f Loss_G: %.2f''' % (epoch, self.max_epoch, step, self.num_batches, errD_total.data[0], errG_total.data[0])) count = count + 1 if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) # Save images backup_para = copy_G_params(self.netG) load_params(self.netG, avg_param_G) # self.fake_imgs, _, _ = self.netG(fixed_noise) save_img_results(self.imgs_tcpu, self.fake_imgs, self.num_Ds, count, self.image_dir, self.summary_writer) # load_params(self.netG, backup_para) # Compute inception score if len(predictions) > 500: predictions = np.concatenate(predictions, 0) mean, std = compute_inception_score(predictions, 10) # print('mean:', mean, 'std', std) m_incep = summary.scalar('Inception_mean', mean) self.summary_writer.add_summary(m_incep, count) # mean_nlpp, std_nlpp = \ negative_log_posterior_probability(predictions, 10) m_nlpp = summary.scalar('NLPP_mean', mean_nlpp) self.summary_writer.add_summary(m_nlpp, count) # predictions = [] end_t = time.time() print('Total Time: %.2fsec' % (end_t - start_t)) save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) self.summary_writer.close() def save_superimages(self, images, folder, startID, imsize): fullpath = '%s/%d_%d.png' % (folder, startID, imsize) vutils.save_image(images.data, fullpath, normalize=True) def save_singleimages(self, images, folder, startID, imsize): for i in range(images.size(0)): fullpath = '%s/%d_%d.png' % (folder, startID + i, imsize) # range from [-1, 1] to [0, 1] img = (images[i] + 1.0) / 2 img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte() # range from [0, 1] to [0, 255] ndarr = img.permute(1, 2, 0).data.cpu().numpy() im = Image.fromarray(ndarr) im.save(fullpath) def evaluate(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: # Build and load the generator netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=self.gpus) print(netG) # state_dict = torch.load(cfg.TRAIN.NET_G) state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) # the path to save generated images s_tmp = cfg.TRAIN.NET_G istart = s_tmp.rfind('_') + 1 iend = s_tmp.rfind('.') iteration = int(s_tmp[istart:iend]) s_tmp = s_tmp[:s_tmp.rfind('/')] save_dir = '%s/iteration%d/%s' % (s_tmp, iteration, split_dir) if cfg.TEST.B_EXAMPLE: folder = '%s/super' % (save_dir) else: folder = '%s/single' % (save_dir) print('Make a new folder: ', folder) mkdir_p(folder) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) if cfg.CUDA: netG.cuda() noise = noise.cuda() # switch to evaluate mode netG.eval() num_batches = int(cfg.TEST.SAMPLE_NUM / self.batch_size) cnt = 0 for step in xrange(num_batches): noise.data.normal_(0, 1) fake_imgs, _, _ = netG(noise) if cfg.TEST.B_EXAMPLE: self.save_superimages(fake_imgs[-1], folder, cnt, 256) else: self.save_singleimages(fake_imgs[-1], folder, cnt, 256) # self.save_singleimages(fake_imgs[-2], folder, 128) # self.save_singleimages(fake_imgs[-3], folder, 64) cnt += self.batch_size
class GANTrainer(object): def __init__(self, output_dir): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True # ############# For training stageI GAN ############# def load_network_stageI(self): from model import STAGE1_G, STAGE1_D netG = STAGE1_G() netG.apply(weights_init) print(netG) netD = STAGE1_D() netD.apply(weights_init) print(netD) if cfg.NET_G != '': state_dict = \ torch.load(cfg.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load from: ', cfg.NET_G) if cfg.NET_D != '': state_dict = \ torch.load(cfg.NET_D, map_location=lambda storage, loc: storage) netD.load_state_dict(state_dict) print('Load from: ', cfg.NET_D) if cfg.CUDA: netG.cuda() netD.cuda() return netG, netD # ############# For training stageII GAN ############# def load_network_stageII(self): from model import STAGE1_G, STAGE2_G, STAGE2_D Stage1_G = STAGE1_G() netG = STAGE2_G(Stage1_G) netG.apply(weights_init) print(netG) if cfg.NET_G != '': state_dict = \ torch.load(cfg.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load from: ', cfg.NET_G) elif cfg.STAGE1_G != '': state_dict = \ torch.load(cfg.STAGE1_G, map_location=lambda storage, loc: storage) netG.STAGE1_G.load_state_dict(state_dict) print('Load from: ', cfg.STAGE1_G) else: print("Please give the Stage1_G path") return netD = STAGE2_D() netD.apply(weights_init) if cfg.NET_D != '': state_dict = \ torch.load(cfg.NET_D, map_location=lambda storage, loc: storage) netD.load_state_dict(state_dict) print('Load from: ', cfg.NET_D) print(netD) if cfg.CUDA: netG.cuda() netD.cuda() return netG, netD def train(self, data_loader, stage=1): if stage == 1: netG, netD = self.load_network_stageI() else: netG, netD = self.load_network_stageII() nz = cfg.Z_DIM batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), volatile=True) real_labels = torch.FloatTensor([1]) fake_labels = real_labels * -1 wrong_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) # real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) # fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels, wrong_labels = real_labels.cuda( ), fake_labels.cuda(), wrong_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH optimizerD = optim.RMSprop(netD.parameters(), lr=discriminator_lr) netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerG = optim.RMSprop(netG_para, lr=generator_lr) # optimizerD = \ # optim.Adam(netD.parameters(), # lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) # netG_para = [] # for p in netG.parameters(): # if p.requires_grad: # netG_para.append(p) # optimizerG = optim.Adam(netG_para, # lr=cfg.TRAIN.GENERATOR_LR, # betas=(0.5, 0.999)) count = 0 for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr for i, data in enumerate(data_loader, 0): ###################################################### # (1) Prepare training data ###################################################### real_img_cpu, txt_embedding = data real_imgs = Variable(real_img_cpu) txt_embedding = Variable(txt_embedding) if cfg.CUDA: real_imgs = real_imgs.cuda() txt_embedding = txt_embedding.cuda() ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (txt_embedding, noise) _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) ############################ # (3) Update D network ########################### for p in netD.parameters(): p.data.clamp_(-0.01, 0.01) netD.zero_grad() errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels,wrong_labels, mu, self.gpus) optimizerD.step() ############################ # (2) Update G network ########################### netG.zero_grad() errG = compute_generator_loss(netD, fake_imgs, real_labels, mu, self.gpus) kl_loss = KL_loss(mu, logvar) errG_total = kl_loss * cfg.TRAIN.COEFF.KL errG_total.backward(retain_graph=True) optimizerG.step() count = count + 1 if i % 100 == 0: summary_D = summary.scalar('D_loss', errD.data[0]) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.data[0]) summary_KL = summary.scalar('KL_loss', kl_loss.data[0]) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D_r, count) self.summary_writer.add_summary(summary_D_w, count) self.summary_writer.add_summary(summary_D_f, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) # save the image result for each epoch inputs = (txt_embedding, fixed_noise) lr_fake, fake, _, _ = \ nn.parallel.data_parallel(netG, inputs, self.gpus) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) end_t = time.time() print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.data[0], errG.data[0], kl_loss.data[0], errD_real, errD_wrong, errD_fake, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD, epoch, self.model_dir) # save_model(netG, netD, self.max_epoch, self.model_dir) # self.summary_writer.close() def sample(self, datapath, stage=1): if stage == 1: netG, _ = self.load_network_stageI() else: netG, _ = self.load_network_stageII() netG.eval() # Load text embeddings generated from the encoder t_file = torchfile.load(datapath) captions_list = t_file.raw_txt embeddings = np.concatenate(t_file.fea_txt, axis=0) num_embeddings = len(captions_list) print('Successfully load sentences from: ', datapath) print('Total number of sentences:', num_embeddings) print('num_embeddings:', num_embeddings, embeddings.shape) # path to save generated samples save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')] mkdir_p(save_dir) batch_size = np.minimum(num_embeddings, self.batch_size) nz = cfg.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz)) if cfg.CUDA: noise = noise.cuda() count = 0 while count < num_embeddings: if count > 3000: break iend = count + batch_size if iend > num_embeddings: iend = num_embeddings count = num_embeddings - batch_size embeddings_batch = embeddings[count:iend] # captions_batch = captions_list[count:iend] txt_embedding = Variable(torch.FloatTensor(embeddings_batch)) if cfg.CUDA: txt_embedding = txt_embedding.cuda() ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (txt_embedding, noise) _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) for i in range(batch_size): save_name = '%s/%d.png' % (save_dir, count + i) im = fake_imgs[i].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) # print('im', im.shape) im = np.transpose(im, (1, 2, 0)) # print('im', im.shape) im = Image.fromarray(im) im.save(save_name) count += batch_size
class condGANTrainer(object): def __init__(self, output_dir, data_loader): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL self.data_loader = data_loader self.num_batches = len(self.data_loader) def prepare_data(self, data): real_vimgs, wrong_vimgs = [], [] imgs, texts, w_imgs, _ = data if cfg.CUDA: vtxts = Variable(texts).cuda() else: vtxts = Variable(texts) for i in xrange(len(imgs)): if cfg.CUDA: real_vimgs.append(Variable(imgs[i]).cuda()) wrong_vimgs.append(Variable(w_imgs[i]).cuda()) else: real_vimgs.append(Variable(imgs[i])) wrong_vimgs.append(Variable(w_imgs[i])) return imgs, vtxts, real_vimgs, wrong_vimgs def train_Dnet(self, idx, count): flag = count % 100 batch_size = self.real_tgpu[0].size(0) criterion, c_code = self.criterion, self.c_code[idx // 3] netD, optD = self.netsD[idx], self.optimizersD[idx] real_imgs = self.real_tgpu[int((idx // 3) + idx % 3)] wrong_imgs = self.wrong_tgpu[int((idx // 3) + idx % 3)] fake_imgs = self.fake_imgs[idx] netD.zero_grad() # Forward real_labels = self.real_labels[:batch_size] fake_labels = self.fake_labels[:batch_size] # for real real_logits = netD(real_imgs, c_code.detach()) wrong_logits = netD(wrong_imgs, c_code.detach()) fake_logits = netD(fake_imgs.detach(), c_code.detach()) errD_real = criterion(real_logits[0], real_labels) errD_wrong = criterion(wrong_logits[0], fake_labels) errD_fake = criterion(fake_logits[0], fake_labels) if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ criterion(real_logits[1], real_labels) errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ criterion(wrong_logits[1], real_labels) errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ criterion(fake_logits[1], fake_labels) errD_real = errD_real + errD_real_uncond errD_wrong = errD_wrong + errD_wrong_uncond errD_fake = errD_fake + errD_fake_uncond errD = errD_real + errD_wrong + errD_fake else: errD = errD_real + 0.5 * (errD_wrong + errD_fake) # backward errD.backward() # update parameters optD.step() # log if flag == 0: summary_D = summary.scalar('D_loss%d' % idx, float(errD.data[0])) self.summary_writer.add_summary(summary_D, count) return float(errD) def train_Gnet(self, idx, count): optG = self.optimizersG[idx] optG.zero_grad() errG_total = 0 flag = count % 100 batch_size = self.real_tgpu[0].size(0) criterion, c_code = self.criterion, self.c_code[idx] real_labels = self.real_labels[:batch_size] for i in xrange(len(self.netsG)): outputs = self.netsD[idx * 3 + i](self.fake_imgs[idx * 3 + i], c_code) errG = criterion(outputs[0], real_labels) if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS *\ criterion(outputs[1], real_labels) errG = errG + errG_patch errG_total = errG_total + errG if flag == 0: summary_D = summary.scalar('G_loss%d' % i, errG.data[0]) self.summary_writer.add_summary(summary_D, count) errG_total = errG_total errG_total.backward() optG.step() return float(errG_total) def train_Enet(self, count): errEn_total = 0 flag = count % 100 params = [ self.shape_1, self.scale_1, self.shape_2, self.scale_2, self.shape_3, self.scale_3, self.Phi[0], self.theta_1, self.Phi[1], self.theta_2, self.Phi[2], self.theta_3, self.txtbow ] criterion, optEn, netEn = self.vae, self.optimizerEn, self.netEn loss, theta3_KL, theta2_KL, theta1_KL, Likelihood, Lowerbound = criterion( params) errEn_total = errEn_total + loss netEn.zero_grad() # backward errEn_total.backward() # update parameters optEn.step() if flag == 0: summary_LS = summary.scalar('En_loss', float(loss.data.item())) summary_LB = summary.scalar('En_lowerbound', float(Lowerbound.data.item())) summary_LL = summary.scalar('En_likelihood', float(Likelihood.data.item())) summary_KL1 = summary.scalar('En_kl1', float(theta1_KL.data.item())) summary_KL2 = summary.scalar('En_kl2', float(theta2_KL.data.item())) summary_KL3 = summary.scalar('En_kl3', float(theta3_KL.data.item())) self.summary_writer.add_summary(summary_LS, count) self.summary_writer.add_summary(summary_LB, count) self.summary_writer.add_summary(summary_LL, count) self.summary_writer.add_summary(summary_KL1, count) self.summary_writer.add_summary(summary_KL2, count) self.summary_writer.add_summary(summary_KL3, count) return theta1_KL, theta2_KL, theta3_KL, Likelihood, Lowerbound, loss def updatePhi(self, miniBatch, Phi, Theta, MBratio, MBObserved): real_min = 1e-6 Xt = miniBatch for t in range(len(Phi)): if t == 0: self.Xt_to_t1[t], self.WSZS[t] = PGBN_sampler.Multrnd_Matrix( Xt.astype('double'), Phi[t], Theta[t]) else: self.Xt_to_t1[t], self.WSZS[ t] = PGBN_sampler.Crt_Multirnd_Matrix( self.Xt_to_t1[t - 1], Phi[t], Theta[t]) self.EWSZS[t] = MBratio * self.WSZS[t] if (MBObserved == 0): self.NDot[t] = self.EWSZS[t].sum(0) else: self.NDot[t] = (1 - self.ForgetRate[MBObserved]) * self.NDot[t] + self.ForgetRate[MBObserved] * \ self.EWSZS[t].sum(0) tmp = self.EWSZS[t] + self.eta[t] tmp = (1 / (self.NDot[t] + real_min)) * (tmp - tmp.sum(0) * Phi[t]) tmp1 = (2 / (self.NDot[t] + real_min)) * Phi[t] tmp = Phi[t] + self.epsit[MBObserved] * tmp + np.sqrt( self.epsit[MBObserved] * tmp1) * np.random.randn( Phi[t].shape[0], Phi[t].shape[1]) Phi[t] = PGBN_sampler.ProjSimplexSpecial(tmp, Phi[t], 0) return Phi def train(self): self.netEn, self.netsG, self.netsD, self.num_Ds,\ self.inception_model, start_count = load_network(self.gpus) avg_param_G = [] for i in xrange(len(self.netsG)): avg_param_G.append(copy_G_params(self.netsG[i])) self.optimizerEn, self.optimizersG, self.optimizersD = \ define_optimizers(self.netEn, self.netsG, self.netsD) self.criterion = nn.BCELoss() self.vae = myLoss() self.real_labels = \ Variable(torch.FloatTensor(self.batch_size).fill_(1)) self.fake_labels = \ Variable(torch.FloatTensor(self.batch_size).fill_(0)) predictions = [] count = start_count start_epoch = start_count // self.num_batches batch_length = self.num_batches self.Phi = [] self.eta = [] K = [256, 128, 64] real_min = np.float64(2.2e-308) eta = np.ones(3) * 0.1 for i in range(3): self.eta.append(eta[i]) if i == 0: self.Phi.append(0.2 + 0.8 * np.float64(np.random.rand(1000, K[i]))) else: self.Phi.append(0.2 + 0.8 * np.float64(np.random.rand(K[i - 1], K[i]))) self.Phi[i] = self.Phi[i] / np.maximum(real_min, self.Phi[i].sum(0)) self.NDot = [0] * 3 self.Xt_to_t1 = [0] * 3 self.WSZS = [0] * 3 self.EWSZS = [0] * 3 self.ForgetRate = np.power( (0 + np.linspace(1, cfg.TRAIN.MAX_EPOCH * int(batch_length), cfg.TRAIN.MAX_EPOCH * int(batch_length))), -0.7) epsit = np.power( (20 + np.linspace(1, cfg.TRAIN.MAX_EPOCH * int(batch_length), cfg.TRAIN.MAX_EPOCH * int(batch_length))), -0.7) self.epsit = 1 * epsit / epsit[0] num_total_samples = batch_length * self.batch_size if cfg.CUDA: for i in xrange(len(self.Phi)): self.Phi[i] = Variable(torch.from_numpy( self.Phi[i]).float()).cuda() self.criterion.cuda() self.real_labels = self.real_labels.cuda() self.fake_labels = self.fake_labels.cuda() for epoch in xrange(start_epoch, self.max_epoch): start_t = time.time() LL = 0 KL1 = 0 KL2 = 0 KL3 = 0 LS = 0 DL = 0 GL = 0 for step, data in enumerate(self.data_loader, 0): ####################################################### # (0) Prepare training data ###################################################### self.img_tcpu, self.txtbow, self.real_tgpu, self.wrong_tgpu = self.prepare_data( data) ####################################################### # (1) Get conv hidden units ###################################################### _, self.flat = self.inception_model(self.real_tgpu[-1]) self.theta_1, self.shape_1, self.scale_1, self.theta_2,\ self.shape_2, self.scale_2, self.theta_3, self.shape_3,\ self.scale_3 = self.netEn(self.flat) self.txt_embedding = [] self.txt_embedding.append(self.theta_3.detach()) self.txt_embedding.append(self.theta_2.detach()) self.txt_embedding.append(self.theta_1.detach()) ####################################################### # (2) Generate fake images ###################################################### tmp = [] self.c_code = [] x_embedding = None for it in xrange(len(self.netsG)): fake_imgs, c_code, x_embedding = \ self.netsG[it](self.txt_embedding[it], x_embedding) tmp.append(fake_imgs) self.c_code.append(c_code) self.fake_imgs = [] for it in xrange(len(tmp)): for jt in xrange(len(tmp[it])): self.fake_imgs.append(tmp[it][jt]) ####################################################### # (3) Update En network ###################################################### self.KL1, self.KL2, self.KL3, self.LL, self.LB, self.LS = self.train_Enet( count) LL += self.LL KL1 += self.KL1 KL2 += self.KL2 KL3 += self.KL3 LS += self.LS if count % 100 == 0: print(self.LS) print(self.KL1) print(self.KL2) print(self.KL3) ####################################################### # (4) Update Phi ####################################################### input_txt = np.array(np.transpose(self.txtbow.cpu().numpy()), order='C').astype('double') Phi = [] theta = [] self.theta = [self.theta_1, self.theta_2, self.theta_3] for i in xrange(len(self.Phi)): Phi.append( np.array(self.Phi[i].cpu().numpy(), order='C').astype('double')) theta.append( np.array(np.transpose( self.theta[i].detach().cpu().numpy()), order='C').astype('double')) phi = self.updatePhi(input_txt, Phi, theta, int(batch_length), count) for i in xrange(len(phi)): self.Phi[i] = torch.tensor(phi[i], dtype=torch.float32).cuda() ####################################################### # (5) Update D network ###################################################### errD_total = 0 for i in xrange(self.num_Ds): errD = self.train_Dnet(i, count) errD_total += errD DL += errD_total ####################################################### # (6) Update G network: maximize log(D(G(z))) ###################################################### errG_total = 0 for i in xrange(len(self.netsG)): errG = self.train_Gnet(i, count) errG_total += errG for p, avg_p in zip(self.netsG[i].parameters(), avg_param_G[i]): avg_p.mul_(0.999).add_(0.001, p.data) GL += errG_total # for inception score if cfg.INCEPTION: pred, _ = self.inception_model(self.fake_imgs[-1].detach()) predictions.append(pred.data.cpu().numpy()) if count % 100 == 0: summary_D = summary.scalar('D_loss', errD_total) summary_G = summary.scalar('G_loss', errG_total) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_G, count) if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: save_model(self.netEn, self.netsG, avg_param_G, self.netsD, epoch, self.model_dir) # Save images backup_para = [] for i in xrange(len(self.netsG)): backup_para.append(copy_G_params(self.netsG[i])) load_params(self.netsG[i], avg_param_G[i]) x_embedding = None self.fake_imgs = [] for it in xrange(len(self.netsG)): fake_imgs, _, x_embedding = self.netsG[it]( self.txt_embedding[it], x_embedding) self.fake_imgs.append(fake_imgs[-1]) save_img_results(self.img_tcpu, self.fake_imgs, len(self.netsG), count, self.image_dir) for i in xrange(len(self.netsG)): load_params(self.netsG[i], backup_para[i]) if cfg.INCEPTION: # Compute inception score if len(predictions) > 500: predictions = np.concatenate(predictions, 0) mean, std = compute_inception_score( predictions, 10) m_incep = summary.scalar('Inception_mean', mean) self.summary_writer.add_summary(m_incep, count) mean_nlpp, std_nlpp = \ negative_log_posterior_probability(predictions, 10) m_nlpp = summary.scalar('NLPP_mean', mean_nlpp) self.summary_writer.add_summary(m_nlpp, count) predictions = [] count = count + 1 end_t = time.time() LS = LS / num_total_samples LL = LL / num_total_samples KL1 = KL1 / num_total_samples KL2 = KL2 / num_total_samples KL3 = KL3 / num_total_samples DL = DL / num_total_samples GL = GL / num_total_samples print( 'Epoch: %d/%d, Time elapsed: %.4fs\n' '* Batch Train Loss: %.6f (LL: %.6f, KL1: %.6f, KL2: %.6f,' 'KL3: %.6f, Loss_D: %.2f Loss_G: %.2f)\n' % (epoch, self.max_epoch, end_t - start_t, LS, LL, KL1, KL2, KL3, DL, GL)) save_model(self.netEn, self.netsG, avg_param_G, self.netsD, epoch, self.model_dir) self.summary_writer.close()
class FineGAN_trainer(object): def __init__(self, output_dir, data_loader, imsize): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL self.data_loader = data_loader self.num_batches = len(self.data_loader) def prepare_data(self, data): fimgs, cimgs, c_code, _, warped_bbox = data real_vfimgs, real_vcimgs = [], [] if cfg.CUDA: vc_code = Variable(c_code).cuda() for i in range(len(warped_bbox)): warped_bbox[i] = Variable(warped_bbox[i]).float().cuda() else: vc_code = Variable(c_code) for i in range(len(warped_bbox)): warped_bbox[i] = Variable(warped_bbox[i]) if cfg.CUDA: real_vfimgs.append(Variable(fimgs[0]).cuda()) real_vcimgs.append(Variable(cimgs[0]).cuda()) else: real_vfimgs.append(Variable(fimgs[0])) real_vcimgs.append(Variable(cimgs[0])) return fimgs, real_vfimgs, real_vcimgs, vc_code, warped_bbox def train_Dnet(self, idx, count): if idx == 0 or idx == 2: # Discriminator is only trained in background and child stage. (NOT in parent stage) flag = count % 100 batch_size = self.real_fimgs[0].size(0) criterion, criterion_one = self.criterion, self.criterion_one netD, optD = self.netsD[idx], self.optimizersD[idx] if idx == 0: real_imgs = self.real_fimgs[0] elif idx == 2: real_imgs = self.real_cimgs[0] fake_imgs = self.fake_imgs[idx] netD.zero_grad() real_logits = netD(real_imgs) if idx == 2: fake_labels = torch.zeros_like(real_logits[1]) real_labels = torch.ones_like(real_logits[1]) elif idx == 0: fake_labels = torch.zeros_like(real_logits[1]) ext, output = real_logits weights_real = torch.ones_like(output) real_labels = torch.ones_like(output) for i in range(batch_size): x1 = self.warped_bbox[0][i] x2 = self.warped_bbox[2][i] y1 = self.warped_bbox[1][i] y2 = self.warped_bbox[3][i] a1 = max( torch.tensor(0).float().cuda(), torch.ceil((x1 - self.recp_field) / self.patch_stride)) a2 = min( torch.tensor(self.n_out - 1).float().cuda(), torch.floor((self.n_out - 1) - ((126 - self.recp_field) - x2) / self.patch_stride)) + 1 b1 = max( torch.tensor(0).float().cuda(), torch.ceil((y1 - self.recp_field) / self.patch_stride)) b2 = min( torch.tensor(self.n_out - 1).float().cuda(), torch.floor((self.n_out - 1) - ((126 - self.recp_field) - y2) / self.patch_stride)) + 1 if (x1 != x2 and y1 != y2): weights_real[ i, :, a1.type(torch.int):a2.type(torch.int), b1.type(torch.int):b2.type(torch.int)] = 0.0 norm_fact_real = weights_real.sum() norm_fact_fake = weights_real.shape[0] * weights_real.shape[ 1] * weights_real.shape[2] * weights_real.shape[3] real_logits = ext, output fake_logits = netD(fake_imgs.detach()) if idx == 0: # Background stage errD_real_uncond = criterion( real_logits[1], real_labels ) # Real/Fake loss for 'real background' (on patch level) errD_real_uncond = torch.mul( errD_real_uncond, weights_real ) # Masking output units which correspond to receptive fields which lie within the boundin box errD_real_uncond = errD_real_uncond.mean() errD_real_uncond_classi = criterion( real_logits[0], weights_real) # Background/foreground classification loss errD_real_uncond_classi = errD_real_uncond_classi.mean() errD_fake_uncond = criterion( fake_logits[1], fake_labels ) # Real/Fake loss for 'fake background' (on patch level) errD_fake_uncond = errD_fake_uncond.mean() if ( norm_fact_real > 0 ): # Normalizing the real/fake loss for background after accounting the number of masked members in the output. errD_real = errD_real_uncond * ((norm_fact_fake * 1.0) / (norm_fact_real * 1.0)) else: errD_real = errD_real_uncond errD_fake = errD_fake_uncond errD = ((errD_real + errD_fake) * cfg.TRAIN.BG_LOSS_WT) + errD_real_uncond_classi if idx == 2: errD_real = criterion_one( real_logits[1], real_labels) # Real/Fake loss for the real image errD_fake = criterion_one( fake_logits[1], fake_labels) # Real/Fake loss for the fake image errD = errD_real + errD_fake if (idx == 0 or idx == 2): errD.backward() optD.step() if (flag == 0): summary_D = summary.scalar('D_loss%d' % idx, errD.data[0]) self.summary_writer.add_summary(summary_D, count) summary_D_real = summary.scalar('D_loss_real_%d' % idx, errD_real.data[0]) self.summary_writer.add_summary(summary_D_real, count) summary_D_fake = summary.scalar('D_loss_fake_%d' % idx, errD_fake.data[0]) self.summary_writer.add_summary(summary_D_fake, count) return errD def train_Gnet(self, count): self.netG.zero_grad() for myit in range(len(self.netsD)): self.netsD[myit].zero_grad() errG_total = 0 flag = count % 100 batch_size = self.real_fimgs[0].size(0) criterion_one, criterion_class, c_code, p_code = self.criterion_one, self.criterion_class, self.c_code, self.p_code for i in range(self.num_Ds): outputs = self.netsD[i](self.fake_imgs[i]) if i == 0 or i == 2: # real/fake loss for background (0) and child (2) stage real_labels = torch.ones_like(outputs[1]) errG = criterion_one(outputs[1], real_labels) if i == 0: errG = errG * cfg.TRAIN.BG_LOSS_WT errG_classi = criterion_one( outputs[0], real_labels ) # Background/Foreground classification loss for the fake background image (on patch level) errG = errG + errG_classi errG_total = errG_total + errG if i == 1: # Mutual information loss for the parent stage (1) pred_p = self.netsD[i](self.fg_mk[i - 1]) errG_info = criterion_class(pred_p[0], torch.nonzero(p_code.long())[:, 1]) elif i == 2: # Mutual information loss for the child stage (2) pred_c = self.netsD[i](self.fg_mk[i - 1]) errG_info = criterion_class(pred_c[0], torch.nonzero(c_code.long())[:, 1]) if (i > 0): errG_total = errG_total + errG_info if flag == 0: if i > 0: summary_D_class = summary.scalar('Information_loss_%d' % i, errG_info.data[0]) self.summary_writer.add_summary(summary_D_class, count) if i == 0 or i == 2: summary_D = summary.scalar('G_loss%d' % i, errG.data[0]) self.summary_writer.add_summary(summary_D, count) errG_total.backward() for myit in range(len(self.netsD)): self.optimizerG[myit].step() return errG_total def train(self): self.netG, self.netsD, self.num_Ds, start_count = load_network( self.gpus) avg_param_G = copy_G_params(self.netG) self.optimizerG, self.optimizersD = \ define_optimizers(self.netG, self.netsD) self.criterion = nn.BCELoss(reduce=False) self.criterion_one = nn.BCELoss() self.criterion_class = nn.CrossEntropyLoss() self.real_labels = \ Variable(torch.FloatTensor(self.batch_size).fill_(1)) self.fake_labels = \ Variable(torch.FloatTensor(self.batch_size).fill_(0)) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1)) hard_noise = \ Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1)).cuda() self.patch_stride = float( 4 ) # Receptive field stride given the current discriminator architecture for background stage self.n_out = 24 # Output size of the discriminator at the background stage; N X N where N = 24 self.recp_field = 34 # Receptive field of each of the member of N X N if cfg.CUDA: self.criterion.cuda() self.criterion_one.cuda() self.criterion_class.cuda() self.real_labels = self.real_labels.cuda() self.fake_labels = self.fake_labels.cuda() noise, fixed_noise = noise.cuda(), fixed_noise.cuda() print("Starting normal FineGAN training..") count = start_count start_epoch = start_count // (self.num_batches) for epoch in range(start_epoch, self.max_epoch): start_t = time.time() for step, data in enumerate(self.data_loader, 0): self.imgs_tcpu, self.real_fimgs, self.real_cimgs, \ self.c_code, self.warped_bbox = self.prepare_data(data) # Feedforward through Generator. Obtain stagewise fake images noise.data.normal_(0, 1) self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \ self.netG(noise, self.c_code) # Obtain the parent code given the child code self.p_code = child_to_parent(self.c_code, cfg.FINE_GRAINED_CATEGORIES, cfg.SUPER_CATEGORIES) # Update Discriminator networks errD_total = 0 for i in range(self.num_Ds): if i == 0 or i == 2: # only at parent and child stage errD = self.train_Dnet(i, count) errD_total += errD # Update the Generator networks errG_total = self.train_Gnet(count) for p, avg_p in zip(self.netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) count = count + 1 if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: backup_para = copy_G_params(self.netG) save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) # Save images load_params(self.netG, avg_param_G) self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \ self.netG(fixed_noise, self.c_code) save_img_results(self.imgs_tcpu, (self.fake_imgs + self.fg_imgs + self.mk_imgs + self.fg_mk), self.num_Ds, count, self.image_dir, self.summary_writer) # load_params(self.netG, backup_para) end_t = time.time() print('''[%d/%d][%d] Loss_D: %.2f Loss_G: %.2f Time: %.2fs ''' % (epoch, self.max_epoch, self.num_batches, errD_total.data[0], errG_total.data[0], end_t - start_t)) save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) print( "Done with the normal training. Now performing hard negative training.." ) count = 0 start_t = time.time() for step, data in enumerate(self.data_loader, 0): self.imgs_tcpu, self.real_fimgs, self.real_cimgs, \ self.c_code, self.warped_bbox = self.prepare_data(data) if (count % 2) == 0: # Train on normal batch of images # Feedforward through Generator. Obtain stagewise fake images noise.data.normal_(0, 1) self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \ self.netG(noise, self.c_code) self.p_code = child_to_parent(self.c_code, cfg.FINE_GRAINED_CATEGORIES, cfg.SUPER_CATEGORIES) # Update discriminator networks errD_total = 0 for i in range(self.num_Ds): if i == 0 or i == 2: errD = self.train_Dnet(i, count) errD_total += errD # Update the generator network errG_total = self.train_Gnet(count) else: # Train on degenerate images repeat_times = 10 all_hard_z = Variable( torch.zeros(self.batch_size * repeat_times, nz)).cuda() all_hard_class = Variable( torch.zeros(self.batch_size * repeat_times, cfg.FINE_GRAINED_CATEGORIES)).cuda() all_logits = Variable( torch.zeros(self.batch_size * repeat_times, )).cuda() for hard_it in range(repeat_times): hard_noise = hard_noise.data.normal_(0, 1) hard_class = Variable( torch.zeros( [self.batch_size, cfg.FINE_GRAINED_CATEGORIES])).cuda() my_rand_id = [] for c_it in range(self.batch_size): rand_class = random.sample( range(cfg.FINE_GRAINED_CATEGORIES), 1) hard_class[c_it][rand_class] = 1 my_rand_id.append(rand_class) all_hard_z[self.batch_size * hard_it:self.batch_size * (hard_it + 1)] = hard_noise.data all_hard_class[self.batch_size * hard_it:self.batch_size * (hard_it + 1)] = hard_class.data self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = self.netG( hard_noise.detach(), hard_class.detach()) fake_logits = self.netsD[2](self.fg_mk[1].detach()) smax_class = softmax(fake_logits[0], dim=1) for b_it in range(self.batch_size): all_logits[(self.batch_size * hard_it) + b_it] = smax_class[b_it][my_rand_id[b_it]] sorted_val, indices_hard = torch.sort(all_logits) noise = all_hard_z[indices_hard[0:self.batch_size]] self.c_code = all_hard_class[indices_hard[0:self.batch_size]] self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \ self.netG(noise, self.c_code) self.p_code = child_to_parent(self.c_code, cfg.FINE_GRAINED_CATEGORIES, cfg.SUPER_CATEGORIES) # Update Discriminator networks errD_total = 0 for i in range(self.num_Ds): if i == 0 or i == 2: errD = self.train_Dnet(i, count) errD_total += errD # Update generator network errG_total = self.train_Gnet(count) for p, avg_p in zip(self.netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) count = count + 1 if count % cfg.TRAIN.SNAPSHOT_INTERVAL_HARDNEG == 0: backup_para = copy_G_params(self.netG) save_model(self.netG, avg_param_G, self.netsD, count + 500000, self.model_dir) load_params(self.netG, avg_param_G) self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \ self.netG(fixed_noise, self.c_code) save_img_results(self.imgs_tcpu, (self.fake_imgs + self.fg_imgs + self.mk_imgs + self.fg_mk), self.num_Ds, count, self.image_dir, self.summary_writer) # load_params(self.netG, backup_para) end_t = time.time() if (count % 100) == 0: print( '''[%d/%d][%d] Loss_D: %.2f Loss_G: %.2f Time: %.2fs ''' % (count, cfg.TRAIN.HARDNEG_MAX_ITER, self.num_batches, errD_total.data[0], errG_total.data[0], end_t - start_t)) if (count == cfg.TRAIN.HARDNEG_MAX_ITER ): # Hard negative training complete break save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) self.summary_writer.close()
def run(args): with tf.Graph().as_default(): global_step = tf.Variable(0, name='global_step', trainable=False) train_flag = tf.placeholder(tf.bool) keep_prob = tf.placeholder(tf.float32) feature_1 = tf.placeholder(tf.float32, [None, args.input_dimension], 'feature_1') feature_2 = tf.placeholder(tf.float32, [None, args.input_dimension], 'feature_2') feature_3 = tf.placeholder(tf.float32, [None, args.input_dimension], 'feature_3') real_labels=tf.placeholder(tf.int32, [None, ], 'real_labels') ### generated predicate features ### bottle_z = ST_encoder(dim_G, feature_1, keep_prob, reuse=False, training=train_flag) reconstruction = ST_decoder(dim_D,1000, bottle_z, feature_2, keep_prob, reuse=False, training=train_flag) errL1 = tf.reduce_mean(tf.losses.absolute_difference(reconstruction, feature_3, reduction=tf.losses.Reduction.NONE)) if args.ac_weight > 0: ac_loss = aux_classifier(reconstruction, real_labels, args.num_predicates, keep_prob, reuse=False, training=train_flag) else: ac_loss=tf.zeros(1,dtype=tf.dtypes.float32) #errL1 = tf.reduce_mean(tf.abs(reconstruction - feature_3)) errD_fake = netD(256, reconstruction, n_layers=0, reuse=False) errD_real = netD(256, feature_3, n_layers=0, reuse=True) # cost functions errD = tf.reduce_mean(errD_fake) - tf.reduce_mean(errD_real) errG = -tf.reduce_mean(errD_fake) if args.ac_weight > 0: errG_total = errG + errL1 * args.L1_weight + args.ac_weight * ac_loss else: errG_total = errG + errL1 * args.L1_weight # gradient penalty epsilon = tf.random_uniform([], 0.0, 1.0) x_hat = feature_3 * (1 - epsilon) + epsilon * reconstruction d_hat = netD(256,x_hat, n_layers=0, reuse=True) gradients = tf.gradients(d_hat, x_hat)[0] slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) gradient_penalty = 10 * tf.reduce_mean((slopes - 1.0) ** 2) errD_total = errD + gradient_penalty t_vars = tf.trainable_variables() d_vars = [var for var in t_vars if 'Discriminator' in var.name] g_vars = [var for var in t_vars if 'Generator' in var.name] learning_rate = get_learning_rate(data_num, global_step) G_train_op = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.5, beta2=0.9).minimize(errG_total,global_step,var_list=g_vars) D_train_op = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.5, beta2=0.9).minimize(errD_total,global_step,var_list=d_vars) ops = {'D_train_op': D_train_op, 'G_train_op': G_train_op, 'feature_1': feature_1, 'feature_2': feature_2, 'feature_3': feature_3, 'keep_prob': keep_prob, 'real_labels': real_labels, 'train_flag': train_flag, 'errD': errD, 'errG': errG, 'errL1': errL1, 'ac_loss': ac_loss, 'reconstruction': reconstruction} saver = tf.train.Saver(max_to_keep=None) init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) ### make gpu memory grow according to needed ### gpu_options = tf.GPUOptions(allow_growth=True) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) sess.run(init) summary_writer = FileWriter(log_path, graph=tf.get_default_graph()) # tf.add_to_collection('train_op', train_op) tf.add_to_collection('G_train_op', G_train_op) tf.add_to_collection('D_train_op', D_train_op) start_epoch=0 if args.training: # restore previous model if there is one ckpt = tf.train.get_checkpoint_state(model_pth) if ckpt and ckpt.model_checkpoint_path: print("Restoring previous model...") try: start_epoch = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1]) + 1 print(start_epoch) saver.restore(sess, ckpt.model_checkpoint_path) print("Model restored") except: print("Could not restore model") pass ########################################### training portion for epoch in range(start_epoch,args.max_epoch): start = time.time() train_loss_d, train_loss_g, train_loss_L1Loss,train_loss_acLoss = train_one_epoch(sess, input_data, ops, args) print('epoch:', epoch, 'D loss:', train_loss_d.avg, 'G_loss:', train_loss_g.avg, 'L1:',train_loss_L1Loss.avg,'AC:',train_loss_acLoss.avg, 'time:', time.time() - start) summary_D = summary.scalar('D_loss', train_loss_d.avg) summary_writer.add_summary(summary_D, epoch) summary_G = summary.scalar('G_loss', train_loss_g.avg) summary_writer.add_summary(summary_G, epoch) summary_G_L1 = summary.scalar('G_L1', train_loss_L1Loss.avg) summary_writer.add_summary(summary_G_L1, epoch) summary_AC = summary.scalar('G_AC', train_loss_acLoss.avg) summary_writer.add_summary(summary_AC, epoch) if (epoch + 1) % 10 == 0: print('save model') if not os.path.exists(model_pth): os.makedirs(model_pth) saver.save(sess, model_pth + 'checkpoint-' + str(epoch)) saver.export_meta_graph(model_pth + 'checkpoint-' + str(epoch) + '.meta') else: print('evaluation') ckpt = tf.train.get_checkpoint_state(model_pth) try: epoch = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1]) saver.restore(sess, ckpt.model_checkpoint_path) print("Model restored") except: print("Could not restore model") exit(0) pass ### generate whole data### if args.test_setting == 'wholedata': print('generate whole data:', epoch) generate_wholedata(sess, input_data, ops, epoch) ### generate lowshot vrd data### elif args.test_setting == 'lowshot': print('generate lowshot data:', epoch) generate_lowshot(sess, input_data, ops, args, epoch) input_data.close()
start_ind:end_ind] gt_concat_fea = np.concatenate( (test_feature_all['sub_fea'][start_ind:end_ind], test_feature_all['obj_fea'][start_ind:end_ind]), axis=1) rd_loss_temp, acc_temp, acc_each = vnet.val_predicate_fea_concate( sess, gt_concat_fea, labels) rd_loss_val = rd_loss_val + rd_loss_temp acc_val += sum(acc_each) print( "whole-val: {0} rd_loss: {1}, acc: {2}, best_acc: {3}". format(step, rd_loss_val / N_val, acc_val / N_val, acc_val_all)) val_loss = summary.scalar('val_loss', rd_loss_val / N_val) summary_writer.add_summary(val_loss, step) val_accuracy = summary.scalar('val_acc', acc_val / N_val) summary_writer.add_summary(val_accuracy, step) if (acc_val / N_val) > acc_val_all: save_path = model_path + '/' + args.data + '_vgg_' + format( int(step), '04') saver.save(sess, save_path) saver.export_meta_graph(save_path + '.meta') acc_val_all = acc_val / N_val best_acc_val = step ###evaluation lowshot dataset ### elif args.mode == "lowshot": rd_loss_val_lowshot = 0.0 acc_val_lowshot = 0
class condGANTrainer(object): def __init__(self, output_dir, data_loader, imsize): self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') self.testImage_dir = os.path.join(output_dir, 'TestImage') if cfg.TRAIN.FLAG: mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) mkdir_p(self.testImage_dir) self.summary_writer = FileWriter(self.log_dir) s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL self.data_loader = data_loader self.num_batches = len(self.data_loader) def prepare_data(self, data): imgs, w_imgs, s_imgs, t_embedding, class_id, _ = data real_vimgs, wrong_vimgs = [], [] if cfg.CUDA: vembedding = Variable(t_embedding).cuda() same_vimg = Variable(s_imgs).cuda() else: vembedding = Variable(t_embedding) same_vimg = Variable(s_imgs) for i in range(self.num_Ds): if cfg.CUDA: real_vimgs.append(Variable(imgs[i]).cuda()) wrong_vimgs.append(Variable(w_imgs[i]).cuda()) else: real_vimgs.append(Variable(imgs[i])) wrong_vimgs.append(Variable(w_imgs[i])) return imgs, real_vimgs, wrong_vimgs, same_vimg, vembedding, class_id def train_MDnet(self, count): flag = count % 100 batch_size = self.real_imgs[0].size(0) real_imgs = self.real_imgs[-1] wrong_imgs = self.wrong_imgs[-1] fake_imgs = self.fake_imgs[-1] similar_imgs = self.similar_imgs # netMD = self.netMD optMD = self.optimizerMD netMD.zero_grad() same_labels = self.same_labels[:batch_size] real_labels = self.real_labels[:batch_size] fake_labels = self.real_labels[:batch_size] wrong_labels = self.wrong_labels[:batch_size] real_feat = self.image_cnn(real_imgs.detach()) real_feat = self.image_encoder(real_feat.detach()) similar_feat = self.image_cnn(similar_imgs.detach()) similar_feat = self.image_encoder(similar_feat.detach()) fake_feat = self.image_cnn(fake_imgs.detach()) fake_feat = self.image_encoder(fake_feat.detach()) wrong_feat = self.image_cnn(wrong_imgs.detach()) wrong_feat = self.image_encoder(wrong_feat.detach()) same_logits = netMD(real_feat, real_feat) real_logits2 = netMD(real_feat, similar_feat) fake_logits2 = netMD(real_feat, fake_feat.detach()) wrong_logits2 = netMD(real_feat, wrong_feat) errMD_si = cfg.TRAIN.COEFF.MD_LOSS * nn.CrossEntropyLoss()( real_logits2, real_labels.long()) errMD_sa = cfg.TRAIN.COEFF.MD_LOSS * nn.CrossEntropyLoss()( same_logits, same_labels.long()) errMD_fa = cfg.TRAIN.COEFF.MD_LOSS * nn.CrossEntropyLoss()( fake_logits2, fake_labels.long()) errMD_wr = cfg.TRAIN.COEFF.MD_LOSS * nn.CrossEntropyLoss()( wrong_logits2, wrong_labels.long()) if cfg.DATASET_NAME == 'birds' or cfg.DATASET_NAME == 'flowers': errMD = errMD_si + errMD_sa + errMD_fa + errMD_wr else: errMD = errMD_si + errMD_fa + errMD_wr # backward errMD.backward() optMD.step() # log if flag == 0: summary_MD = summary.scalar('MD_loss', errMD.item()) self.summary_writer.add_summary(summary_MD, count) return errMD def train_Dnet(self, idx, count): flag = count % 100 batch_size = self.real_imgs[0].size(0) criterion, mu = self.criterion, self.mu netD, optD = self.netsD[idx], self.optimizersD[idx] real_imgs = self.real_imgs[idx] wrong_imgs = self.wrong_imgs[idx] fake_imgs = self.fake_imgs[idx] # netD.zero_grad() # Forward real_labels = self.real_labels[:batch_size] fake_labels = self.fake_labels[:batch_size] # for real real_logits = netD(real_imgs, mu.detach()) wrong_logits = netD(wrong_imgs, mu.detach()) fake_logits = netD(fake_imgs.detach(), mu.detach()) # errD_real = criterion(real_logits[0], real_labels) errD_wrong = criterion(wrong_logits[0], fake_labels) errD_fake = criterion(fake_logits[0], fake_labels) if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion( real_logits[1], real_labels) errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion( wrong_logits[1], real_labels) errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion( fake_logits[1], fake_labels) # errD_real = errD_real + errD_real_uncond errD_wrong = errD_wrong + errD_wrong_uncond errD_fake = errD_fake + errD_fake_uncond # errD = errD_real + errD_wrong + errD_fake else: errD = errD_real + 0.5 * (errD_wrong + errD_fake) # backward errD.backward() # update parameters optD.step() # log if flag == 0: summary_D = summary.scalar('D_loss%d' % idx, errD.item()) self.summary_writer.add_summary(summary_D, count) return errD def train_Gnet(self, count): self.netG.zero_grad() errG_total = 0 flag = count % 100 batch_size = self.real_imgs[0].size(0) criterion, mu, logvar = self.criterion, self.mu, self.logvar real_labels = self.real_labels[:batch_size] for i in range(self.num_Ds): outputs = self.netsD[i](self.fake_imgs[i], mu) errG = criterion(outputs[0], real_labels) if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion( outputs[1], real_labels) errG = errG + errG_patch errG_total = errG_total + errG if cfg.TRAIN.COEFF.CONTENTCONSIST_LOSS > 0 or cfg.TRAIN.COEFF.SEMANTICONSIST_LOSS > 0 or cfg.TRAIN.COEFF.MD_LOSS > 0: fake_feat = self.image_cnn(self.fake_imgs[i]) fake_feat = self.image_encoder(fake_feat) if cfg.TRAIN.COEFF.CONTENTCONSIST_LOSS > 0 or cfg.TRAIN.COEFF.MD_LOSS > 0: real_feat = self.image_cnn(self.real_imgs[i]) real_feat = self.image_encoder(real_feat) if cfg.TRAIN.COEFF.CONTENTCONSIST_LOSS > 0: loss1, loss2 = batch_loss(real_feat, fake_feat, self.class_ids) errG_CC = loss1 + loss2 errG_total = errG_total + errG_CC * cfg.TRAIN.COEFF.CONTENTCONSIST_LOSS if cfg.TRAIN.COEFF.SEMANTICONSIST_LOSS > 0: loss1, loss2 = batch_loss(self.txt_embedding, fake_feat, self.class_ids) errG_SC = loss1 + loss2 errG_total = errG_total + errG_SC * cfg.TRAIN.COEFF.SEMANTICONSIST_LOSS if cfg.TRAIN.COEFF.MD_LOSS > 0 and i == (self.num_Ds - 1): outputs2 = self.netMD(real_feat, fake_feat) errMG = nn.CrossEntropyLoss()(outputs2, real_labels.long()) errG_total = errG_total + errMG * cfg.TRAIN.COEFF.MD_LOSS if flag == 0: summary_D = summary.scalar('G_loss%d' % i, errG.item()) self.summary_writer.add_summary(summary_D, count) # Compute color consistency losses if cfg.TRAIN.COEFF.COLOR_LOSS > 0: if self.num_Ds > 1: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1]) mu2, covariance2 = compute_mean_covariance( self.fake_imgs[-2].detach()) like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * nn.MSELoss()( covariance1, covariance2) errG_total = errG_total + like_mu2 + like_cov2 if flag == 0: sum_mu = summary.scalar('G_like_mu2', like_mu2.item()) self.summary_writer.add_summary(sum_mu, global_step=count) sum_cov = summary.scalar('G_like_cov2', like_cov2.item()) self.summary_writer.add_summary(sum_cov, global_step=count) if self.num_Ds > 2: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2]) mu2, covariance2 = compute_mean_covariance( self.fake_imgs[-3].detach()) like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * nn.MSELoss()( covariance1, covariance2) errG_total = errG_total + like_mu1 + like_cov1 if flag == 0: sum_mu = summary.scalar('G_like_mu1', like_mu1.item()) self.summary_writer.add_summary(sum_mu, count) sum_cov = summary.scalar('G_like_cov1', like_cov1.item()) self.summary_writer.add_summary(sum_cov, count) kl_loss = KL_loss(mu, logvar) * cfg.TRAIN.COEFF.KL errG_total = errG_total + kl_loss errG_total.backward() self.optimizerG.step() return kl_loss, errG_total def train(self): if cfg.TRAIN.COEFF.MD_LOSS > 0: self.netG, self.netsD, self.netMD, self.num_Ds, self.inception_model, start_count = load_network( self.gpus, self.num_batches) else: self.netG, self.netsD, self.num_Ds, self.inception_model, start_count = load_network( self.gpus, self.num_batches) avg_param_G = copy_G_params(self.netG) if cfg.TRAIN.COEFF.CONTENTCONSIST_LOSS > 0 or cfg.TRAIN.COEFF.SEMANTICONSIST_LOSS > 0 or cfg.TRAIN.COEFF.MD_LOSS > 0: self.image_cnn = Inception_v3() self.image_encoder = LINEAR_ENCODER() if not isinstance(self.image_cnn, torch.nn.DataParallel): self.image_cnn = nn.DataParallel(self.image_cnn) if not isinstance(self.image_encoder, torch.nn.DataParallel): self.image_encoder = nn.DataParallel(self.image_encoder) if cfg.DATASET_NAME == 'birds': self.image_encoder.load_state_dict( torch.load( "outputs/pre_train/birds/models/best_image_model.pth")) if cfg.DATASET_NAME == 'flowers': self.image_encoder.load_state_dict( torch.load( "outputs/pre_train/flowers/models/best_image_model.pth" )) if cfg.CUDA: self.image_cnn = self.image_cnn.cuda() self.image_encoder = self.image_encoder.cuda() self.image_cnn.eval() self.image_encoder.eval() for p in self.image_cnn.parameters(): p.requires_grad = False for p in self.image_encoder.parameters(): p.requires_grad = False self.optimizerG, self.optimizersD = define_optimizers( self.netG, self.netsD) if cfg.TRAIN.COEFF.MD_LOSS > 0: self.optimizerMD = optim.Adam(self.netMD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) self.criterion = nn.BCELoss() self.real_labels = Variable( torch.FloatTensor(self.batch_size).fill_(1)) self.fake_labels = Variable( torch.FloatTensor(self.batch_size).fill_(0)) self.same_labels = Variable( torch.FloatTensor(self.batch_size).fill_(0)) self.wrong_labels = Variable( torch.FloatTensor(self.batch_size).fill_(2)) self.gradient_one = torch.FloatTensor([1.0]) self.gradient_half = torch.FloatTensor([0.5]) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) fixed_noise = Variable( torch.FloatTensor(self.batch_size, nz).normal_(0, 1)) if cfg.CUDA: self.criterion.cuda() self.real_labels = self.real_labels.cuda() self.fake_labels = self.fake_labels.cuda() self.same_labels = self.same_labels.cuda() self.wrong_labels = self.wrong_labels.cuda() self.gradient_one = self.gradient_one.cuda() self.gradient_half = self.gradient_half.cuda() noise, fixed_noise = noise.cuda(), fixed_noise.cuda() predictions = [] count = start_count start_epoch = start_count // (self.num_batches) for epoch in range(start_epoch, self.max_epoch): start_t = time.time() for step, data in enumerate(self.data_loader, 0): ####################################################### # (0) Prepare training data ###################################################### self.imgs_tcpu, self.real_imgs, self.wrong_imgs, self.similar_imgs, self.txt_embedding, self.class_ids = self.prepare_data( data) ####################################################### # (1) Generate fake images ###################################################### noise.data.normal_(0, 1) self.fake_imgs, self.mu, self.logvar = self.netG( noise, self.txt_embedding) ####################################################### # (2) Update D network ###################################################### errD_total = 0 for i in range(self.num_Ds): errD = self.train_Dnet(i, count) errD_total += errD #update MD network errMD = self.train_MDnet(count) errD_total += errMD ####################################################### # (3) Update G network: maximize log(D(G(z))) ###################################################### kl_loss, errG_total = self.train_Gnet(count) for p, avg_p in zip(self.netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) # # for inception score # pred = self.inception_model(self.fake_imgs[-1].detach()) # predictions.append(pred.data.cpu().numpy()) if count % 100 == 0: summary_D = summary.scalar('D_loss', errD_total.item()) summary_G = summary.scalar('G_loss', errG_total.item()) summary_KL = summary.scalar('KL_loss', kl_loss.item()) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) count = count + 1 if epoch % cfg.TRAIN.SAVE_EPOCH == 0: if cfg.TRAIN.COEFF.MD_LOSS > 0: DIS_NET = [self.netsD, self.netMD] else: DIS_NET = self.netsD save_model(self.netG, avg_param_G, DIS_NET, epoch, self.model_dir) if epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0: # Save images backup_para = copy_G_params(self.netG) load_params(self.netG, avg_param_G) # self.fake_imgs, _, _ = self.netG(fixed_noise, self.txt_embedding) save_img_results(self.imgs_tcpu, self.fake_imgs, self.num_Ds, epoch, self.image_dir, self.summary_writer) # load_params(self.netG, backup_para) ############################# #***during the training process, the paramerter of G are updated alone #**why in the generating stage, use the weighting parameter of G ############################# """ # Compute inception score if len(predictions) > 500: predictions = np.concatenate(predictions, 0) mean, std = compute_inception_score(predictions, 10) # print('mean:', mean, 'std', std) m_incep = summary.scalar('Inception_mean', mean) self.summary_writer.add_summary(m_incep, count) # mean_nlpp, std_nlpp = negative_log_posterior_probability(predictions, 10) m_nlpp = summary.scalar('NLPP_mean', mean_nlpp) self.summary_writer.add_summary(m_nlpp, count) # predictions = [] """ end_t = time.time() print('''[%d/%d][%d] Loss_D: %.2f Loss_G: %.2f Loss_KL: %.2f Time: %.2fs ''' # D(real): %.4f D(wrong):%.4f D(fake) %.4f % (epoch, self.max_epoch, self.num_batches, errD_total.item(), errG_total.item(), kl_loss.item(), end_t - start_t)) if cfg.TRAIN.COEFF.MD_LOSS > 0: DIS_NET = [self.netsD, self.netMD] else: DIS_NET = self.netsD save_model(self.netG, avg_param_G, DIS_NET, epoch, self.model_dir) self.summary_writer.close() def save_superimages(self, images_list, filenames, save_dir, split_dir, imsize): batch_size = images_list[0].size(0) num_sentences = len(images_list) for i in range(batch_size): s_tmp = '%s/super/%s/%s' % (save_dir, split_dir, filenames[i]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) # savename = '%s_%d.png' % (s_tmp, imsize) super_img = [] for j in range(num_sentences): img = images_list[j][i] # print(img.size()) img = img.view(1, 3, imsize, imsize) # print(img.size()) super_img.append(img) # break super_img = torch.cat(super_img, 0) vutils.save_image(super_img, savename, nrow=10, normalize=True) def save_singleimages(self, images, filenames, save_dir, split_dir, sentenceID, imsize): for i in range(images.size(0)): s_tmp = '%s/%s' % (save_dir, filenames[i]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) fullpath = '%s_%d_sentence%d.png' % (s_tmp, imsize, sentenceID) # range from [-1, 1] to [0, 255] img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte() ndarr = img.permute(1, 2, 0).data.cpu().numpy() im = Image.fromarray(ndarr) im.save(fullpath) def evaluate(self, split_dir): NET_G_root = self.model_dir net_list = os.listdir(NET_G_root) G_NETS = [] for net in net_list: if net.find('netG') != -1: s_tmp = net istart = s_tmp.rfind('_') + 1 iend = s_tmp.rfind('.') epoch = int(s_tmp[istart:iend]) if epoch >= 100 and epoch <= 600: ##################********************************************250 G_NETS.append(net) for NET_G in G_NETS: NET_G_path = os.path.join(NET_G_root, NET_G) if split_dir == 'test': split_dir = 'valid' netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=self.gpus) print(netG) # state_dict = torch.load(cfg.TRAIN.NET_G) state_dict = torch.load(NET_G_path, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load ', NET_G_path) # the path to save generated images s_tmp = NET_G_path istart = s_tmp.rfind('_') + 1 iend = s_tmp.rfind('.') epoch = int(s_tmp[istart:iend]) s_tmp = s_tmp[:s_tmp.rfind('/')] save_dir = '%s/epoch%d' % (self.testImage_dir, epoch) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) if cfg.CUDA: netG.cuda() noise = noise.cuda() # switch to evaluate mode netG.eval() for step, data in enumerate(self.data_loader, 0): imgs, t_embeddings, filenames = data if cfg.CUDA: t_embeddings = Variable(t_embeddings).cuda() else: t_embeddings = Variable(t_embeddings) # print(t_embeddings[:, 0, :], t_embeddings.size(1)) embedding_dim = t_embeddings.size(1) batch_size = imgs[0].size(0) noise.data.resize_(batch_size, nz) noise.data.normal_(0, 1) fake_img_list = [] for i in range(embedding_dim): fake_imgs, _, _ = netG( noise, t_embeddings[:, i, :]) #t_embeddings[:, i, :] by shawn if cfg.TEST.B_EXAMPLE: # fake_img_list.append(fake_imgs[0].data.cpu()) # fake_img_list.append(fake_imgs[1].data.cpu()) fake_img_list.append(fake_imgs[2].data.cpu()) else: self.save_singleimages(fake_imgs[-1], filenames, save_dir, split_dir, i, 256) # self.save_singleimages(fake_imgs[-2], filenames, # save_dir, split_dir, i, 128) # self.save_singleimages(fake_imgs[-3], filenames, # save_dir, split_dir, i, 64) # break if cfg.TEST.B_EXAMPLE: # self.save_superimages(fake_img_list, filenames, # save_dir, split_dir, 64) # self.save_superimages(fake_img_list, filenames, # save_dir, split_dir, 128) self.save_superimages(fake_img_list, filenames, save_dir, split_dir, 256)
class FineGAN_trainer(object): def __init__(self, output_dir): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True self.subdataset_idx = None def prepare_data(self, data): fimgs, cimgs, c_code, _, masks, aux_masks = data if cfg.CUDA: vc_code = Variable(c_code).cuda() masks = Variable(masks).cuda() aux_masks = Variable(aux_masks).cuda() real_vfimgs = Variable(fimgs).cuda() real_vcimgs = Variable(cimgs).cuda() else: vc_code = Variable(c_code) masks = masks.detach() aux_masks = aux_masks.detach() real_vfimgs = Variable(fimgs) real_vcimgs = Variable(cimgs) return fimgs, real_vfimgs, real_vcimgs, vc_code, masks, aux_masks def train_Dnet(self, idx, count): flag = count % 100 batch_size = self.real_fimgs.size(0) criterion, criterion_one = self.criterion, self.criterion_one if idx == 0: real_imgs = self.real_fimgs fake_imgs = self.fake_imgs[0] optD = self.optimizersD[0] netD = self.netsD[0] netD.zero_grad() real_logits = netD(real_imgs, self.alpha, self.masks.detach()) fake_logits = netD(fake_imgs.detach(), self.alpha, self.aux_masks) real_labels = torch.ones_like(real_logits[1]) fake_labels = torch.zeros_like(real_logits[1]) errD_real = criterion_one( real_logits[1], real_labels) # Real/Fake loss for the real image errD_fake = criterion_one( fake_logits[1], fake_labels) # Real/Fake loss for the fake image errD0 = (errD_real + errD_fake) * cfg.TRAIN.BG_LOSS_WT_GLB netD = self.netsD[3] netD.zero_grad() _fg = self.masks == 0 rev_masks = torch.zeros_like(self.masks) rev_masks.masked_fill_(_fg, 1.0) real_logits = netD(real_imgs, self.alpha, rev_masks) fake_labels = torch.zeros_like(real_logits[1]) ext, output, fnl_masks = real_logits weights_real = torch.ones_like(output) real_labels = torch.ones_like(output) # for i in range(batch_size): invalid_patch = fnl_masks != 0.0 weights_real.masked_fill_(invalid_patch, 0.0) norm_fact_real = weights_real.sum() norm_fact_fake = weights_real.shape[0] * weights_real.shape[ 1] * weights_real.shape[2] * weights_real.shape[3] real_logits = ext, output fake_logits = netD(fake_imgs.detach(), self.alpha) errD_real_uncond = criterion( real_logits[1], real_labels ) # Real/Fake loss for 'real background' (on patch level) errD_real_uncond = torch.mul( errD_real_uncond, weights_real ) # Masking output units which correspond to receptive fields which lie within the boundin box errD_real_uncond = errD_real_uncond.mean() errD_fake_uncond = criterion( fake_logits[1], fake_labels ) # Real/Fake loss for 'fake background' (on patch level) errD_fake_uncond = errD_fake_uncond.mean() if norm_fact_real > 0: # Normalizing the real/fake loss for background after accounting the number of masked members in the output. errD_real = errD_real_uncond * ((norm_fact_fake * 1.0) / (norm_fact_real * 1.0)) else: errD_real = errD_real_uncond errD_fake = errD_fake_uncond errD1 = (errD_real + errD_fake) * cfg.TRAIN.BG_LOSS_WT_LCL # Background/foreground classification loss errD_real_uncond_classi = criterion(real_logits[0], weights_real) errD_real_uncond_classi = errD_real_uncond_classi.mean() errD_classi = errD_real_uncond_classi * cfg.TRAIN.BG_CLASSI_WT # print(errD0, errD1) # sys.exit(0) errD = errD0 + errD1 + errD_classi elif idx == 2: # Discriminator is only trained in background and child stage. (NOT in parent stage) netD, optD = self.netsD[2], self.optimizersD[2] real_imgs = self.real_cimgs fake_imgs = self.fake_imgs[2] netD.zero_grad() real_logits = netD(real_imgs, self.alpha) fake_logits = netD(fake_imgs.detach(), self.alpha) real_labels = torch.ones_like(real_logits[1]) fake_labels = torch.zeros_like(real_logits[1]) errD_real = criterion_one( real_logits[1], real_labels) # Real/Fake loss for the real image errD_fake = criterion_one( fake_logits[1], fake_labels) # Real/Fake loss for the fake image errD = errD_real + errD_fake errD.backward() optD.step() if flag == 0: summary_D = summary.scalar('D_loss%d' % idx, errD.item()) self.summary_writer.add_summary(summary_D, count) summary_D_real = summary.scalar('D_loss_real_%d' % idx, errD_real.item()) self.summary_writer.add_summary(summary_D_real, count) summary_D_fake = summary.scalar('D_loss_fake_%d' % idx, errD_fake.item()) self.summary_writer.add_summary(summary_D_fake, count) return errD def train_Gnet(self, count): self.netG.zero_grad() for myit in range(4): self.netsD[myit].zero_grad() errG_total = 0 flag = count % 100 batch_size = self.real_fimgs.size(0) criterion_one, criterion_class, c_code, p_code = self.criterion_one, self.criterion_class, self.c_code, self.p_code for i in range(3): if i == 0 or i == 2: # real/fake loss for background (0) and child (2) stage if i == 0: outputs = self.netsD[0](self.fake_imgs[0], self.alpha, self.aux_masks) real_labels = torch.ones_like(outputs[1]) errG0 = criterion_one(outputs[1], real_labels) errG0 = errG0 * cfg.TRAIN.BG_LOSS_WT_GLB outputs = self.netsD[3](self.fake_imgs[0], self.alpha) real_labels = torch.ones_like(outputs[1]) errG1 = criterion_one(outputs[1], real_labels) errG1 = errG1 * cfg.TRAIN.BG_LOSS_WT_LCL errG_classi = criterion_one( outputs[0], real_labels ) # Background/Foreground classification loss for the fake background image (on patch level) errG_classi = errG_classi * cfg.TRAIN.BG_CLASSI_WT errG = errG0 + errG1 + errG_classi errG_total = errG_total + errG else: # i = 2 outputs = self.netsD[2](self.fake_imgs[2], self.alpha) real_labels = torch.ones_like(outputs[1]) errG = criterion_one(outputs[1], real_labels) errG_total = errG_total + errG if i == 1: # Mutual information loss for the parent stage (1) pred_p = self.netsD[i](self.fg_mk[i - 1], self.alpha) errG_info = criterion_class(pred_p[0], torch.nonzero(p_code.long())[:, 1]) elif i == 2: # Mutual information loss for the child stage (2) pred_c = self.netsD[i](self.fg_mk[i - 1], self.alpha) errG_info = criterion_class(pred_c[0], torch.nonzero(c_code.long())[:, 1]) if i > 0: errG_total = errG_total + errG_info if flag == 0: if i > 0: summary_D_class = summary.scalar('Information_loss_%d' % i, errG_info.item()) self.summary_writer.add_summary(summary_D_class, count) if i == 0 or i == 2: summary_D = summary.scalar('G_loss%d' % i, errG.item()) self.summary_writer.add_summary(summary_D, count) errG_total.backward() for myit in range(3): self.optimizerG[myit].step() return errG_total def get_dataloader(self, cur_depth): bshuffle = True imsize = 32 * (2**(cur_depth + 1)) image_transform = transforms.Compose([ transforms.Resize(int(imsize * 76 / 64)), transforms.RandomCrop(imsize), transforms.RandomHorizontalFlip() ]) dataset = Dataset(cfg.DATA_DIR, cur_depth=cur_depth, transform=image_transform) if cfg.TRAIN.DATASET_SIZE != -1: if self.subdataset_idx is None: self.subdataset_idx = random.sample(range(0, len(dataset)), cfg.TRAIN.DATASET_SIZE) dataset = torch.utils.data.Subset(dataset, self.subdataset_idx) assert dataset print('training dataset size: ', len(dataset)) num_gpu = len(cfg.GPU_ID.split(',')) dataloader = torch.utils.data.DataLoader( dataset, batch_size=batchsize_per_depth[cur_depth] * num_gpu, drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS)) return dataloader def train(self): self.netG, self.netsD, self.num_Ds, start_count = load_network( self.gpus) newly_loaded = True avg_param_G = copy_G_params(self.netG) self.optimizerG, self.optimizersD = \ define_optimizers(self.netG, self.netsD) self.criterion = nn.BCELoss(reduce=False) self.criterion_one = nn.BCELoss() self.criterion_class = nn.CrossEntropyLoss() nz = cfg.GAN.Z_DIM if cfg.CUDA: self.criterion.cuda() self.criterion_one.cuda() self.criterion_class.cuda() print("Starting normal FineGAN training..") count = start_count for cur_depth in range(start_depth, end_depth + 1): max_epoch = blend_epochs_per_depth[cur_depth] + \ stable_epochs_per_depth[cur_depth] dataloader = self.get_dataloader(cur_depth) num_batches = len(dataloader) depth_ep_ctr = 0 # depth epoch counter batch_size = batchsize_per_depth[cur_depth] * self.num_gpus noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = Variable( torch.FloatTensor(batch_size, nz).normal_(0, 1)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() start_epoch = start_count // (num_batches) start_count = 0 for epoch in range(start_epoch, max_epoch): depth_ep_ctr += 1 # switch dataset if depth_ep_ctr < blend_epochs_per_depth[cur_depth]: self.alpha = depth_ep_ctr / blend_epochs_per_depth[ cur_depth] else: self.alpha = 1 start_t = time.time() for step, data in enumerate(dataloader, 0): count += 1 _, self.real_fimgs, self.real_cimgs, \ self.c_code, self.masks, self.aux_masks = self.prepare_data(data) # Feedforward through Generator. Obtain stagewise fake images noise.data.normal_(0, 1) fake_imgs, fg_imgs, mk_imgs, fg_mk = self.netG( noise, self.c_code, self.alpha) self.fake_imgs = fake_imgs[cur_depth * 3:cur_depth * 3 + 3] self.fg_imgs = fg_imgs[cur_depth * 2:cur_depth * 2 + 2] self.mk_imgs = mk_imgs[cur_depth * 2:cur_depth * 2 + 2] self.fg_mk = fg_mk[cur_depth * 2:cur_depth * 2 + 2] # Obtain the parent code given the child code self.p_code = child_to_parent(self.c_code, cfg.FINE_GRAINED_CATEGORIES, cfg.SUPER_CATEGORIES) # Update Discriminator networks errD_total = 0 for i in range(3): if i == 0 or i == 2: # only at parent and child stage errD = self.train_Dnet(i, count) errD_total += errD # Update the Generator networks errG_total = self.train_Gnet(count) for p, avg_p in zip(self.netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) newly_loaded = False if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: backup_para = copy_G_params(self.netG) if count % cfg.TRAIN.SAVEMODEL_INTERVAL == 0: save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir, cur_depth) # Save images load_params(self.netG, avg_param_G) fake_imgs, fg_imgs, mk_imgs, fg_mk = self.netG( fixed_noise, self.c_code, self.alpha) save_img_results((fake_imgs[cur_depth*3:cur_depth*3+3] + fg_imgs[cur_depth*2:cur_depth*2+2] \ + mk_imgs[cur_depth*2:cur_depth*2+2] + fg_mk[cur_depth*2:cur_depth*2+2]), count, self.image_dir, self.summary_writer, cur_depth) # load_params(self.netG, backup_para) end_t = time.time() print('''[%d/%d][%d]Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' % (epoch, max_epoch, num_batches, errD_total.item(), errG_total.item(), end_t - start_t)) # sys.exit(0) if not newly_loaded: save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir, cur_depth) self.update_network() avg_param_G = copy_G_params(self.netG) def update_network(self): self.netG.module.inc_depth() # self.netG = torch.nn.DataParallel(self.netG, device_ids=self.gpus) print(self.netG) for netD in self.netsD: netD.module.inc_depth() # netD = torch.nn.DataParallel(netD, device_ids=self.gpus) print(netD) if cfg.CUDA: self.netG.cuda() for netD in self.netsD: netD.cuda() self.optimizersD = [] for netD in self.netsD: opt = optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) self.optimizersD.append(opt) self.optimizerG = [] self.optimizerG.append( optim.Adam(self.netG.parameters(), lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999))) opt = optim.Adam(self.netsD[1].parameters(), lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) self.optimizerG.append(opt) opt = optim.Adam( [{ 'params': self.netsD[2].module.down_net[0].jointConv.parameters() }, { 'params': self.netsD[2].module.down_net[0].logits.parameters() }], lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) self.optimizerG.append(opt)
class WHAI_GAN_Trainer(object): def __init__(self, output_dir, data_loader): self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL self.data_loader = data_loader self.num_batches = len(self.data_loader) def prepare_data(self, data): real_vimgs, wrong_vimgs = [], [] imgs, texts, w_imgs, _ = data if cfg.CUDA: vtxts = Variable(texts).cuda() else: vtxts = Variable(texts) for i in xrange(3): if cfg.CUDA: real_vimgs.append(Variable(imgs[i]).cuda()) wrong_vimgs.append(Variable(w_imgs[i]).cuda()) else: real_vimgs.append(Variable(imgs[i])) wrong_vimgs.append(Variable(w_imgs[i])) return imgs, vtxts, real_vimgs, wrong_vimgs def train_Dnet(self, idx, count): flag = count % 100 batch_size = self.real_tgpu[0].size(0) criterion, mu = self.criterion_1, self.mu_theta1 netD, optD = self.netsD[idx], self.optimizersD[idx] real_imgs = self.real_tgpu[idx] wrong_imgs = self.wrong_tgpu[idx] fake_imgs = self.fake_imgs[idx] netD.zero_grad() real_labels = self.real_labels[:batch_size] fake_labels = self.fake_labels[:batch_size] # for real real_logits = netD(real_imgs, mu.detach()) wrong_logits = netD(wrong_imgs, mu.detach()) fake_logits = netD(fake_imgs.detach(), mu.detach()) errD_real = criterion(real_logits[0], real_labels) errD_wrong = criterion(wrong_logits[0], fake_labels) errD_fake = criterion(fake_logits[0], fake_labels) if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ criterion(real_logits[1], real_labels) errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ criterion(wrong_logits[1], real_labels) errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ criterion(fake_logits[1], fake_labels) errD_real = errD_real + errD_real_uncond errD_wrong = errD_wrong + errD_wrong_uncond errD_fake = errD_fake + errD_fake_uncond errD = errD_real + errD_wrong + errD_fake else: errD = errD_real + 0.5 * (errD_wrong + errD_fake) # backward errD.backward() # update parameters optD.step() # log if flag == 0: summary_D = summary.scalar('D_loss%d' % idx, float(errD.data[0])) self.summary_writer.add_summary(summary_D, count) return float(errD) def train_Gnet(self, count): errG_total = 0 flag = count % 100 batch_size = self.real_tgpu[0].size(0) criterion_1, mu = self.criterion_1, self.mu_theta1 params = [ self.shape1, self.scale1, self.Phi1, self.theta1, self.txtbow ] criterion, optEnG, netG = self.criterion_2, self.optimizerEnG, self.netG loss, theta1_KL, Likelihood, p1, p2, p3, shape1, scale1 = criterion( params) real_labels = self.real_labels[:batch_size] for i in xrange(self.num_Ds): outputs = self.netsD[i](self.fake_imgs[i], mu) errG = criterion_1(outputs[0], real_labels) if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS * \ criterion_1(outputs[1], real_labels) errG = errG + errG_patch errG_total = errG_total + errG if flag == 0: summary_D = summary.scalar('G_loss%d' % i, errG.data[0]) self.summary_writer.add_summary(summary_D, count) errG_total += loss optEnG.zero_grad() errG_total.backward() optEnG.step() return float(errG_total ), theta1_KL, Likelihood, loss, p1, p2, p3, shape1, scale1 def updatePhi(self, miniBatch, Phi, Theta, MBratio, MBObserved, NDot): Xt = miniBatch Xt_to_t1, WSZS = PGBN_sampler.Multrnd_Matrix(Xt.astype('double'), Phi.astype('double'), Theta.astype('double')) EWSZS = WSZS EWSZS = MBratio * EWSZS if (MBObserved == 0): NDot = EWSZS.sum(0) else: NDot = (1 - self.ForgetRate[MBObserved] ) * NDot + self.ForgetRate[MBObserved] * EWSZS.sum(0) tmp = EWSZS + self.eta tmp = (1 / NDot) * (tmp - tmp.sum(0) * Phi) tmp1 = (2 / NDot) * Phi tmp = Phi + self.epsit[MBObserved] * tmp + np.sqrt( self.epsit[MBObserved] * tmp1) * np.random.randn( Phi.shape[0], Phi.shape[1]) Phi = PGBN_sampler.ProjSimplexSpecial(tmp, Phi, 0) return Phi, NDot def train(self): self.netG, self.netsD, self.netIMG, self.inception_model,\ self.num_Ds, start_count = load_network(self.gpus) avg_param_G = copy_G_params(self.netG) self.optimizersD, self.optimizerEnG = \ define_optimizers(self.netG, self.netsD, self.netIMG) self.criterion_1 = nn.BCELoss() self.criterion_2 = myLoss() self.real_labels = \ Variable(torch.FloatTensor(self.batch_size).fill_(1)) self.fake_labels = \ Variable(torch.FloatTensor(self.batch_size).fill_(0)) # Prepare PHI real_min = np.float64(2.2e-308) Phi1 = 0.2 + 0.8 * np.float64(np.random.rand(1000, 256)) Phi1 = Phi1 / np.maximum(real_min, Phi1.sum(0)) if cfg.CUDA: self.Phi1 = Variable(Phi1).cuda() self.criterion_1.cuda() self.criterion_2.cuda() self.real_labels = self.real_labels.cuda() self.fake_labels = self.fake_labels.cuda() predictions = [] count = start_count start_epoch = start_count // self.num_batches batch_length = self.num_batches self.NDot = 0 self.ForgetRate = np.power( (0 + np.linspace(1, cfg.TRAIN.MAX_EPOCH * int(batch_length), cfg.TRAIN.MAX_EPOCH * int(batch_length))), -0.7) self.eta = 0.1 epsit = np.power( (20 + np.linspace(1, cfg.TRAIN.MAX_EPOCH * int(batch_length), cfg.TRAIN.MAX_EPOCH * int(batch_length))), -0.7) self.epsit = 1 * epsit / epsit[0] num_total_samples = batch_length * self.batch_size start_t = time.time() for epoch in xrange(start_epoch, self.max_epoch): LL = 0 KL = 0 LS = 0 p1 = 0 p2 = 0 p3 = 0 DL = 0 GL = 0 shape = [] scale = [] for step, data in enumerate(self.data_loader, 0): ####################################################### # (0) Prepare training data ###################################################### self.img_tcpu, self.txtbow, self.real_tgpu, self.wrong_tgpu = self.prepare_data( data) ####################################################### # (1) Get conv hidden units ###################################################### _, self.flat = self.inception_model(self.real_tgpu[-1]) ####################################################### # (2) Get shape, scale and sample of theta ###################################################### self.theta1, self.shape1, self.scale1 = self.netIMG(self.flat) shape1 = self.shape1.detach().cpu().numpy() scale1 = self.scale1.detach().cpu().numpy() mu_theta1 = scale1 * ss.gamma(1 + 1 / shape1) self.mu_theta1 = torch.tensor(mu_theta1, dtype=torch.float32) ####################################################### # (3) Generate fake images ###################################################### self.fake_imgs, _ = self.netG(self.mu_theta1.detach()) ####################################################### # (4) Update D network ###################################################### errD_total = 0 for i in xrange(self.num_Ds): errD = self.train_Dnet(i, count) errD_total += errD ####################################################### # (5) Update G network (or En network): maximize log(D(G(z))) ###################################################### errG_total, self.KL1, self.LL, self.LS, self.p1, self.p2, self.p3, shape1, scale1 = self.train_Gnet( count) LL += self.LL KL += self.KL1 LS += self.LS p1 += self.p1 p2 += self.p2 p3 += self.p3 shape.append(shape1) scale.append(scale1) for p, avg_p in zip(self.netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) DL += errD_total GL += errG_total ####################################################### # (6) Update Phi ####################################################### input_txt = np.array(np.transpose(self.txtbow.cpu().numpy()), order='C').astype('double') Phi1 = np.array(self.Phi1.cpu().numpy(), order='C').astype('double') Theta1 = np.array(np.transpose(self.theta1.cpu().numpy()), order='C').astype('double') phi1, self.NDot = self.updatePhi(input_txt, Phi1, Theta1, int(batch_length), count, self.NDot) self.Phi1 = torch.tensor(phi1, dtype=torch.float32).cuda() # for inception score pred, _ = self.inception_model(self.fake_imgs[-1].detach()) predictions.append(pred.data.cpu().numpy()) if count % 100 == 0: summary_D = summary.scalar('D_loss', errD_total) summary_G = summary.scalar('G_loss', errG_total) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_G, count) if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: save_model(self.netIMG, self.netG, avg_param_G, self.netsD, epoch, count, self.model_dir) # Save images backup_para = copy_G_params(self.netG) load_params(self.netG, avg_param_G) self.fake_imgs, _ = \ self.netG(self.theta1) save_img_results(self.img_tcpu, self.fake_imgs, self.num_Ds, count, self.image_dir, self.summary_writer) # load_params(self.netG, backup_para) # Compute inception score if len(predictions) > 500: predictions = np.concatenate(predictions, 0) mean, std = compute_inception_score(predictions, 10) m_incep = summary.scalar('Inception_mean', mean) self.summary_writer.add_summary(m_incep, count) # mean_nlpp, std_nlpp = \ negative_log_posterior_probability(predictions, 10) m_nlpp = summary.scalar('NLPP_mean', mean_nlpp) self.summary_writer.add_summary(m_nlpp, count) # predictions = [] count += 1 end_t = time.time() LS = LS / num_total_samples LL = LL / num_total_samples KL = KL / num_total_samples DL = DL / num_total_samples GL = GL / num_total_samples print( 'Epoch: %d/%d, Time elapsed: %.4fs\n' '* Batch Train Loss: %.6f (LL: %.6f, KL: %.6f, Loss_D:' '%.2f Loss_G: %.2f)\n' % (epoch, self.max_epoch, end_t - start_t, LS, LL, KL, DL, GL)) start_t = time.time() if epoch % 50 == 0: save_model(self.netIMG, self.netG, avg_param_G, self.netsD, epoch, count, self.model_dir) # save the model at the last updating save_model(self.netIMG, self.netG, avg_param_G, self.netsD, epoch, count, self.model_dir) self.summary_writer.close()