def test_vsi_to_be_one_for_identical_inputs(input_tensors: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: x, _ = input_tensors index = vsi(x.to(device), x.to(device), data_range=1., reduction='none') index_255 = vsi(x.to(device) * 255, x.to(device) * 255, data_range=255, reduction='none') assert torch.allclose(index, torch.ones_like(index, device=device)), \ f'Expected index to be equal 1, got {index}' assert torch.allclose(index_255, torch.ones_like(index_255, device=device)), \ f'Expected index to be equal 1, got {index_255}'
def test_vsi_zeros_ones_inputs(device: str) -> None: zeros = torch.zeros(1, 3, 128, 128, device=device) ones = torch.ones(1, 3, 128, 128, device=device) vsi_zeros = vsi(zeros, zeros, data_range=1.) assert torch.isfinite(vsi_zeros).all(), f'Expected finite value for zeros tensors, got {vsi_zeros}' vsi_ones = vsi(ones, ones, data_range=1.) assert torch.isfinite(vsi_ones).all(), f'Expected finite value for ones tensos, got {vsi_ones}' vsi_zeros_ones = vsi(zeros, ones, data_range=1.) assert torch.isfinite(vsi_zeros_ones).all(), \ f'Expected finite value for zeros and ones tensos, got {vsi_zeros_ones}'
def test_vsi_symmetry(input_tensors: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: x, y = input_tensors result = vsi(x.to(device), y.to(device), data_range=1., reduction='none') result_sym = vsi(y.to(device), x.to(device), data_range=1., reduction='none') assert torch.allclose( result_sym, result), f'Expected the same results, got {result} and {result_sym}'
def test_vsi_compare_with_matlab(device: str) -> None: x = torch.tensor(imread('tests/assets/I01.BMP')).permute(2, 0, 1) y = torch.tensor(imread('tests/assets/i01_01_5.bmp')).permute(2, 0, 1) predicted_score = vsi(x.to(device), y.to(device), data_range=255, reduction='none') target_score = torch.tensor([0.96405]).to(predicted_score) assert torch.allclose(predicted_score, target_score), f'Expected result similar to MATLAB,' \ f'got diff{predicted_score - target_score}'
psnr_n = psnr(img_sharp, img_deblu, data_range=255) ssim_n = ssim(img_deblu / 255, img_sharp / 255, gaussian_weights=True, multichannel=True, use_sample_covariance=False, sigma=1.5) if name_sharp[-7:-4] == "001": print(name_sharp, (psnr_n, ssim_n)) sharp = Image.fromarray(np.uint8(img_sharp)) deblu = Image.fromarray(np.uint8(img_deblu)) sharp_ts = TF.to_tensor(sharp).unsqueeze(0) deblu_ts = TF.to_tensor(deblu).unsqueeze(0) sharp_ts, deblu_ts = sharp_ts/255.0, deblu_ts/255.0 vif_n = piq.vif_p(deblu_ts, sharp_ts) vsi_n = piq.vsi(deblu_ts, sharp_ts) haar_n = piq.haarpsi(deblu_ts, sharp_ts) # # if count_k < 198: # psnr_k.append(psnr_n) # ssim_k.append(ssim_n) # count_k += 1 # elif count_k == 198: # psnr_k.append(psnr_n) # ssim_k.append(ssim_n) # kernel_p.append(max(psnr_k)) # kernel_s.append(max(ssim_k)) # psnr_k = [] # ssim_k = [] # count_k = 1
def main(): # Read RGB image and it's noisy version x = torch.tensor(imread('tests/assets/i01_01_5.bmp')).permute(2, 0, 1) / 255. y = torch.tensor(imread('tests/assets/I01.BMP')).permute(2, 0, 1) / 255. if torch.cuda.is_available(): # Move to GPU to make computaions faster x = x.cuda() y = y.cuda() # To compute BRISQUE score as a measure, use lower case function from the library brisque_index: torch.Tensor = piq.brisque(x, data_range=1., reduction='none') # In order to use BRISQUE as a loss function, use corresponding PyTorch module. # Note: the back propagation is not available using torch==1.5.0. # Update the environment with latest torch and torchvision. brisque_loss: torch.Tensor = piq.BRISQUELoss(data_range=1., reduction='none')(x) print( f"BRISQUE index: {brisque_index.item():0.4f}, loss: {brisque_loss.item():0.4f}" ) # To compute Content score as a loss function, use corresponding PyTorch module # By default VGG16 model is used, but any feature extractor model is supported. # Don't forget to adjust layers names accordingly. Features from different layers can be weighted differently. # Use weights parameter. See other options in class docstring. content_loss = piq.ContentLoss(feature_extractor="vgg16", layers=("relu3_3", ), reduction='none')(x, y) print(f"ContentLoss: {content_loss.item():0.4f}") # To compute DISTS as a loss function, use corresponding PyTorch module # By default input images are normalized with ImageNet statistics before forwarding through VGG16 model. # If there is no need to normalize the data, use mean=[0.0, 0.0, 0.0] and std=[1.0, 1.0, 1.0]. dists_loss = piq.DISTS(reduction='none')(x, y) print(f"DISTS: {dists_loss.item():0.4f}") # To compute FSIM as a measure, use lower case function from the library fsim_index: torch.Tensor = piq.fsim(x, y, data_range=1., reduction='none') # In order to use FSIM as a loss function, use corresponding PyTorch module fsim_loss = piq.FSIMLoss(data_range=1., reduction='none')(x, y) print( f"FSIM index: {fsim_index.item():0.4f}, loss: {fsim_loss.item():0.4f}") # To compute GMSD as a measure, use lower case function from the library # This is port of MATLAB version from the authors of original paper. # In any case it should me minimized. Usually values of GMSD lie in [0, 0.35] interval. gmsd_index: torch.Tensor = piq.gmsd(x, y, data_range=1., reduction='none') # In order to use GMSD as a loss function, use corresponding PyTorch module: gmsd_loss: torch.Tensor = piq.GMSDLoss(data_range=1., reduction='none')(x, y) print( f"GMSD index: {gmsd_index.item():0.4f}, loss: {gmsd_loss.item():0.4f}") # To compute HaarPSI as a measure, use lower case function from the library # This is port of MATLAB version from the authors of original paper. haarpsi_index: torch.Tensor = piq.haarpsi(x, y, data_range=1., reduction='none') # In order to use HaarPSI as a loss function, use corresponding PyTorch module haarpsi_loss: torch.Tensor = piq.HaarPSILoss(data_range=1., reduction='none')(x, y) print( f"HaarPSI index: {haarpsi_index.item():0.4f}, loss: {haarpsi_loss.item():0.4f}" ) # To compute LPIPS as a loss function, use corresponding PyTorch module lpips_loss: torch.Tensor = piq.LPIPS(reduction='none')(x, y) print(f"LPIPS: {lpips_loss.item():0.4f}") # To compute MDSI as a measure, use lower case function from the library mdsi_index: torch.Tensor = piq.mdsi(x, y, data_range=1., reduction='none') # In order to use MDSI as a loss function, use corresponding PyTorch module mdsi_loss: torch.Tensor = piq.MDSILoss(data_range=1., reduction='none')(x, y) print( f"MDSI index: {mdsi_index.item():0.4f}, loss: {mdsi_loss.item():0.4f}") # To compute MS-SSIM index as a measure, use lower case function from the library: ms_ssim_index: torch.Tensor = piq.multi_scale_ssim(x, y, data_range=1.) # In order to use MS-SSIM as a loss function, use corresponding PyTorch module: ms_ssim_loss = piq.MultiScaleSSIMLoss(data_range=1., reduction='none')(x, y) print( f"MS-SSIM index: {ms_ssim_index.item():0.4f}, loss: {ms_ssim_loss.item():0.4f}" ) # To compute Multi-Scale GMSD as a measure, use lower case function from the library # It can be used both as a measure and as a loss function. In any case it should me minimized. # By defualt scale weights are initialized with values from the paper. # You can change them by passing a list of 4 variables to scale_weights argument during initialization # Note that input tensors should contain images with height and width equal 2 ** number_of_scales + 1 at least. ms_gmsd_index: torch.Tensor = piq.multi_scale_gmsd(x, y, data_range=1., chromatic=True, reduction='none') # In order to use Multi-Scale GMSD as a loss function, use corresponding PyTorch module ms_gmsd_loss: torch.Tensor = piq.MultiScaleGMSDLoss(chromatic=True, data_range=1., reduction='none')(x, y) print( f"MS-GMSDc index: {ms_gmsd_index.item():0.4f}, loss: {ms_gmsd_loss.item():0.4f}" ) # To compute PSNR as a measure, use lower case function from the library. psnr_index = piq.psnr(x, y, data_range=1., reduction='none') print(f"PSNR index: {psnr_index.item():0.4f}") # To compute PieAPP as a loss function, use corresponding PyTorch module: pieapp_loss: torch.Tensor = piq.PieAPP(reduction='none', stride=32)(x, y) print(f"PieAPP loss: {pieapp_loss.item():0.4f}") # To compute SSIM index as a measure, use lower case function from the library: ssim_index = piq.ssim(x, y, data_range=1.) # In order to use SSIM as a loss function, use corresponding PyTorch module: ssim_loss: torch.Tensor = piq.SSIMLoss(data_range=1.)(x, y) print( f"SSIM index: {ssim_index.item():0.4f}, loss: {ssim_loss.item():0.4f}") # To compute Style score as a loss function, use corresponding PyTorch module: # By default VGG16 model is used, but any feature extractor model is supported. # Don't forget to adjust layers names accordingly. Features from different layers can be weighted differently. # Use weights parameter. See other options in class docstring. style_loss = piq.StyleLoss(feature_extractor="vgg16", layers=("relu3_3", ))(x, y) print(f"Style: {style_loss.item():0.4f}") # To compute TV as a measure, use lower case function from the library: tv_index: torch.Tensor = piq.total_variation(x) # In order to use TV as a loss function, use corresponding PyTorch module: tv_loss: torch.Tensor = piq.TVLoss(reduction='none')(x) print(f"TV index: {tv_index.item():0.4f}, loss: {tv_loss.item():0.4f}") # To compute VIF as a measure, use lower case function from the library: vif_index: torch.Tensor = piq.vif_p(x, y, data_range=1.) # In order to use VIF as a loss function, use corresponding PyTorch class: vif_loss: torch.Tensor = piq.VIFLoss(sigma_n_sq=2.0, data_range=1.)(x, y) print(f"VIFp index: {vif_index.item():0.4f}, loss: {vif_loss.item():0.4f}") # To compute VSI score as a measure, use lower case function from the library: vsi_index: torch.Tensor = piq.vsi(x, y, data_range=1.) # In order to use VSI as a loss function, use corresponding PyTorch module: vsi_loss: torch.Tensor = piq.VSILoss(data_range=1.)(x, y) print(f"VSI index: {vsi_index.item():0.4f}, loss: {vsi_loss.item():0.4f}")
def test_vsi_preserves_dtype(input_tensors: Tuple[torch.Tensor, torch.Tensor], dtype, device: str) -> None: x, y = input_tensors output = vsi(x.to(device=device, dtype=dtype), y.to(device=device, dtype=dtype)) assert output.dtype == dtype