示例#1
0
def test_wavedecn_coeff_ravel():
    # verify round trip is correct:
    #   wavedecn - >ravel_coeffs-> unravel_coeffs -> waverecn
    # This is done for wavedec{1, 2, n}
    rng = np.random.RandomState(1234)
    params = {'wavedec': {'d': 1, 'dec': pywt.wavedec, 'rec': pywt.waverec},
              'wavedec2': {'d': 2, 'dec': pywt.wavedec2, 'rec': pywt.waverec2},
              'wavedecn': {'d': 3, 'dec': pywt.wavedecn, 'rec': pywt.waverecn}}
    N = 12
    for f in params:
        x1 = rng.randn(*([N] * params[f]['d']))
        for mode in pywt.Modes.modes:
            for wave in wavelist:
                w = pywt.Wavelet(wave)
                maxlevel = pywt.dwt_max_level(np.min(x1.shape), w.dec_len)
                if maxlevel == 0:
                    continue

                coeffs = params[f]['dec'](x1, w, mode=mode)
                coeff_arr, slices, shapes = pywt.ravel_coeffs(coeffs)
                coeffs2 = pywt.unravel_coeffs(coeff_arr, slices, shapes,
                                              output_format=f)
                x1r = params[f]['rec'](coeffs2, w, mode=mode)

                assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4)
示例#2
0
 def hdot(self, x):
     """
     Implements the adjoint i.e. image to coeffs.
     Per basis and band coefficients are raveled and padded so they can be stacked
     into a single array.
     """
     beta = x[0]
     alpha = x[1]
     coeffs = np.zeros((self.nbasis + 1, self.nband, self.nmax))
     coeffs[0] = np.pad(beta.reshape(self.nband, self.nx * self.ny),
                        ((0, 0), (0, self.nmax - self.nx * self.ny)),
                        mode='constant')
     for b in range(self.nbasis):
         base = self.bases[b]
         for l in range(self.nband):
             # decompose
             alphal = pywt.wavedecn(alpha[l],
                                    base,
                                    mode='zero',
                                    level=self.nlevels)
             # ravel and pad
             tmp, _, _ = pywt.ravel_coeffs(alphal)
             coeffs[b + 1, l] = np.pad(tmp / self.sqrtP,
                                       (0, self.nmax - self.ntot[b]),
                                       mode='constant')
     return coeffs
示例#3
0
    def scales(self):
        """Get the scales of each coefficient.

        Returns
        -------
        scales : ``range`` element
            The scale of each coefficient, given by an integer. 0 for the
            lowest resolution and self.nlevels for the highest.
        """
        if self.impl == 'pywt':
            if self.__variant == 'forward':
                discr_space = self.domain
                wavelet_space = self.range
            else:
                discr_space = self.range
                wavelet_space = self.domain

            shapes = pywt.wavedecn_shapes(discr_space.shape,
                                          self.pywt_wavelet,
                                          mode=self.pywt_pad_mode,
                                          level=self.nlevels,
                                          axes=self.axes)
            coeff_list = [np.full(shapes[0], 0)]
            for i in range(1, 1 + len(shapes[1:])):
                coeff_list.append(
                    {k: np.full(shapes[i][k], i)
                     for k in shapes[i].keys()})
            coeffs = pywt.ravel_coeffs(coeff_list, axes=self.axes)[0]
            return wavelet_space.element(coeffs)
        else:
            raise RuntimeError("bad `impl` '{}'".format(self.impl))
示例#4
0
def test_wavedecn_coeff_ravel():
    # verify round trip is correct:
    #   wavedecn - >ravel_coeffs-> unravel_coeffs -> waverecn
    # This is done for wavedec{1, 2, n}
    rng = np.random.RandomState(1234)
    params = {'wavedec': {'d': 1, 'dec': pywt.wavedec, 'rec': pywt.waverec},
              'wavedec2': {'d': 2, 'dec': pywt.wavedec2, 'rec': pywt.waverec2},
              'wavedecn': {'d': 3, 'dec': pywt.wavedecn, 'rec': pywt.waverecn}}
    N = 12
    for f in params:
        x1 = rng.randn(*([N] * params[f]['d']))
        for mode in pywt.Modes.modes:
            for wave in wavelist:
                w = pywt.Wavelet(wave)
                maxlevel = pywt.dwt_max_level(np.min(x1.shape), w.dec_len)
                if maxlevel == 0:
                    continue

                coeffs = params[f]['dec'](x1, w, mode=mode)
                coeff_arr, slices, shapes = pywt.ravel_coeffs(coeffs)
                coeffs2 = pywt.unravel_coeffs(coeff_arr, slices, shapes,
                                              output_format=f)
                x1r = params[f]['rec'](coeffs2, w, mode=mode)

                assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4)
示例#5
0
    def decompose(self, y):
        """
        WaveletTransformer.decompose(self, y)

        Computes and returns the 1d ravelled wavelet transform of y using the
        wavelet specified by self.wname.

        Inputs
        ------
        y : self.dim numpy array
            A signal given a numpy array with ndim = self.dim
        wname : str

        Returns
        -------
        coef_ravelled : 1d numpy array
        slices : list
        shapes : list

        Example
        -------
        >>> # not tested:
        >>> x0 = np.random.randn(128)
        >>> wt = WaveletTransformer('db1', x0.ndim)
        >>> coefs, *aux_data = wt.decompose(x0)
        >>> print(coefs.ndim)
        """
        assert (y.ndim == self.dim
                ), f"Expected y of dimension {self.dim} but got {y.ndim}."
        return wt.ravel_coeffs(self._wavedec(y, self.wname))
示例#6
0
    def __init__(self, imsize=None,
                 nlevels=2,
                 bases=['self', 'db1', 'db2', 'db3']):
        """
        Sets up operators to move between wavelet coefficients
        in each basis and the image x.

        Parameters
        ----------
        nband - number of bands
        nx - number of pixels in x-dimension
        ny - number of pixels in y-dimension
        nlevels - The level of the decomposition. Default=2
        basis - List holding basis names.
                Default is db1-4 wavelets
                Supports any subset of
                ['self', 'db1', 'db2', 'db3', 'db4', 'db5', 'db6', 'db7', 'db8']

        Returns
        =======
        Psi - list of operators performing coeff to image where
            each entry corresponds to one of the basis elements.
        Psi_t - list of operators performing image to coeff where
                each entry corresponds to one of the basis elements.
        """
        self.real_type = np.float64
        if imsize is None:
            raise ValueError("You must initialise imsize")
        else:
            self.nband, self.nx, self.ny = imsize
        self.nlevels = nlevels
        self.P = len(bases)
        self.sqrtP = np.sqrt(self.P)
        self.bases = bases
        self.nbasis = len(bases)

        # do a mock decomposition to get max coeff size
        x = np.random.randn(self.nx, self.ny)
        self.ntot = []
        self.iy = {}
        self.sy = {}
        for i, b in enumerate(bases):
            if b == 'self':
                alpha = x.flatten()
                y, iy, sy = x.flatten(), 0, 0
            else:
                alpha = pywt.wavedecn(x, b, mode='zero', level=self.nlevels)
                y, iy, sy = pywt.ravel_coeffs(alpha)
            self.iy[b] = iy
            self.sy[b] = sy
            self.ntot.append(y.size)

        # get padding info
        self.nmax = np.asarray(self.ntot).max()
        self.padding = []
        for i in range(self.nbasis):
            self.padding.append(slice(0, self.ntot[i]))
示例#7
0
 def _call(self, x):
     """Return wavelet transform of ``x``."""
     if self.impl == 'pywt':
         coeffs = pywt.wavedecn(
             x, wavelet=self.pywt_wavelet, level=self.nlevels,
             mode=self.pywt_pad_mode, axes=self.axes)
         return pywt.ravel_coeffs(coeffs, axes=self.axes)[0]
     else:
         raise RuntimeError("bad `impl` '{}'".format(self.impl))
示例#8
0
def test_swt_ravel_and_unravel():
    # When trim_approx=True, all swt functions can user pywt.ravel_coeffs
    for ndim, _swt, _iswt, ravel_type in [(1, pywt.swt, pywt.iswt, 'swt'),
                                          (2, pywt.swt2, pywt.iswt2, 'swt2'),
                                          (3, pywt.swtn, pywt.iswtn, 'swtn')]:
        x = np.ones((16, ) * ndim)
        c = _swt(x, 'sym2', level=3, trim_approx=True)
        arr, slices, shapes = pywt.ravel_coeffs(c)
        c = pywt.unravel_coeffs(arr, slices, shapes, output_format=ravel_type)
        r = _iswt(c, 'sym2')
        assert_allclose(x, r)
示例#9
0
    def DWT_analyze(self, x, level=None, wavelet=None):
        if level is None:
            level = self.levels
        if wavelet is None:
            wavelet = self.wavelet

        decomposition_0 = pywt.wavedec(x[:, 0], wavelet=wavelet, level=level, mode="per")
        decomposition_1 = pywt.wavedec(x[:, 1], wavelet=wavelet, level=level, mode="per")
        coefs_0, slices, shapes = pywt.ravel_coeffs(decomposition_0)
        
        if self.slices is None:
            self.slices = slices
        if self.shapes is None:
            self.shapes = shapes
        if self.size is None:
            self.size = len(coefs_0)

        coefs_1, _, _ = pywt.ravel_coeffs(decomposition_1)
        coefs_0 = np.rint(coefs_0).astype(np.int32)
        coefs_1 = np.rint(coefs_1).astype(np.int32)
        return np.concatenate((coefs_0, coefs_1))
示例#10
0
def test_unravel_invalid_inputs():
    coeffs = pywt.wavedecn(np.ones(2), 'haar')
    arr, slices, shapes = pywt.ravel_coeffs(coeffs)

    # empty list for slices or shapes
    assert_raises(ValueError, pywt.unravel_coeffs, arr, slices, [])
    assert_raises(ValueError, pywt.unravel_coeffs, arr, [], shapes)

    # unequal length for slices/shapes
    assert_raises(ValueError, pywt.unravel_coeffs, arr, slices[:-1], shapes)

    # invalid format name
    assert_raises(ValueError, pywt.unravel_coeffs, arr, slices, shapes, 'foo')
def test_unravel_invalid_inputs():
    coeffs = pywt.wavedecn(np.ones(2), 'haar')
    arr, slices, shapes = pywt.ravel_coeffs(coeffs)

    # empty list for slices or shapes
    assert_raises(ValueError, pywt.unravel_coeffs, arr, slices, [])
    assert_raises(ValueError, pywt.unravel_coeffs, arr, [], shapes)

    # unequal length for slices/shapes
    assert_raises(ValueError, pywt.unravel_coeffs, arr, slices[:-1], shapes)

    # invalid format name
    assert_raises(ValueError, pywt.unravel_coeffs, arr, slices, shapes, 'foo')
示例#12
0
def _hdot_internal(x, bases, ntot, nmax, nlevels, sqrtP, nx, ny, real_type):
    nband = x.shape[0]
    nbasis = len(bases)
    alpha = np.zeros((nbasis, nband, nmax), dtype=real_type)
    for b in range(nbasis):
        base = bases[b]
        for l in range(nband):
            # decompose
            alphal = pywt.wavedecn(x[l], base, mode='zero', level=nlevels)
            # ravel and pad
            tmp, _, _ = pywt.ravel_coeffs(alphal)
            alpha[b, l] = np.pad(tmp / sqrtP, (0, nmax - ntot[b]),
                                 mode='constant')

    return alpha
示例#13
0
def test_ravel_wavedec2_with_lists():
    x1 = np.ones((8, 8))
    wav = pywt.Wavelet('haar')
    coeffs = pywt.wavedec2(x1, wav)

    # list [cHn, cVn, cDn] instead of tuple is okay
    coeffs[1:] = [list(c) for c in coeffs[1:]]
    coeff_arr, slices, shapes = pywt.ravel_coeffs(coeffs)
    coeffs2 = pywt.unravel_coeffs(coeff_arr, slices, shapes,
                                  output_format='wavedec2')
    x1r = pywt.waverec2(coeffs2, wav)
    assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4)

    # wrong length list will cause a ValueError
    coeffs[1:] = [list(c[:-1]) for c in coeffs[1:]]  # truncate diag coeffs
    assert_raises(ValueError, pywt.ravel_coeffs, coeffs)
示例#14
0
def test_ravel_wavedec2_with_lists():
    x1 = np.ones((8, 8))
    wav = pywt.Wavelet('haar')
    coeffs = pywt.wavedec2(x1, wav)

    # list [cHn, cVn, cDn] instead of tuple is okay
    coeffs[1:] = [list(c) for c in coeffs[1:]]
    coeff_arr, slices, shapes = pywt.ravel_coeffs(coeffs)
    coeffs2 = pywt.unravel_coeffs(coeff_arr, slices, shapes,
                                  output_format='wavedec2')
    x1r = pywt.waverec2(coeffs2, wav)
    assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4)

    # wrong length list will cause a ValueError
    coeffs[1:] = [list(c[:-1]) for c in coeffs[1:]]  # truncate diag coeffs
    assert_raises(ValueError, pywt.ravel_coeffs, coeffs)
示例#15
0
def test_waverecn_coeff_ravel_odd():
    # verify round trip is correct:
    #   wavedecn - >ravel_coeffs-> unravel_coeffs -> waverecn
    rng = np.random.RandomState(1234)
    x1 = rng.randn(35, 33)
    for mode in pywt.Modes.modes:
        for wave in ['haar', ]:
            w = pywt.Wavelet(wave)
            maxlevel = pywt.dwt_max_level(np.min(x1.shape), w.dec_len)
            if maxlevel == 0:
                continue
            coeffs = pywt.wavedecn(x1, w, mode=mode)
            coeff_arr, slices, shapes = pywt.ravel_coeffs(coeffs)
            coeffs2 = pywt.unravel_coeffs(coeff_arr, slices, shapes)
            x1r = pywt.waverecn(coeffs2, w, mode=mode)
            # truncate reconstructed values to original shape
            x1r = x1r[tuple([slice(s) for s in x1.shape])]
            assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4)
示例#16
0
def test_waverecn_coeff_ravel_odd():
    # verify round trip is correct:
    #   wavedecn - >ravel_coeffs-> unravel_coeffs -> waverecn
    rng = np.random.RandomState(1234)
    x1 = rng.randn(35, 33)
    for mode in pywt.Modes.modes:
        for wave in ['haar', ]:
            w = pywt.Wavelet(wave)
            maxlevel = pywt.dwt_max_level(np.min(x1.shape), w.dec_len)
            if maxlevel == 0:
                continue
            coeffs = pywt.wavedecn(x1, w, mode=mode)
            coeff_arr, slices, shapes = pywt.ravel_coeffs(coeffs)
            coeffs2 = pywt.unravel_coeffs(coeff_arr, slices, shapes)
            x1r = pywt.waverecn(coeffs2, w, mode=mode)
            # truncate reconstructed values to original shape
            x1r = x1r[tuple([slice(s) for s in x1.shape])]
            assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4)
示例#17
0
def wave(im):
    w = pywt.wavedec2(im, "bior3.5", level=2)
    h, s, sh= pywt.ravel_coeffs(w)
    e = []
    h0 = h[s[0]]
    h0 = np.array(h0).reshape(sh[0][0],sh[0][1])
    e0 = (1/(sh[0][0]*sh[0][1]))*sum(ndimage.laplace(h0).ravel()**2)

    e.append(e0)

    for key in s[1].keys():
        hi = h[s[1][key]]
        e.append((1/len(hi))*sum(hi**2))

    for key in s[2].keys():
        hi = h[s[2][key]]
        e.append((1/len(hi))*sum(hi**2))    
        
    return e 
示例#18
0
    def __init__(self, nband, nx, ny):
        """
        Sets up operators to move between wavelet coefficients
        in each basis and the image x.

        Parameters
        ----------
        nband - number of bands
        nx - number of pixels in x-dimension
        ny - number of pixels in y-dimension
        nlevels - The level of the decomposition. Default=2
        bases - List holding basis names.
                Default is db1-8 wavelets
        """
        self.real_type = np.float64
        self.nband = nband
        self.nx = nx
        self.ny = ny
        self.nlevels = 3
        self.bases = ['db1', 'db2', 'db3', 'db4', 'db5', 'db6', 'db7', 'db8']
        self.P = len(self.bases)
        self.sqrtP = np.sqrt(self.P)
        self.nbasis = len(self.bases)

        # do a mock decomposition to get coefficient info
        x = np.random.randn(nx, ny)
        self.ntot = []
        self.iy = {}
        self.sy = {}
        for i, b in enumerate(self.bases):
            alpha = pywt.wavedecn(x, b, mode='zero', level=self.nlevels)
            y, iy, sy = pywt.ravel_coeffs(alpha)
            self.iy[b] = iy
            self.sy[b] = sy
            self.ntot.append(y.size)

        # get padding info
        self.nmax = np.asarray(self.ntot).max()
        self.dpadding = slice(0, nx * ny)  # Dirac slices/padding
        self.padding = []
        for i in range(self.nbasis):
            self.padding.append(slice(0, self.ntot[i]))
示例#19
0
 def hdot(self, x):
     """
     This implements the adjoint of Psi_func i.e. image to coeffs.
     Per basis and band coefficients are raveled padded so they can be stacked
     into a single array.
     """
     alpha = np.zeros((self.nbasis, self.nband, self.nmax))
     for b in range(self.nbasis):
         base = self.bases[b]
         for l in range(self.nband):
             if base == 'self':
                 # just pad image to have same shape as flattened wavelet coefficients
                 alpha[b, l] = np.pad(x[l].reshape(self.nx*self.ny)/self.sqrtP, (0, self.nmax-self.ntot[b]), mode='constant')
             else:
                 # decompose
                 alphal = pywt.wavedecn(x[l], base, mode='zero', level=self.nlevels)
                 # ravel and pad
                 tmp, _, _ = pywt.ravel_coeffs(alphal)
                 alpha[b, l] = np.pad(tmp/self.sqrtP, (0, self.nmax-self.ntot[b]), mode='constant')
     return alpha
示例#20
0
def test_wavedecn_coeff_ravel_zero_level():
    # verify round trip is correct:
    #   wavedecn - >ravel_coeffs-> unravel_coeffs -> waverecn
    # This is done for wavedec{1, 2, n}
    rng = np.random.RandomState(1234)
    params = {'wavedec': {'d': 1, 'dec': pywt.wavedec, 'rec': pywt.waverec},
              'wavedec2': {'d': 2, 'dec': pywt.wavedec2, 'rec': pywt.waverec2},
              'wavedecn': {'d': 3, 'dec': pywt.wavedecn, 'rec': pywt.waverecn}}
    N = 16
    for f in params:
        x1 = rng.randn(*([N] * params[f]['d']))
        for mode in pywt.Modes.modes:
            w = pywt.Wavelet('db2')

            coeffs = params[f]['dec'](x1, w, mode=mode, level=0)
            coeff_arr, slices, shapes = pywt.ravel_coeffs(coeffs)
            coeffs2 = pywt.unravel_coeffs(coeff_arr, slices, shapes,
                                          output_format=f)
            x1r = params[f]['rec'](coeffs2, w, mode=mode)

            assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4)
示例#21
0
def test_wavedecn_coeff_ravel_zero_level():
    # verify round trip is correct:
    #   wavedecn - >ravel_coeffs-> unravel_coeffs -> waverecn
    # This is done for wavedec{1, 2, n}
    rng = np.random.RandomState(1234)
    params = {'wavedec': {'d': 1, 'dec': pywt.wavedec, 'rec': pywt.waverec},
              'wavedec2': {'d': 2, 'dec': pywt.wavedec2, 'rec': pywt.waverec2},
              'wavedecn': {'d': 3, 'dec': pywt.wavedecn, 'rec': pywt.waverecn}}
    N = 16
    for f in params:
        x1 = rng.randn(*([N] * params[f]['d']))
        for mode in pywt.Modes.modes:
            w = pywt.Wavelet('db2')

            coeffs = params[f]['dec'](x1, w, mode=mode, level=0)
            coeff_arr, slices, shapes = pywt.ravel_coeffs(coeffs)
            coeffs2 = pywt.unravel_coeffs(coeff_arr, slices, shapes,
                                          output_format=f)
            x1r = params[f]['rec'](coeffs2, w, mode=mode)

            assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4)
    def dir_op(self, x):
        if (self.wav == 'dirac'):
            return np.ravel(x)
        if (self.wav == 'fourier'):
            return np.ravel(np.fft.fftn(x))
        if (self.wav == "dct"):
            return np.ravel(scipy.fft.dctn(x, norm='ortho'))
        if (self.shape[0] % 2 == 1):
            raise Exception("Signal shape should be even dimensions.")
        if (len(self.shape) > 1):
            if (self.shape[1] % 2 == 1):
                raise Exception("Signal shape should be even dimensions.")

        coeffs = pywt.wavedecn(x,
                               wavelet=self.wav,
                               level=self.levels,
                               mode='periodic',
                               axes=self.axes)
        arr, self.coeff_slices, self.coeff_shapes = pywt.ravel_coeffs(
            coeffs, axes=self.axes)
        return arr
示例#23
0
    def pansharpenWavelet(self):
        '''function pansharpenWavelet(self):
    This is an instance method that returns a Python list of 3 or 4
    NumPy arrays containing the pan-sharpened Red,Green,Blue, and 
    optionally, NIR bands. These bands will have been created using 
    the Wavelet  pan-sharpening method
    
    Returns: 
      list: Python list[] containing 3 or 4 NumPy arrays using wavelet method.
    '''

        # set number of output bands (3 or 4, depending on if NIR is passed-in)
        if self.NIR is not None:
            d = 4
        else:
            d = 3

        # get 2D dimensions of output imagery
        nrows, ncols = self.pan.shape

        # create 3D data cube to perform wavelet pan-sharpening
        if d == 3:
            image = np.zeros((nrows, ncols, d), dtype=np.float16)
            image[:, :, 0] = self.red
            image[:, :, 1] = self.green
            image[:, :, 2] = self.blue
        elif d == 4:
            image = np.zeros((nrows, ncols, d), dtype=np.float16)
            image[:, :, 0] = self.red
            image[:, :, 1] = self.green
            image[:, :, 2] = self.blue
            image[:, :, 3] = self.NIR
        else:
            return []

        level = 0
        wavelet_type = 'haar'

        coeffs = pywt.wavedec2(self.pan, wavelet=wavelet_type, level=level)
        panvec, coeff_slices, coeff_shapes = pywt.ravel_coeffs(coeffs)
        reconstvec = np.tile(panvec.T, (d, 1)).T

        n = panvec.shape[0]
        lowresvec = np.zeros((n, d), dtype=np.float16)

        for band in range(d):
            lowresCoeffs = pywt.wavedec2(image[:, :, band],
                                         wavelet=wavelet_type,
                                         level=level)
            lowresArr, arrSlices = pywt.coeffs_to_array(lowresCoeffs)
            lowresvec[:, band] = np.reshape(lowresArr, (nrows * ncols, ))

        for j in range(0, coeff_shapes[0][0] * coeff_shapes[0][1]):
            reconstvec[j, :] = lowresvec[j, :]

        sharpened = np.zeros((nrows, ncols, d), dtype=np.float16)
        for band in range(d):
            p = np.reshape(reconstvec[:, band], (nrows, ncols))
            fcoeffs = pywt.wavedec2(p, wavelet_type, level=level)
            out = pywt.waverec2(fcoeffs, wavelet_type)
            sharpened[:, :, band] = out

        redsharp = sharpened[:, :, 0]
        greensharp = sharpened[:, :, 1]
        bluesharp = sharpened[:, :, 2]

        if d == 4:
            NIRsharp = sharpened[:, :, 3]
            return [redsharp, greensharp, bluesharp, NIRsharp]
        elif d == 3:
            return [redsharp, greensharp, bluesharp]
        else:
            return []
示例#24
0
# import matlab file using scipy
wheat = scipy.io.loadmat('./wheat.mat')
Xwave = np.zeros((0, 0))

X = wheat['WHEAT_SPECTRA']
(nVar, nSamp) = (len(X[0]), len(X))

wavelet = pywt.Wavelet('sym2')
Lmax = pywt.dwt_max_level(nVar, wavelet)

wavelets = [
    'db2', 'db3', 'db4', 'db5', 'db6', 'db7', 'db8', 'db9', 'db10', 'coif1',
    'coif2', 'coif3', 'coif4', 'coif5', 'sym4', 'sym5', 'sym6', 'sym7', 'sym8',
    'sym9', 'sym10'
]
Lmaxs = []
for wave in wavelets:
    Lmaxs.append(pywt.dwt_max_level(nVar, wave))

Xwave = np.zeros((nSamp, 711))

# for each line in the dataset
for i in range(0, nSamp):
    # coeffs is list of arrays with coefficients cA5 and cD5 to cD1
    coeffs = wavedec(X[i, :], wavelet, level=4)
    # flatten coefficients into one line array
    C = pywt.ravel_coeffs(coeffs)
    C = C[0]
    # add coefficients to matrix of wavelet coefficients
    Xwave[i, 0:len(C)] = C
示例#25
0
def train(args):
    mel_list = glob.glob(os.path.join(args.train_dir, '*.mel'))
    trainset = MelDataset(args.seq_len, mel_list, args.hop_length)
    train_loader = DataLoader(trainset,
                              batch_size=args.batch_size,
                              num_workers=0,
                              shuffle=True,
                              drop_last=True)

    test_mel = glob.glob(os.path.join(args.valid_dir, '*.mel'))
    testset = []
    for i in range(args.test_num):
        mel = torch.load(test_mel[i])
        mel = mel[:, :args.test_len]
        mel = mel.unsqueeze(0)
        testset.append(mel)

    G = Generator(80)
    D = MultiScale()

    G = G.cuda()
    D = D.cuda()

    g_optimizer = optim.Adam(G.parameters(), lr=1e-4, betas=(0.5, 0.9))
    d_optimizer = optim.Adam(D.parameters(), lr=1e-4, betas=(0.5, 0.9))

    step, epochs = 0, 0
    if args.load_dir is not None:
        #print("Loading checkpoint")
        ckpt = torch.load(args.load_dir)
        G.load_state_dict(ckpt['G'])
        g_optimizer.load_state_dict(ckpt['g_optimizer'])
        D.load_state_dict(ckpt['D'])
        d_optimizer.load_state_dict(ckpt['d_optimizer'])
        step = ckpt['step']
        epochs = ckpt['epoch']
        ('Load Status: Epochs %d, Step %d' % (epochs, step))

    torch.backends.cudnn.benchmark = True

    start = time.time()
    try:
        for epoch in itertools.count(epochs):
            for (mel, audio) in train_loader:

                mel = mel.cuda()
                coeffs = pywt.wavedecn(audio.squeeze(), 'db2',
                                       level=2)  #,mode='symmetric'
                arr__, coeff_slices__, coeffs_shapes_ = pywt.ravel_coeffs(
                    coeffs)
                arr__ = numpy.array([[arr__]])
                arr3 = torch.from_numpy(arr__).float()

                audio = arr3.cuda()

                # Discriminator
                d_real = D(audio)
                d_loss_real = 0
                for scale in d_real:
                    d_loss_real += F.relu(1 - scale[-1]).mean()

                fake_audio = G(mel)

                ############################
                fake_audio = (fake_audio.cuda()).detach().cpu().clone().numpy()
                #print(fake_audio.shape)
                coeffs2 = pywt.wavedecn(fake_audio.squeeze(), 'db2', level=2)
                arr2, coeff_slices2, coeff_shapes = pywt.ravel_coeffs(coeffs2)

                arr2 = numpy.array([[arr2]])

                #for n in range(1, len(coeff_slices2)) :
                #    slice_dict = coeff_slices2[n]
                #    shape_dict = coeff_shapes[n]
                #    print('slice : ', slice_dict)
                #    print('shape : ', shape_dict)

                #print('d : ',coeff_slices2[0]['d'] + coeff_slices2[1]['d'] + coeff_slices2[2]['d'])
                #print('ad : ',arr2[coeff_slices2[1]['ad']])
                #print('dd : ',arr2[coeff_slices2[1]['dd']])

                #print(coeff_slices2)

                fake_audio = torch.from_numpy(arr2).float()
                #print(fake_audio.shape)
                ############################

                d_fake = D(fake_audio.cuda().detach())
                d_loss_fake = 0
                for scale in d_fake:
                    d_loss_fake += F.relu(1 + scale[-1]).mean()

                d_loss = d_loss_real + d_loss_fake

                D.zero_grad()
                d_loss.backward()
                d_optimizer.step()

                # Generator
                d_fake = D(fake_audio.cuda())
                g_loss = 0
                for scale in d_fake:
                    g_loss += -scale[-1].mean()

                # Feature Matching
                feature_loss = 0
                # feat_weights = 4.0 / 5.0  # discriminator block size + 1
                # D_weights = 1.0 / 3.0  # multi scale size
                # wt = D_weights * feat_weights  # not in paper
                for i in range(1):
                    for j in range(len(d_fake[i]) - 1):
                        feature_loss += F.l1_loss(d_fake[i][j],
                                                  d_real[i][j].detach())

                g_loss += args.lambda_feat * feature_loss

                G.zero_grad()
                g_loss.backward()
                g_optimizer.step()

                step += 1
                if step % 1 == 0:
                    print(
                        'Epoch: %-5d, Step: %-7d, D_loss: %.05f, G_loss: %.05f, ms/batch: %5.2f'
                        % (epoch, step, d_loss, g_loss, 1000 *
                           (time.time() - start) / args.log_interval))
                    start = time.time()

                if step % args.save_interval == 0:
                    root = Path(args.save_dir)
                    with torch.no_grad():
                        for i, mel_test in enumerate(testset):
                            g_audio = G(mel_test.cuda())
                            g_audio = g_audio.squeeze().cpu()
                            audio = (g_audio.numpy() * 32768)
                            scipy.io.wavfile.write(
                                root / ('generated-%d-%dk-%d.wav' %
                                        (epoch, step // 1000, i)), 22050,
                                audio.astype('int16'))

                    print("Saving checkpoint")
                    torch.save(
                        {
                            'G': G.state_dict(),
                            'g_optimizer': g_optimizer.state_dict(),
                            'D': D.state_dict(),
                            'd_optimizer': d_optimizer.state_dict(),
                            'step': step,
                            'epoch': epoch,
                        }, root / ('ckpt-%dk.pt' % (step // 1000)))

    except Exception as e:
        traceback.print_exc()
  def pansharpenWavelet(self):
    '''function pansharpenWavelet(self):
    This is an instance method that returns a Python list of 3 or 4
    NumPy arrays containing the pan-sharpened Red,Green,Blue, and 
    optionally, NIR bands. These bands will have been created using 
    the Wavelet  pan-sharpening method
    
    Returns: 
      list: Python list[] containing 3 or 4 NumPy arrays using wavelet method.
    '''
    # read Panchromatic,Multispectral Geotiffs into GDAL objects
    dsPan = gdal.Open(self.pan)
    dsMulti = gdal.Open(self.multi)

    # read Panchromatic,Red,Green,Blue bands into 2D NumPy arrays
    pan  = dsPan.GetRasterBand(1).ReadAsArray()
    blue    = dsMulti.GetRasterBand(3).ReadAsArray().astype(float)
    green   = dsMulti.GetRasterBand(2).ReadAsArray().astype(float)
    red     = dsMulti.GetRasterBand(1).ReadAsArray().astype(float)
    d = dsMulti.RasterCount
    nrows,ncols = pan.shape
    
    if d == 3: 
      
      image = np.zeros((nrows,ncols,d),dtype=np.float32)
      image[:,:,0] = red
      image[:,:,1] = green
      image[:,:,2] = blue

    elif d == 4: 
      NIR = dsMulti.GetRasterBand(1).ReadAsArray().astype(float)
      image[:,:,3] = NIR
    else: 
      dsPan,dsMulti = None,None
      return []

    level = 0
    wavelet_type = 'haar'

    coeffs = pywt.wavedec2( pan, wavelet=wavelet_type, level=level)
    panvec,coeff_slices,coeff_shapes = pywt.ravel_coeffs(coeffs)
    reconstvec = np.tile( panvec.T , (d,1)).T

    n=panvec.shape[0]
    lowresvec = np.zeros((n,d),dtype=np.float32)

    for band in range(d):
      lowresCoeffs = pywt.wavedec2( image[:,:,band], wavelet=wavelet_type, level=level)
      lowresArr,arrSlices = pywt.coeffs_to_array(lowresCoeffs)
      lowresvec[:,band] = np.reshape(lowresArr,(nrows*ncols,))

    for j in range( 0 , coeff_shapes[0][0] * coeff_shapes[0][1] ):
      reconstvec[ j,:] = lowresvec[j,:]

    sharpened = np.zeros((nrows,ncols,d),dtype=np.float32)
    for band in range(d):
      p = np.reshape( reconstvec[:,band], (nrows,ncols))
      fcoeffs = pywt.wavedec2(p,wavelet_type,level=level)
      out=pywt.waverec2(fcoeffs,wavelet_type)
      sharpened[:,:,band] = out

    redsharp = sharpened[:,:,0]
    greensharp = sharpened[:,:,1]
    bluesharp = sharpened[:,:,2]
    
    if d == 4:
      NIRsharp = sharpened[:,:,3]
      return [redsharp,greensharp,bluesharp,NIRsharp]
    elif d == 3: 
      return [redsharp,greensharp,bluesharp]
    else: 
      return []
+-------------------------------+-------------------------------+
"""

cam = pywt.data.camera()
coeffs = pywt.wavedecn(cam, wavelet="db2", level=3)

# Concatenating all coefficients into a single n-d array
arr, coeff_slices = pywt.coeffs_to_array(coeffs)

# Splitting concatenated coefficient array back into its components
coeffs_from_arr = pywt.array_to_coeffs(arr, coeff_slices)

cam_recon = pywt.waverecn(coeffs_from_arr, wavelet='db2')

# Raveling coefficients to a 1D array
arr, coeff_slices, coeff_shapes = pywt.ravel_coeffs(coeffs)

# Unraveling coefficients from a 1D array
coeffs_from_arr = pywt.unravel_coeffs(arr, coeff_slices, coeff_shapes)

cam_recon2 = pywt.waverecn(coeffs_from_arr, wavelet='db2')

# Multilevel: n-d coefficient shapes
shapes = pywt.wavedecn_shapes((64, 32), 'db2', mode='periodization')

# Multilevel: Total size of all coefficients
size = pywt.wavedecn_size(shapes)
print(size)

print()