def test_log_scalar_summary(): logdir = './experiment/scalar' writer = FileWriter(logdir) for i in range(10): s = scalar('scalar', i) writer.add_summary(s, i + 1) writer.flush() writer.close()
class GradientMetric: def __init__(self): self.training_log = 'logs/train/network' self.gradient_metric_name = '{0}_gradient' self.summary_writer = FileWriter(self.training_log) def log_gradient(self, network_name, gradient): assert self.summary_writer summary_value = summary.histogram('{0}'.format(network_name), gradient) self.summary_writer.add_summary(summary_value)
def test_log_histogram_summary(): logdir = './experiment/histogram' writer = FileWriter(logdir) for i in range(10): mu, sigma = i * 0.1, 1.0 values = np.random.normal(mu, sigma, 10000) # larger for better looking. hist = summary.histogram('discrete_normal', values) writer.add_summary(hist, i + 1) writer.flush() writer.close()
def test_log_image_summary(): logdir = './experiment/image' writer = FileWriter(logdir) path = 'http://yann.lecun.com/exdb/mnist/' (train_lbl, train_img) = read_data(path + 'train-labels-idx1-ubyte.gz', path + 'train-images-idx3-ubyte.gz') for i in range(10): tensor = np.reshape(train_img[i], (28, 28, 1)) im = summary.image( 'mnist/' + str(i), tensor) # in this case, images are grouped under `mnist` tag. writer.add_summary(im, i + 1) writer.flush() writer.close()
def test_event_logging(): logdir = './experiment/' summary_writer = FileWriter(logdir) scalar_value = 1.0 s = scalar('test_scalar', scalar_value) summary_writer.add_summary(s, global_step=1) summary_writer.close() assert os.path.isdir(logdir) assert len(os.listdir(logdir)) == 1 summary_writer = FileWriter(logdir) scalar_value = 1.0 s = scalar('test_scalar', scalar_value) summary_writer.add_summary(s, global_step=1) summary_writer.close() assert os.path.isdir(logdir) assert len(os.listdir(logdir)) == 2 # clean up. shutil.rmtree(logdir)
class condGANTrainer(object): def __init__(self, output_dir, data_loader, imsize): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL self.data_loader = data_loader self.num_batches = len(self.data_loader) def prepare_data(self, data): imgs, w_imgs, t_embedding, _ = data real_vimgs, wrong_vimgs = [], [] if cfg.CUDA: vembedding = Variable(t_embedding).cuda() else: vembedding = Variable(t_embedding) for i in range(self.num_Ds): if cfg.CUDA: real_vimgs.append(Variable(imgs[i]).cuda()) wrong_vimgs.append(Variable(w_imgs[i]).cuda()) else: real_vimgs.append(Variable(imgs[i])) wrong_vimgs.append(Variable(w_imgs[i])) return imgs, real_vimgs, wrong_vimgs, vembedding def train_Dnet(self, idx, count): flag = count % 100 batch_size = self.real_imgs[0].size(0) criterion, mu = self.criterion, self.mu netD, optD = self.netsD[idx], self.optimizersD[idx] real_imgs = self.real_imgs[idx] wrong_imgs = self.wrong_imgs[idx] fake_imgs = self.fake_imgs[idx] # netD.zero_grad() # Forward real_labels = self.real_labels[:batch_size] fake_labels = self.fake_labels[:batch_size] # for real real_logits = netD(real_imgs, mu.detach()) wrong_logits = netD(wrong_imgs, mu.detach()) fake_logits = netD(fake_imgs.detach(), mu.detach()) # errD_real = criterion(real_logits[0], real_labels) errD_wrong = criterion(wrong_logits[0], fake_labels) errD_fake = criterion(fake_logits[0], fake_labels) if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ criterion(real_logits[1], real_labels) errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ criterion(wrong_logits[1], real_labels) errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ criterion(fake_logits[1], fake_labels) # errD_real = errD_real + errD_real_uncond errD_wrong = errD_wrong + errD_wrong_uncond errD_fake = errD_fake + errD_fake_uncond # errD = errD_real + errD_wrong + errD_fake else: errD = errD_real + 0.5 * (errD_wrong + errD_fake) # backward errD.backward() # update parameters optD.step() # log if flag == 0: summary_D = summary.scalar('D_loss%d' % idx, errD.item()) self.summary_writer.add_summary(summary_D, count) return errD def train_Gnet(self, count): self.netG.zero_grad() errG_total = 0 flag = count % 100 batch_size = self.real_imgs[0].size(0) criterion, mu, logvar = self.criterion, self.mu, self.logvar real_labels = self.real_labels[:batch_size] for i in range(self.num_Ds): outputs = self.netsD[i](self.fake_imgs[i], mu) errG = criterion(outputs[0], real_labels) if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS *\ criterion(outputs[1], real_labels) errG = errG + errG_patch errG_total = errG_total + errG if flag == 0: summary_D = summary.scalar('G_loss%d' % i, errG.item()) self.summary_writer.add_summary(summary_D, count) # Compute color consistency losses if cfg.TRAIN.COEFF.COLOR_LOSS > 0: if self.num_Ds > 1: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1]) mu2, covariance2 = \ compute_mean_covariance(self.fake_imgs[-2].detach()) like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ nn.MSELoss()(covariance1, covariance2) errG_total = errG_total + like_mu2 + like_cov2 if flag == 0: sum_mu = summary.scalar('G_like_mu2', like_mu2.item()) self.summary_writer.add_summary(sum_mu, count) sum_cov = summary.scalar('G_like_cov2', like_cov2.item()) self.summary_writer.add_summary(sum_cov, count) if self.num_Ds > 2: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2]) mu2, covariance2 = \ compute_mean_covariance(self.fake_imgs[-3].detach()) like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ nn.MSELoss()(covariance1, covariance2) errG_total = errG_total + like_mu1 + like_cov1 if flag == 0: sum_mu = summary.scalar('G_like_mu1', like_mu1.item()) self.summary_writer.add_summary(sum_mu, count) sum_cov = summary.scalar('G_like_cov1', like_cov1.item()) self.summary_writer.add_summary(sum_cov, count) kl_loss = KL_loss(mu, logvar) * cfg.TRAIN.COEFF.KL errG_total = errG_total + kl_loss # Postpone the backward propagation # errG_total.backward() # self.optimizerG.step() return kl_loss, errG_total def train(self): self.netG, self.netsD, self.num_Ds,\ self.inception_model, start_count = load_network(self.gpus) avg_param_G = copy_G_params(self.netG) self.optimizerG, self.optimizersD = \ define_optimizers(self.netG, self.netsD) self.criterion = nn.BCELoss() self.SATcriterion = nn.CrossEntropyLoss() self.real_labels = Variable( torch.FloatTensor(self.batch_size).fill_(1)) self.fake_labels = Variable( torch.FloatTensor(self.batch_size).fill_(0)) self.gradient_one = torch.FloatTensor([1.0]) self.gradient_half = torch.FloatTensor([0.5]) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) fixed_noise = Variable( torch.FloatTensor(self.batch_size, nz).normal_(0, 1)) # Data parameters data_folder = 'birds_output' # folder with data files saved by create_input_files.py data_name = 'CUB_5_cap_per_img_5_min_word_freq' # base name shared by data files normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Show, Attend, and Tell Dataloader train_loader = torch.utils.data.DataLoader( CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])), batch_size=self.batch_size, shuffle=True, num_workers=int(cfg.WORKERS), pin_memory=True) if cfg.CUDA: self.criterion.cuda() self.SATcriterion.cuda() # Compute SATloss self.real_labels = self.real_labels.cuda() self.fake_labels = self.fake_labels.cuda() self.gradient_one = self.gradient_one.cuda() self.gradient_half = self.gradient_half.cuda() noise, fixed_noise = noise.cuda(), fixed_noise.cuda() predictions = [] count = start_count start_epoch = start_count // (self.num_batches) for epoch in range(start_epoch, self.max_epoch): start_t = time.time() # for step, data in enumerate(self.data_loader, 0): for step, data in enumerate(zip(self.data_loader, train_loader), 0): data_1 = data[0] _, caps, caplens = data[1] data = data_1 ####################################################### # (0) Prepare training data ###################################################### self.imgs_tcpu, self.real_imgs, self.wrong_imgs, \ self.txt_embedding = self.prepare_data(data) # Testing line for real samples if epoch == start_epoch and step == 0: print('Checking real samples at first...') save_real(self.imgs_tcpu, self.image_dir) ####################################################### # (1) Generate fake images ###################################################### noise.data.normal_(0, 1) self.fake_imgs, self.mu, self.logvar = \ self.netG(noise, self.txt_embedding) # len(self.fake_imgs) = NUM_BRANCHES # self.fake_imgs[0].shape = [batch_size, 3, 64, 64] # self.fake_imgs[1].shape = [batch_size, 3, 128, 128] # self.fake_imgs[2].shape = [batch_size, 3, 256, 256] ####################################################### # (*) Forward fake images to SAT ###################################################### from SATmodels import Encoder, DecoderWithAttention from torch.nn.utils.rnn import pack_padded_sequence fine_tune_encoder = False # Read word map word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json') with open(word_map_file, 'r') as j: word_map = json.load(j) # Define the encoder/decoder structure for SAT model decoder = DecoderWithAttention(attention_dim=512, embed_dim=512, decoder_dim=512, vocab_size=len(word_map), dropout=0.5).cuda() decoder_optimizer = torch.optim.Adam(params=filter( lambda p: p.requires_grad, decoder.parameters()), lr=4e-4) encoder = Encoder().cuda() encoder.fine_tune(fine_tune_encoder) encoder_optimizer = torch.optim.Adam( params=filter(lambda p: p.requires_grad, encoder.parameters()), lr=1e-4) if fine_tune_encoder else None SATloss = 0 # Compute the SAT loss after forwarding the SAT model for idx in range(len(self.fake_imgs)): img = encoder(self.fake_imgs[idx]) scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder( img, caps, caplens) targets = caps_sorted[:, 1:] scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True).cuda() targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True).cuda() SATloss += self.SATcriterion(scores, targets) + 1 * ( (1. - alphas.sum(dim=1))**2).mean() # Set zero_grad for encoder/decoder decoder_optimizer.zero_grad() if encoder_optimizer is not None: encoder_optimizer.zero_grad() ####################################################### # (2) Update D network ###################################################### errD_total = 0 for i in range(self.num_Ds): errD = self.train_Dnet(i, count) errD_total += errD ####################################################### # (3) Update G network: maximize log(D(G(z))) ###################################################### kl_loss, errG_total = self.train_Gnet(count) for p, avg_p in zip(self.netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) # Combine with G and SAT first, then back propagation errG_total += SATloss errG_total.backward() self.optimizerG.step() ####################################################### # (*) Update SAT network: ###################################################### # Update weights decoder_optimizer.step() if encoder_optimizer is not None: encoder_optimizer.step() ####################################################### # (*) Prediction and Inception score: ###################################################### pred = self.inception_model(self.fake_imgs[-1].detach()) predictions.append(pred.data.cpu().numpy()) if count % 100 == 0: summary_D = summary.scalar('D_loss', errD_total.item()) summary_G = summary.scalar('G_loss', errG_total.item()) summary_KL = summary.scalar('KL_loss', kl_loss.item()) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) count += 1 ####################################################### # (*) Save Images/Log/Model per SNAPSHOT_INTERVAL: ###################################################### if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) # Save images backup_para = copy_G_params(self.netG) load_params(self.netG, avg_param_G) # self.fake_imgs, _, _ = self.netG(fixed_noise, self.txt_embedding) save_img_results(self.imgs_tcpu, self.fake_imgs, self.num_Ds, count, self.image_dir, self.summary_writer) # load_params(self.netG, backup_para) # Compute inception score if len(predictions) > 500: predictions = np.concatenate(predictions, 0) mean, std = compute_inception_score(predictions, 10) # print('mean:', mean, 'std', std) m_incep = summary.scalar('Inception_mean', mean) self.summary_writer.add_summary(m_incep, count) # mean_nlpp, std_nlpp = negative_log_posterior_probability( predictions, 10) m_nlpp = summary.scalar('NLPP_mean', mean_nlpp) self.summary_writer.add_summary(m_nlpp, count) # predictions = [] end_t = time.time() print('''[%d/%d][%d] Loss_D: %.2f Loss_G: %.2f Loss_KL: %.2f Time: %.2fs ''' # D(real): %.4f D(wrong):%.4f D(fake) %.4f % (epoch, self.max_epoch, self.num_batches, errD_total.item(), errG_total.item(), kl_loss.item(), end_t - start_t)) save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) self.summary_writer.close() def save_superimages(self, images_list, filenames, save_dir, split_dir, imsize): batch_size = images_list[0].size(0) num_sentences = len(images_list) for i in range(batch_size): s_tmp = '%s/super/%s/%s' %\ (save_dir, split_dir, filenames[i]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) # savename = '%s_%d.png' % (s_tmp, imsize) super_img = [] for j in range(num_sentences): img = images_list[j][i] # print(img.size()) img = img.view(1, 3, imsize, imsize) # print(img.size()) super_img.append(img) # break super_img = torch.cat(super_img, 0) vutils.save_image(super_img, savename, nrow=10, normalize=True) def save_singleimages(self, images, filenames, save_dir, split_dir, sentenceID, imsize): for i in range(images.size(0)): s_tmp = '%s/single_samples/%s/%s' %\ (save_dir, split_dir, filenames[i]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) fullpath = '%s_%d_sentence%d.png' % (s_tmp, imsize, sentenceID) # range from [-1, 1] to [0, 255] img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte() ndarr = img.permute(1, 2, 0).data.cpu().numpy() im = Image.fromarray(ndarr) im.save(fullpath) def evaluate(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: # Build and load the generator if split_dir == 'test': split_dir = 'valid' netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=self.gpus) print(netG) # state_dict = torch.load(cfg.TRAIN.NET_G) state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) # the path to save generated images s_tmp = cfg.TRAIN.NET_G istart = s_tmp.rfind('_') + 1 iend = s_tmp.rfind('.') iteration = int(s_tmp[istart:iend]) s_tmp = s_tmp[:s_tmp.rfind('/')] save_dir = '%s/iteration%d' % (s_tmp, iteration) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) if cfg.CUDA: netG.cuda() noise = noise.cuda() # switch to evaluate mode netG.eval() for step, data in enumerate(self.data_loader, 0): imgs, t_embeddings, filenames = data if cfg.CUDA: t_embeddings = Variable(t_embeddings).cuda() else: t_embeddings = Variable(t_embeddings) # print(t_embeddings[:, 0, :], t_embeddings.size(1)) embedding_dim = t_embeddings.size(1) batch_size = imgs[0].size(0) noise.data.resize_(batch_size, nz) noise.data.normal_(0, 1) fake_img_list = [] for i in range(embedding_dim): fake_imgs, _, _ = netG(noise, t_embeddings[:, i, :]) if cfg.TEST.B_EXAMPLE: # fake_img_list.append(fake_imgs[0].data.cpu()) # fake_img_list.append(fake_imgs[1].data.cpu()) fake_img_list.append(fake_imgs[2].data.cpu()) else: self.save_singleimages(fake_imgs[-1], filenames, save_dir, split_dir, i, 256) # self.save_singleimages(fake_imgs[-2], filenames, # save_dir, split_dir, i, 128) # self.save_singleimages(fake_imgs[-3], filenames, # save_dir, split_dir, i, 64) # break if cfg.TEST.B_EXAMPLE: # self.save_superimages(fake_img_list, filenames, # save_dir, split_dir, 64) # self.save_superimages(fake_img_list, filenames, # save_dir, split_dir, 128) self.save_superimages(fake_img_list, filenames, save_dir, split_dir, 256)
class condGANTrainer(object): def __init__(self, output_dir, data_loader, imsize): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL self.data_loader = data_loader self.num_batches = len(self.data_loader) def prepare_data(self, data): imgs, w_imgs, t_embedding, _ = data real_vimgs, wrong_vimgs = [], [] if cfg.CUDA: vembedding = Variable(t_embedding).cuda() else: vembedding = Variable(t_embedding) for i in range(self.num_Ds): if cfg.CUDA: real_vimgs.append(Variable(imgs[i]).cuda()) wrong_vimgs.append(Variable(w_imgs[i]).cuda()) else: real_vimgs.append(Variable(imgs[i])) wrong_vimgs.append(Variable(w_imgs[i])) return imgs, real_vimgs, wrong_vimgs, vembedding def train_Dnet(self, idx, count): flag = count % 100 batch_size = self.real_imgs[0].size(0) criterion, mu = self.criterion, self.mu netD, optD = self.netsD[idx], self.optimizersD[idx] real_imgs = self.real_imgs[idx] wrong_imgs = self.wrong_imgs[idx] fake_imgs = self.fake_imgs[idx] # netD.zero_grad() # Forward real_labels = self.real_labels[:batch_size] fake_labels = self.fake_labels[:batch_size] # for real real_logits = netD(real_imgs, mu.detach()) wrong_logits = netD(wrong_imgs, mu.detach()) fake_logits = netD(fake_imgs.detach(), mu.detach()) # errD_real = criterion(real_logits[0], real_labels) errD_wrong = criterion(wrong_logits[0], fake_labels) errD_fake = criterion(fake_logits[0], fake_labels) if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ criterion(real_logits[1], real_labels) errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ criterion(wrong_logits[1], real_labels) errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ criterion(fake_logits[1], fake_labels) # errD_real = errD_real + errD_real_uncond errD_wrong = errD_wrong + errD_wrong_uncond errD_fake = errD_fake + errD_fake_uncond # errD = errD_real + errD_wrong + errD_fake else: errD = errD_real + 0.5 * (errD_wrong + errD_fake) # backward errD.backward() # update parameters optD.step() # log if flag == 0: summary_D = summary.scalar('D_loss%d' % idx, errD.data[0]) self.summary_writer.add_summary(summary_D, count) return errD def train_Gnet(self, count): self.netG.zero_grad() errG_total = 0 flag = count % 100 batch_size = self.real_imgs[0].size(0) criterion, mu, logvar = self.criterion, self.mu, self.logvar real_labels = self.real_labels[:batch_size] for i in range(self.num_Ds): outputs = self.netsD[i](self.fake_imgs[i], mu) errG = criterion(outputs[0], real_labels) if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS *\ criterion(outputs[1], real_labels) errG = errG + errG_patch errG_total = errG_total + errG if flag == 0: summary_D = summary.scalar('G_loss%d' % i, errG.data[0]) self.summary_writer.add_summary(summary_D, count) # Compute color consistency losses if cfg.TRAIN.COEFF.COLOR_LOSS > 0: if self.num_Ds > 1: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1]) mu2, covariance2 = \ compute_mean_covariance(self.fake_imgs[-2].detach()) like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ nn.MSELoss()(covariance1, covariance2) errG_total = errG_total + like_mu2 + like_cov2 if flag == 0: sum_mu = summary.scalar('G_like_mu2', like_mu2.data[0]) self.summary_writer.add_summary(sum_mu, count) sum_cov = summary.scalar('G_like_cov2', like_cov2.data[0]) self.summary_writer.add_summary(sum_cov, count) if self.num_Ds > 2: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2]) mu2, covariance2 = \ compute_mean_covariance(self.fake_imgs[-3].detach()) like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ nn.MSELoss()(covariance1, covariance2) errG_total = errG_total + like_mu1 + like_cov1 if flag == 0: sum_mu = summary.scalar('G_like_mu1', like_mu1.data[0]) self.summary_writer.add_summary(sum_mu, count) sum_cov = summary.scalar('G_like_cov1', like_cov1.data[0]) self.summary_writer.add_summary(sum_cov, count) kl_loss = KL_loss(mu, logvar) * cfg.TRAIN.COEFF.KL errG_total = errG_total + kl_loss errG_total.backward() self.optimizerG.step() return kl_loss, errG_total def train(self): self.netG, self.netsD, self.num_Ds,\ self.inception_model, start_count = load_network(self.gpus) avg_param_G = copy_G_params(self.netG) self.optimizerG, self.optimizersD = \ define_optimizers(self.netG, self.netsD) self.criterion = nn.BCELoss() self.real_labels = \ Variable(torch.FloatTensor(self.batch_size).fill_(1)) self.fake_labels = \ Variable(torch.FloatTensor(self.batch_size).fill_(0)) self.gradient_one = torch.FloatTensor([1.0]) self.gradient_half = torch.FloatTensor([0.5]) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1)) if cfg.CUDA: self.criterion.cuda() self.real_labels = self.real_labels.cuda() self.fake_labels = self.fake_labels.cuda() self.gradient_one = self.gradient_one.cuda() self.gradient_half = self.gradient_half.cuda() noise, fixed_noise = noise.cuda(), fixed_noise.cuda() predictions = [] count = start_count start_epoch = start_count // (self.num_batches) for epoch in range(start_epoch, self.max_epoch): start_t = time.time() for step, data in enumerate(self.data_loader, 0): ####################################################### # (0) Prepare training data ###################################################### self.imgs_tcpu, self.real_imgs, self.wrong_imgs, \ self.txt_embedding = self.prepare_data(data) ####################################################### # (1) Generate fake images ###################################################### noise.data.normal_(0, 1) self.fake_imgs, self.mu, self.logvar = \ self.netG(noise, self.txt_embedding) ####################################################### # (2) Update D network ###################################################### errD_total = 0 for i in range(self.num_Ds): errD = self.train_Dnet(i, count) errD_total += errD ####################################################### # (3) Update G network: maximize log(D(G(z))) ###################################################### kl_loss, errG_total = self.train_Gnet(count) for p, avg_p in zip(self.netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) # for inception score pred = self.inception_model(self.fake_imgs[-1].detach()) predictions.append(pred.data.cpu().numpy()) if count % 100 == 0: summary_D = summary.scalar('D_loss', errD_total.data[0]) summary_G = summary.scalar('G_loss', errG_total.data[0]) summary_KL = summary.scalar('KL_loss', kl_loss.data[0]) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) count = count + 1 if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) # Save images backup_para = copy_G_params(self.netG) load_params(self.netG, avg_param_G) # self.fake_imgs, _, _ = \ self.netG(fixed_noise, self.txt_embedding) save_img_results(self.imgs_tcpu, self.fake_imgs, self.num_Ds, count, self.image_dir, self.summary_writer) # load_params(self.netG, backup_para) # Compute inception score if len(predictions) > 500: predictions = np.concatenate(predictions, 0) mean, std = compute_inception_score(predictions, 10) # print('mean:', mean, 'std', std) m_incep = summary.scalar('Inception_mean', mean) self.summary_writer.add_summary(m_incep, count) # mean_nlpp, std_nlpp = \ negative_log_posterior_probability(predictions, 10) m_nlpp = summary.scalar('NLPP_mean', mean_nlpp) self.summary_writer.add_summary(m_nlpp, count) # predictions = [] end_t = time.time() print('''[%d/%d][%d] Loss_D: %.2f Loss_G: %.2f Loss_KL: %.2f Time: %.2fs ''' # D(real): %.4f D(wrong):%.4f D(fake) %.4f % (epoch, self.max_epoch, self.num_batches, errD_total.data[0], errG_total.data[0], kl_loss.data[0], end_t - start_t)) save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) self.summary_writer.close() def save_superimages(self, images_list, filenames, save_dir, split_dir, imsize): batch_size = images_list[0].size(0) num_sentences = len(images_list) for i in range(batch_size): s_tmp = '%s/super/%s/%s' %\ (save_dir, split_dir, filenames[i]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) # savename = '%s_%d.png' % (s_tmp, imsize) super_img = [] for j in range(num_sentences): img = images_list[j][i] # print(img.size()) img = img.view(1, 3, imsize, imsize) # print(img.size()) super_img.append(img) # break super_img = torch.cat(super_img, 0) vutils.save_image(super_img, savename, nrow=10, normalize=True) def save_singleimages(self, images, filenames, save_dir, split_dir, sentenceID, imsize): for i in range(images.size(0)): s_tmp = '%s/single_samples/%s/%s' %\ (save_dir, split_dir, filenames[i]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) fullpath = '%s_%d_sentence%d.png' % (s_tmp, imsize, sentenceID) # range from [-1, 1] to [0, 255] img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte() ndarr = img.permute(1, 2, 0).data.cpu().numpy() im = Image.fromarray(ndarr) im.save(fullpath) def evaluate(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: # Build and load the generator if split_dir == 'test': split_dir = 'valid' netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=self.gpus) print(netG) # state_dict = torch.load(cfg.TRAIN.NET_G) state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) # the path to save generated images s_tmp = cfg.TRAIN.NET_G istart = s_tmp.rfind('_') + 1 iend = s_tmp.rfind('.') iteration = int(s_tmp[istart:iend]) s_tmp = s_tmp[:s_tmp.rfind('/')] save_dir = '%s/iteration%d' % (s_tmp, iteration) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) if cfg.CUDA: netG.cuda() noise = noise.cuda() # switch to evaluate mode netG.eval() for step, data in enumerate(self.data_loader, 0): imgs, t_embeddings, filenames = data if cfg.CUDA: t_embeddings = Variable(t_embeddings).cuda() else: t_embeddings = Variable(t_embeddings) # print(t_embeddings[:, 0, :], t_embeddings.size(1)) embedding_dim = t_embeddings.size(1) batch_size = imgs[0].size(0) noise.data.resize_(batch_size, nz) noise.data.normal_(0, 1) fake_img_list = [] for i in range(embedding_dim): fake_imgs, _, _ = netG(noise, t_embeddings[:, i, :]) if cfg.TEST.B_EXAMPLE: # fake_img_list.append(fake_imgs[0].data.cpu()) # fake_img_list.append(fake_imgs[1].data.cpu()) fake_img_list.append(fake_imgs[2].data.cpu()) else: self.save_singleimages(fake_imgs[-1], filenames, save_dir, split_dir, i, 256) # self.save_singleimages(fake_imgs[-2], filenames, # save_dir, split_dir, i, 128) # self.save_singleimages(fake_imgs[-3], filenames, # save_dir, split_dir, i, 64) # break if cfg.TEST.B_EXAMPLE: # self.save_superimages(fake_img_list, filenames, # save_dir, split_dir, 64) # self.save_superimages(fake_img_list, filenames, # save_dir, split_dir, 128) self.save_superimages(fake_img_list, filenames, save_dir, split_dir, 256)
class GANTrainer(object): def __init__(self, output_dir, data_loader, imsize): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL self.data_loader = data_loader self.num_batches = len(self.data_loader) def prepare_data(self, data): imgs = data vimgs = [] for i in range(self.num_Ds): if cfg.CUDA: vimgs.append(Variable(imgs[i]).cuda()) else: vimgs.append(Variable(imgs[i])) return imgs, vimgs def train_Dnet(self, idx, count): flag = count % 100 batch_size = self.real_imgs[0].size(0) criterion = self.criterion netD, optD = self.netsD[idx], self.optimizersD[idx] real_imgs = self.real_imgs[idx] fake_imgs = self.fake_imgs[idx] real_labels = self.real_labels[:batch_size] fake_labels = self.fake_labels[:batch_size] # netD.zero_grad() # real_logits = netD(real_imgs) fake_logits = netD(fake_imgs.detach()) # errD_real = criterion(real_logits[0], real_labels) errD_fake = criterion(fake_logits[0], fake_labels) # errD = errD_real + errD_fake errD.backward() # update parameters optD.step() # log if flag == 0: summary_D = summary.scalar('D_loss%d' % idx, errD.data[0]) self.summary_writer.add_summary(summary_D, count) return errD def train_Gnet(self, count): self.netG.zero_grad() errG_total = 0 flag = count % 100 batch_size = self.real_imgs[0].size(0) criterion = self.criterion real_labels = self.real_labels[:batch_size] for i in range(self.num_Ds): netD = self.netsD[i] outputs = netD(self.fake_imgs[i]) errG = criterion(outputs[0], real_labels) # errG = self.stage_coeff[i] * errG errG_total = errG_total + errG if flag == 0: summary_G = summary.scalar('G_loss%d' % i, errG.data[0]) self.summary_writer.add_summary(summary_G, count) # Compute color preserve losses if cfg.TRAIN.COEFF.COLOR_LOSS > 0: if self.num_Ds > 1: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1]) mu2, covariance2 = \ compute_mean_covariance(self.fake_imgs[-2].detach()) like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ nn.MSELoss()(covariance1, covariance2) errG_total = errG_total + like_mu2 + like_cov2 if self.num_Ds > 2: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2]) mu2, covariance2 = \ compute_mean_covariance(self.fake_imgs[-3].detach()) like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ nn.MSELoss()(covariance1, covariance2) errG_total = errG_total + like_mu1 + like_cov1 if flag == 0: sum_mu = summary.scalar('G_like_mu2', like_mu2.data[0]) self.summary_writer.add_summary(sum_mu, count) sum_cov = summary.scalar('G_like_cov2', like_cov2.data[0]) self.summary_writer.add_summary(sum_cov, count) if self.num_Ds > 2: sum_mu = summary.scalar('G_like_mu1', like_mu1.data[0]) self.summary_writer.add_summary(sum_mu, count) sum_cov = summary.scalar('G_like_cov1', like_cov1.data[0]) self.summary_writer.add_summary(sum_cov, count) errG_total.backward() self.optimizerG.step() return errG_total def train(self): self.netG, self.netsD, self.num_Ds,\ self.inception_model, start_count = load_network(self.gpus) avg_param_G = copy_G_params(self.netG) self.optimizerG, self.optimizersD = \ define_optimizers(self.netG, self.netsD) self.criterion = nn.BCELoss() self.real_labels = \ Variable(torch.FloatTensor(self.batch_size).fill_(1)) self.fake_labels = \ Variable(torch.FloatTensor(self.batch_size).fill_(0)) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1)) if cfg.CUDA: self.criterion.cuda() noise, fixed_noise = noise.cuda(), fixed_noise.cuda() self.real_labels = self.real_labels.cuda() self.fake_labels = self.fake_labels.cuda() predictions = [] count = start_count start_epoch = start_count // (self.num_batches) for epoch in range(start_epoch, self.max_epoch): start_t = time.time() for step, data in enumerate(self.data_loader, 0): ####################################################### # (0) Prepare training data ###################################################### self.imgs_tcpu, self.real_imgs = self.prepare_data(data) ####################################################### # (1) Generate fake images ###################################################### noise.data.normal_(0, 1) self.fake_imgs, _, _ = self.netG(noise) ####################################################### # (2) Update D network ###################################################### errD_total = 0 for i in range(self.num_Ds): errD = self.train_Dnet(i, count) errD_total += errD ####################################################### # (3) Update G network: maximize log(D(G(z))) ###################################################### errG_total = self.train_Gnet(count) for p, avg_p in zip(self.netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) # for inception score pred = self.inception_model(self.fake_imgs[-1].detach()) predictions.append(pred.data.cpu().numpy()) if count % 100 == 0: summary_D = summary.scalar('D_loss', errD_total.data[0]) summary_G = summary.scalar('G_loss', errG_total.data[0]) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_G, count) if step == 0: print('''[%d/%d][%d/%d] Loss_D: %.2f Loss_G: %.2f''' % (epoch, self.max_epoch, step, self.num_batches, errD_total.data[0], errG_total.data[0])) count = count + 1 if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) # Save images backup_para = copy_G_params(self.netG) load_params(self.netG, avg_param_G) # self.fake_imgs, _, _ = self.netG(fixed_noise) save_img_results(self.imgs_tcpu, self.fake_imgs, self.num_Ds, count, self.image_dir, self.summary_writer) # load_params(self.netG, backup_para) # Compute inception score if len(predictions) > 500: predictions = np.concatenate(predictions, 0) mean, std = compute_inception_score(predictions, 10) # print('mean:', mean, 'std', std) m_incep = summary.scalar('Inception_mean', mean) self.summary_writer.add_summary(m_incep, count) # mean_nlpp, std_nlpp = \ negative_log_posterior_probability(predictions, 10) m_nlpp = summary.scalar('NLPP_mean', mean_nlpp) self.summary_writer.add_summary(m_nlpp, count) # predictions = [] end_t = time.time() print('Total Time: %.2fsec' % (end_t - start_t)) save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) self.summary_writer.close() def save_superimages(self, images, folder, startID, imsize): fullpath = '%s/%d_%d.png' % (folder, startID, imsize) vutils.save_image(images.data, fullpath, normalize=True) def save_singleimages(self, images, folder, startID, imsize): for i in range(images.size(0)): fullpath = '%s/%d_%d.png' % (folder, startID + i, imsize) # range from [-1, 1] to [0, 1] img = (images[i] + 1.0) / 2 img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte() # range from [0, 1] to [0, 255] ndarr = img.permute(1, 2, 0).data.cpu().numpy() im = Image.fromarray(ndarr) im.save(fullpath) def evaluate(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: # Build and load the generator netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=self.gpus) print(netG) # state_dict = torch.load(cfg.TRAIN.NET_G) state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) # the path to save generated images s_tmp = cfg.TRAIN.NET_G istart = s_tmp.rfind('_') + 1 iend = s_tmp.rfind('.') iteration = int(s_tmp[istart:iend]) s_tmp = s_tmp[:s_tmp.rfind('/')] save_dir = '%s/iteration%d/%s' % (s_tmp, iteration, split_dir) if cfg.TEST.B_EXAMPLE: folder = '%s/super' % (save_dir) else: folder = '%s/single' % (save_dir) print('Make a new folder: ', folder) mkdir_p(folder) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) if cfg.CUDA: netG.cuda() noise = noise.cuda() # switch to evaluate mode netG.eval() num_batches = int(cfg.TEST.SAMPLE_NUM / self.batch_size) cnt = 0 for step in xrange(num_batches): noise.data.normal_(0, 1) fake_imgs, _, _ = netG(noise) if cfg.TEST.B_EXAMPLE: self.save_superimages(fake_imgs[-1], folder, cnt, 256) else: self.save_singleimages(fake_imgs[-1], folder, cnt, 256) # self.save_singleimages(fake_imgs[-2], folder, 128) # self.save_singleimages(fake_imgs[-3], folder, 64) cnt += self.batch_size
class GANTrainer(object): def __init__(self, output_dir): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL self.max_objects = 4 s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) self.batch_size = cfg.TRAIN.BATCH_SIZE torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True # ############# For training stageI GAN ############# def load_network_stageI(self): from model import STAGE1_G, STAGE1_D netG = STAGE1_G() netG.apply(weights_init) netD = STAGE1_D() netD.apply(weights_init) if cfg.NET_G != '': state_dict = \ torch.load(cfg.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict["netG"]) print('Load from: ', cfg.NET_G) if cfg.NET_D != '': state_dict = \ torch.load(cfg.NET_D, map_location=lambda storage, loc: storage) netD.load_state_dict(state_dict) print('Load from: ', cfg.NET_D) if cfg.CUDA: netG.cuda() netD.cuda() return netG, netD def train(self, data_loader, stage=1): netG, netD = self.load_network_stageI() nz = cfg.Z_DIM batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), requires_grad=False) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerD = optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) print("Start training...") count = 0 for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr for i, data in enumerate(data_loader, 0): ###################################################### # (1) Prepare training data ###################################################### real_img_cpu, transformation_matrices, label_one_hot, _ = data transf_matrices, transf_matrices_inv = tuple( transformation_matrices) transf_matrices = transf_matrices.detach() transf_matrices_inv = transf_matrices_inv.detach() real_imgs = Variable(real_img_cpu) if cfg.CUDA: real_imgs = real_imgs.cuda() label_one_hot = label_one_hot.cuda() transf_matrices = transf_matrices.cuda() transf_matrices_inv = transf_matrices_inv.cuda() ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (noise, transf_matrices_inv, label_one_hot) fake_imgs = nn.parallel.data_parallel(netG, inputs, self.gpus) ############################ # (3) Update D network ########################### netD.zero_grad() errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, label_one_hot, transf_matrices, transf_matrices_inv, self.gpus) errD.backward(retain_graph=True) optimizerD.step() ############################ # (2) Update G network ########################### netG.zero_grad() errG = compute_generator_loss(netD, fake_imgs, real_labels, label_one_hot, transf_matrices, transf_matrices_inv, self.gpus) errG_total = errG errG_total.backward() optimizerG.step() ############################ # (3) Log results ########################### count = count + 1 if i % 500 == 0: summary_D = summary.scalar('D_loss', errD.item()) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.item()) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D_r, count) self.summary_writer.add_summary(summary_D_w, count) self.summary_writer.add_summary(summary_D_f, count) self.summary_writer.add_summary(summary_G, count) # save the image result for each epoch with torch.no_grad(): inputs = (noise, transf_matrices_inv, label_one_hot) fake = nn.parallel.data_parallel( netG, inputs, self.gpus) save_img_results(real_img_cpu, fake, epoch, self.image_dir) with torch.no_grad(): inputs = (noise, transf_matrices_inv, label_one_hot) fake = nn.parallel.data_parallel(netG, inputs, self.gpus) save_img_results(real_img_cpu, fake, epoch, self.image_dir) end_t = time.time() print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.item(), errG.item(), errD_real, errD_wrong, errD_fake, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir) # save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir) # self.summary_writer.close() def sample(self, data_loader, num_samples=25, draw_bbox=True, max_objects=4): from PIL import Image, ImageDraw, ImageFont import pickle import torchvision import torchvision.utils as vutils netG, _ = self.load_network_stageI() netG.eval() # path to save generated samples save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')] + "_samples_" + str( max_objects) + "_objects" print("saving to:", save_dir) mkdir_p(save_dir) nz = cfg.Z_DIM noise = Variable(torch.FloatTensor(9, nz)) if cfg.CUDA: noise = noise.cuda() imsize = 64 # for count in range(num_samples): count = 0 for i, data in enumerate(data_loader, 0): if count == num_samples: break ###################################################### # (1) Prepare training data ###################################################### real_img_cpu, transformation_matrices, label_one_hot, bbox = data transf_matrices, transf_matrices_inv = tuple( transformation_matrices) transf_matrices_inv = transf_matrices_inv.detach() real_img = Variable(real_img_cpu) if cfg.CUDA: real_img = real_img.cuda() label_one_hot = label_one_hot.cuda() transf_matrices_inv = transf_matrices_inv.cuda() transf_matrices_inv_batch = transf_matrices_inv.view( 1, max_objects, 2, 3).repeat(9, 1, 1, 1) label_one_hot_batch = label_one_hot.view(1, max_objects, 13).repeat(9, 1, 1) ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (noise, transf_matrices_inv_batch, label_one_hot_batch) with torch.no_grad(): fake_imgs = nn.parallel.data_parallel(netG, inputs, self.gpus) data_img = torch.FloatTensor(20, 3, imsize, imsize).fill_(0) data_img[0] = real_img data_img[1:10] = fake_imgs if draw_bbox: for idx in range(max_objects): x, y, w, h = tuple([int(imsize * x) for x in bbox[0, idx]]) w = imsize - 1 if w > imsize - 1 else w h = imsize - 1 if h > imsize - 1 else h if x <= -1 or y <= -1: break data_img[:10, :, y, x:x + w] = 1 data_img[:10, :, y:y + h, x] = 1 data_img[:10, :, y + h, x:x + w] = 1 data_img[:10, :, y:y + h, x + w] = 1 # write caption into image shape_dict = {0: "cube", 1: "cylinder", 2: "sphere", 3: "empty"} color_dict = { 0: "gray", 1: "red", 2: "blue", 3: "green", 4: "brown", 5: "purple", 6: "cyan", 7: "yellow", 8: "empty" } text_img = Image.new('L', (imsize * 10, imsize), color='white') d = ImageDraw.Draw(text_img) label = label_one_hot_batch[0] label = label.cpu().numpy() label_shape = label[:, :4] label_color = label[:, 4:] label_shape = np.argmax(label_shape, axis=1) label_color = np.argmax(label_color, axis=1) label_combined = ", ".join([ color_dict[label_color[_]] + " " + shape_dict[label_shape[_]] for _ in range(max_objects) ]) d.text((10, 10), label_combined) text_img = torchvision.transforms.functional.to_tensor(text_img) text_img = torch.chunk(text_img, 10, 2) text_img = torch.cat( [text_img[i].view(1, 1, imsize, imsize) for i in range(10)], 0) data_img[10:] = text_img vutils.save_image(data_img, '{}/vis_{}.png'.format(save_dir, count), normalize=True, nrow=10) count += 1 print("Saved {} files to {}".format(count, save_dir))
class GANTrainer(object): def __init__(self, output_dir, max_epoch, snapshot_interval, gpu_id, batch_size, train_flag, net_g, net_d, cuda, stage1_g, z_dim, generator_lr, discriminator_lr, lr_decay_epoch, coef_kl, regularizer ): if train_flag: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) self.max_epoch = max_epoch self.snapshot_interval = snapshot_interval self.net_g = net_g self.net_d = net_d self.cuda = cuda self.stage1_g = stage1_g self.nz = z_dim self.generator_lr = generator_lr self.discriminator_lr = discriminator_lr self.lr_decay_step = lr_decay_epoch self.coef_kl = coef_kl self.regularizer = regularizer s_gpus = gpu_id.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) self.batch_size = batch_size * self.num_gpus torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True # ############# For training stageI GAN ############# def load_network_stageI(self, text_dim, gf_dim, condition_dim, z_dim, df_dim): from model import STAGE1_G, STAGE1_D netG = STAGE1_G(text_dim, gf_dim, condition_dim, z_dim, self.cuda) netG.apply(weights_init) print(netG) netD = STAGE1_D(df_dim, condition_dim) netD.apply(weights_init) print(netD) if self.net_g != '': state_dict = torch.load(self.net_g) #torch.load(self.net_g, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load from: ', self.net_g) if self.net_d != '': state_dict = torch.load(self.net_d, map_location=lambda storage, loc: storage) netD.load_state_dict(state_dict) print('Load from: ', self.net_d) if self.cuda: netG.cuda() netD.cuda() return netG, netD # ############# For training stageII GAN ############# def load_network_stageII(self, text_dim, gf_dim, condition_dim, z_dim, df_dim, res_num): from model import STAGE1_G, STAGE2_G, STAGE2_D Stage1_G = STAGE1_G(text_dim, gf_dim, condition_dim, z_dim, self.cuda) netG = STAGE2_G(Stage1_G, text_dim, gf_dim, condition_dim, z_dim, res_num, self.cuda) netG.apply(weights_init) print(netG) if self.net_g != '': state_dict = torch.load(self.net_g) #torch.loadself.net_g, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load from: ', self.net_g) elif self.stage1_g != '': state_dict = torch.load(self.stage1_g, map_location=lambda storage, loc: storage) netG.STAGE1_G.load_state_dict(state_dict) print('Load from: ', self.stage1_g) else: print("Please give the Stage1_G path") return netD = STAGE2_D(df_dim, condition_dim) netD.apply(weights_init) if self.net_d != '': state_dict = torch.load(self.net_d, map_location=lambda storage, loc: storage) netD.load_state_dict(state_dict) print('Load from: ', self.net_d) print(netD) if self.cuda: netG.cuda() netD.cuda() return netG, netD def train(self, data_loader, stage, text_dim, gf_dim, condition_dim, z_dim, df_dim, res_num): if stage == 1: netG, netD = self.load_network_stageI(text_dim, gf_dim, condition_dim, z_dim, df_dim) else: netG, netD = self.load_network_stageII(text_dim, gf_dim, condition_dim, z_dim, df_dim, res_num) batch_size = self.batch_size noise = torch.FloatTensor(batch_size, self.nz) fixed_noise = torch.FloatTensor(batch_size, self.nz).normal_(0, 1) real_labels = torch.ones(batch_size) fake_labels = torch.zeros(batch_size) if self.cuda: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() optimizerD = optim.Adam(netD.parameters(), lr=self.discriminator_lr, betas=(0.5, 0.999)) netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerG = optim.Adam(netG_para, self.generator_lr, betas=(0.5, 0.999)) count = 0 for epoch in range(self.max_epoch): start_t = time.time() if epoch % self.lr_decay_step == 0 and epoch > 0: self.generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = self.generator_lr self.discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = self.discriminator_lr for i, data in enumerate(data_loader, 0): ###################################################### # (1) Prepare training data ###################################################### real_imgs, txt_embedding = data if self.cuda: real_imgs = real_imgs.cuda() txt_embedding = txt_embedding.cuda() ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (txt_embedding, noise) _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) ############################ # (3) Update D network ########################### netD.zero_grad() errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, mu, self.gpus) errD.backward() optimizerD.step() ############################ # (2) Update G network ########################### netG.zero_grad() errG = compute_generator_loss(netD, fake_imgs, real_labels, mu, self.gpus) if self.regularizer == 'KL': regularizer_loss = KL_loss(mu, logvar) else: regularizer_loss = JSD_loss(mu, logvar) errG_total = errG + regularizer_loss * self.coef_kl errG_total.backward() optimizerG.step() count = count + 1 if i % 100 == 0: summary_D = summary.scalar('D_loss', errD.item()) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.item()) summary_KL = summary.scalar('Regularizer_loss', regularizer_loss.item()) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D_r, count) self.summary_writer.add_summary(summary_D_w, count) self.summary_writer.add_summary(summary_D_f, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) # save the image result for each epoch inputs = (txt_embedding, fixed_noise) lr_fake, fake, _, _ = \ nn.parallel.data_parallel(netG, inputs, self.gpus) save_img_results(real_imgs.cpu(), fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) end_t = time.time() print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.item(), errG.item(), regularizer_loss.item(), errD_real, errD_wrong, errD_fake, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD, epoch, self.model_dir) # save_model(netG, netD, self.max_epoch, self.model_dir) # self.summary_writer.close() def sample(self, datapath, stage=1): if stage == 1: netG, _ = self.load_network_stageI() else: netG, _ = self.load_network_stageII() netG.eval() # Load text embeddings generated from the encoder t_file = torchfile.load(datapath) captions_list = t_file.raw_txt embeddings = np.concatenate(t_file.fea_txt, axis=0) num_embeddings = len(captions_list) print('Successfully load sentences from: ', datapath) print('Total number of sentences:', num_embeddings) print('num_embeddings:', num_embeddings, embeddings.shape) # path to save generated samples save_dir = self.net_g[:self.net_g.find('.pth')] mkdir_p(save_dir) batch_size = np.minimum(num_embeddings, self.batch_size) noise = torch.FloatTensor(batch_size, self.nz) if self.cuda: noise = noise.cuda() count = 0 while count < num_embeddings: if count > 3000: break iend = count + batch_size if iend > num_embeddings: iend = num_embeddings count = num_embeddings - batch_size embeddings_batch = embeddings[count:iend] # captions_batch = captions_list[count:iend] txt_embedding = torch.FloatTensor(embeddings_batch) if self.cuda: txt_embedding = txt_embedding.cuda() ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (txt_embedding, noise) _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) for i in range(batch_size): save_name = '%s/%d.png' % (save_dir, count + i) im = fake_imgs[i].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) # print('im', im.shape) im = np.transpose(im, (1, 2, 0)) # print('im', im.shape) im = Image.fromarray(im) im.save(save_name) count += batch_size def birds_eval(self, data_loader, stage=2): if stage == 1: netG, netD = self.load_network_stageI() else: netG, netD = self.load_network_stageII() netG.eval() nz = self.z_dim batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), volatile=True) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if self.cuda: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() # path to save generated samples save_dir = self.netG[:self.netG.find('.pth')] print("Save directory", save_dir) mkdir_p(save_dir) count = 0 for i, data in enumerate(data_loader, 0): ###################################################### # (1) Prepare training data ###################################################### real_img_cpu, txt_embedding = data real_imgs = Variable(real_img_cpu) txt_embedding = Variable(txt_embedding) if self.cuda: real_imgs = real_imgs.cuda() txt_embedding = txt_embedding.cuda() print("Batch Running:", i) ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (txt_embedding, noise) _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) for i in range(batch_size): save_name = '%s/%d.png' % (save_dir, count + i) im = fake_imgs[i].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) # print('im', im.shape) im = np.transpose(im, (1, 2, 0)) # print('im', im.shape) im = Image.fromarray(im) im.save(save_name) count += batch_size
def train(iter_cnt, model, domain_d, corpus, args, optimizer_encoder, optimizer_domain_d): train_writer = FileWriter(args.run_path + "/train", flush_secs=5) pos_file_path = "{}.pos.txt".format(args.train) neg_file_path = "{}.neg.txt".format(args.train) # for adversarial training just use natural language portions of inputs train_corpus_path = os.path.dirname(args.train) + "/nl.tsv.gz" cross_train_corpus_path = os.path.dirname(args.cross_train) + "/nl.tsv.gz" use_content = False if args.use_content: use_content = True pos_batch_loader = FileLoader( [tuple([pos_file_path, os.path.dirname(args.train)])], args.batch_size) neg_batch_loader = FileLoader( [tuple([neg_file_path, os.path.dirname(args.train)])], args.batch_size) cross_loader = TwoDomainLoader( [tuple([train_corpus_path, os.path.dirname(train_corpus_path)])], [ tuple([ cross_train_corpus_path, os.path.dirname(cross_train_corpus_path) ]) ], args.batch_size * 2) embedding_layer = model.embedding_layer criterion1 = model.compute_loss criterion2 = domain_d.compute_loss start = time.time() task_loss = 0.0 task_cnt = 0 domain_loss = 0.0 dom_cnt = 0 total_loss = 0.0 total_cnt = 0 for batch, labels, domain_batch, domain_labels in tqdm( cross_pad_iter(corpus, embedding_layer, pos_batch_loader, neg_batch_loader, cross_loader, use_content, pad_left=False)): iter_cnt += 1 new_batch = [] if args.use_content: for x in batch: for y in x: new_batch.append(y) batch = new_batch domain_batch = [x for x in domain_batch] if args.cuda: batch = [x.cuda() for x in batch] labels = labels.cuda() if not use_content: domain_batch = domain_batch.cuda() else: domain_batch = [x.cuda() for x in domain_batch] domain_labels = domain_labels.cuda() batch = map(Variable, batch) labels = Variable(labels) if not use_content: domain_batch = Variable(domain_batch) else: domain_batch = map(Variable, domain_batch) domain_labels = Variable(domain_labels) model.zero_grad() domain_d.zero_grad() repr_left = None repr_right = None if not use_content: repr_left = model(batch[0]) repr_right = model(batch[1]) else: repr_left = model(batch[0]) + model(batch[1]) repr_right = model(batch[2]) + model(batch[3]) output = model.compute_similarity(repr_left, repr_right) loss1 = criterion1(output, labels) task_loss += loss1.data[0] * output.size(0) task_cnt += output.size(0) domain_output = None if not use_content: domain_output = domain_d(model(domain_batch)) else: domain_output = domain_d(model(domain_batch[0])) + domain_d( model(domain_batch[1])) loss2 = criterion2(domain_output, domain_labels) domain_loss += loss2.data[0] * domain_output.size(0) dom_cnt += domain_output.size(0) loss = loss1 - args.lambda_d * loss2 total_loss += loss.data[0] total_cnt += 1 loss.backward() optimizer_encoder.step() optimizer_domain_d.step() if iter_cnt % 100 == 0: outputManager.say("\r" + " " * 50) outputManager.say( "\r{} tot_loss: {:.4f} task_loss: {:.4f} domain_loss: {:.4f} eps: {:.0f} " .format(iter_cnt, total_loss / total_cnt, task_loss / task_cnt, domain_loss / dom_cnt, (task_cnt + dom_cnt) / (time.time() - start))) s = summary.scalar('total_loss', total_loss / total_cnt) train_writer.add_summary(s, iter_cnt) s = summary.scalar('domain_loss', domain_loss / dom_cnt) train_writer.add_summary(s, iter_cnt) s = summary.scalar('task_loss', task_loss / task_cnt) train_writer.add_summary(s, iter_cnt) outputManager.say("\n") train_writer.close() return iter_cnt
def train(iter_cnt, model, corpus, args, optimizer): train_writer = FileWriter(args.run_path + "/train", flush_secs=5) pos_file_path = "{}.pos.txt".format(args.train) neg_file_path = "{}.neg.txt".format(args.train) pos_batch_loader = FileLoader( [tuple([pos_file_path, os.path.dirname(args.train)])], args.batch_size) neg_batch_loader = FileLoader( [tuple([neg_file_path, os.path.dirname(args.train)])], args.batch_size) #neg_batch_loader = RandomLoader( # corpus = corpus, # exclusive_set = zip(pos_batch_loader.data_left, pos_batch_loader.data_right), # batch_size = args.batch_size #) #neg_batch_loader = CombinedLoader( # neg_batch_loader_1, # neg_batch_loader_2, # args.batch_size #) use_content = False if args.use_content: use_content = True embedding_layer = model.embedding_layer criterion = model.compute_loss start = time.time() tot_loss = 0.0 tot_cnt = 0 for batch, labels in tqdm( pad_iter(corpus, embedding_layer, pos_batch_loader, neg_batch_loader, use_content, pad_left=False)): iter_cnt += 1 model.zero_grad() labels = labels.type(torch.LongTensor) new_batch = [] if args.use_content: for x in batch: for y in x: new_batch.append(y) batch = new_batch if args.cuda: batch = [x.cuda() for x in batch] labels = labels.cuda() batch = map(Variable, batch) labels = Variable(labels) repr_left = None repr_right = None if not use_content: repr_left = model(batch[0]) repr_right = model(batch[1]) else: repr_left = model(batch[0]) + model(batch[1]) repr_right = model(batch[2]) + model(batch[3]) output = model.compute_similarity(repr_left, repr_right) loss = criterion(output, labels) loss.backward() prev_emb = embedding_layer.embedding.weight.cpu().data.numpy() optimizer.step() current_emb = embedding_layer.embedding.weight.cpu().data.numpy() diff = np.sum(np.absolute(current_emb - prev_emb)) tot_loss += loss.data[0] * output.size(0) tot_cnt += output.size(0) if iter_cnt % 100 == 0: outputManager.say("\r" + " " * 50) outputManager.say("\r{} loss: {:.4f} eps: {:.0f} ".format( iter_cnt, tot_loss / tot_cnt, tot_cnt / (time.time() - start))) s = summary.scalar('loss', tot_loss / tot_cnt) train_writer.add_summary(s, iter_cnt) outputManager.say("\n") train_writer.close() #if model.criterion.startswith('classification'): # print model.output_op.weight.min().data[0], model.output_op.weight.max().data[0] return iter_cnt
class GANTrainer(object): def __init__(self, output_dir): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True # ############# For training stageI GAN ############# def load_network_stageI(self): from model import STAGE1_G, STAGE1_D netG = STAGE1_G() netG.apply(weights_init) print(netG) netD = STAGE1_D() netD.apply(weights_init) print(netD) if cfg.NET_G != '': state_dict = \ torch.load(cfg.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load from: ', cfg.NET_G) if cfg.NET_D != '': state_dict = \ torch.load(cfg.NET_D, map_location=lambda storage, loc: storage) netD.load_state_dict(state_dict) print('Load from: ', cfg.NET_D) if cfg.CUDA: netG.cuda() netD.cuda() return netG, netD # ############# For training stageII GAN ############# def load_network_stageII(self): from model import STAGE1_G, STAGE2_G, STAGE2_D Stage1_G = STAGE1_G() netG = STAGE2_G(Stage1_G) netG.apply(weights_init) print(netG) if cfg.NET_G != '': state_dict = \ torch.load(cfg.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load from: ', cfg.NET_G) elif cfg.STAGE1_G != '': state_dict = \ torch.load(cfg.STAGE1_G, map_location=lambda storage, loc: storage) netG.STAGE1_G.load_state_dict(state_dict) print('Load from: ', cfg.STAGE1_G) else: print("Please give the Stage1_G path") return netD = STAGE2_D() netD.apply(weights_init) if cfg.NET_D != '': state_dict = \ torch.load(cfg.NET_D, map_location=lambda storage, loc: storage) netD.load_state_dict(state_dict) print('Load from: ', cfg.NET_D) print(netD) if cfg.CUDA: netG.cuda() netD.cuda() return netG, netD def train(self, data_loader, stage=1): if stage == 1: netG, netD = self.load_network_stageI() else: netG, netD = self.load_network_stageII() nz = cfg.Z_DIM batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), volatile=True) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH optimizerD = \ optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) count = 0 for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr for i, data in enumerate(data_loader, 0): ###################################################### # (1) Prepare training data ###################################################### real_img_cpu, txt_embedding = data real_imgs = Variable(real_img_cpu) txt_embedding = Variable(txt_embedding) if cfg.CUDA: real_imgs = real_imgs.cuda() txt_embedding = txt_embedding.cuda() ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (txt_embedding, noise) _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) ############################ # (3) Update D network ########################### netD.zero_grad() errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, mu, self.gpus) errD.backward() optimizerD.step() ############################ # (2) Update G network ########################### netG.zero_grad() errG = compute_generator_loss(netD, fake_imgs, real_labels, mu, self.gpus) kl_loss = KL_loss(mu, logvar) errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL errG_total.backward() optimizerG.step() count = count + 1 if i % 100 == 0: summary_D = summary.scalar('D_loss', errD.data[0]) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.data[0]) summary_KL = summary.scalar('KL_loss', kl_loss.data[0]) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D_r, count) self.summary_writer.add_summary(summary_D_w, count) self.summary_writer.add_summary(summary_D_f, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) # save the image result for each epoch inputs = (txt_embedding, fixed_noise) lr_fake, fake, _, _ = \ nn.parallel.data_parallel(netG, inputs, self.gpus) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) end_t = time.time() print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.data[0], errG.data[0], kl_loss.data[0], errD_real, errD_wrong, errD_fake, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD, epoch, self.model_dir) # save_model(netG, netD, self.max_epoch, self.model_dir) # self.summary_writer.close() def sample(self, datapath, stage=1): if stage == 1: netG, _ = self.load_network_stageI() else: netG, _ = self.load_network_stageII() netG.eval() # Load text embeddings generated from the encoder t_file = torchfile.load(datapath) captions_list = t_file.raw_txt embeddings = np.concatenate(t_file.fea_txt, axis=0) num_embeddings = len(captions_list) print('Successfully load sentences from: ', datapath) print('Total number of sentences:', num_embeddings) print('num_embeddings:', num_embeddings, embeddings.shape) # path to save generated samples save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')] mkdir_p(save_dir) batch_size = np.minimum(num_embeddings, self.batch_size) nz = cfg.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz)) if cfg.CUDA: noise = noise.cuda() count = 0 while count < num_embeddings: if count > 3000: break iend = count + batch_size if iend > num_embeddings: iend = num_embeddings count = num_embeddings - batch_size embeddings_batch = embeddings[count:iend] # captions_batch = captions_list[count:iend] txt_embedding = Variable(torch.FloatTensor(embeddings_batch)) if cfg.CUDA: txt_embedding = txt_embedding.cuda() ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (txt_embedding, noise) _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) for i in range(batch_size): save_name = '%s/%d.png' % (save_dir, count + i) im = fake_imgs[i].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) # print('im', im.shape) im = np.transpose(im, (1, 2, 0)) # print('im', im.shape) im = Image.fromarray(im) im.save(save_name) count += batch_size
class GANTrainer(object): def __init__(self, output_dir, data_loader, imsize): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL self.data_loader = data_loader self.num_batches = len(self.data_loader) def prepare_data(self, data): imgs = data vimgs = [] for i in range(self.num_Ds): if cfg.CUDA: vimgs.append(Variable(imgs[i]).cuda()) else: vimgs.append(Variable(imgs[i])) return imgs, vimgs def train_Dnet(self, idx, count): flag = count % 100 batch_size = self.real_imgs[0].size(0) criterion = self.criterion netD, optD = self.netsD[idx], self.optimizersD[idx] real_imgs = self.real_imgs[idx] fake_imgs = self.fake_imgs[idx] real_labels = self.real_labels[:batch_size] fake_labels = self.fake_labels[:batch_size] # netD.zero_grad() # real_logits = netD(real_imgs) fake_logits = netD(fake_imgs.detach()) # errD_real = criterion(real_logits[0], real_labels) errD_fake = criterion(fake_logits[0], fake_labels) # errD = errD_real + errD_fake errD.backward() # update parameters optD.step() # log if flag == 0: summary_D = summary.scalar('D_loss%d' % idx, errD.data[0]) self.summary_writer.add_summary(summary_D, count) return errD def train_Gnet(self, count): self.netG.zero_grad() errG_total = 0 flag = count % 100 batch_size = self.real_imgs[0].size(0) criterion = self.criterion real_labels = self.real_labels[:batch_size] for i in range(self.num_Ds): netD = self.netsD[i] outputs = netD(self.fake_imgs[i]) errG = criterion(outputs[0], real_labels) # errG = self.stage_coeff[i] * errG errG_total = errG_total + errG if flag == 0: summary_G = summary.scalar('G_loss%d' % i, errG.data[0]) self.summary_writer.add_summary(summary_G, count) # Compute color preserve losses if cfg.TRAIN.COEFF.COLOR_LOSS > 0: if self.num_Ds > 1: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1]) mu2, covariance2 = \ compute_mean_covariance(self.fake_imgs[-2].detach()) like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ nn.MSELoss()(covariance1, covariance2) errG_total = errG_total + like_mu2 + like_cov2 if self.num_Ds > 2: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2]) mu2, covariance2 = \ compute_mean_covariance(self.fake_imgs[-3].detach()) like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ nn.MSELoss()(covariance1, covariance2) errG_total = errG_total + like_mu1 + like_cov1 if flag == 0: sum_mu = summary.scalar('G_like_mu2', like_mu2.data[0]) self.summary_writer.add_summary(sum_mu, count) sum_cov = summary.scalar('G_like_cov2', like_cov2.data[0]) self.summary_writer.add_summary(sum_cov, count) if self.num_Ds > 2: sum_mu = summary.scalar('G_like_mu1', like_mu1.data[0]) self.summary_writer.add_summary(sum_mu, count) sum_cov = summary.scalar('G_like_cov1', like_cov1.data[0]) self.summary_writer.add_summary(sum_cov, count) errG_total.backward() self.optimizerG.step() return errG_total def train(self): print("Try to load network") self.netG, self.netsD, self.num_Ds,\ self.inception_model, start_count = load_network(self.gpus) print("Network loaded") avg_param_G = copy_G_params(self.netG) self.optimizerG, self.optimizersD = \ define_optimizers(self.netG, self.netsD) self.criterion = nn.BCELoss() self.real_labels = \ Variable(torch.FloatTensor(self.batch_size).fill_(1)) self.fake_labels = \ Variable(torch.FloatTensor(self.batch_size).fill_(0)) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1)) if cfg.CUDA: self.criterion.cuda() noise, fixed_noise = noise.cuda(), fixed_noise.cuda() self.real_labels = self.real_labels.cuda() self.fake_labels = self.fake_labels.cuda() predictions = [] count = start_count start_epoch = start_count // (self.num_batches) for epoch in range(start_epoch, self.max_epoch): start_t = time.time() for step, data in enumerate(self.data_loader, 0): ####################################################### # (0) Prepare training data ###################################################### self.imgs_tcpu, self.real_imgs = self.prepare_data(data) ####################################################### # (1) Generate fake images ###################################################### noise.data.normal_(0, 1) self.fake_imgs, _, _ = self.netG(noise) ####################################################### # (2) Update D network ###################################################### errD_total = 0 for i in range(self.num_Ds): errD = self.train_Dnet(i, count) errD_total += errD ####################################################### # (3) Update G network: maximize log(D(G(z))) ###################################################### errG_total = self.train_Gnet(count) for p, avg_p in zip(self.netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) # for inception score pred = self.inception_model(self.fake_imgs[-1].detach()) predictions.append(pred.data.cpu().numpy()) if count % 100 == 0: summary_D = summary.scalar('D_loss', errD_total.data[0]) summary_G = summary.scalar('G_loss', errG_total.data[0]) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_G, count) if step == 0: print('''[%d/%d][%d/%d] Loss_D: %.2f Loss_G: %.2f''' % (epoch, self.max_epoch, step, self.num_batches, errD_total.data[0], errG_total.data[0])) count = count + 1 if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) # Save images backup_para = copy_G_params(self.netG) load_params(self.netG, avg_param_G) # self.fake_imgs, _, _ = self.netG(fixed_noise) save_img_results(self.imgs_tcpu, self.fake_imgs, self.num_Ds, count, self.image_dir, self.summary_writer) # load_params(self.netG, backup_para) # Compute inception score if len(predictions) > 500: predictions = np.concatenate(predictions, 0) mean, std = compute_inception_score(predictions, 10) # print('mean:', mean, 'std', std) m_incep = summary.scalar('Inception_mean', mean) self.summary_writer.add_summary(m_incep, count) # mean_nlpp, std_nlpp = \ negative_log_posterior_probability(predictions, 10) m_nlpp = summary.scalar('NLPP_mean', mean_nlpp) self.summary_writer.add_summary(m_nlpp, count) # predictions = [] end_t = time.time() print('Total Time: %.2fsec' % (end_t - start_t)) save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) self.summary_writer.close() def save_superimages(self, images, folder, startID, imsize): fullpath = '%s/%d_%d.png' % (folder, startID, imsize) vutils.save_image(images.data, fullpath, normalize=True) def save_singleimages(self, images, folder, startID, imsize): for i in range(images.size(0)): fullpath = '%s/%d_%d.png' % (folder, startID + i, imsize) # range from [-1, 1] to [0, 1] img = (images[i] + 1.0) / 2 img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte() # range from [0, 1] to [0, 255] ndarr = img.permute(1, 2, 0).data.cpu().numpy() im = Image.fromarray(ndarr) im.save(fullpath) def evaluate(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: # Build and load the generator netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=self.gpus) print(netG) # state_dict = torch.load(cfg.TRAIN.NET_G) state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) # the path to save generated images s_tmp = cfg.TRAIN.NET_G istart = s_tmp.rfind('_') + 1 iend = s_tmp.rfind('.') iteration = int(s_tmp[istart:iend]) s_tmp = s_tmp[:s_tmp.rfind('/')] save_dir = '%s/iteration%d/%s' % (s_tmp, iteration, split_dir) if cfg.TEST.B_EXAMPLE: folder = '%s/super' % (save_dir) else: folder = '%s/single' % (save_dir) print('Make a new folder: ', folder) mkdir_p(folder) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) if cfg.CUDA: netG.cuda() noise = noise.cuda() # switch to evaluate mode netG.eval() num_batches = int(cfg.TEST.SAMPLE_NUM / self.batch_size) cnt = 0 for step in xrange(num_batches): noise.data.normal_(0, 1) fake_imgs, _, _ = netG(noise) if cfg.TEST.B_EXAMPLE: self.save_superimages(fake_imgs[-1], folder, cnt, 256) else: self.save_singleimages(fake_imgs[-1], folder, cnt, 256) # self.save_singleimages(fake_imgs[-2], folder, 128) # self.save_singleimages(fake_imgs[-3], folder, 64) cnt += self.batch_size
class QRNN(object): def __init__(self, data): if cfg.TRAIN.FLAG: self.model_dir = cfg.NET self.log_dir = cfg.TRAIN.LOG_DIR self.summary_writer = FileWriter(self.log_dir) self.cuda = torch.cuda.is_available() if self.cuda: self.gpu = int(cfg.GPU_ID) torch.cuda.set_device(self.gpu) print("Use CUDA") self.epoch = cfg.TRAIN.NUM_EPOCH self.batch_size = cfg.TRAIN.BATCH_SIZE self.vocab_size = data.vocab_size self.embedding_size = cfg.EMBEDDING_SIZE self.index2word = data.index2word self.seqs = data.indexed_seqs self.net = self.load_network() self.data = data def load_network(self): net = Net(self.vocab_size) if self.cuda: net = net.cuda() if not cfg.TRAIN.FLAG: state_dict = torch.load(cfg.NET, map_location=lambda storage, loc: storage) print("Load from", cfg.NET) return net def train(self): start_idx = 0 net = self.net criterion = nn.CrossEntropyLoss() lr = cfg.TRAIN.LEARNING_RATE optimizer = optim.Adam(net.parameters(), lr=lr, betas=(0, 0.99)) count = 0 start_idx = 0 iteration = len(self.seqs) // self.batch_size + 1 for epoch in range(self.epoch): for i in tqdm(range(iteration)): x, lengths = prepare_batch(self.seqs, start_idx, self.batch_size) if self.cuda: x = Variable(torch.LongTensor(x)).cuda() else: x = Variable(torch.LongTensor(x)) logits = net(x) loss = [] for i in range(logits.size(0)): logit = logits[i][:lengths[i] - 1] target = x[i][1:lengths[i]] loss.append(criterion(logit, target)) loss = sum(loss) / len(loss) net.zero_grad() loss.backward() optimizer.step() end = time.time() summary_net = summary.scalar("Loss", loss.data[0]) self.summary_writer.add_summary(summary_net, count) count += 1 print("epoch %d done. train loss: %f" % (epoch + 1, loss.data[0])) if count % cfg.TRAIN.LR_DECAY_INTERVAL == 0: lr = lr * 0.95 optimizer = optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999)) print("decayed learning rate") save_model(net, self.model_dir) f = open("output/Data/data.pkl", "wb") pickle.dump(self.data, f) f.close def predict(self, seq): pass
class GANTrainer(object): def __init__(self, output_dir): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) self.batch_size = cfg.TRAIN.BATCH_SIZE torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True # ############# For training stageI GAN ############# def load_network_stageI(self): from model import STAGE1_G, STAGE1_D netG = STAGE1_G() netG.apply(weights_init) print(netG) netD = STAGE1_D() netD.apply(weights_init) print(netD) if cfg.NET_G != '': state_dict = \ torch.load(cfg.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict["netG"]) print('Load from: ', cfg.NET_G) if cfg.NET_D != '': state_dict = \ torch.load(cfg.NET_D, map_location=lambda storage, loc: storage) netD.load_state_dict(state_dict) print('Load from: ', cfg.NET_D) if cfg.CUDA: netG.cuda() netD.cuda() return netG, netD # ############# For training stageII GAN ############# def load_network_stageII(self): from model import STAGE1_G, STAGE2_G, STAGE2_D Stage1_G = STAGE1_G() netG = STAGE2_G(Stage1_G) netG.apply(weights_init) print(netG) if cfg.NET_G != '': state_dict = torch.load(cfg.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict["netG"]) print('Load from: ', cfg.NET_G) elif cfg.STAGE1_G != '': state_dict = torch.load(cfg.STAGE1_G, map_location=lambda storage, loc: storage) netG.STAGE1_G.load_state_dict(state_dict["netG"]) print('Load from: ', cfg.STAGE1_G) else: print("Please give the Stage1_G path") return netD = STAGE2_D() netD.apply(weights_init) if cfg.NET_D != '': state_dict = \ torch.load(cfg.NET_D, map_location=lambda storage, loc: storage) netD.load_state_dict(state_dict) print('Load from: ', cfg.NET_D) print(netD) if cfg.CUDA: netG.cuda() netD.cuda() return netG, netD def train(self, data_loader, stage=1, max_objects=3): if stage == 1: netG, netD = self.load_network_stageI() else: netG, netD = self.load_network_stageII() nz = cfg.Z_DIM batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) # with torch.no_grad(): fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), requires_grad=False) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerD = optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) count = 0 for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr for i, data in enumerate(data_loader, 0): ###################################################### # (1) Prepare training data ###################################################### real_img_cpu, bbox, label, txt_embedding = data real_imgs = Variable(real_img_cpu) txt_embedding = Variable(txt_embedding) if cfg.CUDA: real_imgs = real_imgs.cuda() if cfg.STAGE == 1: bbox = bbox.cuda() elif cfg.STAGE == 2: bbox = [bbox[0].cuda(), bbox[1].cuda()] label = label.cuda() txt_embedding = txt_embedding.cuda() if cfg.STAGE == 1: bbox = bbox.view(-1, 4) transf_matrices_inv = compute_transformation_matrix_inverse( bbox) transf_matrices_inv = transf_matrices_inv.view( real_imgs.shape[0], max_objects, 2, 3) transf_matrices = compute_transformation_matrix(bbox) transf_matrices = transf_matrices.view( real_imgs.shape[0], max_objects, 2, 3) elif cfg.STAGE == 2: _bbox = bbox[0].view(-1, 4) transf_matrices_inv = compute_transformation_matrix_inverse( _bbox) transf_matrices_inv = transf_matrices_inv.view( real_imgs.shape[0], max_objects, 2, 3) _bbox = bbox[1].view(-1, 4) transf_matrices_inv_s2 = compute_transformation_matrix_inverse( _bbox) transf_matrices_inv_s2 = transf_matrices_inv_s2.view( real_imgs.shape[0], max_objects, 2, 3) transf_matrices_s2 = compute_transformation_matrix(_bbox) transf_matrices_s2 = transf_matrices_s2.view( real_imgs.shape[0], max_objects, 2, 3) # produce one-hot encodings of the labels _labels = label.long() # remove -1 to enable one-hot converting _labels[_labels < 0] = 80 label_one_hot = torch.cuda.FloatTensor(noise.shape[0], max_objects, 81).fill_(0) label_one_hot = label_one_hot.scatter_(2, _labels, 1).float() ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) if cfg.STAGE == 1: inputs = (txt_embedding, noise, transf_matrices_inv, label_one_hot) elif cfg.STAGE == 2: inputs = (txt_embedding, noise, transf_matrices_inv, transf_matrices_s2, transf_matrices_inv_s2, label_one_hot) _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel( netG, inputs, self.gpus) # _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise, transf_matrices_inv, label_one_hot) ############################ # (3) Update D network ########################### netD.zero_grad() if cfg.STAGE == 1: errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, label_one_hot, transf_matrices, transf_matrices_inv, mu, self.gpus) elif cfg.STAGE == 2: errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, label_one_hot, transf_matrices_s2, transf_matrices_inv_s2, mu, self.gpus) errD.backward(retain_graph=True) optimizerD.step() ############################ # (2) Update G network ########################### netG.zero_grad() if cfg.STAGE == 1: errG = compute_generator_loss(netD, fake_imgs, real_labels, label_one_hot, transf_matrices, transf_matrices_inv, mu, self.gpus) elif cfg.STAGE == 2: errG = compute_generator_loss(netD, fake_imgs, real_labels, label_one_hot, transf_matrices_s2, transf_matrices_inv_s2, mu, self.gpus) kl_loss = KL_loss(mu, logvar) errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL errG_total.backward() optimizerG.step() count += 1 if i % 500 == 0: summary_D = summary.scalar('D_loss', errD.item()) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.item()) summary_KL = summary.scalar('KL_loss', kl_loss.item()) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D_r, count) self.summary_writer.add_summary(summary_D_w, count) self.summary_writer.add_summary(summary_D_f, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) # save the image result for each epoch with torch.no_grad(): if cfg.STAGE == 1: inputs = (txt_embedding, noise, transf_matrices_inv, label_one_hot) elif cfg.STAGE == 2: inputs = (txt_embedding, noise, transf_matrices_inv, transf_matrices_s2, transf_matrices_inv_s2, label_one_hot) lr_fake, fake, _, _, _ = nn.parallel.data_parallel( netG, inputs, self.gpus) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) with torch.no_grad(): if cfg.STAGE == 1: inputs = (txt_embedding, noise, transf_matrices_inv, label_one_hot) elif cfg.STAGE == 2: inputs = (txt_embedding, noise, transf_matrices_inv, transf_matrices_s2, transf_matrices_inv_s2, label_one_hot) lr_fake, fake, _, _, _ = nn.parallel.data_parallel( netG, inputs, self.gpus) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) end_t = time.time() print( '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.item(), errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir) # save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir) # self.summary_writer.close() def sample(self, datapath, num_samples=25, stage=1, draw_bbox=True, max_objects=3): from PIL import Image, ImageDraw, ImageFont import cPickle as pickle import torchvision import torchvision.utils as vutils img_dir = cfg.IMG_DIR if stage == 1: netG, _ = self.load_network_stageI() else: netG, _ = self.load_network_stageII() netG.eval() # Load text embeddings generated from the encoder t_file = torchfile.load(datapath + "val_captions.t7") captions_list = t_file.raw_txt embeddings = np.concatenate(t_file.fea_txt, axis=0) num_embeddings = len(captions_list) label, bbox = load_validation_data(datapath) filepath = os.path.join(datapath, 'filenames.pickle') with open(filepath, 'rb') as f: filenames = pickle.load(f) print('Successfully load sentences from: ', datapath) print('Total number of sentences:', num_embeddings) # path to save generated samples save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')] + "_visualize_bbox" print("saving to:", save_dir) mkdir_p(save_dir) if cfg.CUDA: if cfg.STAGE == 1: bbox = bbox.cuda() elif cfg.STAGE == 2: bbox = [bbox.clone().cuda(), bbox.cuda()] label = label.cuda() ####################################### if cfg.STAGE == 1: bbox_ = bbox.clone() elif cfg.STAGE == 2: bbox_ = bbox[0].clone() if cfg.STAGE == 1: bbox = bbox.view(-1, 4) transf_matrices_inv = compute_transformation_matrix_inverse(bbox) transf_matrices_inv = transf_matrices_inv.view( num_embeddings, max_objects, 2, 3) elif cfg.STAGE == 2: _bbox = bbox[0].view(-1, 4) transf_matrices_inv = compute_transformation_matrix_inverse(_bbox) transf_matrices_inv = transf_matrices_inv.view( num_embeddings, max_objects, 2, 3) _bbox = bbox[1].view(-1, 4) transf_matrices_inv_s2 = compute_transformation_matrix_inverse( _bbox) transf_matrices_inv_s2 = transf_matrices_inv_s2.view( num_embeddings, max_objects, 2, 3) transf_matrices_s2 = compute_transformation_matrix(_bbox) transf_matrices_s2 = transf_matrices_s2.view( num_embeddings, max_objects, 2, 3) # produce one-hot encodings of the labels _labels = label.long() # remove -1 to enable one-hot converting _labels[_labels < 0] = 80 label_one_hot = torch.cuda.FloatTensor(num_embeddings, max_objects, 81).fill_(0) label_one_hot = label_one_hot.scatter_(2, _labels, 1).float() ####################################### nz = cfg.Z_DIM noise = Variable(torch.FloatTensor(9, nz)) if cfg.CUDA: noise = noise.cuda() imsize = 64 if stage == 1 else 256 for count in range(num_samples): index = int(np.random.randint(0, num_embeddings, 1)) key = filenames[index] img_name = img_dir + "/" + key + ".jpg" img = Image.open(img_name).convert('RGB').resize((imsize, imsize), Image.ANTIALIAS) val_image = torchvision.transforms.functional.to_tensor(img) val_image = val_image.view(1, 3, imsize, imsize) val_image = (val_image - 0.5) * 2 embeddings_batch = embeddings[index] transf_matrices_inv_batch = transf_matrices_inv[index] label_one_hot_batch = label_one_hot[index] embeddings_batch = np.reshape(embeddings_batch, (1, 1024)).repeat(9, 0) transf_matrices_inv_batch = transf_matrices_inv_batch.view( 1, 3, 2, 3).repeat(9, 1, 1, 1) label_one_hot_batch = label_one_hot_batch.view(1, 3, 81).repeat(9, 1, 1) if cfg.STAGE == 2: transf_matrices_s2_batch = transf_matrices_s2[index] transf_matrices_s2_batch = transf_matrices_s2_batch.view( 1, 3, 2, 3).repeat(9, 1, 1, 1) transf_matrices_inv_s2_batch = transf_matrices_inv_s2[index] transf_matrices_inv_s2_batch = transf_matrices_inv_s2_batch.view( 1, 3, 2, 3).repeat(9, 1, 1, 1) txt_embedding = Variable(torch.FloatTensor(embeddings_batch)) if cfg.CUDA: label_one_hot_batch = label_one_hot_batch.cuda() txt_embedding = txt_embedding.cuda() ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) # inputs = (txt_embedding, noise, transf_matrices_inv_batch, label_one_hot_batch) if cfg.STAGE == 1: inputs = (txt_embedding, noise, transf_matrices_inv_batch, label_one_hot_batch) elif cfg.STAGE == 2: inputs = (txt_embedding, noise, transf_matrices_inv_batch, transf_matrices_s2_batch, transf_matrices_inv_s2_batch, label_one_hot_batch) with torch.no_grad(): _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel( netG, inputs, self.gpus) data_img = torch.FloatTensor(10, 3, imsize, imsize).fill_(0) data_img[0] = val_image data_img[1:10] = fake_imgs if draw_bbox: for idx in range(3): x, y, w, h = tuple( [int(imsize * x) for x in bbox_[index, idx]]) w = imsize - 1 if w > imsize - 1 else w h = imsize - 1 if h > imsize - 1 else h if x <= -1: break data_img[:10, :, y, x:x + w] = 1 data_img[:10, :, y:y + h, x] = 1 data_img[:10, :, y + h, x:x + w] = 1 data_img[:10, :, y:y + h, x + w] = 1 vutils.save_image(data_img, '{}/{}.png'.format(save_dir, captions_list[index]), normalize=True, nrow=10) print("Saved {} files to {}".format(count + 1, save_dir))
class RecurrentGANTrainer: def __init__(self, output_dir, data_loader, imsize): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'model') self.image_dir = os.path.join(output_dir, 'image') self.log_dir = os.path.join(output_dir, 'log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL self.data_loader = data_loader self.num_batches = len(self.data_loader) def prepare_data(self, data): imgs, w_imgs, t_embedding, _, caption_tensors, len_vector = data v_caption_tensors = [] v_len_vector = [] real_vimgs, wrong_vimgs = [], [] if cfg.CUDA: vembedding = Variable(t_embedding).cuda() if caption_tensors is not None: v_caption_tensors = Variable(caption_tensors).cuda() v_len_vector = len_vector.cuda() else: vembedding = Variable(t_embedding) if caption_tensors is not None: v_caption_tensors = Variable(caption_tensors) v_len_vector = len_vector for i in range(len(imgs)): if cfg.CUDA: real_vimgs.append(Variable(imgs[i]).cuda()) wrong_vimgs.append(Variable(w_imgs[i]).cuda()) else: real_vimgs.append(Variable(imgs[i])) wrong_vimgs.append(Variable(w_imgs[i])) return imgs, real_vimgs, wrong_vimgs, vembedding, v_caption_tensors, v_len_vector def train_Dnet(self, count): flag = count % 100 batch_size = self.real_imgs[0].size(0) criterion = self.citerion netD = self.netD optD = self.optimizerD real_imgs = self.real_imgs[0] wrong_imgs = self.wrong_imgs[0] fake_imgs = self.fake_imgs[-1] # Take only the last image netD.zero_grad() real_labels = self.real_labels[:batch_size] fake_labels = self.fake_labels[:batch_size] # Calculating the logits mu = self.mus[-1] real_logits = netD(real_imgs, mu.detach()) wrong_logits = netD(wrong_imgs, mu.detach()) fake_logits = netD(fake_imgs.detach(), mu.detach()) # Calculating the error errD_real = criterion(real_logits[0], real_labels) errD_wrong = criterion(wrong_logits[0], fake_labels) errD_fake = criterion(fake_logits[0], fake_labels) if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion( real_logits[1], real_labels) errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion( wrong_logits[1], fake_labels) errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion( fake_logits[1], fake_labels) errD_real += errD_real_uncond errD_wrong += errD_wrong_uncond errD_fake += errD_fake_uncond errD = errD_real + errD_wrong + errD_fake else: errD = errD_real + 0.5 * (errD_wrong + errD_fake) # Calculating the gradients errD.backward() # Backproping optD.step() if flag == 0: summary_D = summary.scalar('D_loss%d', errD.data[0]) self.summary_writer.add_summary(summary_D, count) return errD def train_Gnet(self, count): self.netG.zero_grad() errG_total = 0 flag = count % 100 batch_size = self.real_imgs[0].size(0) criterion = self.citerion mus, logvars = self.mus, self.logvars real_labels = self.real_labels[:batch_size] # Looping through each time-step. for i in range(len(self.fake_imgs)): logits = self.netD(self.fake_imgs[i], mus[i]) errG = criterion(logits[0], real_labels) if len(logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: errG_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion( logits[1], real_labels) errG += errG_uncond errG_total += errG if flag == 0: summary_D = summary.scalar('G_loss%d' % i, errG.data[0]) self.summary_writer.add_summary(summary_D, count) kl_loss = 0 for i in range(len(self.fake_imgs)): kl_loss += KL_loss(mus[i], logvars[i]) * cfg.TRAIN.COEFF.KL errG_total += kl_loss # Compute the gradients errG_total.backward() # BPTT self.optimizerG.step() return kl_loss, errG_total def save_singleimages(self, images, filenames, save_dir, split_dir, sentenceID, imsize, mean=0): for i in range(images.size(0)): s_tmp = '%s/single_samples/%s/%s' %\ (save_dir, split_dir, filenames[i]+'_'+str(mean)) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) fullpath = '%s_%d_sentence%d.png' % (s_tmp, imsize, sentenceID) # range from [-1, 1] to [0, 255] img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte() ndarr = img.permute(1, 2, 0).data.cpu().numpy() im = Image.fromarray(ndarr) im.save(fullpath) def train(self): self.netG, self.netD, self.inception_model, start_count = load_network( self.gpus) avg_param_G = copy_G_params(self.netG) self.optimizerG, self.optimizerD = define_optimizers( self.netG, self.netD) self.citerion = nn.BCELoss() self.real_labels = Variable( torch.FloatTensor(self.batch_size).fill_(1)) self.fake_labels = Variable( torch.FloatTensor(self.batch_size).fill_(0)) # self.gradient_one = torch.FloatTensor([1.0]) # self.gradient_half = torch.FloatTensor([0.5]) # Initial Hidden State h0 = Variable( torch.FloatTensor(self.batch_size, 1, cfg.HIDDEN_STATE_SIZE, cfg.HIDDEN_STATE_SIZE)) h0_initalized = Variable( torch.FloatTensor(self.batch_size, 1, cfg.HIDDEN_STATE_SIZE, cfg.HIDDEN_STATE_SIZE).normal_(0, 1)) if cfg.CUDA: self.citerion.cuda() self.real_labels = self.real_labels.cuda() self.fake_labels = self.fake_labels.cuda() h0 = h0.cuda() h0_initalized = h0_initalized.cuda() predictions = [] count = start_count start_epoch = start_count // self.num_batches for epoch in range(start_epoch, self.max_epoch): start_t = time.time() for step, data in enumerate(self.data_loader, 0): self.imgs_tcpu, self.real_imgs, self.wrong_imgs, self.txt_embeddings, self.caption_tensors, self.len_vector = self.prepare_data( data) # 1. Generate Fake Data from Generator h0.data.normal_(0, 1) self.fake_imgs, self.mus, self.logvars = self.netG( h0, self.txt_embeddings) # 2. Update Discriminator errD_total = self.train_Dnet(count) # 3. Update Generator kl_loss, errG_total = self.train_Gnet(count) for p, avg_p in zip(self.netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) pred = self.inception_model(self.fake_imgs[-1].detach()) predictions.append(pred.data.cpu().numpy()) if count % 100 == 0: summary_D = summary.scalar('D_loss', errD_total.data[0]) summary_G = summary.scalar('G_loss', errG_total.data[0]) summary_KL = summary.scalar('KL_loss', kl_loss.data[0]) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) count += 1 if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: save_model(self.netG, avg_param_G, self.netD, count, self.model_dir) # Save Images backup_para = copy_G_params(self.netG) load_params(self.netG, avg_param_G) self.fake_imgs, _, _ = self.netG(h0_initalized, self.txt_embeddings) save_img_results(self.imgs_tcpu, self.fake_imgs, count, self.image_dir, self.summary_writer) load_params(self.netG, backup_para) # Compute Inception Score if len(predictions) > 500: predictions = np.concatenate(predictions, 0) mean, std = compute_inception_score(predictions, 10) score_summary = summary.scalar('Inception_mean', mean) self.summary_writer.add_summary(score_summary, count) mean_nlpp, std_nlpp = negative_log_posterior_probability( predictions, 10) mean_nlpp_summary = summary.scalar( 'NLPP_mean', mean_nlpp) self.summary_writer.add_summary( mean_nlpp_summary, count) predictions = [] end_t = time.time() print( '''[%d/%d][%d--%d] Loss_D: %.2f Loss_G: %.2f Loss_KL: %.2f Time: %.2fs ''' % (epoch, self.max_epoch, self.num_batches, count, errD_total.data[0], errG_total.data[0], kl_loss.data[0], end_t - start_t)) save_model(self.netG, avg_param_G, self.netD, count, self.model_dir) self.summary_writer.close() def evaluate(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: Could not find the saved Generator Model.') else: # Build and load the generator if split_dir == 'test': split_dir = 'valid' netG = Generator() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=self.gpus) print(netG) state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Loaded weights to Generator Network.', cfg.TRAIN.NET_G) # the path to save generated images s_tmp = cfg.TRAIN.NET_G istart = s_tmp.rfind('_') + 1 iend = s_tmp.rfind('.') iteration = int(s_tmp[istart:iend]) s_tmp = s_tmp[:s_tmp.rfind('/')] save_dir = '%s/iteration%d' % (s_tmp, iteration) h0 = Variable( torch.FloatTensor(self.batch_size, 1, cfg.INITIAL_IMAGE_SIZE, cfg.INITIAL_IMAGE_SIZE)) if cfg.CUDA: netG.cuda() h0 = h0.cuda() # switch to evaluate mode netG.eval() for step, data in enumerate(self.data_loader, 0): imgs, t_embeddings, filenames = data if cfg.CUDA: t_embeddings = Variable(t_embeddings).cuda() else: t_embeddings = Variable(t_embeddings) embedding_dim = t_embeddings.size(1) batch_size = imgs[0].size(0) h0.data.normal_(0, 1) fake_imgs, _, _ = netG(h0, t_embeddings) self.save_singleimages(fake_imgs[-1], filenames, save_dir, split_dir, 1, 32)
class GANTrainer(object): def __init__(self, output_dir): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True # ############# For training stageI GAN ############# def load_network_stageI(self): from model import STAGE1_G, STAGE1_D netG = STAGE1_G() netG.apply(weights_init) print(netG) netD = STAGE1_D() netD.apply(weights_init) print(netD) if cfg.NET_G != '': #state_dict = \ #torch.load(cfg.NET_G, # map_location=lambda storage, loc: storage) state_dict = torch.load(cfg.NET_G, map_location='cuda:0') netG.load_state_dict(state_dict) print('Load from: ', cfg.NET_G) if cfg.NET_D != '': state_dict = \ torch.load(cfg.NET_D, map_location=lambda storage, loc: storage) netD.load_state_dict(state_dict) print('Load from: ', cfg.NET_D) if cfg.CUDA: netG.cuda() netD.cuda() return netG, netD # ############# For training stageII GAN ############# def load_network_stageII(self): from model import STAGE1_G, STAGE2_G, STAGE2_D Stage1_G = STAGE1_G() netG = STAGE2_G(Stage1_G) netG.apply(weights_init) print(netG) if cfg.NET_G != '': state_dict = \ torch.load(cfg.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load from: ', cfg.NET_G) elif cfg.STAGE1_G != '': state_dict = \ torch.load(cfg.STAGE1_G, map_location=lambda storage, loc: storage) netG.STAGE1_G.load_state_dict(state_dict) print('Load from: ', cfg.STAGE1_G) else: print("Please give the Stage1_G path") return netD = STAGE2_D() netD.apply(weights_init) if cfg.NET_D != '': state_dict = \ torch.load(cfg.NET_D, map_location=lambda storage, loc: storage) netD.load_state_dict(state_dict) print('Load from: ', cfg.NET_D) print(netD) if cfg.CUDA: netG.cuda() netD.cuda() return netG, netD def train(self, data_loader, stage=1): if stage == 1: netG, netD = self.load_network_stageI() else: netG, netD = self.load_network_stageII() nz = cfg.Z_DIM batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) with torch.no_grad(): #Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), #volatile=True) fixed_noise = \ torch.FloatTensor(batch_size, nz).normal_(0, 1) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH optimizerD = \ optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) count = 0 for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr #print('dataLoader, line 156 trainer.py...........') #print(data_loader) num_batches = len(data_loader) print('Number of batches: ' + str(len(data_loader))) for i, data in enumerate(data_loader, 0): print('Epoch number: ' + str(epoch) + '\tBatches: ' + str(i) + '/' + str(num_batches), end='\r') ###################################################### # (1) Prepare training data ###################################################### real_img_cpu, txt_embedding = data #print(txt_embedding.shape) #(Batch_size,1024) #exit(0) real_imgs = Variable(real_img_cpu) txt_embedding = Variable(txt_embedding) if cfg.CUDA: real_imgs = real_imgs.cuda() txt_embedding = txt_embedding.cuda() #print('train line 170') ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (txt_embedding, noise) _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) #print('Fake images generated shape = ' + str(fake_imgs.shape)) #print('Shape of fake image: ' + str(fake_imgs.shape)) [Batch_size, Channels(3), N, N] #print('Fake images: ') #Display one image ### Check this line! How to display image?? ############## #plt.imshow(fake_imgs[0].permute(1,2,0).cpu().detach().numpy()) #exit(0) ################################################ ############################ # (3) Update D network ########################### netD.zero_grad() #print('train line 186') errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, mu, self.gpus) errD.backward() optimizerD.step() ############################ # (2) Update G network ########################### netG.zero_grad() errG = compute_generator_loss(netD, fake_imgs, real_labels, mu, self.gpus) kl_loss = KL_loss(mu, logvar) errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL errG_total.backward() optimizerG.step() #print('train line 203') count = count + 1 if i % 100 == 0: """ summary_D = summary.scalar('D_loss', errD.data[0]) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.data[0]) summary_KL = summary.scalar('KL_loss', kl_loss.data[0]) """ ## My lines summary_D = summary.scalar('D_loss', errD.data) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.data) summary_KL = summary.scalar('KL_loss', kl_loss.data) #### End of my lines self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D_r, count) self.summary_writer.add_summary(summary_D_w, count) self.summary_writer.add_summary(summary_D_f, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) # save the image result for each epoch inputs = (txt_embedding, fixed_noise) lr_fake, fake, _, _ = \ nn.parallel.data_parallel(netG, inputs, self.gpus) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) del inputs end_t = time.time() print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.data, errG.data, kl_loss.data, errD_real, errD_wrong, errD_fake, (end_t - start_t))) # % (epoch, self.max_epoch, i, len(data_loader), # errD.data[0], errG.data[0], kl_loss.data[0], # errD_real, errD_wrong, errD_fake, (end_t - start_t))) print('################EPOCH COMPLETED###########') if epoch % self.snapshot_interval == 0: save_model(netG, netD, epoch, self.model_dir) # save_model(netG, netD, self.max_epoch, self.model_dir) # self.summary_writer.close() def sample(self, datapath, stage=1): if stage == 1: netG, _ = self.load_network_stageI() else: netG, _ = self.load_network_stageII() netG.eval() # Load text embeddings generated from the encoder t_file = torchfile.load(datapath) #with open(datapath, 'rb') as f: # embeddings = pickle.load(f) #embeddings = np.array(embeddings) captions_list = t_file.raw_txt embeddings = np.concatenate(t_file.fea_txt, axis=0) num_embeddings = len(captions_list) #num_embeddings = len(embeddings) print('Successfully load sentences from: ', datapath) print('Total number of sentences:', num_embeddings) print('num_embeddings:', num_embeddings, embeddings.shape) # path to save generated samples save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')] mkdir_p(save_dir) batch_size = np.minimum(num_embeddings, self.batch_size) nz = cfg.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz)) if cfg.CUDA: noise = noise.cuda() count = 0 while count < num_embeddings: if count > 3000: break iend = count + batch_size if iend > num_embeddings: iend = num_embeddings count = num_embeddings - batch_size embeddings_batch = embeddings[count:iend] # captions_batch = captions_list[count:iend] txt_embedding = Variable(torch.FloatTensor(embeddings_batch)) if cfg.CUDA: txt_embedding = txt_embedding.cuda() ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (txt_embedding, noise) _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) for i in range(batch_size): save_name = '%s/%d.png' % (save_dir, count + i) im = fake_imgs[i].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) # print('im', im.shape) im = np.transpose(im, (1, 2, 0)) # print('im', im.shape) im = Image.fromarray(im) im.save(save_name) count += batch_size
class condGANTrainer(object): def __init__(self, output_dir, data_loader, imsize): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL self.data_loader = data_loader self.num_batches = len(self.data_loader) def prepare_data(self, data): imgs, w_imgs, t_embedding, _ = data real_vimgs, wrong_vimgs = [], [] if cfg.CUDA: vembedding = Variable(t_embedding).cuda() else: vembedding = Variable(t_embedding) for i in range(self.num_Ds): if cfg.CUDA: real_vimgs.append(Variable(imgs[i]).cuda()) wrong_vimgs.append(Variable(w_imgs[i]).cuda()) else: real_vimgs.append(Variable(imgs[i])) wrong_vimgs.append(Variable(w_imgs[i])) return imgs, real_vimgs, wrong_vimgs, vembedding def train_Dnet(self, idx, count): flag = count % 100 batch_size = self.real_imgs[0].size(0) criterion, mu = self.criterion, self.mu netD, optD = self.netsD[idx], self.optimizersD[idx] real_imgs = self.real_imgs[idx] wrong_imgs = self.wrong_imgs[idx] fake_imgs = self.fake_imgs[idx] # netD.zero_grad() # Forward real_labels = self.real_labels[:batch_size] fake_labels = self.fake_labels[:batch_size] # for real real_logits = netD(real_imgs, mu.detach()) wrong_logits = netD(wrong_imgs, mu.detach()) fake_logits = netD(fake_imgs.detach(), mu.detach()) # errD_real = criterion(real_logits[0], real_labels) errD_wrong = criterion(wrong_logits[0], fake_labels) errD_fake = criterion(fake_logits[0], fake_labels) if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ criterion(real_logits[1], real_labels) errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ criterion(wrong_logits[1], real_labels) errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ criterion(fake_logits[1], fake_labels) # errD_real = errD_real + errD_real_uncond errD_wrong = errD_wrong + errD_wrong_uncond errD_fake = errD_fake + errD_fake_uncond # errD = errD_real + errD_wrong + errD_fake else: errD = errD_real + 0.5 * (errD_wrong + errD_fake) # backward errD.backward() # update parameters optD.step() # log if flag == 0: summary_D = summary.scalar('D_loss%d' % idx, errD.data) self.summary_writer.add_summary(summary_D, count) return errD def train_Gnet(self, count): self.netG.zero_grad() errG_total = 0 flag = count % 100 batch_size = self.real_imgs[0].size(0) criterion, mu, logvar = self.criterion, self.mu, self.logvar real_labels = self.real_labels[:batch_size] for i in range(self.num_Ds): outputs = self.netsD[i](self.fake_imgs[i], mu) errG = criterion(outputs[0], real_labels) if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS *\ criterion(outputs[1], real_labels) errG = errG + errG_patch errG_total = errG_total + errG if flag == 0: summary_D = summary.scalar('G_loss%d' % i, errG.data) self.summary_writer.add_summary(summary_D, count) # Compute color consistency losses if cfg.TRAIN.COEFF.COLOR_LOSS > 0: if self.num_Ds > 1: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1]) mu2, covariance2 = \ compute_mean_covariance(self.fake_imgs[-2].detach()) like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ nn.MSELoss()(covariance1, covariance2) errG_total = errG_total + like_mu2 + like_cov2 if flag == 0: sum_mu = summary.scalar('G_like_mu2', like_mu2.data[0]) self.summary_writer.add_summary(sum_mu, count) sum_cov = summary.scalar('G_like_cov2', like_cov2.data[0]) self.summary_writer.add_summary(sum_cov, count) if self.num_Ds > 2: mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2]) mu2, covariance2 = \ compute_mean_covariance(self.fake_imgs[-3].detach()) like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ nn.MSELoss()(covariance1, covariance2) errG_total = errG_total + like_mu1 + like_cov1 if flag == 0: sum_mu = summary.scalar('G_like_mu1', like_mu1.data[0]) self.summary_writer.add_summary(sum_mu, count) sum_cov = summary.scalar('G_like_cov1', like_cov1.data[0]) self.summary_writer.add_summary(sum_cov, count) kl_loss = KL_loss(mu, logvar) * cfg.TRAIN.COEFF.KL errG_total = errG_total + kl_loss errG_total.backward() self.optimizerG.step() return kl_loss, errG_total def train(self): self.netG, self.netsD, self.num_Ds,\ self.inception_model, start_count = load_network(self.gpus) print("Models loaded") avg_param_G = copy_G_params(self.netG) self.optimizerG, self.optimizersD = \ define_optimizers(self.netG, self.netsD) self.criterion = nn.BCELoss() self.real_labels = \ Variable(torch.FloatTensor(self.batch_size).fill_(1)) self.fake_labels = \ Variable(torch.FloatTensor(self.batch_size).fill_(0)) self.gradient_one = torch.FloatTensor([1.0]) self.gradient_half = torch.FloatTensor([0.5]) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1)) if cfg.CUDA: self.criterion.cuda() self.real_labels = self.real_labels.cuda() self.fake_labels = self.fake_labels.cuda() self.gradient_one = self.gradient_one.cuda() self.gradient_half = self.gradient_half.cuda() noise, fixed_noise = noise.cuda(), fixed_noise.cuda() print("Start epoches!") predictions = [] count = start_count start_epoch = start_count // (self.num_batches) for epoch in range(start_epoch, self.max_epoch): print("Epoch {:d} started".format(epoch)) start_t = time.time() start_iters = time.time() N_TOTAL_BATCH = len(self.data_loader) # data_iterator = iter(self.data_loader) for step, data in enumerate(self.data_loader, 0): # for step in range(N_TOTAL_BATCH): #enumerate(self.data_loader, 0): # print("Iter {:d}".format(step)) # data = next(data_iterator) if step % 50 == 0: curr_time = time.time() print("Itteration {:d}/{:d}\nTime for 50 batch steps:{:.2f}\nTotal time:{:.2f}".format(step, N_TOTAL_BATCH, curr_time - start_iters, curr_time-start_t)) start_iters = time.time() ####################################################### # (0) Prepare training data ###################################################### self.imgs_tcpu, self.real_imgs, self.wrong_imgs, \ self.txt_embedding = self.prepare_data(data) ####################################################### # (1) Generate fake images ###################################################### #print("Data prepeared!") noise.data.normal_(0, 1) self.fake_imgs, self.mu, self.logvar = \ self.netG(noise, self.txt_embedding) ####################################################### # (2) Update D network ###################################################### #print("Fake image generated") errD_total = 0 for i in range(self.num_Ds): errD = self.train_Dnet(i, count) errD_total += errD ####################################################### # (3) Update G network: maximize log(D(G(z))) ###################################################### #print("D network updated") kl_loss, errG_total = self.train_Gnet(count) for p, avg_p in zip(self.netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) # for inception score pred = self.inception_model(self.fake_imgs[-1].detach()) predictions.append(pred.data.cpu().numpy()) #print("G network uddated") if count % 100 == 0: summary_D = summary.scalar('D_loss', errD_total.data) summary_G = summary.scalar('G_loss', errG_total.data) summary_KL = summary.scalar('KL_loss', kl_loss.data) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) count = count + 1 if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) # Save images backup_para = copy_G_params(self.netG) load_params(self.netG, avg_param_G) # self.fake_imgs, _, _ = \ self.netG(fixed_noise, self.txt_embedding) save_img_results(self.imgs_tcpu, self.fake_imgs, self.num_Ds, count, self.image_dir, self.summary_writer) # load_params(self.netG, backup_para) # Compute inception score if len(predictions) > 500: predictions = np.concatenate(predictions, 0) mean, std = compute_inception_score(predictions, 10) # print('mean:', mean, 'std', std) m_incep = summary.scalar('Inception_mean', mean) self.summary_writer.add_summary(m_incep, count) # mean_nlpp, std_nlpp = \ negative_log_posterior_probability(predictions, 10) m_nlpp = summary.scalar('NLPP_mean', mean_nlpp) self.summary_writer.add_summary(m_nlpp, count) # predictions = [] end_t = time.time() print('''[%d/%d][%d] Loss_D: %.2f Loss_G: %.2f Loss_KL: %.2f Time: %.2fs ''' # D(real): %.4f D(wrong):%.4f D(fake) %.4f % (epoch, self.max_epoch, self.num_batches, errD_total.data, errG_total.data, kl_loss.data, end_t - start_t)) save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) self.summary_writer.close() def save_superimages(self, images_list, filenames, save_dir, split_dir, imsize): batch_size = images_list[0].size(0) num_sentences = len(images_list) for i in range(batch_size): s_tmp = '%s/super/%s/%s' %\ (save_dir, split_dir, filenames[i]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) # savename = '%s_%d.png' % (s_tmp, imsize) super_img = [] for j in range(num_sentences): img = images_list[j][i] # print(img.size()) img = img.view(1, 3, imsize, imsize) # print(img.size()) super_img.append(img) # break super_img = torch.cat(super_img, 0) vutils.save_image(super_img, savename, nrow=10, normalize=True) def save_singleimages(self, images, filenames, save_dir, split_dir, sentenceID, imsize): for i in range(images.size(0)): s_tmp = '%s/single_samples/%s/%s' %\ (save_dir, split_dir, filenames[i]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) fullpath = '%s_%d_sentence%d.png' % (s_tmp, imsize, sentenceID) # range from [-1, 1] to [0, 255] img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte() ndarr = img.permute(1, 2, 0).data.cpu().numpy() im = Image.fromarray(ndarr) im.save(fullpath) def evaluate(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: # Build and load the generator #if split_dir == 'test': # split_dir = 'valid' netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=self.gpus) print(netG) # state_dict = torch.load(cfg.TRAIN.NET_G) state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) # the path to save generated images s_tmp = cfg.TRAIN.NET_G istart = s_tmp.rfind('_') + 1 iend = s_tmp.rfind('.') iteration = int(s_tmp[istart:iend]) s_tmp = s_tmp[:s_tmp.rfind('/')] save_dir = '%s/iteration%d' % (s_tmp, iteration) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) if cfg.CUDA: netG.cuda() noise = noise.cuda() # switch to evaluate mode netG.eval() for step, data in enumerate(self.data_loader, 0): imgs, t_embeddings, filenames = data if cfg.CUDA: t_embeddings = Variable(t_embeddings).cuda() else: t_embeddings = Variable(t_embeddings) # print(t_embeddings[:, 0, :], t_embeddings.size(1)) embedding_dim = t_embeddings.size(1) batch_size = imgs[0].size(0) noise.data.resize_(batch_size, nz) noise.data.normal_(0, 1) fake_img_list = [] for i in range(embedding_dim): fake_imgs, _, _ = netG(noise, t_embeddings[:, i, :]) if cfg.TEST.B_EXAMPLE: # fake_img_list.append(fake_imgs[0].data.cpu()) # fake_img_list.append(fake_imgs[1].data.cpu()) fake_img_list.append(fake_imgs[2].data.cpu()) else: self.save_singleimages(fake_imgs[-1], filenames, save_dir, split_dir, i, 256) # self.save_singleimages(fake_imgs[-2], filenames, # save_dir, split_dir, i, 128) # self.save_singleimages(fake_imgs[-3], filenames, # save_dir, split_dir, i, 64) # break if cfg.TEST.B_EXAMPLE: # self.save_superimages(fake_img_list, filenames, # save_dir, split_dir, 64) # self.save_superimages(fake_img_list, filenames, # save_dir, split_dir, 128) self.save_superimages(fake_img_list, filenames, save_dir, split_dir, 256)
class GANTrainer(object): def __init__(self, output_dir): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL self.max_objects = 3 s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) self.batch_size = cfg.TRAIN.BATCH_SIZE torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True # ############# For training stageI GAN ############# def load_network_stageI(self): from model import STAGE1_G, STAGE1_D netG = STAGE1_G() netG.apply(weights_init) print(netG) netD = STAGE1_D() netD.apply(weights_init) print(netD) if cfg.NET_G != '': state_dict = \ torch.load(cfg.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict["netG"]) print('Load from: ', cfg.NET_G) if cfg.NET_D != '': state_dict = \ torch.load(cfg.NET_D, map_location=lambda storage, loc: storage) netD.load_state_dict(state_dict) print('Load from: ', cfg.NET_D) if cfg.CUDA: netG.cuda() netD.cuda() return netG, netD def train(self, data_loader): netG, netD = self.load_network_stageI() nz = cfg.Z_DIM batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), requires_grad=False) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerD = optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) print("Starting training...") count = 0 for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr for i, data in enumerate(data_loader, 0): ###################################################### # (1) Prepare training data ###################################################### real_img_cpu, bbox, label = data real_imgs = Variable(real_img_cpu) if cfg.CUDA: real_imgs = real_imgs.cuda() bbox = bbox.cuda() label_one_hot = label.cuda().float() bbox = bbox.view(-1, 4) transf_matrices_inv = compute_transformation_matrix_inverse(bbox).float() transf_matrices_inv = transf_matrices_inv.view(real_imgs.shape[0], self.max_objects, 2, 3) transf_matrices = compute_transformation_matrix(bbox).float() transf_matrices = transf_matrices.view(real_imgs.shape[0], self.max_objects, 2, 3) ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (noise, transf_matrices_inv, label_one_hot) # _, fake_imgs = nn.parallel.data_parallel(netG, inputs, self.gpus) _, fake_imgs = netG(noise, transf_matrices_inv, label_one_hot) ############################ # (3) Update D network ########################### netD.zero_grad() errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, label_one_hot, transf_matrices, transf_matrices_inv, self.gpus) errD.backward(retain_graph=True) optimizerD.step() ############################ # (2) Update G network ########################### netG.zero_grad() errG = compute_generator_loss(netD, fake_imgs, real_labels, label_one_hot, transf_matrices, transf_matrices_inv, self.gpus) errG_total = errG errG_total.backward() optimizerG.step() ############################ # (3) Log results ########################### count += 1 if i % 500 == 0: summary_D = summary.scalar('D_loss', errD.item()) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.item()) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D_r, count) self.summary_writer.add_summary(summary_D_w, count) self.summary_writer.add_summary(summary_D_f, count) self.summary_writer.add_summary(summary_G, count) # save the image result for each epoch with torch.no_grad(): inputs = (noise, transf_matrices_inv, label_one_hot) lr_fake, fake = nn.parallel.data_parallel(netG, inputs, self.gpus) real_img_cpu = pad_imgs(real_img_cpu) fake = pad_imgs(fake) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) with torch.no_grad(): inputs = (noise, transf_matrices_inv, label_one_hot) lr_fake, fake = nn.parallel.data_parallel(netG, inputs, self.gpus) real_img_cpu = pad_imgs(real_img_cpu) fake = pad_imgs(fake) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) end_t = time.time() print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.item(), errG.item(), errD_real, errD_wrong, errD_fake, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir) # save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir) # self.summary_writer.close() def sample(self, datapath, num_samples=25, draw_bbox=True, num_digits_per_img=3, change_bbox_size=False): from PIL import Image, ImageDraw, ImageFont import cPickle as pickle import torchvision import torchvision.utils as vutils img_dir = os.path.join(datapath, "normal", "imgs/") netG, _ = self.load_network_stageI() netG.eval() test_set_size = 10000 label, bbox = load_validation_data(datapath) if num_digits_per_img < 3: label = label[:, :num_digits_per_img, :] bbox = bbox[:, :num_digits_per_img, ...] elif num_digits_per_img > 3: def get_one_hot(targets, nb_classes): res = np.eye(nb_classes)[np.array(targets).reshape(-1)] return res.reshape(list(targets.shape) + [nb_classes]) labels_sample = np.random.randint(0, 10, size=(bbox.shape[0], num_digits_per_img-3)) labels_sample = get_one_hot(labels_sample, 10) labels_new = np.zeros((label.shape[0], num_digits_per_img, 10)) labels_new[:, :3, :] = label labels_new[:, 3:, :] = labels_sample label = torch.from_numpy(labels_new) bboxes_x = np.random.random((bbox.shape[0], num_digits_per_img-3, 1)) bboxes_y = np.random.random((bbox.shape[0], num_digits_per_img-3, 1)) bboxes_w = np.random.randint(10, 20, size=(bbox.shape[0], num_digits_per_img-3, 1)) / 64.0 bboxes_h = np.random.randint(16, 20, size=(bbox.shape[0], num_digits_per_img-3, 1)) / 64.0 bbox_new_concat = np.concatenate((bboxes_x, bboxes_y, bboxes_w, bboxes_h), axis=2) bbox_new = np.zeros([bbox.shape[0], num_digits_per_img, 4]) bbox_new[:, :3, :] = bbox bbox_new[:, 3:, :] = bbox_new_concat bbox = torch.from_numpy(bbox_new) if change_bbox_size: bbox_idx = np.random.randint(0, bbox.shape[1]) scale_x = np.random.random(bbox.shape[0]) scale_x[scale_x < 0.5] = 0.5 scale_y = np.random.random(bbox.shape[0]) scale_y[scale_y < 0.5] = 0.5 bbox[:, bbox_idx, 2] *= torch.from_numpy(scale_x) bbox[:, bbox_idx, 3] *= torch.from_numpy(scale_y) filepath = os.path.join(datapath, "normal", 'filenames.pickle') with open(filepath, 'rb') as f: filenames = pickle.load(f) # path to save generated samples save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')] + "_samples_" + str(num_digits_per_img) + "_digits" if change_bbox_size: save_dir += "_change_bbox_size" print("Saving {} to {}:".format(num_samples, save_dir)) mkdir_p(save_dir) if cfg.CUDA: bbox = bbox.cuda() label_one_hot = label.cuda().float() ####################################### bbox_ = bbox.clone() bbox = bbox.view(-1, 4) transf_matrices_inv = compute_transformation_matrix_inverse(bbox).float() transf_matrices_inv = transf_matrices_inv.view(test_set_size, num_digits_per_img, 2, 3) ####################################### nz = cfg.Z_DIM noise = Variable(torch.FloatTensor(9, nz)) if cfg.CUDA: noise = noise.cuda() imsize = 64 for count in range(num_samples): index = int(np.random.randint(0, test_set_size, 1)) key = filenames[index].split("/")[-1] img_name = img_dir + key img = Image.open(img_name) val_image = torchvision.transforms.functional.to_tensor(img) val_image = val_image.view(1, 1, imsize, imsize) val_image = (val_image - 0.5) * 2 transf_matrices_inv_batch = transf_matrices_inv[index] label_one_hot_batch = label_one_hot[index] transf_matrices_inv_batch = transf_matrices_inv_batch.view(1, num_digits_per_img, 2, 3).repeat(9, 1, 1, 1) label_one_hot_batch = label_one_hot_batch.view(1, num_digits_per_img, 10).repeat(9, 1, 1) if cfg.CUDA: label_one_hot_batch = label_one_hot_batch.cuda() ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (noise, transf_matrices_inv_batch, label_one_hot_batch, num_digits_per_img) _, fake_imgs = nn.parallel.data_parallel(netG, inputs, self.gpus) data_img = torch.FloatTensor(20, 1, imsize, imsize).fill_(0) data_img[0] = val_image data_img[1:10] = fake_imgs if draw_bbox: for idx in range(num_digits_per_img): x, y, w, h = tuple([int(imsize*x) for x in bbox_[index, idx]]) w = imsize-1 if w > imsize-1 else w h = imsize-1 if h > imsize-1 else h while x + w >= 64: x -= 1 w -= 1 while y + h >= 64: y -= 1 h -= 1 if x <= -1: break data_img[:10, :, y, x:x + w] = 1 data_img[:10, :, y:y + h, x] = 1 data_img[:10, :, y+h, x:x + w] = 1 data_img[:10, :, y:y + h, x + w] = 1 # write digit identities into image text_img = Image.new('L', (imsize*10, imsize), color = 'white') d = ImageDraw.Draw(text_img) label = label_one_hot_batch[0] label = label.cpu().numpy() label = np.argmax(label, axis=1) label = ", ".join([str(label[_]) for _ in range(num_digits_per_img)]) d.text((10,10), label) text_img = torchvision.transforms.functional.to_tensor(text_img) text_img = torch.chunk(text_img, 10, 2) text_img = torch.cat([text_img[i].view(1, 1, imsize, imsize) for i in range(10)], 0) data_img[10:] = text_img vutils.save_image(data_img, '{}/vis_{}.png'.format(save_dir, count), normalize=True, nrow=10) print("Saved {} files to {}".format(count+1, save_dir))
def evaluate(iter_cnt, filepath, model, corpus, args, logging=True): if logging: valid_writer = FileWriter(args.run_path + "/valid", flush_secs=5) pos_file_path = "{}.pos.txt".format(filepath) neg_file_path = "{}.neg.txt".format(filepath) pos_batch_loader = FileLoader( [tuple([pos_file_path, os.path.dirname(args.eval)])], args.batch_size) neg_batch_loader = FileLoader( [tuple([neg_file_path, os.path.dirname(args.eval)])], args.batch_size) batchify = lambda bch: make_batch(model.embedding_layer, bch) model.eval() criterion = model.compute_loss auc_meter = AUCMeter() scores = [np.asarray([], dtype='float32') for i in range(2)] for loader_id, loader in tqdm( enumerate((neg_batch_loader, pos_batch_loader))): for data in tqdm(loader): data = map(corpus.get, data) batch = None if not args.eval_use_content: batch = (batchify(data[0][0]), batchify(data[1][0])) else: batch = (map(batchify, data[0]), map(batchify, data[1])) new_batch = [] for x in batch: for y in x: new_batch.append(y) batch = new_batch labels = torch.ones(batch[0].size(1)).type( torch.LongTensor) * loader_id if args.cuda: batch = [x.cuda() for x in batch] labels = labels.cuda() if not args.eval_use_content: batch = (Variable(batch[0], volatile=True), Variable(batch[1], volatile=True)) else: batch = (Variable(batch[0], volatile=True), Variable(batch[1], volatile=True), Variable(batch[2], volatile=True), Variable(batch[3], volatile=True)) labels = Variable(labels) if not args.eval_use_content: repr_left = model(batch[0]) repr_right = model(batch[1]) else: repr_left = model(batch[0]) + model(batch[1]) repr_right = model(batch[2]) + model(batch[3]) output = model.compute_similarity(repr_left, repr_right) if model.criterion.startswith('classification'): assert output.size(1) == 2 output = nn.functional.log_softmax(output) current_scores = -output[:, loader_id].data.cpu().squeeze( ).numpy() output = output[:, 1] else: assert output.size(1) == 1 current_scores = output.data.cpu().squeeze().numpy() auc_meter.add(output.data, labels.data) scores[loader_id] = np.append(scores[loader_id], current_scores) auc_score = auc_meter.value() auc10_score = auc_meter.value(0.1) auc05_score = auc_meter.value(0.05) auc02_score = auc_meter.value(0.02) auc01_score = auc_meter.value(0.01) if model.criterion.startswith('classification'): avg_score = (scores[1].mean() + scores[0].mean()) * 0.5 else: avg_score = scores[1].mean() - scores[0].mean() outputManager.say( "\r[{}] auc(.01): {:.3f} auc(.02): {:.3f} auc(.05): {:.3f}" " auc(.1): {:.3f} auc: {:.3f}" " scores: {:.2f} ({:.2f} {:.2f})\n".format( os.path.basename(filepath).split('.')[0], auc01_score, auc02_score, auc05_score, auc10_score, auc_score, avg_score, scores[1].mean(), scores[0].mean())) if logging: s = summary.scalar('auc', auc_score) valid_writer.add_summary(s, iter_cnt) s = summary.scalar('auc (fpr<0.1)', auc10_score) valid_writer.add_summary(s, iter_cnt) s = summary.scalar('auc (fpr<0.05)', auc05_score) valid_writer.add_summary(s, iter_cnt) s = summary.scalar('auc (fpr<0.02)', auc02_score) valid_writer.add_summary(s, iter_cnt) s = summary.scalar('auc (fpr<0.01)', auc01_score) valid_writer.add_summary(s, iter_cnt) valid_writer.close() return auc05_score
class condGANTrainer(object): def __init__(self, output_dir, label_loader, unlabel_loader): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') self.theta = 0.5 mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL self.label_loader = label_loader self.unlabel_loader = unlabel_loader self.num_batches = len(self.unlabel_loader) self.label_num_batches = len(self.label_loader) def prepare_data(self, data): imgs, labels, label_vectors = data #print(labels) #error_label_vectors = self.get_error_label(labels) real_vimgs = [] if cfg.CUDA: vembedding = Variable(label_vectors).cuda() labels = Variable(labels).cuda() #error_label_vectors = Variable(error_label_vectors).cuda() else: vembedding = Variable(label_vectors) labels = Variable(labels) #error_label_vectors = Variable(error_label_vectors) for i in range(self.num_Ds): print("images_{}:{}".format(i, imgs[i].size())) if cfg.CUDA: real_vimgs.append(Variable(imgs[i]).cuda()) else: real_vimgs.append(Variable(imgs[i])) return imgs, real_vimgs, labels, vembedding #, error_label_vectors def get_extend(self, labels): ll = torch.zeros(self.batch_size, 1) extend_label = torch.cat((labels, Variable(ll)), 1) return extend_label def piece_wise(self, logits, count): logits_1 = logits > (0.5 - self.theta) logits_1 = logits_1.type(torch.FloatTensor) if cfg.CUDA: logits_1 = logits_1.cuda() logits_2 = logits_1 * logits logits_3 = logits > (0.5 + self.theta) logits_3 = logits_3.type(torch.FloatTensor) if cfg.CUDA: logits_3 = logits_3.cuda() logits_4 = (logits_3 + logits_2) - (logits_3 * logits_2) #print(logits_4) return logits_4 def log_sum_exp(self, value, dim=None, keepdim=True): if dim is not None: m, _ = torch.max(value, dim=dim, keepdim=True) value0 = value - m if keepdim is False: m = m.squeeze(dim) return m + torch.log( torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim)) else: m = torch.max(value) sum_exp = torch.sum(torch.exp(value - m)) if isinstance(sum_exp, Number): return m + math.log(sum_exp) else: return m + torch.log(sum_exp) def train_Dnet(self, idx, count): flag = count % 25 batch_size = cfg.TRAIN.BATCH_SIZE print("batch_size:%d" % batch_size) criterion = self.criterion netD, optD = self.netsD[idx], self.optimizersD[idx] label_imgs = self.label_real_imgs[idx] unlabel_imgs = self.unlabel_real_imgs[idx] fake_imgs = self.fake_imgs[idx] fake_imgs_2 = self.fake_imgs_2[idx] netD.zero_grad() lab_labels = self.labels[:batch_size].type(torch.LongTensor) if cfg.CUDA: lab_labels = lab_labels.cuda() error_labels = self.error_labels[:batch_size] print("unlabel_imgs:{}".format(unlabel_imgs.size())) unlabel_logits, unlabel_softmax_out, unlabel_hash_logits, _ = netD( unlabel_imgs) label_logits, label_softmax_out, label_hash_logits, _ = netD( label_imgs) fake_logits, fake_softmax_out, fake_hash_logits, _ = netD( fake_imgs.detach()) fake2_logits, fake2_softmax_out, fake2_hash_logits, _ = netD( fake_imgs_2.detach()) # standard classfication loss lab_loss = criterion(label_logits, lab_labels) fake_lab_loss = criterion(fake_logits, lab_labels) fake2_lab_loss = criterion(fake_logits, error_labels) supvised_loss = (lab_loss + fake_lab_loss + fake2_lab_loss) / 3 # GAN true-fake loss adversary stream unl_logsumexp = self.log_sum_exp(unlabel_logits, 1) fake_logsumexp = self.log_sum_exp(fake_logits, 1) fake2_logsumexp = self.log_sum_exp(fake2_logits, 1) true_loss = -0.5 * torch.mean(unl_logsumexp) + 0.5 * torch.mean( F.softplus(unl_logsumexp)) fake_loss = 0.5 * torch.mean(F.softplus(fake_logsumexp)) fake2_loss = 0.5 * torch.mean(F.softplus(fake2_logsumexp)) adversary_loss = (true_loss + fake_loss + fake2_loss) / 3 # loss for hash print("label_hash:{},fake_hash{}".format(label_hash_logits.size(), fake_logits.size())) positive = torch.sum((label_hash_logits - fake_hash_logits)**2, 1) negtive = torch.sum((label_hash_logits - fake2_hash_logits)**2, 1) hash_loss = 1 + positive - negtive hash_loss_temp = hash_loss > 0 hash_loss_temp = hash_loss_temp.type(torch.FloatTensor) if cfg.CUDA: hash_loss_temp = hash_loss_temp.cuda() hash_loss = torch.mean(hash_loss * hash_loss_temp) d_total_loss = supvised_loss + adversary_loss + hash_loss print("d_supervied_loss_{0}:{1}".format(idx, supvised_loss.data[0])) print("d_adversary_loss_{0}:{1}".format(idx, adversary_loss.data[0])) print("d_hash_loss_{0}:{1}".format(idx, hash_loss.data[0])) print("d_total_loss_{0}:{1}".format(idx, d_total_loss.data[0])) # adversary stream # for true # errD_real = criterion(real_logits, real_labels) # #errD_wrong = criterion(wrong_logits[0], fake_labels) # errD_fake = criterion(fake_logits, fake_labels) # # if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: # # errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ # # criterion(real_logits[1], real_labels) # # # errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ # # # criterion(wrong_logits[1], real_labels) # # errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \ # # criterion(fake_logits[1], fake_labels) # # # # # errD_real = errD_real + errD_real_uncond # # # errD_wrong = errD_wrong + errD_wrong_uncond # # errD_fake = errD_fake + errD_fake_uncond # # # # # errD = errD_real + errD_wrong + errD_fake # # else: # # errD = errD_real + 0.5 * (errD_wrong + errD_fake) # errD = errD_real + errD_fake # # # backward # errD.backward() # update parameters d_total_loss.backward() optD.step() # log if flag == 0: summary_D = summary.scalar('D_supervised%d' % idx, supvised_loss.data[0]) summary_D1 = summary.scalar('D_hash_loss_%d' % idx, hash_loss.data[0]) summary_D2 = summary.scalar('D_total_loss_%d' % idx, d_total_loss.data[0]) summary_D3 = summary.scalar('D_adversary_loss_%d' % idx, adversary_loss.data[0]) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D1, count) self.summary_writer.add_summary(summary_D2, count) self.summary_writer.add_summary(summary_D3, count) return d_total_loss def train_Gnet(self, count): self.netG.zero_grad() errG_total = 0 flag = count % 25 batch_size = cfg.TRAIN.BATCH_SIZE print("batch_size:%d" % batch_size) criterion = self.criterion for i in range(self.num_Ds): netD = self.netsD[i] label_imgs = self.label_real_imgs[i] unlabel_imgs = self.unlabel_real_imgs[i] fake_imgs = self.fake_imgs[i] fake_imgs_2 = self.fake_imgs_2[i] lab_labels = self.labels[:batch_size].type(torch.LongTensor) if cfg.CUDA: lab_labels = lab_labels.cuda() error_labels = self.error_labels[:batch_size] label_logits, label_softmax_out, label_hash_logits, _ = netD( label_imgs) unlabel_logits, unlabel_softmax_out, unlabel_hash_logits, _ = netD( unlabel_imgs) fake_logits, fake_softmax_out, fake_hash_logits, _ = netD( fake_imgs) fake2_logits, fake2_softmax_out, fake2_hash_logits, _ = netD( fake_imgs_2) # standard classfication loss lab_loss = criterion(label_logits, lab_labels) fake_lab_loss = criterion(fake_logits, lab_labels) fake2_lab_loss = criterion(fake_logits, error_labels) supvised_loss = (lab_loss + fake_lab_loss + fake2_lab_loss) / 3 # GAN true-fake loss adversary stream unl_logsumexp = self.log_sum_exp(unlabel_logits) fake_logsumexp = self.log_sum_exp(fake_logits) fake2_logsumexp = self.log_sum_exp(fake2_logits) true_loss = -0.5 * torch.mean(unl_logsumexp) + 0.5 * torch.mean( F.softplus(unl_logsumexp)) fake_loss = 0.5 * torch.mean(F.softplus(fake_logsumexp)) fake2_loss = 0.5 * torch.mean(F.softplus(fake2_logsumexp)) adversary_loss = (true_loss + fake_loss + fake2_loss) / 3 # loss for hash positive = torch.sum((label_hash_logits - fake_hash_logits)**2, 1) negtive = torch.sum((label_hash_logits - fake2_hash_logits)**2, 1) hash_loss = 1 + positive - negtive hash_loss_temp = hash_loss > 0 hash_loss_temp = hash_loss_temp.type(torch.FloatTensor) if cfg.CUDA: hash_loss_temp = hash_loss_temp.cuda() hash_loss = torch.mean(hash_loss * hash_loss_temp) g_total_loss = supvised_loss - adversary_loss + hash_loss print("g_supervied_loss_{0}:{1}".format(i, supvised_loss.data[0])) print("g_adversary_loss_{0}:{1}".format(i, adversary_loss.data[0])) print("g_hash_loss_{0}:{1}".format(i, hash_loss.data[0])) print("g_total_loss_{0}:{1}".format(i, g_total_loss.data[0])) # if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0: # errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS * \ # criterion(outputs[1], real_labels) # errG = errG + errG_patch errG_total = errG_total + g_total_loss # + hash_loss print("g_loss_%d: %f" % (i, errG_total.data[0])) if flag == 0: summary_D = summary.scalar('G_supervised%d' % i, supvised_loss.data[0]) summary_D1 = summary.scalar('G_hash_loss_%d' % i, hash_loss.data[0]) summary_D2 = summary.scalar('G_total_loss_%d' % i, g_total_loss.data[0]) summary_D3 = summary.scalar('G_adversary_loss_%d' % i, adversary_loss.data[0]) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D1, count) self.summary_writer.add_summary(summary_D2, count) self.summary_writer.add_summary(summary_D3, count) # # # Compute color consistency losses # if cfg.TRAIN.COEFF.COLOR_LOSS > 0: # if self.num_Ds > 1: # mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1]) # mu2, covariance2 = \ # compute_mean_covariance(self.fake_imgs[-2].detach()) # like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) # like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ # nn.MSELoss()(covariance1, covariance2) # errG_total = errG_total + like_mu2 + like_cov2 # if flag == 0: # sum_mu = summary.scalar('G_like_mu2', like_mu2.data[0]) # self.summary_writer.add_summary(sum_mu, count) # sum_cov = summary.scalar('G_like_cov2', like_cov2.data[0]) # self.summary_writer.add_summary(sum_cov, count) # if self.num_Ds > 2: # mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2]) # mu2, covariance2 = \ # compute_mean_covariance(self.fake_imgs[-3].detach()) # like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2) # like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \ # nn.MSELoss()(covariance1, covariance2) # errG_total = errG_total + like_mu1 + like_cov1 # if flag == 0: # sum_mu = summary.scalar('G_like_mu1', like_mu1.data[0]) # self.summary_writer.add_summary(sum_mu, count) # sum_cov = summary.scalar('G_like_cov1', like_cov1.data[0]) # self.summary_writer.add_summary(sum_cov, count) # not use kl_loss # kl_loss = KL_loss(mu, logvar) * cfg.TRAIN.COEFF.KL # errG_total = errG_total + kl_loss errG_total.backward(retain_graph=True) self.optimizerG.step() return errG_total def get_error_label(self, labels): class_num = cfg.GAN.CLASS_NUM batch_size = cfg.TRAIN.BATCH_SIZE error_labels_vector = torch.FloatTensor(batch_size, class_num).zero_() error_labels = [] for i in range(batch_size): m = randint(0, class_num - 1) while m == labels[i]: m = randint(0, 9) error_labels_vector[i][m] = 1 error_labels.append(m) error_labels = torch.LongTensor(error_labels) if cfg.CUDA: error_labels_vector = Variable(error_labels_vector).cuda() error_labels = Variable(error_labels).cuda() else: error_labels_vector = Variable(error_labels_vector) error_labels = Variable(error_labels) return error_labels_vector, error_labels def adjust_lr(self, optimizer, epoch): epoch_ratio = float(epoch) / float(cfg.TRAIN.MAX_EPOCH) lr = max(cfg.TRAIN.DISCRIMINATOR_LR * min(3. * (1 - epoch_ratio), 1.), 0) for param_group in optimizer.param_groups: param_group['lr'] = lr def train(self): self.netG, self.netsD, self.num_Ds, start_count = load_network( self.gpus) avg_param_G = copy_G_params(self.netG) self.optimizerG, self.optimizersD = \ define_optimizers(self.netG, self.netsD) self.criterion = nn.CrossEntropyLoss() # self.real_labels = \ # Variable(torch.FloatTensor(self.batch_size).fill_(1)) # self.fake_labels = \ # Variable(torch.FloatTensor(self.batch_size).fill_(0)) self.gradient_one = torch.FloatTensor([1.0]) self.gradient_half = torch.FloatTensor([0.5]) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) noise_2 = Variable(torch.FloatTensor(self.batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(self.batch_size, nz))#.normal_(0, 1)) if cfg.CUDA: self.criterion.cuda() # self.real_labels = self.real_labels.cuda() # self.fake_labels = self.fake_labels.cuda() self.gradient_one = self.gradient_one.cuda() self.gradient_half = self.gradient_half.cuda() noise, fixed_noise = noise.cuda(), fixed_noise.cuda() predictions = [] count = start_count start_epoch = start_count // (self.num_batches) label_iter = iter(self.label_loader) cnt = 0 print("label_num_batch:{}".format(self.label_num_batches)) for epoch in range(start_epoch, self.max_epoch): start_t = time.time() for step, unlabel_data in enumerate(self.unlabel_loader, 0): # print("data size:{0}".format(data)) print("epoch:%d, step:%d" % (epoch, step)) ####################################################### # (0) Prepare training data ###################################################### if cnt == self.label_num_batches: label_iter = iter(self.label_loader) cnt = 0 cnt += 1 label_data = label_iter.next() #print(label_data[0]) self.label_imgs_tcpu, self.label_real_imgs, \ self.labels, self.label_vectors = self.prepare_data(label_data) self.unlabel_imgs_tcpu, self.unlabel_real_imgs, \ _, self.unlabel_vectors = self.prepare_data(unlabel_data) self.error_label_vector, self.error_labels = self.get_error_label( label_data[1]) ####################################################### # (1) Generate fake images ###################################################### noise.data.normal_(0, 1) noise_2.data.normal_(0, 1) # self.fake_imgs = \ self.netG(noise, self.label_vectors) self.fake_imgs_2 = self.netG(noise_2, self.error_label_vector) ####################################################### # (2) Update D network ###################################################### errD_total = 0 for i in range(self.num_Ds): errD = self.train_Dnet(i, count) errD_total += errD ####################################################### # (3) Update G network: maximize log(D(G(z))) ###################################################### self.train_Gnet(count) for p, avg_p in zip(self.netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) errG_total = self.train_Gnet(count) for p, avg_p in zip(self.netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) """ # for inception score pred = self.inception_model(self.fake_imgs[-1].detach()) predictions.append(pred.data.cpu().numpy()) """ if count % 25 == 0: summary_D = summary.scalar('D_loss', errD_total.data[0]) summary_G = summary.scalar('G_loss', errG_total.data[0]) # summary_KL = summary.scalar('KL_loss', kl_loss.data[0]) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_G, count) # self.summary_writer.add_summary(summary_KL, count) count = count + 1 if count % 3000 == 0: self.theta = self.theta * 0.8 print("theta:%f, count:%d" % (self.theta, count)) if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) # Save images backup_para = copy_G_params(self.netG) load_params(self.netG, avg_param_G) # fixed_noise.data.normal_(0, 1) self.fake_imgs = \ self.netG(fixed_noise, self.label_vectors) save_img_results(self.label_imgs_tcpu, self.fake_imgs, self.num_Ds, count, self.image_dir, self.summary_writer) # load_params(self.netG, backup_para) # # Compute inception score # if len(predictions) > 500: # predictions = np.concatenate(predictions, 0) # mean, std = compute_inception_score(predictions, 10) # # print('mean:', mean, 'std', std) # m_incep = summary.scalar('Inception_mean', mean) # self.summary_writer.add_summary(m_incep, count) # # # mean_nlpp, std_nlpp = \ # negative_log_posterior_probability(predictions, 10) # m_nlpp = summary.scalar('NLPP_mean', mean_nlpp) # self.summary_writer.add_summary(m_nlpp, count) # # # predictions = [] print( "____________________________________________________________" ) end_t = time.time() self.adjust_lr(self.optimizerG, epoch) for i in range(cfg.TREE.BRANCH_NUM): self.adjust_lr(self.optimizersD[i], epoch) print('''[%d/%d][%d] Loss_D: %.2f Loss_G: %.2f Time: %.2fs ''' # D(real): %.4f D(wrong):%.4f D(fake) %.4f % (epoch, self.max_epoch, self.num_batches, errD_total.data[0], errG_total.data[0], end_t - start_t)) save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) self.summary_writer.close() def save_superimages(self, images_list, filenames, save_dir, split_dir, imsize): batch_size = images_list[0].size(0) num_sentences = len(images_list) for i in range(batch_size): s_tmp = '%s/super/%s/%s' % \ (save_dir, split_dir, filenames[i]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) # savename = '%s_%d.png' % (s_tmp, imsize) super_img = [] for j in range(num_sentences): img = images_list[j][i] # print(img.size()) img = img.view(1, 3, imsize, imsize) # print(img.size()) super_img.append(img) # break super_img = torch.cat(super_img, 0) vutils.save_image(super_img, savename, nrow=10, normalize=True) def save_singleimages(self, images, filenames, save_dir, split_dir, sentenceID, imsize): for i in range(images.size(0)): s_tmp = '%s/single_samples/%s/%s' % \ (save_dir, split_dir, filenames[i]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) fullpath = '%s_%d_sentence%d.png' % (s_tmp, imsize, sentenceID) # range from [-1, 1] to [0, 255] img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte() ndarr = img.permute(1, 2, 0).data.cpu().numpy() im = Image.fromarray(ndarr) im.save(fullpath) def evaluate(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: # Build and load the generator if split_dir == 'image_test': split_dir = 'valid' netG = G_NET() netG.apply(weights_init) netG = torch.nn.DataParallel(netG, device_ids=self.gpus) print(netG) # state_dict = torch.load(cfg.TRAIN.NET_G) state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_G) # the path to save generated images s_tmp = cfg.TRAIN.NET_G istart = s_tmp.rfind('_') + 1 iend = s_tmp.rfind('.') iteration = int(s_tmp[istart:iend]) s_tmp = s_tmp[:s_tmp.rfind('/')] save_dir = '%s/iteration%d' % (s_tmp, iteration) nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(self.batch_size, nz)) if cfg.CUDA: netG.cuda() noise = noise.cuda() # switch to evaluate mode netG.eval() for step, data in enumerate(self.data_loader, 0): imgs, t_embeddings, filenames = data if cfg.CUDA: t_embeddings = Variable(t_embeddings).cuda() else: t_embeddings = Variable(t_embeddings) # print(t_embeddings[:, 0, :], t_embeddings.size(1)) embedding_dim = t_embeddings.size(1) batch_size = imgs[0].size(0) noise.data.resize_(batch_size, nz) noise.data.normal_(0, 1) fake_img_list = [] for i in range(embedding_dim): fake_imgs, _, _ = netG(noise, t_embeddings[:, i, :]) if cfg.TEST.B_EXAMPLE: # fake_img_list.append(fake_imgs[0].data.cpu()) # fake_img_list.append(fake_imgs[1].data.cpu()) fake_img_list.append(fake_imgs[2].data.cpu()) else: self.save_singleimages(fake_imgs[-1], filenames, save_dir, split_dir, i, 256) # self.save_singleimages(fake_imgs[-2], filenames, # save_dir, split_dir, i, 128) # self.save_singleimages(fake_imgs[-3], filenames, # save_dir, split_dir, i, 64) # break if cfg.TEST.B_EXAMPLE: # self.save_superimages(fake_img_list, filenames, # save_dir, split_dir, 64) # self.save_superimages(fake_img_list, filenames, # save_dir, split_dir, 128) self.save_superimages(fake_img_list, filenames, save_dir, split_dir, 256) def get_hash(self, dataloader, netsD, dataset_name, save_steps): hash_dict = {} #imgs_total = None label_total = None output_dir = os.path.join("eval", dataset_name) print("len:%d" % len(self.unlabel_loader)) #save_steps = len(self.unlabel_loader) / 10 if not os.path.exists(output_dir): os.mkdir(output_dir) for step, data in enumerate(dataloader, 0): imgs_tcpu, real_imgs, \ labels, label_vectors = self.prepare_data(data) for i in range(cfg.TREE.BRANCH_NUM): net = netsD[i] real_img = real_imgs[i] real_logits, softmax_out, hash_logits, _ = net(real_img) #hash_logits = hash_logits > 0.5 if i in hash_dict: hash_dict[i] = torch.cat((hash_dict[i], hash_logits), 0) else: hash_dict[i] = hash_logits if label_total is None: label_total = labels imgs_total = real_imgs[0] else: label_total = torch.cat((label_total, labels), 0) imgs_total = torch.cat((imgs_total, real_imgs[0]), 0) if (step + 1) % save_steps == 0: cnt = (step + 1) / save_steps output_img = os.path.join( output_dir, "%s_images_%d.npy" % (dataset_name, cnt)) output_label = os.path.join( output_dir, "%s_label_%d.npy" % (dataset_name, cnt)) if cfg.CUDA: np.save(output_img, imgs_total.cpu().data.numpy()) np.save(output_label, label_total.cpu().data.numpy()) else: np.save(output_img, imgs_total.data.numpy()) np.save(output_label, label_total.data.numpy()) for i in range(cfg.TREE.BRANCH_NUM): output_hash = os.path.join( output_dir, "branch_%d_hash_%s_%d.npy" % (i, dataset_name, cnt)) if cfg.CUDA: np.save(output_hash, hash_dict[i].cpu().data.numpy()) else: np.save(output_hash, hash_dict[i].data.numpy()) imgs_total = None label_total = None hash_dict = {} print("step %d done!" % step) print("step %d done!" % step) def load_Dnet(self, gpus): if cfg.TRAIN.NET_D == '': print('Error: the path for morels is not found!') sys.exit(-1) else: netsD = [] if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) if cfg.TREE.BRANCH_NUM > 3: netsD.append(D_NET512()) if cfg.TREE.BRANCH_NUM > 4: netsD.append(D_NET1024()) self.num_Ds = len(netsD) for i in range(len(netsD)): netsD[i].apply(weights_init) netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus) for i in range(cfg.TREE.BRANCH_NUM): print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i)) state_dict = torch.load( '%snetD%d_70000.pth' % (cfg.TRAIN.NET_D, i), map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) if cfg.CUDA: for i in range(len(netsD)): netsD[i].cuda() return netsD def get_numpy(self, files): arr = None print(files) for f in files: a = np.load(f) #print("a:{0}".format(a.shape)) if arr is not None: arr = np.concatenate((arr, a), axis=0) else: arr = a return arr def compute_MAP_sklearn(self, test_features, db_features, test_label, db_label, metric='euclidean'): Y = cdist(test_features, db_features, metric) ind = np.argsort(Y, axis=1) prec_total = 0.0 recall_total = None precision_total = None for k in range(np.shape(test_features)[0]): class_values = db_label[ind[k, :]] y_true = (test_label[k] == class_values) y_scores = np.arange(y_true.shape[0], 0, -1) ap = average_precision_score(y_true, y_scores) prec_total += ap if recall_total is None: precision_total, recall_total, r = precision_recall_curve( y_true, y_scores) print("precision_shape:{}".format(precision_total.shape)) print("recall_shape:{}".format(recall_total.shape)) else: precision, recall, r = precision_recall_curve(y_true, y_scores) #if k % 100 == 0: #print(r) #print("precision_shape:{}".format(precision.shape)) #print("recall_shape:{}".format(recall.shape)) #precision_total = precision_total + precision #recall_total = recall_total + recall print(precision[5800:]) print(recall[:20]) test_num = test_features.shape[0] print("test_num:{}".format(test_num)) MAP = prec_total / test_num #recall_total = [i/test_num for i in recall_total] #precision_total = [i / test_num for i in precision_total] print("MAP: %f" % MAP) #print("recall:{0}".format(recall_total[:30])) # print("precision:{0}".format(precision_total[:30])) with open(metric + "_result.txt", 'w') as f: f.write("MAP:%f\n" % MAP) #f.write("recall\n:{0}".format(recall_total)) #f.write("precision\n:{0}".format(precision_total)) #np.save("recall.npy", recall_total) #np.save("precision.npy", precision_total) def compute_MAP(self, root, branch, query_labels, db_labels): query_npys = os.listdir(os.path.join(root, "test")) db_npys = os.listdir(os.path.join(root, "db")) query_npys.sort() db_npys.sort() query_hash_path = [ os.path.join(root, "test", i) for i in query_npys if "branch_" + str(branch) in i ] db_hash_path = [ os.path.join(root, "db", i) for i in db_npys if "branch_" + str(branch) in i ] query_hashs = self.get_numpy(query_hash_path) db_hashs = self.get_numpy(db_hash_path) assert db_labels.shape[0] == db_hashs.shape[ 0], "db labels num must be equal to db hash num" assert query_labels.shape[0] == query_hashs.shape[ 0], "db labels num must be equal to db hash num" print( "-----------------------------------use features----------------------" ) self.compute_MAP_sklearn(query_hashs, db_hashs, query_labels, db_labels) print( "-----------------------------------use hash--------------------------" ) query_h = query_hashs > 0.5 db_h = db_hashs > 0.5 self.compute_MAP_sklearn(query_h, db_h, query_labels, db_labels) return db_hashs, query_hashs # query_labels = query_labels.tolist() # recall_num = 5900 # MAP = 0 # cnt = 0 # precision_topk = [0 for i in range(len(query_labels) / 10 + 1)] # recall_curve = [0 for i in range(len(query_labels) / 25 + 1)] # precision_curve = [0 for i in range(len(query_labels) / 25 + 1)] # for i in range(query_labels.shape[0]): # hamming_dis = 0.0 # for h in query_hashs: # q = h[i] # hamming_dis += np.sum((h - q) ** 2, axis=1) # hamming_dis = hamming_dis.tolist() # hamming_label = zip(query_labels, hamming_dis) # sorted_by_hamming = sorted(hamming_label, key=lambda m: m[1]) # # smi = 0.0 # count = 0 # ap = 0 # for label, hamming in sorted_by_hamming: # count += 1 # if label == query_labels[i]: # smi += 1 # ap += smi / count # if count % 10 == 0: # precision_topk[count / 10] += smi / count # # if count % 25 == 0: # recall_curve[count / 25] += smi / recall_num # precision_curve[count / 25] += smi / count # #print("ap:%f"%(ap / smi)) # MAP += ap / smi # cnt += 1 # if cnt % 100 == 0: # print("has process %d queries" % cnt) # # MAP = MAP / len(query_labels) # precision_topk = [i / len(query_labels) for i in precision_topk] # recall_curve = [i / len(query_labels) for i in recall_curve] # precision_curve = [i / len(query_labels) for i in precision_curve] # # print("MAP_BRANCH_%d: %f" % (branch, MAP)) # print("precision_top:{0}".format(precision_topk[0])) # print("recall curve:{0}".format(recall_curve[0])) # print("precision curve:{0}".format(precision_curve[0])) # # f = open("branch_%d_eval" % branch, 'w') # f.write("MAP: %f\n" % (MAP)) # f.write("precision_top:\n") # self.write(f, precision_topk, "precision_top") # self.write(f, recall_curve, "recall_curve") # self.write(f, precision_curve, "precision_curve") # f.close() def compute_MAP_hash(self, root, paths, branch): query_label_path, query_hash_path = paths query_labels = self.get_numpy(root, query_label_path) query_hash = self.get_numpy(root, query_hash_path) print("total test number:%d" % query_hash.shape[0]) assert query_hash.shape[0] == query_hash.shape[ 0], "query hash size not equal to query label size" query_hash = (query_hash > 0.5) + 0 query_labels = query_labels.tolist() recall_num = 100 r_1 = np.matmul(query_hash, query_hash.T) r_2 = np.matmul(1 - query_hash, (1 - query_hash).T) r = r_1 + r_2 hamming_distance = cfg.GAN.HASH_DIM - r hamming_distance_list = hamming_distance.tolist() query_hamming = zip(query_labels, hamming_distance_list) MAP = 0 cnt = 0 precision_topk = [0 for i in range(len(query_labels) / 10 + 1)] recall_curve = [0 for i in range(len(query_labels) / 25 + 1)] precision_curve = [0 for i in range(len(query_labels) / 25 + 1)] for q_label, hamming_dis in query_hamming: hamming_label = zip(query_labels, hamming_dis) sorted_by_hamming = sorted(hamming_label, key=lambda m: m[1]) smi = 0.0 count = 0 ap = 0 for label, hamming in sorted_by_hamming: count += 1 if label == q_label: smi += 1 ap += smi / count if count % 10 == 0: precision_topk[count / 10] += smi / count if count % 25 == 0: recall_curve[count / 25] += smi / recall_num precision_curve[count / 25] += smi / count # print("ap:%f"%(ap / smi)) MAP += ap / smi cnt += 1 if cnt % 100 == 0: print("has process %d queries" % cnt) MAP = MAP / len(query_labels) precision_topk = [i / len(query_labels) for i in precision_topk] recall_curve = [i / len(query_labels) for i in recall_curve] precision_curve = [i / len(query_labels) for i in precision_curve] print("MAP_BRANCH_%d: %f" % (branch, MAP)) print("precision_top:{0}".format(precision_topk)) print("recall curve:{0}".format(recall_curve)) print("precision curve:{0}".format(precision_curve)) f = open("branch_%d_eval" % branch, 'w') f.write("MAP: %f\n" % (MAP)) f.write("precision_top:\n") self.write(f, precision_topk, "precision_top") self.write(f, recall_curve, "recall_curve") self.write(f, precision_curve, "precision_curve") f.close() def evaluate_MAP(self, db_dataloader, query_dataloader, root): netsD = self.load_Dnet(self.gpus) if len(os.listdir(os.path.join(root, "test"))) == 0: save_steps = len(query_dataloader) / 10 self.get_hash(query_dataloader, netsD, "test", save_steps) print("get query image hash") if len(os.listdir(os.path.join(root, "db"))) == 0: save_steps = len(db_dataloader) / 10 self.get_hash(db_dataloader, netsD, "db", save_steps) print("get db image hash!") eval_path = root db_hashs = None query_hashs = None query_npys = os.listdir(os.path.join(root, "test")) query_npys.sort() query_label_path = [ os.path.join(root, "test", i) for i in query_npys if "label" in i ] db_npys = os.listdir(os.path.join(root, "db")) db_npys.sort() db_label_path = [ os.path.join(root, "db", i) for i in db_npys if "label" in i ] query_labels = self.get_numpy(query_label_path) db_labels = self.get_numpy(db_label_path) for i in range(cfg.TREE.BRANCH_NUM): print( "--------------------------branch %d-------------------------------------------" % i) db_hash, query_hash = self.compute_MAP(eval_path, i, query_labels, db_labels) if db_hashs is None: db_hashs = db_hash query_hashs = query_hash else: db_hashs += db_hash query_hashs += query_hash db_hashs /= cfg.TREE.BRANCH_NUM query_hashs /= cfg.TREE.BRANCH_NUM print( "--------------------------------total--------------------------------------------" ) print( "-----------------------------------use features----------------------" ) self.compute_MAP_sklearn(query_hashs, db_hashs, query_labels, db_labels) print( "-----------------------------------use hash--------------------------" ) db_hashs = db_hashs > 0.5 query_hashs = query_hashs > 0.5 self.compute_MAP_sklearn(query_hashs, db_hashs, query_labels, db_labels, metric="hamming")
class GANTrainer(object): def __init__(self, output_dir): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') self.log_dir = os.path.join(output_dir, 'Log') mkdir_p(self.model_dir) mkdir_p(self.image_dir) mkdir_p(self.log_dir) self.summary_writer = FileWriter(self.log_dir) self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True # ############# For training stageI GAN ############# def load_network_stageI(self): from model import STAGE1_G, STAGE1_D netG = STAGE1_G() netG.apply(weights_init) print(netG) netD = STAGE1_D() netD.apply(weights_init) print(netD) if cfg.NET_G != '': state_dict = \ torch.load(cfg.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load from: ', cfg.NET_G) if cfg.NET_D != '': state_dict = \ torch.load(cfg.NET_D, map_location=lambda storage, loc: storage) netD.load_state_dict(state_dict) print('Load from: ', cfg.NET_D) if cfg.CUDA: netG.cuda() netD.cuda() start_epoch = 0 if cfg.NET_G != '': istart = cfg.NET_G.rfind('_') + 1 iend = cfg.NET_G.rfind('.') start_epoch = cfg.NET_G[istart:iend] start_epoch = int(start_epoch) + 1 return netG, netD, start_epoch # ############# For training stageII GAN ############# def load_network_stageII(self): from model import STAGE1_G, STAGE2_G, STAGE2_D Stage1_G = STAGE1_G() netG = STAGE2_G(Stage1_G) netG.apply(weights_init) print(netG) if cfg.NET_G != '': state_dict = \ torch.load(cfg.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load from: ', cfg.NET_G) elif cfg.STAGE1_G != '': state_dict = \ torch.load(cfg.STAGE1_G, map_location=lambda storage, loc: storage) netG.STAGE1_G.load_state_dict(state_dict) print('Load from: ', cfg.STAGE1_G) else: print("Please give the Stage1_G path") return netD = STAGE2_D() netD.apply(weights_init) if cfg.NET_D != '': state_dict = \ torch.load(cfg.NET_D, map_location=lambda storage, loc: storage) netD.load_state_dict(state_dict) print('Load from: ', cfg.NET_D) print(netD) if cfg.CUDA: netG.cuda() netD.cuda() start_epoch = 0 if cfg.NET_G != '': istart = cfg.NET_G.rfind('_') + 1 iend = cfg.NET_G.rfind('.') start_epoch = cfg.NET_G[istart:iend] start_epoch = int(start_epoch) + 1 return netG, netD, start_epoch def load_optimizers(self, netG, netD): optimizerD = optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) if cfg.NET_G != '': Gname = cfg.NET_G s_tmp = Gname[:Gname.rfind('/')] oGname = '%s/optimizerG.pth' % (s_tmp) state_dict = \ torch.load(oGname, map_location=lambda storage, loc: storage) optimizerG.load_state_dict(state_dict) print('Load optimizerG from: ', oGname) oDname = '%s/optimizerD.pth' % (s_tmp) state_dict = \ torch.load(oDname, map_location=lambda storage, loc: storage) optimizerD.load_state_dict(state_dict) print('Load optimizerD from: ', oDname) return optimizerG, optimizerD def train(self, data_loader, stage=1): if stage == 1: netG, netD, start_epoch = self.load_network_stageI() else: netG, netD, start_epoch = self.load_network_stageII() nz = cfg.Z_DIM batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), requires_grad=False) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH optimizerG, optimizerD = self.load_optimizers(netG, netD) count = 0 for epoch in range(start_epoch, self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr for i, data in enumerate(data_loader, 0): ###################################################### # (1) Prepare training data ###################################################### real_img_cpu, txt_embedding = data real_imgs = Variable(real_img_cpu) txt_embedding = Variable(txt_embedding) if cfg.CUDA: real_imgs = real_imgs.cuda() txt_embedding = txt_embedding.float().cuda() ###################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (txt_embedding, noise) _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) ###################################################### # (3) Update D network ###################################################### netD.zero_grad() errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, mu, self.gpus) errD.backward() optimizerD.step() ###################################################### # (2) Update G network ###################################################### netG.zero_grad() errG = compute_generator_loss(netD, fake_imgs, real_labels, mu, self.gpus) kl_loss = KL_loss(mu, logvar) errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL errG_total.backward() optimizerG.step() count = count + 1 if i % 100 == 0: summary_D = summary.scalar('D_loss', errD.item()) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.item()) summary_KL = summary.scalar('KL_loss', kl_loss.item()) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D_r, count) self.summary_writer.add_summary(summary_D_w, count) self.summary_writer.add_summary(summary_D_f, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) # save the image result for each epoch inputs = (txt_embedding, fixed_noise) lr_fake, fake, _, _ = \ nn.parallel.data_parallel(netG, inputs, self.gpus) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) end_t = time.time() print( '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.item(), errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD, epoch, self.model_dir) save_optimizer(optimizerG, optimizerD, self.model_dir) save_model(netG, netD, self.max_epoch, self.model_dir) self.summary_writer.close() def sample(self, datapath, stage=1): if stage == 1: netG, _, _ = self.load_network_stageI() else: netG, _, _ = self.load_network_stageII() netG.eval() # Load text embeddings generated from the encoder embeddings = np.load(datapath) num_embeddings = embeddings.shape[0] print('Successfully load sentences from: ', datapath) print('num_embeddings:', num_embeddings, embeddings.shape) # path to save generated samples save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')] mkdir_p(save_dir) batch_size = np.minimum(num_embeddings, self.batch_size) nz = cfg.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz)) if cfg.CUDA: noise = noise.cuda() count = 0 while count < num_embeddings: # if count > 3000: # break iend = count + batch_size if iend > num_embeddings: iend = num_embeddings embeddings_batch = embeddings[count:iend] txt_embedding = Variable(torch.FloatTensor(embeddings_batch)) if cfg.CUDA: txt_embedding = txt_embedding.cuda() ####################################################### # (2) Generate fake images ####################################################### noise.data.normal_(0, 1) inputs = (txt_embedding, noise[0:embeddings_batch.shape[0], :]) _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) for i in range(embeddings_batch.shape[0]): save_name = '%s/%d.png' % (save_dir, count + i + 1) im = fake_imgs[i].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) # print('im', im.shape) im = np.transpose(im, (1, 2, 0)) # print('im', im.shape) im = Image.fromarray(im) im.save(save_name) count += batch_size def sample2(self, datapath, stage=2): if stage == 1: netG, _, _ = self.load_network_stageI() else: netG, _, _ = self.load_network_stageII() netG.eval() # Load text embeddings generated from the encoder embeddings = np.load(datapath) num_embeddings = embeddings.shape[0] print('Successfully load sentences from: ', datapath) print('num_embeddings:', num_embeddings, embeddings.shape) # path to save generated samples save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')] mkdir_p(save_dir) batch_size = np.minimum(num_embeddings, self.batch_size) nz = cfg.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz)) if cfg.CUDA: noise = noise.cuda() count = 0 while count < num_embeddings: # if count > 3000: # break iend = count + batch_size if iend > num_embeddings: iend = num_embeddings embeddings_batch = embeddings[count:iend] txt_embedding = Variable(torch.FloatTensor(embeddings_batch)) if cfg.CUDA: txt_embedding = txt_embedding.cuda() ####################################################### # (2) Generate fake images ####################################################### noise.data.normal_(0, 1) inputs = (txt_embedding, noise[0:embeddings_batch.shape[0], :]) lr_fake_imgs, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) for i in range(embeddings_batch.shape[0]): im = fake_imgs[i].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) # print('im', im.shape) im = np.transpose(im, (1, 2, 0)) # print('im', im.shape) im = Image.fromarray(im) save_name = '%s/%d_%d.png' % (save_dir, im.size[0], count + i + 1) im.save(save_name) for i in range(embeddings_batch.shape[0]): im = lr_fake_imgs[i].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) # print('im', im.shape) im = np.transpose(im, (1, 2, 0)) # print('im', im.shape) im = Image.fromarray(im) save_name = '%s/%d_%d.png' % (save_dir, im.size[0], count + i + 1) im.save(save_name) count += batch_size
def train(): '''train wgan ''' ctxs = [mx.gpu(int(i)) for i in args.gpus.split(',')] batch_size = args.batch_size z_dim = args.z_dim lr = args.lr epoches = args.epoches wclip = args.wclip frequency = args.frequency model_prefix = args.model_prefix rand_iter = RandIter(batch_size, z_dim) image_iter = ImageIter(args.data_path, batch_size, (3, 64, 64)) # G and D symG, symD = dcgan64x64(ngf=args.ngf, ndf=args.ndf, nc=args.nc) modG = mx.mod.Module(symbol=symG, data_names=('rand',), label_names=None, context=ctxs) modG.bind(data_shapes=rand_iter.provide_data) modG.init_params(initializer=mx.init.Normal(0.002)) modG.init_optimizer( optimizer='sgd', optimizer_params={ 'learning_rate': lr, }) modD = mx.mod.Module(symbol=symD, data_names=('data',), label_names=None, context=ctxs) modD.bind(data_shapes=image_iter.provide_data, inputs_need_grad=True) modD.init_params(mx.init.Normal(0.002)) modD.init_optimizer( optimizer='sgd', optimizer_params={ 'learning_rate': lr, }) # train logging.info('Start training') metricD = WGANMetric() metricG = WGANMetric() fix_noise_batch = mx.io.DataBatch([mx.random.normal(0, 1, shape=(batch_size, z_dim, 1, 1))], []) # visualization with TensorBoard if possible if use_tb: writer = FileWriter('tmp/exp') for epoch in range(epoches): image_iter.reset() metricD.reset() metricG.reset() for i, batch in enumerate(image_iter): # clip weight for params in modD._exec_group.param_arrays: for param in params: mx.nd.clip(param, -wclip, wclip, out=param) # forward G rbatch = rand_iter.next() modG.forward(rbatch, is_train=True) outG = modG.get_outputs() # fake modD.forward(mx.io.DataBatch(outG, label=[]), is_train=True) fw_g = modD.get_outputs()[0].asnumpy() modD.backward([mx.nd.ones((batch_size, 1)) / batch_size]) gradD = [[grad.copyto(grad.context) for grad in grads] for grads in modD._exec_group.grad_arrays] # real modD.forward(batch, is_train=True) fw_r = modD.get_outputs()[0].asnumpy() modD.backward([-mx.nd.ones((batch_size, 1)) / batch_size]) for grads_real, grads_fake in zip(modD._exec_group.grad_arrays, gradD): for grad_real, grad_fake in zip(grads_real, grads_fake): grad_real += grad_fake modD.update() errorD = -(fw_r - fw_g) / batch_size metricD.update(errorD.mean()) # update G rbatch = rand_iter.next() modG.forward(rbatch, is_train=True) outG = modG.get_outputs() modD.forward(mx.io.DataBatch(outG, []), is_train=True) errorG = -modD.get_outputs()[0] / batch_size modD.backward([-mx.nd.ones((batch_size, 1)) / batch_size]) modG.backward(modD.get_input_grads()) modG.update() metricG.update(errorG.asnumpy().mean()) # logging state if (i+1)%frequency == 0: print("epoch:", epoch+1, "iter:", i+1, "G: ", metricG.get(), "D: ", metricD.get()) # save checkpoint modG.save_checkpoint('model/%s-G'%(model_prefix), epoch+1) modD.save_checkpoint('model/%s-D'%(model_prefix), epoch+1) rbatch = rand_iter.next() modG.forward(rbatch) outG = modG.get_outputs()[0] canvas = visual('tmp/gout-rand-%d.png'%(epoch+1), outG.asnumpy()) if use_tb: canvas = canvas[:, :, ::-1] # BGR -> RGB writer.add_summary(summary.image('gout-rand-%d'%(epoch+1), canvas)) modG.forward(fix_noise_batch) outG = modG.get_outputs()[0] canvas = visual('tmp/gout-fix-%d.png'%(epoch+1), outG.asnumpy()) if use_tb: canvas = canvas[:, :, ::-1] writer.add_summary(summary.image('gout-fix-%d'%(epoch+1), canvas)) if use_tb: writer.flush() writer.close()