def train(self, data_loader, stage=1): if stage == 1: netG, netD, start_epoch = self.load_network_stageI() else: netG, netD, start_epoch = self.load_network_stageII() nz = cfg.Z_DIM batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), requires_grad=False) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH optimizerG, optimizerD = self.load_optimizers(netG, netD) count = 0 for epoch in range(start_epoch, self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr for i, data in enumerate(data_loader, 0): ###################################################### # (1) Prepare training data ###################################################### real_img_cpu, txt_embedding = data real_imgs = Variable(real_img_cpu) txt_embedding = Variable(txt_embedding) if cfg.CUDA: real_imgs = real_imgs.cuda() txt_embedding = txt_embedding.float().cuda() ###################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (txt_embedding, noise) _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) ###################################################### # (3) Update D network ###################################################### netD.zero_grad() errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, mu, self.gpus) errD.backward() optimizerD.step() ###################################################### # (2) Update G network ###################################################### netG.zero_grad() errG = compute_generator_loss(netD, fake_imgs, real_labels, mu, self.gpus) kl_loss = KL_loss(mu, logvar) errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL errG_total.backward() optimizerG.step() count = count + 1 if i % 100 == 0: summary_D = summary.scalar('D_loss', errD.item()) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.item()) summary_KL = summary.scalar('KL_loss', kl_loss.item()) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D_r, count) self.summary_writer.add_summary(summary_D_w, count) self.summary_writer.add_summary(summary_D_f, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) # save the image result for each epoch inputs = (txt_embedding, fixed_noise) lr_fake, fake, _, _ = \ nn.parallel.data_parallel(netG, inputs, self.gpus) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) end_t = time.time() print( '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.item(), errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD, epoch, self.model_dir) save_optimizer(optimizerG, optimizerD, self.model_dir) save_model(netG, netD, self.max_epoch, self.model_dir) self.summary_writer.close()
def train(self, imageloader, storyloader, testloader): self.imageloader = imageloader self.testloader = testloader self.imagedataset = None self.testdataset = None netG, netD_im, netD_st = self.load_networks() im_real_labels = Variable(torch.FloatTensor(self.imbatch_size).fill_(1)) im_fake_labels = Variable(torch.FloatTensor(self.imbatch_size).fill_(0)) st_real_labels = Variable(torch.FloatTensor(self.stbatch_size).fill_(1)) st_fake_labels = Variable(torch.FloatTensor(self.stbatch_size).fill_(0)) if cfg.CUDA: im_real_labels, im_fake_labels = im_real_labels.cuda(), im_fake_labels.cuda() st_real_labels, st_fake_labels = st_real_labels.cuda(), st_fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH im_optimizerD = \ optim.Adam(netD_im.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) st_optimizerD = \ optim.Adam(netD_st.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) if self.tensorboard: self.build_tensorboard() loss = {} step = 0 torch.save({ 'netG': netG, 'netD_im': netD_im, 'netD_st': netD_st, }, os.path.join(self.model_dir, 'barebone.pth')) for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in st_optimizerD.param_groups: param_group['lr'] = discriminator_lr for param_group in im_optimizerD.param_groups: param_group['lr'] = discriminator_lr loss.update({ 'D/lr': discriminator_lr, 'G/lr': generator_lr, }) print('Epoch [{}/{}]:'.format(epoch, self.max_epoch)) with tqdm(total=len(storyloader), dynamic_ncols=True) as pbar: for i, data in enumerate(storyloader, 0): ###################################################### # (1) Prepare training data ###################################################### im_batch = self.sample_real_image_batch() st_batch = data im_real_cpu = im_batch['images'] im_motion_input = im_batch['description'] im_content_input = im_batch['content'] im_content_input = im_content_input.mean(1).squeeze() im_catelabel = im_batch['label'] im_real_imgs = Variable(im_real_cpu) im_motion_input = Variable(im_motion_input) im_content_input = Variable(im_content_input) st_real_cpu = st_batch['images'] st_motion_input = st_batch['description'] st_content_input = st_batch['description'] st_catelabel = st_batch['label'] st_real_imgs = Variable(st_real_cpu) st_motion_input = Variable(st_motion_input) st_content_input = Variable(st_content_input) if cfg.CUDA: st_real_imgs = st_real_imgs.cuda() im_real_imgs = im_real_imgs.cuda() st_motion_input = st_motion_input.cuda() im_motion_input = im_motion_input.cuda() st_content_input = st_content_input.cuda() im_content_input = im_content_input.cuda() im_catelabel = im_catelabel.cuda() st_catelabel = st_catelabel.cuda() ####################################################### # (2) Generate fake stories and images ###################################################### with torch.no_grad(): im_inputs = (im_motion_input, im_content_input) _, im_fake, im_mu, im_logvar = netG.sample_images(*im_inputs) st_inputs = (st_motion_input, st_content_input) _, st_fake, c_mu, c_logvar, m_mu, m_logvar = netG.sample_videos(*st_inputs) ############################ # (3) Update D network ########################### netD_im.zero_grad() netD_st.zero_grad() im_errD, im_errD_real, im_errD_wrong, im_errD_fake, accD = \ compute_discriminator_loss(netD_im, im_real_imgs, im_fake, im_real_labels, im_fake_labels, im_catelabel, im_mu, self.gpus) st_errD, st_errD_real, st_errD_wrong, st_errD_fake, _ = \ compute_discriminator_loss(netD_st, st_real_imgs, st_fake, st_real_labels, st_fake_labels, st_catelabel, c_mu, self.gpus) loss.update({ 'D/story/loss': st_errD.data, 'D/story/real_loss': st_errD_real.data, 'D/story/fake_loss': st_errD_fake.data, 'D/image/accuracy': accD, 'D/image/loss': im_errD.data, 'D/image/real_loss': im_errD_real.data, 'D/image/fake_loss': im_errD_fake.data, }) im_errD.backward() st_errD.backward() im_optimizerD.step() st_optimizerD.step() ############################ # (2) Update G network ########################### for g_iter in range(2): netG.zero_grad() _, st_fake, c_mu, c_logvar, m_mu, m_logvar = netG.sample_videos( st_motion_input, st_content_input) _, im_fake, im_mu, im_logvar = netG.sample_images(im_motion_input, im_content_input) im_errG, accG = compute_generator_loss(netD_im, im_fake, im_real_labels, im_catelabel, im_mu, self.gpus) st_errG, _ = compute_generator_loss(netD_st, st_fake, st_real_labels, st_catelabel, c_mu, self.gpus) im_kl_loss = KL_loss(im_mu, im_logvar) st_kl_loss = KL_loss(m_mu, m_logvar) errG = im_errG + self.ratio * st_errG kl_loss = im_kl_loss + self.ratio * st_kl_loss loss.update({ 'G/loss': im_errG.data, 'G/kl': kl_loss.data, }) errG_total = im_errG + self.ratio * st_errG + kl_loss errG_total.backward() optimizerG.step() if self.writer: for key, value in loss.items(): self.writer.add_scalar(key, value, step) step += 1 pbar.update(1) if i % 100 == 0: # save the image result for each epoch lr_fake, fake, _, _, _, _ = netG.sample_videos(st_motion_input, st_content_input) save_story_results(st_real_cpu, fake, epoch, self.image_dir, writer=self.writer, steps=step) if lr_fake is not None: save_story_results(None, lr_fake, epoch, self.image_dir, writer=self.writer, steps=step) end_t = time.time() print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f accG: %.4f accD: %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(storyloader), st_errD.data, st_errG.data, st_errD_real, st_errD_wrong, st_errD_fake, accG, accD, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD_im, netD_st, epoch, self.model_dir) save_test_samples(netG, self.testloader, self.test_dir, writer=self.writer, steps=step) # save_model(netG, netD_im, netD_st, self.max_epoch, self.model_dir)
def train(self, data_loader, stage=1): logger = Logger('./logs_CS_GAN') image_transform_train = transforms.Compose([ transforms.ToPILImage(), transforms.Resize([64, 64]), transforms.ToTensor() ]) CT_update = 35 if cfg.CTModel == '' else 0 print("Training CT model for ", CT_update) if stage == 1: netG, netD = self.load_network_stageI() else: netG, netD = self.load_network_stageII() ####### nz = cfg.Z_DIM if not cfg.CAP.USE else cfg.CAP.Z_DIM batch_size = self.batch_size flags = Variable(torch.cuda.FloatTensor([-1.0] * batch_size)) noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), requires_grad=False) fixed_noise_test = \ Variable(torch.FloatTensor(10, nz).normal_(0, 1), requires_grad=False) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) #Gaussian noise input added to the input images to the disc noise_input = Variable( torch.zeros(batch_size, 3, cfg.FAKEIMSIZE, cfg.FAKEIMSIZE)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() noise_input = noise_input.cuda() flags.cuda() epsilon = 0.999 epsilon_decay = 0.99 generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH optimizerD = \ optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR,betas=(0.5, 0.999)) netG_para = [] # self.emb_model=EMB(512,128) for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) ####Optimizers for CT c ##########################TODO:PRINT PARAMETERS!!!! optimizerCTallmodel = optim.Adam(self.CTallmodel.parameters(), lr=0.0001, weight_decay=0.00001, betas=(0.5, 0.999)) optimizerCTenc = optim.Adam(self.CTencoder.parameters(), lr=0.0001, weight_decay=0.00001, betas=(0.5, 0.999)) count = 0 len_dataset = len(data_loader) for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr print("Started training for new epoch") optimizerCTallmodel.zero_grad() ct_epoch_loss = 0 emb_loss = 0 epoch_count = 0 for i, data in enumerate(data_loader): ###################################################### # (1) Prepare training data ###################################################### real_img_cpu, sentences, paddedArrayPrev, maskArrayPrev, paddedArrayCurr, Currlenghts, paddedArrayNext, maskArrayNext = data self.CTallmodel.encoder.hidden = self.CTallmodel.encoder.hidden_init( paddedArrayCurr.size(1)) real_imgs = Variable(real_img_cpu) paddedArrayCurr = Variable( paddedArrayCurr.type(torch.LongTensor)) paddedArrayNext_input = Variable(paddedArrayNext[:-1, :].type( torch.LongTensor)) paddedArrayPrev_input = Variable(paddedArrayPrev[:-1, :].type( torch.LongTensor)) if cfg.CUDA: real_imgs = real_imgs.cuda() paddedArrayCurr = paddedArrayCurr.cuda() paddedArrayNext_input = paddedArrayNext_input.cuda() paddedArrayPrev_input = paddedArrayPrev_input.cuda() inputs_CT = (paddedArrayCurr, Currlenghts, paddedArrayPrev_input, paddedArrayNext_input) # sent_hidden, logits_prev, logits_next = self.CTallmodel(paddedArrayCurr, Currlenghts, paddedArrayPrev_input, paddedArrayNext_input) sent_hidden, logits_prev, logits_next = nn.parallel.data_parallel( self.CTallmodel, inputs_CT, self.gpus) #Optimizing over Concurrent model if (epoch < CT_update): logits_prev = logits_prev.contiguous().view( -1, logits_prev.size()[2]) logits_next = logits_next.contiguous().view( -1, logits_next.size()[2]) Y_prev = paddedArrayPrev[1:, :] Y_prev = Y_prev.contiguous().view(-1) Y_next = paddedArrayNext[1:, :] Y_next = Y_next.contiguous().view(-1) maskArrayPrev = maskArrayPrev[1:, :] maskArrayPrev = maskArrayPrev.contiguous().view(-1) maskArrayNext = maskArrayNext[1:, :] maskArrayNext = maskArrayNext.contiguous().view(-1) ind_prev = torch.nonzero(maskArrayPrev, out=None).squeeze() ind_next = torch.nonzero(maskArrayNext, out=None).squeeze() if torch.cuda.is_available(): ind_prev = ind_prev.cuda() ind_next = ind_next.cuda() valid_target_prev = torch.index_select( Y_prev, 0, ind_prev.type(torch.LongTensor)).type(torch.LongTensor) valid_output_prev = torch.index_select( logits_prev, 0, Variable(ind_prev)) valid_target_next = torch.index_select( Y_next, 0, ind_next.type(torch.LongTensor)).type(torch.LongTensor) valid_output_next = torch.index_select( logits_next, 0, Variable(ind_next)) if torch.cuda.is_available(): valid_output_prev = valid_output_prev.cuda() valid_output_next = valid_output_next.cuda() valid_target_prev = valid_target_prev.cuda() valid_target_next = valid_target_next.cuda() loss_prev = self.CTloss(valid_output_prev, Variable(valid_target_prev)) loss_next = self.CTloss(valid_output_next, Variable(valid_target_next)) self.CTallmodel.zero_grad() optimizerCTallmodel.zero_grad() loss = loss_prev + loss_next loss.backward(retain_graph=True) ct_epoch_loss += loss.data[0] nn.utils.clip_grad_norm(self.CTallmodel.parameters(), 0.25) optimizerCTallmodel.step() if epoch >= CT_update: ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (sent_hidden, noise) _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) #### TODO: Check Shapes->Checked # _,fake_imgs,mu,logvar=netG(inputs[0],inputs[1]) ####################################################### # (2.1) Generate captions for fake images ###################################################### if self.cap_model_bool: sents, h_sent = self.eval_utils.captioning_model( fake_imgs, self.cap_model, self.vocab_cap, self.my_resnet, self.eval_kwargs) h_sent_var = Variable(torch.FloatTensor(h_sent)).cuda() # input_layer = tf.stack([preprocess_for_train(i) for i in real_imgs], axis=0) real_imgs = Variable( torch.stack([ image_transform_train(img.data.cpu()).cuda() for img in real_imgs ], dim=0)) ############################ # (3) Update D network ########################### if random.uniform(0, 1) < epsilon and cfg.GAN.ADD_NOISE: epsilon *= epsilon_decay noise_input.data.normal_(0, 1) fake_imgs = fake_imgs + noise_input real_imgs = real_imgs + noise_input netD.zero_grad() errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, mu, self.gpus) errD.backward() optimizerD.step() #Label Switching #Trick as of - https://github.com/soumith/ganhacks/issues/14 # if random.uniform(0.1)<epsilon: # netD.zero_grad() # errD, errD_real, errD_wrong, errD_fake = \ # compute_discriminator_loss(netD, real_imgs, fake_imgs, # fake_labels, real_labels, # mu, self.gpus) # errD.backward() # optimizerD.step() ############################ # (4) Update G network ########################### if self.cap_model_bool: loss_cos = self.cosEmbLoss(sent_hidden, h_sent_var, flags) netG.zero_grad() errG = compute_generator_loss(netD, fake_imgs, real_labels, mu, self.gpus) kl_loss = KL_loss(mu, logvar) if self.cap_model_bool: errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL + 10 * loss_cos emb_loss += loss_cos.data[0] else: errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL errG_total.backward() optimizerG.step() count = count + 1 epoch_count += 1 if i % 200 == 0: print("Loss CT Model: ", ct_epoch_loss / epoch_count) # print("Emb Loss: ", emb_loss) # save the image result for each epoch after embedding model has been trained if epoch >= CT_update: inputs = (sent_hidden, fixed_noise) lr_fake, fake, _, _ = \ nn.parallel.data_parallel(netG, inputs, self.gpus) if self.cap_model_bool: save_img_results(real_img_cpu, fake, epoch, self.image_dir, sentences, sents) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir, sentences, sents) else: save_img_results(real_img_cpu, fake, epoch, self.image_dir, sentences, None) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir, sentences, None) self.test(netG, fixed_noise_test, epoch) end_t = time.time() print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.data[0], errG.data[0], kl_loss.data[0], errD_real, errD_wrong, errD_fake, (end_t - start_t))) # logger.scalar_summary('Cosine_loss', emb_loss, epoch+1) logger.scalar_summary('errD_loss', errD.data[0] / len_dataset, epoch + 1) logger.scalar_summary('errG_loss', errG.data[0] / len_dataset, epoch + 1) logger.scalar_summary('kl_loss', kl_loss.data[0] / len_dataset, epoch + 1) if epoch % self.snapshot_interval == 0: save_model(netG, netD, self.CTallmodel, epoch, self.model_dir) logger.scalar_summary('CT_loss', ct_epoch_loss / len_dataset, epoch + 1) save_model(netG, netD, self.CTallmodel, self.max_epoch, self.model_dir)
def train(self, imageloader, storyloader, testloader, stage=1): c_time = time.time() self.imageloader = imageloader self.imagedataset = None netG, netD_im, netD_st, netD_se = self.load_network_stageI() start = time.time() # Initial Labels im_real_labels = Variable( torch.FloatTensor(self.imbatch_size).fill_(1)) im_fake_labels = Variable( torch.FloatTensor(self.imbatch_size).fill_(0)) st_real_labels = Variable( torch.FloatTensor(self.stbatch_size).fill_(1)) st_fake_labels = Variable( torch.FloatTensor(self.stbatch_size).fill_(0)) if cfg.CUDA: im_real_labels, im_fake_labels = im_real_labels.cuda( ), im_fake_labels.cuda() st_real_labels, st_fake_labels = st_real_labels.cuda( ), st_fake_labels.cuda() use_segment = cfg.SEGMENT_LEARNING segment_weight = cfg.SEGMENT_RATIO image_weight = cfg.IMAGE_RATIO # Optimizer and Scheduler generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH im_optimizerD = optim.Adam(netD_im.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) st_optimizerD = optim.Adam(netD_st.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) if use_segment: se_optimizerD = optim.Adam(netD_se.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) mse_loss = nn.MSELoss() scheduler_imD = ReduceLROnPlateau(im_optimizerD, 'min', verbose=True, factor=0.5, min_lr=1e-7, patience=0) scheduler_stD = ReduceLROnPlateau(st_optimizerD, 'min', verbose=True, factor=0.5, min_lr=1e-7, patience=0) if use_segment: scheduler_seD = ReduceLROnPlateau(se_optimizerD, 'min', verbose=True, factor=0.5, min_lr=1e-7, patience=0) scheduler_G = ReduceLROnPlateau(optimizerG, 'min', verbose=True, factor=0.5, min_lr=1e-7, patience=0) count = 0 # Start training if not self.con_ckpt: start_epoch = 0 else: start_epoch = int(self.con_ckpt) # self.calculate_vfid(netG, 0, testloader) print('LR DECAY EPOCH: {}'.format(lr_decay_step)) for epoch in range(start_epoch, self.max_epoch): l = self.ratio * (2. / (1. + np.exp(-10. * epoch)) - 1) start_t = time.time() # Adjust lr num_step = len(storyloader) stats = {} with tqdm(total=len(storyloader), dynamic_ncols=True) as pbar: for i, data in enumerate(storyloader): ###################################################### # (1) Prepare training data ###################################################### im_batch = self.sample_real_image_batch() st_batch = data im_real_cpu = im_batch['images'] im_motion_input = im_batch[ 'description'][:, :cfg.TEXT. DIMENSION] # description vector and arrtibute (60, 356) im_content_input = im_batch[ 'content'][:, :, :cfg.TEXT. DIMENSION] # description vector and attribute for every story (60,5,356) im_real_imgs = Variable(im_real_cpu) im_motion_input = Variable(im_motion_input) im_content_input = Variable(im_content_input) im_labels = Variable(im_batch['labels']) st_real_cpu = st_batch['images'] st_motion_input = st_batch[ 'description'][:, :, :cfg.TEXT.DIMENSION] #(12,5,356) st_content_input = st_batch[ 'description'][:, :, :cfg.TEXT.DIMENSION] # (12,5,356) st_texts = None if 'text' in st_batch: st_texts = st_batch['text'] st_real_imgs = Variable(st_real_cpu) st_motion_input = Variable(st_motion_input) st_content_input = Variable(st_content_input) st_labels = Variable(st_batch['labels']) # (12,5,9) if use_segment: se_real_cpu = im_batch['images_seg'] se_real_imgs = Variable(se_real_cpu) if cfg.CUDA: st_real_imgs = st_real_imgs.cuda() # (12,3,5,64,64) im_real_imgs = im_real_imgs.cuda() st_motion_input = st_motion_input.cuda() im_motion_input = im_motion_input.cuda() st_content_input = st_content_input.cuda() im_content_input = im_content_input.cuda() im_labels = im_labels.cuda() st_labels = st_labels.cuda() if use_segment: se_real_imgs = se_real_imgs.cuda() im_motion_input = torch.cat((im_motion_input, im_labels), 1) # 356+9=365 (60,365) st_motion_input = torch.cat((st_motion_input, st_labels), 2) # (12,5,365) ####################################################### # (2) Generate fake stories and images ###################################################### # print(st_motion_input.shape, im_motion_input.shape) with torch.no_grad(): _, st_fake, m_mu, m_logvar, c_mu, c_logvar, _ = \ netG.sample_videos(st_motion_input, st_content_input) # m_mu (60,365), c_mu (12,124) _, im_fake, im_mu, im_logvar, cim_mu, cim_logvar, se_fake = \ netG.sample_images(im_motion_input, im_content_input, seg=use_segment) # im_mu (60,489), cim_mu (60,124) characters_mu = ( st_labels.mean(1) > 0 ).type(torch.FloatTensor).cuda( ) # which character exists in the full story (5 descriptions) st_mu = torch.cat( (c_mu, st_motion_input[:, :, :cfg.TEXT.DIMENSION].mean( 1).squeeze(), characters_mu), 1) # 124 + 356 + 9 = 489 (12,489), get character info form whole story im_mu = torch.cat((im_motion_input, cim_mu), 1) # (60,489) ############################ # (3) Update D network ########################### netD_im.zero_grad() netD_st.zero_grad() se_accD = 0 if use_segment: netD_se.zero_grad() se_errD, se_errD_real, se_errD_wrong, se_errD_fake, se_accD, _ = \ compute_discriminator_loss(netD_se, se_real_imgs, se_fake, im_real_labels, im_fake_labels, im_labels, im_mu, self.gpus) im_errD, im_errD_real, im_errD_wrong, im_errD_fake, im_accD, _ = \ compute_discriminator_loss(netD_im, im_real_imgs, im_fake, im_real_labels, im_fake_labels, im_labels, im_mu, self.gpus) st_errD, st_errD_real, st_errD_wrong, st_errD_fake, _, order_consistency = \ compute_discriminator_loss(netD_st, st_real_imgs, st_fake, st_real_labels, st_fake_labels, st_labels, st_mu, self.gpus) if use_segment: se_errD.backward() se_optimizerD.step() stats.update({ 'seg_D/loss': se_errD.data, 'seg_D/real': se_errD_real, 'seg_D/fake': se_errD_fake, }) im_errD.backward() st_errD.backward() im_optimizerD.step() st_optimizerD.step() stats.update({ 'img_D/loss': im_errD.data, 'img_D/real': im_errD_real, 'img_D/fake': im_errD_fake, 'Accuracy/im_D': im_accD, 'Accuracy/se_D': se_accD, }) step = i + num_step * epoch self._logger.add_scalar('st_D/loss', st_errD.data, step) self._logger.add_scalar('st_D/real', st_errD_real, step) self._logger.add_scalar('st_D/fake', st_errD_fake, step) self._logger.add_scalar('st_D/order', order_consistency, step) ############################ # (2) Update G network ########################### netG.zero_grad() video_latents, st_fake, m_mu, m_logvar, c_mu, c_logvar, _ = netG.sample_videos( st_motion_input, st_content_input) image_latents, im_fake, im_mu, im_logvar, cim_mu, cim_logvar, se_fake = netG.sample_images( im_motion_input, im_content_input, seg=use_segment) encoder_decoder_loss = 0 if video_latents is not None: ((h_seg1, h_seg2, h_seg3, h_seg4), (g_seg1, g_seg2, g_seg3, g_seg4)) = video_latents video_latent_loss = mse_loss( g_seg1, h_seg1) + mse_loss(g_seg2, h_seg2) + mse_loss( g_seg3, h_seg3) + mse_loss(g_seg4, h_seg4) ((h_seg1, h_seg2, h_seg3, h_seg4), (g_seg1, g_seg2, g_seg3, g_seg4)) = image_latents image_latent_loss = mse_loss( g_seg1, h_seg1) + mse_loss(g_seg2, h_seg2) + mse_loss( g_seg3, h_seg3) + mse_loss(g_seg4, h_seg4) encoder_decoder_loss = (image_latent_loss + video_latent_loss) / 2 reconstruct_img = netG.train_autoencoder(se_real_imgs) reconstruct_fake = netG.train_autoencoder(se_fake) reconstruct_loss = ( mse_loss(reconstruct_img, se_real_imgs) + mse_loss(reconstruct_fake, se_fake)) / 2.0 self._logger.add_scalar('G/image_vae_loss', image_latent_loss.data, step) self._logger.add_scalar('G/video_vae_loss', video_latent_loss.data, step) self._logger.add_scalar('G/reconstruct_loss', reconstruct_loss.data, step) characters_mu = (st_labels.mean(1) > 0).type( torch.FloatTensor).cuda() st_mu = torch.cat( (c_mu, st_motion_input[:, :, :cfg.TEXT.DIMENSION].mean( 1).squeeze(), characters_mu), 1) im_mu = torch.cat((im_motion_input, cim_mu), 1) se_errG, se_errG, se_accG = 0, 0, 0 if use_segment: se_errG, se_accG, _ = compute_generator_loss( netD_se, se_fake, se_real_imgs, im_real_labels, im_labels, im_mu, self.gpus) im_errG, im_accG, _ = compute_generator_loss( netD_im, im_fake, im_real_imgs, im_real_labels, im_labels, im_mu, self.gpus) st_errG, st_accG, G_consistency = compute_generator_loss( netD_st, st_fake, st_real_imgs, st_real_labels, st_labels, st_mu, self.gpus) ###### # Sample Image Loss and Sample Video Loss im_kl_loss = KL_loss(cim_mu, cim_logvar) st_kl_loss = KL_loss(c_mu, c_logvar) errG = im_errG + self.ratio * ( image_weight * st_errG + se_errG * segment_weight ) # for record kl_loss = im_kl_loss + self.ratio * st_kl_loss # for record # Total Loss errG_total = im_errG + im_kl_loss * cfg.TRAIN.COEFF.KL \ + self.ratio * (se_errG*segment_weight + st_errG*image_weight + st_kl_loss * cfg.TRAIN.COEFF.KL) if video_latents is not None: errG_total += (video_latent_loss + reconstruct_loss) * cfg.RECONSTRUCT_LOSS errG_total.backward() optimizerG.step() stats.update({ 'G/loss': errG_total.data, 'G/im_KL': im_kl_loss.data, 'G/st_KL': st_kl_loss.data, 'G/KL': kl_loss.data, 'G/consistency': G_consistency, 'Accuracy/im_G': im_accG, 'Accuracy/se_G': se_accG, 'Accuracy/st_G': st_accG, 'G/gan_loss': errG.data, }) count = count + 1 pbar.update(1) if i % 20 == 0: step = i + num_step * epoch for key, value in stats.items(): self._logger.add_scalar(key, value, step) with torch.no_grad(): lr_fake, fake, _, _, _, _, se_fake = netG.sample_videos( st_motion_input, st_content_input, seg=use_segment) st_result = save_story_results(st_real_cpu, fake, st_texts, epoch, self.image_dir, i) if use_segment and se_fake is not None: se_result = save_image_results(None, se_fake) self._logger.add_image("pororo", st_result.transpose(2, 0, 1) / 255, epoch) if use_segment: self._logger.add_image("segment", se_result.transpose(2, 0, 1) / 255, epoch) # Adjust lr if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in st_optimizerD.param_groups: param_group['lr'] = discriminator_lr for param_group in im_optimizerD.param_groups: param_group['lr'] = discriminator_lr lr_decay_step *= 2 g_lr, im_lr, st_lr = 0, 0, 0 for param_group in optimizerG.param_groups: g_lr = param_group['lr'] for param_group in st_optimizerD.param_groups: st_lr = param_group['lr'] for param_group in im_optimizerD.param_groups: im_lr = param_group['lr'] self._logger.add_scalar('learning/generator', g_lr, epoch) self._logger.add_scalar('learning/st_discriminator', st_lr, epoch) self._logger.add_scalar('learning/im_discriminator', im_lr, epoch) if cfg.EVALUATE_FID_SCORE: self.calculate_vfid(netG, epoch, testloader) #self.calculate_ssim(netG, epoch, testloader) time_mins = int((time.time() - c_time) / 60) time_hours = int(time_mins / 60) epoch_mins = int((time.time() - start_t) / 60) epoch_hours = int(epoch_mins / 60) print( "----[{}/{}]Epoch time:{} hours {} mins, Total time:{} hours----" .format(epoch, self.max_epoch, epoch_hours, epoch_mins, time_hours)) if epoch % self.snapshot_interval == 0: save_model(netG, netD_im, netD_st, netD_se, epoch, self.model_dir) #save_test_samples(netG, testloader, self.test_dir) save_model(netG, netD_im, netD_st, netD_se, self.max_epoch, self.model_dir)
def train(self, imageloader, storyloader, testloader): self.imageloader = imageloader self.testloader = testloader self.imagedataset = None self.testdataset = None netG, netD_im, netD_st = self.load_networks() im_real_labels = Variable(torch.FloatTensor(self.imbatch_size).fill_(1)) im_fake_labels = Variable(torch.FloatTensor(self.imbatch_size).fill_(0)) st_real_labels = Variable(torch.FloatTensor(self.stbatch_size).fill_(1)) st_fake_labels = Variable(torch.FloatTensor(self.stbatch_size).fill_(0)) if cfg.CUDA: im_real_labels, im_fake_labels = im_real_labels.cuda(), im_fake_labels.cuda() st_real_labels, st_fake_labels = st_real_labels.cuda(), st_fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH im_optimizerD = \ optim.Adam(netD_im.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) st_optimizerD = \ optim.Adam(netD_st.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in st_optimizerD.param_groups: param_group['lr'] = discriminator_lr for param_group in im_optimizerD.param_groups: param_group['lr'] = discriminator_lr for i, data in enumerate(storyloader, 0): ###################################################### # (1) Prepare training data ###################################################### im_batch = self.sample_real_image_batch() st_batch = data im_real_cpu = im_batch['images'] im_motion_input = im_batch['description'] im_content_input = im_batch['content'] im_content_input = im_content_input.mean(1).squeeze() im_catelabel = im_batch['label'] im_real_imgs = Variable(im_real_cpu) im_motion_input = Variable(im_motion_input) im_content_input = Variable(im_content_input) st_real_cpu = st_batch['images'] st_motion_input = st_batch['description'] st_content_input = st_batch['description'] st_catelabel = st_batch['label'] st_real_imgs = Variable(st_real_cpu) st_motion_input = Variable(st_motion_input) st_content_input = Variable(st_content_input) if cfg.CUDA: st_real_imgs = st_real_imgs.cuda() im_real_imgs = im_real_imgs.cuda() st_motion_input = st_motion_input.cuda() im_motion_input = im_motion_input.cuda() st_content_input = st_content_input.cuda() im_content_input = im_content_input.cuda() im_catelabel = im_catelabel.cuda() st_catelabel = st_catelabel.cuda() ####################################################### # (2) Generate fake stories and images ###################################################### # im_inputs = (im_motion_input, im_content_input) # _, im_fake, im_mu, im_logvar =\ # nn.parallel.data_parallel(netG.sample_images, im_inputs, self.gpus) # st_inputs = (st_motion_input, st_content_input) # _, st_fake, c_mu, c_logvar, m_mu, m_logvar = \ # nn.parallel.data_parallel(netG.sample_videos, st_inputs, self.gpus) im_inputs = (im_motion_input, im_content_input) _, im_fake, im_mu, im_logvar = netG.sample_images(im_motion_input, im_content_input) st_inputs = (st_motion_input, st_content_input) _, st_fake, c_mu, c_logvar, m_mu, m_logvar = netG.sample_videos( st_motion_input, st_content_input) ############################ # (3) Update D network ########################### netD_im.zero_grad() netD_st.zero_grad() im_errD, im_errD_real, im_errD_wrong, im_errD_fake, accD = \ compute_discriminator_loss(netD_im, im_real_imgs, im_fake, im_real_labels, im_fake_labels, im_catelabel, im_mu, self.gpus) st_errD, st_errD_real, st_errD_wrong, st_errD_fake, _ = \ compute_discriminator_loss(netD_st, st_real_imgs, st_fake, st_real_labels, st_fake_labels, st_catelabel, c_mu, self.gpus) im_errD.backward() st_errD.backward() im_optimizerD.step() st_optimizerD.step() ############################ # (2) Update G network ########################### for g_iter in range(2): netG.zero_grad() _, st_fake, c_mu, c_logvar, m_mu, m_logvar = netG.sample_videos( st_motion_input, st_content_input) # st_mu = m_mu.view(cfg.TRAIN.ST_BATCH_SIZE, cfg.VIDEO_LEN, m_mu.shape[1]) # st_mu = st_mu.contiguous().view(-1, cfg.VIDEO_LEN * m_mu.shape[1]) _, im_fake, im_mu, im_logvar = netG.sample_images(im_motion_input, im_content_input) im_errG, accG = compute_generator_loss(netD_im, im_fake, im_real_labels, im_catelabel, im_mu, self.gpus) st_errG, _ = compute_generator_loss(netD_st, st_fake, st_real_labels, st_catelabel, c_mu, self.gpus) im_kl_loss = KL_loss(im_mu, im_logvar) st_kl_loss = KL_loss(m_mu, m_logvar) errG = im_errG + self.ratio * st_errG kl_loss = im_kl_loss + self.ratio * st_kl_loss errG_total = im_errG + self.ratio * st_errG + kl_loss errG_total.backward() optimizerG.step() if i % 100 == 0: # save the image result for each epoch lr_fake, fake, _, _, _, _ = netG.sample_videos(st_motion_input, st_content_input) save_story_results(st_real_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_story_results(None, lr_fake, epoch, self.image_dir) end_t = time.time() print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f accG: %.4f accD: %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(storyloader), st_errD.data, st_errG.data, st_errD_real, st_errD_wrong, st_errD_fake, accG, accD, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD_im, netD_st, epoch, self.model_dir) save_test_samples(netG, self.testloader, self.test_dir) # save_model(netG, netD_im, netD_st, self.max_epoch, self.model_dir)
def train(self, data_loader, stage=1, max_objects=3): if stage == 1: netG, netD = self.load_network_stageI() else: netG, netD = self.load_network_stageII() nz = cfg.Z_DIM batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) # with torch.no_grad(): fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), requires_grad=False) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerD = optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) count = 0 for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr for i, data in enumerate(data_loader, 0): ###################################################### # (1) Prepare training data ###################################################### real_img_cpu, bbox, label, txt_embedding = data real_imgs = Variable(real_img_cpu) txt_embedding = Variable(txt_embedding) if cfg.CUDA: real_imgs = real_imgs.cuda() if cfg.STAGE == 1: bbox = bbox.cuda() elif cfg.STAGE == 2: bbox = [bbox[0].cuda(), bbox[1].cuda()] label = label.cuda() txt_embedding = txt_embedding.cuda() if cfg.STAGE == 1: bbox = bbox.view(-1, 4) transf_matrices_inv = compute_transformation_matrix_inverse( bbox) transf_matrices_inv = transf_matrices_inv.view( real_imgs.shape[0], max_objects, 2, 3) transf_matrices = compute_transformation_matrix(bbox) transf_matrices = transf_matrices.view( real_imgs.shape[0], max_objects, 2, 3) elif cfg.STAGE == 2: _bbox = bbox[0].view(-1, 4) transf_matrices_inv = compute_transformation_matrix_inverse( _bbox) transf_matrices_inv = transf_matrices_inv.view( real_imgs.shape[0], max_objects, 2, 3) _bbox = bbox[1].view(-1, 4) transf_matrices_inv_s2 = compute_transformation_matrix_inverse( _bbox) transf_matrices_inv_s2 = transf_matrices_inv_s2.view( real_imgs.shape[0], max_objects, 2, 3) transf_matrices_s2 = compute_transformation_matrix(_bbox) transf_matrices_s2 = transf_matrices_s2.view( real_imgs.shape[0], max_objects, 2, 3) # produce one-hot encodings of the labels _labels = label.long() # remove -1 to enable one-hot converting _labels[_labels < 0] = 80 # label_one_hot = torch.cuda.FloatTensor(noise.shape[0], max_objects, 81).fill_(0) label_one_hot = torch.FloatTensor(noise.shape[0], max_objects, 81).fill_(0) label_one_hot = label_one_hot.scatter_(2, _labels, 1).float() ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) if cfg.STAGE == 1: inputs = (txt_embedding, noise, transf_matrices_inv, label_one_hot) elif cfg.STAGE == 2: inputs = (txt_embedding, noise, transf_matrices_inv, transf_matrices_s2, transf_matrices_inv_s2, label_one_hot) _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel( netG, inputs, self.gpus) # _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise, transf_matrices_inv, label_one_hot) ############################ # (3) Update D network ########################### netD.zero_grad() if cfg.STAGE == 1: errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, label_one_hot, transf_matrices, transf_matrices_inv, mu, self.gpus) elif cfg.STAGE == 2: errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, label_one_hot, transf_matrices_s2, transf_matrices_inv_s2, mu, self.gpus) errD.backward(retain_graph=True) optimizerD.step() ############################ # (2) Update G network ########################### netG.zero_grad() if cfg.STAGE == 1: errG = compute_generator_loss(netD, fake_imgs, real_labels, label_one_hot, transf_matrices, transf_matrices_inv, mu, self.gpus) elif cfg.STAGE == 2: errG = compute_generator_loss(netD, fake_imgs, real_labels, label_one_hot, transf_matrices_s2, transf_matrices_inv_s2, mu, self.gpus) kl_loss = KL_loss(mu, logvar) errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL errG_total.backward() optimizerG.step() count += 1 if i % 500 == 0: summary_D = summary.scalar('D_loss', errD.item()) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.item()) summary_KL = summary.scalar('KL_loss', kl_loss.item()) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D_r, count) self.summary_writer.add_summary(summary_D_w, count) self.summary_writer.add_summary(summary_D_f, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) # save the image result for each epoch with torch.no_grad(): if cfg.STAGE == 1: inputs = (txt_embedding, noise, transf_matrices_inv, label_one_hot) elif cfg.STAGE == 2: inputs = (txt_embedding, noise, transf_matrices_inv, transf_matrices_s2, transf_matrices_inv_s2, label_one_hot) lr_fake, fake, _, _, _ = nn.parallel.data_parallel( netG, inputs, self.gpus) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) with torch.no_grad(): if cfg.STAGE == 1: inputs = (txt_embedding, noise, transf_matrices_inv, label_one_hot) elif cfg.STAGE == 2: inputs = (txt_embedding, noise, transf_matrices_inv, transf_matrices_s2, transf_matrices_inv_s2, label_one_hot) lr_fake, fake, _, _, _ = nn.parallel.data_parallel( netG, inputs, self.gpus) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) end_t = time.time() print( '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.item(), errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir) # save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir) # self.summary_writer.close()
def train(self, data_loader, stage=1): if stage == 1: netG, netD = self.load_network_stageI() else: netG, netD = self.load_network_stageII() nz = cfg.Z_DIM batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), volatile=True) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH optimizerD = \ optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) count = 0 #### #netD_std = 0.1 for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 #### stand deviation decay #netD_std *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr for i, data in enumerate(data_loader, 0): ###################################################### # (1) Prepare training data ###################################################### real_img_cpu, txt_embedding = data real_imgs = Variable(real_img_cpu) txt_embedding = Variable(txt_embedding).float() # print(txt_embedding.size()) if cfg.CUDA: real_imgs = real_imgs.cuda() txt_embedding = txt_embedding.cuda() ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (txt_embedding, noise) if stage == 1: _, fake_imgs, mu, logvar, _ = \ nn.parallel.data_parallel(netG, inputs, self.gpus) else: _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) ############################ # (3) Update D network ########################### #### A little noise for images passed to discriminator #fake_imgs = fake_imgs + torch.cuda.FloatTensor(fake_imgs.size()).normal_(0,netD_std) #### update D twice for D_update in range(2): netD.zero_grad() errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, mu, self.gpus) errD.backward() optimizerD.step() #### update D with reversed labels netD.zero_grad() errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, fake_labels, real_labels, mu, self.gpus) errD.backward() optimizerD.step() ############################ # (2) Update G network ########################### netG.zero_grad() errG = compute_generator_loss(netD, fake_imgs, real_labels, mu, self.gpus) kl_loss = KL_loss(mu, logvar) errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL errG_total.backward() optimizerG.step() count = count + 1 if i % 100 == 0: #summary_D = summary.scalar('D_loss', errD.data[0]) #print(summary_D) #summary_D_r = summary.scalar('D_loss_real', errD_real) #summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) #summary_D_f = summary.scalar('D_loss_fake', errD_fake) #summary_G = summary.scalar('G_loss', errG.data[0]) #summary_KL = summary.scalar('KL_loss', kl_loss.data[0]) #self.summary_writer.add_summary(summary_D, count) #self.summary_writer.add_summary(summary_D_r, count) #self.summary_writer.add_summary(summary_D_w, count) #self.summary_writer.add_summary(summary_D_f, count) #self.summary_writer.add_summary(summary_G, count) #self.summary_writer.add_summary(summary_KL, count) # save the image result for each epoch inputs = (txt_embedding, fixed_noise) if stage == 1: lr_fake, fake, _, _, _ = \ nn.parallel.data_parallel(netG, inputs, self.gpus) else: lr_fake, fake, _, _ = \ nn.parallel.data_parallel(netG, inputs, self.gpus) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) end_t = time.time() print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.data[0], errG.data[0], kl_loss.data[0], errD_real, errD_wrong, errD_fake, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD, epoch, self.model_dir) # save_model(netG, netD, self.max_epoch, self.model_dir) # self.summary_writer.close()
def train(self, data_loader, dataset, stage=1): image_encoder, image_generator, text_encoder, text_generator, disc_image, disc_latent = self.networks nz = cfg.Z_DIM batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), volatile=True) # # make labels for real/fake # real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) # try discriminator smoothing fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) txt_enc_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) img_enc_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() txt_enc_labels = txt_enc_labels.cuda() img_enc_labels = img_enc_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH optims = self.define_optimizers(image_encoder, image_generator, text_encoder, text_generator, disc_image, disc_latent) optim_img_enc, optim_img_gen, optim_txt_enc, optim_txt_gen, optim_disc_img, optim_disc_latent = optims count = 0 for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.75 for param_group in optim_img_gen.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.75 for param_group in optim_disc_img.param_groups: param_group['lr'] = discriminator_lr for i, data in enumerate(data_loader, 0): ###################################################### # (1) Prepare training data ###################################################### _, real_img_cpu, _, captions, pred_cap = data raw_inds, raw_lengths = self.process_captions(captions) inds, lengths = raw_inds.data, raw_lengths inds = Variable(inds) lens_sort, sort_idx = lengths.sort(0, descending=True) # need to dataparallel the encoders? txt_encoder_output = text_encoder(inds[:, sort_idx], lens_sort.cpu().numpy(), None) encoder_out, encoder_hidden, real_txt_code, real_txt_mu, real_txt_logvar = txt_encoder_output real_imgs = Variable(real_img_cpu) if cfg.CUDA: real_imgs = real_imgs.cuda() ####################################################### # (2) Generate fake images and their latent codes ###################################################### noise.data.normal_(0, 1) inputs = (real_txt_code, noise) fake_imgs = \ nn.parallel.data_parallel(image_generator, inputs, self.gpus) fake_img_out = nn.parallel.data_parallel( image_encoder, (fake_imgs), self.gpus ) fake_img_feats, fake_img_emb, fake_img_code, fake_img_mu, fake_img_logvar = fake_img_out fake_img_feats = fake_img_feats.transpose(0,1) ####################################################### # (2b) Calculate auto encoding loss for text ###################################################### loss_auto_txt, _ = compute_text_gen_loss(text_generator, inds[:,sort_idx], real_txt_code.unsqueeze(0), encoder_out, self.txt_dico) loss_auto_txt = loss_auto_txt / lengths.float().sum() ####################################################### # (2c) Decode z from real imgs and calc auto-encoding loss ###################################################### real_img_out = nn.parallel.data_parallel( image_encoder, (real_imgs[sort_idx]), self.gpus ) real_img_feats, real_img_emb, real_img_code, real_img_mu, real_img_logvar = real_img_out noise.data.normal_(0, 1) loss_auto_img, _ = compute_image_gen_loss(image_generator, real_imgs[sort_idx], real_img_code, noise, self.gpus) ####################################################### # (2c) Decode z from fake imgs and calc cycle loss ###################################################### loss_cycle_text, gen_captions = compute_text_gen_loss(text_generator, inds[:,sort_idx], fake_img_code.unsqueeze(0), fake_img_feats, self.txt_dico) loss_cycle_text = loss_cycle_text / lengths.float().sum() ############################################################### # (2d) Generate image from predicted cap, calc img cycle loss ############################################################### loss_cycle_img = 0 if (len(pred_cap)): pred_inds, pred_lens = pred_cap pred_inds = Variable(pred_inds.transpose(0,1)) pred_inds = pred_inds.cuda() if cfg.CUDA else pred_inds pred_output = encoder(pred_inds[:, sort_idx], pred_lens.cpu().numpy(), None) pred_txt_out, pred_txt_hidden, pred_txt_code, pred_txt_mu, pred_txt_logvar = pred_output noise.data.normal_(0, 1) inputs = (pred_txt_code, noise) _, fake_from_fake_img, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) pred_img_out = nn.parallel.data_parallel( image_encoder, (fake_from_fake_img), self.gpus ) pred_img_feats, pred_img_emb, pred_img_code, pred_img_mu, pred_img_logvar = pred_img_out semantic_target = Variable(torch.ones(batch_size)) if cfg.CUDA: semantic_target = semantic_target.cuda() loss_cycle_img = cosine_emb_loss( pred_img_feats.contiguous().view(batch_size, -1), real_img_feats.contiguous().view(batch_size, -1), semantic_target ) ########################### # (3) Update D network ########################### optim_disc_img.zero_grad() optim_disc_latent.zero_grad() errD = 0 errD_fake_imgs = compute_cond_discriminator_loss(disc_image, fake_imgs, fake_labels, encoder_hidden[0], self.gpus) errD_im, errD_real, errD_fake = \ compute_uncond_discriminator_loss(disc_image, real_imgs, fake_imgs, real_labels, fake_labels, self.gpus) err_latent_disc = compute_latent_discriminator_loss(disc_latent, real_img_emb, encoder_hidden[0], img_enc_labels, txt_enc_labels, self.gpus) if (len(pred_cap)): errD_fake_from_fake_imgs = compute_cond_disc(netD, fake_from_fake_img, fake_labels, pred_txt_hidden[0], self.gpus) errD += errD_fake_from_fake_imgs errD = errD + errD_im + errD_fake_imgs + err_latent_disc # check NaN if (errD != errD).data.any(): print("NaN detected (discriminator)") pdb.set_trace() exit() errD.backward() optim_disc_img.step() optim_disc_latent.step() ############################ # (2) Update G network ########################### optim_img_enc.zero_grad() optim_img_gen.zero_grad() optim_txt_enc.zero_grad() optim_txt_gen.zero_grad() errG_total = 0 err_g_uncond_loss = compute_uncond_generator_loss(disc_image, fake_imgs, real_labels, self.gpus) err_g_cond_disc_loss = compute_cond_generator_loss(disc_image, fake_imgs, real_labels, encoder_hidden[0], self.gpus) err_latent_gen = compute_latent_generator_loss(disc_latent, real_img_emb, encoder_hidden[0], img_enc_labels, txt_enc_labels, self.gpus) errG = err_g_uncond_loss + err_g_cond_disc_loss + err_latent_gen + \ loss_cycle_text + \ loss_auto_img + \ loss_auto_txt if (len(pred_cap)): errG_fake_from_fake_imgs = compute_cond_disc(netD, fake_from_fake_img, real_labels, pred_txt_hidden[0], self.gpus) errG += errG_fake_from_fake_imgs img_kl_loss = KL_loss(real_img_mu, real_img_logvar) txt_kl_loss = KL_loss(real_txt_mu, real_txt_logvar) f_img_kl_loss = KL_loss(fake_img_mu, fake_img_logvar) kl_loss = img_kl_loss + txt_kl_loss + f_img_kl_loss errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL # check NaN if (errG_total != errG_total).data.any(): print("NaN detected (generator)") pdb.set_trace() exit() errG_total.backward() optim_img_enc.step() optim_img_gen.step() optim_txt_enc.step() optim_txt_gen.step() count = count + 1 if i % 100 == 0: self.vis.add_to_plot("D_loss", np.asarray([[ errD.data[0], errD_im.data[0], errD_fake_imgs.data[0], err_latent_disc.data[0] ]]), np.asarray([[count] * 4])) self.vis.add_to_plot("G_loss", np.asarray([[ errG.data[0], err_g_uncond_loss.data[0], err_g_cond_disc_loss.data[0], err_latent_gen.data[0], loss_cycle_text.data[0], loss_auto_img.data[0], loss_auto_txt.data[0] ]]), np.asarray([[count] * 7])) self.vis.add_to_plot("KL_loss", np.asarray([[ kl_loss.data[0], img_kl_loss.data[0], txt_kl_loss.data[0], f_img_kl_loss.data[0] ]]), np.asarray([[count] * 4])) self.vis.show_images("real_im", real_imgs[sort_idx].data.cpu().numpy()) self.vis.show_images("fake_im", fake_imgs.data.cpu().numpy()) sorted_captions = [captions[i] for i in sort_idx.cpu().tolist()] gen_cap_text = [] for d_i, d in enumerate(gen_captions): s = u"" for i in d: if i == self.txt_dico.EOS_TOKEN: break if i != self.txt_dico.SOS_TOKEN: s += self.txt_dico.id2word[i] + u" " gen_cap_text.append(s) self.vis.show_text("real_captions", sorted_captions) self.vis.show_text("genr_captions", gen_cap_text) r_precision = self.evaluator.r_precision_score(fake_img_code, real_txt_code) self.vis.add_to_plot("r_precision", np.asarray([r_precision.data[0]]), np.asarray([count])) # # save pred caps for next iteration # for i, data in enumerate(data_loader, 0): # keys, real_img_cpu, _, _, _ = data # real_imgs = Variable(real_img_cpu) # if cfg.CUDA: # real_imgs = real_imgs.cuda() # cap_img_out = nn.parallel.data_parallel( # image_encoder, (real_imgs[sort_idx]), self.gpus # ) # cap_img_feats, cap_img_emb, cap_img_code, cap_img_mu, cap_img_logvar = cap_img_out # cap_img_feats = cap_img_feats.transpose(0,1) # cap_features = cap_img_code.unsqueeze(0) # cap_dec_inp = Variable(torch.LongTensor([self.txt_dico.SOS_TOKEN] * self.batch_size)) # cap_dec_inp = cap_dec_inp.cuda() if cfg.CUDA else cap_dec_inp # cap_dec_hidden = cap_features.detach() # seq = torch.LongTensor([]) # seq = seq.cuda() if cfg.CUDA else seq # max_target_length = 20 # lengths = torch.LongTensor(batch_size).fill_(20) # for t in range(max_target_length): # cap_dec_out, cap_dec_hidden, cap_dec_attn = decoder( # cap_dec_inp, cap_dec_hidden, cap_img_feats # ) # topv, topi = cap_dec_out.topk(1, dim=1) # cap_dec_inp = topi #.squeeze(dim=2) # cap_dec_inp = cap_dec_inp.cuda() if cfg.CUDA else cap_dec_inp # seq = torch.cat((seq, cap_dec_inp.data), dim=1) # dataset.save_captions(keys, seq.cpu(), lengths.cpu()) iscore_mu_real, _ = self.evaluator.inception_score(real_imgs[sort_idx]) iscore_mu_fake, _ = self.evaluator.inception_score(fake_imgs) self.vis.add_to_plot("inception_score", np.asarray([[ iscore_mu_real, iscore_mu_fake ]]), np.asarray([[epoch] * 2])) end_t = time.time() prefix = "Epoch %d; %s, %.1f sec" % (epoch, time.strftime('D%d %X'), (end_t-start_t)) gen_str = "G_total: %.3f Gen loss: %.3f KL loss %.3f" % ( errG_total.data[0], errG.data[0], kl_loss.data[0] ) dis_str = "Img Disc: %.3f Latent Disc: %.3f" % ( errD.data[0], err_latent_disc.data[0] ) eval_str = "Incep real: %.3f Incep fake: %.3f R prec %.3f" % ( iscore_mu_real, iscore_mu_fake, r_precision ) print("%s %s, %s; %s" % (prefix, gen_str, dis_str, eval_str)) if epoch % self.snapshot_interval == 0: save_model(image_encoder, image_generator, text_encoder, text_generator, disc_image, disc_latent, epoch, self.model_dir) save_model(image_encoder, image_generator, text_encoder, text_generator, disc_image, disc_latent, epoch, self.model_dir) self.summary_writer.close()
def train(self, data_loader, dataset, stage=1): netG, netD, encoder, decoder, image_encoder, enc_disc, clf_model = self.load_network_stageI( ) nz = cfg.Z_DIM batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), volatile=True) real_labels = Variable(torch.FloatTensor(batch_size).fill_( 1)) # try discriminator smoothing fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH optimizerD = \ optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) optim_fn, optim_params = get_optimizer("adam,lr=0.001") enc_params = filter(lambda p: p.requires_grad, encoder.parameters()) enc_optimizer = optim_fn(enc_params, **optim_params) optim_fn, optim_params = get_optimizer("adam,lr=0.001") dec_params = filter(lambda p: p.requires_grad, decoder.parameters()) dec_optimizer = optim_fn(dec_params, **optim_params) # image_enc_optimizer = \ # optim.Adam(image_encoder.parameters(), # lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) image_enc_optimizer = \ optim.SGD(image_encoder.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR) enc_disc_optimizer = \ optim.Adam(enc_disc.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) count = 0 criterionCycle = nn.SmoothL1Loss() #criterionCycle = torch.nn.BCELoss() semantic_criterion = nn.CosineEmbeddingLoss() for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.75 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.75 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr for i, data in enumerate(data_loader, 0): ###################################################### # (1) Prepare training data ###################################################### _, real_img_cpu, _, captions, pred_cap = data raw_inds, raw_lengths = self.process_captions(captions) # need to fix noise addition #inds, lengths = self.add_noise(raw_inds.data, raw_lengths) inds, lengths = raw_inds.data, raw_lengths inds = Variable(inds) lens_sort, sort_idx = lengths.sort(0, descending=True) # need to dataparallel the encoders? txt_encoder_output = encoder(inds[:, sort_idx], lens_sort.cpu().numpy(), None) encoder_out, encoder_hidden, real_txt_code, real_txt_mu, real_txt_logvar = txt_encoder_output real_imgs = Variable(real_img_cpu) if cfg.CUDA: real_imgs = real_imgs.cuda() ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (real_txt_code, noise) _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) ####################################################### # (2b) Decode z from txt and calc auto-encoding loss ###################################################### loss_auto = 0 auto_dec_inp = Variable( torch.LongTensor([self.txt_dico.SOS_TOKEN] * self.batch_size)) auto_dec_inp = auto_dec_inp.cuda( ) if cfg.CUDA else auto_dec_inp auto_dec_hidden = real_txt_code.unsqueeze(0) max_target_length = inds.size(0) for t in range(max_target_length): auto_dec_out, auto_dec_hidden, auto_dec_attn = decoder( auto_dec_inp, auto_dec_hidden, encoder_out) loss_auto = loss_auto + F.cross_entropy( auto_dec_out, inds[:, sort_idx][t], ignore_index=self.txt_dico.PAD_TOKEN) auto_dec_inp = inds[:, sort_idx][t] loss_auto = loss_auto / lengths.float().sum() ####################################################### # (2c) Decode z from real imgs and calc auto-encoding loss ###################################################### real_img_out = nn.parallel.data_parallel( image_encoder, (real_imgs[sort_idx]), self.gpus) real_img_feats, real_img_emb, real_img_code, real_img_mu, real_img_logvar = real_img_out noise.data.normal_(0, 1) inputs = (real_img_code, noise) _, fake_from_real_img, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) loss_img = criterionCycle(F.sigmoid(fake_from_real_img), F.sigmoid(real_imgs[sort_idx])) # loss_img = F.binary_cross_entropy_with_logits(fake_from_real_img.view(batch_size, -1), # real_imgs.view(batch_size, -1)) ####################################################### # (2c) Decode z from fake imgs and calc cycle loss ###################################################### fake_img_out = nn.parallel.data_parallel( image_encoder, (real_imgs[sort_idx]), self.gpus) fake_img_feats, fake_img_emb, fake_img_code, fake_img_mu, fake_img_logvar = fake_img_out fake_img_feats = fake_img_feats.transpose(0, 1) loss_cd = 0 cd_dec_inp = Variable( torch.LongTensor([self.txt_dico.SOS_TOKEN] * self.batch_size)) cd_dec_inp = cd_dec_inp.cuda() if cfg.CUDA else cd_dec_inp cd_dec_hidden = fake_img_code.unsqueeze(0) max_target_length = inds.size(0) for t in range(max_target_length): cd_dec_out, cd_dec_hidden, cd_dec_attn = decoder( cd_dec_inp, cd_dec_hidden, fake_img_feats) loss_cd = loss_cd + F.cross_entropy( cd_dec_out, inds[:, sort_idx][t], ignore_index=self.txt_dico.PAD_TOKEN) cd_dec_inp = inds[:, sort_idx][t] loss_cd = loss_cd / lengths.float().sum() loss_dc = criterionCycle(fake_imgs, real_imgs[sort_idx]) ############################ # (3) Update D network ########################### netD.zero_grad() enc_disc.zero_grad() errD = 0 errD_im, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, real_txt_mu, self.gpus) # updating discriminator for encoding txt_enc_labels = Variable( torch.FloatTensor(batch_size).fill_(0)) img_enc_labels = Variable( torch.FloatTensor(batch_size).fill_(1)) if cfg.CUDA: txt_enc_labels = txt_enc_labels.cuda() img_enc_labels = img_enc_labels.cuda() disc_real_txt_emb = encoder_hidden[0].detach() disc_real_img_emb = real_img_emb.detach() pred_txt = enc_disc(disc_real_txt_emb) pred_img = enc_disc(disc_real_img_emb) enc_disc_loss_txt = F.binary_cross_entropy_with_logits( pred_txt.squeeze(), txt_enc_labels) enc_disc_loss_img = F.binary_cross_entropy_with_logits( pred_img.squeeze(), img_enc_labels) errD = errD + errD_im + enc_disc_loss_txt + enc_disc_loss_img # check NaN if (errD != errD).data.any(): print("NaN detected (discriminator)") pdb.set_trace() exit() errD.backward() optimizerD.step() enc_disc_optimizer.step() ############################ # (2) Update G network ########################### encoder.zero_grad() decoder.zero_grad() netG.zero_grad() image_encoder.zero_grad() errG = compute_generator_loss(netD, fake_imgs, real_labels, real_txt_mu, self.gpus) img_kl_loss = KL_loss(real_img_mu, real_img_logvar) txt_kl_loss = KL_loss(real_txt_mu, real_txt_logvar) #f_img_kl_loss = KL_loss(fake_img_mu, fake_img_logvar) kl_loss = img_kl_loss + txt_kl_loss #+ f_img_kl_loss #_, disc_hidden_g = encoder(inds[:, sort_idx], lens_sort.cpu().numpy(), None) #dg_mu, dg_logvar = nn.parallel.data_parallel(image_encoder, (real_imgs), self.gpus) #disc_img_g = torch.cat((dg_mu.unsqueeze(0), dg_logvar.unsqueeze(0))) pred_txt_g = enc_disc(encoder_hidden[0]) pred_img_g = enc_disc(real_img_emb) enc_fake_loss_txt = F.binary_cross_entropy_with_logits( pred_img_g.squeeze(), txt_enc_labels) enc_fake_loss_img = F.binary_cross_entropy_with_logits( pred_txt_g.squeeze(), img_enc_labels) errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL + loss_cd + loss_dc + loss_img + loss_auto + enc_fake_loss_txt + enc_fake_loss_img # check NaN if (errG_total != errG_total).data.any(): print("NaN detected (generator)") pdb.set_trace() exit() errG_total.backward() optimizerG.step() image_enc_optimizer.step() enc_optimizer.step() dec_optimizer.step() count = count + 1 if i % 100 == 0: # summary_D = summary.scalar('D_loss', errD.data[0]) # summary_D_r = summary.scalar('D_loss_real', errD_real) # #summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) # summary_D_f = summary.scalar('D_loss_fake', errD_fake) # summary_G = summary.scalar('G_loss', errG.data[0]) # #summary_KL = summary.scalar('KL_loss', kl_loss.data[0]) # self.summary_writer.add_summary(summary_D, count) # self.summary_writer.add_summary(summary_D_r, count) # #self.summary_writer.add_summary(summary_D_w, count) # self.summary_writer.add_summary(summary_D_f, count) # self.summary_writer.add_summary(summary_G, count) # #self.summary_writer.add_summary(summary_KL, count) # save the image result for each epoch inputs = (real_txt_code, fixed_noise) lr_fake, fake, _, _ = \ nn.parallel.data_parallel(netG, inputs, self.gpus) self.vis.images(normalize( real_imgs[sort_idx].data.cpu().numpy()), win=self.vis_win1) self.vis.images(normalize(fake_imgs.data.cpu().numpy()), win=self.vis_win2) self.vis.text("\n*".join(captions), win=self.vis_txt1) if (len(pred_cap)): self.vis.images(normalize( fake_from_fake_img.data.cpu().numpy()), win=self.vis_win3) end_t = time.time() prefix = "E%d/%s, %.1fs" % (epoch, time.strftime('D%d %X'), (end_t - start_t)) gen_str = "G_all: %.3f Cy_T: %.3f AE_T: %.3f AE_I %.3f KL_T %.3f KL_I %.3f" % ( errG_total.data[0], loss_cd.data[0], loss_auto.data[0], loss_img.data[0], txt_kl_loss.data[0], img_kl_loss.data[0]) dis_str = "D_all: %.3f D_I: %.3f D_zT: %.3f D_zI: %.3f" % ( errD.data[0], errD_im.data[0], enc_disc_loss_txt.data[0], enc_disc_loss_img.data[0]) print("%s %s, %s" % (prefix, gen_str, dis_str)) if epoch % self.snapshot_interval == 0: save_model(netG, netD, encoder, decoder, image_encoder, epoch, self.model_dir) # save_model(netG, netD, encoder, decoder, image_encoder, self.max_epoch, self.model_dir) # self.summary_writer.close()
def train(self, data_loader, stage=1): if stage == 1: netG, netD = self.load_network_stageI() else: netG, netD = self.load_network_stageII() nz = cfg.Z_DIM batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), volatile=True) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH optimizerD = \ optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) count = 0 for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr for i, data in enumerate(data_loader, 0): ###################################################### # (1) Prepare training data ###################################################### real_img_cpu, txt_embedding = data real_imgs = Variable(real_img_cpu) txt_embedding = Variable(txt_embedding) if cfg.CUDA: real_imgs = real_imgs.cuda() txt_embedding = txt_embedding.cuda() ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (txt_embedding, noise) _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) ############################ # (3) Update D network ########################### netD.zero_grad() errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, mu, self.gpus) errD.backward() optimizerD.step() ############################ # (2) Update G network ########################### netG.zero_grad() errG = compute_generator_loss(netD, fake_imgs, real_labels, mu, self.gpus) kl_loss = KL_loss(mu, logvar) errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL errG_total.backward() optimizerG.step() count = count + 1 if i % 100 == 0: self.summary_writer.add_scalars(main_tag="loss", tag_scalar_dict={ 'D_loss':errD.cpu().item(), 'G_loss':errG_total.cpu().item() }, global_step=count) self.summary_writer.add_scalars(main_tag="D_loss", tag_scalar_dict={ "D_loss_real":errD_real, "D_loss_wrong":errD_wrong, "D_loss_fake":errD_fake }, global_step=count) self.summary_writer.add_scalars(main_tag="G_loss", tag_scalar_dict={ "G_loss":errG.cpu().item(), "KL_loss":kl_loss.cpu().item() }, global_step=count) # save the image result for each epoch inputs = (txt_embedding, fixed_noise) lr_fake, fake, _, _ = \ nn.parallel.data_parallel(netG, inputs, self.gpus) save_img_results(real_img_cpu, fake, epoch, self.image_dir) self.summary_writer.add_image(tag="fake_image", img_tensor=vutils.make_grid(fake_imgs, normalize=True, range=(-1,1)), global_step=count ) self.summary_writer.add_image(tag="real_image", img_tensor=vutils.make_grid(real_img_cpu, normalize=True, range=(-1,1)), global_step=count ) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) end_t = time.time() print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.cpu().item(), errG.cpu().item(), kl_loss.cpu().item(), errD_real, errD_wrong, errD_fake, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD, epoch, self.model_dir) # save_model(netG, netD, self.max_epoch, self.model_dir)
def train(self, data_loader, stage=1): if stage == 1: netG, netD = self.load_network_stageI() else: netG, netD = self.load_network_stageII() nz = cfg.Z_DIM # 100 batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), volatile=True) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH optimizerD = \ optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) count = 0 detectron = Detectron() for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr #print('check 0') for i, data in enumerate(data_loader): ###################################################### # (1) Prepare training data ###################################################### #print('check 1') real_img_cpu, txt_embedding, caption = data caption = np.moveaxis(np.array(caption), 1, 0) #print('check 2') real_imgs = Variable(real_img_cpu) txt_embedding = Variable(txt_embedding) #print('check 3') if cfg.CUDA: real_imgs = real_imgs.cuda() txt_embedding = txt_embedding.cuda() ####################################################### # (2) Generate fake images ###################################################### #print(real_imgs.size()) noise.data.normal_(0, 1) inputs = (txt_embedding, noise) _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) ############################ # (3) Update D network ########################### netD.zero_grad() errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, mu, self.gpus) errD.backward() optimizerD.step() ############################ # (2) Update G network ########################### netG.zero_grad() errG = compute_generator_loss(netD, fake_imgs, real_labels, mu, self.gpus) kl_loss = KL_loss(mu, logvar) fake_img = fake_imgs.cpu().detach().numpy() #print(fake_img.shape) det_obj_list = detectron.get_labels(fake_img) fake_l = Variable(get_ohe(det_obj_list)).cuda() real_l = Variable(get_ohe(caption)).cuda() det_loss = nn.SmoothL1Loss()(fake_l, real_l) errG_total = det_loss + errG + kl_loss * cfg.TRAIN.COEFF.KL errG_total.backward() optimizerG.step() count = count + 1 if i % 100 == 0: summary_D = summary.scalar('D_loss', errD.item()) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.item()) summary_KL = summary.scalar('KL_loss', kl_loss.item()) summary_DET = summary.scalar('det_loss', det_loss.item()) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D_r, count) self.summary_writer.add_summary(summary_D_w, count) self.summary_writer.add_summary(summary_D_f, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) self.summary_writer.add_summary(summary_DET, count) # save the image result for each epoch inputs = (txt_embedding, fixed_noise) lr_fake, fake, _, _ = \ nn.parallel.data_parallel(netG, inputs, self.gpus) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) end_t = time.time() print( '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.item(), errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD, epoch, self.model_dir) # save_model(netG, netD, self.max_epoch, self.model_dir) # self.summary_writer.close()
def train(self, data_loader, stage=1): if stage == 1: netG, netD = self.load_network_stageI() else: netG, netD = self.load_network_stageII() nz = cfg.Z_DIM batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), volatile=True) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) if cfg.TRAIN.ADAM: optimizerD = \ optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) else: optimizerD = \ optim.RMSprop(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR) optimizerG = \ optim.RMSprop(netG_para, lr=cfg.TRAIN.GENERATOR_LR) cnn = models.vgg19(pretrained=True).features cnn = nn.Sequential(*list(cnn.children())[0:28]) gram = GramMatrix() if cfg.CUDA: cnn.cuda() gram.cuda() count = 0 for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr for i, data in enumerate(data_loader, 0): ###################################################### # (1) Prepare training data ###################################################### real_img_cpu, txt_embedding = data real_imgs = Variable(real_img_cpu) txt_embedding = Variable(txt_embedding) if cfg.CUDA: real_imgs = real_imgs.cuda() txt_embedding = txt_embedding.cuda() ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (txt_embedding, noise) if cfg.CUDA: _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) else: _, fake_imgs, mu, logvar = netG(txt_embedding, noise) ############################ # (3) Update D network ########################### netD.zero_grad() errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, mu, self.gpus, cfg.CUDA) errD.backward() optimizerD.step() ############################ # (2) Update G network ########################### netG.zero_grad() errG = compute_generator_loss(netD, fake_imgs, real_labels, mu, self.gpus, cfg.CUDA) kl_loss = KL_loss(mu, logvar) pixel_loss = PIXEL_loss(real_imgs, fake_imgs) if cfg.CUDA: fake_features = nn.parallel.data_parallel( cnn, fake_imgs.detach(), self.gpus) real_features = nn.parallel.data_parallel( cnn, real_imgs.detach(), self.gpus) else: fake_features = cnn(fake_imgs) real_features = cnn(real_imgs) active_loss = ACT_loss(fake_features, real_features) text_loss = TEXT_loss(gram, fake_features, real_features, cfg.TRAIN.COEFF.TEXT) errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL + \ pixel_loss * cfg.TRAIN.COEFF.PIX + \ active_loss * cfg.TRAIN.COEFF.ACT +\ text_loss errG_total.backward() optimizerG.step() count = count + 1 if i % 100 == 0: summary_D = summary.scalar('D_loss', errD.data[0]) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.data[0]) summary_KL = summary.scalar('KL_loss', kl_loss.data[0]) summary_Pix = summary.scalar('Pixel_loss', pixel_loss.data[0]) summary_Act = summary.scalar('Act_loss', active_loss.data[0]) summary_Text = summary.scalar('Text_loss', text_loss.data[0]) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D_r, count) self.summary_writer.add_summary(summary_D_w, count) self.summary_writer.add_summary(summary_D_f, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) self.summary_writer.add_summary(summary_Pix, count) self.summary_writer.add_summary(summary_Act, count) self.summary_writer.add_summary(summary_Text, count) # save the image result for each epoch inputs = (txt_embedding, fixed_noise) if cfg.CUDA: lr_fake, fake, _, _ = \ nn.parallel.data_parallel(netG, inputs, self.gpus) else: lr_fake, fake, _, _ = netG(txt_embedding, fixed_noise) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) end_t = time.time() print( '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_Pixel: %.4f Loss_Activ: %.4f Loss_Text: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.data[0], errG.data[0], kl_loss.data[0], pixel_loss.data[0], active_loss.data[0], text_loss.data[0], errD_real, errD_wrong, errD_fake, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD, epoch, self.model_dir) # save_model(netG, netD, self.max_epoch, self.model_dir) # self.summary_writer.close()
def train(self, data_loader, stage=1): if stage == 1: netG, netD = self.load_network_stageI() else: netG, netD = self.load_network_stageII() nz = cfg.Z_DIM batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) with torch.no_grad(): #Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), #volatile=True) fixed_noise = \ torch.FloatTensor(batch_size, nz).normal_(0, 1) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH optimizerD = \ optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) count = 0 for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr #print('dataLoader, line 156 trainer.py...........') #print(data_loader) num_batches = len(data_loader) print('Number of batches: ' + str(len(data_loader))) for i, data in enumerate(data_loader, 0): print('Epoch number: ' + str(epoch) + '\tBatches: ' + str(i) + '/' + str(num_batches), end='\r') ###################################################### # (1) Prepare training data ###################################################### real_img_cpu, txt_embedding = data #print(txt_embedding.shape) #(Batch_size,1024) #exit(0) real_imgs = Variable(real_img_cpu) txt_embedding = Variable(txt_embedding) if cfg.CUDA: real_imgs = real_imgs.cuda() txt_embedding = txt_embedding.cuda() #print('train line 170') ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (txt_embedding, noise) _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) #print('Fake images generated shape = ' + str(fake_imgs.shape)) #print('Shape of fake image: ' + str(fake_imgs.shape)) [Batch_size, Channels(3), N, N] #print('Fake images: ') #Display one image ### Check this line! How to display image?? ############## #plt.imshow(fake_imgs[0].permute(1,2,0).cpu().detach().numpy()) #exit(0) ################################################ ############################ # (3) Update D network ########################### netD.zero_grad() #print('train line 186') errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, mu, self.gpus) errD.backward() optimizerD.step() ############################ # (2) Update G network ########################### netG.zero_grad() errG = compute_generator_loss(netD, fake_imgs, real_labels, mu, self.gpus) kl_loss = KL_loss(mu, logvar) errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL errG_total.backward() optimizerG.step() #print('train line 203') count = count + 1 if i % 100 == 0: """ summary_D = summary.scalar('D_loss', errD.data[0]) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.data[0]) summary_KL = summary.scalar('KL_loss', kl_loss.data[0]) """ ## My lines summary_D = summary.scalar('D_loss', errD.data) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.data) summary_KL = summary.scalar('KL_loss', kl_loss.data) #### End of my lines self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D_r, count) self.summary_writer.add_summary(summary_D_w, count) self.summary_writer.add_summary(summary_D_f, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) # save the image result for each epoch inputs = (txt_embedding, fixed_noise) lr_fake, fake, _, _ = \ nn.parallel.data_parallel(netG, inputs, self.gpus) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) del inputs end_t = time.time() print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.data, errG.data, kl_loss.data, errD_real, errD_wrong, errD_fake, (end_t - start_t))) # % (epoch, self.max_epoch, i, len(data_loader), # errD.data[0], errG.data[0], kl_loss.data[0], # errD_real, errD_wrong, errD_fake, (end_t - start_t))) print('################EPOCH COMPLETED###########') if epoch % self.snapshot_interval == 0: save_model(netG, netD, epoch, self.model_dir) # save_model(netG, netD, self.max_epoch, self.model_dir) # self.summary_writer.close()
def train(self, data_loader, stage=1, max_objects=3): if stage == 1: netG, netD = self.load_network_stageI() else: netG, netD = self.load_network_stageII() nz = cfg.Z_DIM batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) # with torch.no_grad(): fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), requires_grad=False) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerD = optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) #### startpoint = -1 if cfg.NET_G != '': state_dict = torch.load(cfg.NET_G, map_location=lambda storage, loc: storage) optimizerD.load_state_dict(state_dict["optimD"]) optimizerG.load_state_dict(state_dict["optimG"]) startpoint = state_dict["epoch"] print(startpoint) print('Load Optim and optimizers as : ', cfg.NET_G) #### count = 0 drive_count = 0 for epoch in range(startpoint + 1, self.max_epoch): print('epoch : ', epoch, ' drive_count : ', drive_count) epoch_start_time = time.time() print(epoch) start_t = time.time() start_t500 = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr time_to_i = time.time() for i, data in enumerate(data_loader, 0): # if i >= 3360 : # print ('Last Batches : ' , i) # if i < 10 : # print ('first Batches : ' , i) # if i == 0 : # print ('Startig! Batch ',i,'from total of 2070' ) # if i % 10 == 0 and i!=0: # end_t500 = time.time() # print ('Batch Number : ' , i ,' ||||| Toatal Time : ' , (end_t500 - start_t500)) # start_t500 = time.time() ###################################################### # (1) Prepare training data # if i < 10 : # print (" (1) Prepare training data for batch : " , i) ###################################################### #print ("Prepare training data for batch : " , i) real_img_cpu, bbox, label, txt_embedding = data real_imgs = Variable(real_img_cpu) txt_embedding = Variable(txt_embedding) if cfg.CUDA: real_imgs = real_imgs.cuda() if cfg.STAGE == 1: bbox = bbox.cuda() elif cfg.STAGE == 2: bbox = [bbox[0].cuda(), bbox[1].cuda()] label = label.cuda() txt_embedding = txt_embedding.cuda() if cfg.STAGE == 1: bbox = bbox.view(-1, 4) transf_matrices_inv = compute_transformation_matrix_inverse( bbox) transf_matrices_inv = transf_matrices_inv.view( real_imgs.shape[0], max_objects, 2, 3) transf_matrices = compute_transformation_matrix(bbox) transf_matrices = transf_matrices.view( real_imgs.shape[0], max_objects, 2, 3) elif cfg.STAGE == 2: _bbox = bbox[0].view(-1, 4) transf_matrices_inv = compute_transformation_matrix_inverse( _bbox) transf_matrices_inv = transf_matrices_inv.view( real_imgs.shape[0], max_objects, 2, 3) _bbox = bbox[1].view(-1, 4) transf_matrices_inv_s2 = compute_transformation_matrix_inverse( _bbox) transf_matrices_inv_s2 = transf_matrices_inv_s2.view( real_imgs.shape[0], max_objects, 2, 3) transf_matrices_s2 = compute_transformation_matrix(_bbox) transf_matrices_s2 = transf_matrices_s2.view( real_imgs.shape[0], max_objects, 2, 3) # produce one-hot encodings of the labels _labels = label.long() # remove -1 to enable one-hot converting _labels[_labels < 0] = 80 if cfg.CUDA: label_one_hot = torch.cuda.FloatTensor( noise.shape[0], max_objects, 81).fill_(0) else: label_one_hot = torch.FloatTensor(noise.shape[0], max_objects, 81).fill_(0) label_one_hot = label_one_hot.scatter_(2, _labels, 1).float() ####################################################### # # (2) Generate fake images # if i < 10 : # print ("(2)Generate fake images") ###################################################### noise.data.normal_(0, 1) if cfg.STAGE == 1: inputs = (txt_embedding, noise, transf_matrices_inv, label_one_hot) elif cfg.STAGE == 2: inputs = (txt_embedding, noise, transf_matrices_inv, transf_matrices_s2, transf_matrices_inv_s2, label_one_hot) if cfg.CUDA: _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel( netG, inputs, self.gpus) else: print('Hiiiiiiiiiiii') _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise, transf_matrices_inv, label_one_hot) # _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise, transf_matrices_inv, label_one_hot) ############################ # # (3) Update D network # if i < 10 : # print("(3) Update D network") ########################### netD.zero_grad() if cfg.STAGE == 1: errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, label_one_hot, transf_matrices, transf_matrices_inv, mu, self.gpus) elif cfg.STAGE == 2: errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, label_one_hot, transf_matrices_s2, transf_matrices_inv_s2, mu, self.gpus) errD.backward(retain_graph=True) optimizerD.step() ############################ # # (4) Update G network # if i < 10 : # print ("(4) Update G network") ########################### netG.zero_grad() # if i < 10 : # print ("netG.zero_grad") if cfg.STAGE == 1: errG = compute_generator_loss(netD, fake_imgs, real_labels, label_one_hot, transf_matrices, transf_matrices_inv, mu, self.gpus) elif cfg.STAGE == 2: # if i < 10 : # print ("cgf.STAGE = " , cfg.STAGE) errG = compute_generator_loss(netD, fake_imgs, real_labels, label_one_hot, transf_matrices_s2, transf_matrices_inv_s2, mu, self.gpus) # if i < 10 : # print("errG : ",errG) kl_loss = KL_loss(mu, logvar) # if i < 10 : # print ("kl_loss = " , kl_loss) errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL # if i < 10 : # print (" errG_total = " , errG_total ) errG_total.backward() # if i < 10 : # print ("errG_total.backward() ") optimizerG.step() # if i < 10 : # print ("optimizerG.step() " ) #print (" i % 500 == 0 : " , i % 500 == 0 ) end_t = time.time() #print ("batch time : " , (end_t - start_t)) if i % 500 == 0: #print (" i % 500 == 0" , i % 500 == 0 ) count += 1 summary_D = summary.scalar('D_loss', errD.item()) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.item()) summary_KL = summary.scalar('KL_loss', kl_loss.item()) print('epoch : ', epoch) print('count : ', count) print(' i : ', i) print('Time to i : ', time.time() - time_to_i) time_to_i = time.time() print('D_loss : ', errD.item()) print('D_loss_real : ', errD_real) print('D_loss_wrong : ', errD_wrong) print('D_loss_fake : ', errD_fake) print('G_loss : ', errG.item()) print('KL_loss : ', kl_loss.item()) print('generator_lr : ', generator_lr) print('discriminator_lr : ', discriminator_lr) print('lr_decay_step : ', lr_decay_step) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D_r, count) self.summary_writer.add_summary(summary_D_w, count) self.summary_writer.add_summary(summary_D_f, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) # save the image result for each epoch with torch.no_grad(): if cfg.STAGE == 1: inputs = (txt_embedding, noise, transf_matrices_inv, label_one_hot) elif cfg.STAGE == 2: inputs = (txt_embedding, noise, transf_matrices_inv, transf_matrices_s2, transf_matrices_inv_s2, label_one_hot) if cfg.CUDA: lr_fake, fake, _, _, _ = nn.parallel.data_parallel( netG, inputs, self.gpus) else: lr_fake, fake, _, _, _ = netG( txt_embedding, noise, transf_matrices_inv, label_one_hot) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) if i % 100 == 0: drive_count += 1 self.drive_summary_writer.add_summary( summary_D, drive_count) self.drive_summary_writer.add_summary( summary_D_r, drive_count) self.drive_summary_writer.add_summary( summary_D_w, drive_count) self.drive_summary_writer.add_summary( summary_D_f, drive_count) self.drive_summary_writer.add_summary( summary_G, drive_count) self.drive_summary_writer.add_summary( summary_KL, drive_count) #print (" with torch.no_grad(): " ) with torch.no_grad(): if cfg.STAGE == 1: inputs = (txt_embedding, noise, transf_matrices_inv, label_one_hot) elif cfg.STAGE == 2: #print (" cfg.STAGE == 2: " , cfg.STAGE == 2 ) inputs = (txt_embedding, noise, transf_matrices_inv, transf_matrices_s2, transf_matrices_inv_s2, label_one_hot) #print (" inputs " , inputs ) lr_fake, fake, _, _, _ = nn.parallel.data_parallel( netG, inputs, self.gpus) #print (" lr_fake, fake " , lr_fake, fake ) save_img_results(real_img_cpu, fake, epoch, self.image_dir) #print (" save_img_results(real_img_cpu, fake, epoch, self.image_dir) " , ) #print (" lr_fake is not None: " , lr_fake is not None ) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) #print (" save_img_results(None, lr_fake, epoch, self.image_dir) " ) #end_t = time.time() #print ("batch time : " , (end_t - start_t)) end_t = time.time() print( '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.item(), errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir) print("keyTime |||||||||||||||||||||||||||||||") print("epoch_time : ", time.time() - epoch_start_time) print("KeyTime |||||||||||||||||||||||||||||||") # save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir) # self.summary_writer.close()