def wavedec2(data: np.array, wavelet: JaxWavelet, level: int = None) -> list: """Compute the two dimensional wavelet analysis transform on the last two dimensions of the input data array. Args: data (np.array): Jax array containing the data to be transformed. Assumed shape: [batch size, channels, hight, width]. wavelet (JaxWavelet): A namedtouple containing the filters for the transformation. level (int, optional): The max level to be used, if not set as many levels as possible will be used. Defaults to None. Returns: list: The wavelet coefficients in a nested list. """ dec_lo, dec_hi, _, _ = get_filter_arrays(wavelet, flip=True) dec_filt = construct_2d_filt(lo=dec_lo, hi=dec_hi) filt_len = dec_lo.shape[-1] if level is None: level = pywt.dwtn_max_level([data.shape[-1], data.shape[-2]], pywt.Wavelet('MyWavelet', wavelet)) result_lst = [] res_ll = data for _ in range(level): res_ll = fwt_pad2d(res_ll, wavelet) res = jax.lax.conv_general_dilated( lhs=res_ll, # lhs = NCHw image tensor rhs=dec_filt, # rhs = OIHw conv kernel tensor padding='VALID', window_strides=[2, 2], dimension_numbers=('NCHW', 'OIHW', 'NCHW'), ) res_ll, res_lh, res_hl, res_hh = np.split(res, 4, 1) result_lst.append((res_lh, res_hl, res_hh)) result_lst.append(res_ll) result_lst.reverse() return result_lst
def conv_fwt_2d(data, wavelet, scales: int = None) -> list: """ Non seperated two dimensional wavelet transform. Args: data (torch.tensor): [batch_size, 1, height, width] wavelet (WaveletFilter): The wavelet object to be used. scales (int, optional): The scale level to be computed. Defaults to None. Returns: [list]: List containing the wavelet coefficients. """ # dec_lo, dec_hi, _, _ = wavelet.filter_bank # filt_len = len(dec_lo) # dec_lo = torch.tensor(dec_lo[::-1]).unsqueeze(0) # dec_hi = torch.tensor(dec_hi[::-1]).unsqueeze(0) dec_lo, dec_hi, _, _ = get_filter_tensors(wavelet, flip=True, device=data.device) # filt_len = dec_lo.shape[-1] dec_filt = construct_2d_filt(lo=dec_lo, hi=dec_hi) if scales is None: scales = pywt.dwtn_max_level([data.shape[-1], data.shape[-2]], wavelet) result_lst = [] res_ll = data for s in range(scales): res_ll = fwt_pad2d(res_ll, wavelet) res = torch.nn.functional.conv2d(res_ll, dec_filt, stride=2) res_ll, res_lh, res_hl, res_hh = torch.split(res, 1, 1) result_lst.append((res_lh, res_hl, res_hh)) result_lst.append(res_ll) return result_lst[::-1]
def sep_conv_fwt_2d(data, wavelet, scales: int = None) -> list: """ Non seperated two dimensional wavelet transform. Args: data (torch.tensor): [batch_size, height, width] wavelet (util.WaveletFilter or pywt.wavelet): The wavelet object. scales (int, optional): The number of decomposition scales. Defaults to None. Returns: [list]: List containing the wavelet coefficients. """ ds = data.shape dec_lo, dec_hi, _, _ = get_filter_tensors(wavelet, flip=True, device=data.device) filt = torch.stack([dec_lo, dec_hi], 0) if scales is None: scales = pywt.dwtn_max_level([data.shape[-1], data.shape[-2]], wavelet) result_lst = [] res_ll = data for s in range(scales): res_ll = fwt_pad2d(res_ll, wavelet) res_ll = res_ll.reshape( [ds[0] * ds[1], res_ll.shape[2], res_ll.shape[3]]) rll_s = res_ll.shape res_llr = res_ll.reshape([rll_s[0] * rll_s[1], rll_s[2]]).unsqueeze(1) res = torch.nn.functional.conv1d(res_llr, filt, stride=2) res_l, res_h = torch.split(res, 1, 1) res_l = res_l.reshape([rll_s[0], rll_s[1], rll_s[2] // 2]) res_h = res_h.reshape([rll_s[0], rll_s[1], rll_s[2] // 2]) res_lt = res_l.permute(0, 2, 1) res_ht = res_h.permute(0, 2, 1) res_ltr = res_lt.reshape([-1, res_lt.shape[-1]]).unsqueeze(1) res_htr = res_ht.reshape([-1, res_ht.shape[-1]]).unsqueeze(1) res_l2 = torch.nn.functional.conv1d(res_ltr, filt, stride=2) res_h2 = torch.nn.functional.conv1d(res_htr, filt, stride=2) res_ll, res_lh = torch.split(res_l2, 1, 1) res_hl, res_hh = torch.split(res_h2, 1, 1) res_llr = res_ll.reshape(rll_s[0], rll_s[2] // 2, rll_s[1] // 2) res_lhr = res_lh.reshape(rll_s[0], rll_s[2] // 2, rll_s[1] // 2) res_hlr = res_hl.reshape(rll_s[0], rll_s[2] // 2, rll_s[1] // 2) res_hhr = res_hh.reshape(rll_s[0], rll_s[2] // 2, rll_s[1] // 2) res_llrp = res_llr.permute([0, 2, 1]) res_lhrp = res_lhr.permute([0, 2, 1]) res_hlrp = res_hlr.permute([0, 2, 1]) res_hhrp = res_hhr.permute([0, 2, 1]) # res = torch.nn.functional.conv2d(res_ll, dec_filt, stride=2) # res_ll, res_lh, res_hl, res_hh = torch.split(res, 1, 1) result_lst.append( (res_lhrp.reshape(ds[0], ds[1], rll_s[1] // 2, rll_s[2] // 2), res_hlrp.reshape(ds[0], ds[1], rll_s[1] // 2, rll_s[2] // 2), res_hhrp.reshape(ds[0], ds[1], rll_s[1] // 2, rll_s[2] // 2))) res_ll = res_llrp.reshape(ds[0], ds[1], rll_s[1] // 2, rll_s[2] // 2) result_lst.append( res_llrp.reshape(ds[0], ds[1], rll_s[1] // 2, rll_s[2] // 2)) return result_lst[::-1]
def update_level(self): wavelet = self.wavelet_combo.currentText() max_level = pywt.dwtn_max_level(self.image.shape[:-1], wavelet) self.level_spin.blockSignals(True) self.level_spin.setRange(1, max_level) self.level_spin.setValue(max_level // 2) self.level_spin.blockSignals(False) self.compute_dwt()
def inverse_wavelet_transform(vres, inv_filters=None, output_shape=None, levels=None): if inv_filters is None: w = pywt.Wavelet('db4') rec_hi = np.array(w.rec_hi) rec_lo = np.array(w.rec_lo) inv_filters = np.stack([ rec_lo[None, None, :] * rec_lo[None, :, None] * rec_lo[:, None, None], rec_lo[None, None, :] * rec_lo[None, :, None] * rec_hi[:, None, None], rec_lo[None, None, :] * rec_hi[None, :, None] * rec_lo[:, None, None], rec_lo[None, None, :] * rec_hi[None, :, None] * rec_hi[:, None, None], rec_hi[None, None, :] * rec_lo[None, :, None] * rec_lo[:, None, None], rec_hi[None, None, :] * rec_lo[None, :, None] * rec_hi[:, None, None], rec_hi[None, None, :] * rec_hi[None, :, None] * rec_lo[:, None, None], rec_hi[None, None, :] * rec_hi[None, :, None] * rec_hi[:, None, None] ]).transpose((1, 2, 3, 0))[:, :, :, None, :] inv_filters = K.constant(inv_filters) if levels is None: levels = pywt.dwtn_max_level(K.int_shape(vres)[1:4], 'db4') print(levels) t = vres.shape[1] h = vres.shape[2] w = vres.shape[3] ''' res = K.permute_dimensions(vres, (0, 4, 1, 2, 3)) res = K.reshape(res, (-1, t // 2, 2, h // 2, w // 2)) res = K.permute_dimensions(res, (0, 2, 1, 3, 4)) res = K.reshape(res, (-1, 8, t // 2, h // 2, w // 2)) res = K.permute_dimensions(res, (0, 2, 3, 4, 1)) ''' res = K.reshape(vres, (-1, t // 2, h // 2, w // 2, 8)) if levels > 1: res = K.concatenate([ inverse_wavelet_transform( res[:, :, :, :, :1], inv_filters, output_shape=(K.shape(vres)[0], K.shape(vres)[1] // 2, K.shape(vres)[2] // 2, K.shape(vres)[3] // 2, K.shape(vres)[4]), levels=(levels - 1)), res[:, :, :, :, 1:] ], axis=-1) res = K.conv3d_transpose(res, inv_filters, output_shape=K.shape(vres), strides=(2, 2, 2), padding='same') out = res[:, :output_shape[1], :output_shape[2], :output_shape[3], :] #print('iwt', levels, K.int_shape(vres), K.int_shape(inv_filters), K.int_shape(res), K.int_shape(out), output_shape) return out
def wavelet_transform(img, filters=None, levels=None): if levels is None: vimg = tf.pad(img, [(0, 0), (0, 2**int(np.ceil(np.log2(K.int_shape(img)[1]))) - K.int_shape(img)[1]), (0, 2**int(np.ceil(np.log2(K.int_shape(img)[2]))) - K.int_shape(img)[2]), (0, 2**int(np.ceil(np.log2(K.int_shape(img)[3]))) - K.int_shape(img)[3]), (0, 0)]) else: vimg = img if filters is None: w = pywt.Wavelet('db4') dec_hi = np.array(w.dec_hi[::-1]) dec_lo = np.array(w.dec_lo[::-1]) filters = np.stack([ dec_lo[None, None, :] * dec_lo[None, :, None] * dec_lo[:, None, None], dec_lo[None, None, :] * dec_lo[None, :, None] * dec_hi[:, None, None], dec_lo[None, None, :] * dec_hi[None, :, None] * dec_lo[:, None, None], dec_lo[None, None, :] * dec_hi[None, :, None] * dec_hi[:, None, None], dec_hi[None, None, :] * dec_lo[None, :, None] * dec_lo[:, None, None], dec_hi[None, None, :] * dec_lo[None, :, None] * dec_hi[:, None, None], dec_hi[None, None, :] * dec_hi[None, :, None] * dec_lo[:, None, None], dec_hi[None, None, :] * dec_hi[None, :, None] * dec_hi[:, None, None] ]).transpose((1, 2, 3, 0))[:, :, :, None, :] filters = K.constant(filters) if levels is None: print(K.int_shape(vimg)[1:4]) levels = pywt.dwtn_max_level(K.int_shape(vimg)[1:4], 'db4') print(levels) t = vimg.shape[1] h = vimg.shape[2] w = vimg.shape[3] res = K.conv3d(vimg, filters, strides=(2, 2, 2), padding='same') if levels > 1: res = K.concatenate([ wavelet_transform(res[:, :, :, :, :1], filters, levels=(levels - 1)), res[:, :, :, :, 1:] ], axis=-1) ''' res = K.permute_dimensions(res, (0, 4, 1, 2, 3)) res = K.reshape(res, (-1, 2, t // 2, h // 2, w // 2)) res = K.permute_dimensions(res, (0, 2, 1, 3, 4)) res = K.reshape(res, (-1, 1, t, h, w)) res = K.permute_dimensions(res, (0, 2, 3, 4, 1)) ''' res = K.reshape(res, (-1, t, h, w, 1)) #print('wt', levels, K.int_shape(img), K.int_shape(vimg), K.int_shape(filters), K.int_shape(res)) return res
def test_dwtn_max_level(): # predicted and empirical dwtn_max_level match for wav in [pywt.Wavelet('db2'), 'sym8']: for data_shape in [(33, ), (64, 32), (1, 15, 30)]: for axes in [None, 0, -1]: for mode in pywt.Modes.modes: coeffs = pywt.wavedecn(np.ones(data_shape), wav, mode=mode, axes=axes) max_lev = pywt.dwtn_max_level(data_shape, wav, axes) assert_equal(len(coeffs[1:]), max_lev)
def cdf97_2d_forward(x, level, axes=(-2, -1)): '''Forward 2D Cohen–Daubechies–Feauveau 9/7 wavelet. Parameters ---------- x : array_like 2D signal. level : int Decomposition level. axes : tuple, optional Axes to perform wavelet decomposition across. Returns ------- wavelet_transform : array_like The stitched together elements wvlt (see combine_chunks). locations : list Indices telling us how we stitched it together so we can take it back apart. Notes ----- Returns transform, same shape as input, with locations. Locations is a list of indices instructing cdf97_2d_inverse where the coefficients for each block are located. Biorthogonal 4/4 is the same as CDF 9/7 according to wikipedia [1]_. References ---------- .. [1] https://en.wikipedia.org/wiki/ Cohen%E2%80%93Daubechies%E2%80%93Feauveau_wavelet#Numbering ''' # Make sure we don't go too deep max_level = pywt.dwtn_max_level(x.shape, 'bior4.4', axes=axes) if level > max_level: msg = ('Level %d cannot be achieved, using max level=%d!' '' % (level, max_level)) warnings.warn(msg) level = max_level # periodization seems to be the only way to get shapes to line up. cdf97 = pywt.wavedec2(x, wavelet='bior4.4', mode='periodization', level=level, axes=axes) # Now throw all the chunks together return combine_chunks(cdf97, x.shape, x.dtype)
def wavelet_forward(x, wavelet, mode='symmetric', level=None, axes=(-2, -1)): '''Wrapper for the multilevel 2D discrete wavelet transform. Parameters ---------- x : array_like Input data. wavelet : str Wavelet to use. mode : str, optional Signal extension mode. level : int, optional Decomposition level (must be >= 0). axes : tuple, optional Axes over which to compute the DWT. Returns ------- wavelet_transform : array_like The stitched together elements wvlt (see combine_chunks). locations : list Indices telling us how we stitched it together so we can take it back apart. Notes ----- See PyWavelets documentation on pywt.wavedec2() for more information. If level=None (default) then it will be calculated using the dwt_max_level function. ''' # Make sure we don't go too deep max_level = pywt.dwtn_max_level(x.shape, wavelet) if level is not None and level > max_level: msg = ('Level %d cannot be achieved, using max level=%d!' '' % (level, max_level)) warnings.warn(msg) level = max_level # Do the pywavelets thing wvlt = pywt.wavedec2(x, wavelet, mode, level, axes) # But wvlt is a bunch of tuples, and we want them all stitched # together: return combine_chunks(wvlt, x.shape, x.dtype)
def estimate_sparsity(im, wavelet, nres=None, eps=1e-1): '''Python implementation of csl_compute_sparsity_of_image from https://bitbucket.org/vegarant/cslib.git Parameters: im: 2d numpy array. Can be complex-valued wavelet: Name of wavelet (TODO: or wavelet object?) nres: Optional. Number of wavelet levels. Default given by `pywt.dwtn_max_level` eps: Tolerance. Everything below eps in abs. value will be treated as zero Returns: List of sparsity for each level, beginning in the upper left corner, i.e. the low res coefficients ''' im = np.abs(im) M = im.max() m = im.min() im = (im - m) / (M - m) # Default to max number of levels if nres is None: nres = pywt.dwtn_max_level(im.shape, wavelet) coeffs = pywt.wavedec2(im, wavelet, 'periodization', nres) sparsities = [np.sum(np.abs(coeffs[0]) > eps)] for level in coeffs[1:]: # TODO: Could be done more elegantly sparsities.append(0) for detail in level: sparsities[-1] += np.sum(np.abs(detail) > eps) return sparsities
def __init__(self, space, wavelet, nlevels, variant, pad_mode='constant', pad_const=0, impl='pywt', axes=None): """Initialize a new instance. Parameters ---------- space : `DiscreteLp` Domain of the forward wavelet transform (the "image domain"). In the case of ``variant in ('inverse', 'adjoint')``, this space is the range of the operator. wavelet : string or `pywt.Wavelet` Specification of the wavelet to be used in the transform. If a string is given, it is converted to a `pywt.Wavelet`. Use `pywt.wavelist` to get a list of available wavelets. Possible wavelet families are: ``'haar'``: Haar ``'db'``: Daubechies ``'sym'``: Symlets ``'coif'``: Coiflets ``'bior'``: Biorthogonal ``'rbio'``: Reverse biorthogonal ``'dmey'``: Discrete FIR approximation of the Meyer wavelet variant : {'forward', 'inverse', 'adjoint'} Wavelet transform variant to be created. nlevels : positive int, optional Number of scaling levels to be used in the decomposition. The maximum number of levels can be calculated with `pywt.dwtn_max_level`. Default: Use maximum number of levels. pad_mode : string, optional Method to be used to extend the signal. ``'constant'``: Fill with ``pad_const``. ``'symmetric'``: Reflect at the boundaries, not repeating the outmost values. ``'periodic'``: Fill in values from the other side, keeping the order. ``'order0'``: Extend constantly with the outmost values (ensures continuity). ``'order1'``: Extend with constant slope (ensures continuity of the first derivative). This requires at least 2 values along each axis where padding is applied. ``'pywt_per'``: like ``'periodic'``-padding but gives the smallest possible number of decomposition coefficients. Only available with ``impl='pywt'``, See ``pywt.Modes.modes``. ``'reflect'``: Reflect at the boundary, without repeating the outmost values. ``'antisymmetric'``: Anti-symmetric variant of ``symmetric``. ``'antireflect'``: Anti-symmetric variant of ``reflect``. For reference, the following table compares the naming conventions for the modes in ODL vs. PyWavelets:: ======================= ================== ODL PyWavelets ======================= ================== symmetric symmetric reflect reflect order1 smooth order0 constant constant, pad_const=0 zero periodic periodic pywt_per periodization antisymmetric antisymmetric antireflect antireflect ======================= ================== See `signal extension modes`_ for an illustration of the modes (under the PyWavelets naming conventions). pad_const : float, optional Constant value to use if ``pad_mode == 'constant'``. Ignored otherwise. Constants other than 0 are not supported by the ``pywt`` back-end. impl : {'pywt'}, optional Back-end for the wavelet transform. axes : sequence of ints, optional Axes over which the DWT that created ``coeffs`` was performed. The default value of ``None`` corresponds to all axes. When not all axes are included this is analagous to a batch transform in ``len(axes)`` dimensions looped over the non-transformed axes. In orther words, filtering and decimation does not occur along any axes not in ``axes``. References ---------- .. _signal extension modes: https://pywavelets.readthedocs.io/en/latest/ref/signal-extension-modes.html """ if not isinstance(space, DiscreteLp): raise TypeError('`space` {!r} is not a `DiscreteLp` instance.' ''.format(space)) self.__impl, impl_in = str(impl).lower(), impl if self.impl not in _SUPPORTED_WAVELET_IMPLS: raise ValueError("`impl` '{}' not supported".format(impl_in)) if axes is None: axes = tuple(range(space.ndim)) elif np.isscalar(axes): axes = (axes, ) elif len(axes) > space.ndim: raise ValueError("Too many axes.") self.axes = tuple(axes) if nlevels is None: nlevels = pywt.dwtn_max_level(space.shape, wavelet, self.axes) self.__nlevels, nlevels_in = int(nlevels), nlevels if self.nlevels != nlevels_in: raise ValueError('`nlevels` must be integer, got {}' ''.format(nlevels_in)) self.__impl, impl_in = str(impl).lower(), impl if self.impl not in _SUPPORTED_WAVELET_IMPLS: raise ValueError("`impl` '{}' not supported".format(impl_in)) self.__wavelet = getattr(wavelet, 'name', str(wavelet).lower()) self.__pad_mode = str(pad_mode).lower() self.__pad_const = space.field.element(pad_const) if self.impl == 'pywt': self.pywt_pad_mode = pywt_pad_mode(pad_mode, pad_const) self.pywt_wavelet = pywt_wavelet(self.wavelet) # determine coefficient shapes (without running wavedecn) self._coeff_shapes = pywt.wavedecn_shapes(space.shape, wavelet, mode=self.pywt_pad_mode, level=self.nlevels, axes=self.axes) # precompute slices into the (raveled) coeffs self._coeff_slices = precompute_raveled_slices(self._coeff_shapes) coeff_size = pywt.wavedecn_size(self._coeff_shapes) coeff_space = space.tspace_type(coeff_size, dtype=space.dtype) else: raise RuntimeError("bad `impl` '{}'".format(self.impl)) variant, variant_in = str(variant).lower(), variant if variant not in ('forward', 'inverse', 'adjoint'): raise ValueError("`variant` '{}' not understood" "".format(variant_in)) self.__variant = variant if variant == 'forward': super(WaveletTransformBase, self).__init__(domain=space, range=coeff_space, linear=True) else: super(WaveletTransformBase, self).__init__(domain=coeff_space, range=space, linear=True)
def _wavelet_threshold(image, wavelet, method=None, threshold=None, sigma=None, mode='soft', wavelet_levels=None): """Perform wavelet thresholding. Parameters ---------- image : ndarray (2d or 3d) of ints, uints or floats Input data to be denoised. `image` can be of any numeric type, but it is cast into an ndarray of floats for the computation of the denoised image. wavelet : string The type of wavelet to perform. Can be any of the options pywt.wavelist outputs. For example, this may be any of ``{db1, db2, db3, db4, haar}``. method : {'BayesShrink', 'VisuShrink'}, optional Thresholding method to be used. The currently supported methods are "BayesShrink" [1]_ and "VisuShrink" [2]_. If it is set to None, a user-specified ``threshold`` must be supplied instead. threshold : float, optional The thresholding value to apply during wavelet coefficient thresholding. The default value (None) uses the selected ``method`` to estimate appropriate threshold(s) for noise removal. sigma : float, optional The standard deviation of the noise. The noise is estimated when sigma is None (the default) by the method in [2]_. mode : {'soft', 'hard'}, optional An optional argument to choose the type of denoising performed. It noted that choosing soft thresholding given additive noise finds the best approximation of the original image. wavelet_levels : int or None, optional The number of wavelet decomposition levels to use. The default is three less than the maximum number of possible decomposition levels (see Notes below). Returns ------- out : ndarray Denoised image. References ---------- .. [1] Chang, S. Grace, Bin Yu, and Martin Vetterli. "Adaptive wavelet thresholding for image denoising and compression." Image Processing, IEEE Transactions on 9.9 (2000): 1532-1546. :DOI:`10.1109/83.862633` .. [2] D. L. Donoho and I. M. Johnstone. "Ideal spatial adaptation by wavelet shrinkage." Biometrika 81.3 (1994): 425-455. :DOI:`10.1093/biomet/81.3.425` """ wavelet = pywt.Wavelet(wavelet) if not wavelet.orthogonal: warn(("Wavelet thresholding was designed for use with orthogonal " "wavelets. For nonorthogonal wavelets such as {}, results are " "likely to be suboptimal.").format(wavelet.name)) # original_extent is used to workaround PyWavelets issue #80 # odd-sized input results in an image with 1 extra sample after waverecn original_extent = tuple(slice(s) for s in image.shape) # Determine the number of wavelet decomposition levels if wavelet_levels is None: # Determine the maximum number of possible levels for image dlen = wavelet.dec_len wavelet_levels = pywt.dwtn_max_level(image.shape, wavelet) # Skip coarsest wavelet scales (see Notes in docstring). wavelet_levels = max(wavelet_levels - 3, 1) coeffs = pywt.wavedecn(image, wavelet=wavelet, level=wavelet_levels) # Detail coefficients at each decomposition level dcoeffs = coeffs[1:] if sigma is None: # Estimate the noise via the method in [2]_ detail_coeffs = dcoeffs[-1]['d' * image.ndim] sigma = _sigma_est_dwt(detail_coeffs, distribution='Gaussian') if method is not None and threshold is not None: warn(("Thresholding method {} selected. The user-specified threshold " "will be ignored.").format(method)) if threshold is None: var = sigma**2 if method is None: raise ValueError( "If method is None, a threshold must be provided.") elif method == "BayesShrink": # The BayesShrink thresholds from [1]_ in docstring threshold = [{ key: _bayes_thresh(level[key], var) for key in level } for level in dcoeffs] elif method == "VisuShrink": # The VisuShrink thresholds from [2]_ in docstring threshold = _universal_thresh(image, sigma) else: raise ValueError("Unrecognized method: {}".format(method)) if np.isscalar(threshold): # A single threshold for all coefficient arrays denoised_detail = [{ key: pywt.threshold(level[key], value=threshold, mode=mode) for key in level } for level in dcoeffs] else: # Dict of unique threshold coefficients for each detail coeff. array denoised_detail = [{ key: pywt.threshold(level[key], value=thresh[key], mode=mode) for key in level } for thresh, level in zip(threshold, dcoeffs)] denoised_coeffs = [coeffs[0]] + denoised_detail return pywt.waverecn(denoised_coeffs, wavelet)[original_extent]
def wavedec2( data: torch.Tensor, wavelet: Union[Wavelet, str], level: Optional[int] = None, mode: str = "reflect", ) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: """Non seperated two dimensional wavelet transform. Args: data (torch.Tensor): The input data tensor of shape [batch_size, 1, height, width]. 2d inputs are interpreted as [height, width], 3d inputs are interpreted as [batch_size, height, width]. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. level (int): The number of desired scales. Defaults to None. mode (str): The padding mode, i.e. zero or reflect. Defaults to reflect. Returns: list: A list containing the wavelet coefficients. The coefficients are in pywt order. That is: [cAn, (cHn, cVn, cDn), … (cH1, cV1, cD1)] . A denotes approximation, H horizontal, V vertical and D diagonal coefficients. Examples:: >>> import torch >>> import ptwt, pywt >>> import numpy as np >>> import scipy.misc >>> face = np.transpose(scipy.misc.face(), [2, 0, 1]).astype(np.float64) >>> pytorch_face = torch.tensor(face).unsqueeze(1) >>> coefficients = ptwt.wavedec2(pytorch_face, pywt.Wavelet("haar"), level=2, mode="constant") """ if data.dim() == 2: data = data.unsqueeze(0).unsqueeze(0) elif data.dim() == 3: data = data.unsqueeze(1) wavelet = _as_wavelet(wavelet) dec_lo, dec_hi, _, _ = get_filter_tensors(wavelet, flip=True, device=data.device, dtype=data.dtype) dec_filt = construct_2d_filt(lo=dec_lo, hi=dec_hi) if level is None: level = pywt.dwtn_max_level([data.shape[-1], data.shape[-2]], wavelet) result_lst: List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]] = [] res_ll = data for s in range(level): res_ll = fwt_pad2d(res_ll, wavelet, level=s, mode=mode) res = torch.nn.functional.conv2d(res_ll, dec_filt, stride=2) res_ll, res_lh, res_hl, res_hh = torch.split(res, 1, 1) result_lst.append((res_lh, res_hl, res_hh)) result_lst.append(res_ll) return result_lst[::-1]
def main(args): # Dirty solution global _args _args = args # Parametri pri stiskanju in odstranjevanju suma threshold_value = args.threshold threshold_mode = args.mode dec_level = args.levels if not os.path.isfile(args.image): print('Given file <{}> does not exist!'.format(args.image)) sys.exit(1) if args.wavelets: print('Available wavelets:') print(pywt.wavelist(kind='discrete')) image_file = args.image if args.wavelet == 'test': # Test some discrete wavelets in pywt (dmey excluded because memory error?) wavelets = ['bior1.3', 'haar', 'db4', 'coif1', 'sym2'] test_wavelets(image_file, threshold_value, dec_level, threshold_mode, wavelets) else: # Preberi sliko s pomocjo opencv image_data = read_image(image_file, args.grayscale) wavelet = pywt.Wavelet(args.wavelet) # V kolikor levels ni podan uporabi pywt max level funkcijo dec_level = int(args.levels) if args.levels else pywt.dwtn_max_level( image_data[0].shape, wavelet) # Sliko razdeli v kanale (grayscale ima le enega) in obdelaj vsak kanal # Note: opencv po privzetem prebere sliko kot BGR oz. BGRA in ne RGB denoised_image = [] for channel in cv2.split(image_data): # Izvedi 2D DWT nad barvnim kanalom dwt_data = dwt(channel, wavelet, dl=dec_level) # Izvedi pragovno odstranjevanje motenj denoised_data = threshold(dwt_data, threshold_value, mode=threshold_mode) # Pretvori obdelan kanal z iDWT in ga shrani denoised_image.append(idwt(denoised_data, wavelet)) # V kolikor je vec kanalov (ni grayscale) jih zdruzi v sliko if len(denoised_image) > 1: new_image = cv2.merge(denoised_image) else: new_image = denoised_image[0] # Pearson korelacija pr = pearson(image_data, new_image) # normaliziran RMSE r = calc_nrmse(image_data.ravel(), new_image.ravel()) # Shrani sliko v _new new_file = save_image(image_file, new_image) # Izracunaj razmerje stiskanja cr = calc_compression(image_file, new_file) # Prikazi rezultate print_metrics(pr, r, cr, new_file, threshold_value, threshold_mode, dec_level, wavelet)
def compute_depth_map( depth_cues, iterations=500, lambda_tv=2.0, lambda_d2=0.05, lambda_wl=None, use_defocus=1.0, use_correspondence=1.0, use_xcorrelation=0.0, ): """Computes a depth map from the given depth cues. This depth map is based on the procedure from: M. W. Tao, et al., "Depth from combining defocus and correspondence using light-field cameras," in Proceedings of the IEEE International Conference on Computer Vision, 2013, pp. 673–680. :param depth_cues: The depth cues :type depth_cues: dict :param iterations: Number of iterations, defaults to 500 :type iterations: int, optional :param lambda_tv: Lambda value of the TV term, defaults to 2.0 :type lambda_tv: float, optional :param lambda_d2: Lambda value of the smoothing term, defaults to 0.05 :type lambda_d2: float, optional :param lambda_wl: Lambda value of the wavelet term, defaults to None :type lambda_wl: float, optional :param use_defocus: Weight of defocus cues, defaults to 1.0 :type use_defocus: float, optional :param use_correspondence: Weight of corresponence cues, defaults to 1.0 :type use_correspondence: float, optional :param use_xcorrelation: Weight of the cross-correlation cues, defaults to 0.0 :type use_xcorrelation: float, optional :raises ValueError: In case of requested wavelet regularization but not available :returns: The depth map :rtype: `numpy.array_like` """ if not (lambda_wl is None or (has_pywt and use_swtn)): raise ValueError("Wavelet regularization requested but not available") use_defocus = np.fmax(use_defocus, 0.0) use_defocus = np.fmin(use_defocus, 1.0) use_correspondence = np.fmax(use_correspondence, 0.0) use_correspondence = np.fmin(use_correspondence, 1.0) use_xcorrelation = np.fmax(use_xcorrelation, 0.0) use_xcorrelation = np.fmin(use_xcorrelation, 1.0) W_d = depth_cues["confidence_defocus"] a_d = depth_cues["depth_defocus"] W_c = depth_cues["confidence_correspondence"] a_c = depth_cues["depth_correspondence"] W_x = depth_cues["confidence_xcorrelation"] a_x = depth_cues["depth_xcorrelation"] if use_defocus > 0 and (W_d.size == 0 or a_d.size == 0): use_defocus = 0 warnings.warn("Defocusing parameters were not passed, disabling their use") if use_correspondence > 0 and (W_c.size == 0 or a_c.size == 0): use_correspondence = 0 warnings.warn("Correspondence parameters were not passed, disabling their use") if use_xcorrelation > 0 and (W_x.size == 0 or a_x.size == 0): use_xcorrelation = 0 warnings.warn("Cross-correlation parameters were not passed, disabling their use") if use_defocus: img_size = a_d.shape data_type = a_d.dtype elif use_correspondence: img_size = a_c.shape data_type = a_c.dtype elif use_xcorrelation: img_size = a_x.shape data_type = a_x.dtype else: raise ValueError("Cannot proceed if at least one of Defocus, Correspondence, and Cross-correlation cues can be used") if lambda_wl is not None and has_pywt is False: lambda_wl = None print("WARNING - wavelets selected but not available") depth = np.zeros(img_size, dtype=data_type) depth_it = depth q_g = np.zeros(np.concatenate(((2,), img_size)), dtype=data_type) tau = 4 * lambda_tv if lambda_d2 is not None: q_l = np.zeros(img_size, dtype=data_type) tau += 8 * lambda_d2 if use_defocus > 0: q_d = np.zeros(img_size, dtype=data_type) tau += W_d if use_correspondence > 0: q_c = np.zeros(img_size, dtype=data_type) tau += W_c if use_xcorrelation > 0: q_x = np.zeros(img_size, dtype=data_type) tau += W_x if lambda_wl is not None: wl_type = "sym4" wl_lvl = np.fmin(pywt.dwtn_max_level(img_size, wl_type), 2) print("Wavelets selected! Wl type: %s, Wl lvl %d" % (wl_type, wl_lvl)) q_wl = pywt.swtn(depth, wl_type, wl_lvl) tau += lambda_wl * (2 ** wl_lvl) sigma_wl = 1 / (2 ** np.arange(wl_lvl, 0, -1)) tau = 1 / tau for ii in range(iterations): (d0, d1) = _gradient2(depth_it) d_2 = np.stack((d0, d1)) / 2 q_g += d_2 grad_l2_norm = np.fmax(1, np.sqrt(np.sum(q_g ** 2, axis=0))) q_g /= grad_l2_norm update = -lambda_tv * _divergence2(q_g[0, :, :], q_g[1, :, :]) if lambda_d2 is not None: l_dep = _laplacian2(depth_it) q_l += l_dep / 8 q_l /= np.fmax(1, np.abs(q_l)) update += lambda_d2 * _laplacian2(q_l) if use_defocus > 0: q_d += depth_it - a_d q_d /= np.fmax(1, np.abs(q_d)) update += use_defocus * W_d * q_d if use_correspondence > 0: q_c += depth_it - a_c q_c /= np.fmax(1, np.abs(q_c)) update += use_correspondence * W_c * q_c if use_xcorrelation > 0: q_x += depth_it - a_x q_x /= np.fmax(1, np.abs(q_x)) update += use_xcorrelation * W_x * q_x if lambda_wl is not None: d = pywt.swtn(depth_it, wl_type, wl_lvl) for ii_l in range(wl_lvl): for k in q_wl[ii_l].keys(): q_wl[ii_l][k] += d[ii_l][k] * sigma_wl[ii_l] q_wl[ii_l][k] /= np.fmax(1, np.abs(q_wl[ii_l][k])) update += lambda_wl * pywt.iswtn(q_wl, wl_type) depth_new = depth - update * tau depth_it = depth_new + (depth_new - depth) depth = depth_new return depth