def test_ssim_loss_grad(x_y_4d_5d, device: str) -> None: x = x_y_4d_5d[0].to(device) y = x_y_4d_5d[1].to(device) x.requires_grad_(True) loss = SSIMLoss(data_range=1.)(x, y).mean() loss.backward() assert torch.isfinite( x.grad).all(), f'Expected finite gradient values, got {x.grad}'
def test_ssim_loss_check_kernel_size_is_passed(x: torch.Tensor, y: torch.Tensor) -> None: kernel_sizes = list(range(0, 50)) for kernel_size in kernel_sizes: if kernel_size % 2: SSIMLoss(kernel_size=kernel_size)(x, y) else: with pytest.raises(AssertionError): SSIMLoss(kernel_size=kernel_size)(x, y)
def test_ssim_loss_grad(prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: prediction = prediction_target_4d_5d[0].to(device) target = prediction_target_4d_5d[1].to(device) prediction.requires_grad_(True) loss = SSIMLoss(data_range=1.)(prediction, target).mean() loss.backward() assert torch.isfinite(prediction.grad).all( ), f'Expected finite gradient values, got {prediction.grad}'
def test_ssim_loss_raises_if_tensors_have_different_shapes(x_y_4d_5d, device) -> None: y = x_y_4d_5d[1].to(device) dims = [[3], [2, 3], [161, 162], [161, 162]] if y.dim() == 5: dims += [[2, 3]] for size in list(itertools.product(*dims)): wrong_shape_x = torch.rand(size).to(y) if wrong_shape_x.size() == y.size(): SSIMLoss()(wrong_shape_x, y) else: with pytest.raises(AssertionError): SSIMLoss()(wrong_shape_x, y)
def test_ssim_loss_check_available_dimensions() -> None: custom_x = torch.rand(256, 256) custom_y = torch.rand(256, 256) for _ in range(10): if custom_x.dim() < 5: try: SSIMLoss()(custom_x, custom_y) except Exception as e: pytest.fail(f"Unexpected error occurred: {e}") else: with pytest.raises(AssertionError): SSIMLoss()(custom_x, custom_y) custom_x.unsqueeze_(0) custom_y.unsqueeze_(0)
def test_ssim_loss_raises_if_tensors_have_different_shapes( prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor], device) -> None: target = prediction_target_4d_5d[1].to(device) dims = [[3], [2, 3], [161, 162], [161, 162]] if target.dim() == 5: dims += [[2, 3]] for size in list(itertools.product(*dims)): wrong_shape_prediction = torch.rand(size).to(target) if wrong_shape_prediction.size() == target.size(): SSIMLoss()(wrong_shape_prediction, target) else: with pytest.raises(AssertionError): SSIMLoss()(wrong_shape_prediction, target)
def test_ssim_loss_equality(y: torch.Tensor, device: str) -> None: y = y.to(device) x = y.clone() loss = SSIMLoss()(x, y) assert torch.allclose(loss, torch.zeros_like(loss)), \ f'If equal tensors are passed SSIM loss must be equal to 0 '\ f'(considering floating point operation error up to 1 * 10^-6), got {loss}'
def test_ssim_loss_is_less_or_equal_to_one( ones_zeros_4d_5d: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: # Create two maximally different tensors. ones = ones_zeros_4d_5d[0].to(device) zeros = ones_zeros_4d_5d[1].to(device) loss = SSIMLoss()(ones, zeros) assert (loss <= 1).all(), f'SSIM loss must be <= 1, got {loss}'
def test_ssim_loss_symmetry(x_y_4d_5d, device: str) -> None: x = x_y_4d_5d[0].to(device) y = x_y_4d_5d[1].to(device) loss = SSIMLoss() loss_value = loss(x, y) reverse_loss_value = loss(y, x) assert torch.allclose(loss_value, reverse_loss_value), \ f'Expect: SSIMLoss(a, b) == SSIMLoss(b, a), got {loss_value} != {reverse_loss_value}'
def test_ssim_loss_raises_if_kernel_size_greater_than_image( x_y_4d_5d, device: str) -> None: x = x_y_4d_5d[0].to(device) y = x_y_4d_5d[1].to(device) kernel_size = 11 wrong_size_x = x[:, :, :kernel_size - 1, :kernel_size - 1] wrong_size_y = y[:, :, :kernel_size - 1, :kernel_size - 1] with pytest.raises(ValueError): SSIMLoss(kernel_size=kernel_size)(wrong_size_x, wrong_size_y)
def test_ssim_loss_symmetry(prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: prediction = prediction_target_4d_5d[0].to(device) target = prediction_target_4d_5d[1].to(device) loss = SSIMLoss() loss_value = loss(prediction, target) reverse_loss_value = loss(target, prediction) assert torch.allclose(loss_value, reverse_loss_value), \ f'Expect: SSIMLoss(a, b) == SSIMLoss(b, a), got {loss_value} != {reverse_loss_value}'
def test_ssim_loss_raises_if_kernel_size_greater_than_image( prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: prediction = prediction_target_4d_5d[0].to(device) target = prediction_target_4d_5d[1].to(device) kernel_size = 11 wrong_size_prediction = prediction[:, :, :kernel_size - 1, :kernel_size - 1] wrong_size_target = target[:, :, :kernel_size - 1, :kernel_size - 1] with pytest.raises(ValueError): SSIMLoss(kernel_size=kernel_size)(wrong_size_prediction, wrong_size_target)
def test_ssim_loss_raise_if_wrong_value_is_estimated( test_images: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: for x, y in test_images: ssim_loss = SSIMLoss(kernel_size=11, kernel_sigma=1.5, data_range=255, reduction='mean')(x.to(device), y.to(device)) tf_x = tf.convert_to_tensor(x.permute(0, 2, 3, 1).numpy()) tf_y = tf.convert_to_tensor(y.permute(0, 2, 3, 1).numpy()) with tf.device('/CPU'): tf_ssim = torch.tensor( tf.image.ssim(tf_x, tf_y, max_val=255).numpy()).mean().to(device) match_accuracy = 2e-4 + 1e-8 assert torch.isclose(ssim_loss, 1. - tf_ssim, rtol=0, atol=match_accuracy), \ f'The estimated value must be equal to tensorflow provided one' \ f'(considering floating point operation error up to {match_accuracy}), ' \ f'got difference {(ssim_loss - 1. + tf_ssim).abs()}'
def usage(): import torch from piq import ssim, SSIMLoss x = torch.rand(4, 256, 256, requires_grad=True) y = torch.rand(4, 256, 256) ssim_index: torch.Tensor = ssim(x, y, data_range=1.) ssimvalue = ssim_index.detach().numpy() assert type(ssim_index.detach().numpy()) == np.ndarray, type( ssim_index.detach().numpy()) print(f"ssim_index: ", ) # loss = SSIMLoss(data_range=1.) loss = SSIMLoss(data_range=1.) output2 = loss(x, y) output: torch.Tensor = loss(x, x) print(output.item()) output.backward()
"out_channels": 2, } subnet = Tiramisu it_net_params = { "num_iter": 1, "lam": 1.0, "lam_learnable": False, "final_dc": False, "resnet_factor": 1.0, "concat_mask": False, "multi_slice": False, } # ----- training configuration ----- ssimloss = SSIMLoss(data_range=1e1) maeloss = torch.nn.L1Loss(reduction="mean") def ssim_l1_loss(pred, tar): return 0.7 * ssimloss( rotate_real(pred)[:, 0:1, ...], rotate_real(tar)[:, 0:1, ...], ) + 0.3 * maeloss( rotate_real(pred)[:, 0:1, ...], rotate_real(tar)[:, 0:1, ...], ) train_params = { "num_epochs": [40],
def main(): args = parse_args() config = load_json('./configs', args.config_name) config = config[args.encoder_type] if torch.cuda.is_available(): DEVICE = 'cuda' else: DEVICE = 'cpu' print('CUDA is not available, using CPU instead...') # ===================Create data loaders===================== transforms = get_transforms(roi=config['roi'], grayscale=config['grayscale']) train_set = SuperMarioKartDataset( Path(args.data_path) / 'train_files.txt', transforms=transforms ) valid_set = SuperMarioKartDataset( Path(args.data_path) / 'valid_files.txt', transforms=transforms ) train_loader = DataLoader(train_set, batch_size=config['batch_size'], num_workers=4) valid_loader = DataLoader(train_set, batch_size=1, num_workers=4) # ===================Export some val images===================== if not args.save_path: save_path = args.data_path / (args.encoder_type + '_results') else: save_path = args.save_path save_path.mkdir(exist_ok=True) # ===================Instantiate models===================== img_shape = train_set.image_shape if args.encoder_type == 'pca': network = autoencoders.PCA(img_shape, config['latent_dim']).cuda() elif args.encoder_type == 'mlp': network = autoencoders.MLP(img_shape, config['latent_dim']).cuda() optimizer = torch.optim.Adam(network.parameters()) ssim_loss = SSIMLoss(data_range=1.) mse_loss = nn.MSELoss() load_path = args.load_path if load_path and load_path.exists(): network.load_state_dict(torch.load(load_path / "best_encoder.pt")) best_val_loss = 999 cnt_bad_epocs = 0 for epoch in range(1, config['n_epochs'] + 1): train_losses, valid_losses = [], [] epc_save_path = save_path / f'epoch_{epoch}' epc_save_path.mkdir(exist_ok=True) # ===================Training===================== for x_t in train_loader: x_t = x_t.cuda() encoding, decoding = network(x_t) step_mse = mse_loss(decoding, x_t) loss = mse_loss(decoding, x_t) + 0.1 * ssim_loss(decoding, x_t) train_losses.append(loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() # ===================Validation===================== for i, x_t in enumerate(valid_loader): x_t = x_t.cuda() with torch.no_grad(): encoding, decoding = network(x_t) step_mse = mse_loss(decoding, x_t) step_ssim = 0.1 * ssim_loss(decoding, x_t) loss = step_mse + step_ssim if i < args.viz_samples: # display only the first 10 samples input_image = np.array(ToPILImage()(x_t.cpu()[0])) decoded_image = np.array(ToPILImage()(decoding.cpu()[0])) merged_image = np.concatenate((input_image, decoded_image)) imsave(epc_save_path / f'result_{i}.png', merged_image) valid_losses.append(step_mse.item()) avg_trn_loss = round(np.mean(train_losses), 4) avg_val_loss = round(np.mean(valid_losses), 4) print(f'Epoch - {epoch} | Avg Train MSE - {avg_trn_loss} | Avg Val MSE - {avg_val_loss}') # ===================Checkpointing===================== if best_val_loss > avg_val_loss: cnt_bad_epocs = 0 best_val_loss = avg_val_loss torch.save(network.state_dict(), save_path / 'best_encoder.pt') # ===================Early stopping===================== if best_val_loss < avg_val_loss: cnt_bad_epocs += 1 if cnt_bad_epocs == args.patience: print('Network is not improving. Stopping training...') break
def test_ssim_loss_raises_if_tensors_have_different_types( y: torch.Tensor) -> None: wrong_type_x = list(range(10)) with pytest.raises(AssertionError): SSIMLoss()(wrong_type_x, y)