def test_ssim_reduction_and_full(reduction: str, full: bool, expectation: Any, prediction: torch.Tensor, target: torch.Tensor, device: str) -> None: prediction = prediction.to(device) target = target.to(device) with expectation: ssim(prediction, target, data_range=1., reduction=reduction, full=full)
def test_ssim_fails_for_incorrect_data_range(x: torch.Tensor, y: torch.Tensor, device: str) -> None: # Scale to [0, 255] x_scaled = (x * 255).type(torch.uint8) y_scaled = (y * 255).type(torch.uint8) with pytest.raises(AssertionError): ssim(x_scaled.to(device), y_scaled.to(device), data_range=1.0)
def test_ssim_simmular_to_matlab_implementation(): # Greyscale images goldhill = torch.tensor(imread('tests/assets/goldhill.gif'))[None, None, ...] goldhill_jpeg = torch.tensor( imread('tests/assets/goldhill_jpeg.gif'))[None, None, ...] score = ssim(goldhill_jpeg, goldhill, data_range=255, reduction='none') # Output of http://www.cns.nyu.edu/~lcv/ssim/ssim.m score_baseline = torch.tensor(0.8202) assert torch.isclose(score, score_baseline, atol=1e-4), \ f'Expected PyTorch score to be equal to MATLAB prediction. Got {score} and {score_baseline}' # RGB images I01 = torch.tensor(imread('tests/assets/I01.BMP')).permute(2, 0, 1)[None, ...] i1_01_5 = torch.tensor(imread('tests/assets/i01_01_5.bmp')).permute( 2, 0, 1)[None, ...] score = ssim(i1_01_5, I01, data_range=255, reduction='none') # Output of http://www.cns.nyu.edu/~lcv/ssim/ssim.m # score_baseline = torch.tensor(0.7820) score_baseline = torch.tensor(0.7842) assert torch.isclose(score, score_baseline, atol=1e-2), \ f'Expected PyTorch score to be equal to MATLAB prediction. Got {score} and {score_baseline}'
def test_ssim_symmetry(x_y_4d_5d, device: str) -> None: x = x_y_4d_5d[0].to(device) y = x_y_4d_5d[1].to(device) measure = ssim(x, y, data_range=1., reduction='none') reverse_measure = ssim(y, x, data_range=1., reduction='none') assert torch.allclose(measure, reverse_measure), f'Expect: SSIM(a, b) == SSIM(b, a), ' \ f'got {measure} != {reverse_measure}'
def test_ssim_reduction(x: torch.Tensor, y: torch.Tensor, device: str) -> None: for mode in ['mean', 'sum', 'none']: ssim(x.to(device), y.to(device), reduction=mode) for mode in [None, 'n', 2]: with pytest.raises(KeyError): ssim(x.to(device), y.to(device), reduction=mode)
def test_ssim_reduction(prediction: torch.Tensor, target: torch.Tensor, device: str) -> None: for mode in ['mean', 'sum', 'none']: ssim(prediction.to(device), target.to(device), reduction=mode) for mode in [None, 'n', 2]: with pytest.raises(KeyError): ssim(prediction.to(device), target.to(device), reduction=mode)
def test_ssim_symmetry(prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: prediction = prediction_target_4d_5d[0].to(device) target = prediction_target_4d_5d[1].to(device) measure = ssim(prediction, target, data_range=1., reduction='none') reverse_measure = ssim(target, prediction, data_range=1., reduction='none') assert torch.allclose(measure, reverse_measure), f'Expect: SSIM(a, b) == SSIM(b, a), ' \ f'got {measure} != {reverse_measure}'
def test_ssim_raises_if_kernel_size_greater_than_image(x_y_4d_5d, device: str) -> None: x = x_y_4d_5d[0].to(device) y = x_y_4d_5d[1].to(device) kernel_size = 11 wrong_size_x = x[:, :, :kernel_size - 1, :kernel_size - 1] wrong_size_y = y[:, :, :kernel_size - 1, :kernel_size - 1] with pytest.raises(ValueError): ssim(wrong_size_x, wrong_size_y, kernel_size=kernel_size)
def test_ssim_fails_for_incorrect_data_range(prediction: torch.Tensor, target: torch.Tensor, device: str) -> None: # Scale to [0, 255] prediction_scaled = (prediction * 255).type(torch.uint8) target_scaled = (target * 255).type(torch.uint8) with pytest.raises(AssertionError): ssim(prediction_scaled.to(device), target_scaled.to(device), data_range=1.0)
def test_ssim_check_kernel_size_is_passed(x_y_4d_5d, device: str) -> None: x = x_y_4d_5d[0].to(device) y = x_y_4d_5d[1].to(device) kernel_sizes = list(range(0, 50)) for kernel_size in kernel_sizes: if kernel_size % 2: ssim(x, y, kernel_size=kernel_size) else: with pytest.raises(AssertionError): ssim(x, y, kernel_size=kernel_size)
def test_ssim_check_kernel_size_is_passed(prediction_target_4d_5d: Tuple[ torch.Tensor, torch.Tensor], device: str) -> None: prediction = prediction_target_4d_5d[0].to(device) target = prediction_target_4d_5d[1].to(device) kernel_sizes = list(range(0, 50)) for kernel_size in kernel_sizes: if kernel_size % 2: ssim(prediction, target, kernel_size=kernel_size) else: with pytest.raises(AssertionError): ssim(prediction, target, kernel_size=kernel_size)
def test_ssim_raises_if_kernel_size_greater_than_image( prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: prediction = prediction_target_4d_5d[0].to(device) target = prediction_target_4d_5d[1].to(device) kernel_size = 11 wrong_size_prediction = prediction[:, :, :kernel_size - 1, :kernel_size - 1] wrong_size_target = target[:, :, :kernel_size - 1, :kernel_size - 1] with pytest.raises(ValueError): ssim(wrong_size_prediction, wrong_size_target, kernel_size=kernel_size)
def test_ssim_check_available_dimensions() -> None: custom_x = torch.rand(256, 256) custom_y = torch.rand(256, 256) for _ in range(10): if custom_x.dim() < 5: try: ssim(custom_x, custom_y) except Exception as e: pytest.fail(f"Unexpected error occurred: {e}") else: with pytest.raises(AssertionError): ssim(custom_x, custom_y) custom_x.unsqueeze_(0) custom_y.unsqueeze_(0)
def test_ssim_supports_different_data_ranges(input_tensors: Tuple[ torch.Tensor, torch.Tensor], data_range, device: str) -> None: x, y = input_tensors x_scaled = (x * data_range).type(torch.uint8) y_scaled = (y * data_range).type(torch.uint8) measure_scaled = ssim(x_scaled.to(device), y_scaled.to(device), data_range=data_range) measure = ssim(x_scaled.to(device) / float(data_range), y_scaled.to(device) / float(data_range), data_range=1.0) diff = torch.abs(measure_scaled - measure) assert diff <= 1e-6, f'Result for same tensor with different data_range should be the same, got {diff}'
def test_ssim_raises_if_tensors_have_different_shapes(x_y_4d_5d, device) -> None: y = x_y_4d_5d[1].to(device) dims = [[3], [2, 3], [161, 162], [161, 162]] if y.dim() == 5: dims += [[2, 3]] for size in list(itertools.product(*dims)): wrong_shape_x = torch.rand(size).to(y) if wrong_shape_x.size() == y.size(): try: ssim(wrong_shape_x, y) except Exception as e: pytest.fail(f"Unexpected error occurred: {e}") else: with pytest.raises(AssertionError): ssim(wrong_shape_x, y)
def eval(self, gt, pred): with torch.no_grad(): gt_tensor = torch.Tensor(gt).clamp(0, 1).permute(0, 3, 1, 2).to('cuda:0') pred_tensor = torch.Tensor(pred).clamp(0, 1).permute(0, 3, 1, 2).to('cuda:0') psnr_index = piq.psnr(pred_tensor, gt_tensor, data_range=1., reduction='none').item() _, _, h, w = gt_tensor.shape lpipsAlex = 0 lpipsVGG = 0 msssim_index = 0 ssim_index = 0 n = 1 for i in range(n): for j in range(n): xstart = w // n * j ystart = h // n * i xend = w // n * (j + 1) yend = h // n * (i + 1) ssim_index += piq.ssim(pred_tensor[:, :, ystart:yend, xstart:xend], gt_tensor[:, :, ystart:yend, xstart:xend], data_range=1., reduction='none').item() msssim_index = piq.multi_scale_ssim( pred_tensor[:, :, ystart:yend, xstart:xend], gt_tensor[:, :, ystart:yend, xstart:xend], data_range=1., reduction='none').item() lpipsVGG += self.lpipsVGG( pred_tensor[:, :, ystart:yend, xstart:xend], gt_tensor[:, :, ystart:yend, xstart:xend]).item() lpipsAlex += self.lpipsAlex( pred_tensor[:, :, ystart:yend, xstart:xend], gt_tensor[:, :, ystart:yend, xstart:xend]).item() msssim_index /= n * n ssim_index /= n * n lpipsVGG /= n * n lpipsAlex /= n * n # dists = piq.DISTS(reduction='none')(pred_tensor, gt_tensor).item() # with torch.no_grad(): # lpips_index = piq.LPIPS(reduction='none')(pred_tensor, gt_tensor).item() rmse = ((gt - pred)**2).mean()**0.5 # relmse = (((gt - pred) ** 2).mean() / (gt ** 2).mean() + 1e-5) ** 0.5 # return {'rmse':rmse,'relmse':relmse,'psnr':psnr_index,'ssim':ssim_index,'msssim':msssim_index,'lpips':lpips_index} return { 'rmse': rmse, 'psnr': psnr_index, 'ssim': ssim_index, 'msssim': msssim_index, 'lpipsVGG': lpipsVGG, 'lpipsAlex': lpipsAlex }
def update(self, output): y_pred = output[0] y = output[1] y_pred = torch.clamp_min(y_pred, min=0.0) y = torch.clamp_min(y, min=0.0) # print("CrowdCountingMeanSSIMclamp ") # print("y_pred", y_pred.shape) # print("y", y.shape) y_pred = F.interpolate(y_pred, scale_factor=8) / 64 pad_density_map_tensor = torch.zeros((1, 1, y.shape[2], y.shape[3])).cuda() pad_density_map_tensor[:, 0, :y_pred.shape[2], :y_pred.shape[3]] = y_pred y_pred = pad_density_map_tensor # y_max = torch.max(y) # y_pred_max = torch.max(y_pred) # max_value = torch.max(y_max, y_pred_max) y = y / torch.max(y) * 255 y_pred = y_pred / torch.max(y_pred) * 255 ssim_metric = piq.ssim(y, y_pred, reduction="sum", data_range=255) self._sum += ssim_metric.item() # we multiply because ssim calculate mean of each image in batch # we multiply so we will divide correctly self._num_examples += y.shape[0]
def val_iter(self, final=True): with torch.no_grad(): self.model.eval() t = tqdm(self.loader_val) if final: t.set_description("Validation") else: t.set_description(f"Epoch {self.epoch} val ") psnr_avg = AverageMeter() ssim_avg = AverageMeter() l1_avg = AverageMeter() l2_avg = AverageMeter() for hr, lr in t: hr, lr = hr.to(self.dtype).to(self.device), lr.to(self.dtype).to(self.device) sr = self.model(lr).clamp(0, 1) l1_loss = torch.nn.functional.l1_loss(sr, hr).item() l2_loss = torch.sqrt(torch.nn.functional.mse_loss(sr, hr)).item() psnr = piq.psnr(hr, sr) ssim = piq.ssim(hr, sr) l1_avg.update(l1_loss) l2_avg.update(l2_loss) psnr_avg.update(psnr) ssim_avg.update(ssim) t.set_postfix(PSNR=f'{psnr_avg.get():.2f}', SSIM=f'{ssim_avg.get():.4f}') if self.writer is not None: self.writer.add_scalar('PSNR', psnr_avg.get(), self.epoch) self.writer.add_scalar('SSIM', ssim_avg.get(), self.epoch) self.writer.add_scalar('L1', l1_avg.get(), self.epoch) self.writer.add_scalar('L2', l2_avg.get(), self.epoch) return psnr_avg.get(), ssim_avg.get()
def forward(self, reference_observations: torch.Tensor, generated_observations: torch.Tensor, range=1.0) -> torch.Tensor: ''' Computes the ssim between the reference and the generated observations :param reference_observations: (bs, observations_count, channels, height, width) tensor with reference observations :param generated_observations: (bs, observations_count, channels, height, width) tensor with generated observations :param range: The maximum value used to represent each pixel :return: (bs, observations_count) tensor with ssim for each observation ''' # Flattens observations and then folds the results observations_count = reference_observations.size(1) flattened_reference_observations = TensorFolder.flatten( reference_observations) flattened_generated_observations = TensorFolder.flatten( generated_observations) flattened_ssim = ssim(flattened_generated_observations, flattened_reference_observations, range, reduction="none") folded_ssim = TensorFolder.fold(flattened_ssim, observations_count) return folded_ssim
def test_ssim_measure_is_less_or_equal_to_one( ones_zeros_4d_5d: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: # Create two maximally different tensors. ones = ones_zeros_4d_5d[0].to(device) zeros = ones_zeros_4d_5d[1].to(device) measure = ssim(ones, zeros, data_range=1., reduction='none') assert torch.le(measure, 1).all(), f'SSIM must be <= 1, got {measure}'
def test_ssim_measure_is_one_for_equal_tensors(y: torch.Tensor, device: str) -> None: y = y.to(device) x = y.clone() measure = ssim(x, y, data_range=1., reduction='none') assert torch.allclose(measure, torch.ones_like(measure)), f'If equal tensors are passed SSIM must be equal to 1 ' \ f'(considering floating point error up to 1 * 10^-6), '\ f'got {measure}'
def test_ssim_raises_if_tensors_have_different_shapes( prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor], device) -> None: target = prediction_target_4d_5d[1].to(device) dims = [[3], [2, 3], [161, 162], [161, 162]] if target.dim() == 5: dims += [[2, 3]] for size in list(itertools.product(*dims)): wrong_shape_prediction = torch.rand(size).to(target) if wrong_shape_prediction.size() == target.size(): try: ssim(wrong_shape_prediction, target) except Exception as e: pytest.fail(f"Unexpected error occurred: {e}") else: with pytest.raises(AssertionError): ssim(wrong_shape_prediction, target)
def eval(self,gt,pred,imformat='BHWC',dtype='jax'): if(dtype == 'jax'): gt = np.array(gt) pred = np.array(pred) with torch.no_grad(): if(imformat == 'BHWC'): gt_tensor = torch.Tensor(gt).permute(0,3,1,2).to(self.device) pred_tensor = torch.Tensor(pred).permute(0,3,1,2).to(self.device) elif(imformat == 'HWC'): gt_tensor = torch.Tensor(gt[None,...]).permute(0,3,1,2).to(self.device) pred_tensor = torch.Tensor(pred[None,...]).permute(0,3,1,2).to(self.device) else: print('Unknown image dimension format') exit(0) pred_tensor = torch.clamp(pred_tensor,0,1) gt_tensor = torch.clamp(gt_tensor,0,1) _,_,h,w = gt_tensor.shape lpipsAlex = 0 lpipsVGG = 0 msssim_index = 0 ssim_index = 0 n = 1 for i in range(n): for j in range(n): xstart = w//n * j ystart = h//n * i xend = w//n * (j+1) yend = h//n * (i+1) if('ssim' in self.metrics): ssim_index += piq.ssim(pred_tensor[:,:,ystart:yend,xstart:xend], gt_tensor[:,:,ystart:yend,xstart:xend], data_range=1., reduction='mean').item() if('msssim' in self.metrics): msssim_index = piq.multi_scale_ssim(pred_tensor[:,:,ystart:yend,xstart:xend], gt_tensor[:,:,ystart:yend,xstart:xend], data_range=1., reduction='mean').item() if('lpipsVGG' in self.metrics): lpipsVGG += self.lpipsVGG(pred_tensor[:,:,ystart:yend,xstart:xend], gt_tensor[:,:,ystart:yend,xstart:xend]).item() if('lpipsAlex' in self.metrics): lpipsAlex += self.lpipsAlex(pred_tensor[:,:,ystart:yend,xstart:xend], gt_tensor[:,:,ystart:yend,xstart:xend]).item() res = {} if('ssim' in self.metrics): res['ssim'] = ssim_index / n*n if('msssim' in self.metrics): res['msssim'] = msssim_index / n*n if('lpipsVGG' in self.metrics): res['lpipsVGG'] = lpipsVGG / n*n if('lpipsAlex' in self.metrics): res['lpipsAlex'] = lpipsAlex / n*n if('rmse' in self.metrics): res['rmse'] = ((gt_tensor - pred_tensor) ** 2).mean() ** 0.5 res['mse'] = float(((gt_tensor - pred_tensor) ** 2).mean().cpu().numpy()) res['psnr'] = -10. * np.log10(res['mse']) / np.log10(10.) return res
def validation_epoch_end( self, outputs: List[Tuple[torch.Tensor, torch.Tensor]]) -> None: if isinstance(self.val_dataloader().dataset, ImageLoader): self.val_dataloader().dataset.val = False else: self.val_dataloader().dataset.dataset.val = False fid_score = fid(self.forged_images, self.reference_images, self.hparams.feature_dimensionality_fid, self.device) ssim_score = ssim(self.forged_images, self.reference_images, data_range=255) psnr_score = psnr(self.forged_images, self.reference_images, data_range=255) self.log('FID_score', fid_score, on_step=False, on_epoch=True) self.log('SSIM', ssim_score, on_step=False, on_epoch=True) self.log('PSNR', psnr_score, on_step=False, on_epoch=True)
def image_metrics_from_dataset(dataset, output_addr='/tmp/psnr.csv'): cols = ['file_name', 'psnr', 'ssim'] cols_str = ','.join(cols) with open(output_addr, 'wt') as f: f.write(f'{cols_str}\n') dl = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) for i, data in tqdm(enumerate(dl), total=int(len(dataset))): x, y, file_name = data psnr = piq.psnr(x, y).item() ssim = piq.ssim(x, y).item() vals = [file_name[0], str(psnr), str(ssim)] val_str = ','.join(vals) with open(output_addr, 'at') as f: f.write(f'{val_str}\n') return pd.read_csv(output_addr)
def usage(): import torch from piq import ssim, SSIMLoss x = torch.rand(4, 256, 256, requires_grad=True) y = torch.rand(4, 256, 256) ssim_index: torch.Tensor = ssim(x, y, data_range=1.) ssimvalue = ssim_index.detach().numpy() assert type(ssim_index.detach().numpy()) == np.ndarray, type( ssim_index.detach().numpy()) print(f"ssim_index: ", ) # loss = SSIMLoss(data_range=1.) loss = SSIMLoss(data_range=1.) output2 = loss(x, y) output: torch.Tensor = loss(x, x) print(output.item()) output.backward()
def test_ssim_raise_if_wrong_value_is_estimated( test_images: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: for x, y in test_images: piq_ssim = ssim(x.to(device), y.to(device), kernel_size=11, kernel_sigma=1.5, data_range=255, reduction='none') tf_x = tf.convert_to_tensor(x.permute(0, 2, 3, 1).numpy()) tf_y = tf.convert_to_tensor(y.permute(0, 2, 3, 1).numpy()) with tf.device('/CPU'): tf_ssim = torch.tensor( tf.image.ssim(tf_x, tf_y, max_val=255).numpy()).to(piq_ssim) match_accuracy = 2e-4 + 1e-8 assert torch.allclose(piq_ssim, tf_ssim, rtol=0, atol=match_accuracy), \ f'The estimated value must be equal to tensorflow provided one' \ f'(considering floating point operation error up to {match_accuracy}), ' \ f'got difference {(piq_ssim - tf_ssim).abs()}'
def grid_search(x, y, rec_func, grid): """ Grid search utility for tuning hyper-parameters. """ err_min = np.inf grid_param = None grid_shape = [len(val) for val in grid.values()] err = torch.zeros(grid_shape) err_psnr = torch.zeros(grid_shape) err_ssim = torch.zeros(grid_shape) for grid_val, nidx in zip(itertools.product(*grid.values()), np.ndindex(*grid_shape)): grid_param_cur = dict(zip(grid.keys(), grid_val)) print( "Current grid parameters (" + str(list(nidx)) + " / " + str(grid_shape) + "): " + str(grid_param_cur), flush=True, ) x_rec = rec_func(y, **grid_param_cur) err[nidx], _ = l2_error(x_rec, x, relative=True, squared=False) err_psnr[nidx] = psnr( rotate_real(x_rec)[:, 0:1, ...], rotate_real(x)[:, 0:1, ...], data_range=rotate_real(x)[:, 0:1, ...].max(), reduction="mean", ) err_ssim[nidx] = ssim( rotate_real(x_rec)[:, 0:1, ...], rotate_real(x)[:, 0:1, ...], data_range=rotate_real(x)[:, 0:1, ...].max(), size_average=True, ) print("Rel. recovery error: {:1.2e}".format(err[nidx]), flush=True) print("PSNR: {:.2f}".format(err_psnr[nidx]), flush=True) print("SSIM: {:.2f}".format(err_ssim[nidx]), flush=True) if err[nidx] < err_min: grid_param = grid_param_cur err_min = err[nidx] return grid_param, err_min, err, err_psnr, err_ssim
def forward(self, predict, target): if self.l1_norm: l1_norm_metric = nn.functional.l1_loss(predict, target) if self.mse: mse_norm_metric = nn.functional.mse_loss(predict, target) if self.pearsonr: pearsonr_metric = audtorch.metrics.functional.pearsonr(predict, target).mean() if self.cc: cc_metric = audtorch.metrics.functional.concordance_cc(predict, target).mean() if self.psnr: psnr_metric = piq.psnr(predict, target, data_range=1., reduction='none').mean() if self.ssim: ssim_metric = piq.ssim(predict, target, data_range=1.) if self.mssim: mssim_metric = piq.multi_scale_ssim(predict, target, data_range=1.) metric_summary = {'l1_norm': l1_norm_metric, 'mse': mse_norm_metric, 'pearsonr_metric': pearsonr_metric, 'cc': cc_metric, 'psnr': psnr_metric, 'ssim': ssim_metric, 'mssim': mssim_metric } return metric_summary
def main(args): input_shape = (3, 380, 380) if not os.path.exists(args.checkpoints_output): os.makedirs(args.checkpoints_output) if not os.path.exists(args.logs): os.makedirs(args.logs) images_output = os.path.join(args.logs, 'images') if not os.path.exists(images_output): os.makedirs(images_output) if not args.model in models: print(f"Model name {args.model} must be one of: {model_names}") return 1 print(f"Seting up training for model: {args.model}") print(f"Train X Root: {args.train_x_root}") print(f"Train Y Root: {args.train_y_root}") if args.test_x is not None and args.test_y is not None: print(f"Test X Root: {args.test_x}") print(f"Test Y Root: {args.test_y}") normalize_transform = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) if args.test_x is None or args.test_y is None: dataset = EnumPairedDataset(args.train_x_root, args.train_y_root, transform=normalize_transform) train_d, test_d = train_val_dataset(dataset) else: train_d = EnumPairedDataset(args.train_x_root, args.train_y_root, transform=normalize_transform) test_d = EnumPairedDataset(args.test_x_root, args.test_y_root, transform=normalize_transform) train_batch_size = args.train_batch_size test_batch_size = args.test_batch_size train_dl = DataLoader(train_d, batch_size=train_batch_size, shuffle=True, num_workers=0) test_dl = DataLoader(test_d, batch_size=test_batch_size, shuffle=True, num_workers=0) if args.show_dataset: x_batch, y_batch, names = next(iter(train_dl)) plt.subplot(2, 1, 1) plt.imshow(torchvision.utils.make_grid(x_batch).permute(1, 2, 0)) plt.subplot(2, 1, 2) plt.imshow(torchvision.utils.make_grid(y_batch).permute(1, 2, 0)) plt.show() model = models[args.model] model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) device = f"cuda:{model.device_ids[0]}" #device = 'cpu' model.to(device) summary(model, input_shape) optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) #criterion = loss_fun() if args.load is not None: try: pretrained_dict = torch.load(args.load) model_dict = model.state_dict() model_dict.update(pretrained_dict) model.load_state_dict(model_dict) print(f"Weights loaded from {args.load}") except Exception as e: print(f"Couldn't load weights from {args.load}") print(e) best_training_loss = math.inf test_loss_for_best_training_loss = math.inf cols = [ 'epoch', 'training_loss', 'test_loss', 'train_psnr', 'test_psnr', 'train_ssim', 'test_ssim' ] logs_addr = os.path.join(args.logs, 'logs.csv') add_line_to_csv(logs_addr, cols) print(f"Logs to: {args.logs}") epochs = args.epochs if epochs <= args.from_epoch: epochs += args.from_epoch for epoch in range(args.from_epoch, epochs + 1): print(f"Epoch {epoch}/{epochs}") training_loss = 0.0 test_loss = 0.0 train_psnr = 0.0 test_psnr = 0.0 train_ssim = 0.0 test_ssim = 0.0 if args.load is not None: try: fn, ext = os.path.splitext(os.path.basename(args.load)) loss_vals = fn.split('_') best_training_loss = float(loss_vals[2]) test_loss_for_best_training_loss = float(loss_vals[3]) except Exception as e: print(f"Couldn't load best training loss from {args.load}") print(e) print("Training:") for i, data in tqdm(enumerate(train_dl), total=int(len(train_d) / train_batch_size)): w, m, file_name = data x = w.to(device) y = m.to(device) del w del m optimizer.zero_grad() y_hat = model(x) loss = loss_fun(y_hat, y) loss.backward() optimizer.step() training_loss += float(loss.item()) del x ''' train_psnr = piq.psnr(y_hat[0], y[0],data_range=1., reduction='none') train_ssim = piq.ssim(y_hat[0], y[0], data_range=1., reduction='none') ''' del y del y_hat training_loss /= (i + 1) #train_psnr /= (i+1) #train_ssim /= (i+1) with torch.no_grad(): print("Testing:") for i, data in tqdm(enumerate(test_dl), total=int(len(test_d) / test_batch_size)): w, m, file_name = data x = w.to(device) y = m.to(device) y_hat = model(x) loss = loss_fun(y_hat, y) test_loss += float(loss.item()) del x try: test_psnr += piq.psnr(y_hat, y) test_ssim += piq.ssim(y_hat, y) except: pass if args.show_output_images and i < 5: imgs_dir = os.path.join(images_output, f"epoch_{epoch}") if not os.path.exists(imgs_dir): os.makedirs(imgs_dir) for j, y_hat_i in enumerate(y_hat): fn = os.path.splitext(os.path.basename( file_name[j]))[0] y_gt = y[j] img_i_addr = os.path.join( imgs_dir, f'{epoch}_{fn}_{i}_{j}.{dataset.images_extension}') img_i_gt_addr = os.path.join( imgs_dir, f'{epoch}_{fn}_{i}_{j}_gt.{dataset.images_extension}' ) torchvision.utils.save_image(y_hat_i, img_i_addr) torchvision.utils.save_image(y_gt, img_i_gt_addr) del y_gt del y del y_hat test_loss /= (i + 1) test_psnr /= (i + 1) test_ssim /= (i + 1) print(f"Completed Epoch: {epoch}/{args.epochs}") print(f"\tTrain loss: {training_loss}") print(f"\tTest loss: {test_loss}") print(f"\tTrain PSNR: {train_psnr}") print(f"\tTest PSNR: {test_psnr}") print(f"\tTrain SSIM: {train_ssim}") print(f"\tTest SSIM: {test_ssim}") print(f"\tBest training loss so far: {best_training_loss}") print(f"\tTest loss for: {test_loss_for_best_training_loss}") add_line_to_csv(logs_addr, [ str(epoch), str(training_loss), str(test_loss), str(train_psnr), str(test_psnr), str(train_ssim), str(test_ssim) ]) if best_training_loss > training_loss: best_training_loss = training_loss test_loss_for_best_training_loss = test_loss save_file_name = f"{args.model}_epoch_{epoch}_{best_training_loss:.3f}_{test_loss_for_best_training_loss:.3f}.pth" checkpoint_path = os.path.join(args.checkpoints_output, save_file_name) torch.save(model.state_dict(), checkpoint_path)