def wavedec2(data: np.array, wavelet: JaxWavelet, level: int = None) -> list:
    """Compute the two dimensional wavelet analysis transform on the last two dimensions 
       of the input data array.
    Args:
        data (np.array): Jax array containing the data to be transformed. Assumed shape:
                         [batch size, channels, hight, width].
        wavelet (JaxWavelet): A namedtouple containing the filters for the transformation.
        level (int, optional): The max level to be used, if not set as many levels as possible
                               will be used. Defaults to None.
    Returns:
        list: The wavelet coefficients in a nested list.
    """
    dec_lo, dec_hi, _, _ = get_filter_arrays(wavelet, flip=True)
    dec_filt = construct_2d_filt(lo=dec_lo, hi=dec_hi)
    filt_len = dec_lo.shape[-1]

    if level is None:
        level = pywt.dwtn_max_level([data.shape[-1], data.shape[-2]], pywt.Wavelet('MyWavelet', wavelet))

    result_lst = []
    res_ll = data
    for _ in range(level):
        res_ll = fwt_pad2d(res_ll, wavelet)
        res = jax.lax.conv_general_dilated(
            lhs=res_ll,  # lhs = NCHw image tensor
            rhs=dec_filt,  # rhs = OIHw conv kernel tensor
            padding='VALID', window_strides=[2, 2],
            dimension_numbers=('NCHW', 'OIHW', 'NCHW'),
        )

        res_ll, res_lh, res_hl, res_hh = np.split(res, 4, 1)
        result_lst.append((res_lh, res_hl, res_hh))
    result_lst.append(res_ll)
    result_lst.reverse()
    return result_lst
def conv_fwt_2d(data, wavelet, scales: int = None) -> list:
    """ Non seperated two dimensional wavelet transform.

    Args:
        data (torch.tensor): [batch_size, 1, height, width]
        wavelet (WaveletFilter): The wavelet object to be used.
        scales (int, optional):  The scale level to be computed.
                                Defaults to None.

    Returns:
        [list]: List containing the wavelet coefficients.
    """
    # dec_lo, dec_hi, _, _ = wavelet.filter_bank
    # filt_len = len(dec_lo)
    # dec_lo = torch.tensor(dec_lo[::-1]).unsqueeze(0)
    # dec_hi = torch.tensor(dec_hi[::-1]).unsqueeze(0)
    dec_lo, dec_hi, _, _ = get_filter_tensors(wavelet,
                                              flip=True,
                                              device=data.device)
    # filt_len = dec_lo.shape[-1]
    dec_filt = construct_2d_filt(lo=dec_lo, hi=dec_hi)

    if scales is None:
        scales = pywt.dwtn_max_level([data.shape[-1], data.shape[-2]], wavelet)

    result_lst = []
    res_ll = data
    for s in range(scales):
        res_ll = fwt_pad2d(res_ll, wavelet)
        res = torch.nn.functional.conv2d(res_ll, dec_filt, stride=2)
        res_ll, res_lh, res_hl, res_hh = torch.split(res, 1, 1)
        result_lst.append((res_lh, res_hl, res_hh))
    result_lst.append(res_ll)
    return result_lst[::-1]
def sep_conv_fwt_2d(data, wavelet, scales: int = None) -> list:
    """ Non seperated two dimensional wavelet transform.

    Args:
        data (torch.tensor): [batch_size, height, width]
        wavelet (util.WaveletFilter or pywt.wavelet): The wavelet object.
        scales (int, optional): The number of decomposition scales.
                                 Defaults to None.

    Returns:
        [list]: List containing the wavelet coefficients.
    """
    ds = data.shape
    dec_lo, dec_hi, _, _ = get_filter_tensors(wavelet,
                                              flip=True,
                                              device=data.device)
    filt = torch.stack([dec_lo, dec_hi], 0)

    if scales is None:
        scales = pywt.dwtn_max_level([data.shape[-1], data.shape[-2]], wavelet)

    result_lst = []
    res_ll = data
    for s in range(scales):
        res_ll = fwt_pad2d(res_ll, wavelet)
        res_ll = res_ll.reshape(
            [ds[0] * ds[1], res_ll.shape[2], res_ll.shape[3]])
        rll_s = res_ll.shape
        res_llr = res_ll.reshape([rll_s[0] * rll_s[1], rll_s[2]]).unsqueeze(1)
        res = torch.nn.functional.conv1d(res_llr, filt, stride=2)
        res_l, res_h = torch.split(res, 1, 1)
        res_l = res_l.reshape([rll_s[0], rll_s[1], rll_s[2] // 2])
        res_h = res_h.reshape([rll_s[0], rll_s[1], rll_s[2] // 2])
        res_lt = res_l.permute(0, 2, 1)
        res_ht = res_h.permute(0, 2, 1)
        res_ltr = res_lt.reshape([-1, res_lt.shape[-1]]).unsqueeze(1)
        res_htr = res_ht.reshape([-1, res_ht.shape[-1]]).unsqueeze(1)
        res_l2 = torch.nn.functional.conv1d(res_ltr, filt, stride=2)
        res_h2 = torch.nn.functional.conv1d(res_htr, filt, stride=2)
        res_ll, res_lh = torch.split(res_l2, 1, 1)
        res_hl, res_hh = torch.split(res_h2, 1, 1)
        res_llr = res_ll.reshape(rll_s[0], rll_s[2] // 2, rll_s[1] // 2)
        res_lhr = res_lh.reshape(rll_s[0], rll_s[2] // 2, rll_s[1] // 2)
        res_hlr = res_hl.reshape(rll_s[0], rll_s[2] // 2, rll_s[1] // 2)
        res_hhr = res_hh.reshape(rll_s[0], rll_s[2] // 2, rll_s[1] // 2)
        res_llrp = res_llr.permute([0, 2, 1])
        res_lhrp = res_lhr.permute([0, 2, 1])
        res_hlrp = res_hlr.permute([0, 2, 1])
        res_hhrp = res_hhr.permute([0, 2, 1])
        # res = torch.nn.functional.conv2d(res_ll, dec_filt, stride=2)
        # res_ll, res_lh, res_hl, res_hh = torch.split(res, 1, 1)
        result_lst.append(
            (res_lhrp.reshape(ds[0], ds[1], rll_s[1] // 2, rll_s[2] // 2),
             res_hlrp.reshape(ds[0], ds[1], rll_s[1] // 2, rll_s[2] // 2),
             res_hhrp.reshape(ds[0], ds[1], rll_s[1] // 2, rll_s[2] // 2)))
        res_ll = res_llrp.reshape(ds[0], ds[1], rll_s[1] // 2, rll_s[2] // 2)
    result_lst.append(
        res_llrp.reshape(ds[0], ds[1], rll_s[1] // 2, rll_s[2] // 2))
    return result_lst[::-1]
Пример #4
0
 def update_level(self):
     wavelet = self.wavelet_combo.currentText()
     max_level = pywt.dwtn_max_level(self.image.shape[:-1], wavelet)
     self.level_spin.blockSignals(True)
     self.level_spin.setRange(1, max_level)
     self.level_spin.setValue(max_level // 2)
     self.level_spin.blockSignals(False)
     self.compute_dwt()
Пример #5
0
def inverse_wavelet_transform(vres,
                              inv_filters=None,
                              output_shape=None,
                              levels=None):
    if inv_filters is None:
        w = pywt.Wavelet('db4')
        rec_hi = np.array(w.rec_hi)
        rec_lo = np.array(w.rec_lo)
        inv_filters = np.stack([
            rec_lo[None, None, :] * rec_lo[None, :, None] *
            rec_lo[:, None, None], rec_lo[None, None, :] *
            rec_lo[None, :, None] * rec_hi[:, None, None],
            rec_lo[None, None, :] * rec_hi[None, :, None] *
            rec_lo[:, None, None], rec_lo[None, None, :] *
            rec_hi[None, :, None] * rec_hi[:, None, None],
            rec_hi[None, None, :] * rec_lo[None, :, None] *
            rec_lo[:, None, None], rec_hi[None, None, :] *
            rec_lo[None, :, None] * rec_hi[:, None, None],
            rec_hi[None, None, :] * rec_hi[None, :, None] *
            rec_lo[:, None, None], rec_hi[None, None, :] *
            rec_hi[None, :, None] * rec_hi[:, None, None]
        ]).transpose((1, 2, 3, 0))[:, :, :, None, :]
        inv_filters = K.constant(inv_filters)
    if levels is None:
        levels = pywt.dwtn_max_level(K.int_shape(vres)[1:4], 'db4')
        print(levels)

    t = vres.shape[1]
    h = vres.shape[2]
    w = vres.shape[3]
    '''
    res = K.permute_dimensions(vres, (0, 4, 1, 2, 3))
    res = K.reshape(res, (-1, t // 2, 2, h // 2, w // 2))
    res = K.permute_dimensions(res, (0, 2, 1, 3, 4))
    res = K.reshape(res, (-1, 8, t // 2, h // 2, w // 2))
    res = K.permute_dimensions(res, (0, 2, 3, 4, 1))
    '''
    res = K.reshape(vres, (-1, t // 2, h // 2, w // 2, 8))
    if levels > 1:
        res = K.concatenate([
            inverse_wavelet_transform(
                res[:, :, :, :, :1],
                inv_filters,
                output_shape=(K.shape(vres)[0], K.shape(vres)[1] // 2,
                              K.shape(vres)[2] // 2, K.shape(vres)[3] // 2,
                              K.shape(vres)[4]),
                levels=(levels - 1)), res[:, :, :, :, 1:]
        ],
                            axis=-1)
    res = K.conv3d_transpose(res,
                             inv_filters,
                             output_shape=K.shape(vres),
                             strides=(2, 2, 2),
                             padding='same')

    out = res[:, :output_shape[1], :output_shape[2], :output_shape[3], :]
    #print('iwt', levels, K.int_shape(vres), K.int_shape(inv_filters), K.int_shape(res), K.int_shape(out), output_shape)
    return out
Пример #6
0
def wavelet_transform(img, filters=None, levels=None):
    if levels is None:
        vimg = tf.pad(img, [(0, 0),
                            (0, 2**int(np.ceil(np.log2(K.int_shape(img)[1]))) -
                             K.int_shape(img)[1]),
                            (0, 2**int(np.ceil(np.log2(K.int_shape(img)[2]))) -
                             K.int_shape(img)[2]),
                            (0, 2**int(np.ceil(np.log2(K.int_shape(img)[3]))) -
                             K.int_shape(img)[3]), (0, 0)])
    else:
        vimg = img

    if filters is None:
        w = pywt.Wavelet('db4')
        dec_hi = np.array(w.dec_hi[::-1])
        dec_lo = np.array(w.dec_lo[::-1])
        filters = np.stack([
            dec_lo[None, None, :] * dec_lo[None, :, None] *
            dec_lo[:, None, None], dec_lo[None, None, :] *
            dec_lo[None, :, None] * dec_hi[:, None, None],
            dec_lo[None, None, :] * dec_hi[None, :, None] *
            dec_lo[:, None, None], dec_lo[None, None, :] *
            dec_hi[None, :, None] * dec_hi[:, None, None],
            dec_hi[None, None, :] * dec_lo[None, :, None] *
            dec_lo[:, None, None], dec_hi[None, None, :] *
            dec_lo[None, :, None] * dec_hi[:, None, None],
            dec_hi[None, None, :] * dec_hi[None, :, None] *
            dec_lo[:, None, None], dec_hi[None, None, :] *
            dec_hi[None, :, None] * dec_hi[:, None, None]
        ]).transpose((1, 2, 3, 0))[:, :, :, None, :]
        filters = K.constant(filters)
    if levels is None:
        print(K.int_shape(vimg)[1:4])
        levels = pywt.dwtn_max_level(K.int_shape(vimg)[1:4], 'db4')
        print(levels)

    t = vimg.shape[1]
    h = vimg.shape[2]
    w = vimg.shape[3]
    res = K.conv3d(vimg, filters, strides=(2, 2, 2), padding='same')
    if levels > 1:
        res = K.concatenate([
            wavelet_transform(res[:, :, :, :, :1],
                              filters,
                              levels=(levels - 1)), res[:, :, :, :, 1:]
        ],
                            axis=-1)
    '''
    res = K.permute_dimensions(res, (0, 4, 1, 2, 3))
    res = K.reshape(res, (-1, 2, t // 2, h // 2, w // 2))
    res = K.permute_dimensions(res, (0, 2, 1, 3, 4))
    res = K.reshape(res, (-1, 1, t, h, w))
    res = K.permute_dimensions(res, (0, 2, 3, 4, 1))
    '''
    res = K.reshape(res, (-1, t, h, w, 1))
    #print('wt', levels, K.int_shape(img), K.int_shape(vimg), K.int_shape(filters), K.int_shape(res))
    return res
Пример #7
0
def test_dwtn_max_level():
    # predicted and empirical dwtn_max_level match
    for wav in [pywt.Wavelet('db2'), 'sym8']:
        for data_shape in [(33, ), (64, 32), (1, 15, 30)]:
            for axes in [None, 0, -1]:
                for mode in pywt.Modes.modes:
                    coeffs = pywt.wavedecn(np.ones(data_shape), wav,
                                           mode=mode, axes=axes)
                    max_lev = pywt.dwtn_max_level(data_shape, wav, axes)
                    assert_equal(len(coeffs[1:]), max_lev)
Пример #8
0
def test_dwtn_max_level():
    # predicted and empirical dwtn_max_level match
    for wav in [pywt.Wavelet('db2'), 'sym8']:
        for data_shape in [(33, ), (64, 32), (1, 15, 30)]:
            for axes in [None, 0, -1]:
                for mode in pywt.Modes.modes:
                    coeffs = pywt.wavedecn(np.ones(data_shape), wav,
                                           mode=mode, axes=axes)
                    max_lev = pywt.dwtn_max_level(data_shape, wav, axes)
                    assert_equal(len(coeffs[1:]), max_lev)
Пример #9
0
def cdf97_2d_forward(x, level, axes=(-2, -1)):
    '''Forward 2D Cohen–Daubechies–Feauveau 9/7 wavelet.

    Parameters
    ----------
    x : array_like
        2D signal.
    level : int
        Decomposition level.
    axes : tuple, optional
        Axes to perform wavelet decomposition across.

    Returns
    -------
    wavelet_transform : array_like
        The stitched together elements wvlt (see combine_chunks).
    locations : list
        Indices telling us how we stitched it together so we can take
        it back apart.

    Notes
    -----
    Returns transform, same shape as input, with locations.
    Locations is a list of indices instructing cdf97_2d_inverse where
    the coefficients for each block are located.

    Biorthogonal 4/4 is the same as CDF 9/7 according to wikipedia
    [1]_.

    References
    ----------
    .. [1] https://en.wikipedia.org/wiki/
           Cohen%E2%80%93Daubechies%E2%80%93Feauveau_wavelet#Numbering
    '''

    # Make sure we don't go too deep
    max_level = pywt.dwtn_max_level(x.shape, 'bior4.4', axes=axes)
    if level > max_level:
        msg = ('Level %d cannot be achieved, using max level=%d!'
               '' % (level, max_level))
        warnings.warn(msg)
        level = max_level

    # periodization seems to be the only way to get shapes to line up.
    cdf97 = pywt.wavedec2(x,
                          wavelet='bior4.4',
                          mode='periodization',
                          level=level,
                          axes=axes)

    # Now throw all the chunks together
    return combine_chunks(cdf97, x.shape, x.dtype)
Пример #10
0
def wavelet_forward(x, wavelet, mode='symmetric', level=None, axes=(-2, -1)):
    '''Wrapper for the multilevel 2D discrete wavelet transform.

    Parameters
    ----------
    x : array_like
        Input data.
    wavelet : str
        Wavelet to use.
    mode : str, optional
        Signal extension mode.
    level : int, optional
        Decomposition level (must be >= 0).
    axes : tuple, optional
        Axes over which to compute the DWT.

    Returns
    -------
    wavelet_transform : array_like
        The stitched together elements wvlt (see combine_chunks).
    locations : list
        Indices telling us how we stitched it together so we can take
        it back apart.

    Notes
    -----
    See PyWavelets documentation on pywt.wavedec2() for more
    information.

    If level=None (default) then it will be calculated using the
    dwt_max_level function.
    '''

    # Make sure we don't go too deep
    max_level = pywt.dwtn_max_level(x.shape, wavelet)
    if level is not None and level > max_level:
        msg = ('Level %d cannot be achieved, using max level=%d!'
               '' % (level, max_level))
        warnings.warn(msg)
        level = max_level

    # Do the pywavelets thing
    wvlt = pywt.wavedec2(x, wavelet, mode, level, axes)

    # But wvlt is a bunch of tuples, and we want them all stitched
    # together:
    return combine_chunks(wvlt, x.shape, x.dtype)
Пример #11
0
def estimate_sparsity(im, wavelet, nres=None, eps=1e-1):
    '''Python implementation of csl_compute_sparsity_of_image from
    https://bitbucket.org/vegarant/cslib.git

    Parameters:
        im: 2d numpy array. Can be complex-valued
        wavelet: Name of wavelet (TODO: or wavelet object?)
        nres: Optional. Number of wavelet levels. Default given by 
              `pywt.dwtn_max_level`
        eps: Tolerance. Everything below eps in abs. value will be treated 
             as zero

    Returns:
        List of sparsity for each level, beginning in the upper left corner, i.e. 
        the low res coefficients
    '''

    im = np.abs(im)
    M = im.max()
    m = im.min()
    im = (im - m) / (M - m)

    # Default to max number of levels
    if nres is None:
        nres = pywt.dwtn_max_level(im.shape, wavelet)

    coeffs = pywt.wavedec2(im, wavelet, 'periodization', nres)

    sparsities = [np.sum(np.abs(coeffs[0]) > eps)]
    for level in coeffs[1:]:
        # TODO: Could be done more elegantly
        sparsities.append(0)
        for detail in level:
            sparsities[-1] += np.sum(np.abs(detail) > eps)

    return sparsities
Пример #12
0
    def __init__(self,
                 space,
                 wavelet,
                 nlevels,
                 variant,
                 pad_mode='constant',
                 pad_const=0,
                 impl='pywt',
                 axes=None):
        """Initialize a new instance.

        Parameters
        ----------
        space : `DiscreteLp`
            Domain of the forward wavelet transform (the "image domain").
            In the case of ``variant in ('inverse', 'adjoint')``, this
            space is the range of the operator.
        wavelet : string or `pywt.Wavelet`
            Specification of the wavelet to be used in the transform.
            If a string is given, it is converted to a `pywt.Wavelet`.
            Use `pywt.wavelist` to get a list of available wavelets.

            Possible wavelet families are:

            ``'haar'``: Haar

            ``'db'``: Daubechies

            ``'sym'``: Symlets

            ``'coif'``: Coiflets

            ``'bior'``: Biorthogonal

            ``'rbio'``: Reverse biorthogonal

            ``'dmey'``: Discrete FIR approximation of the Meyer wavelet

        variant : {'forward', 'inverse', 'adjoint'}
            Wavelet transform variant to be created.
        nlevels : positive int, optional
            Number of scaling levels to be used in the decomposition. The
            maximum number of levels can be calculated with
            `pywt.dwtn_max_level`.
            Default: Use maximum number of levels.
        pad_mode : string, optional
            Method to be used to extend the signal.

            ``'constant'``: Fill with ``pad_const``.

            ``'symmetric'``: Reflect at the boundaries, not repeating the
            outmost values.

            ``'periodic'``: Fill in values from the other side, keeping
            the order.

            ``'order0'``: Extend constantly with the outmost values
            (ensures continuity).

            ``'order1'``: Extend with constant slope (ensures continuity of
            the first derivative). This requires at least 2 values along
            each axis where padding is applied.

            ``'pywt_per'``:  like ``'periodic'``-padding but gives the smallest
            possible number of decomposition coefficients.
            Only available with ``impl='pywt'``, See ``pywt.Modes.modes``.

            ``'reflect'``: Reflect at the boundary, without repeating the
            outmost values.

            ``'antisymmetric'``: Anti-symmetric variant of ``symmetric``.

            ``'antireflect'``: Anti-symmetric variant of ``reflect``.

            For reference, the following table compares the naming conventions
            for the modes in ODL vs. PyWavelets::

                ======================= ==================
                          ODL               PyWavelets
                ======================= ==================
                symmetric               symmetric
                reflect                 reflect
                order1                  smooth
                order0                  constant
                constant, pad_const=0   zero
                periodic                periodic
                pywt_per                periodization
                antisymmetric           antisymmetric
                antireflect             antireflect
                ======================= ==================

            See `signal extension modes`_ for an illustration of the modes
            (under the PyWavelets naming conventions).
        pad_const : float, optional
            Constant value to use if ``pad_mode == 'constant'``. Ignored
            otherwise. Constants other than 0 are not supported by the
            ``pywt`` back-end.
        impl : {'pywt'}, optional
            Back-end for the wavelet transform.
        axes : sequence of ints, optional
            Axes over which the DWT that created ``coeffs`` was performed.  The
            default value of ``None`` corresponds to all axes. When not all
            axes are included this is analagous to a batch transform in
            ``len(axes)`` dimensions looped over the non-transformed axes. In
            orther words, filtering and decimation does not occur along any
            axes not in ``axes``.

        References
        ----------
        .. _signal extension modes:
           https://pywavelets.readthedocs.io/en/latest/ref/signal-extension-modes.html
        """
        if not isinstance(space, DiscreteLp):
            raise TypeError('`space` {!r} is not a `DiscreteLp` instance.'
                            ''.format(space))

        self.__impl, impl_in = str(impl).lower(), impl
        if self.impl not in _SUPPORTED_WAVELET_IMPLS:
            raise ValueError("`impl` '{}' not supported".format(impl_in))

        if axes is None:
            axes = tuple(range(space.ndim))
        elif np.isscalar(axes):
            axes = (axes, )
        elif len(axes) > space.ndim:
            raise ValueError("Too many axes.")
        self.axes = tuple(axes)

        if nlevels is None:
            nlevels = pywt.dwtn_max_level(space.shape, wavelet, self.axes)
        self.__nlevels, nlevels_in = int(nlevels), nlevels
        if self.nlevels != nlevels_in:
            raise ValueError('`nlevels` must be integer, got {}'
                             ''.format(nlevels_in))

        self.__impl, impl_in = str(impl).lower(), impl
        if self.impl not in _SUPPORTED_WAVELET_IMPLS:
            raise ValueError("`impl` '{}' not supported".format(impl_in))

        self.__wavelet = getattr(wavelet, 'name', str(wavelet).lower())
        self.__pad_mode = str(pad_mode).lower()
        self.__pad_const = space.field.element(pad_const)

        if self.impl == 'pywt':
            self.pywt_pad_mode = pywt_pad_mode(pad_mode, pad_const)
            self.pywt_wavelet = pywt_wavelet(self.wavelet)
            # determine coefficient shapes (without running wavedecn)
            self._coeff_shapes = pywt.wavedecn_shapes(space.shape,
                                                      wavelet,
                                                      mode=self.pywt_pad_mode,
                                                      level=self.nlevels,
                                                      axes=self.axes)
            # precompute slices into the (raveled) coeffs
            self._coeff_slices = precompute_raveled_slices(self._coeff_shapes)
            coeff_size = pywt.wavedecn_size(self._coeff_shapes)
            coeff_space = space.tspace_type(coeff_size, dtype=space.dtype)
        else:
            raise RuntimeError("bad `impl` '{}'".format(self.impl))

        variant, variant_in = str(variant).lower(), variant
        if variant not in ('forward', 'inverse', 'adjoint'):
            raise ValueError("`variant` '{}' not understood"
                             "".format(variant_in))
        self.__variant = variant

        if variant == 'forward':
            super(WaveletTransformBase, self).__init__(domain=space,
                                                       range=coeff_space,
                                                       linear=True)
        else:
            super(WaveletTransformBase, self).__init__(domain=coeff_space,
                                                       range=space,
                                                       linear=True)
Пример #13
0
def _wavelet_threshold(image,
                       wavelet,
                       method=None,
                       threshold=None,
                       sigma=None,
                       mode='soft',
                       wavelet_levels=None):
    """Perform wavelet thresholding.

    Parameters
    ----------
    image : ndarray (2d or 3d) of ints, uints or floats
        Input data to be denoised. `image` can be of any numeric type,
        but it is cast into an ndarray of floats for the computation
        of the denoised image.
    wavelet : string
        The type of wavelet to perform. Can be any of the options
        pywt.wavelist outputs. For example, this may be any of ``{db1, db2,
        db3, db4, haar}``.
    method : {'BayesShrink', 'VisuShrink'}, optional
        Thresholding method to be used. The currently supported methods are
        "BayesShrink" [1]_ and "VisuShrink" [2]_. If it is set to None, a
        user-specified ``threshold`` must be supplied instead.
    threshold : float, optional
        The thresholding value to apply during wavelet coefficient
        thresholding. The default value (None) uses the selected ``method`` to
        estimate appropriate threshold(s) for noise removal.
    sigma : float, optional
        The standard deviation of the noise. The noise is estimated when sigma
        is None (the default) by the method in [2]_.
    mode : {'soft', 'hard'}, optional
        An optional argument to choose the type of denoising performed. It
        noted that choosing soft thresholding given additive noise finds the
        best approximation of the original image.
    wavelet_levels : int or None, optional
        The number of wavelet decomposition levels to use.  The default is
        three less than the maximum number of possible decomposition levels
        (see Notes below).

    Returns
    -------
    out : ndarray
        Denoised image.

    References
    ----------
    .. [1] Chang, S. Grace, Bin Yu, and Martin Vetterli. "Adaptive wavelet
           thresholding for image denoising and compression." Image Processing,
           IEEE Transactions on 9.9 (2000): 1532-1546.
           :DOI:`10.1109/83.862633`
    .. [2] D. L. Donoho and I. M. Johnstone. "Ideal spatial adaptation
           by wavelet shrinkage." Biometrika 81.3 (1994): 425-455.
           :DOI:`10.1093/biomet/81.3.425`
    """
    wavelet = pywt.Wavelet(wavelet)
    if not wavelet.orthogonal:
        warn(("Wavelet thresholding was designed for use with orthogonal "
              "wavelets. For nonorthogonal wavelets such as {}, results are "
              "likely to be suboptimal.").format(wavelet.name))

    # original_extent is used to workaround PyWavelets issue #80
    # odd-sized input results in an image with 1 extra sample after waverecn
    original_extent = tuple(slice(s) for s in image.shape)

    # Determine the number of wavelet decomposition levels
    if wavelet_levels is None:
        # Determine the maximum number of possible levels for image
        dlen = wavelet.dec_len
        wavelet_levels = pywt.dwtn_max_level(image.shape, wavelet)

        # Skip coarsest wavelet scales (see Notes in docstring).
        wavelet_levels = max(wavelet_levels - 3, 1)

    coeffs = pywt.wavedecn(image, wavelet=wavelet, level=wavelet_levels)
    # Detail coefficients at each decomposition level
    dcoeffs = coeffs[1:]

    if sigma is None:
        # Estimate the noise via the method in [2]_
        detail_coeffs = dcoeffs[-1]['d' * image.ndim]
        sigma = _sigma_est_dwt(detail_coeffs, distribution='Gaussian')

    if method is not None and threshold is not None:
        warn(("Thresholding method {} selected.  The user-specified threshold "
              "will be ignored.").format(method))

    if threshold is None:
        var = sigma**2
        if method is None:
            raise ValueError(
                "If method is None, a threshold must be provided.")
        elif method == "BayesShrink":
            # The BayesShrink thresholds from [1]_ in docstring
            threshold = [{
                key: _bayes_thresh(level[key], var)
                for key in level
            } for level in dcoeffs]
        elif method == "VisuShrink":
            # The VisuShrink thresholds from [2]_ in docstring
            threshold = _universal_thresh(image, sigma)
        else:
            raise ValueError("Unrecognized method: {}".format(method))

    if np.isscalar(threshold):
        # A single threshold for all coefficient arrays
        denoised_detail = [{
            key: pywt.threshold(level[key], value=threshold, mode=mode)
            for key in level
        } for level in dcoeffs]
    else:
        # Dict of unique threshold coefficients for each detail coeff. array
        denoised_detail = [{
            key: pywt.threshold(level[key], value=thresh[key], mode=mode)
            for key in level
        } for thresh, level in zip(threshold, dcoeffs)]
    denoised_coeffs = [coeffs[0]] + denoised_detail
    return pywt.waverecn(denoised_coeffs, wavelet)[original_extent]
def wavedec2(
    data: torch.Tensor,
    wavelet: Union[Wavelet, str],
    level: Optional[int] = None,
    mode: str = "reflect",
) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor,
                                    torch.Tensor]]]:
    """Non seperated two dimensional wavelet transform.

    Args:
        data (torch.Tensor): The input data tensor of shape
            [batch_size, 1, height, width].
            2d inputs are interpreted as [height, width],
            3d inputs are interpreted as [batch_size, height, width].
        wavelet (Wavelet or str): A pywt wavelet compatible object or
            the name of a pywt wavelet.
        level (int): The number of desired scales.
            Defaults to None.
        mode (str): The padding mode, i.e. zero or reflect.
            Defaults to reflect.

    Returns:
        list: A list containing the wavelet coefficients.
              The coefficients are in pywt order. That is:
              [cAn, (cHn, cVn, cDn), … (cH1, cV1, cD1)] .
              A denotes approximation, H horizontal, V vertical
              and D diagonal coefficients.

    Examples::
        >>> import torch
        >>> import ptwt, pywt
        >>> import numpy as np
        >>> import scipy.misc
        >>> face = np.transpose(scipy.misc.face(),
                                [2, 0, 1]).astype(np.float64)
        >>> pytorch_face = torch.tensor(face).unsqueeze(1)
        >>> coefficients = ptwt.wavedec2(pytorch_face, pywt.Wavelet("haar"),
                                         level=2, mode="constant")

    """
    if data.dim() == 2:
        data = data.unsqueeze(0).unsqueeze(0)
    elif data.dim() == 3:
        data = data.unsqueeze(1)

    wavelet = _as_wavelet(wavelet)
    dec_lo, dec_hi, _, _ = get_filter_tensors(wavelet,
                                              flip=True,
                                              device=data.device,
                                              dtype=data.dtype)
    dec_filt = construct_2d_filt(lo=dec_lo, hi=dec_hi)

    if level is None:
        level = pywt.dwtn_max_level([data.shape[-1], data.shape[-2]], wavelet)

    result_lst: List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor,
                                               torch.Tensor]]] = []
    res_ll = data
    for s in range(level):
        res_ll = fwt_pad2d(res_ll, wavelet, level=s, mode=mode)
        res = torch.nn.functional.conv2d(res_ll, dec_filt, stride=2)
        res_ll, res_lh, res_hl, res_hh = torch.split(res, 1, 1)
        result_lst.append((res_lh, res_hl, res_hh))
    result_lst.append(res_ll)
    return result_lst[::-1]
Пример #15
0
def main(args):
    # Dirty solution
    global _args
    _args = args
    # Parametri pri stiskanju in odstranjevanju suma
    threshold_value = args.threshold
    threshold_mode = args.mode
    dec_level = args.levels

    if not os.path.isfile(args.image):
        print('Given file <{}> does not exist!'.format(args.image))
        sys.exit(1)

    if args.wavelets:
        print('Available wavelets:')
        print(pywt.wavelist(kind='discrete'))

    image_file = args.image
    if args.wavelet == 'test':
        # Test some discrete wavelets in pywt (dmey excluded because memory error?)
        wavelets = ['bior1.3', 'haar', 'db4', 'coif1', 'sym2']
        test_wavelets(image_file, threshold_value, dec_level, threshold_mode,
                      wavelets)
    else:
        # Preberi sliko s pomocjo opencv
        image_data = read_image(image_file, args.grayscale)
        wavelet = pywt.Wavelet(args.wavelet)
        # V kolikor levels ni podan uporabi pywt max level funkcijo
        dec_level = int(args.levels) if args.levels else pywt.dwtn_max_level(
            image_data[0].shape, wavelet)
        # Sliko razdeli v kanale (grayscale ima le enega) in obdelaj vsak kanal
        # Note: opencv po privzetem prebere sliko kot BGR oz. BGRA in ne RGB
        denoised_image = []
        for channel in cv2.split(image_data):
            # Izvedi 2D DWT nad barvnim kanalom
            dwt_data = dwt(channel, wavelet, dl=dec_level)

            # Izvedi pragovno odstranjevanje motenj
            denoised_data = threshold(dwt_data,
                                      threshold_value,
                                      mode=threshold_mode)

            # Pretvori obdelan kanal z iDWT in ga shrani
            denoised_image.append(idwt(denoised_data, wavelet))

        # V kolikor je vec kanalov (ni grayscale) jih zdruzi v sliko
        if len(denoised_image) > 1:
            new_image = cv2.merge(denoised_image)
        else:
            new_image = denoised_image[0]

        # Pearson korelacija
        pr = pearson(image_data, new_image)

        # normaliziran RMSE
        r = calc_nrmse(image_data.ravel(), new_image.ravel())

        # Shrani sliko v _new
        new_file = save_image(image_file, new_image)

        # Izracunaj razmerje stiskanja
        cr = calc_compression(image_file, new_file)

        # Prikazi rezultate
        print_metrics(pr, r, cr, new_file, threshold_value, threshold_mode,
                      dec_level, wavelet)
Пример #16
0
def compute_depth_map(
    depth_cues,
    iterations=500,
    lambda_tv=2.0,
    lambda_d2=0.05,
    lambda_wl=None,
    use_defocus=1.0,
    use_correspondence=1.0,
    use_xcorrelation=0.0,
):
    """Computes a depth map from the given depth cues.

    This depth map is based on the procedure from:

    M. W. Tao, et al., "Depth from combining defocus and correspondence using
    light-field cameras," in Proceedings of the IEEE International Conference on
    Computer Vision, 2013, pp. 673–680.

    :param depth_cues: The depth cues
    :type depth_cues: dict
    :param iterations: Number of iterations, defaults to 500
    :type iterations: int, optional
    :param lambda_tv: Lambda value of the TV term, defaults to 2.0
    :type lambda_tv: float, optional
    :param lambda_d2: Lambda value of the smoothing term, defaults to 0.05
    :type lambda_d2: float, optional
    :param lambda_wl: Lambda value of the wavelet term, defaults to None
    :type lambda_wl: float, optional
    :param use_defocus: Weight of defocus cues, defaults to 1.0
    :type use_defocus: float, optional
    :param use_correspondence: Weight of corresponence cues, defaults to 1.0
    :type use_correspondence: float, optional
    :param use_xcorrelation: Weight of the cross-correlation cues, defaults to 0.0
    :type use_xcorrelation: float, optional

    :raises ValueError: In case of requested wavelet regularization but not available

    :returns: The depth map
    :rtype: `numpy.array_like`
    """
    if not (lambda_wl is None or (has_pywt and use_swtn)):
        raise ValueError("Wavelet regularization requested but not available")

    use_defocus = np.fmax(use_defocus, 0.0)
    use_defocus = np.fmin(use_defocus, 1.0)
    use_correspondence = np.fmax(use_correspondence, 0.0)
    use_correspondence = np.fmin(use_correspondence, 1.0)
    use_xcorrelation = np.fmax(use_xcorrelation, 0.0)
    use_xcorrelation = np.fmin(use_xcorrelation, 1.0)

    W_d = depth_cues["confidence_defocus"]
    a_d = depth_cues["depth_defocus"]

    W_c = depth_cues["confidence_correspondence"]
    a_c = depth_cues["depth_correspondence"]

    W_x = depth_cues["confidence_xcorrelation"]
    a_x = depth_cues["depth_xcorrelation"]

    if use_defocus > 0 and (W_d.size == 0 or a_d.size == 0):
        use_defocus = 0
        warnings.warn("Defocusing parameters were not passed, disabling their use")

    if use_correspondence > 0 and (W_c.size == 0 or a_c.size == 0):
        use_correspondence = 0
        warnings.warn("Correspondence parameters were not passed, disabling their use")

    if use_xcorrelation > 0 and (W_x.size == 0 or a_x.size == 0):
        use_xcorrelation = 0
        warnings.warn("Cross-correlation parameters were not passed, disabling their use")

    if use_defocus:
        img_size = a_d.shape
        data_type = a_d.dtype
    elif use_correspondence:
        img_size = a_c.shape
        data_type = a_c.dtype
    elif use_xcorrelation:
        img_size = a_x.shape
        data_type = a_x.dtype
    else:
        raise ValueError("Cannot proceed if at least one of Defocus, Correspondence, and Cross-correlation cues can be used")

    if lambda_wl is not None and has_pywt is False:
        lambda_wl = None
        print("WARNING - wavelets selected but not available")

    depth = np.zeros(img_size, dtype=data_type)
    depth_it = depth

    q_g = np.zeros(np.concatenate(((2,), img_size)), dtype=data_type)
    tau = 4 * lambda_tv
    if lambda_d2 is not None:
        q_l = np.zeros(img_size, dtype=data_type)
        tau += 8 * lambda_d2
    if use_defocus > 0:
        q_d = np.zeros(img_size, dtype=data_type)
        tau += W_d
    if use_correspondence > 0:
        q_c = np.zeros(img_size, dtype=data_type)
        tau += W_c
    if use_xcorrelation > 0:
        q_x = np.zeros(img_size, dtype=data_type)
        tau += W_x
    if lambda_wl is not None:
        wl_type = "sym4"
        wl_lvl = np.fmin(pywt.dwtn_max_level(img_size, wl_type), 2)
        print("Wavelets selected! Wl type: %s, Wl lvl %d" % (wl_type, wl_lvl))
        q_wl = pywt.swtn(depth, wl_type, wl_lvl)
        tau += lambda_wl * (2 ** wl_lvl)
        sigma_wl = 1 / (2 ** np.arange(wl_lvl, 0, -1))
    tau = 1 / tau

    for ii in range(iterations):
        (d0, d1) = _gradient2(depth_it)
        d_2 = np.stack((d0, d1)) / 2
        q_g += d_2
        grad_l2_norm = np.fmax(1, np.sqrt(np.sum(q_g ** 2, axis=0)))
        q_g /= grad_l2_norm

        update = -lambda_tv * _divergence2(q_g[0, :, :], q_g[1, :, :])
        if lambda_d2 is not None:
            l_dep = _laplacian2(depth_it)
            q_l += l_dep / 8
            q_l /= np.fmax(1, np.abs(q_l))

            update += lambda_d2 * _laplacian2(q_l)

        if use_defocus > 0:
            q_d += depth_it - a_d
            q_d /= np.fmax(1, np.abs(q_d))

            update += use_defocus * W_d * q_d

        if use_correspondence > 0:
            q_c += depth_it - a_c
            q_c /= np.fmax(1, np.abs(q_c))

            update += use_correspondence * W_c * q_c

        if use_xcorrelation > 0:
            q_x += depth_it - a_x
            q_x /= np.fmax(1, np.abs(q_x))

            update += use_xcorrelation * W_x * q_x

        if lambda_wl is not None:
            d = pywt.swtn(depth_it, wl_type, wl_lvl)
            for ii_l in range(wl_lvl):
                for k in q_wl[ii_l].keys():
                    q_wl[ii_l][k] += d[ii_l][k] * sigma_wl[ii_l]
                    q_wl[ii_l][k] /= np.fmax(1, np.abs(q_wl[ii_l][k]))
            update += lambda_wl * pywt.iswtn(q_wl, wl_type)

        depth_new = depth - update * tau
        depth_it = depth_new + (depth_new - depth)
        depth = depth_new

    return depth