예제 #1
0
def test_mutual_information():
    camera_image = camera()
    camera_image_with_noise = add_noise(camera())

    mi = mutual_information(camera_image, camera_image, normalised=False)
    mi_n = mutual_information(camera_image,
                              camera_image_with_noise,
                              normalised=False)

    assert mi > mi_n
예제 #2
0
def test_normalised_mutual_information():
    camera_image = camera()
    camera_image_with_noise = add_noise(camera())

    assert pytest.approx(
        mutual_information(camera_image, camera_image, normalised=True), 1)
    assert pytest.approx(
        mutual_information(camera_image_with_noise,
                           camera_image_with_noise,
                           normalised=True),
        1,
    )

    assert (mutual_information(
        camera_image, camera_image_with_noise, normalised=True) < 1)
예제 #3
0
def benchmark_on_image(run_name, folder, image_name, image, methods):
    def printscore(header, val1, val2, val3, val4):
        print(
            f"{header}: \t {val1:.4f} \t {val2:.4f} \t {val3:.4f} \t {val4:.4f}"
        )

    image = normalise(image.astype(numpy.float32))

    gt_numpy_filepath = join(join(folder, 'gt_numpy'),
                             f'{image_name}' + '.npy')
    numpy.save(gt_numpy_filepath, image)

    blurred_image, psf_kernel = add_microscope_blur_2d(
        image, multi_channel=image.ndim == 3)

    noisy_blurred_image = add_poisson_gaussian_noise(blurred_image,
                                                     alpha=0.001,
                                                     sigma=0.1,
                                                     sap=0.01,
                                                     quant_bits=10)

    blurrynoisy_numpy_filepath = join(join(folder, 'blurrynoisy_numpy'),
                                      f'{image_name}' + '.npy')
    numpy.save(blurrynoisy_numpy_filepath, noisy_blurred_image)

    blurry_filepath = join(join(folder, 'blurry'), image_name)
    save_png(blurry_filepath, blurred_image)

    blurrynoisy_filepath = join(join(folder, 'blurrynoisy'), image_name)
    save_png(blurrynoisy_filepath, noisy_blurred_image)

    method_names = [method.__name__ for method in methods]

    # We restore the images with all methods:

    restored_image_list = []

    with open(join(folder, f"timming_{run_name}.tsv"), "a") as timming_file:

        for restore in methods:

            restored_cached_filepath = join(
                join(folder, 'restored_cache_numpy'),
                f'{run_name}_{restore.__name__}_' + image_name + '.npy')

            if exists(restored_cached_filepath):
                print(
                    f"File: {restored_cached_filepath} does exists: skipping restoration."
                )
                restored_image = numpy.load(restored_cached_filepath)
            else:
                print(
                    f"File: {restored_cached_filepath} does not exists, restoration started."
                )
                restored_image, train_time, inf_time = restore(
                    noisy_blurred_image, psf_kernel)
                numpy.save(restored_cached_filepath, restored_image)
                timming_file.write(
                    f"{image_name}\t{restore.__name__}\t{train_time}\t{inf_time}\n"
                )

            restored_image_list.append(restored_image)

            restored_filepath = join(
                join(folder, 'restored'),
                f'{run_name}_{restore.__name__}_' + image_name)
            save_png(restored_filepath, restored_image)

    # We compute scores:
    with open(join(folder, f"scores_{run_name}.tsv"), "a") as scores_file:

        blurred_psnr_value = psnr(image, blurred_image)
        blurred_ssim_value = ssim(image, blurred_image)
        blurred_mi_value = mutual_information(image, blurred_image)
        blurred_smi_value = spectral_mutual_information(image, blurred_image)

        noisy_blurred_psnr_value = psnr(image, noisy_blurred_image)
        noisy_blurred_ssim_value = ssim(image, noisy_blurred_image)
        noisy_blurred_mi_value = mutual_information(image, noisy_blurred_image)
        noisy_blurred_smi_value = spectral_mutual_information(
            image, noisy_blurred_image)

        scores_file.write(
            f"{image_name}\tblurry\t{blurred_psnr_value}\t{blurred_ssim_value}\t{blurred_mi_value}\t{blurred_smi_value}\n"
        )
        scores_file.write(
            f"{image_name}\tnoisy&blurred\t{noisy_blurred_psnr_value}\t{noisy_blurred_ssim_value}\t{noisy_blurred_mi_value}\t{noisy_blurred_smi_value}\n"
        )

        print(
            "Below in order: PSNR, norm spectral mutual info, norm mutual info, SSIM: "
        )
        printscore(
            "blurry image                       \t\t: ",
            blurred_psnr_value,
            blurred_ssim_value,
            blurred_mi_value,
            blurred_smi_value,
        )

        printscore(
            "noisy and blurry image             \t\t: ",
            noisy_blurred_psnr_value,
            noisy_blurred_ssim_value,
            noisy_blurred_mi_value,
            noisy_blurred_smi_value,
        )

        for restore in methods:
            restored_filepath = join(
                join(folder, 'restored_cache_numpy'),
                f'{run_name}_{restore.__name__}_' + image_name + '.npy')
            restored_image = numpy.load(restored_filepath)

            psnr_value = psnr(image, restored_image)
            ssim_value = ssim(image, restored_image)
            mi_value = mutual_information(image, restored_image)
            smi_value = spectral_mutual_information(image, restored_image)

            printscore(f"restored with {restore.__name__}  \t\t: ", psnr_value,
                       ssim_value, mi_value, smi_value)

            scores_file.write(
                f"{image_name}\t{restore.__name__}\t{psnr_value}\t{ssim_value}\t{mi_value}\t{smi_value}\n"
            )
예제 #4
0
def demo(image_clipped):
    image_clipped = normalise(image_clipped.astype(numpy.float32))
    blurred_image, psf_kernel = add_microscope_blur_2d(image_clipped)
    # noisy_blurred_image = add_noise(blurred_image, intensity=None, variance=0.01, sap=0.01, clip=True)
    noisy_blurred_image = add_poisson_gaussian_noise(blurred_image,
                                                     alpha=0.001,
                                                     sigma=0.1,
                                                     sap=0.01,
                                                     quant_bits=10)

    for i in range(10):
        it_deconv = SSIDeconvolution(
            max_epochs=3000,
            patience=300,
            batch_size=8,
            learning_rate=0.01,
            normaliser_type='identity',
            psf_kernel=psf_kernel,
            model_class=UNet,
            masking=True,
            masking_density=0.01,
            loss='l2',
        )

        start = time.time()
        it_deconv.train(noisy_blurred_image)
        stop = time.time()
        print(f"Training: elapsed time:  {stop - start} ")

        start = time.time()
        deconvolved_image = it_deconv.translate(noisy_blurred_image)
        stop = time.time()
        print(f"inference: elapsed time:  {stop - start} ")

        image_clipped = numpy.clip(image_clipped, 0, 1)
        deconvolved_image_clipped = numpy.clip(deconvolved_image, 0, 1)

        printscore(
            "ssi deconv            : ",
            psnr(image_clipped, deconvolved_image_clipped),
            spectral_mutual_information(image_clipped,
                                        deconvolved_image_clipped),
            mutual_information(image_clipped, deconvolved_image_clipped),
            ssim(image_clipped, deconvolved_image_clipped),
        )

        print(
            "NOTE: if you get a bad results for ssi, blame stochastic optimisation and retry..."
        )
        print(
            "      The training is done on the same exact image that we infer on, very few pixels..."
        )
        print("      Training should be more stable given more data...")

        folder = Path(".") / f"{str(i)}"
        folder.mkdir(exist_ok=True)

        #imwrite(str(fo'image2d.tif',image)
        #imwrite('blurred2d.tif', blurred_image)
        imwrite(str(folder / 'noisyblurred.tif'), noisy_blurred_image)
        #imwrite('lr_deconvolved_image_2.tif',lr_deconvolved_image_2_clipped)
        #imwrite('lr_deconvolved_image_5.tif',lr_deconvolved_image_5_clipped)
        #imwrite('lr_deconvolved_image_10.tif',lr_deconvolved_image_10_clipped)
        #imwrite('lr_deconvolved_image_20.tif',lr_deconvolved_image_20_clipped)
        imwrite(str(folder / 'ssi_deconvolved_clipped_image.tif'),
                deconvolved_image_clipped)
        imwrite(str(folder / 'ssi_deconvolved_image.tif'), deconvolved_image)
예제 #5
0
def demo(image_clipped):
    image_clipped = normalise(image_clipped.astype(numpy.float32))
    blurred_image, psf_kernel = add_microscope_blur_2d(image_clipped)
    # noisy_blurred_image = add_noise(blurred_image, intensity=None, variance=0.01, sap=0.01, clip=True)
    noisy_blurred_image = add_poisson_gaussian_noise(blurred_image,
                                                     alpha=0.001,
                                                     sigma=0.1,
                                                     sap=0.01,
                                                     quant_bits=10)

    lr = ImageTranslatorLRDeconv(psf_kernel=psf_kernel, backend="cupy")
    lr.train(noisy_blurred_image)
    lr.max_num_iterations = 2
    lr_deconvolved_image_2 = lr.translate(noisy_blurred_image)
    lr.max_num_iterations = 5
    lr_deconvolved_image_5 = lr.translate(noisy_blurred_image)
    lr.max_num_iterations = 10
    lr_deconvolved_image_10 = lr.translate(noisy_blurred_image)
    lr.max_num_iterations = 20
    lr_deconvolved_image_20 = lr.translate(noisy_blurred_image)

    it_deconv = SSIDeconvolution(
        max_epochs=3000,
        patience=300,
        batch_size=8,
        learning_rate=0.01,
        normaliser_type='identity',
        psf_kernel=psf_kernel,
        model_class=UNet,
        masking=True,
        masking_density=0.01,
        loss='l2',
    )

    start = time.time()
    it_deconv.train(noisy_blurred_image)
    stop = time.time()
    print(f"Training: elapsed time:  {stop - start} ")

    start = time.time()
    deconvolved_image = it_deconv.translate(noisy_blurred_image)
    stop = time.time()
    print(f"inference: elapsed time:  {stop - start} ")

    image_clipped = numpy.clip(image_clipped, 0, 1)
    lr_deconvolved_image_2_clipped = numpy.clip(lr_deconvolved_image_2, 0, 1)
    lr_deconvolved_image_5_clipped = numpy.clip(lr_deconvolved_image_5, 0, 1)
    lr_deconvolved_image_10_clipped = numpy.clip(lr_deconvolved_image_10, 0, 1)
    lr_deconvolved_image_20_clipped = numpy.clip(lr_deconvolved_image_20, 0, 1)
    deconvolved_image_clipped = numpy.clip(deconvolved_image, 0, 1)

    print(
        "Below in order: PSNR, norm spectral mutual info, norm mutual info, SSIM: "
    )
    printscore(
        "blurry image          :   ",
        psnr(image_clipped, blurred_image),
        spectral_mutual_information(image_clipped, blurred_image),
        mutual_information(image_clipped, blurred_image),
        ssim(image_clipped, blurred_image),
    )

    printscore(
        "noisy and blurry image:   ",
        psnr(image_clipped, noisy_blurred_image),
        spectral_mutual_information(image_clipped, noisy_blurred_image),
        mutual_information(image_clipped, noisy_blurred_image),
        ssim(image_clipped, noisy_blurred_image),
    )

    printscore(
        "lr deconv (n=2)       :    ",
        psnr(image_clipped, lr_deconvolved_image_2_clipped),
        spectral_mutual_information(image_clipped,
                                    lr_deconvolved_image_2_clipped),
        mutual_information(image_clipped, lr_deconvolved_image_2_clipped),
        ssim(image_clipped, lr_deconvolved_image_2_clipped),
    )

    printscore(
        "lr deconv (n=5)       :    ",
        psnr(image_clipped, lr_deconvolved_image_5_clipped),
        spectral_mutual_information(image_clipped,
                                    lr_deconvolved_image_5_clipped),
        mutual_information(image_clipped, lr_deconvolved_image_5_clipped),
        ssim(image_clipped, lr_deconvolved_image_5_clipped),
    )

    printscore(
        "lr deconv (n=10)      :    ",
        psnr(image_clipped, lr_deconvolved_image_10_clipped),
        spectral_mutual_information(image_clipped,
                                    lr_deconvolved_image_10_clipped),
        mutual_information(image_clipped, lr_deconvolved_image_10_clipped),
        ssim(image_clipped, lr_deconvolved_image_10_clipped),
    )

    printscore(
        "lr deconv (n=20)      :    ",
        psnr(image_clipped, lr_deconvolved_image_20_clipped),
        spectral_mutual_information(image_clipped,
                                    lr_deconvolved_image_20_clipped),
        mutual_information(image_clipped, lr_deconvolved_image_20_clipped),
        ssim(image_clipped, lr_deconvolved_image_20_clipped),
    )

    printscore(
        "ssi deconv            : ",
        psnr(image_clipped, deconvolved_image_clipped),
        spectral_mutual_information(image_clipped, deconvolved_image_clipped),
        mutual_information(image_clipped, deconvolved_image_clipped),
        ssim(image_clipped, deconvolved_image_clipped),
    )

    print(
        "NOTE: if you get a bad results for ssi, blame stochastic optimisation and retry..."
    )
    print(
        "      The training is done on the same exact image that we infer on, very few pixels..."
    )
    print("      Training should be more stable given more data...")

    with napari.gui_qt():
        viewer = napari.Viewer()
        viewer.add_image(image, name='image')
        viewer.add_image(blurred_image, name='blurred')
        viewer.add_image(noisy_blurred_image, name='noisy_blurred_image')
        viewer.add_image(lr_deconvolved_image_2_clipped,
                         name='lr_deconvolved_image_2')
        viewer.add_image(lr_deconvolved_image_5_clipped,
                         name='lr_deconvolved_image_5')
        viewer.add_image(lr_deconvolved_image_10_clipped,
                         name='lr_deconvolved_image_10')
        viewer.add_image(lr_deconvolved_image_20_clipped,
                         name='lr_deconvolved_image_20')
        viewer.add_image(deconvolved_image_clipped,
                         name='ssi_deconvolved_image')
예제 #6
0
def demo(image_clipped):
    image_clipped = normalise(image_clipped.astype(numpy.float32))
    blurred_image, psf_kernel = add_microscope_blur_3d(image_clipped)
    noisy_blurred_image = add_poisson_gaussian_noise(blurred_image,
                                                     alpha=0.001,
                                                     sigma=0.1,
                                                     sap=0.01,
                                                     quant_bits=10)

    lr = ImageTranslatorLRDeconv(psf_kernel=psf_kernel, backend="cupy")
    lr.train(noisy_blurred_image)
    # lr.max_num_iterations=2
    # lr_deconvolved_image_2 = lr.translate(noisy_blurred_image)
    lr.max_num_iterations = 5
    lr_deconvolved_image_5 = lr.translate(noisy_blurred_image)
    # lr.max_num_iterations=10
    # lr_deconvolved_image_10 = lr.translate(noisy_blurred_image)
    # lr.max_num_iterations=20
    # lr_deconvolved_image_20 = lr.translate(noisy_blurred_image)

    it_deconv = SSIDeconvolution(
        max_epochs=3000,
        patience=300,
        batch_size=8,
        learning_rate=0.01,
        normaliser_type="identity",
        psf_kernel=psf_kernel,
        model_class=UNet,
        masking=True,
        masking_density=0.01,
        loss="l2",
    )

    start = time.time()
    it_deconv.train(noisy_blurred_image)
    stop = time.time()
    print(f"Training: elapsed time:  {stop - start} ")

    start = time.time()
    deconvolved_image = it_deconv.translate(noisy_blurred_image)
    stop = time.time()
    print(f"inference: elapsed time:  {stop - start} ")

    image_clipped = numpy.clip(image_clipped, 0, 1)
    # lr_deconvolved_image_2_clipped = numpy.clip(lr_deconvolved_image_2, 0, 1)
    lr_deconvolved_image_5_clipped = numpy.clip(lr_deconvolved_image_5, 0, 1)
    # lr_deconvolved_image_10_clipped = numpy.clip(lr_deconvolved_image_10, 0, 1)
    # lr_deconvolved_image_20_clipped = numpy.clip(lr_deconvolved_image_20, 0, 1)
    deconvolved_image_clipped = numpy.clip(deconvolved_image, 0, 1)

    columns = ["PSNR", "norm spectral mutual info", "norm mutual info", "SSIM"]
    print_header(columns)
    print_score(
        "blurry image",
        psnr(image_clipped, blurred_image),
        spectral_mutual_information(image_clipped, blurred_image),
        mutual_information(image_clipped, blurred_image),
        ssim(image_clipped, blurred_image),
    )

    print_score(
        "noisy and blurry image",
        psnr(image_clipped, noisy_blurred_image),
        spectral_mutual_information(image_clipped, noisy_blurred_image),
        mutual_information(image_clipped, noisy_blurred_image),
        ssim(image_clipped, noisy_blurred_image),
    )

    # print_score(
    #     "lr deconv (n=2)",
    #     psnr(image_clipped, lr_deconvolved_image_2_clipped),
    #     spectral_mutual_information(image_clipped, lr_deconvolved_image_2_clipped),
    #     mutual_information(image_clipped, lr_deconvolved_image_2_clipped),
    #     ssim(image_clipped, lr_deconvolved_image_2_clipped),
    # )

    print_score(
        "lr deconv (n=5)",
        psnr(image_clipped, lr_deconvolved_image_5_clipped),
        spectral_mutual_information(image_clipped,
                                    lr_deconvolved_image_5_clipped),
        mutual_information(image_clipped, lr_deconvolved_image_5_clipped),
        ssim(image_clipped, lr_deconvolved_image_5_clipped),
    )

    # print_score(
    #     "lr deconv (n=10)",
    #     psnr(image_clipped, lr_deconvolved_image_10_clipped),
    #     spectral_mutual_information(image_clipped, lr_deconvolved_image_10_clipped),
    #     mutual_information(image_clipped, lr_deconvolved_image_10_clipped),
    #     ssim(image_clipped, lr_deconvolved_image_10_clipped),
    # )
    #
    # print_score(
    #     "lr deconv (n=20)",
    #     psnr(image_clipped, lr_deconvolved_image_20_clipped),
    #     spectral_mutual_information(image_clipped, lr_deconvolved_image_20_clipped),
    #     mutual_information(image_clipped, lr_deconvolved_image_20_clipped),
    #     ssim(image_clipped, lr_deconvolved_image_20_clipped),
    # )

    print_score(
        "ssi deconv",
        psnr(image_clipped, deconvolved_image_clipped),
        spectral_mutual_information(image_clipped, deconvolved_image_clipped),
        mutual_information(image_clipped, deconvolved_image_clipped),
        ssim(image_clipped, deconvolved_image_clipped),
    )

    print(
        "NOTE: if you get a bad results for ssi, blame stochastic optimisation and retry..."
    )
    print(
        "      The training is done on the same exact image that we infer on, very few pixels..."
    )
    print("      Training should be more stable given more data...")

    if use_napari:
        with napari.gui_qt():
            viewer = napari.Viewer()
            viewer.add_image(image_clipped, name="image")
            viewer.add_image(blurred_image, name="blurred")
            viewer.add_image(noisy_blurred_image, name="noisy_blurred_image")
            # viewer.add_image(lr_deconvolved_image_2_clipped, name='lr_deconvolved_image_2')
            viewer.add_image(lr_deconvolved_image_5_clipped,
                             name="lr_deconvolved_image_5")
            # viewer.add_image(lr_deconvolved_image_10_clipped, name='lr_deconvolved_image_10')
            # viewer.add_image(lr_deconvolved_image_20_clipped, name='lr_deconvolved_image_20')
            viewer.add_image(deconvolved_image_clipped,
                             name="ssi_deconvolved_image")
예제 #7
0
def demo(
    image_clipped: np.ndarray,
    two_pass: bool = False,
    inv_mse_before_forward_model: bool = False,
    inv_mse_lambda: float = 2.0,
    learning_rate: float = 0.01,
    max_epochs: int = 3000,
    patience: int = 1000,
    masking_density: float = 0.01,
    training_noise: float = 0.1,
    output_dir: str = "demo_results",
    loss: str = "l2",
    check: bool = False,
    optimizer: str = "esadam",
    scheduler: str = "plateau",
    clip_before_psf: bool = True,
    fft_psf: Union[str, bool] = "auto",
    standardize: bool = False,
    amp: bool = False,
):

    image_clipped = normalise(image_clipped.astype(numpy.float32))
    blurred_image, psf_kernel = add_microscope_blur_2d(image_clipped)
    # noisy_blurred_image = add_noise(blurred_image, intensity=None, variance=0.01, sap=0.01, clip=True)
    noisy_blurred_image = add_poisson_gaussian_noise(blurred_image,
                                                     alpha=0.001,
                                                     sigma=0.1,
                                                     sap=0.01,
                                                     quant_bits=10)

    lr = ImageTranslatorLRDeconv(psf_kernel=psf_kernel, backend="cupy")
    lr.train(noisy_blurred_image)
    lr.max_num_iterations = 2
    lr_deconvolved_image_2 = lr.translate(noisy_blurred_image)
    lr.max_num_iterations = 5
    lr_deconvolved_image_5 = lr.translate(noisy_blurred_image)
    lr.max_num_iterations = 10
    lr_deconvolved_image_10 = lr.translate(noisy_blurred_image)
    lr.max_num_iterations = 20
    lr_deconvolved_image_20 = lr.translate(noisy_blurred_image)

    it_deconv = SSIDeconvolution(
        max_epochs=max_epochs,
        patience=patience,
        batch_size=8,
        learning_rate=learning_rate,
        normaliser_type="identity",
        psf_kernel=psf_kernel,
        model_class=UNet,
        masking=True,
        masking_density=masking_density,
        training_noise=training_noise,
        loss=loss,
        two_pass=two_pass,
        inv_mse_before_forward_model=inv_mse_before_forward_model,
        inv_mse_lambda=inv_mse_lambda,
        check=check,
        optimizer=optimizer,
        scheduler=scheduler,
        clip_before_psf=clip_before_psf,
        fft_psf=fft_psf,
        standardize_image=standardize,
        amp=amp,
    )

    start = time.time()
    it_deconv.train(noisy_blurred_image)
    stop = time.time()
    print(f"Training: elapsed time:  {stop - start} ")

    if not check:
        wandb.run.summary["training_time"] = stop - start

    start = time.time()
    deconvolved_image = it_deconv.translate(noisy_blurred_image)
    stop = time.time()
    print(f"inference: elapsed time:  {stop - start} ")

    if not check:
        wandb.run.summary["inference_time"] = stop - start

    image_clipped = numpy.clip(image_clipped, 0, 1)
    lr_deconvolved_image_2_clipped = numpy.clip(lr_deconvolved_image_2, 0, 1)
    lr_deconvolved_image_5_clipped = numpy.clip(lr_deconvolved_image_5, 0, 1)
    lr_deconvolved_image_10_clipped = numpy.clip(lr_deconvolved_image_10, 0, 1)
    lr_deconvolved_image_20_clipped = numpy.clip(lr_deconvolved_image_20, 0, 1)
    deconvolved_image_clipped = numpy.clip(deconvolved_image, 0, 1)

    columns = ["PSNR", "norm spectral mutual info", "norm mutual info", "SSIM"]
    print_header(columns)
    print_score(
        "blurry image",
        psnr(image_clipped, blurred_image),
        spectral_mutual_information(image_clipped, blurred_image),
        mutual_information(image_clipped, blurred_image),
        ssim(image_clipped, blurred_image),
    )

    print_score(
        "noisy and blurry image",
        psnr(image_clipped, noisy_blurred_image),
        spectral_mutual_information(image_clipped, noisy_blurred_image),
        mutual_information(image_clipped, noisy_blurred_image),
        ssim(image_clipped, noisy_blurred_image),
    )

    print_score(
        "lr deconv (n=2)",
        psnr(image_clipped, lr_deconvolved_image_2_clipped),
        spectral_mutual_information(image_clipped,
                                    lr_deconvolved_image_2_clipped),
        mutual_information(image_clipped, lr_deconvolved_image_2_clipped),
        ssim(image_clipped, lr_deconvolved_image_2_clipped),
    )

    print_score(
        "lr deconv (n=5)",
        psnr(image_clipped, lr_deconvolved_image_5_clipped),
        spectral_mutual_information(image_clipped,
                                    lr_deconvolved_image_5_clipped),
        mutual_information(image_clipped, lr_deconvolved_image_5_clipped),
        ssim(image_clipped, lr_deconvolved_image_5_clipped),
    )

    print_score(
        "lr deconv (n=10)",
        psnr(image_clipped, lr_deconvolved_image_10_clipped),
        spectral_mutual_information(image_clipped,
                                    lr_deconvolved_image_10_clipped),
        mutual_information(image_clipped, lr_deconvolved_image_10_clipped),
        ssim(image_clipped, lr_deconvolved_image_10_clipped),
    )

    print_score(
        "lr deconv (n=20)",
        psnr(image_clipped, lr_deconvolved_image_20_clipped),
        spectral_mutual_information(image_clipped,
                                    lr_deconvolved_image_20_clipped),
        mutual_information(image_clipped, lr_deconvolved_image_20_clipped),
        ssim(image_clipped, lr_deconvolved_image_20_clipped),
    )

    psnr_deconv = psnr(image_clipped, deconvolved_image_clipped)
    smi_deconv = spectral_mutual_information(image_clipped,
                                             deconvolved_image_clipped)
    mi_deconv = mutual_information(image_clipped, deconvolved_image_clipped)
    ssim_deconv = ssim(image_clipped, deconvolved_image_clipped)

    if not check:
        wandb.run.summary["psnr"] = psnr_deconv
        wandb.run.summary["smi"] = smi_deconv
        wandb.run.summary["mi"] = mi_deconv
        wandb.run.summary["ssim"] = ssim_deconv

    print_score(
        "ssi deconv",
        psnr_deconv,
        smi_deconv,
        mi_deconv,
        ssim_deconv,
    )

    print(
        "NOTE: if you get a bad results for ssi, blame stochastic optimisation and retry..."
    )
    print(
        "      The training is done on the same exact image that we infer on, very few pixels..."
    )
    print("      Training should be more stable given more data...")

    if use_napari:
        with napari.gui_qt():
            viewer = napari.Viewer()
            viewer.add_image(image_clipped, name="image")
            viewer.add_image(blurred_image, name="blurred")
            viewer.add_image(noisy_blurred_image, name="noisy_blurred_image")
            viewer.add_image(lr_deconvolved_image_2_clipped,
                             name="lr_deconvolved_image_2")
            viewer.add_image(lr_deconvolved_image_5_clipped,
                             name="lr_deconvolved_image_5")
            viewer.add_image(lr_deconvolved_image_10_clipped,
                             name="lr_deconvolved_image_10")
            viewer.add_image(lr_deconvolved_image_20_clipped,
                             name="lr_deconvolved_image_20")
            viewer.add_image(deconvolved_image_clipped,
                             name="ssi_deconvolved_image")
    else:
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

        imwrite(output_dir / "image.png", image_clipped, format="png")
        imwrite(output_dir / "blurred.png", blurred_image, format="png")
        imwrite(output_dir / "noisy_blurred_image.png",
                noisy_blurred_image,
                format="png")
        imwrite(
            output_dir / "lr_deconvolved_image_2.png",
            lr_deconvolved_image_2_clipped,
            format="png",
        )
        imwrite(
            output_dir / "lr_deconvolved_image_5.png",
            lr_deconvolved_image_5_clipped,
            format="png",
        )
        imwrite(
            output_dir / "lr_deconvolved_image_10.png",
            lr_deconvolved_image_10_clipped,
            format="png",
        )
        imwrite(
            output_dir / "lr_deconvolved_image_20.png",
            lr_deconvolved_image_20_clipped,
            format="png",
        )
        imwrite(
            output_dir / "ssi_deconvolved_image.png",
            deconvolved_image_clipped,
            format="png",
        )