def test_loop(self): # put to GPU # mse, ssim, and psnr are not available at current settings test_results = {'D_G_z': 0, 'n_samples': 0} naive_results = {'D_G_z': 0, 'n_samples': 0} with torch.no_grad(): self.generator.eval() self.discriminator.eval() test_images = [] for idx, (lr_image, naive_hr_image) in enumerate(tqdm(self.test_loader)): if idx >= self.args.n_save: break cur_batch_size = lr_image.size(0) test_results['n_samples'] += cur_batch_size if torch.cuda.is_available(): lr_image = lr_image.cuda() naive_hr_image = naive_hr_image.cuda() sr_image = self.generator(lr_image) sr_probs, log_sr_probs = self.discriminator(sr_image) test_results['D_G_z'] += sr_probs.data.cpu().sum() naive_sr_probs, naive_log_sr_probs = self.discriminator( naive_hr_image) naive_results['D_G_z'] += naive_sr_probs.data.cpu().sum() lr_image = create_new_lr_image(lr_image, sr_image) for image_idx in range(cur_batch_size): test_images.extend([ display_transform()(lr_image[image_idx].data.cpu()), display_transform()( naive_hr_image[image_idx].data.cpu()), display_transform()(sr_image[image_idx].data.cpu()) ]) if idx == 10: break test_results['D_G_z'] /= test_results['n_samples'] naive_results['D_G_z'] /= test_results['n_samples'] # write to out file result_line = '\tTest\n' for k, v in test_results.items(): result_line += '{} = {}, '.format(k, v) result_line += '\n' for k, v in naive_results.items(): result_line += 'naive_{} = {} '.format(k, v) print(result_line) self.out.write(result_line + '\n') self.save_image(test_images)
def validate(self, epoch): with torch.no_grad(): self.model.eval() val_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'val_size': 0} if not self.naive_results_computed: self.naive_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'val_size': 0} val_images = [] for idx, (lr_image, naive_hr_image, hr_image) in enumerate(tqdm(self.val_loader)): # put data to GPU cur_batch_size = lr_image.size(0) val_results['val_size'] += cur_batch_size if torch.cuda.is_available(): lr_image = lr_image.cuda() naive_hr_image = naive_hr_image.cuda() hr_image = hr_image.cuda() sr_image = self.model(lr_image) batch_mse = ((sr_image - hr_image) ** 2).data.mean() val_results['mse'] += batch_mse * cur_batch_size batch_ssim = pytorch_ssim.ssim(sr_image, hr_image).item() val_results['ssims'] += batch_ssim * cur_batch_size val_results['psnr'] = 10 * math.log10(1 / (val_results['mse'] / val_results['val_size'])) val_results['ssim'] = val_results['ssims'] / val_results['val_size'] if not self.naive_results_computed: naive_batch_mse = ((naive_hr_image - hr_image) ** 2).data.mean() self.naive_results['mse'] += naive_batch_mse * cur_batch_size naive_batch_ssim = pytorch_ssim.ssim(naive_hr_image, hr_image).item() self.naive_results['ssims'] += naive_batch_ssim * cur_batch_size self.naive_results['psnr'] = 10 * math.log10(1 / (self.naive_results['mse'] / val_results['val_size'])) self.naive_results['ssim'] = self.naive_results['ssims'] / val_results['val_size'] # only save certain number of images # transform does not support batch processing lr_image = create_new_lr_image(lr_image, hr_image) if idx < self.args.n_save: for image_idx in range(cur_batch_size): val_images.extend( [display_transform()(lr_image[image_idx].data.cpu()), display_transform()(naive_hr_image[image_idx].data.cpu()), display_transform()(hr_image[image_idx].data.cpu()), display_transform()(sr_image[image_idx].data.cpu())]) # write to out file result_line = '\tVal\t' for k, v in val_results.items(): result_line += '{} = {} '.format(k, v) if not self.naive_results_computed: result_line += '\n' for k, v in self.naive_results.items(): result_line += 'naive_{} = {} '.format(k, v) self.naive_results_computed = True print(result_line) self.out.write(result_line+'\n') self.out.flush() # save model torch.save(self.model.state_dict(), os.path.join(self.model_dir, str(epoch)+'.pth')) self.save_image(val_images, epoch)
def gan_validate(self, epoch): with torch.no_grad(): self.generator.eval() self.discriminator.eval() val_results = {'mse_loss': 0, 'D_G_z':0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'n_samples': 0} if not self.naive_results_computed: self.naive_results = {'mse_loss': 0, 'D_G_z':0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'n_samples': 0} # TODO: to finish val_images = [] for idx, (lr_image, naive_hr_image, hr_image) in enumerate(tqdm(self.val_loader)): # put data to GPU cur_batch_size = lr_image.size(0) val_results['n_samples'] += cur_batch_size if torch.cuda.is_available(): lr_image = lr_image.cuda() naive_hr_image = naive_hr_image.cuda() hr_image = hr_image.cuda() sr_image = self.generator(lr_image) sr_probs, log_sr_probs = self.discriminator(sr_image) val_results['D_G_z'] += sr_probs.data.cpu().sum() mse_loss = self.mse_loss(input=sr_image, target=hr_image) val_results['mse_loss'] += mse_loss.data.cpu() * cur_batch_size batch_ssim = pytorch_ssim.ssim(sr_image, hr_image).item() val_results['ssims'] += batch_ssim * cur_batch_size val_results['psnr'] = 10 * math.log10(1 / (val_results['mse_loss'] / val_results['n_samples'])) val_results['ssim'] = val_results['ssims'] / val_results['n_samples'] # to save memory naive_sr_probs, naive_log_sr_probs = self.discriminator(naive_hr_image) self.naive_results['D_G_z'] += naive_sr_probs.data.cpu().sum() if not self.naive_results_computed: naive_mse_loss = self.mse_loss(input=naive_hr_image, target=hr_image).data.cpu() self.naive_results['mse_loss'] += naive_mse_loss * cur_batch_size naive_batch_ssim = pytorch_ssim.ssim(naive_hr_image, hr_image).item() self.naive_results['ssims'] += naive_batch_ssim * cur_batch_size self.naive_results['psnr'] = 10 * math.log10(1 / (self.naive_results['mse_loss'] / val_results['n_samples'])) self.naive_results['ssim'] = self.naive_results['ssims'] / val_results['n_samples'] # only save certain number of images # transform does not support batch processing lr_image = create_new_lr_image(lr_image, hr_image) if idx < self.args.n_save: for image_idx in range(cur_batch_size): val_images.extend( [display_transform()(lr_image[image_idx].data.cpu()), display_transform()(naive_hr_image[image_idx].data.cpu()), display_transform()(hr_image[image_idx].data.cpu()), display_transform()(sr_image[image_idx].data.cpu())]) # if idx == 5: # break val_results['D_G_z'] = val_results['D_G_z'] / val_results['n_samples'] val_results['mse_loss'] = val_results['mse_loss'] / val_results['n_samples'] # write to out file result_line = '\tVal\t' for k, v in val_results.items(): result_line += '{} = {}, '.format(k, v) self.writer.add_scalar('val/{}'.format(k), v, epoch) if not self.naive_results_computed: result_line += '\n' self.naive_results['D_G_z'] = self.naive_results['D_G_z'] / val_results['n_samples'] for k, v in self.naive_results.items(): result_line += 'naive_{} = {} '.format(k, v) self.naive_results_computed = True else: result_line += '\n\t' self.naive_results['D_G_z'] = self.naive_results['D_G_z'] / val_results['n_samples'] result_line += 'naive D_G_z = {}'.format(self.naive_results['D_G_z']/val_results['n_samples']) print(result_line) self.out.write(result_line+'\n') self.out.flush() self.out.flush() # save model torch.save((self.generator.state_dict(), self.discriminator.state_dict()), os.path.join(self.model_dir, str(epoch)+'.pth')) self.save_image(val_images, epoch)
def test_loop(self): # mse, ssim, and psnr are not available at current settings test_results = {'mse_loss': 0, 'D_G_z':0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'n_samples': 0, 'D_x': 0} naive_results = {'mse_loss': 0, 'D_G_z': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'n_samples': 0} with torch.no_grad(): self.generator.eval() self.discriminator.eval() test_images = [] for idx, (lr_image, naive_hr_image, hr_image) in enumerate(tqdm(self.test_loader)): # if idx >= self.args.n_save: # break cur_batch_size = lr_image.size(0) test_results['n_samples'] += cur_batch_size if torch.cuda.is_available(): lr_image = lr_image.cuda() naive_hr_image = naive_hr_image.cuda() hr_image = hr_image.cuda() hr_probs, log_hr_probs = self.discriminator(hr_image) test_results['D_x'] += hr_probs.data.cpu().sum() sr_image = self.generator(lr_image) sr_probs, log_sr_probs = self.discriminator(sr_image) test_results['D_G_z'] += sr_probs.data.cpu().sum() lr_image = create_new_lr_image(lr_image, hr_image) sr_image = create_new_lr_image(sr_image, hr_image) naive_hr_image = create_new_lr_image(naive_hr_image, hr_image) naive_sr_probs, naive_log_sr_probs = self.discriminator(naive_hr_image) naive_results['D_G_z'] += naive_sr_probs.data.cpu().sum() mse_loss = self.mse_loss(input=sr_image, target=hr_image) test_results['mse_loss'] += mse_loss.data.cpu() * cur_batch_size batch_ssim = pytorch_ssim.ssim(sr_image, hr_image).item() test_results['ssims'] += batch_ssim * cur_batch_size test_results['psnr'] = 10 * math.log10(1 / (test_results['mse_loss'] / test_results['n_samples'])) test_results['ssim'] = test_results['ssims'] / test_results['n_samples'] naive_mse_loss = self.mse_loss(input=naive_hr_image, target=hr_image).data.cpu() naive_results['mse_loss'] += naive_mse_loss * cur_batch_size naive_batch_ssim = pytorch_ssim.ssim(naive_hr_image, hr_image).item() naive_results['ssims'] += naive_batch_ssim * cur_batch_size naive_results['psnr'] = 10 * math.log10(1 / (naive_results['mse_loss'] / test_results['n_samples'])) naive_results['ssim'] = naive_results['ssims'] / test_results['n_samples'] for image_idx in range(cur_batch_size): test_images.extend( [display_transform()(lr_image[image_idx].data.cpu()), display_transform()(naive_hr_image[image_idx].data.cpu()), display_transform()(hr_image[image_idx].data.cpu()), display_transform()(sr_image[image_idx].data.cpu())]) # if idx == 1: # break test_results['D_G_z'] /= test_results['n_samples'] test_results['D_x'] /= test_results['n_samples'] naive_results['D_G_z'] /= test_results['n_samples'] test_results['mse_loss'] /= test_results['n_samples'] naive_results['mse_loss'] /= test_results['n_samples'] # write to out file result_line = '\tTest\n' for k, v in test_results.items(): result_line += '{} = {}, '.format(k, v) result_line += '\n' for k, v in naive_results.items(): result_line += 'naive_{} = {} '.format(k, v) print(result_line) self.out.write(result_line+'\n') self.save_image_single(test_images)