Exemplo n.º 1
0
def test_brisque_raises_if_wrong_reduction(x_grey: torch.Tensor, device: str) -> None:
    for mode in ['mean', 'sum', 'none']:
        brisque(x_grey.to(device), reduction=mode)

    for mode in [None, 'n', 2]:
        with pytest.raises(KeyError):
            brisque(x_grey.to(device), reduction=mode)
Exemplo n.º 2
0
def test_brisque_supports_different_data_ranges(x_rgb: torch.Tensor,
                                                data_range,
                                                device: str) -> None:
    x_scaled = (x_rgb * data_range).type(torch.uint8)
    loss_scaled = brisque(x_scaled.to(device), data_range=data_range)
    loss = brisque(x_scaled.to(device) / float(data_range), data_range=1.0)
    diff = torch.abs(loss_scaled - loss)
    assert diff <= 1e-5, f'Result for same tensor with different data_range should be the same, got {diff}'
def main():
    # Read RGB image and it's noisy version
    directory = 'test'
    im = torch.tensor(imread(directory + '/image.bmp')).permute(
        2, 0, 1)[None, ...] / 255.
    im1 = torch.tensor(imread(directory + '/image.jpg')).permute(
        2, 0, 1)[None, ...] / 255.
    #im2 = torch.tensor(imread(directory + '/image2.bmp')).permute(2, 0, 1)[None, ...] / 255.
    #im3 = torch.tensor(imread(directory + '/image3.bmp')).permute(2, 0, 1)[None, ...] / 255.
    #im4 = torch.tensor(imread(directory + '/image4.bmp')).permute(2, 0, 1)[None, ...] / 255.
    #im5 = torch.tensor(imread(directory + '/image5.bmp')).permute(2, 0, 1)[None, ...] / 255.

    im_number = 0
    for image in [im, im1]:  #, im2, im3, im4, im5]:
        print('im' + str(im_number))
        im_number += 1

        # To compute BRISQUE score as a measure, use lower case function from the library
        brisque_index: torch.Tensor = piq.brisque(image,
                                                  data_range=1.,
                                                  reduction='sum')
        # In order to use BRISQUE as a loss function, use corresponding PyTorch module.
        # Note: the back propagation is not available using torch==1.5.0.
        # Update the environment with latest torch and torchvision.
        brisque_loss: torch.Tensor = piq.BRISQUELoss(data_range=1.,
                                                     reduction='none')(image)
        print(
            f"BRISQUE index: {brisque_index.item():0.4f}, loss: {brisque_loss.item():0.4f}"
        )
        print(
            '-------------------------------------------------------------------------------------'
        )
Exemplo n.º 4
0
def test_brisque_values_grey(device: str) -> None:
    img = imread('tests/assets/goldhill.gif')
    x_grey = torch.tensor(img).unsqueeze(0).unsqueeze(0)
    score = brisque(x_grey.to(device), reduction='none', data_range=255)
    score_baseline = BRISQUE().get_score(img)
    assert torch.isclose(score, torch.tensor(score_baseline).to(score), rtol=1e-3), \
        f'Expected values to be equal to baseline, got {score.item()} and {score_baseline}'
Exemplo n.º 5
0
def test_brisque_values_rgb(device: str) -> None:
    img = imread('tests/assets/I01.BMP')
    x_rgb = (torch.tensor(img, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0))
    score = brisque(x_rgb.to(device), reduction='none', data_range=255.)
    score_baseline = BRISQUE().get_score(x_rgb[0].permute(1, 2, 0).numpy()[..., ::-1])
    assert torch.isclose(score, torch.tensor(score_baseline).to(score), rtol=1e-3), \
        f'Expected values to be equal to baseline, got {score.item()} and {score_baseline}'
Exemplo n.º 6
0
def main():
    # Read RGB image and it's noisy version
    x = torch.tensor(imread('tests/assets/i01_01_5.bmp')).permute(2, 0,
                                                                  1) / 255.
    y = torch.tensor(imread('tests/assets/I01.BMP')).permute(2, 0, 1) / 255.

    if torch.cuda.is_available():
        # Move to GPU to make computaions faster
        x = x.cuda()
        y = y.cuda()

    # To compute BRISQUE score as a measure, use lower case function from the library
    brisque_index: torch.Tensor = piq.brisque(x,
                                              data_range=1.,
                                              reduction='none')
    # In order to use BRISQUE as a loss function, use corresponding PyTorch module.
    # Note: the back propagation is not available using torch==1.5.0.
    # Update the environment with latest torch and torchvision.
    brisque_loss: torch.Tensor = piq.BRISQUELoss(data_range=1.,
                                                 reduction='none')(x)
    print(
        f"BRISQUE index: {brisque_index.item():0.4f}, loss: {brisque_loss.item():0.4f}"
    )

    # To compute Content score as a loss function, use corresponding PyTorch module
    # By default VGG16 model is used, but any feature extractor model is supported.
    # Don't forget to adjust layers names accordingly. Features from different layers can be weighted differently.
    # Use weights parameter. See other options in class docstring.
    content_loss = piq.ContentLoss(feature_extractor="vgg16",
                                   layers=("relu3_3", ),
                                   reduction='none')(x, y)
    print(f"ContentLoss: {content_loss.item():0.4f}")

    # To compute DISTS as a loss function, use corresponding PyTorch module
    # By default input images are normalized with ImageNet statistics before forwarding through VGG16 model.
    # If there is no need to normalize the data, use mean=[0.0, 0.0, 0.0] and std=[1.0, 1.0, 1.0].
    dists_loss = piq.DISTS(reduction='none')(x, y)
    print(f"DISTS: {dists_loss.item():0.4f}")

    # To compute FSIM as a measure, use lower case function from the library
    fsim_index: torch.Tensor = piq.fsim(x, y, data_range=1., reduction='none')
    # In order to use FSIM as a loss function, use corresponding PyTorch module
    fsim_loss = piq.FSIMLoss(data_range=1., reduction='none')(x, y)
    print(
        f"FSIM index: {fsim_index.item():0.4f}, loss: {fsim_loss.item():0.4f}")

    # To compute GMSD as a measure, use lower case function from the library
    # This is port of MATLAB version from the authors of original paper.
    # In any case it should me minimized. Usually values of GMSD lie in [0, 0.35] interval.
    gmsd_index: torch.Tensor = piq.gmsd(x, y, data_range=1., reduction='none')
    # In order to use GMSD as a loss function, use corresponding PyTorch module:
    gmsd_loss: torch.Tensor = piq.GMSDLoss(data_range=1., reduction='none')(x,
                                                                            y)
    print(
        f"GMSD index: {gmsd_index.item():0.4f}, loss: {gmsd_loss.item():0.4f}")

    # To compute HaarPSI as a measure, use lower case function from the library
    # This is port of MATLAB version from the authors of original paper.
    haarpsi_index: torch.Tensor = piq.haarpsi(x,
                                              y,
                                              data_range=1.,
                                              reduction='none')
    # In order to use HaarPSI as a loss function, use corresponding PyTorch module
    haarpsi_loss: torch.Tensor = piq.HaarPSILoss(data_range=1.,
                                                 reduction='none')(x, y)
    print(
        f"HaarPSI index: {haarpsi_index.item():0.4f}, loss: {haarpsi_loss.item():0.4f}"
    )

    # To compute LPIPS as a loss function, use corresponding PyTorch module
    lpips_loss: torch.Tensor = piq.LPIPS(reduction='none')(x, y)
    print(f"LPIPS: {lpips_loss.item():0.4f}")

    # To compute MDSI as a measure, use lower case function from the library
    mdsi_index: torch.Tensor = piq.mdsi(x, y, data_range=1., reduction='none')
    # In order to use MDSI as a loss function, use corresponding PyTorch module
    mdsi_loss: torch.Tensor = piq.MDSILoss(data_range=1., reduction='none')(x,
                                                                            y)
    print(
        f"MDSI index: {mdsi_index.item():0.4f}, loss: {mdsi_loss.item():0.4f}")

    # To compute MS-SSIM index as a measure, use lower case function from the library:
    ms_ssim_index: torch.Tensor = piq.multi_scale_ssim(x, y, data_range=1.)
    # In order to use MS-SSIM as a loss function, use corresponding PyTorch module:
    ms_ssim_loss = piq.MultiScaleSSIMLoss(data_range=1., reduction='none')(x,
                                                                           y)
    print(
        f"MS-SSIM index: {ms_ssim_index.item():0.4f}, loss: {ms_ssim_loss.item():0.4f}"
    )

    # To compute Multi-Scale GMSD as a measure, use lower case function from the library
    # It can be used both as a measure and as a loss function. In any case it should me minimized.
    # By defualt scale weights are initialized with values from the paper.
    # You can change them by passing a list of 4 variables to scale_weights argument during initialization
    # Note that input tensors should contain images with height and width equal 2 ** number_of_scales + 1 at least.
    ms_gmsd_index: torch.Tensor = piq.multi_scale_gmsd(x,
                                                       y,
                                                       data_range=1.,
                                                       chromatic=True,
                                                       reduction='none')
    # In order to use Multi-Scale GMSD as a loss function, use corresponding PyTorch module
    ms_gmsd_loss: torch.Tensor = piq.MultiScaleGMSDLoss(chromatic=True,
                                                        data_range=1.,
                                                        reduction='none')(x, y)
    print(
        f"MS-GMSDc index: {ms_gmsd_index.item():0.4f}, loss: {ms_gmsd_loss.item():0.4f}"
    )

    # To compute PSNR as a measure, use lower case function from the library.
    psnr_index = piq.psnr(x, y, data_range=1., reduction='none')
    print(f"PSNR index: {psnr_index.item():0.4f}")

    # To compute PieAPP as a loss function, use corresponding PyTorch module:
    pieapp_loss: torch.Tensor = piq.PieAPP(reduction='none', stride=32)(x, y)
    print(f"PieAPP loss: {pieapp_loss.item():0.4f}")

    # To compute SSIM index as a measure, use lower case function from the library:
    ssim_index = piq.ssim(x, y, data_range=1.)
    # In order to use SSIM as a loss function, use corresponding PyTorch module:
    ssim_loss: torch.Tensor = piq.SSIMLoss(data_range=1.)(x, y)
    print(
        f"SSIM index: {ssim_index.item():0.4f}, loss: {ssim_loss.item():0.4f}")

    # To compute Style score as a loss function, use corresponding PyTorch module:
    # By default VGG16 model is used, but any feature extractor model is supported.
    # Don't forget to adjust layers names accordingly. Features from different layers can be weighted differently.
    # Use weights parameter. See other options in class docstring.
    style_loss = piq.StyleLoss(feature_extractor="vgg16",
                               layers=("relu3_3", ))(x, y)
    print(f"Style: {style_loss.item():0.4f}")

    # To compute TV as a measure, use lower case function from the library:
    tv_index: torch.Tensor = piq.total_variation(x)
    # In order to use TV as a loss function, use corresponding PyTorch module:
    tv_loss: torch.Tensor = piq.TVLoss(reduction='none')(x)
    print(f"TV index: {tv_index.item():0.4f}, loss: {tv_loss.item():0.4f}")

    # To compute VIF as a measure, use lower case function from the library:
    vif_index: torch.Tensor = piq.vif_p(x, y, data_range=1.)
    # In order to use VIF as a loss function, use corresponding PyTorch class:
    vif_loss: torch.Tensor = piq.VIFLoss(sigma_n_sq=2.0, data_range=1.)(x, y)
    print(f"VIFp index: {vif_index.item():0.4f}, loss: {vif_loss.item():0.4f}")

    # To compute VSI score as a measure, use lower case function from the library:
    vsi_index: torch.Tensor = piq.vsi(x, y, data_range=1.)
    # In order to use VSI as a loss function, use corresponding PyTorch module:
    vsi_loss: torch.Tensor = piq.VSILoss(data_range=1.)(x, y)
    print(f"VSI index: {vsi_index.item():0.4f}, loss: {vsi_loss.item():0.4f}")
Exemplo n.º 7
0
files = []
[files.extend(glob.glob(dirname + '*.' + e)) for e in extensions]
print("Found {} image files in {}".format(len(files), dirname))

# Iterate through files
scores = []
confidences = []
counter = 0
for file in files:
    counter += 1
    # Score BRISQUE image quality metric
    img = PIL.Image.open(file)
    trans = transforms.ToTensor()
    # For high-RAM GPUs, you can uncomment the "to.cuda" and get more performance
    gpu_img = trans(img)  #.to('cuda')
    brisk = brisque(gpu_img)
    scores.append(float(brisk))
    if brisk > 70:
        print("Found bad file {} with BRISQUE = {}".format(file, brisk))
        with open(badfile, 'a') as b:
            b.write("{}\n".format(file))
    elif brisk >= 20:
        pass
    elif brisk < 20:
        print("Found good file {} with BRISQUE = {}".format(file, brisk))
        with open(goodfile, 'a') as g:
            g.write("{}\n".format(file))
    else:
        pass

    # Run classifier
Exemplo n.º 8
0
def test_brisque_preserves_dtype(input_tensors: torch.Tensor, dtype,
                                 device: str) -> None:
    x, _ = input_tensors
    output = brisque(x.to(device=device, dtype=dtype))
    assert output.dtype == dtype
Exemplo n.º 9
0
def test_brisque_fails_for_incorrect_data_range(x_rgb: torch.Tensor,
                                                device: str) -> None:
    # Scale to [0, 255]
    x_scaled = (x_rgb * 255).type(torch.uint8)
    with pytest.raises(AssertionError):
        brisque(x_scaled.to(device), data_range=1.0)
Exemplo n.º 10
0
def test_brisque_for_special_cases(input: torch.Tensor, expectation: Any,
                                   device: str) -> None:
    with expectation:
        brisque(input.to(device), reduction='mean')
Exemplo n.º 11
0
def test_brisque_works_with_rgb(x_rgb, device: str) -> None:
    brisque(x_rgb.to(device))
Exemplo n.º 12
0
def test_brisque_works_with_grey(x_grey: torch.Tensor, device: str) -> None:
    brisque(x_grey.to(device))
Exemplo n.º 13
0
def test_brisque_if_works_with_rgb(prediction_rgb, device: str) -> None:
    brisque(prediction_rgb.to(device))
Exemplo n.º 14
0
def test_brisque_if_works_with_grey(prediction_grey: torch.Tensor,
                                    device: str) -> None:
    brisque(prediction_grey.to(device))