示例#1
0
def test_cycle_spinning_num_workers():
    img = astro_gray
    sigma = 0.1
    rstate = np.random.RandomState(1234)
    noisy = img.copy() + 0.1 * rstate.randn(*(img.shape))

    denoise_func = restoration.denoise_wavelet
    func_kw = dict(sigma=sigma, multichannel=True)

    # same result whether using 1 worker or multiple workers
    dn_cc1 = restoration.cycle_spin(noisy,
                                    denoise_func,
                                    max_shifts=1,
                                    func_kw=func_kw,
                                    multichannel=False,
                                    num_workers=1)
    dn_cc2 = restoration.cycle_spin(noisy,
                                    denoise_func,
                                    max_shifts=1,
                                    func_kw=func_kw,
                                    multichannel=False,
                                    num_workers=4)
    dn_cc3 = restoration.cycle_spin(noisy,
                                    denoise_func,
                                    max_shifts=1,
                                    func_kw=func_kw,
                                    multichannel=False,
                                    num_workers=None)
    assert_almost_equal(dn_cc1, dn_cc2)
    assert_almost_equal(dn_cc1, dn_cc3)
示例#2
0
def test_cycle_spinning_num_workers():
    img = astro_gray
    sigma = 0.1
    rstate = np.random.RandomState(1234)
    noisy = img.copy() + 0.1 * rstate.randn(*(img.shape))

    denoise_func = restoration.denoise_wavelet
    func_kw = dict(sigma=sigma, multichannel=True, rescale_sigma=True)

    # same results are expected whether using 1 worker or multiple workers
    with expected_warnings([PYWAVELET_ND_INDEXING_WARNING]):
        dn_cc1 = restoration.cycle_spin(noisy,
                                        denoise_func,
                                        max_shifts=1,
                                        func_kw=func_kw,
                                        multichannel=False,
                                        num_workers=1)
    with expected_warnings(
        [PYWAVELET_ND_INDEXING_WARNING, DASK_NOT_INSTALLED_WARNING]):
        dn_cc2 = restoration.cycle_spin(noisy,
                                        denoise_func,
                                        max_shifts=1,
                                        func_kw=func_kw,
                                        multichannel=False,
                                        num_workers=4)
        dn_cc3 = restoration.cycle_spin(noisy,
                                        denoise_func,
                                        max_shifts=1,
                                        func_kw=func_kw,
                                        multichannel=False,
                                        num_workers=None)
    assert_almost_equal(dn_cc1, dn_cc2)
    assert_almost_equal(dn_cc1, dn_cc3)
def test_cycle_spinning_num_workers_deprecated_multichannel():
    img = astro_gray[:32, :32]
    sigma = 0.1
    rstate = np.random.RandomState(1234)
    noisy = img.copy() + 0.1 * rstate.randn(*(img.shape))

    denoise_func = restoration.denoise_wavelet

    func_kw = dict(sigma=sigma, channel_axis=-1, rescale_sigma=True)

    mc_warn_str = "`multichannel` is a deprecated argument"

    # same results are expected whether using 1 worker or multiple workers
    with expected_warnings([mc_warn_str]):
        dn_cc1 = restoration.cycle_spin(noisy,
                                        denoise_func,
                                        max_shifts=1,
                                        func_kw=func_kw,
                                        multichannel=False,
                                        num_workers=1)

    if DASK_NOT_INSTALLED_WARNING is None:
        exp_warn = [mc_warn_str]
    else:
        exp_warn = [mc_warn_str, DASK_NOT_INSTALLED_WARNING]
    with expected_warnings(exp_warn):
        dn_cc2 = restoration.cycle_spin(noisy,
                                        denoise_func,
                                        max_shifts=1,
                                        func_kw=func_kw,
                                        multichannel=False,
                                        num_workers=2)
    assert_almost_equal(dn_cc1, dn_cc2)

    # providing multichannel argument positionally also warns
    mc_warn_str = "Providing the `multichannel` argument"
    if DASK_NOT_INSTALLED_WARNING is None:
        exp_warn = [mc_warn_str]
    else:
        exp_warn = [mc_warn_str, DASK_NOT_INSTALLED_WARNING]

    with expected_warnings(exp_warn):
        restoration.cycle_spin(noisy, denoise_func, 1, 1, None, False)
示例#4
0
def cspin():
    img = img_as_float(skimage.data.camera())
    imgO = img.copy()
    sigma = 0.1
    img = img + sigma * np.random.standard_normal(img.shape)
    imgN = img.copy()
    denoised = cycle_spin(img, func=denoise_wavelet, max_shifts=3)
    imgR = denoised.copy()

    return [imgO, imgN, imgR]
示例#5
0
def test_cycle_spinning_num_workers():
    img = astro_gray
    sigma = 0.1
    rstate = np.random.RandomState(1234)
    noisy = img.copy() + 0.1 * rstate.randn(*(img.shape))

    denoise_func = restoration.denoise_wavelet
    func_kw = dict(sigma=sigma, multichannel=True)

    # same result whether using 1 worker or multiple workers
    dn_cc1 = restoration.cycle_spin(noisy, denoise_func, max_shifts=1,
                                    func_kw=func_kw, multichannel=False,
                                    num_workers=1)
    dn_cc2 = restoration.cycle_spin(noisy, denoise_func, max_shifts=1,
                                    func_kw=func_kw, multichannel=False,
                                    num_workers=4)
    dn_cc3 = restoration.cycle_spin(noisy, denoise_func, max_shifts=1,
                                    func_kw=func_kw, multichannel=False,
                                    num_workers=None)
    assert_almost_equal(dn_cc1, dn_cc2)
    assert_almost_equal(dn_cc1, dn_cc3)
示例#6
0
def test_cycle_spinning_num_workers():
    img = astro_gray
    sigma = 0.1
    rstate = np.random.default_rng(1234)
    noisy = img.copy() + 0.1 * rstate.standard_normal(img.shape)

    denoise_func = restoration.denoise_wavelet
    func_kw = dict(sigma=sigma, channel_axis=-1, rescale_sigma=True)

    # same results are expected whether using 1 worker or multiple workers
    dn_cc1 = restoration.cycle_spin(noisy, denoise_func, max_shifts=1,
                                    func_kw=func_kw, channel_axis=None,
                                    num_workers=1)
    with expected_warnings([DASK_NOT_INSTALLED_WARNING]):
        dn_cc2 = restoration.cycle_spin(noisy, denoise_func, max_shifts=1,
                                        func_kw=func_kw, channel_axis=None,
                                        num_workers=4)
        dn_cc3 = restoration.cycle_spin(noisy, denoise_func, max_shifts=1,
                                        func_kw=func_kw, channel_axis=None,
                                        num_workers=None)
    assert_array_almost_equal(dn_cc1, dn_cc2)
    assert_array_almost_equal(dn_cc1, dn_cc3)
示例#7
0
def test_cycle_spinning_num_workers():
    img = astro_gray
    sigma = 0.1
    rstate = np.random.RandomState(1234)
    noisy = img.copy() + 0.1 * rstate.randn(*(img.shape))

    denoise_func = restoration.denoise_wavelet
    func_kw = dict(sigma=sigma, multichannel=True)

    # same results are expected whether using 1 worker or multiple workers
    with expected_warnings([PYWAVELET_ND_INDEXING_WARNING]):
        dn_cc1 = restoration.cycle_spin(noisy, denoise_func, max_shifts=1,
                                        func_kw=func_kw, multichannel=False,
                                        num_workers=1)
    with expected_warnings([PYWAVELET_ND_INDEXING_WARNING,
                            DASK_NOT_INSTALLED_WARNING]):
        dn_cc2 = restoration.cycle_spin(noisy, denoise_func, max_shifts=1,
                                        func_kw=func_kw, multichannel=False,
                                        num_workers=4)
        dn_cc3 = restoration.cycle_spin(noisy, denoise_func, max_shifts=1,
                                        func_kw=func_kw, multichannel=False,
                                        num_workers=None)
    assert_almost_equal(dn_cc1, dn_cc2)
    assert_almost_equal(dn_cc1, dn_cc3)
示例#8
0
def test_cycle_spinning_multichannel(rescale_sigma):
    sigma = 0.1
    rstate = np.random.default_rng(1234)

    for channel_axis in -1, None:
        if channel_axis is not None:
            img = astro
            # can either omit or be 0 along the channels axis
            valid_shifts = [1, (0, 1), (1, 0), (1, 1), (1, 1, 0)]
            # can either omit or be 1 on channels axis.
            valid_steps = [1, 2, (1, 2), (1, 2, 1)]
            # too few or too many shifts or non-zero shift on channels
            invalid_shifts = [(1, 1, 2), (1, ), (1, 1, 0, 1)]
            # too few or too many shifts or any shifts <= 0
            invalid_steps = [(1, ), (1, 1, 1, 1), (0, 1), (-1, -1)]
        else:
            img = astro_gray
            valid_shifts = [1, (0, 1), (1, 0), (1, 1)]
            valid_steps = [1, 2, (1, 2)]
            invalid_shifts = [(1, 1, 2), (1, )]
            invalid_steps = [(1, ), (1, 1, 1), (0, 1), (-1, -1)]

        noisy = img.copy() + 0.1 * rstate.standard_normal(img.shape)

        denoise_func = restoration.denoise_wavelet
        func_kw = dict(sigma=sigma, channel_axis=channel_axis,
                       rescale_sigma=rescale_sigma)

        # max_shifts=0 is equivalent to just calling denoise_func
        with expected_warnings([DASK_NOT_INSTALLED_WARNING]):
            dn_cc = restoration.cycle_spin(noisy, denoise_func, max_shifts=0,
                                           func_kw=func_kw,
                                           channel_axis=channel_axis)
            dn = denoise_func(noisy, **func_kw)
        assert_array_equal(dn, dn_cc)

        # denoising with cycle spinning will give better PSNR than without
        for max_shifts in valid_shifts:
            with expected_warnings([DASK_NOT_INSTALLED_WARNING]):
                dn_cc = restoration.cycle_spin(noisy, denoise_func,
                                               max_shifts=max_shifts,
                                               func_kw=func_kw,
                                               channel_axis=channel_axis)
            psnr = peak_signal_noise_ratio(img, dn)
            psnr_cc = peak_signal_noise_ratio(img, dn_cc)
            assert psnr_cc > psnr

        for shift_steps in valid_steps:
            with expected_warnings([DASK_NOT_INSTALLED_WARNING]):
                dn_cc = restoration.cycle_spin(noisy, denoise_func,
                                               max_shifts=2,
                                               shift_steps=shift_steps,
                                               func_kw=func_kw,
                                               channel_axis=channel_axis)
            psnr = peak_signal_noise_ratio(img, dn)
            psnr_cc = peak_signal_noise_ratio(img, dn_cc)
            assert psnr_cc > psnr

        for max_shifts in invalid_shifts:
            with pytest.raises(ValueError):
                dn_cc = restoration.cycle_spin(noisy, denoise_func,
                                               max_shifts=max_shifts,
                                               func_kw=func_kw,
                                               channel_axis=channel_axis)
        for shift_steps in invalid_steps:
            with pytest.raises(ValueError):
                dn_cc = restoration.cycle_spin(noisy, denoise_func,
                                               max_shifts=2,
                                               shift_steps=shift_steps,
                                               func_kw=func_kw,
                                               channel_axis=channel_axis)
示例#9
0
def test_cycle_spinning_multichannel(rescale_sigma):
    sigma = 0.1
    rstate = np.random.RandomState(1234)

    for multichannel in True, False:
        if multichannel:
            img = astro
            # can either omit or be 0 along the channels axis
            valid_shifts = [1, (0, 1), (1, 0), (1, 1), (1, 1, 0)]
            # can either omit or be 1 on channels axis.
            valid_steps = [1, 2, (1, 2), (1, 2, 1)]
            # too few or too many shifts or non-zero shift on channels
            invalid_shifts = [(1, 1, 2), (1, ), (1, 1, 0, 1)]
            # too few or too many shifts or any shifts <= 0
            invalid_steps = [(1, ), (1, 1, 1, 1), (0, 1), (-1, -1)]
        else:
            img = astro_gray
            valid_shifts = [1, (0, 1), (1, 0), (1, 1)]
            valid_steps = [1, 2, (1, 2)]
            invalid_shifts = [(1, 1, 2), (1, )]
            invalid_steps = [(1, ), (1, 1, 1), (0, 1), (-1, -1)]

        noisy = img.copy() + 0.1 * rstate.randn(*(img.shape))

        denoise_func = restoration.denoise_wavelet
        func_kw = dict(sigma=sigma,
                       multichannel=multichannel,
                       rescale_sigma=rescale_sigma)

        # max_shifts=0 is equivalent to just calling denoise_func
        with expected_warnings(
            [PYWAVELET_ND_INDEXING_WARNING, DASK_NOT_INSTALLED_WARNING]):
            dn_cc = restoration.cycle_spin(noisy,
                                           denoise_func,
                                           max_shifts=0,
                                           func_kw=func_kw,
                                           multichannel=multichannel)
            dn = denoise_func(noisy, **func_kw)
        assert_equal(dn, dn_cc)

        # denoising with cycle spinning will give better PSNR than without
        for max_shifts in valid_shifts:
            with expected_warnings(
                [PYWAVELET_ND_INDEXING_WARNING, DASK_NOT_INSTALLED_WARNING]):
                dn_cc = restoration.cycle_spin(noisy,
                                               denoise_func,
                                               max_shifts=max_shifts,
                                               func_kw=func_kw,
                                               multichannel=multichannel)
            psnr = peak_signal_noise_ratio(img, dn)
            psnr_cc = peak_signal_noise_ratio(img, dn_cc)
            assert_(psnr_cc > psnr)

        for shift_steps in valid_steps:
            with expected_warnings(
                [PYWAVELET_ND_INDEXING_WARNING, DASK_NOT_INSTALLED_WARNING]):
                dn_cc = restoration.cycle_spin(noisy,
                                               denoise_func,
                                               max_shifts=2,
                                               shift_steps=shift_steps,
                                               func_kw=func_kw,
                                               multichannel=multichannel)
            psnr = peak_signal_noise_ratio(img, dn)
            psnr_cc = peak_signal_noise_ratio(img, dn_cc)
            assert_(psnr_cc > psnr)

        for max_shifts in invalid_shifts:
            with testing.raises(ValueError):
                dn_cc = restoration.cycle_spin(noisy,
                                               denoise_func,
                                               max_shifts=max_shifts,
                                               func_kw=func_kw,
                                               multichannel=multichannel)
        for shift_steps in invalid_steps:
            with testing.raises(ValueError):
                dn_cc = restoration.cycle_spin(noisy,
                                               denoise_func,
                                               max_shifts=2,
                                               shift_steps=shift_steps,
                                               func_kw=func_kw,
                                               multichannel=multichannel)
示例#10
0
def test_cycle_spinning_multichannel():
    sigma = 0.1
    rstate = np.random.RandomState(1234)

    for multichannel in True, False:
        if multichannel:
            img = astro
            # can either omit or be 0 along the channels axis
            valid_shifts = [1, (0, 1), (1, 0), (1, 1), (1, 1, 0)]
            # can either omit or be 1 on channels axis.
            valid_steps = [1, 2, (1, 2), (1, 2, 1)]
            # too few or too many shifts or non-zero shift on channels
            invalid_shifts = [(1, 1, 2), (1, ), (1, 1, 0, 1)]
            # too few or too many shifts or any shifts <= 0
            invalid_steps = [(1, ), (1, 1, 1, 1), (0, 1), (-1, -1)]
        else:
            img = astro_gray
            valid_shifts = [1, (0, 1), (1, 0), (1, 1)]
            valid_steps = [1, 2, (1, 2)]
            invalid_shifts = [(1, 1, 2), (1, )]
            invalid_steps = [(1, ), (1, 1, 1), (0, 1), (-1, -1)]

        noisy = img.copy() + 0.1 * rstate.randn(*(img.shape))

        denoise_func = restoration.denoise_wavelet
        func_kw = dict(sigma=sigma, multichannel=multichannel)

        # max_shifts=0 is equivalent to just calling denoise_func
        with expected_warnings([PYWAVELET_ND_INDEXING_WARNING,
                                DASK_NOT_INSTALLED_WARNING]):
            dn_cc = restoration.cycle_spin(noisy, denoise_func, max_shifts=0,
                                           func_kw=func_kw,
                                           multichannel=multichannel)
            dn = denoise_func(noisy, **func_kw)
        assert_equal(dn, dn_cc)

        # denoising with cycle spinning will give better PSNR than without
        for max_shifts in valid_shifts:
            with expected_warnings([PYWAVELET_ND_INDEXING_WARNING,
                                    DASK_NOT_INSTALLED_WARNING]):
                dn_cc = restoration.cycle_spin(noisy, denoise_func,
                                               max_shifts=max_shifts,
                                               func_kw=func_kw,
                                               multichannel=multichannel)
            assert_(compare_psnr(img, dn_cc) > compare_psnr(img, dn))

        for shift_steps in valid_steps:
            with expected_warnings([PYWAVELET_ND_INDEXING_WARNING,
                                    DASK_NOT_INSTALLED_WARNING]):
                dn_cc = restoration.cycle_spin(noisy, denoise_func,
                                               max_shifts=2,
                                               shift_steps=shift_steps,
                                               func_kw=func_kw,
                                               multichannel=multichannel)
            assert_(compare_psnr(img, dn_cc) > compare_psnr(img, dn))

        for max_shifts in invalid_shifts:
            with testing.raises(ValueError):
                dn_cc = restoration.cycle_spin(noisy, denoise_func,
                                               max_shifts=max_shifts,
                                               func_kw=func_kw,
                                               multichannel=multichannel)
        for shift_steps in invalid_steps:
            with testing.raises(ValueError):
                dn_cc = restoration.cycle_spin(noisy, denoise_func,
                                               max_shifts=2,
                                               shift_steps=shift_steps,
                                               func_kw=func_kw,
                                               multichannel=multichannel)
def NoiseRemoval(X, algo=0):
    # AlgoToUse

    if algo == 0:
        # Applying the mean filter
        X_g = sp.ndimage.gaussian_filter(X, 1)
        X_2 = X_g

    if algo == 1:
        # Applying the Median filter
        X_m = sp.ndimage.median_filter(X, 3)
        X_2 = X_m

    if algo == 2:
        # Applying the anistropic filter
        X_an = anisodiff(X,
                         niter=5,
                         kappa=100,
                         gamma=0.15,
                         step=(1., 1.),
                         option=1,
                         ploton=False)
        X_2 = X_an

    if algo == 3:
        # Applying nonlocal means
        X_clip = X[30:180, 150:300]
        sigma_est = np.mean(estimate_sigma(X_clip, multichannel=False))
        print("estimated noise standard deviation = {}".format(sigma_est))
        patch_kw = dict(
            patch_size=5,  # 5x5 patches
            patch_distance=6,  # 13x13 search area
            multichannel=True)
        X_non = denoise_nl_means(X,
                                 h=0.8 * sigma_est,
                                 fast_mode=True,
                                 **patch_kw)

        X_2 = X_non

    if algo == 4:
        # Applying wavelet based
        # Repeat denosing with different amounts of cycle spinning.  e.g.
        # max_shift = 0 -> no cycle spinning
        # max_shift = 1 -> shifts of (0, 1) along each axis
        # max_shift = 3 -> shifts of (0, 1, 2, 3) along each axis
        # etc...
        X = X / 255
        denoise_kwargs = dict(multichannel=False,
                              convert2ycbcr=True,
                              wavelet='db1')
        max_shifts = [0, 1, 3, 5]
        s = 3
        X_w = cycle_spin(X,
                         func=denoise_wavelet,
                         max_shifts=s,
                         func_kw=denoise_kwargs,
                         multichannel=False)

        X_2 = X_w * 255
        X = X * 255

    return X_2
示例#12
0
def test_cycle_spinning_multichannel():
    sigma = 0.1
    rstate = np.random.RandomState(1234)

    for multichannel in True, False:
        if multichannel:
            img = astro
            # can either omit or be 0 along the channels axis
            valid_shifts = [1, (0, 1), (1, 0), (1, 1), (1, 1, 0)]
            # can either omit or be 1 on channels axis.
            valid_steps = [1, 2, (1, 2), (1, 2, 1)]
            # too few or too many shifts or non-zero shift on channels
            invalid_shifts = [(1, 1, 2), (1, ), (1, 1, 0, 1)]
            # too few or too many shifts or any shifts <= 0
            invalid_steps = [(1, ), (1, 1, 1, 1), (0, 1), (-1, -1)]
        else:
            img = astro_gray
            valid_shifts = [1, (0, 1), (1, 0), (1, 1)]
            valid_steps = [1, 2, (1, 2)]
            invalid_shifts = [(1, 1, 2), (1, )]
            invalid_steps = [(1, ), (1, 1, 1), (0, 1), (-1, -1)]

        noisy = img.copy() + 0.1 * rstate.randn(*(img.shape))

        denoise_func = restoration.denoise_wavelet
        func_kw = dict(sigma=sigma, multichannel=multichannel)

        # max_shifts=0 is equivalent to just calling denoise_func
        dn_cc = restoration.cycle_spin(noisy,
                                       denoise_func,
                                       max_shifts=0,
                                       func_kw=func_kw,
                                       multichannel=multichannel)
        dn = denoise_func(noisy, **func_kw)
        assert_equal(dn, dn_cc)

        # denoising with cycle spinning will give better PSNR than without
        for max_shifts in valid_shifts:
            dn_cc = restoration.cycle_spin(noisy,
                                           denoise_func,
                                           max_shifts=max_shifts,
                                           func_kw=func_kw,
                                           multichannel=multichannel)
            assert_(compare_psnr(img, dn_cc) > compare_psnr(img, dn))

        for shift_steps in valid_steps:
            dn_cc = restoration.cycle_spin(noisy,
                                           denoise_func,
                                           max_shifts=2,
                                           shift_steps=shift_steps,
                                           func_kw=func_kw,
                                           multichannel=multichannel)
            assert_(compare_psnr(img, dn_cc) > compare_psnr(img, dn))

        for max_shifts in invalid_shifts:
            with testing.raises(ValueError):
                dn_cc = restoration.cycle_spin(noisy,
                                               denoise_func,
                                               max_shifts=max_shifts,
                                               func_kw=func_kw,
                                               multichannel=multichannel)
        for shift_steps in invalid_steps:
            with testing.raises(ValueError):
                dn_cc = restoration.cycle_spin(noisy,
                                               denoise_func,
                                               max_shifts=2,
                                               shift_steps=shift_steps,
                                               func_kw=func_kw,
                                               multichannel=multichannel)
示例#13
0
# max_shift = 0 -> no cycle spinning
# max_shift = 1 -> shifts of (0, 1) along each axis
# max_shift = 3 -> shifts of (0, 1, 2, 3) along each axis
# etc...

denoise_kwargs = dict(channel_axis=-1,
                      convert2ycbcr=True,
                      wavelet='db1',
                      rescale_sigma=True)

all_psnr = []
max_shifts = [0, 1, 3, 5]
for n, s in enumerate(max_shifts):
    im_bayescs = cycle_spin(noisy,
                            func=denoise_wavelet,
                            max_shifts=s,
                            func_kw=denoise_kwargs,
                            channel_axis=-1)
    ax[n + 1].imshow(im_bayescs)
    ax[n + 1].axis('off')
    psnr = peak_signal_noise_ratio(original, im_bayescs)
    if s == 0:
        ax[n + 1].set_title(
            "Denoised: no cycle shifts\nPSNR={:0.4g}".format(psnr))
    else:
        ax[n + 1].set_title("Denoised: {0}x{0} shifts\nPSNR={1:0.4g}".format(
            s + 1, psnr))
    all_psnr.append(psnr)

# plot PSNR as a function of the degree of cycle shifting
ax[5].plot(max_shifts, all_psnr, 'k.-')
ax[0].set_title('Noisy\nPSNR={:0.4g}'.format(psnr_noisy))

# Repeat denosing with different amounts of cycle spinning.  e.g.
# max_shift = 0 -> no cycle spinning
# max_shift = 1 -> shifts of (0, 1) along each axis
# max_shift = 3 -> shifts of (0, 1, 2, 3) along each axis
# etc...

denoise_kwargs = dict(multichannel=True, convert2ycbcr=True, wavelet='db1')

all_psnr = []
max_shifts = [0, 1, 3, 5]
for n, s in enumerate(max_shifts):
    im_bayescs = cycle_spin(noisy,
                            func=denoise_wavelet,
                            max_shifts=s,
                            func_kw=denoise_kwargs,
                            multichannel=True)
    ax[n + 1].imshow(im_bayescs)
    ax[n + 1].axis('off')
    psnr = compare_psnr(original, im_bayescs)
    if s == 0:
        ax[n + 1].set_title(
            "Denoised: no cycle shifts\nPSNR={:0.4g}".format(psnr))
    else:
        ax[n + 1].set_title("Denoised: {0}x{0} shifts\nPSNR={1:0.4g}".format(
            s + 1, psnr))
    all_psnr.append(psnr)

# plot PSNR as a function of the degree of cycle shifting
ax[5].plot(max_shifts, all_psnr, 'k.-')
示例#15
0
ax[0].axis('off')
ax[0].set_title('Noisy\nPSNR={:0.4g}'.format(psnr_noisy))


# Repeat denosing with different amounts of cycle spinning.  e.g.
# max_shift = 0 -> no cycle spinning
# max_shift = 1 -> shifts of (0, 1) along each axis
# max_shift = 3 -> shifts of (0, 1, 2, 3) along each axis
# etc...

denoise_kwargs = dict(multichannel=True, convert2ycbcr=True, wavelet='db1')

all_psnr = []
max_shifts = [0, 1, 3, 5]
for n, s in enumerate(max_shifts):
    im_bayescs = cycle_spin(noisy, func=denoise_wavelet, max_shifts=s,
                            func_kw=denoise_kwargs, multichannel=True)
    ax[n+1].imshow(im_bayescs)
    ax[n+1].axis('off')
    psnr = compare_psnr(original, im_bayescs)
    if s == 0:
        ax[n+1].set_title(
            "Denoised: no cycle shifts\nPSNR={:0.4g}".format(psnr))
    else:
        ax[n+1].set_title(
            "Denoised: {0}x{0} shifts\nPSNR={1:0.4g}".format(s+1, psnr))
    all_psnr.append(psnr)

# plot PSNR as a function of the degree of cycle shifting
ax[5].plot(max_shifts, all_psnr, 'k.-')
ax[5].set_ylabel('PSNR (dB)')
ax[5].set_xlabel('max cycle shift along each axis')
示例#16
0
from skimage import io

noisy_img = img_as_float(io.imread("images/MRI_images/MRI_noisy.tif"))
ref_img = img_as_float(io.imread("images/MRI_images/MRI_clean.tif"))

denoise_kwargs = dict(multichannel=False,
                      wavelet='db1',
                      method='BayesShrink',
                      rescale_sigma=True)

all_psnr = []
max_shifts = 3  #0, 1, 3, 5

Shft_inv_wavelet = cycle_spin(noisy_img,
                              func=denoise_wavelet,
                              max_shifts=max_shifts,
                              func_kw=denoise_kwargs,
                              multichannel=False)

noise_psnr = peak_signal_noise_ratio(ref_img, noisy_img)
shft_cleaned_psnr = peak_signal_noise_ratio(ref_img, Shft_inv_wavelet)
print("PSNR of input noisy image = ", noise_psnr)
print("PSNR of cleaned image = ", shft_cleaned_psnr)

plt.imsave("images/MRI_images/Shift_Inv_wavelet_smoothed.tif",
           Shft_inv_wavelet,
           cmap='gray')

##########################################################################
#Anisotropic Diffusion
示例#17
0
import time
from HPGe_Calibration import calibrate, fullCalibrate
from mpl_toolkits.axes_grid.anchored_artists import AnchoredText

import datetime
from skimage.restoration import denoise_wavelet, cycle_spin
from skimage import data, img_as_float
from skimage.util import random_noise

dataa = pd.read_csv("Data/LongTrinititeShielded.csv",
                    names=["Energy (KeV)", "Counts (a.u)"])

denoise_kwargs = dict(wavelet='db1', wavelet_levels=7, sigma=9e-14)

smoothed_data = cycle_spin(dataa["Counts (a.u)"],
                           func=denoise_wavelet,
                           max_shifts=15,
                           func_kw=denoise_kwargs)
# smoothed_data = denoise_wavelet(dataa["Counts (a.u)"], sigma=1.4e-13)
smoothedIntegral = np.max(smoothed_data)
originalIntegral = np.max(dataa["Counts (a.u)"])

ratio = smoothed_data / originalIntegral

plt.plot(smoothed_data, label="Smoothed")
plt.plot(dataa["Counts (a.u)"] * ratio * 700, label="Original")
plt.xlim([7100, 7600])
plt.ylim([0, 3e-18])
plt.legend()
plt.show()
def Ondelette_raconte(NomDeLImage):
    image=su.PullFromSlicer(NomDeLImage)
    NumpyImage=sitk.GetArrayFromImage(image)
    max_lev = 2
    c = pywt.wavedec2(NumpyImage, 'db2', mode='zero',axes=(-2,-1), level=max_lev)
    c_arr,c_slices= pywt.coeffs_to_array(c, padding=0, axes=(-2,-1)
    aa=c_arr[c_slices[0]]   
    image_ondelette=sitk.GetImageFromArray(aa)
    image_ondelette.SetSpacing(image.GetSpacing())
    image_ondelette.SetDirection(image.GetDirection())
    image_ondelette.SetOrigin(image.GetOrigin())
    su.PushToSlicer(image_ondelette,'image_aa')




def SpatialFrequency(image):
    SizeMatrix=image.GetSize()
    Square_diff_x=0
    Square_diff_y=0
    Square_diff_z=0
    Nvoxel=0
    for x in range(SizeMatrix[0]-1):
        for y in range(SizeMatrix[1]-1):
            for z in range(SizeMatrix[2]-1):
                    Square_diff_x=Square_diff_x+(image.GetPixel(x+1,y,z)-image.GetPixel(x,y,z))**2
                    Square_diff_y=Square_diff_y+(image.GetPixel(x,y+1,z)-image.GetPixel(x,y,z))**2
                    Square_diff_z=Square_diff_z+(image.GetPixel(x,y,z+1)-image.GetPixel(x,y,z))**2
                    Nvoxel=Nvoxel+1
    SF=(Square_diff_x+Square_diff_y+Square_diff_z)**0.5 #for testing
    #SF=(Square_diff_x/(Nvoxel)+Square_diff_y/(Nvoxel)+Square_diff_z/(Nvoxel))**0.5
    return SF


def reechantillonage_translateOnly(image_ref, tranformation,MinimumImage):
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(image_ref)
    resampler.SetInterpolator(sitk.sitkLinear)
    resampler.SetDefaultPixelValue(MinimumImage)
    resampler.SetTransform(tranformation)
    ImageRecaler = resampler.Execute(image_ref)
    return ImageRecaler 

def SpatialFrequencyOptim(image):
    dimension = 3        
    offset =(1,0,0) # offset can be any vector-like data  
    translation_dx = sitk.TranslationTransform(dimension, offset)
    image_dx=reechantillonage_translateOnly(image,translation_dx,0)      
    offset =(0,1,0) # offset can be any vector-like data  
    translation_dy = sitk.TranslationTransform(dimension, offset)
    image_dy=reechantillonage_translateOnly(image,translation_dy,0)
    offset =(0,0,1) # offset can be any vector-like data  
    translation_dz = sitk.TranslationTransform(dimension, offset)
    image_dz=reechantillonage_translateOnly(image,translation_dz,0)
    imageSF=(image-image_dx)**2+(image-image_dy)**2+(image-image_dz)**2
    imageSF=sitk.SumProjection(imageSF,2)
    imageSF=sitk.SumProjection(imageSF,1)
    imageSF=sitk.SumProjection(imageSF,0)
    #SF=imageSF.GetPixel(0,0,0)**0.5 #for testing
    SF=(imageSF.GetPixel(0,0,0)/(image.GetSize()[0]*image.GetSize()[1]*image.GetSize()[2]))**0.5
    return SF

def SpatialFrequencyOptim2(matrix):
    sq_diff = 0.0
    size=matrix.shape
    dim=len(size)  
    for i in range(dim): #iterate over all image dimensions
        slc1 = [slice(None)]*dim
        slc1[i] = slice(0,size[i]-1)
        slc2 = [slice(None)]*dim
        slc2[i] = slice(1,size[i])
        sq_diff+= np.sum((matrix[tuple(slc2)]- matrix[tuple(slc1)])**2)
    return sq_diff/np.prod(size)

        
dim=3

img = sitk.GaussianSource(outputPixelType=sitk.sitkUInt8, size=[128]*dim, sigma=[20]*dim, mean=[60]*dim)

sitk.Show(img)



res = SpatialFrequencyOptim(img)
    return SF



NomDeLImage='FOUMA_1m'
Nlevel=2

def wavelet_denoising(NomDeLImage, Nlevel):
    image=su.PullFromSlicer(NomDeLImage)
    NumpyImage=sitk.GetArrayFromImage(image)
    max_lev = 6       # how many levels of decomposition to draw
    coeffs = pywt.wavedecn(NumpyImage, 'db2', mode='zero', level=max_lev) #voir https://pywavelets.readthedocs.io/en/latest/ref/nd-dwt-and-idwt.html#pywt.wavedecn
    for i in range(Nlevel-max_lev):
        coeffs[(max_lev-i)] = {k: np.zeros_like(v) for k, v in coeffs[(max_lev-i)].items()} #remove highest frequency
        coeffs[-(max_lev-i)] = {k: np.zeros_like(v) for k, v in coeffs[-(max_lev-i)].items()} #remove highest frequency
    matrice_ondelette=pywt.waverecn(coeffs, 'db2') #mode periodic ou zero
    image_ondelette=sitk.GetImageFromArray(matrice_ondelette)
    image_ondelette.SetSpacing(image.GetSpacing())
    image_ondelette.SetDirection(image.GetDirection())
    image_ondelette.SetOrigin(image.GetOrigin())
    su.PushToSlicer(image_ondelette,'image_DenoisWave_level0-'+str(Nlevel))
  

###########autre test

def maddest(d, axis=None):
    """
    Mean Absolute Deviation
    """
    return np.mean(np.absolute(d - np.mean(d, axis)), axis)

def high_pass_filter(x, low_cutoff=1000, sample_rate=sample_rate):
    """
    From @randxie https://github.com/randxie/Kaggle-VSB-Baseline/blob/master/src/utils/util_signal.py
    Modified to work with scipy version 1.1.0 which does not have the fs parameter
    """   
    # nyquist frequency is half the sample rate https://en.wikipedia.org/wiki/Nyquist_frequency
    nyquist = 0.5 * sample_rate
    norm_low_cutoff = low_cutoff / nyquist  
    # Fault pattern usually exists in high frequency band. According to literature, the pattern is visible above 10^4 Hz.
    # scipy version 1.2.0
    #sos = butter(10, low_freq, btype='hp', fs=sample_fs, output='sos')
    
    # scipy version 1.1.0
    sos = butter(10, Wn=[norm_low_cutoff], btype='highpass', output='sos')
    filtered_sig = signal.sosfilt(sos, x)
    return filtered_sig

def denoise_signal( x, wavelet='db4', level=1):
    """
    1. Adapted from waveletSmooth function found here:
    http://connor-johnson.com/2016/01/24/using-pywavelets-to-remove-high-frequency-noise/
    2. Threshold equation and using hard mode in threshold as mentioned
    in section '3.2 denoising based on optimized singular values' from paper by Tomas Vantuch:
    http://dspace.vsb.cz/bitstream/handle/10084/133114/VAN431_FEI_P1807_1801V001_2018.pdf
    """   
    # Decompose to get the wavelet coefficients
    coeff = pywt.wavedec( x, wavelet, mode="per" )   
    # Calculate sigma for threshold as defined in http://dspace.vsb.cz/bitstream/handle/10084/133114/VAN431_FEI_P1807_1801V001_2018.pdf
    # As noted by @harshit92 MAD referred to in the paper is Mean Absolute Deviation not Median Absolute Deviation
    sigma = (1/0.6745) * maddest( coeff[-level] )
    # Calculte the univeral threshold
    uthresh = sigma * np.sqrt( 2*np.log( len( x ) ) )
    coeff[1:] = ( pywt.threshold( i, value=uthresh, mode='hard' ) for i in coeff[1:] )
    # Reconstruct the signal using the thresholded coefficients
    return pywt.waverec( coeff, wavelet, mode='per' )



def wavelet_denoising2(NomDeLImage, Nlevel):
    image=su.PullFromSlicer(NomDeLImage)
    NumpyImage=sitk.GetArrayFromImage(image)
    max_lev = 6       # how many levels of decomposition to draw
    coeffs = pywt.wavedecn(NumpyImage, 'db2', mode='zero', level=max_lev) #voir https://pywavelets.readthedocs.io/en/latest/ref/nd-dwt-and-idwt.html#pywt.wavedecn
    for levelR in range (max_lev-Nlevel):
        sigma = (1/0.6745) * maddest( coeffs[max_lev-levelR] )
        uthresh = sigma * np.sqrt( 2*np.log( len( NumpyImage ) ) )
        coeffs[(max_lev-levelR)] = ( pywt.threshold( i, value=uthresh, mode='hard' ) for i in coeffs[(max_lev-levelR)] )
    matrice_ondelette=pywt.waverecn(coeffs, 'db2', mode='per') #mode periodic ou zero
    image_ondelette=sitk.GetImageFromArray(matrice_ondelette)
    image_ondelette.SetSpacing(image.GetSpacing())
    image_ondelette.SetDirection(image.GetDirection())
    image_ondelette.SetOrigin(image.GetOrigin())
    su.PushToSlicer(image_ondelette,'image_DenoisWave_level0-'+str(Nlevel))



from pip._internal import main as pip_main
pip_modules = ['scipy', 'sklearn', 'PyWavelets']
for module_ in pip_modules:
    try:
        module_obj = __import__(module_)
    except ImportError:
        logging.info("{0} was not found.\n Attempting to install {0} . . ."
                     .format(module_))
        pip_main(['install', module_])



pip_main(['install','scikit-image'])


from skimage.restoration import (denoise_wavelet, estimate_sigma)
from skimage import data, img_as_float
from skimage.util import random_noise
from skimage.metrics import peak_signal_noise_ratio


Nom_image="FOUMA_1m"

def denoising_BayesShrinkAndVIsuShrink(Nom_image):
    image=su.PullFromSlicer(NomDeLImage)
    NumpyImage=sitk.GetArrayFromImage(image)
    # Estimate the average noise standard deviation across color channels.
    sigma_est = estimate_sigma(NumpyImage, multichannel=True, average_sigmas=True)
    # Due to clipping in random_noise, the estimate will be a bit smaller than the
    # specified sigma.
    print(f"Estimated Gaussian noise standard deviation = {sigma_est}")
    im_bayes = denoise_wavelet(NumpyImage, multichannel=True, convert2ycbcr=True, method='BayesShrink', mode='soft',rescale_sigma=True)
    im_visushrink = denoise_wavelet(NumpyImage, multichannel=True, convert2ycbcr=True, method='VisuShrink', mode='soft',sigma=sigma_est, rescale_sigma=True)
    su.PushToSlicer(im_bayes,'image_DenoisWave_level0-'+str(Nlevel))
    su.PushToSlicer(im_visushrink,'image_DenoisWave_level0-'+str(Nlevel))
    # VisuShrink is designed to eliminate noise with high probability, but this
    # results in a visually over-smooth appearance.  Repeat, specifying a reduction
    # in the threshold by factors of 2 and 4.
    #im_visushrink2 = denoise_wavelet(NumpyImage, multichannel=True, convert2ycbcr=True, method='VisuShrink', mode='soft', sigma=sigma_est/2, rescale_sigma=True)
    #im_visushrink4 = denoise_wavelet(NumpyImage, multichannel=True, convert2ycbcr=True,method='VisuShrink', mode='soft', sigma=sigma_est/4, rescale_sigma=True)

#list all the python module
import pip
installed_packages = pip._internal.get_installed_distributions()
installed_packages_list = sorted(["%s==%s" % (i.key, i.version)
     for i in installed_packages])
print(installed_packages_list)    

help('modules') #to find teir corresponding name    



# coding: utf-8

import unittest
from slicer.ScriptedLoadableModule import *
import logging
from __main__ import vtk, qt, ctk, slicer
from math import *
import numpy as np
from vtk.util import numpy_support
import SimpleITK as sitk
import sitkUtils as su
import time
import datetime
import sys, time, os

import pywt # https://github.com/PyWavelets/pywt/blob/master/pywt/_multilevel.py

Nom_image="imSh_1"
Nom_label="FOUMA_1m-label"

def denoising_nonlocalmeans(Nom_image, Nom_label):
    image=su.PullFromSlicer(Nom_image)
    image=sitk.Shrink(image, [2,2,2])
    label=su.PullFromSlicer(Nom_label)
    timeRMR1 = time.time()
    DenoiseFilter=sitk.PatchBasedDenoisingImageFilter() #Execute (const Image &image1, double kernelBandwidthSigma, uint32_t patchRadius, 
    #uint32_t numberOfIterations, uint32_t numberOfSamplePatches, double sampleVariance, PatchBasedDenoisingImageFilter::NoiseModelType noiseModel, 
    #double noiseSigma, double noiseModelFidelityWeight, bool alwaysTreatComponentsAsEuclidean, bool kernelBandwidthEstimation, double kernelBandwidthMultiplicationFactor, 
    #uint32_t kernelBandwidthUpdateFrequency, double kernelBandwidthFractionPixelsForEstimation)
    DenoiseFilter.SetAlwaysTreatComponentsAsEuclidean(True)
    DenoiseFilter.SetKernelBandwidthEstimation(True)
    DenoiseFilter.SetKernelBandwidthFractionPixelsForEstimation(0.5) #double KernelBandwidthFractionPixelsForEstimation
    #DenoiseFilter.SetKernelBandwidthMultiplicationFactor() #(double KernelBandwidthMultiplicationFactor)
    #DenoiseFilter.SetKernelBandwidthSigma(400) #(double KernelBandwidthSigma)  #faible voire pas d'influence
    #DenoiseFilter.SetKernelBandwidthUpdateFrequency() #(uint32_t KernelBandwidthUpdateFrequency 1 par defaut)
    DenoiseFilter.SetNoiseModel(3) #(NoiseModelType NoiseModel) #NoiseModelType { NOMODEL:0, GAUSSIAN:1, RICIAN:2,  POISSON:3}
    DenoiseFilter.SetNoiseModelFidelityWeight(0.05) #(double NoiseModelFidelityWeight entre 0 et 1)# This weight controls the balance between the smoothing and the closeness to the noisy data. 
    #DenoiseFilter.SetNoiseSigma(0.50) #(double NoiseSigma)#usualy 5% of min max of an image ##############pas d'influence  
    #DenoiseFilter.SetNumberOfIterations(1) #(uint32_t NumberOfIterations 1 par defaut)
    DenoiseFilter.SetNumberOfSamplePatches(200) #(uint32_t NumberOfSamplePatches)#200->100, 41 a 23s mais filtre plus
    DenoiseFilter.SetPatchRadius(4) #(uint32_t PatchRadius) # 2->10s 4->41s 6->121s ##############paramétre critique
    #DenoiseFilter.SetSampleVariance(400) #(double SampleVariance) #pas d'influence?
    ImageDenoised=DenoiseFilter.Execute(image)
    timeRMR2 = time.time()
    TimeForrunFunctionRMR2 = timeRMR2 - timeRMR1
    print(u"La fonction de traitement s'est executée en " + str(TimeForrunFunctionRMR2) +" secondes")
    print("\n")
    print(DenoiseFilter.GetNumberOfSamplePatches()) #200
    print("\n")
    print (DenoiseFilter.GetSampleVariance()) #400
    print("\n")
    print(DenoiseFilter.GetNoiseSigma()) #0.0
    print("\n")
    print(DenoiseFilter.GetNumberOfIterations()) #1
    print("\n")
    print(DenoiseFilter.GetKernelBandwidthSigma()) #400.0
    print("\n")
    stat_filter=sitk.LabelIntensityStatisticsImageFilter()
    stat_filter.Execute(label,image) #attention à l'ordre
    print(stat_filter.GetStandardDeviation(1)/stat_filter.GetMean(1))
    print("\n") 
    stat_filter.Execute(label,ImageDenoised) #attention à l'ordre 
    print(stat_filter.GetStandardDeviation(1)/stat_filter.GetMean(1))   
    su.PushToSlicer(ImageDenoised,'ImageDenoisedbyPatchBasedDenoisingImageFilter')


denoising_nonlocalmeans(Nom_image, Nom_label)


Nom_image="template"

def denoising_nonlocalmeans2(Nom_image):
    image=su.PullFromSlicer(Nom_image)
    Shrinkfactor=2
    image=sitk.Shrink(image, [Shrinkfactor,Shrinkfactor,Shrinkfactor])
    timeRMR1 = time.time()
    DenoiseFilter_init=sitk.PatchBasedDenoisingImageFilter() #Execute (const Image &image1, double kernelBandwidthSigma, uint32_t patchRadius, 
    #uint32_t numberOfIterations, uint32_t numberOfSamplePatches, double sampleVariance, PatchBasedDenoisingImageFilter::NoiseModelType noiseModel, 
    #double noiseSigma, double noiseModelFidelityWeight, bool alwaysTreatComponentsAsEuclidean, bool kernelBandwidthEstimation, double kernelBandwidthMultiplicationFactor, 
    #uint32_t kernelBandwidthUpdateFrequency, double kernelBandwidthFractionPixelsForEstimation)
    DenoiseFilter_init.SetAlwaysTreatComponentsAsEuclidean(True)
    DenoiseFilter_init.SetKernelBandwidthEstimation(True)
    DenoiseFilter_init.SetKernelBandwidthFractionPixelsForEstimation(0.5) #double KernelBandwidthFractionPixelsForEstimation
    #DenoiseFilter.SetKernelBandwidthMultiplicationFactor() #(double KernelBandwidthMultiplicationFactor)
    #DenoiseFilter.SetKernelBandwidthSigma(400) #(double KernelBandwidthSigma)  #faible voire pas d'influence
    #DenoiseFilter.SetKernelBandwidthUpdateFrequency() #(uint32_t KernelBandwidthUpdateFrequency 1 par defaut)
    DenoiseFilter_init.SetNoiseModel(3) #(NoiseModelType NoiseModel) #NoiseModelType { NOMODEL:0, GAUSSIAN:1, RICIAN:2,  POISSON:3}
    DenoiseFilter_init.SetNoiseModelFidelityWeight(0.05) #(double NoiseModelFidelityWeight entre 0 et 1)# This weight controls the balance between the smoothing and the closeness to the noisy data. 
    #DenoiseFilter.SetNoiseSigma(0.50) #(double NoiseSigma)#usualy 5% of min max of an image ##############pas d'influence  
    #DenoiseFilter.SetNumberOfIterations(1) #(uint32_t NumberOfIterations 1 par defaut)
    #DenoiseFilter.SetNumberOfSamplePatches(200) #(uint32_t NumberOfSamplePatches)#200->100, 41 a 23s mais filtre plus
    DenoiseFilter_init.SetPatchRadius(2) #(uint32_t PatchRadius) # 2->10s 4->41s 6->121s ##############paramétre critique
    #DenoiseFilter.SetSampleVariance(400) #(double SampleVariance) #pas d'influence?
    ImageDenoised_init=DenoiseFilter_init.Execute(image)
    timeRMR2 = time.time()
    TimeForrunFunctionRMR2 = timeRMR2 - timeRMR1
    print(u"La fonction de traitement intiale s'est executée en " + str(TimeForrunFunctionRMR2) +" secondes")
    timeRMR1 = time.time()
    DenoiseFilter=sitk.PatchBasedDenoisingImageFilter() #Execute (const Image &image1, double kernelBandwidthSigma, uint32_t patchRadius, 
    #uint32_t numberOfIterations, uint32_t numberOfSamplePatches, double sampleVariance, PatchBasedDenoisingImageFilter::NoiseModelType noiseModel, 
    #double noiseSigma, double noiseModelFidelityWeight, bool alwaysTreatComponentsAsEuclidean, bool kernelBandwidthEstimation, double kernelBandwidthMultiplicationFactor, 
    #uint32_t kernelBandwidthUpdateFrequency, double kernelBandwidthFractionPixelsForEstimation)
    DenoiseFilter.SetAlwaysTreatComponentsAsEuclidean(True)
    DenoiseFilter.SetKernelBandwidthEstimation(False)
    #DenoiseFilter.SetKernelBandwidthFractionPixelsForEstimation(0.5) #double KernelBandwidthFractionPixelsForEstimation
    #DenoiseFilter.SetKernelBandwidthMultiplicationFactor() #(double KernelBandwidthMultiplicationFactor)
    DenoiseFilter.SetKernelBandwidthSigma(DenoiseFilter_init.GetKernelBandwidthSigma()) #(double KernelBandwidthSigma)  #faible voire pas d'influence
    #DenoiseFilter.SetKernelBandwidthUpdateFrequency() #(uint32_t KernelBandwidthUpdateFrequency 1 par defaut)
    DenoiseFilter.SetNoiseModel(3) #(NoiseModelType NoiseModel) #NoiseModelType { NOMODEL:0, GAUSSIAN:1, RICIAN:2,  POISSON:3}
    DenoiseFilter.SetNoiseModelFidelityWeight(0.05) #(double NoiseModelFidelityWeight entre 0 et 1)# This weight controls the balance between the smoothing and the closeness to the noisy data. 
    DenoiseFilter.SetNoiseSigma(DenoiseFilter_init.GetNoiseSigma()) #(double NoiseSigma)#usualy 5% of min max of an image ##############pas d'influence  
    #DenoiseFilter.SetNumberOfIterations(1) #(uint32_t NumberOfIterations 1 par defaut)
    DenoiseFilter.SetNumberOfSamplePatches(DenoiseFilter_init.GetNumberOfSamplePatches()) #(uint32_t NumberOfSamplePatches)#200->100, 41 a 23s mais filtre plus
    DenoiseFilter.SetPatchRadius(DenoiseFilter_init.GetPatchRadius()*Shrinkfactor) #(uint32_t PatchRadius) # 2->10s 4->41s 6->121s ##############paramétre critique
    DenoiseFilter.SetSampleVariance(DenoiseFilter_init.GetSampleVariance()) #(double SampleVariance) #pas d'influence?
    ImageDenoised=DenoiseFilter.Execute(image)
    timeRMR2 = time.time()
    TimeForrunFunctionRMR2 = timeRMR2 - timeRMR1
    print(u"La fonction de traitement final s'est executée en " + str(TimeForrunFunctionRMR2) +" secondes")
    su.PushToSlicer(ImageDenoised,'ImageDenoisedbyPatchBasedDenoisingImageFilter')


denoising_nonlocalmeans2(Nom_image)


# coding: utf-8

import unittest
from slicer.ScriptedLoadableModule import *
import logging
from __main__ import vtk, qt, ctk, slicer
from math import *
import numpy as np
from vtk.util import numpy_support
import SimpleITK as sitk
import sitkUtils as su
import time
import datetime
import sys, time, os
from itertools import *
import six

import pywt # https://github.com/PyWavelets/pywt/blob/master/pywt/_multilevel.py

Nom_image="FOUMA_1m"

a=2.0
d=0.5
DecompMatrixSpacingfactor={
    'aad':[a,a,d], 
    'ada':[a,d,a], 
    'add':[a,d,d], 
    'daa':[d,a,a], 
    'dad':[d,a,d],
    'dda':[d,d,a],
    'ddd':[d,d,d],
}

def reechantillonage_identity(image_ref,image_to_transform):
    identity = sitk.TranslationTransform(3, (0,0,0))
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(image_ref)
    resampler.SetInterpolator(sitk.sitkLinear)
    resampler.SetDefaultPixelValue(0)
    resampler.SetTransform(identity)
    ImageRecaler = resampler.Execute(image_to_transform)
    return ImageRecaler 

def realspace_spacing(list_elem, Zoom):
    DMSF=np.asarray(DecompMatrixSpacingfactor[list_elem], dtype=np.float64)
    Zoom=np.asarray(Zoom, dtype=np.float64)
    RSS=DMSF*Zoom
    return RSS

def cropImagefctLabel(image, LowerBondingBox, UpperBondingBox  ):
  crop=sitk.CropImageFilter()
  image_cropper=crop.Execute(image, LowerBondingBox, UpperBondingBox  )
  return image_cropper


def CreateGaussianKernel(RS, matrice_spacing ):  #to modify to mono exponential
    imageGaussian=sitk.GaussianImageSource()
    imageGaussian.SetOutputPixelType(sitk.sitkUInt16)
    size=ceil(3*RS/matrice_spacing)
    if (size % 2)==0 :
        size=size+1
    imageGaussian.SetSize([size,size,size])   #taille=size*spacing
    sigma=(RS/2.35)/matrice_spacing
    imageGaussian.SetSigma([sigma,sigma,sigma])  #FWHM/2.35 remaruqe ten 4.29*sigma
    imageGaussian.SetMean([0,0,0])  #centre image=mean/spacing
    imageGaussian.SetScale(100)
    imageGaussian.SetOrigin([-((size-1)/2),-((size-1)/2),-((size-1)/2)])
    imageGaussian.SetSpacing([1,1,1])
    imageGaussian.SetDirection([1,0,0,0,1,0,0,0,1])
    kernel=imageGaussian.Execute()
    #♣su.PushToSlicer(kernel,"kernel",1)
    return kernel  

def RLdeconvolutionTV(image,kernel,alpha):
    ################initialisation#############
    RL=sitk.RichardsonLucyDeconvolutionImageFilter()
    laplacian= sitk.LaplacianImageFilter()
    normgradient=sitk.GradientMagnitudeImageFilter()
    divide=sitk.DivideImageFilter()
    multiply=sitk.MultiplyImageFilter()
    Substract=sitk.SubtractImageFilter()
    Cast=sitk.CastImageFilter()
    ##########terme regularisation##############
    image_cast=sitk.Cast(image,sitk.sitkFloat64)
    L=laplacian.Execute(image_cast)
    NG=normgradient.Execute(image_cast)
    NG=sitk.Cast(NG,sitk.sitkFloat64)
    i1=divide.Execute(L, NG )
    i2=multiply.Execute( i1, alpha) #landaTV=0.02
    i3=Substract.Execute(1,i2)
    i4=divide.Execute(1,i3)
    ##############deconvolution#########
    Niteration=1
    Normalise=True
    BoundaryCondition=1 #zerofluxNemaanpad
    OutputRegionMode=0 #same
    image_cast=sitk.Cast(image,sitk.sitkUInt16)
    imagedecon=RL.Execute(image_cast,kernel,Niteration, Normalise, BoundaryCondition,OutputRegionMode)
    imagedecon=sitk.Cast(imagedecon,sitk.sitkFloat64)
    imagedeconRLTV=multiply.Execute(imagedecon,i4)
    return imagedeconRLTV


def SpatialFrequencyOptim2(matrix):
    sq_diff = 0.0
    size=matrix.shape
    dim=len(size)  
    for i in range(dim): #iterate over all image dimensions
        slc1 = [slice(None)]*dim
        slc1[i] = slice(0,size[i]-1)
        slc2 = [slice(None)]*dim
        slc2[i] = slice(1,size[i])
        sq_diff+= np.sum((matrix[tuple(slc2)]- matrix[tuple(slc1)])**2)
    return sq_diff/np.prod(size)

def denoising_nonlocalmeans(image, nom, radius, Niteration):
    timeRMR1 = time.time()
    #su.PushToSlicer(image,'image_Origine'+str(nom))
    DenoiseFilter=sitk.PatchBasedDenoisingImageFilter() 
    DenoiseFilter.SetAlwaysTreatComponentsAsEuclidean(True)
    DenoiseFilter.SetKernelBandwidthEstimation(True)
    #DenoiseFilter.SetKernelBandwidthFractionPixelsForEstimation(0.5) #double KernelBandwidthFractionPixelsForEstimation
    #DenoiseFilter.SetKernelBandwidthMultiplicationFactor() #(double KernelBandwidthMultiplicationFactor)
    #DenoiseFilter.SetKernelBandwidthSigma(400) #(double KernelBandwidthSigma)  #faible voire pas d'influence
    #DenoiseFilter.SetKernelBandwidthUpdateFrequency() #(uint32_t KernelBandwidthUpdateFrequency 1 par defaut)
    DenoiseFilter.SetNoiseModel(0) #(NoiseModelType NoiseModel) #NoiseModelType { NOMODEL:0, GAUSSIAN:1, RICIAN:2,  POISSON:3}
    DenoiseFilter.SetNoiseModelFidelityWeight(0.05) #(double NoiseModelFidelityWeight entre 0 et 1)# This weight controls the balance between the smoothing and the closeness to the noisy data. 
    #DenoiseFilter.SetNoiseSigma(0.50) #(double NoiseSigma)#usualy 5% of min max of an image ##############pas d'influence  
    DenoiseFilter.SetNumberOfIterations(Niteration) #(uint32_t NumberOfIterations 1 par defaut)
    #DenoiseFilter.SetNumberOfSamplePatches(200) #(uint32_t NumberOfSamplePatches)#200->100, 41 a 23s mais filtre plus
    DenoiseFilter.SetPatchRadius(radius) #(uint32_t PatchRadius) # 2->10s 4->41s 6->121s ##############paramétre critique
    #DenoiseFilter.SetSampleVariance(400) #(double SampleVariance) #pas d'influence?
    ImageDenoised=DenoiseFilter.Execute(image)
    #su.PushToSlicer(ImageDenoised,'image_Origine_Denoised'+str(nom))
    timeRMR2 = time.time()
    TimeForrunFunctionRMR2 = timeRMR2 - timeRMR1
    print(u"    NLM-denoising of " + str(nom) +" matrix:")
    print(u"    Le rayon analyser est " + str(radius) +" voxel")
    print(u"    La fonction denoising_nonlocalmeans s'est executée en " + str(TimeForrunFunctionRMR2) +" secondes")
    print("\n")
    return ImageDenoised


def ParchBasedandOndeletteDenoising(Nom_image,denoising,correctPVE):
    timeRMR1 = time.time()
    image=su.PullFromSlicer(Nom_image)
    ####crop pour acceleration############################################
    label_complet=sitk.BinaryThreshold(image, 0.1, 500, 1,0)
    label_complet=sitk.ConnectedComponent(label_complet, True)
    label_complet=sitk.RelabelComponent(label_complet)
    stats= sitk.LabelIntensityStatisticsImageFilter()
    stats.Execute(label_complet,image)
    delta=0 #extention du label pour eviter les problemes aux bords
    LowerBondingBox=[stats.GetBoundingBox(1)[0]-delta,stats.GetBoundingBox(1)[1]-delta,stats.GetBoundingBox(1)[2]-delta]
    UpperBondingBox=[image.GetSize()[0]-(stats.GetBoundingBox(1)[0]+stats.GetBoundingBox(1)[3]+delta),image.GetSize()[1]-(stats.GetBoundingBox(1)[1]+stats.GetBoundingBox(1)[4]+delta),image.GetSize()[2]-(stats.GetBoundingBox(1)[2]+stats.GetBoundingBox(1)[5]+delta)]
    image=cropImagefctLabel(image, LowerBondingBox, UpperBondingBox  )
    ###############################################################
    ##########################wavelets decomposition#############
    ###########################nlm denoising###############
    image_spacing=image.GetSpacing()
    image_size=image.GetSize()
    NumpyImage=sitk.GetArrayFromImage(image)
    max_lev = 2       # how many levels of decomposition to draw
    #pywt.swt_max_level(input_len) #give the maw level
    radius=16  # in mm critique pour le temps et dans quelle rayon on peut trouver des voxels similaires
    Niteration=1 #nombre d'iteration pour le denoising
    c = pywt.wavedecn(NumpyImage, 'coif1', mode='zero', level=max_lev) #voir https://pywavelets.readthedocs.io/en/latest/ref/nd-dwt-and-idwt.html#pywt.wavedecn
    #c = pywt.swtn(NumpyImage, wavelet='db2', level=max_lev, start_level=0, axes=None, trim_approx=False, norm=False) #voir https://pywavelets.readthedocs.io/en/latest/ref/nd-dwt-and-idwt.html#pywt.wavedecn   
    #c= pywt.swtn(NumpyImage, 'coif1', 0, max_lev, None)
    #c_arr,c_slices=swt3(image, "coif1", max_lev, 0)
    c_arr,c_slices= pywt.coeffs_to_array(c, padding=0, axes=None) #separe les sous matrices aprés decomposition et leur indices
    list_elem=[]
    for row in c_slices:
        for elem in row:
            list_elem.append(elem) #list de tous les elements de matrix [aad,add,ddd] a: average, d: detail
    print(list_elem)
    for level in range(1,max_lev+1):
        for keys in range(int((len(list_elem)-3)/max_lev)):
            matrix=c_arr[c_slices[level][list_elem[keys+3]]]
            matrix_size=matrix.shape
             ##in mm# image have to be isotropic
            #radius_voxels=ceil(radius/matrice_spacing)
            print(u"NLM-denoising of "+str(level)+" level " + str(list_elem[keys+3]) +" matrix")
            print(u"Matrix size "+str(matrix_size)+" matrice_spacing " + str(matrice_spacing) +" mm")
            print(u"Le rayon analyser est de " + str(radius_voxels) +" voxel")
            print(u"La frequence spatiale est de  " + str(SpatialFrequencyOptim2(matrix)) +" voxel")
            print("\n")
            image_ondelette=sitk.GetImageFromArray(matrix)
            Zoom=[image_size[0]/matrix_size[0],image_size[1]/matrix_size[1], image_size[2]/matrix_size[2]]
            SpacingImageOndelette_realspace=realspace_spacing(list_elem, Zoom)
            SpacingImageOndelette=np.asarray(DecompMatrixSpacingfactor[list_elem], dtype=np.float64)
            image_ondelette.SetSpacing(SpacingImageOndelette)
            #image_ondelette=reechantillonage_identity(image,image_ondelette)
            ######################################################################################
            #####################################deconvolution###################################
            if (correctPVE==1):
                limitRC=13 #limite taille en mm pour RC<0.95
                RS=4 #spatiale resolution en mm of the system 
                if (<limitRC):
                    kernel=CreateGaussianKernel(RS, min(SpacingImageOndelette_realspace))
                    alpha=0.02
                    RC=0.0937*min(SpacingImageOndelette_realspace)
                    iterRC=0
                    while (image.GetMaximum() <(np.max(matrix)/(3*RC)) ):
                        iterRC=iterRC+1    
                        image_ondelette=RLdeconvolutionTV(image_ondelette,kernel, alpha)
                    print(iterRC) 
                    print("deconvolution ok")
            ##################################################################################
            ##################################################################################
            #il faudrait faire une deconvolution RL avant le denoising? avec comme citére d'arret les coefficient de recovery RC
            #if (2*np.std(matrix)/(np.max(matrix)+abs(np.min(matrix))<0.01):
            #####################################Denoising########################################
            ######################################################################################
            if (denoising==1):
                if (SpatialFrequencyOptim2(matrix)>0.1 and (radius>max(SpacingImageOndelette_realspace))):# contraintes suffisament d'information pour impact sup a suv de 0.01 et radius pas trop grand par rapport image
                    image_ondelette=denoising_nonlocalmeans(image_ondelette,list_elem[keys+3],ceil(radius/max(SpacingImageOndelette_realspace)), Niteration )
            ########################################################################################
            ######################################################################################
            image_ondelette_TranformSpace=
            image_ondelette=reechantillonage_identity(image_ondelette,image_ondelette_realspace)
            image_ondelette_realspace=sitk.GetArrayFromImage(image_ondelette_denoised)
            c_arr[c_slices[level][list_elem[keys+3]]]=matrix_ondelette
    c=pywt.array_to_coeffs(c_arr,c_slices) #recombine les sous matrices apres decomposition et leur indices
    matrice_ondelette=pywt.waverecn(c, 'db2') #decomposition en ondeleet inverse
    #matrice_ondelette=pywt.iswtn(c_arr, 'db2',max_lev)
    image_ondelette=sitk.GetImageFromArray(matrice_ondelette)
    image_ondelette.SetSpacing(image.GetSpacing())
    image_ondelette.SetDirection(image.GetDirection())
    image_ondelette.SetOrigin(image.GetOrigin())
    su.PushToSlicer(image_ondelette,'image_Denoised_final')
    timeRMR2 = time.time()
    TimeForrunFunctionRMR2 = timeRMR2 - timeRMR1
    print(u"La fonction de traitement total s'est executée en " + str(TimeForrunFunctionRMR2) +" secondes")

ParchBasedandOndeletteDenoising(Nom_image,denoising=False,correctPVE=False)

#################################################deconvolution
#inspirer de py radiomics

def getWaveletImage(inputImage, **kwargs):
  """
  Apply wavelet filter to image and compute signature for each filtered image.

  Following settings are possible:

  - start_level [0]: integer, 0 based level of wavelet which should be used as first set of decompositions
    from which a signature is calculated
  - level [1]: integer, number of levels of wavelet decompositions from which a signature is calculated.
  - wavelet ["coif1"]: string, type of wavelet decomposition. Enumerated value, validated against possible values
    present in the ``pyWavelet.wavelist()``. Current possible values (pywavelet version 0.4.0) (where an
    aditional number is needed, range of values is indicated in []):

    - haar
    - dmey
    - sym[2-20]
    - db[1-20]
    - coif[1-5]
    - bior[1.1, 1.3, 1.5, 2.2, 2.4, 2.6, 2.8, 3.1, 3.3, 3.5, 3.7, 3.9, 4.4, 5.5, 6.8]
    - rbio[1.1, 1.3, 1.5, 2.2, 2.4, 2.6, 2.8, 3.1, 3.3, 3.5, 3.7, 3.9, 4.4, 5.5, 6.8]

  Returned filter name reflects wavelet type:
  wavelet[level]-<decompositionName>

  N.B. only levels greater than the first level are entered into the name.

  :return: Yields each wavelet decomposition and final approximation, corresponding filter name and ``kwargs``
  """
  global logger

  logger.debug("Generating Wavelet images")

  approx, ret = _swt3(inputImage, kwargs.get('wavelet', 'coif1'), kwargs.get('level', 1), kwargs.get('start_level', 0))

  for idx, wl in enumerate(ret, start=1):
    for decompositionName, decompositionImage in wl.items():
      print('Computing Wavelet %s', decompositionName)

      if idx == 1:
        inputImageName = 'wavelet-%s' % (decompositionName)
      else:
        inputImageName = 'wavelet%s-%s' % (idx, decompositionName)
      print('Yielding %s image', inputImageName)
      yield decompositionImage, inputImageName, kwargs

  if len(ret) == 1:
    inputImageName = 'wavelet-LLL'
  else:
    inputImageName = 'wavelet%s-LLL' % (len(ret))
  print('Yielding approximation (%s) image', inputImageName)
  yield approx, inputImageName, kwargs



#def _swt3(inputImage, wavelet="coif1", level=1, start_level=0):
def swt3(inputImage, wavelet, level, start_level):
  matrix = sitk.GetArrayFromImage(inputImage)
  matrix = np.asarray(matrix)
  if matrix.ndim != 3:
    raise ValueError("Expected 3D data array")
  original_shape = matrix.shape
  adjusted_shape = tuple([dim + 1 if dim % 2 != 0 else dim for dim in original_shape])
  data = matrix.copy()
  data.resize(adjusted_shape, refcheck=False)
  if not isinstance(wavelet, pywt.Wavelet):
    wavelet = pywt.Wavelet(wavelet)
  for i in range(0, start_level):
    H, L = _decompose_i(data, wavelet)
    LH, LL = _decompose_j(L, wavelet)
    LLH, LLL = _decompose_k(LL, wavelet)
    data = LLL.copy()
  ret = []
  for i in range(start_level, start_level + level):
    H, L = _decompose_i(data, wavelet)
    HH, HL = _decompose_j(H, wavelet)
    LH, LL = _decompose_j(L, wavelet)
    HHH, HHL = _decompose_k(HH, wavelet)
    HLH, HLL = _decompose_k(HL, wavelet)
    LHH, LHL = _decompose_k(LH, wavelet)
    LLH, LLL = _decompose_k(LL, wavelet)
    data = LLL.copy()
    dec = {'HHH': HHH,
           'HHL': HHL,
           'HLH': HLH,
           'HLL': HLL,
           'LHH': LHH,
           'LHL': LHL,
           'LLH': LLH}
    for decName, decImage in six.iteritems(dec):
      decTemp = decImage.copy()
      decTemp = np.resize(decTemp, original_shape)
      sitkImage = sitk.GetImageFromArray(decTemp)
      sitkImage.CopyInformation(inputImage)
      dec[decName] = sitkImage
    ret.append(dec)
  data = np.resize(data, original_shape)
  approximation = sitk.GetImageFromArray(data)
  approximation.CopyInformation(inputImage)
  #return approximation
  return approximation, ret


def _decompose_i(data, wavelet):
  # process in i:
  H, L = [], []
  i_arrays = chain.from_iterable(data)
  for i_array in i_arrays:
    cA, cD = pywt.swt(i_array, wavelet, level=1, start_level=0)[0]
    H.append(cD)
    L.append(cA)
  H = np.hstack(H).reshape(data.shape)
  L = np.hstack(L).reshape(data.shape)
  return H, L


def _decompose_j(data, wavelet):
  # process in j:
  s = data.shape
  H, L = [], []
  j_arrays = chain.from_iterable(np.transpose(data, (0, 2, 1)))
  for j_array in j_arrays:
    cA, cD = pywt.swt(j_array, wavelet, level=1, start_level=0)[0]
    H.append(cD)
    L.append(cA)
  H = np.hstack(H).reshape((s[0], s[2], s[1])).transpose((0, 2, 1))
  L = np.hstack(L).reshape((s[0], s[2], s[1])).transpose((0, 2, 1))
  return H, L


def _decompose_k(data, wavelet):
  # process in k:
  H, L = [], []
  k_arrays = chain.from_iterable(np.transpose(data, (2, 1, 0)))
  for k_array in k_arrays:
    cA, cD = pywt.swt(k_array, wavelet, level=1, start_level=0)[0]
    H.append(cD)
    L.append(cA)
  H = np.asarray([slice for slice in np.split(np.vstack(H), data.shape[2])]).T
  L = np.asarray([slice for slice in np.split(np.vstack(L), data.shape[2])]).T
  return H, L



########################test########################
import unittest
from slicer.ScriptedLoadableModule import *
import logging
from __main__ import vtk, qt, ctk, slicer
from math import *
import numpy as np
from vtk.util import numpy_support
import SimpleITK as sitk
import sitkUtils as su
import time
import datetime
import sys, time, os
from itertools import *
import six

import pywt # https://github.com/PyWavelets/pywt/blob/master/pywt/_multilevel.py
from skimage.restoration import denoise_wavelet, cycle_spin
from skimage.metrics import peak_signal_noise_ratio
import skimage

Nom_image="FOUMA_1m"

def skimage_shift_invariant_wavelet(Nom_image):
    image=su.PullFromSlicer(Nom_image)
    NumpyImage=sitk.GetArrayFromImage(image)
    # Repeat denosing with different amounts of cycle spinning.  e.g.
    # max_shift = 0 -> no cycle spinning
    # max_shift = 1 -> shifts of (0, 1) along each axis
    # max_shift = 3 -> shifts of (0, 1, 2, 3) along each axis
    # etc...
    denoise_kwargs = dict(multichannel=False, convert2ycbcr=False, wavelet='db2', rescale_sigma=True,method='BayesShrink')
    all_psnr = []
    max_shifts = [0, 1, 3, 5]
    for n, s in enumerate(max_shifts):
        im_bayescs = cycle_spin(NumpyImage, func=denoise_wavelet, max_shifts=s, func_kw=denoise_kwargs, multichannel=False)
        #psnr = peak_signal_noise_ratio(NumpyImage, im_bayescs, )
        #all_psnr.append(psnr)
        #print("shift: "+str(s+1)+" psnr "+str(psnr))
        #print("\n")
        image_ondelette=sitk.GetImageFromArray(im_bayescs)
        image_ondelette.SetSpacing(image.GetSpacing())
        image_ondelette.SetDirection(image.GetDirection())
        image_ondelette.SetOrigin(image.GetOrigin())
        su.PushToSlicer(image_ondelette,'image_Denoised_final_shifts_'+str(s+1))

skimage_shift_invariant_wavelet(Nom_image)

skimage.measure.compare_psnr(im_true, im_test)
skimage.measure.compare_mse(im1, im2)
skimage.measure.compare_ssim()
sigma_est = estimate_sigma(noisy, multichannel=True, average_sigmas=True)