def train(self): opt = self.opt gpu_ids = range(torch.cuda.device_count()) print('Number of GPUs in use {}'.format(gpu_ids)) iteration = 0 vae = VAE(hallucination=self.useHallucination, opt=opt).cuda() if torch.cuda.device_count() > 1: vae = nn.DataParallel(vae).cuda() objective_func = losses.losses_multigpu_only_mask( opt, vae.module.floww) print(self.jobname) optimizer = optim.Adam(vae.parameters(), lr=opt.lr_rate) if self.load: model_name = self.sampledir + '/{:06d}_model.pth.tar'.format( self.iter_to_load) print("loading model from {}".format(model_name)) state_dict = torch.load(model_name) if torch.cuda.device_count() > 1: vae.module.load_state_dict(state_dict['vae']) optimizer.load_state_dict(state_dict['optimizer']) else: vae.load_state_dict(state_dict['vae']) optimizer.load_state_dict(state_dict['optimizer']) iteration = self.iter_to_load + 1 for epoch in range(opt.num_epochs): print('Epoch {}/{}'.format(epoch, opt.num_epochs - 1)) print('-' * 10) for sample, mask in iter(self.trainloader): # get the inputs data = sample.cuda() mask = mask.cuda() frame1 = data[:, 0, :, :, :] frame2 = data[:, 1:, :, :, :] noise_bg = torch.randn(frame1.size()).cuda() # torch.cuda.synchronize() start = time.time() # Set train mode vae.train() # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize y_pred_before_refine, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw, prediction_vgg_feature, gt_vgg_feature = vae( frame1, data, mask, noise_bg) # Compute losses flowloss, reconloss, reconloss_back, reconloss_before, kldloss, flowcon, sim_loss, vgg_loss, mask_loss = objective_func( frame1, frame2, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw, prediction_vgg_feature, gt_vgg_feature, y_pred_before_refine=y_pred_before_refine) loss = (flowloss + 2. * reconloss + reconloss_back + reconloss_before + kldloss * self.opt.lamda + flowcon + sim_loss + vgg_loss + 0.1 * mask_loss) # backward loss.backward() # Update optimizer.step() end = time.time() # print statistics if iteration % 20 == 0: print( "iter {} (epoch {}), recon_loss = {:.6f}, recon_loss_back = {:.3f}, recon_loss_before = {:.3f}, " "flow_loss = {:.6f}, flow_consist = {:.3f}, " "kl_loss = {:.6f}, img_sim_loss= {:.3f}, vgg_loss= {:.3f}, mask_loss={:.3f}, time/batch = {:.3f}" .format(iteration, epoch, reconloss.item(), reconloss_back.item(), reconloss_before.item(), flowloss.item(), flowcon.item(), kldloss.item(), sim_loss.item(), vgg_loss.item(), mask_loss.item(), end - start)) if iteration % 2000 == 0: # Set to evaluation mode (randomly sample z from the whole distribution) with torch.no_grad(): vae.eval() val_sample, val_mask, _ = iter(self.testloader).next() # Read data data = val_sample.cuda() mask = val_mask.cuda() frame1 = data[:, 0, :, :, :] noise_bg = torch.randn(frame1.size()).cuda() y_pred_before_refine, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw = vae( frame1, data, mask, noise_bg) utils.save_samples(data, y_pred_before_refine, y_pred, flow, mask_fw, mask_bw, iteration, self.sampledir, opt, eval=True, useMask=True) # Save model's parameter checkpoint_path = self.sampledir + '/{:06d}_model.pth.tar'.format( iteration) print("model saved to {}".format(checkpoint_path)) if torch.cuda.device_count() > 1: torch.save( { 'vae': vae.state_dict(), 'optimizer': optimizer.state_dict() }, checkpoint_path) else: torch.save( { 'vae': vae.module.state_dict(), 'optimizer': optimizer.state_dict() }, checkpoint_path) iteration += 1
def test(self): opt = self.opt gpu_ids = range(torch.cuda.device_count()) print('Number of GPUs in use {}'.format(gpu_ids)) iteration = 0 if torch.cuda.device_count() > 1: vae = nn.DataParallel(VAE(hallucination=self.useHallucination, opt=opt, refine=self.refine, bg=128, fg=896), device_ids=gpu_ids).cuda() else: vae = VAE(hallucination=self.useHallucination, opt=opt).cuda() print(self.jobname) if self.load: model_name = '../pretrained_models/cityscapes/refine_genmask_w_mask_two_path_096000.pth.tar' # model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load) print("loading model from {}".format(model_name)) state_dict = torch.load(model_name) if torch.cuda.device_count() > 1: vae.module.load_state_dict(state_dict['vae']) else: vae.load_state_dict(state_dict['vae']) z_noise = torch.ones(1, 1024).normal_() for data, bg_mask, fg_mask, paths in tqdm(iter(self.testloader)): # Set to evaluation mode (randomly sample z from the whole distribution) vae.eval() # If test on generated images # data = data.unsqueeze(1) # data = data.repeat(1, opt.num_frames, 1, 1, 1) frame1 = data[:, 0, :, :, :] noise_bg = torch.randn(frame1.size()) z_m = Vb(z_noise.repeat(frame1.size()[0] * 8, 1)) # y_pred_before_refine, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw = vae( frame1, data, bg_mask, fg_mask, noise_bg, z_m) utils.save_samples(data, y_pred_before_refine, y_pred, flow, mask_fw, mask_bw, iteration, self.sampledir, opt, eval=True, useMask=True, grid=[4, 4]) '''save images''' utils.save_images(self.output_image_dir, data, y_pred, paths, opt) utils.save_images(self.output_image_dir_before, data, y_pred_before_refine, paths, opt) data = data.cpu().data.transpose(2, 3).transpose(3, 4).numpy() utils.save_gif( data * 255, opt.num_frames, [4, 4], self.sampledir + '/{:06d}_real.gif'.format(iteration)) '''save flows''' utils.save_flows(self.output_fw_flow_dir, flow, paths) utils.save_flows(self.output_bw_flow_dir, flowback, paths) '''save occlusion maps''' utils.save_occ_map(self.output_fw_mask_dir, mask_fw, paths) utils.save_occ_map(self.output_bw_mask_dir, mask_bw, paths) iteration += 1
def test(self): opt = self.opt gpu_ids = range(torch.cuda.device_count()) print('Number of GPUs in use {}'.format(gpu_ids)) iteration = 0 if torch.cuda.device_count() > 1: vae = nn.DataParallel(VAE(hallucination=self.useHallucination, opt=opt, refine=self.refine, bg=128, fg=896), device_ids=gpu_ids).cuda() else: vae = VAE(hallucination=self.useHallucination, opt=opt).cuda() print(self.jobname) if self.load: # model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load) model_name = '../pretrained_models/cityscapes/refine_genmask_w_mask_two_path_096000.pth.tar' print("loading model from {}".format(model_name)) state_dict = torch.load(model_name) if torch.cuda.device_count() > 1: vae.module.load_state_dict(state_dict['vae']) else: vae.load_state_dict(state_dict['vae']) z_noise = torch.ones(1, 1024).normal_() for data, bg_mask, fg_mask, paths in tqdm(iter(self.testloader)): # Set to evaluation mode (randomly sample z from the whole distribution) vae.eval() # If test on generated images # data = data.unsqueeze(1) # data = data.repeat(1, opt.num_frames, 1, 1, 1) frame1 = data[:, 0, :, :, :] noise_bg = torch.randn(frame1.size()) z_m = Vb(z_noise.repeat(frame1.size()[0] * 8, 1)) y_pred_before_refine, y_pred, flow, flowback, mask_fw, mask_bw, warped_mask_bg, warped_mask_fg = vae( frame1, data, bg_mask, fg_mask, noise_bg, z_m) '''iterative generation''' for i in range(5): noise_bg = torch.randn(frame1.size()) y_pred_before_refine_1, y_pred_1, flow_1, flowback_1, mask_fw_1, mask_bw_1, warped_mask_bg, warped_mask_fg = vae( y_pred[:, -1, ...], y_pred, warped_mask_bg, warped_mask_fg, noise_bg, z_m) y_pred_before_refine = torch.cat( [y_pred_before_refine, y_pred_before_refine_1], 1) y_pred = torch.cat([y_pred, y_pred_1], 1) flow = torch.cat([flow, flow_1], 2) flowback = torch.cat([flowback, flowback_1], 2) mask_fw = torch.cat([mask_fw, mask_fw_1], 1) mask_bw = torch.cat([mask_bw, mask_bw_1], 1) print(y_pred_before_refine.size()) utils.save_samples(data, y_pred_before_refine, y_pred, flow, mask_fw, mask_bw, iteration, self.sampledir, opt, eval=True, useMask=True, grid=[4, 4]) # '''save images''' utils.save_images(self.output_image_dir, data, y_pred, paths, opt) utils.save_images(self.output_image_dir_before, data, y_pred_before_refine, paths, opt) iteration += 1