def train(self, imageloader, storyloader, testloader): self.imageloader = imageloader self.testloader = testloader self.imagedataset = None self.testdataset = None netG, netD_im, netD_st = self.load_networks() im_real_labels = Variable(torch.FloatTensor(self.imbatch_size).fill_(1)) im_fake_labels = Variable(torch.FloatTensor(self.imbatch_size).fill_(0)) st_real_labels = Variable(torch.FloatTensor(self.stbatch_size).fill_(1)) st_fake_labels = Variable(torch.FloatTensor(self.stbatch_size).fill_(0)) if cfg.CUDA: im_real_labels, im_fake_labels = im_real_labels.cuda(), im_fake_labels.cuda() st_real_labels, st_fake_labels = st_real_labels.cuda(), st_fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH im_optimizerD = \ optim.Adam(netD_im.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) st_optimizerD = \ optim.Adam(netD_st.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) if self.tensorboard: self.build_tensorboard() loss = {} step = 0 torch.save({ 'netG': netG, 'netD_im': netD_im, 'netD_st': netD_st, }, os.path.join(self.model_dir, 'barebone.pth')) for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in st_optimizerD.param_groups: param_group['lr'] = discriminator_lr for param_group in im_optimizerD.param_groups: param_group['lr'] = discriminator_lr loss.update({ 'D/lr': discriminator_lr, 'G/lr': generator_lr, }) print('Epoch [{}/{}]:'.format(epoch, self.max_epoch)) with tqdm(total=len(storyloader), dynamic_ncols=True) as pbar: for i, data in enumerate(storyloader, 0): ###################################################### # (1) Prepare training data ###################################################### im_batch = self.sample_real_image_batch() st_batch = data im_real_cpu = im_batch['images'] im_motion_input = im_batch['description'] im_content_input = im_batch['content'] im_content_input = im_content_input.mean(1).squeeze() im_catelabel = im_batch['label'] im_real_imgs = Variable(im_real_cpu) im_motion_input = Variable(im_motion_input) im_content_input = Variable(im_content_input) st_real_cpu = st_batch['images'] st_motion_input = st_batch['description'] st_content_input = st_batch['description'] st_catelabel = st_batch['label'] st_real_imgs = Variable(st_real_cpu) st_motion_input = Variable(st_motion_input) st_content_input = Variable(st_content_input) if cfg.CUDA: st_real_imgs = st_real_imgs.cuda() im_real_imgs = im_real_imgs.cuda() st_motion_input = st_motion_input.cuda() im_motion_input = im_motion_input.cuda() st_content_input = st_content_input.cuda() im_content_input = im_content_input.cuda() im_catelabel = im_catelabel.cuda() st_catelabel = st_catelabel.cuda() ####################################################### # (2) Generate fake stories and images ###################################################### with torch.no_grad(): im_inputs = (im_motion_input, im_content_input) _, im_fake, im_mu, im_logvar = netG.sample_images(*im_inputs) st_inputs = (st_motion_input, st_content_input) _, st_fake, c_mu, c_logvar, m_mu, m_logvar = netG.sample_videos(*st_inputs) ############################ # (3) Update D network ########################### netD_im.zero_grad() netD_st.zero_grad() im_errD, im_errD_real, im_errD_wrong, im_errD_fake, accD = \ compute_discriminator_loss(netD_im, im_real_imgs, im_fake, im_real_labels, im_fake_labels, im_catelabel, im_mu, self.gpus) st_errD, st_errD_real, st_errD_wrong, st_errD_fake, _ = \ compute_discriminator_loss(netD_st, st_real_imgs, st_fake, st_real_labels, st_fake_labels, st_catelabel, c_mu, self.gpus) loss.update({ 'D/story/loss': st_errD.data, 'D/story/real_loss': st_errD_real.data, 'D/story/fake_loss': st_errD_fake.data, 'D/image/accuracy': accD, 'D/image/loss': im_errD.data, 'D/image/real_loss': im_errD_real.data, 'D/image/fake_loss': im_errD_fake.data, }) im_errD.backward() st_errD.backward() im_optimizerD.step() st_optimizerD.step() ############################ # (2) Update G network ########################### for g_iter in range(2): netG.zero_grad() _, st_fake, c_mu, c_logvar, m_mu, m_logvar = netG.sample_videos( st_motion_input, st_content_input) _, im_fake, im_mu, im_logvar = netG.sample_images(im_motion_input, im_content_input) im_errG, accG = compute_generator_loss(netD_im, im_fake, im_real_labels, im_catelabel, im_mu, self.gpus) st_errG, _ = compute_generator_loss(netD_st, st_fake, st_real_labels, st_catelabel, c_mu, self.gpus) im_kl_loss = KL_loss(im_mu, im_logvar) st_kl_loss = KL_loss(m_mu, m_logvar) errG = im_errG + self.ratio * st_errG kl_loss = im_kl_loss + self.ratio * st_kl_loss loss.update({ 'G/loss': im_errG.data, 'G/kl': kl_loss.data, }) errG_total = im_errG + self.ratio * st_errG + kl_loss errG_total.backward() optimizerG.step() if self.writer: for key, value in loss.items(): self.writer.add_scalar(key, value, step) step += 1 pbar.update(1) if i % 100 == 0: # save the image result for each epoch lr_fake, fake, _, _, _, _ = netG.sample_videos(st_motion_input, st_content_input) save_story_results(st_real_cpu, fake, epoch, self.image_dir, writer=self.writer, steps=step) if lr_fake is not None: save_story_results(None, lr_fake, epoch, self.image_dir, writer=self.writer, steps=step) end_t = time.time() print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f accG: %.4f accD: %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(storyloader), st_errD.data, st_errG.data, st_errD_real, st_errD_wrong, st_errD_fake, accG, accD, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD_im, netD_st, epoch, self.model_dir) save_test_samples(netG, self.testloader, self.test_dir, writer=self.writer, steps=step) # save_model(netG, netD_im, netD_st, self.max_epoch, self.model_dir)
def train(self, imageloader, storyloader, testloader, stage=1): c_time = time.time() self.imageloader = imageloader self.imagedataset = None netG, netD_im, netD_st, netD_se = self.load_network_stageI() start = time.time() # Initial Labels im_real_labels = Variable( torch.FloatTensor(self.imbatch_size).fill_(1)) im_fake_labels = Variable( torch.FloatTensor(self.imbatch_size).fill_(0)) st_real_labels = Variable( torch.FloatTensor(self.stbatch_size).fill_(1)) st_fake_labels = Variable( torch.FloatTensor(self.stbatch_size).fill_(0)) if cfg.CUDA: im_real_labels, im_fake_labels = im_real_labels.cuda( ), im_fake_labels.cuda() st_real_labels, st_fake_labels = st_real_labels.cuda( ), st_fake_labels.cuda() use_segment = cfg.SEGMENT_LEARNING segment_weight = cfg.SEGMENT_RATIO image_weight = cfg.IMAGE_RATIO # Optimizer and Scheduler generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH im_optimizerD = optim.Adam(netD_im.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) st_optimizerD = optim.Adam(netD_st.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) if use_segment: se_optimizerD = optim.Adam(netD_se.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) mse_loss = nn.MSELoss() scheduler_imD = ReduceLROnPlateau(im_optimizerD, 'min', verbose=True, factor=0.5, min_lr=1e-7, patience=0) scheduler_stD = ReduceLROnPlateau(st_optimizerD, 'min', verbose=True, factor=0.5, min_lr=1e-7, patience=0) if use_segment: scheduler_seD = ReduceLROnPlateau(se_optimizerD, 'min', verbose=True, factor=0.5, min_lr=1e-7, patience=0) scheduler_G = ReduceLROnPlateau(optimizerG, 'min', verbose=True, factor=0.5, min_lr=1e-7, patience=0) count = 0 # Start training if not self.con_ckpt: start_epoch = 0 else: start_epoch = int(self.con_ckpt) # self.calculate_vfid(netG, 0, testloader) print('LR DECAY EPOCH: {}'.format(lr_decay_step)) for epoch in range(start_epoch, self.max_epoch): l = self.ratio * (2. / (1. + np.exp(-10. * epoch)) - 1) start_t = time.time() # Adjust lr num_step = len(storyloader) stats = {} with tqdm(total=len(storyloader), dynamic_ncols=True) as pbar: for i, data in enumerate(storyloader): ###################################################### # (1) Prepare training data ###################################################### im_batch = self.sample_real_image_batch() st_batch = data im_real_cpu = im_batch['images'] im_motion_input = im_batch[ 'description'][:, :cfg.TEXT. DIMENSION] # description vector and arrtibute (60, 356) im_content_input = im_batch[ 'content'][:, :, :cfg.TEXT. DIMENSION] # description vector and attribute for every story (60,5,356) im_real_imgs = Variable(im_real_cpu) im_motion_input = Variable(im_motion_input) im_content_input = Variable(im_content_input) im_labels = Variable(im_batch['labels']) st_real_cpu = st_batch['images'] st_motion_input = st_batch[ 'description'][:, :, :cfg.TEXT.DIMENSION] #(12,5,356) st_content_input = st_batch[ 'description'][:, :, :cfg.TEXT.DIMENSION] # (12,5,356) st_texts = None if 'text' in st_batch: st_texts = st_batch['text'] st_real_imgs = Variable(st_real_cpu) st_motion_input = Variable(st_motion_input) st_content_input = Variable(st_content_input) st_labels = Variable(st_batch['labels']) # (12,5,9) if use_segment: se_real_cpu = im_batch['images_seg'] se_real_imgs = Variable(se_real_cpu) if cfg.CUDA: st_real_imgs = st_real_imgs.cuda() # (12,3,5,64,64) im_real_imgs = im_real_imgs.cuda() st_motion_input = st_motion_input.cuda() im_motion_input = im_motion_input.cuda() st_content_input = st_content_input.cuda() im_content_input = im_content_input.cuda() im_labels = im_labels.cuda() st_labels = st_labels.cuda() if use_segment: se_real_imgs = se_real_imgs.cuda() im_motion_input = torch.cat((im_motion_input, im_labels), 1) # 356+9=365 (60,365) st_motion_input = torch.cat((st_motion_input, st_labels), 2) # (12,5,365) ####################################################### # (2) Generate fake stories and images ###################################################### # print(st_motion_input.shape, im_motion_input.shape) with torch.no_grad(): _, st_fake, m_mu, m_logvar, c_mu, c_logvar, _ = \ netG.sample_videos(st_motion_input, st_content_input) # m_mu (60,365), c_mu (12,124) _, im_fake, im_mu, im_logvar, cim_mu, cim_logvar, se_fake = \ netG.sample_images(im_motion_input, im_content_input, seg=use_segment) # im_mu (60,489), cim_mu (60,124) characters_mu = ( st_labels.mean(1) > 0 ).type(torch.FloatTensor).cuda( ) # which character exists in the full story (5 descriptions) st_mu = torch.cat( (c_mu, st_motion_input[:, :, :cfg.TEXT.DIMENSION].mean( 1).squeeze(), characters_mu), 1) # 124 + 356 + 9 = 489 (12,489), get character info form whole story im_mu = torch.cat((im_motion_input, cim_mu), 1) # (60,489) ############################ # (3) Update D network ########################### netD_im.zero_grad() netD_st.zero_grad() se_accD = 0 if use_segment: netD_se.zero_grad() se_errD, se_errD_real, se_errD_wrong, se_errD_fake, se_accD, _ = \ compute_discriminator_loss(netD_se, se_real_imgs, se_fake, im_real_labels, im_fake_labels, im_labels, im_mu, self.gpus) im_errD, im_errD_real, im_errD_wrong, im_errD_fake, im_accD, _ = \ compute_discriminator_loss(netD_im, im_real_imgs, im_fake, im_real_labels, im_fake_labels, im_labels, im_mu, self.gpus) st_errD, st_errD_real, st_errD_wrong, st_errD_fake, _, order_consistency = \ compute_discriminator_loss(netD_st, st_real_imgs, st_fake, st_real_labels, st_fake_labels, st_labels, st_mu, self.gpus) if use_segment: se_errD.backward() se_optimizerD.step() stats.update({ 'seg_D/loss': se_errD.data, 'seg_D/real': se_errD_real, 'seg_D/fake': se_errD_fake, }) im_errD.backward() st_errD.backward() im_optimizerD.step() st_optimizerD.step() stats.update({ 'img_D/loss': im_errD.data, 'img_D/real': im_errD_real, 'img_D/fake': im_errD_fake, 'Accuracy/im_D': im_accD, 'Accuracy/se_D': se_accD, }) step = i + num_step * epoch self._logger.add_scalar('st_D/loss', st_errD.data, step) self._logger.add_scalar('st_D/real', st_errD_real, step) self._logger.add_scalar('st_D/fake', st_errD_fake, step) self._logger.add_scalar('st_D/order', order_consistency, step) ############################ # (2) Update G network ########################### netG.zero_grad() video_latents, st_fake, m_mu, m_logvar, c_mu, c_logvar, _ = netG.sample_videos( st_motion_input, st_content_input) image_latents, im_fake, im_mu, im_logvar, cim_mu, cim_logvar, se_fake = netG.sample_images( im_motion_input, im_content_input, seg=use_segment) encoder_decoder_loss = 0 if video_latents is not None: ((h_seg1, h_seg2, h_seg3, h_seg4), (g_seg1, g_seg2, g_seg3, g_seg4)) = video_latents video_latent_loss = mse_loss( g_seg1, h_seg1) + mse_loss(g_seg2, h_seg2) + mse_loss( g_seg3, h_seg3) + mse_loss(g_seg4, h_seg4) ((h_seg1, h_seg2, h_seg3, h_seg4), (g_seg1, g_seg2, g_seg3, g_seg4)) = image_latents image_latent_loss = mse_loss( g_seg1, h_seg1) + mse_loss(g_seg2, h_seg2) + mse_loss( g_seg3, h_seg3) + mse_loss(g_seg4, h_seg4) encoder_decoder_loss = (image_latent_loss + video_latent_loss) / 2 reconstruct_img = netG.train_autoencoder(se_real_imgs) reconstruct_fake = netG.train_autoencoder(se_fake) reconstruct_loss = ( mse_loss(reconstruct_img, se_real_imgs) + mse_loss(reconstruct_fake, se_fake)) / 2.0 self._logger.add_scalar('G/image_vae_loss', image_latent_loss.data, step) self._logger.add_scalar('G/video_vae_loss', video_latent_loss.data, step) self._logger.add_scalar('G/reconstruct_loss', reconstruct_loss.data, step) characters_mu = (st_labels.mean(1) > 0).type( torch.FloatTensor).cuda() st_mu = torch.cat( (c_mu, st_motion_input[:, :, :cfg.TEXT.DIMENSION].mean( 1).squeeze(), characters_mu), 1) im_mu = torch.cat((im_motion_input, cim_mu), 1) se_errG, se_errG, se_accG = 0, 0, 0 if use_segment: se_errG, se_accG, _ = compute_generator_loss( netD_se, se_fake, se_real_imgs, im_real_labels, im_labels, im_mu, self.gpus) im_errG, im_accG, _ = compute_generator_loss( netD_im, im_fake, im_real_imgs, im_real_labels, im_labels, im_mu, self.gpus) st_errG, st_accG, G_consistency = compute_generator_loss( netD_st, st_fake, st_real_imgs, st_real_labels, st_labels, st_mu, self.gpus) ###### # Sample Image Loss and Sample Video Loss im_kl_loss = KL_loss(cim_mu, cim_logvar) st_kl_loss = KL_loss(c_mu, c_logvar) errG = im_errG + self.ratio * ( image_weight * st_errG + se_errG * segment_weight ) # for record kl_loss = im_kl_loss + self.ratio * st_kl_loss # for record # Total Loss errG_total = im_errG + im_kl_loss * cfg.TRAIN.COEFF.KL \ + self.ratio * (se_errG*segment_weight + st_errG*image_weight + st_kl_loss * cfg.TRAIN.COEFF.KL) if video_latents is not None: errG_total += (video_latent_loss + reconstruct_loss) * cfg.RECONSTRUCT_LOSS errG_total.backward() optimizerG.step() stats.update({ 'G/loss': errG_total.data, 'G/im_KL': im_kl_loss.data, 'G/st_KL': st_kl_loss.data, 'G/KL': kl_loss.data, 'G/consistency': G_consistency, 'Accuracy/im_G': im_accG, 'Accuracy/se_G': se_accG, 'Accuracy/st_G': st_accG, 'G/gan_loss': errG.data, }) count = count + 1 pbar.update(1) if i % 20 == 0: step = i + num_step * epoch for key, value in stats.items(): self._logger.add_scalar(key, value, step) with torch.no_grad(): lr_fake, fake, _, _, _, _, se_fake = netG.sample_videos( st_motion_input, st_content_input, seg=use_segment) st_result = save_story_results(st_real_cpu, fake, st_texts, epoch, self.image_dir, i) if use_segment and se_fake is not None: se_result = save_image_results(None, se_fake) self._logger.add_image("pororo", st_result.transpose(2, 0, 1) / 255, epoch) if use_segment: self._logger.add_image("segment", se_result.transpose(2, 0, 1) / 255, epoch) # Adjust lr if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in st_optimizerD.param_groups: param_group['lr'] = discriminator_lr for param_group in im_optimizerD.param_groups: param_group['lr'] = discriminator_lr lr_decay_step *= 2 g_lr, im_lr, st_lr = 0, 0, 0 for param_group in optimizerG.param_groups: g_lr = param_group['lr'] for param_group in st_optimizerD.param_groups: st_lr = param_group['lr'] for param_group in im_optimizerD.param_groups: im_lr = param_group['lr'] self._logger.add_scalar('learning/generator', g_lr, epoch) self._logger.add_scalar('learning/st_discriminator', st_lr, epoch) self._logger.add_scalar('learning/im_discriminator', im_lr, epoch) if cfg.EVALUATE_FID_SCORE: self.calculate_vfid(netG, epoch, testloader) #self.calculate_ssim(netG, epoch, testloader) time_mins = int((time.time() - c_time) / 60) time_hours = int(time_mins / 60) epoch_mins = int((time.time() - start_t) / 60) epoch_hours = int(epoch_mins / 60) print( "----[{}/{}]Epoch time:{} hours {} mins, Total time:{} hours----" .format(epoch, self.max_epoch, epoch_hours, epoch_mins, time_hours)) if epoch % self.snapshot_interval == 0: save_model(netG, netD_im, netD_st, netD_se, epoch, self.model_dir) #save_test_samples(netG, testloader, self.test_dir) save_model(netG, netD_im, netD_st, netD_se, self.max_epoch, self.model_dir)
def train(self, imageloader, storyloader, testloader): self.imageloader = imageloader self.testloader = testloader self.imagedataset = None self.testdataset = None netG, netD_im, netD_st = self.load_networks() im_real_labels = Variable(torch.FloatTensor(self.imbatch_size).fill_(1)) im_fake_labels = Variable(torch.FloatTensor(self.imbatch_size).fill_(0)) st_real_labels = Variable(torch.FloatTensor(self.stbatch_size).fill_(1)) st_fake_labels = Variable(torch.FloatTensor(self.stbatch_size).fill_(0)) if cfg.CUDA: im_real_labels, im_fake_labels = im_real_labels.cuda(), im_fake_labels.cuda() st_real_labels, st_fake_labels = st_real_labels.cuda(), st_fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH im_optimizerD = \ optim.Adam(netD_im.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) st_optimizerD = \ optim.Adam(netD_st.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in st_optimizerD.param_groups: param_group['lr'] = discriminator_lr for param_group in im_optimizerD.param_groups: param_group['lr'] = discriminator_lr for i, data in enumerate(storyloader, 0): ###################################################### # (1) Prepare training data ###################################################### im_batch = self.sample_real_image_batch() st_batch = data im_real_cpu = im_batch['images'] im_motion_input = im_batch['description'] im_content_input = im_batch['content'] im_content_input = im_content_input.mean(1).squeeze() im_catelabel = im_batch['label'] im_real_imgs = Variable(im_real_cpu) im_motion_input = Variable(im_motion_input) im_content_input = Variable(im_content_input) st_real_cpu = st_batch['images'] st_motion_input = st_batch['description'] st_content_input = st_batch['description'] st_catelabel = st_batch['label'] st_real_imgs = Variable(st_real_cpu) st_motion_input = Variable(st_motion_input) st_content_input = Variable(st_content_input) if cfg.CUDA: st_real_imgs = st_real_imgs.cuda() im_real_imgs = im_real_imgs.cuda() st_motion_input = st_motion_input.cuda() im_motion_input = im_motion_input.cuda() st_content_input = st_content_input.cuda() im_content_input = im_content_input.cuda() im_catelabel = im_catelabel.cuda() st_catelabel = st_catelabel.cuda() ####################################################### # (2) Generate fake stories and images ###################################################### # im_inputs = (im_motion_input, im_content_input) # _, im_fake, im_mu, im_logvar =\ # nn.parallel.data_parallel(netG.sample_images, im_inputs, self.gpus) # st_inputs = (st_motion_input, st_content_input) # _, st_fake, c_mu, c_logvar, m_mu, m_logvar = \ # nn.parallel.data_parallel(netG.sample_videos, st_inputs, self.gpus) im_inputs = (im_motion_input, im_content_input) _, im_fake, im_mu, im_logvar = netG.sample_images(im_motion_input, im_content_input) st_inputs = (st_motion_input, st_content_input) _, st_fake, c_mu, c_logvar, m_mu, m_logvar = netG.sample_videos( st_motion_input, st_content_input) ############################ # (3) Update D network ########################### netD_im.zero_grad() netD_st.zero_grad() im_errD, im_errD_real, im_errD_wrong, im_errD_fake, accD = \ compute_discriminator_loss(netD_im, im_real_imgs, im_fake, im_real_labels, im_fake_labels, im_catelabel, im_mu, self.gpus) st_errD, st_errD_real, st_errD_wrong, st_errD_fake, _ = \ compute_discriminator_loss(netD_st, st_real_imgs, st_fake, st_real_labels, st_fake_labels, st_catelabel, c_mu, self.gpus) im_errD.backward() st_errD.backward() im_optimizerD.step() st_optimizerD.step() ############################ # (2) Update G network ########################### for g_iter in range(2): netG.zero_grad() _, st_fake, c_mu, c_logvar, m_mu, m_logvar = netG.sample_videos( st_motion_input, st_content_input) # st_mu = m_mu.view(cfg.TRAIN.ST_BATCH_SIZE, cfg.VIDEO_LEN, m_mu.shape[1]) # st_mu = st_mu.contiguous().view(-1, cfg.VIDEO_LEN * m_mu.shape[1]) _, im_fake, im_mu, im_logvar = netG.sample_images(im_motion_input, im_content_input) im_errG, accG = compute_generator_loss(netD_im, im_fake, im_real_labels, im_catelabel, im_mu, self.gpus) st_errG, _ = compute_generator_loss(netD_st, st_fake, st_real_labels, st_catelabel, c_mu, self.gpus) im_kl_loss = KL_loss(im_mu, im_logvar) st_kl_loss = KL_loss(m_mu, m_logvar) errG = im_errG + self.ratio * st_errG kl_loss = im_kl_loss + self.ratio * st_kl_loss errG_total = im_errG + self.ratio * st_errG + kl_loss errG_total.backward() optimizerG.step() if i % 100 == 0: # save the image result for each epoch lr_fake, fake, _, _, _, _ = netG.sample_videos(st_motion_input, st_content_input) save_story_results(st_real_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_story_results(None, lr_fake, epoch, self.image_dir) end_t = time.time() print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f accG: %.4f accD: %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(storyloader), st_errD.data, st_errG.data, st_errD_real, st_errD_wrong, st_errD_fake, accG, accD, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD_im, netD_st, epoch, self.model_dir) save_test_samples(netG, self.testloader, self.test_dir) # save_model(netG, netD_im, netD_st, self.max_epoch, self.model_dir)
def _save_story_results(self, st_real_cpu, lr_fake, st_fake, num, output_dir): save_story_results(st_real_cpu, st_fake, num, output_dir, test=True) if lr_fake is not None: save_story_results(None, lr_fake, num, output_dir, test=True)