Beispiel #1
0
def test_multi_scale_ssim_symmetry(x_y_4d_5d, device: str) -> None:
    x = x_y_4d_5d[0].to(device)
    y = x_y_4d_5d[1].to(device)
    measure = multi_scale_ssim(x, y, data_range=1., reduction='none')
    reverse_measure = multi_scale_ssim(y, x, data_range=1., reduction='none')
    assert torch.allclose(measure, reverse_measure), f'Expect: MS-SSIM(a, b) == MSSSIM(b, a), '\
                                                     f'got {measure} != {reverse_measure}'
Beispiel #2
0
def test_multi_scale_ssim_raises_if_tensors_have_different_types(x, y) -> None:
    wrong_type_x = list(range(10))
    with pytest.raises(AssertionError):
        multi_scale_ssim(wrong_type_x, y)
    wrong_type_scale_weights = True
    with pytest.raises(AssertionError):
        multi_scale_ssim(x, y, scale_weights=wrong_type_scale_weights)
Beispiel #3
0
def test_multi_scale_ssim_check_kernel_size_is_passed(x, y) -> None:
    kernel_sizes = list(range(0, 13))
    for kernel_size in kernel_sizes:
        if kernel_size % 2:
            multi_scale_ssim(x, y, kernel_size=kernel_size)
        else:
            with pytest.raises(AssertionError):
                multi_scale_ssim(x, y, kernel_size=kernel_size)
Beispiel #4
0
def test_multi_scale_ssim_fails_for_incorrect_data_range(
        prediction: torch.Tensor, target: torch.Tensor, device: str) -> None:
    # Scale to [0, 255]
    prediction_scaled = (prediction * 255).type(torch.uint8)
    target_scaled = (target * 255).type(torch.uint8)
    with pytest.raises(AssertionError):
        multi_scale_ssim(prediction_scaled.to(device),
                         target_scaled.to(device),
                         data_range=1.0)
Beispiel #5
0
def test_multi_scale_ssim_check_kernel_size_is_passed(
        prediction: torch.Tensor, target: torch.Tensor) -> None:
    kernel_sizes = list(range(0, 13))
    for kernel_size in kernel_sizes:
        if kernel_size % 2:
            multi_scale_ssim(prediction, target, kernel_size=kernel_size)
        else:
            with pytest.raises(AssertionError):
                multi_scale_ssim(prediction, target, kernel_size=kernel_size)
Beispiel #6
0
def test_multi_scale_ssim_fails_for_incorrect_data_range(x, y,
                                                         device: str) -> None:
    # Scale to [0, 255]
    x_scaled = (x * 255).type(torch.uint8)
    y_scaled = (y * 255).type(torch.uint8)
    with pytest.raises(AssertionError):
        multi_scale_ssim(x_scaled.to(device),
                         y_scaled.to(device),
                         data_range=1.0)
Beispiel #7
0
def test_ms_ssim_raises_if_kernel_size_greater_than_image(x_y_4d_5d, device: str) -> None:
    x = x_y_4d_5d[0].to(device)
    y = x_y_4d_5d[1].to(device)
    kernel_size = 11
    levels = 5
    min_size = (kernel_size - 1) * 2 ** (levels - 1) + 1
    wrong_size_x = x[:, :, :min_size - 1, :min_size - 1]
    wrong_size_y = y[:, :, :min_size - 1, :min_size - 1]
    with pytest.raises(ValueError):
        multi_scale_ssim(wrong_size_x, wrong_size_y, kernel_size=kernel_size)
Beispiel #8
0
def test_multi_scale_ssim_raises_if_tensors_have_different_types(
        prediction: torch.Tensor, target: torch.Tensor) -> None:
    wrong_type_prediction = list(range(10))
    with pytest.raises(AssertionError):
        multi_scale_ssim(wrong_type_prediction, target)
    wrong_type_scale_weights = True
    with pytest.raises(AssertionError):
        multi_scale_ssim(prediction,
                         target,
                         scale_weights=wrong_type_scale_weights)
Beispiel #9
0
def test_multi_scale_ssim_supports_different_data_ranges(x_y_4d_5d, data_range, device: str) -> None:
    x, y = x_y_4d_5d
    x_scaled = (x * data_range).type(torch.uint8)
    y_scaled = (y * data_range).type(torch.uint8)

    measure_scaled = multi_scale_ssim(x_scaled.to(device), y_scaled.to(device), data_range=data_range)
    measure = multi_scale_ssim(
        x_scaled.to(device) / float(data_range),
        y_scaled.to(device) / float(data_range),
        data_range=1.0
    )
    diff = torch.abs(measure_scaled - measure)
    assert (diff <= 1e-6).all(), f'Result for same tensor with different data_range should be the same, got {diff}'
Beispiel #10
0
def test_ms_ssim_raises_if_kernel_size_greater_than_image(
        prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor],
        device: str) -> None:
    prediction = prediction_target_4d_5d[0].to(device)
    target = prediction_target_4d_5d[1].to(device)
    kernel_size = 11
    levels = 5
    min_size = (kernel_size - 1) * 2**(levels - 1) + 1
    wrong_size_prediction = prediction[:, :, :min_size - 1, :min_size - 1]
    wrong_size_target = target[:, :, :min_size - 1, :min_size - 1]
    with pytest.raises(ValueError):
        multi_scale_ssim(wrong_size_prediction,
                         wrong_size_target,
                         kernel_size=kernel_size)
Beispiel #11
0
def test_multi_scale_ssim_check_available_dimensions() -> None:
    custom_prediction = torch.rand(256, 256)
    custom_target = torch.rand(256, 256)
    for _ in range(10):
        if custom_prediction.dim() < 5:
            try:
                multi_scale_ssim(custom_prediction, custom_target)
            except Exception as e:
                pytest.fail(f"Unexpected error occurred: {e}")
        else:
            with pytest.raises(AssertionError):
                multi_scale_ssim(custom_prediction, custom_target)
        custom_prediction.unsqueeze_(0)
        custom_target.unsqueeze_(0)
Beispiel #12
0
def test_multi_scale_ssim_symmetry(prediction_target_4d_5d: Tuple[
    torch.Tensor, torch.Tensor], device: str) -> None:
    prediction = prediction_target_4d_5d[0].to(device)
    target = prediction_target_4d_5d[1].to(device)
    measure = multi_scale_ssim(prediction,
                               target,
                               data_range=1.,
                               reduction='none')
    reverse_measure = multi_scale_ssim(target,
                                       prediction,
                                       data_range=1.,
                                       reduction='none')
    assert torch.allclose(measure, reverse_measure), f'Expect: MS-SSIM(a, b) == MSSSIM(b, a), '\
                                                     f'got {measure} != {reverse_measure}'
Beispiel #13
0
def test_multi_scale_ssim_measure_is_less_or_equal_to_one(ones_zeros_4d_5d: Tuple[torch.Tensor, torch.Tensor],
                                                          device: str) -> None:
    # Create two maximally different tensors.
    ones = ones_zeros_4d_5d[0].to(device)
    zeros = ones_zeros_4d_5d[1].to(device)
    measure = multi_scale_ssim(ones, zeros, data_range=1.)
    assert (measure <= 1).all(), f'MS-SSIM must be <= 1, got {measure}'
Beispiel #14
0
def test_multi_scale_ssim_measure_is_one_for_equal_tensors(x: torch.Tensor, device: str) -> None:
    x = x.to(device)
    y = x.clone()
    measure = multi_scale_ssim(y, x, data_range=1.)
    assert torch.allclose(measure, torch.ones_like(measure)), \
        f'If equal tensors are passed MS-SSIM must be equal to 1 ' \
        f'(considering floating point operation error up to 1 * 10^-6), got {measure + 1}'
def main():
    psnr_tensor = 0
    ssim_tensor = 0
    l1_tensor = 0
    l2_tensor = 0
    # lpips_tensor = 0
    count = 0
    t0 = time.time()
    print("Calculating image quality metrics ...")
    for batch in data_loader:
        img_batch, gt_batch = batch
        if torch.cuda.is_available():
            # Move to GPU to make computaions faster
            img_batch = img_batch.cuda()
            gt_batch = gt_batch.cuda()

        for i in range(gt_batch.shape[0]):
            gt, img = gt_batch[i], img_batch[i]

            # MS-SIM
            ms_ssim_index: torch.Tensor = piq.multi_scale_ssim(gt,
                                                               img,
                                                               data_range=1.)
            # PSNR
            psnr_index: torch.Tensor = piq.psnr(gt,
                                                img,
                                                data_range=1.,
                                                reduction='mean')
            # L1 Error
            l1_index = nn.L1Loss(reduction='mean')(gt, img)
            # L1 Error
            l2_index = nn.MSELoss(reduction='mean')(gt, img)
            # LPIPS
            # lpips_loss: torch.Tensor = piq.LPIPS(reduction='mean')(gt, img)

            # Adding for computing average value
            ssim_tensor += ms_ssim_index
            psnr_tensor += psnr_index
            l1_tensor += l1_index
            l2_tensor += l2_index
            # lpips_tensor += lpips_loss.item()

            count += 1

    t1 = time.time()

    # print(
    #     "Avg. LPIPS: {} \nAvg. SSIM: {} \nAvg. PSNR: {} \nAvg. L1: {} \nAvg. L2: {} \n".format(lpips_tensor / count,
    #                                                                                            ssim_tensor / count,
    #                                                                                            psnr_tensor / count,
    #                                                                                            l1_tensor / count,
    #                                                                                            l2_tensor / count))

    print(
        "Avg. SSIM: {} \nAvg. PSNR: {} \nAvg. L1: {} \nAvg. L2: {} \n".format(
            ssim_tensor / count, psnr_tensor / count, l1_tensor / count,
            l2_tensor / count))
    print(count)
    print("Average processing time for each image (of total {} images): {} s".
          format(count, (t1 - t0) / count))
Beispiel #16
0
def test_multi_scale_ssim_raise_if_wrong_value_is_estimated(
        test_images: Tuple[torch.Tensor, torch.Tensor],
        scale_weights: torch.Tensor, device: str) -> None:
    for x, y in test_images:
        piq_ms_ssim = multi_scale_ssim(x.to(device),
                                       y.to(device),
                                       kernel_size=11,
                                       kernel_sigma=1.5,
                                       data_range=255,
                                       reduction='none',
                                       scale_weights=scale_weights)
        tf_x = tf.convert_to_tensor(x.permute(0, 2, 3, 1).numpy())
        tf_y = tf.convert_to_tensor(y.permute(0, 2, 3, 1).numpy())
        with tf.device('/CPU'):
            tf_ms_ssim = torch.tensor(
                tf.image.ssim_multiscale(
                    tf_x,
                    tf_y,
                    max_val=255,
                    power_factors=scale_weights.numpy()).numpy()).to(device)
        match_accuracy = 1e-4 + 1e-8
        assert torch.allclose(piq_ms_ssim, tf_ms_ssim, rtol=0, atol=match_accuracy), \
            f'The estimated value must be equal to tensorflow provided one' \
            f'(considering floating point operation error up to {match_accuracy}), ' \
            f'got difference {(piq_ms_ssim - tf_ms_ssim).abs()}'
Beispiel #17
0
    def eval(self, gt, pred):
        with torch.no_grad():
            gt_tensor = torch.Tensor(gt).clamp(0, 1).permute(0, 3, 1,
                                                             2).to('cuda:0')
            pred_tensor = torch.Tensor(pred).clamp(0,
                                                   1).permute(0, 3, 1,
                                                              2).to('cuda:0')
            psnr_index = piq.psnr(pred_tensor,
                                  gt_tensor,
                                  data_range=1.,
                                  reduction='none').item()
            _, _, h, w = gt_tensor.shape

            lpipsAlex = 0
            lpipsVGG = 0
            msssim_index = 0
            ssim_index = 0
            n = 1
            for i in range(n):
                for j in range(n):
                    xstart = w // n * j
                    ystart = h // n * i
                    xend = w // n * (j + 1)
                    yend = h // n * (i + 1)
                    ssim_index += piq.ssim(pred_tensor[:, :, ystart:yend,
                                                       xstart:xend],
                                           gt_tensor[:, :, ystart:yend,
                                                     xstart:xend],
                                           data_range=1.,
                                           reduction='none').item()
                    msssim_index = piq.multi_scale_ssim(
                        pred_tensor[:, :, ystart:yend, xstart:xend],
                        gt_tensor[:, :, ystart:yend, xstart:xend],
                        data_range=1.,
                        reduction='none').item()
                    lpipsVGG += self.lpipsVGG(
                        pred_tensor[:, :, ystart:yend, xstart:xend],
                        gt_tensor[:, :, ystart:yend, xstart:xend]).item()
                    lpipsAlex += self.lpipsAlex(
                        pred_tensor[:, :, ystart:yend, xstart:xend],
                        gt_tensor[:, :, ystart:yend, xstart:xend]).item()
            msssim_index /= n * n
            ssim_index /= n * n
            lpipsVGG /= n * n
            lpipsAlex /= n * n
            # dists = piq.DISTS(reduction='none')(pred_tensor, gt_tensor).item()

            # with torch.no_grad():
            #     lpips_index = piq.LPIPS(reduction='none')(pred_tensor, gt_tensor).item()
            rmse = ((gt - pred)**2).mean()**0.5
            # relmse = (((gt - pred) ** 2).mean() / (gt ** 2).mean() + 1e-5) ** 0.5
            # return {'rmse':rmse,'relmse':relmse,'psnr':psnr_index,'ssim':ssim_index,'msssim':msssim_index,'lpips':lpips_index}
        return {
            'rmse': rmse,
            'psnr': psnr_index,
            'ssim': ssim_index,
            'msssim': msssim_index,
            'lpipsVGG': lpipsVGG,
            'lpipsAlex': lpipsAlex
        }
Beispiel #18
0
def test_multi_scale_ssim_supports_different_data_ranges(
        prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor], data_range,
        device: str) -> None:
    prediction, target = prediction_target_4d_5d
    prediction_scaled = (prediction * data_range).type(torch.uint8)
    target_scaled = (target * data_range).type(torch.uint8)

    measure_scaled = multi_scale_ssim(prediction_scaled.to(device),
                                      target_scaled.to(device),
                                      data_range=data_range)
    measure = multi_scale_ssim(prediction_scaled.to(device) /
                               float(data_range),
                               target_scaled.to(device) / float(data_range),
                               data_range=1.0)
    diff = torch.abs(measure_scaled - measure)
    assert (diff <= 1e-6).all(
    ), f'Result for same tensor with different data_range should be the same, got {diff}'
Beispiel #19
0
    def eval(self,gt,pred,imformat='BHWC',dtype='jax'):
        if(dtype == 'jax'):
            gt = np.array(gt)
            pred = np.array(pred)

        with torch.no_grad():
            if(imformat == 'BHWC'):
                gt_tensor = torch.Tensor(gt).permute(0,3,1,2).to(self.device)
                pred_tensor = torch.Tensor(pred).permute(0,3,1,2).to(self.device)
            elif(imformat == 'HWC'):
                gt_tensor = torch.Tensor(gt[None,...]).permute(0,3,1,2).to(self.device)
                pred_tensor = torch.Tensor(pred[None,...]).permute(0,3,1,2).to(self.device)
            else:
                print('Unknown image dimension format')
                exit(0)

            pred_tensor = torch.clamp(pred_tensor,0,1)
            gt_tensor = torch.clamp(gt_tensor,0,1)
            _,_,h,w = gt_tensor.shape
            
            lpipsAlex = 0
            lpipsVGG = 0
            msssim_index = 0
            ssim_index = 0
            n = 1
            for i in range(n):
                for j in range(n):
                    xstart = w//n * j
                    ystart = h//n * i
                    xend = w//n * (j+1)
                    yend = h//n * (i+1)
                    if('ssim' in self.metrics):
                        ssim_index += piq.ssim(pred_tensor[:,:,ystart:yend,xstart:xend], gt_tensor[:,:,ystart:yend,xstart:xend], data_range=1., reduction='mean').item()
                    if('msssim' in self.metrics):
                        msssim_index = piq.multi_scale_ssim(pred_tensor[:,:,ystart:yend,xstart:xend], gt_tensor[:,:,ystart:yend,xstart:xend], data_range=1., reduction='mean').item()
                    if('lpipsVGG' in self.metrics):
                        lpipsVGG +=  self.lpipsVGG(pred_tensor[:,:,ystart:yend,xstart:xend], gt_tensor[:,:,ystart:yend,xstart:xend]).item()
                    if('lpipsAlex' in self.metrics):
                        lpipsAlex +=  self.lpipsAlex(pred_tensor[:,:,ystart:yend,xstart:xend], gt_tensor[:,:,ystart:yend,xstart:xend]).item()

            res = {}
            if('ssim' in self.metrics):
                res['ssim'] = ssim_index / n*n
            if('msssim' in self.metrics):
                res['msssim'] = msssim_index / n*n
            if('lpipsVGG' in self.metrics):
                res['lpipsVGG'] = lpipsVGG / n*n
            if('lpipsAlex' in self.metrics):
                res['lpipsAlex'] = lpipsAlex / n*n
            if('rmse' in self.metrics):
                res['rmse'] = ((gt_tensor - pred_tensor) ** 2).mean() ** 0.5
            
            res['mse'] = float(((gt_tensor - pred_tensor) ** 2).mean().cpu().numpy())
            res['psnr'] = -10. * np.log10(res['mse']) / np.log10(10.)
            
        return res
Beispiel #20
0
def test_multi_scale_ssim_raises_if_tensors_have_different_shapes(x_y_4d_5d, device: str) -> None:
    y = x_y_4d_5d[1].to(device)
    dims = [[3], [2, 3], [161, 162], [161, 162]]
    if y.dim() == 5:
        dims += [[2, 3]]
    for size in list(itertools.product(*dims)):
        wrong_shape_x = torch.rand(size).to(y)
        if wrong_shape_x.size() == y.size():
            multi_scale_ssim(wrong_shape_x, y)
        else:
            with pytest.raises(AssertionError):
                multi_scale_ssim(wrong_shape_x, y)
    scale_weights = torch.rand(2, 2)
    with pytest.raises(AssertionError):
        multi_scale_ssim(x, y, scale_weights=scale_weights)
Beispiel #21
0
def test_multi_scale_ssim_raises_if_tensors_have_different_shapes(
        prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor],
        device: str) -> None:
    target = prediction_target_4d_5d[1].to(device)
    dims = [[3], [2, 3], [161, 162], [161, 162]]
    if target.dim() == 5:
        dims += [[2, 3]]
    for size in list(itertools.product(*dims)):
        wrong_shape_prediction = torch.rand(size).to(target)
        if wrong_shape_prediction.size() == target.size():
            multi_scale_ssim(wrong_shape_prediction, target)
        else:
            with pytest.raises(AssertionError):
                multi_scale_ssim(wrong_shape_prediction, target)
    scale_weights = torch.rand(2, 2)
    with pytest.raises(AssertionError):
        multi_scale_ssim(prediction, target, scale_weights=scale_weights)
Beispiel #22
0
 def forward(self, predict, target):
     if self.l1_norm:
         l1_norm_metric = nn.functional.l1_loss(predict, target)
     if self.mse:
         mse_norm_metric = nn.functional.mse_loss(predict, target)
     if self.pearsonr:
         pearsonr_metric = audtorch.metrics.functional.pearsonr(predict, target).mean()
     if self.cc:
         cc_metric = audtorch.metrics.functional.concordance_cc(predict, target).mean()
     if self.psnr:
         psnr_metric = piq.psnr(predict, target, data_range=1., reduction='none').mean()
     if self.ssim:
         ssim_metric = piq.ssim(predict, target, data_range=1.)
     if self.mssim:
         mssim_metric = piq.multi_scale_ssim(predict, target, data_range=1.)
     metric_summary = {'l1_norm': l1_norm_metric,
                       'mse': mse_norm_metric,
                       'pearsonr_metric': pearsonr_metric,
                       'cc': cc_metric,
                       'psnr': psnr_metric,
                       'ssim': ssim_metric,
                       'mssim': mssim_metric
                       }
     return metric_summary
Beispiel #23
0
def test_multi_scale_ssim_preserves_dtype(x, y, dtype, device: str) -> None:
    output = multi_scale_ssim(x.to(device=device, dtype=dtype),
                              y.to(device=device, dtype=dtype))
    assert output.dtype == dtype
Beispiel #24
0
def main():
    # Read RGB image and it's noisy version
    x = torch.tensor(imread('tests/assets/i01_01_5.bmp')).permute(2, 0,
                                                                  1) / 255.
    y = torch.tensor(imread('tests/assets/I01.BMP')).permute(2, 0, 1) / 255.

    if torch.cuda.is_available():
        # Move to GPU to make computaions faster
        x = x.cuda()
        y = y.cuda()

    # To compute BRISQUE score as a measure, use lower case function from the library
    brisque_index: torch.Tensor = piq.brisque(x,
                                              data_range=1.,
                                              reduction='none')
    # In order to use BRISQUE as a loss function, use corresponding PyTorch module.
    # Note: the back propagation is not available using torch==1.5.0.
    # Update the environment with latest torch and torchvision.
    brisque_loss: torch.Tensor = piq.BRISQUELoss(data_range=1.,
                                                 reduction='none')(x)
    print(
        f"BRISQUE index: {brisque_index.item():0.4f}, loss: {brisque_loss.item():0.4f}"
    )

    # To compute Content score as a loss function, use corresponding PyTorch module
    # By default VGG16 model is used, but any feature extractor model is supported.
    # Don't forget to adjust layers names accordingly. Features from different layers can be weighted differently.
    # Use weights parameter. See other options in class docstring.
    content_loss = piq.ContentLoss(feature_extractor="vgg16",
                                   layers=("relu3_3", ),
                                   reduction='none')(x, y)
    print(f"ContentLoss: {content_loss.item():0.4f}")

    # To compute DISTS as a loss function, use corresponding PyTorch module
    # By default input images are normalized with ImageNet statistics before forwarding through VGG16 model.
    # If there is no need to normalize the data, use mean=[0.0, 0.0, 0.0] and std=[1.0, 1.0, 1.0].
    dists_loss = piq.DISTS(reduction='none')(x, y)
    print(f"DISTS: {dists_loss.item():0.4f}")

    # To compute FSIM as a measure, use lower case function from the library
    fsim_index: torch.Tensor = piq.fsim(x, y, data_range=1., reduction='none')
    # In order to use FSIM as a loss function, use corresponding PyTorch module
    fsim_loss = piq.FSIMLoss(data_range=1., reduction='none')(x, y)
    print(
        f"FSIM index: {fsim_index.item():0.4f}, loss: {fsim_loss.item():0.4f}")

    # To compute GMSD as a measure, use lower case function from the library
    # This is port of MATLAB version from the authors of original paper.
    # In any case it should me minimized. Usually values of GMSD lie in [0, 0.35] interval.
    gmsd_index: torch.Tensor = piq.gmsd(x, y, data_range=1., reduction='none')
    # In order to use GMSD as a loss function, use corresponding PyTorch module:
    gmsd_loss: torch.Tensor = piq.GMSDLoss(data_range=1., reduction='none')(x,
                                                                            y)
    print(
        f"GMSD index: {gmsd_index.item():0.4f}, loss: {gmsd_loss.item():0.4f}")

    # To compute HaarPSI as a measure, use lower case function from the library
    # This is port of MATLAB version from the authors of original paper.
    haarpsi_index: torch.Tensor = piq.haarpsi(x,
                                              y,
                                              data_range=1.,
                                              reduction='none')
    # In order to use HaarPSI as a loss function, use corresponding PyTorch module
    haarpsi_loss: torch.Tensor = piq.HaarPSILoss(data_range=1.,
                                                 reduction='none')(x, y)
    print(
        f"HaarPSI index: {haarpsi_index.item():0.4f}, loss: {haarpsi_loss.item():0.4f}"
    )

    # To compute LPIPS as a loss function, use corresponding PyTorch module
    lpips_loss: torch.Tensor = piq.LPIPS(reduction='none')(x, y)
    print(f"LPIPS: {lpips_loss.item():0.4f}")

    # To compute MDSI as a measure, use lower case function from the library
    mdsi_index: torch.Tensor = piq.mdsi(x, y, data_range=1., reduction='none')
    # In order to use MDSI as a loss function, use corresponding PyTorch module
    mdsi_loss: torch.Tensor = piq.MDSILoss(data_range=1., reduction='none')(x,
                                                                            y)
    print(
        f"MDSI index: {mdsi_index.item():0.4f}, loss: {mdsi_loss.item():0.4f}")

    # To compute MS-SSIM index as a measure, use lower case function from the library:
    ms_ssim_index: torch.Tensor = piq.multi_scale_ssim(x, y, data_range=1.)
    # In order to use MS-SSIM as a loss function, use corresponding PyTorch module:
    ms_ssim_loss = piq.MultiScaleSSIMLoss(data_range=1., reduction='none')(x,
                                                                           y)
    print(
        f"MS-SSIM index: {ms_ssim_index.item():0.4f}, loss: {ms_ssim_loss.item():0.4f}"
    )

    # To compute Multi-Scale GMSD as a measure, use lower case function from the library
    # It can be used both as a measure and as a loss function. In any case it should me minimized.
    # By defualt scale weights are initialized with values from the paper.
    # You can change them by passing a list of 4 variables to scale_weights argument during initialization
    # Note that input tensors should contain images with height and width equal 2 ** number_of_scales + 1 at least.
    ms_gmsd_index: torch.Tensor = piq.multi_scale_gmsd(x,
                                                       y,
                                                       data_range=1.,
                                                       chromatic=True,
                                                       reduction='none')
    # In order to use Multi-Scale GMSD as a loss function, use corresponding PyTorch module
    ms_gmsd_loss: torch.Tensor = piq.MultiScaleGMSDLoss(chromatic=True,
                                                        data_range=1.,
                                                        reduction='none')(x, y)
    print(
        f"MS-GMSDc index: {ms_gmsd_index.item():0.4f}, loss: {ms_gmsd_loss.item():0.4f}"
    )

    # To compute PSNR as a measure, use lower case function from the library.
    psnr_index = piq.psnr(x, y, data_range=1., reduction='none')
    print(f"PSNR index: {psnr_index.item():0.4f}")

    # To compute PieAPP as a loss function, use corresponding PyTorch module:
    pieapp_loss: torch.Tensor = piq.PieAPP(reduction='none', stride=32)(x, y)
    print(f"PieAPP loss: {pieapp_loss.item():0.4f}")

    # To compute SSIM index as a measure, use lower case function from the library:
    ssim_index = piq.ssim(x, y, data_range=1.)
    # In order to use SSIM as a loss function, use corresponding PyTorch module:
    ssim_loss: torch.Tensor = piq.SSIMLoss(data_range=1.)(x, y)
    print(
        f"SSIM index: {ssim_index.item():0.4f}, loss: {ssim_loss.item():0.4f}")

    # To compute Style score as a loss function, use corresponding PyTorch module:
    # By default VGG16 model is used, but any feature extractor model is supported.
    # Don't forget to adjust layers names accordingly. Features from different layers can be weighted differently.
    # Use weights parameter. See other options in class docstring.
    style_loss = piq.StyleLoss(feature_extractor="vgg16",
                               layers=("relu3_3", ))(x, y)
    print(f"Style: {style_loss.item():0.4f}")

    # To compute TV as a measure, use lower case function from the library:
    tv_index: torch.Tensor = piq.total_variation(x)
    # In order to use TV as a loss function, use corresponding PyTorch module:
    tv_loss: torch.Tensor = piq.TVLoss(reduction='none')(x)
    print(f"TV index: {tv_index.item():0.4f}, loss: {tv_loss.item():0.4f}")

    # To compute VIF as a measure, use lower case function from the library:
    vif_index: torch.Tensor = piq.vif_p(x, y, data_range=1.)
    # In order to use VIF as a loss function, use corresponding PyTorch class:
    vif_loss: torch.Tensor = piq.VIFLoss(sigma_n_sq=2.0, data_range=1.)(x, y)
    print(f"VIFp index: {vif_index.item():0.4f}, loss: {vif_loss.item():0.4f}")

    # To compute VSI score as a measure, use lower case function from the library:
    vsi_index: torch.Tensor = piq.vsi(x, y, data_range=1.)
    # In order to use VSI as a loss function, use corresponding PyTorch module:
    vsi_loss: torch.Tensor = piq.VSILoss(data_range=1.)(x, y)
    print(f"VSI index: {vsi_index.item():0.4f}, loss: {vsi_loss.item():0.4f}")