def test_loop(self):
        # put to GPU

        # mse, ssim, and psnr are not available at current settings
        test_results = {'D_G_z': 0, 'n_samples': 0}

        naive_results = {'D_G_z': 0, 'n_samples': 0}

        with torch.no_grad():
            self.generator.eval()
            self.discriminator.eval()
            test_images = []
            for idx, (lr_image,
                      naive_hr_image) in enumerate(tqdm(self.test_loader)):
                if idx >= self.args.n_save:
                    break
                cur_batch_size = lr_image.size(0)
                test_results['n_samples'] += cur_batch_size

                if torch.cuda.is_available():
                    lr_image = lr_image.cuda()
                    naive_hr_image = naive_hr_image.cuda()

                sr_image = self.generator(lr_image)
                sr_probs, log_sr_probs = self.discriminator(sr_image)
                test_results['D_G_z'] += sr_probs.data.cpu().sum()

                naive_sr_probs, naive_log_sr_probs = self.discriminator(
                    naive_hr_image)
                naive_results['D_G_z'] += naive_sr_probs.data.cpu().sum()

                lr_image = create_new_lr_image(lr_image, sr_image)
                for image_idx in range(cur_batch_size):
                    test_images.extend([
                        display_transform()(lr_image[image_idx].data.cpu()),
                        display_transform()(
                            naive_hr_image[image_idx].data.cpu()),
                        display_transform()(sr_image[image_idx].data.cpu())
                    ])

                if idx == 10:
                    break

            test_results['D_G_z'] /= test_results['n_samples']
            naive_results['D_G_z'] /= test_results['n_samples']

            # write to out file
            result_line = '\tTest\n'
            for k, v in test_results.items():
                result_line += '{} = {}, '.format(k, v)

            result_line += '\n'
            for k, v in naive_results.items():
                result_line += 'naive_{} = {} '.format(k, v)
            print(result_line)
            self.out.write(result_line + '\n')
            self.save_image(test_images)
Exemplo n.º 2
0
	def validate(self, epoch):
		with torch.no_grad():
			self.model.eval()
			val_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'val_size': 0}

			if not self.naive_results_computed:
				self.naive_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'val_size': 0}

			val_images = []
			for idx, (lr_image, naive_hr_image, hr_image) in enumerate(tqdm(self.val_loader)):
				# put data to GPU

				cur_batch_size = lr_image.size(0)
				val_results['val_size'] += cur_batch_size

				if torch.cuda.is_available():
					lr_image = lr_image.cuda()
					naive_hr_image = naive_hr_image.cuda()
					hr_image = hr_image.cuda()

				sr_image = self.model(lr_image)

				batch_mse = ((sr_image - hr_image) ** 2).data.mean()
				val_results['mse'] += batch_mse * cur_batch_size
				batch_ssim = pytorch_ssim.ssim(sr_image, hr_image).item()
				val_results['ssims'] += batch_ssim * cur_batch_size
				val_results['psnr'] = 10 * math.log10(1 / (val_results['mse'] / val_results['val_size']))
				val_results['ssim'] = val_results['ssims'] / val_results['val_size']

				if not self.naive_results_computed:
					naive_batch_mse = ((naive_hr_image - hr_image) ** 2).data.mean()
					self.naive_results['mse'] += naive_batch_mse * cur_batch_size
					naive_batch_ssim = pytorch_ssim.ssim(naive_hr_image, hr_image).item()
					self.naive_results['ssims'] += naive_batch_ssim * cur_batch_size
					self.naive_results['psnr'] = 10 * math.log10(1 / (self.naive_results['mse'] / val_results['val_size']))
					self.naive_results['ssim'] = self.naive_results['ssims'] / val_results['val_size']

				# only save certain number of images

				# transform does not support batch processing
				lr_image = create_new_lr_image(lr_image, hr_image)
				if idx < self.args.n_save:
					for image_idx in range(cur_batch_size):


						val_images.extend(
							[display_transform()(lr_image[image_idx].data.cpu()),
							 display_transform()(naive_hr_image[image_idx].data.cpu()),
							 display_transform()(hr_image[image_idx].data.cpu()),
							 display_transform()(sr_image[image_idx].data.cpu())])

			# write to out file
			result_line = '\tVal\t'
			for k, v in val_results.items():
				result_line += '{} = {} '.format(k, v)

			if not self.naive_results_computed:
				result_line += '\n'
				for k, v in self.naive_results.items():
					result_line += 'naive_{} = {} '.format(k, v)
				self.naive_results_computed = True

			print(result_line)
			self.out.write(result_line+'\n')
			self.out.flush()
			# save model
			torch.save(self.model.state_dict(), os.path.join(self.model_dir, str(epoch)+'.pth'))

			self.save_image(val_images, epoch)
Exemplo n.º 3
0
	def gan_validate(self, epoch):
		with torch.no_grad():
			self.generator.eval()
			self.discriminator.eval()
			val_results = {'mse_loss': 0, 'D_G_z':0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'n_samples': 0}

			if not self.naive_results_computed:
				self.naive_results = {'mse_loss': 0, 'D_G_z':0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'n_samples': 0}

			# TODO: to finish
			val_images = []
			for idx, (lr_image, naive_hr_image, hr_image) in enumerate(tqdm(self.val_loader)):
				# put data to GPU

				cur_batch_size = lr_image.size(0)
				val_results['n_samples'] += cur_batch_size

				if torch.cuda.is_available():
					lr_image = lr_image.cuda()
					naive_hr_image = naive_hr_image.cuda()
					hr_image = hr_image.cuda()

				sr_image = self.generator(lr_image)
				sr_probs, log_sr_probs = self.discriminator(sr_image)
				val_results['D_G_z'] += sr_probs.data.cpu().sum()

				mse_loss = self.mse_loss(input=sr_image, target=hr_image)
				val_results['mse_loss'] += mse_loss.data.cpu() * cur_batch_size

				batch_ssim = pytorch_ssim.ssim(sr_image, hr_image).item()
				val_results['ssims'] += batch_ssim * cur_batch_size
				val_results['psnr'] = 10 * math.log10(1 / (val_results['mse_loss'] / val_results['n_samples']))
				val_results['ssim'] = val_results['ssims'] / val_results['n_samples']

				# to save memory
				naive_sr_probs, naive_log_sr_probs = self.discriminator(naive_hr_image)

				self.naive_results['D_G_z'] += naive_sr_probs.data.cpu().sum()

				if not self.naive_results_computed:
					naive_mse_loss = self.mse_loss(input=naive_hr_image, target=hr_image).data.cpu()
					self.naive_results['mse_loss'] += naive_mse_loss * cur_batch_size
					naive_batch_ssim = pytorch_ssim.ssim(naive_hr_image, hr_image).item()
					self.naive_results['ssims'] += naive_batch_ssim * cur_batch_size
					self.naive_results['psnr'] = 10 * math.log10(1 / (self.naive_results['mse_loss'] / val_results['n_samples']))
					self.naive_results['ssim'] = self.naive_results['ssims'] / val_results['n_samples']

				# only save certain number of images

				# transform does not support batch processing
				lr_image = create_new_lr_image(lr_image, hr_image)
				if idx < self.args.n_save:
					for image_idx in range(cur_batch_size):
						val_images.extend(
							[display_transform()(lr_image[image_idx].data.cpu()),
							 display_transform()(naive_hr_image[image_idx].data.cpu()),
							 display_transform()(hr_image[image_idx].data.cpu()),
							 display_transform()(sr_image[image_idx].data.cpu())])

				# if idx == 5:
				# 	break

			val_results['D_G_z'] = val_results['D_G_z'] / val_results['n_samples']
			val_results['mse_loss'] = val_results['mse_loss'] / val_results['n_samples']
			# write to out file
			result_line = '\tVal\t'
			for k, v in val_results.items():
				result_line += '{} = {}, '.format(k, v)
				self.writer.add_scalar('val/{}'.format(k), v, epoch)

			if not self.naive_results_computed:
				result_line += '\n'
				self.naive_results['D_G_z'] = self.naive_results['D_G_z'] / val_results['n_samples']
				for k, v in self.naive_results.items():
					result_line += 'naive_{} = {} '.format(k, v)
				self.naive_results_computed = True
			else:
				result_line += '\n\t'
				self.naive_results['D_G_z'] = self.naive_results['D_G_z'] / val_results['n_samples']
				result_line += 'naive D_G_z = {}'.format(self.naive_results['D_G_z']/val_results['n_samples'])

			print(result_line)
			self.out.write(result_line+'\n')
			self.out.flush()

			self.out.flush()
			# save model
			torch.save((self.generator.state_dict(), self.discriminator.state_dict()),
			           os.path.join(self.model_dir, str(epoch)+'.pth'))

			self.save_image(val_images, epoch)
Exemplo n.º 4
0
	def test_loop(self):

		# mse, ssim, and psnr are not available at current settings
		test_results = {'mse_loss': 0, 'D_G_z':0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'n_samples': 0, 'D_x': 0}

		naive_results = {'mse_loss': 0, 'D_G_z': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'n_samples': 0}

		with torch.no_grad():
			self.generator.eval()
			self.discriminator.eval()
			test_images = []
			for idx, (lr_image, naive_hr_image, hr_image) in enumerate(tqdm(self.test_loader)):
				# if idx >= self.args.n_save:
				# 	break
				cur_batch_size = lr_image.size(0)
				test_results['n_samples'] += cur_batch_size

				if torch.cuda.is_available():
					lr_image = lr_image.cuda()
					naive_hr_image = naive_hr_image.cuda()
					hr_image = hr_image.cuda()
				hr_probs, log_hr_probs = self.discriminator(hr_image)

				test_results['D_x'] += hr_probs.data.cpu().sum()

				sr_image = self.generator(lr_image)
				sr_probs, log_sr_probs = self.discriminator(sr_image)
				test_results['D_G_z'] += sr_probs.data.cpu().sum()

				lr_image = create_new_lr_image(lr_image, hr_image)
				sr_image = create_new_lr_image(sr_image, hr_image)
				naive_hr_image = create_new_lr_image(naive_hr_image, hr_image)
				naive_sr_probs, naive_log_sr_probs = self.discriminator(naive_hr_image)
				naive_results['D_G_z'] += naive_sr_probs.data.cpu().sum()

				mse_loss = self.mse_loss(input=sr_image, target=hr_image)
				test_results['mse_loss'] += mse_loss.data.cpu() * cur_batch_size

				batch_ssim = pytorch_ssim.ssim(sr_image, hr_image).item()
				test_results['ssims'] += batch_ssim * cur_batch_size
				test_results['psnr'] = 10 * math.log10(1 / (test_results['mse_loss'] / test_results['n_samples']))
				test_results['ssim'] = test_results['ssims'] / test_results['n_samples']

				naive_mse_loss = self.mse_loss(input=naive_hr_image, target=hr_image).data.cpu()
				naive_results['mse_loss'] += naive_mse_loss * cur_batch_size
				naive_batch_ssim = pytorch_ssim.ssim(naive_hr_image, hr_image).item()
				naive_results['ssims'] += naive_batch_ssim * cur_batch_size
				naive_results['psnr'] = 10 * math.log10(1 / (naive_results['mse_loss'] / test_results['n_samples']))
				naive_results['ssim'] = naive_results['ssims'] / test_results['n_samples']

				for image_idx in range(cur_batch_size):
					test_images.extend(
						[display_transform()(lr_image[image_idx].data.cpu()),
						 display_transform()(naive_hr_image[image_idx].data.cpu()),
						 display_transform()(hr_image[image_idx].data.cpu()),
						 display_transform()(sr_image[image_idx].data.cpu())])

				# if idx == 1:
				# 	break


			test_results['D_G_z'] /= test_results['n_samples']
			test_results['D_x'] /= test_results['n_samples']
			naive_results['D_G_z'] /= test_results['n_samples']

			test_results['mse_loss'] /= test_results['n_samples']
			naive_results['mse_loss'] /= test_results['n_samples']

			# write to out file
			result_line = '\tTest\n'
			for k, v in test_results.items():
				result_line += '{} = {}, '.format(k, v)

			result_line += '\n'
			for k, v in naive_results.items():
				result_line += 'naive_{} = {} '.format(k, v)
			print(result_line)
			self.out.write(result_line+'\n')
			self.save_image_single(test_images)