Exemple #1
0
def batch_ISwt(batch):
    '''
    Args:
        batch: Tensor of batch [16,h,w,12]
    Returns:
        Idwt_batch: Tensor of Inverse wavelet transform [16,h*2,w*2,3]
    '''

    swt_batch = np.zeros([batch.shape[0], batch.shape[1], batch.shape[2], 3])

    for i in range(batch.shape[0]):
        Iswt_R = pywt.iswt2(
            (batch[i, :, :, 0],
             (batch[i, :, :, 1], batch[i, :, :, 2], batch[i, :, :, 3])),
            wavelet='haar')
        Iswt_G = pywt.iswt2(
            (batch[i, :, :, 4],
             (batch[i, :, :, 5], batch[i, :, :, 6], batch[i, :, :, 7])),
            wavelet='haar')
        Iswt_B = pywt.iswt2(
            (batch[i, :, :, 8],
             (batch[i, :, :, 9], batch[i, :, :, 10], batch[i, :, :, 11])),
            wavelet='haar')

        coeffs = cv2.merge([Iswt_R, Iswt_G, Iswt_B])
        swt_batch[i, :, :, :] = coeffs
        # print(coeffs.shape)
    return swt_batch
Exemple #2
0
def iswt2d_rgb(approxs, details, wavelet='bior2.2'):
    approxs_b = []
    approxs_g = []
    approxs_r = []

    details_b = []
    details_g = []
    details_r = []

    coeffs_b = []
    coeffs_g = []
    coeffs_r = []

    for i in range(0, len(approxs), 3):
        approxs_b.append(approxs[i])
        approxs_g.append(approxs[i + 1])
        approxs_r.append(approxs[i + 2])

    for i in range(0, len(details), 3):
        details_b.append(details[i])
        details_g.append(details[i + 1])
        details_r.append(details[i + 2])

    for i in range(0, len(approxs_b)):
        if i == 0:
            coeff_b = (details_b[0], details_b[1:4])
            coeff_g = (details_g[0], details_g[1:4])
            coeff_r = (details_r[0], details_r[1:4])
        else:
            coeff_b = (approxs_b[i], details_b[1 + 3 * i:1 + 3 * (i + 1)])
            coeff_g = (approxs_g[i], details_g[1 + 3 * i:1 + 3 * (i + 1)])
            coeff_r = (approxs_r[i], details_r[1 + 3 * i:1 + 3 * (i + 1)])

        coeffs_b.append(coeff_b)
        coeffs_g.append(coeff_g)
        coeffs_r.append(coeff_r)

    ch_b = pywt.iswt2(coeffs_b, wavelet=wavelet)
    ch_g = pywt.iswt2(coeffs_g, wavelet=wavelet)
    ch_r = pywt.iswt2(coeffs_r, wavelet=wavelet)

    # ch_b[ch_b > 255] = 255
    # ch_b[ch_b < 0] = 0
    # ch_g[ch_g > 255] = 255
    # ch_g[ch_g < 0] = 0
    # ch_r[ch_r > 255] = 255
    # ch_r[ch_r < 0] = 0

    # ch_b = ch_b.astype(np.uint8)
    # ch_g = ch_g.astype(np.uint8)
    # ch_r = ch_r.astype(np.uint8)

    img = np.dstack((ch_b, ch_g, ch_r))
    return img
Exemple #3
0
def batch_ISwt(batch):
    '''
    Args:
        batch: Input batch RGB image [batch_size, img_h, img_w, 3]
    Returns:
        dwt_batch: Batch  DWT result [batch_size, img_h, img_w, 12]
    '''
    # print(len(batch.shape))
    # assert (len(batch.shape) == 4 ),"Input batch Shape error"
    # assert (batch.shape[3] == 3 ),"Color channel error"

    Swt_batch = np.zeros([batch.shape[0], batch.shape[1], batch.shape[2], 3])

    for i in range(batch.shape[0]):
        Iswt_level_1_R = (batch[i, :, :,
                                0], (batch[i, :, :,
                                           1], batch[i, :, :,
                                                     2], batch[i, :, :, 3]))
        Iswt_level_2_R = (batch[i, :, :,
                                12], (batch[i, :, :,
                                            13], batch[i, :, :,
                                                       14], batch[i, :, :,
                                                                  15]))

        Iswt_level_1_G = (batch[i, :, :,
                                4], (batch[i, :, :,
                                           5], batch[i, :, :,
                                                     6], batch[i, :, :, 7]))
        Iswt_level_2_G = (batch[i, :, :,
                                16], (batch[i, :, :,
                                            17], batch[i, :, :,
                                                       18], batch[i, :, :,
                                                                  19]))

        Iswt_level_1_B = (batch[i, :, :,
                                8], (batch[i, :, :,
                                           9], batch[i, :, :,
                                                     10], batch[i, :, :, 11]))
        Iswt_level_2_B = (batch[i, :, :,
                                20], (batch[i, :, :,
                                            21], batch[i, :, :,
                                                       22], batch[i, :, :,
                                                                  23]))

        Iswt_R = pywt.iswt2([Iswt_level_2_R, Iswt_level_1_R], wavelet='haar')
        Iswt_G = pywt.iswt2([Iswt_level_2_G, Iswt_level_1_G], wavelet='haar')
        Iswt_B = pywt.iswt2([Iswt_level_2_B, Iswt_level_1_B], wavelet='haar')

        coeffs = cv2.merge([Iswt_R, Iswt_G, Iswt_B])
        Swt_batch[i, :, :, :] = coeffs

    return Swt_batch
Exemple #4
0
def test_per_axis_wavelets():
    # tests seperate wavelet for each axis.
    rstate = np.random.RandomState(1234)
    data = rstate.randn(16, 16, 16)
    level = 3

    # wavelet can be a string or wavelet object
    wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')

    coefs = pywt.swtn(data, wavelets, level=level)
    assert_allclose(pywt.iswtn(coefs, wavelets), data, atol=1e-14)

    # 1-tuple also okay
    coefs = pywt.swtn(data, wavelets[:1], level=level)
    assert_allclose(pywt.iswtn(coefs, wavelets[:1]), data, atol=1e-14)

    # length of wavelets doesn't match the length of axes
    assert_raises(ValueError, pywt.swtn, data, wavelets[:2], level)
    assert_raises(ValueError, pywt.iswtn, coefs, wavelets[:2])

    with warnings.catch_warnings():
        warnings.simplefilter('ignore', FutureWarning)
        # swt2/iswt2 also support per-axis wavelets/modes
        data2 = data[..., 0]
        coefs2 = pywt.swt2(data2, wavelets[:2], level)
        assert_allclose(pywt.iswt2(coefs2, wavelets[:2]), data2, atol=1e-14)
Exemple #5
0
def test_iswt2_mixed_dtypes():
    # Mixed precision inputs give double precision output
    rstate = np.random.RandomState(0)
    x_real = rstate.randn(8, 8)
    x_complex = x_real + 1j*x_real
    wav = 'sym2'
    for dtype1, dtype2 in [(np.float64, np.float32),
                           (np.float32, np.float64),
                           (np.float16, np.float64),
                           (np.complex128, np.complex64),
                           (np.complex64, np.complex128)]:

        if dtype1 in [np.complex64, np.complex128]:
            x = x_complex
            output_dtype = np.complex128
        else:
            x = x_real
            output_dtype = np.float64

        coeffs = pywt.swt2(x, wav, 2)
        # different precision for the approximation coefficients
        coeffs[0] = [coeffs[0][0].astype(dtype1),
                     tuple([c.astype(dtype2) for c in coeffs[0][1]])]
        y = pywt.iswt2(coeffs, wav)
        assert_equal(output_dtype, y.dtype)
        assert_allclose(y, x, rtol=1e-3, atol=1e-3)
Exemple #6
0
def test_swt2_iswt2_integration():
    # This function performs a round-trip swt2/iswt2 transform test on
    # all available types of wavelets in PyWavelets - except the
    # 'dmey' wavelet. The latter has been excluded because it does not
    # produce very precise results. This is likely due to the fact
    # that the 'dmey' wavelet is a discrete approximation of a
    # continuous wavelet. All wavelets are tested up to 3 levels. The
    # test validates neither swt2 or iswt2 as such, but it does ensure
    # that they are each other's inverse.

    max_level = 3
    wavelets = pywt.wavelist()
    if 'dmey' in wavelets:
        # The 'dmey' wavelet seems to be a bit special - disregard it for now
        wavelets.remove('dmey')
    for current_wavelet_str in wavelets:
        current_wavelet = pywt.Wavelet(current_wavelet_str)
        input_length_power = int(np.ceil(np.log2(max(
            current_wavelet.dec_len,
            current_wavelet.rec_len))))
        input_length = 2**(input_length_power + max_level - 1)
        X = np.arange(input_length**2).reshape(input_length, input_length)
        coeffs = pywt.swt2(X, current_wavelet, max_level)
        Y = pywt.iswt2(coeffs, current_wavelet)
        assert_allclose(Y, X, rtol=1e-5, atol=1e-5)
 def _transform(self, image, channel=['r']):
     new_image = self._pad(image)
     wave = pywt.Wavelet(self._wavelet)
     cA_1, (cH_1, cV_1, cD_1) = pywt.swt2(new_image,
                                          wave,
                                          level=1,
                                          start_level=1)[0]
     if self._mode == 'rebuild':
         cA = np.zeros_like(cA_1)
         coeffs = ((cA, (cH_1, cV_1, cD_1)), )
         reb_image = pywt.iswt2(coeffs, wave)
         channels = (channel, 'high_pass')
         return channels, (image, self._unpad(reb_image))
     elif self._mode == 'features':
         channels = (channel, 'cH', 'cV', 'cD')
         return channels, (image, self._unpad(cH_1), self._unpad(cV_1),
                           self._unpad(cD_1))
     elif self._mode == 'features-all':
         channels = (channel, 'cA', 'cH', 'cV', 'cD')
         return channels, (image, self._unpad(cA_1), self._unpad(cH_1),
                           self._unpad(cV_1), self._unpad(cD_1))
     elif self._mode == 'features-only':
         channels = ('cA', 'cH', 'cV', 'cD')
         return channels, (self._unpad(cA_1), self._unpad(cH_1),
                           self._unpad(cV_1), self._unpad(cD_1))
Exemple #8
0
def test_iswt2_mixed_dtypes():
    # Mixed precision inputs give double precision output
    rstate = np.random.RandomState(0)
    x_real = rstate.randn(8, 8)
    x_complex = x_real + 1j*x_real
    wav = 'sym2'
    for dtype1, dtype2 in [(np.float64, np.float32),
                           (np.float32, np.float64),
                           (np.float16, np.float64),
                           (np.complex128, np.complex64),
                           (np.complex64, np.complex128)]:

        if dtype1 in [np.complex64, np.complex128]:
            x = x_complex
            output_dtype = np.complex128
        else:
            x = x_real
            output_dtype = np.float64

        coeffs = pywt.swt2(x, wav, 2)
        # different precision for the approximation coefficients
        coeffs[0] = [coeffs[0][0].astype(dtype1),
                     tuple([c.astype(dtype2) for c in coeffs[0][1]])]
        y = pywt.iswt2(coeffs, wav)
        assert_equal(output_dtype, y.dtype)
        assert_allclose(y, x, rtol=1e-3, atol=1e-3)
Exemple #9
0
def wavelet_blend(left, right, shift_x, dx, f_x, wavelet, level):
    coeffs_left = pywt.swt2(left, wavelet, level)
    coeffs_right = pywt.swt2(right, wavelet, level)

    coeffs = []
    for l, (cleft, cright) in enumerate(zip(coeffs_left, coeffs_right)):
        approx_left, details_left = cleft
        approx_right, details_right = cright
        # blend approximation x intersections
        inter_approx_left = approx_left[:, dx:]
        inter_approx_right = approx_right[:, :shift_x]
        inter_approx = f_x * inter_approx_left + (1 - f_x) * inter_approx_right
        approx = approx_right
        approx[:, :shift_x] = inter_approx
        # blend detail x intersections
        details = []
        for dl, dr in zip(details_left, details_right):
            inter_dl = dl[:, dx:]
            inter_dr = dr[:, :shift_x]
            inter_detail = f_x * inter_dl + (1 - f_x) * inter_dr
            detail = dr
            detail[:, :shift_x] = inter_detail
            details.append(detail)
        coeffs.append((approx, tuple(details)))

    return pywt.iswt2(coeffs, wavelet)
Exemple #10
0
def test_swt2_iswt2_integration(wavelets=None):
    # This function performs a round-trip swt2/iswt2 transform test on
    # all available types of wavelets in PyWavelets - except the
    # 'dmey' wavelet. The latter has been excluded because it does not
    # produce very precise results. This is likely due to the fact
    # that the 'dmey' wavelet is a discrete approximation of a
    # continuous wavelet. All wavelets are tested up to 3 levels. The
    # test validates neither swt2 or iswt2 as such, but it does ensure
    # that they are each other's inverse.

    max_level = 3
    if wavelets is None:
        wavelets = pywt.wavelist(kind='discrete')
        if 'dmey' in wavelets:
            # The 'dmey' wavelet is a special case - disregard it for now
            wavelets.remove('dmey')
    for current_wavelet_str in wavelets:
        current_wavelet = pywt.Wavelet(current_wavelet_str)
        input_length_power = int(np.ceil(np.log2(max(
            current_wavelet.dec_len,
            current_wavelet.rec_len))))
        input_length = 2**(input_length_power + max_level - 1)
        X = np.arange(input_length**2).reshape(input_length, input_length)

        coeffs = pywt.swt2(X, current_wavelet, max_level)
        Y = pywt.iswt2(coeffs, current_wavelet)
        assert_allclose(Y, X, rtol=1e-5, atol=1e-5)
Exemple #11
0
def test_per_axis_wavelets():
    # tests seperate wavelet for each axis.
    rstate = np.random.RandomState(1234)
    data = rstate.randn(16, 16, 16)
    level = 3

    # wavelet can be a string or wavelet object
    wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')

    coefs = pywt.swtn(data, wavelets, level=level)
    assert_allclose(pywt.iswtn(coefs, wavelets), data, atol=1e-14)

    # 1-tuple also okay
    coefs = pywt.swtn(data, wavelets[:1], level=level)
    assert_allclose(pywt.iswtn(coefs, wavelets[:1]), data, atol=1e-14)

    # length of wavelets doesn't match the length of axes
    assert_raises(ValueError, pywt.swtn, data, wavelets[:2], level)
    assert_raises(ValueError, pywt.iswtn, coefs, wavelets[:2])

    with warnings.catch_warnings():
        warnings.simplefilter('ignore', FutureWarning)
        # swt2/iswt2 also support per-axis wavelets/modes
        data2 = data[..., 0]
        coefs2 = pywt.swt2(data2, wavelets[:2], level)
        assert_allclose(pywt.iswt2(coefs2, wavelets[:2]), data2, atol=1e-14)
Exemple #12
0
    def apply_wavelet(self,img):
        ################ADDED undecimated isotropic wavelet transform####################
        import pywt
        #Interested in the Stationary Wavelet Transformation or a Trous or starlet (swt2)
        
        #get a median filter for wavelet transformation
        from scipy.signal import medfilt
        
        #Use skimage to create the background gaussian filter
        from scipy.ndimage.filters import gaussian_filter



        #check to see if image is a list if not make it one for convience
        unlist = False
        if not isinstance(img,list):
            img = [img]
            unlist = True

        #processed image list
        nimg = []
        for j,i in enumerate(img):
            #Add wavelet transformation
            wav_img = i.data

            #degrade image for faster comp
            wav_img_low = self.shrink(wav_img,1024,1024)/16.
            d_size = 15
            #get median filter
            n_back = medfilt(wav_img_low,kernel_size=d_size)

            #expand the median filter back to normal size
            n_back = np.repeat(n_back,4,axis=1)
            n_back = np.repeat(n_back,4,axis=0)
            n_back = gaussian_filter(n_back,2)

            #subtract median filter
            img_sub = wav_img-n_back

            #Use Biorthogonal Wavelet
            wavelet = 'bior2.6'

            #use 6 levels
            n_lev = 6
            o_wav = pywt.swt2(img_sub, wavelet, level=n_lev )
            #only use the first 4 (Will need to switch order sometime)
            f_img = pywt.iswt2(o_wav[0:4],wavelet)
            #Add  wavelet back into image
            f_img = f_img+wav_img

            #store the new image and return
            nimg.append(sunpy.map.Map(f_img,i.meta))
        
        #if it was only one file remove list attribute 
        if unlist:
            nimg = nimg[0]

        return nimg
Exemple #13
0
def test_swt2_iswt2_non_square(wavelets=None):
    for nrows in [8, 16, 48]:
        X = np.arange(nrows*32).reshape(nrows, 32)
        current_wavelet = 'db1'
        with warnings.catch_warnings():
            warnings.simplefilter('ignore', FutureWarning)
            coeffs = pywt.swt2(X, current_wavelet, level=2)
            Y = pywt.iswt2(coeffs, current_wavelet)
        assert_allclose(Y, X, rtol=tol_single, atol=tol_single)
Exemple #14
0
def test_swt2_iswt2_non_square(wavelets=None):
    for nrows in [8, 16, 48]:
        X = np.arange(nrows*32).reshape(nrows, 32)
        current_wavelet = 'db1'
        with warnings.catch_warnings():
            warnings.simplefilter('ignore', FutureWarning)
            coeffs = pywt.swt2(X, current_wavelet, level=2)
            Y = pywt.iswt2(coeffs, current_wavelet)
        assert_allclose(Y, X, rtol=tol_single, atol=tol_single)
Exemple #15
0
def get_swt(shape, wname, levels):
    """ returns stationary wavelet transforms for a given image shape
	"""
    # get slices for pywt array <--> coeff conversion
    coeffs = pywt.swt2(np.zeros(shape), wname, levels, trim_approx=True)
    _, swt_slices = coeffs_to_array(coeffs)
    # stationary/undecimated wavelet transform
    iswt = lambda x: pywt.iswt2(array_to_coeffs(x, swt_slices, 'wavedec2'),
                                wname)
    swt = lambda x: coeffs_to_array(
        pywt.swt2(x, wname, levels, trim_approx=True))[0]
    return swt, iswt
Exemple #16
0
def iswt2d(approxs, details, wavelet='bior2.2'):
    coeffs = []
    for i in range(len(approxs)):
        if i == 0:
            coeff = (details[0], details[1:4])
        else:
            coeff = (approxs[i], details[1 + 3 * i:1 + 3 * (i + 1)])

        coeffs.append(coeff)

    img = pywt.iswt2(coeffs, wavelet=wavelet)
    return img
Exemple #17
0
def _swt_norm(x, wavelet, level, p=2):
    """Computes the p-norm of the SWT detail coefficients of the input and its gradient."""
    div = 2**math.ceil(math.log2(max(x.shape[1:])))
    pw = pad_width(x.shape, (1, div, div))
    x_pad = np.pad(x, pw, 'symmetric')
    inv = []
    for ch in x_pad:
        coeffs = pywt.swt2(ch, wavelet, level)
        for a, _ in coeffs:
            a[...] = 0
        inv.append(pywt.iswt2(coeffs, wavelet)[pw[1][0]:pw[1][0] + x.shape[1],
                                               pw[2][0]:pw[2][0] + x.shape[2]])
    return p_norm(np.stack(inv), p)
Exemple #18
0
 def initialize_wl_operators(self):
     if self.use_decimated:
         H = lambda x: pywt.wavedecn(x, wavelet=self.wl_type, axes=self.axes, level=self.decomp_lvl)
         Ht = lambda x: pywt.waverecn(x, wavelet=self.wl_type, axes=self.axes)
     else:
         if use_swtn:
             H = lambda x: pywt.swtn(x, wavelet=self.wl_type, axes=self.axes, level=self.decomp_lvl)
             Ht = lambda x: pywt.iswtn(x, wavelet=self.wl_type, axes=self.axes)
         else:
             H = lambda x: pywt.swt2(np.squeeze(x), wavelet=self.wl_type, axes=self.axes, level=self.decomp_lvl)
             #                Ht = lambda x : pywt.iswt2(x, wavelet=self.wl_type)
             Ht = lambda x: pywt.iswt2(x, wavelet=self.wl_type)[np.newaxis, ...]
     return (H, Ht)
Exemple #19
0
def _swt_norm(x, wavelet, level, p=2):
    """Computes the p-norm of the SWT detail coefficients of the input and its gradient."""
    div = 2**math.ceil(math.log2(max(x.shape[1:])))
    pw = pad_width(x.shape, (1, div, div))
    x_pad = np.pad(x, pw, 'symmetric')
    inv = []
    for ch in x_pad:
        coeffs = pywt.swt2(ch, wavelet, level)
        for a, _ in coeffs:
            a[:] = 0
        inv.append(
            pywt.iswt2(coeffs, wavelet)[pw[1][0]:pw[1][0] + x.shape[1],
                                        pw[2][0]:pw[2][0] + x.shape[2]])
    return p_norm(np.stack(inv), p)
Exemple #20
0
def sharpenChannelLayers(params):
    c = g.coeffs
    # Go through each wavelet layer and apply sharpening
    cCopy = []
    for i in range(1, len(c)):
        level = (len(c) - i - 1)
        # Copy the layer if a change is made to the coefficients
        if (params['level'][level] and
            (params['radius'][level] > 0 or params['sharpen'][level] > 0
             or params['denoise'][level] > 0)):
            cCopy.append(copy.deepcopy(c[i]))
        else:
            cCopy.append(None)

        # Process Layers
        if (params['level'][level]):
            # Apply Unsharp Mask
            if (params['radius'][level] > 0):
                unsharp(c[i][0], params['radius'][level], 2)
                unsharp(c[i][1], params['radius'][level], 2)
                unsharp(c[i][2], params['radius'][level], 2)
            # Multiply the layer to increase intensity
            if (params['sharpen'][level] > 0):
                factor = (100 - 10 * level)
                cv2.add(c[i][0], c[i][0] * params['sharpen'][level] * factor,
                        c[i][0])
                cv2.add(c[i][1], c[i][1] * params['sharpen'][level] * factor,
                        c[i][1])
                cv2.add(c[i][2], c[i][2] * params['sharpen'][level] * factor,
                        c[i][2])
            # Denoise
            if (params['denoise'][level] > 0):
                unsharp(c[i][0], params['denoise'][level], -1)
                unsharp(c[i][1], params['denoise'][level], -1)
                unsharp(c[i][2], params['denoise'][level], -1)

    # Reconstruction
    padding = 2**(Sharpen.LEVEL)
    img = iswt2(c, 'haar')

    # Reset coefficients
    for i in range(1, len(c)):
        if (cCopy[i - 1] is not None):
            c[i] = cCopy[i - 1]

    # Prepare image for saving
    img = img[padding:, padding:]

    return (g.channel, img)
def main():
    noise_sigma = 2
    wavelet = 'bior1.3'

    images = list(io.load_fits_images_in_folder('./resources'))
    with PdfPages('multipage.pdf') as pdf:
        for i, (calibrated_image, npe_truth) in tqdm(enumerate(images)):

            fig = plt.figure()
            plt.text(1, 1, 'Event Number: {}'.format(i))
            plt.xlim(0, 2)
            plt.ylim(0, 2)
            plt.axis('off')
            pdf.savefig(fig)
            plt.close()

            # transform the image
            level = pywt.swt_max_level(len(calibrated_image))
            # print('maximum level of decomposition: {}'.format(level))

            coeff_list = pywt.swt2(calibrated_image, wavelet, level)

            fig = plot.coefficients(coeff_list)
            pdf.savefig(figure=fig)
            plt.close()

            levels = [0.889, 0.7, 0.586]
            coeff_list = denoise.thresholding(coeff_list,
                                              sigma_d=noise_sigma,
                                              kind='hard',
                                              sigma_levels=levels)

            # coeff_list = denoise.wiener(coeff_list)
            reconstructed_image = pywt.iswt2(coeff_list, wavelet)

            fig = plot.results(calibrated_image, reconstructed_image,
                               npe_truth)
            pdf.savefig(figure=fig)
            plt.close()

            fig = plot.pixel_histogram(
                reconstructed_image,
                calibrated_image,
                npe_truth,
                labels=['reconstructed', 'calibrated', 'npe mc truth'],
                bins=60)

            pdf.savefig(figure=fig)
            plt.close()
Exemple #22
0
def test_swt_roundtrip_dtypes():
    # verify perfect reconstruction for all dtypes
    rstate = np.random.RandomState(5)
    wavelet = pywt.Wavelet('haar')
    for dt_in, dt_out in zip(dtypes_in, dtypes_out):
        # swt, iswt
        x = rstate.standard_normal((8, )).astype(dt_in)
        c = pywt.swt(x, wavelet, level=2)
        xr = pywt.iswt(c, wavelet)
        assert_allclose(x, xr, rtol=1e-6, atol=1e-7)

        # swt2, iswt2
        x = rstate.standard_normal((8, 8)).astype(dt_in)
        c = pywt.swt2(x, wavelet, level=2)
        xr = pywt.iswt2(c, wavelet)
        assert_allclose(x, xr, rtol=1e-6, atol=1e-7)
Exemple #23
0
def test_swt_roundtrip_dtypes():
    # verify perfect reconstruction for all dtypes
    rstate = np.random.RandomState(5)
    wavelet = pywt.Wavelet('haar')
    for dt_in, dt_out in zip(dtypes_in, dtypes_out):
        # swt, iswt
        x = rstate.standard_normal((8, )).astype(dt_in)
        c = pywt.swt(x, wavelet, level=2)
        xr = pywt.iswt(c, wavelet)
        assert_allclose(x, xr, rtol=1e-6, atol=1e-7)

        # swt2, iswt2
        x = rstate.standard_normal((8, 8)).astype(dt_in)
        c = pywt.swt2(x, wavelet, level=2)
        xr = pywt.iswt2(c, wavelet)
        assert_allclose(x, xr, rtol=1e-6, atol=1e-7)
def shrink(img, lev=4, shrink_type='bayes', thresh_type='hard', k=1):

    # convert to float
    data = np.asarray(img)
    data = np.float32(data)

    # swt
    c = pywt.swt2(data, 'bior4.4', level=lev, trim_approx=True)

    # algorithm choosing
    if shrink_type == 'visu':
        d = visushrink(c, lev, thresh_type, k)
    elif shrink_type == 'sure':
        d = sureshrink(c, lev, thresh_type)
    elif shrink_type == 'bayes':
        d = bayesshrink(c, lev, thresh_type, k)
    else:
        d = c

    img_denoised = pywt.iswt2(d, 'bior4.4')
    img_denoised = np.clip(img_denoised, 0, 255)
    return img_denoised
Exemple #25
0
def fienup(anc_vec, con_upp, con_low):

    # initial guess
    rec_vec = anc_vec

    for inn_itr in range(INN_ITR):

        # copy
        cpy_vec = cp.deepcopy(rec_vec)

        # thresholding
        rec_vec = pw.swt2(rec_vec, QMF, DLV, norm=True)

        for d in range(DLV):
            # LL
            rec_vec[d] = list(rec_vec[d])
            rec_vec[d][0] = pw.threshold(rec_vec[d][0], REG_PRM)

            # LH, HL, and HH
            rec_vec[d][1] = list(rec_vec[d][1])
            rec_vec[d][1][0] = pw.threshold(rec_vec[d][1][0], REG_PRM)
            rec_vec[d][1][1] = pw.threshold(rec_vec[d][1][1], REG_PRM)
            rec_vec[d][1][2] = pw.threshold(rec_vec[d][1][2], REG_PRM)

        rec_vec = pw.iswt2(rec_vec, QMF, DLV)

        # proximal mapping for inner product
        rec_vec = rec_vec + PEN_PRM * anc_vec

        # projection
        rec_vec = bdct2(rec_vec)
        rec_vec = np.minimum(rec_vec, con_upp)
        rec_vec = np.maximum(rec_vec, con_low)
        rec_vec = bidct2(rec_vec)

        # acceleration
        rec_vec = rec_vec + (inn_itr - 1) * (rec_vec - cpy_vec) / (inn_itr + 2)

    return rec_vec
Exemple #26
0
    def iswt2(self):
        """
        Test pypwt for DWT2 reconstruction (iswt2).
        """

        W = self.W
        # inverse DWT with pypwt
        W.forward()
        logging.info("computing Wavelets.inverse from pypwt")
        t0 = time()
        W.inverse()
        logging.info("Wavelets.inverse took %.3f ms" % elapsed_ms(t0))

        if self.do_pywt:
            # inverse DWT with pywt
            Wpy = pywt.swt2(self.data, self.wname, level=self.levels)
            logging.info("computing iswt2 from pywt")
            _ = pywt.iswt2(Wpy, self.wname)
            logging.info("pywt took %.3f ms" % elapsed_ms(t0))

        # Check reconstruction
        W_image = W.image
        maxerr = _calc_errors(self.data, W_image, "[rec]")
        self.assertTrue(maxerr < self.tol, msg="[%s] something wrong with the reconstruction (errmax = %e)" % (self.wname, maxerr))
V2L1 = np.float32(HL)
D2L1 = np.float32(HH)

#Fusion start
AfL1 = 0.5 * (A1L1 + A2L1)
D = (np.abs(H1L1) - np.abs(H2L1)) >= 0
HfL1 = np.multiply(D, H1L1) + np.multiply(np.logical_not(D), H2L1)
D = (np.abs(V1L1) - np.abs(V2L1)) >= 0
VfL1 = np.multiply(D, V1L1) + np.multiply(np.logical_not(D), V2L1)
D = (np.abs(D1L1) - np.abs(D2L1)) >= 0
DfL1 = np.multiply(D, D1L1) + np.multiply(np.logical_not(D), D2L1)

#For inverse swt2 (iswt2)

coeffs3 = AfL1, (HfL1, VfL1, DfL1)
imf = np.uint8(pywt.iswt2(coeffs3, 'db2'))

#Display images
plt.figure(dpi=200)
plt.subplot(121)
plt.imshow(img1, cmap='gray')
plt.subplot(122)
plt.imshow(img2, cmap='gray')
plt.savefig('inputs.png', dpi=200)
plt.show()

plt.figure(dpi=200)
plt.imshow(imf, cmap='gray')
plt.savefig('output.png')
plt.show()
def sw(img_path, wavelet='haar', level=1):

    img1 = Image.open(img_path[0]).convert('RGB')
    img2 = Image.open(img_path[1]).convert('RGB')
    x = np.array(img1)
    y = np.array(img2)
    img_1 = cv2.cvtColor(x, cv2.COLOR_RGB2BGR)
    img_2 = cv2.cvtColor(y, cv2.COLOR_RGB2BGR)

    img_2 = cv2.resize(y, (img_1.shape[0], img_1.shape[1]))

    Red_Input_Image1 = img_1[:, :, 0]

    Green_Input_Image1 = img_1[:, :, 1]

    Blue_Input_Image1 = img_1[:, :, 2]

    LAr1, LDr1 = pywt.swt2(Red_Input_Image1,
                           wavelet=wavelet,
                           level=level,
                           norm=True)[0]
    LAg1, LDg1 = pywt.swt2(Green_Input_Image1,
                           wavelet=wavelet,
                           level=level,
                           norm=True)[0]
    LAb1, LDb1 = pywt.swt2(Blue_Input_Image1,
                           wavelet=wavelet,
                           level=level,
                           norm=True)[0]

    Red_Input_Image2 = img_2[:, :, 0]

    Green_Input_Image2 = img_2[:, :, 1]

    Blue_Input_Image2 = img_2[:, :, 2]

    LAr2, LDr2 = pywt.swt2(Red_Input_Image2,
                           wavelet=wavelet,
                           level=level,
                           norm=True)[0]
    LAg2, LDg2 = pywt.swt2(Green_Input_Image2,
                           wavelet=wavelet,
                           level=level,
                           norm=True)[0]
    LAb2, LDb2 = pywt.swt2(Blue_Input_Image2,
                           wavelet=wavelet,
                           level=level,
                           norm=True)[0]

    LAr = np.add(LAr1, LAr2)
    LAg = np.add(LAg1, LAg2)
    LAb = np.add(LAb1, LAb2)

    LDr = np.add(LDr1, LDr2) / 2
    LDg = np.add(LDg1, LDg2) / 2
    LDb = np.add(LDb1, LDb2) / 2

    R = pywt.iswt2((LAr, LDr), wavelet=wavelet)
    G = pywt.iswt2((LAg, LDg), wavelet=wavelet)
    B = pywt.iswt2((LAb, LDb), wavelet=wavelet)

    imgx = np.zeros([R.shape[0], R.shape[1], 3], dtype=np.uint8)
    imgx[:, :, 0] = R
    imgx[:, :, 1] = G
    imgx[:, :, 2] = B
    return imgx
def make_images(f):
    global wav, img_scale, wx, wy, h0, w0, sdir

    #try to make the image. If it fails just move on
    try:
        #width of ind. image
        img_w = w0 / len(wav)
        if wx > wy:
            img_wx = wx / len(wav)
            img_wy = wy
        else:
            img_wy = wy / len(wav)
            img_wx = wx

        wavelet = False

        #read all images into sunpy maps
        img = sunpy.map.Map(*f)

        #output file
        outfi = sdir + '/working/panel_{0}'.format(
            img[0].date.strftime('%Y%m%d_%H%M%S')) + '.png'

        #skip if file exists exit
        if os.path.isfile(outfi):
            return

        #dictionary of images
        img_dict = {}
        scale = {}
        scale_list = []
        #create new image
        new_img = Image.new('RGB', (w0, h0))

        #image size of subwindow
        sub_img_size = (img_w, h0)
        #put parameters in a series of dictionaries
        for j, i in enumerate(img):

            #create image position based on index
            if j == 0:
                px, py = 0, 0
            elif j == 2:
                px, py = 0, 1024
            elif j == 1:
                px, py = 1024, 0
            elif j == 3:
                px, py = 1024, 1024

            #color mapping
            icmap = img_scale[wav[j]][0]
            ivmin = img_scale[wav[j]][1]
            ivmax = img_scale[wav[j]][2]
            #keep list of scale
            scale = [i.scale[0].value, i.scale[1].value]

            #set up for wavelet analysis
            wav_img = i.data
            f_img = wav_img

            #do wavelet analysis
            if wavelet:
                d_size = 15 * 4 + 1
                #get median filter
                n_back = medfilt(wav_img, kernel_size=d_size)
                #subtract median filter
                img_sub = wav_img - n_back

                #Use Biorthogonal Wavelet
                wavelet = 'bior2.6'

                #use 6 levels
                n_lev = 6
                o_wav = pywt.swt2(img_sub, wavelet, level=n_lev)
                #only use the first 4
                f_img = pywt.iswt2(o_wav[0:4], wavelet)
                #Add  wavelet back into image
                f_img = f_img + wav_img

                #remove zero values
                f_img[f_img < 0.] = 0.

            #do normalization
            b_img = (np.arcsinh(f_img) - ivmin) / (ivmax - ivmin)
            #format image with color scale
            img_n = np.array(b_img)
            #if greater than 1 set to 0.99
            img_n[img_n > 0.99] = 0.99
            #print(wav[j],img_n.max(),f_img.max(),np.percentile(f_img,[5.,99]))
            #img_n[img_n < -.2] = 0.99
            img_n[img_n < 0.] = 0.
            img_n = icmap(img_n)
            img_n = np.uint8(img_n * 255)

            if img_n.max() > 255:
                print(img_n.max())
                print(img[0].date.strftime('%Y%m%d_%H%M%S'))

            img_n = Image.fromarray(img_n)
            img_dict[wav[j]] = img_n

            #resize image to img0 scale
            img_n = img_n.resize((1024, 1024))

            #default image size
            old_size = img_n.size
            horizontal_padding = (1024 - old_size[0]) / 2
            vertical_padding = (1024 - old_size[1]) / 2
            temp_img = img_n.crop((-horizontal_padding, -vertical_padding,
                                   old_size[0] + horizontal_padding,
                                   old_size[1] + vertical_padding))

            #Add image to array of images
            new_img.paste(temp_img, (px, py))

        #set scale for plotting
        #observed time
        obs_time = img[0].date

        #write on text
        w_text = '{0:%Y/%m/%d %H:%M:%S} '.format(obs_time)
        #add wavelengths
        for i in wav:
            w_text += str(int(i)) + u'\u212B/'
        #remove final /
        w_text = w_text[:-1]

        #Add text of datetime to image
        draw = ImageDraw.Draw(new_img)
        draw.text((10, 10), w_text, (255, 255, 255), font=font)

        #save image
        new_img.save(outfi)
    except:
        pass
Exemple #30
0
def format_img(i):
    global goes, goesdat, sday, eday
    global aceadat, ace

    ##    try:
    filep = dayarray[i]
    #output file
    #    outfi = sdir+'/working/seq{0:4d}.png'.format(i).replace(' ','0')
    outfi = filep.replace('raw', 'working').replace('fits', 'png')
    test = os.path.isfile(outfi)

    #check image quality

    check, img = qual_check(filep)

    #test to see if bmpfile exists
    if ((test == False) & (check)):
        print('Modifying file ' + filep)
        img = sunpy.map.Map(filep)
        fig, ax = plt.subplots(figsize=(sc * float(w0) / float(dpi),
                                        sc * float(h0) / float(dpi)))
        fig.set_dpi(dpi)
        fig.subplots_adjust(left=0, bottom=0, right=1, top=1)
        ax.set_axis_off()
        #ax.imshow(img.data,interpolation='none',cmap=cm.sdoaia193,vmin=0,vmax=255,origin='lower')
        # J. Prchlik 2016/10/06
        #Modified for fits files
        #        ax.imshow(np.arcsinh(img.data),interpolation='none',cmap=cm.sdoaia193,vmin=np.arcsinh(70.),vmax=np.arcsinh(7500.),origin='lower')
        #Block add J. Prchlik (2016/10/06) to give physical coordinate values
        #return extent of image
        maxx, minx, maxy, miny = img_extent(img)

        #Add wavelet transformation J. Prchlik 2018/01/17
        wav_img = img.data
        d_size = 15
        #get median filter
        n_back = medfilt(wav_img, kernel_size=d_size)
        #subtract median filter
        img_sub = wav_img - n_back

        #Use Biorthogonal Wavelet
        wavelet = 'bior2.6'

        #use 6 levels
        n_lev = 6
        o_wav = pywt.swt2(img_sub, wavelet, level=n_lev)
        #only use the first 4
        #f_img = pywt.iswt2(o_wav[0:4],wavelet)
        #use last four because PyWavelet switched the ordering
        f_img = pywt.iswt2(o_wav[-4:], wavelet)
        #Add  wavelet back into image
        f_img = f_img + wav_img

        #remove zero values
        f_img[f_img < 0.] = 0.

        #plot the image in matplotlib
        #        ax.imshow(img.data,interpolation='none',cmap=cm.sdoaia193,vmin=0,vmax=255,origin='lower',extent=[minx,maxx,miny,maxy])
        #ax.imshow(np.arcsinh(f_img),interpolation='none',cmap=cm.sdoaia193,origin='lower',vmin=np.arcsinh(5.),vmax=np.arcsinh(7500.),extent=[minx,maxx,miny,maxy])
        #switch to 0.25 power
        #increase v_min to 33.
        ax.imshow((f_img)**0.25,
                  interpolation='none',
                  cmap=cm.sdoaia193,
                  origin='lower',
                  vmin=(15.)**0.25,
                  vmax=(3500.)**0.25,
                  extent=[minx, maxx, miny, maxy])
        #        ax.set_axis_bgcolor('black')
        ax.text(-2000,
                -1100,
                'AIA 193 - ' + img.date.strftime('%Y/%m/%d - %H:%M:%S') + 'Z',
                color='white',
                fontsize=36,
                zorder=50,
                fontweight='bold')
        if goes:
            #format string for date on xaxis
            myFmt = mdates.DateFormatter('%m/%d')

            #only use goes data upto observed time

            use, = np.where((goesdat['time_dt'] < img.date + dt(minutes=150))
                            & (goesdat['Long'] > 0.0))
            clos, = np.where((goesdat['time_dt'] < img.date)
                             & (goesdat['Long'] > 0.0))
            ingoes = inset_axes(
                ax, width="27%", height="20%", loc=7,
                borderpad=-27)  #hack so it is outside normal boarders
            ingoes.set_position(Bbox([[0.525, 0.51], [1.5, 1.48]]))
            ingoes.set_facecolor('black')
            #set inset plotting information to be white
            ingoes.tick_params(axis='both', colors='white')
            ingoes.spines['top'].set_color('white')
            ingoes.spines['bottom'].set_color('white')
            ingoes.spines['right'].set_color('white')
            ingoes.spines['left'].set_color('white')
            #make grid
            ingoes.grid(color='gray', ls='dashdot')

            ingoes.xaxis.set_major_formatter(myFmt)

            ingoes.set_ylim([1.E-9, 1.E-2])
            ingoes.set_xlim([sday, eday])
            ingoes.set_ylabel(
                'X-ray Flux (1-8$\mathrm{\AA}$) [Watts m$^{-2}$]',
                color='white')
            ingoes.set_xlabel('Universal Time', color='white')
            try:
                ingoes.plot(goesdat['time_dt'][use],
                            goesdat['Long'][use],
                            color='white')
                ingoes.scatter(goesdat['time_dt'][clos][-1],
                               goesdat['Long'][clos][-1],
                               color='red',
                               s=10,
                               zorder=1000)
            except:
                print('No GOES data')
            ingoes.set_yscale('log')
#plot ace information
        if ((ace) & (goes)):
            use, = np.where((aceadat['time_dt'] < img.date + dt(minutes=150))
                            & (aceadat['S_1'] == 0.0) & (aceadat['S_2'] == 0)
                            & (aceadat['Speed'] > -1000.))
            clos, = np.where((aceadat['time_dt'] < img.date)
                             & (aceadat['S_1'] == 0) & (aceadat['S_2'] == 0)
                             & (aceadat['Speed'] > -1000))

            acetop = inset_axes(ingoes,
                                width='100%',
                                height='100%',
                                loc=9,
                                borderpad=-27)
            acebot = inset_axes(ingoes,
                                width='100%',
                                height='100%',
                                loc=8,
                                borderpad=-27)

            #set inset plotting information to be white
            acetop.tick_params(axis='both', colors='white')
            acetop.spines['top'].set_color('white')
            acetop.spines['bottom'].set_color('white')
            acetop.spines['right'].set_color('white')
            acetop.spines['left'].set_color('white')

            #set inset plotting information to be white
            acebot.tick_params(axis='both', colors='white')
            acebot.spines['top'].set_color('white')
            acebot.spines['bottom'].set_color('white')
            acebot.spines['right'].set_color('white')
            acebot.spines['left'].set_color('white')
            #make grid
            acebot.grid(color='gray', ls='dashdot')
            acetop.grid(color='gray', ls='dashdot')

            acetop.set_facecolor('black')
            acebot.set_facecolor('black')

            acetop.set_ylim([0., 50.])
            acebot.set_ylim([200., 1000.])

            acetop.set_xlim([sday, eday])
            acebot.set_xlim([sday, eday])

            acetop.set_xlabel('Universal Time', color='white')
            acebot.set_xlabel('Universal Time', color='white')

            acetop.set_ylabel('B$_\mathrm{T}$ [nT]', color='white')
            acebot.set_ylabel('Wind Speed [km/s]', color='white')

            #skip missing ace solar wind data
            try:
                acetop.plot(aceadat['time_dt'][use],
                            aceadat['Bt'][use],
                            color='white')
                acebot.plot(aceadat['time_dt'][use],
                            aceadat['Speed'][use],
                            color='white')

                acetop.scatter(aceadat['time_dt'][clos][-1],
                               aceadat['Bt'][clos][-1],
                               color='red',
                               s=10,
                               zorder=1000)
                acebot.scatter(aceadat['time_dt'][clos][-1],
                               aceadat['Speed'][clos][-1],
                               color='red',
                               s=10,
                               zorder=1000)
            except:
                print('Missing ACE data')

            acebot.xaxis.set_major_formatter(myFmt)
            acetop.xaxis.set_major_formatter(myFmt)

##        ax.set_axis_bgcolor('black')
#        ax.text(-1000,175,'AIA 193 - '+img.date.strftime('%Y/%m/%d - %H:%M:%S')+'Z',color='white',fontsize=36,zorder=50,fontweight='bold')
        fig.savefig(outfi, edgecolor='black', facecolor='black', dpi=dpi)
        plt.clf()
        plt.close()
##    except:
##        print 'Unable to create {0}'.format(outfi)
    return
def add_aia_image(stime, ax, tries=4):

    file_fmt = '{0:%Y/%m/%d/H%H00/AIA%Y%m%d_%H%M_}'
    tries = 4
    wave = [193]

    gsf.download(stime,
                 stime + datetime.timedelta(minutes=tries),
                 datetime.timedelta(minutes=1),
                 '',
                 nproc=1,
                 syn_arch='http://jsoc.stanford.edu/data/aia/synoptic/',
                 f_dir=file_fmt,
                 d_wav=wave)

    #Decide which AIA file to overplot
    nofile = True
    run = 0
    while nofile:
        testfile = file_fmt.format(stime + datetime.timedelta(
            minutes=run)) + '{0:04d}.fits'.format(wave[0])
        #exit once you find the file
        if os.path.isfile(testfile):
            filep = testfile
            nofile = False
        run += 1
        #exit after tries
        if run == tries + 1:
            nofile = False

    img = sunpy.map.Map(filep)
    #Block add J. Prchlik (2016/10/06) to give physical coordinate values
    #return extent of image
    maxx, minx, maxy, miny = img_extent(img)

    #Add wavelet transformation J. Prchlik 2018/01/17
    wav_img = img.data
    d_size = 15
    #get median filter
    n_back = medfilt(wav_img, kernel_size=d_size)
    #subtract median filter
    img_sub = wav_img - n_back

    #Use Biorthogonal Wavelet
    wavelet = 'bior2.6'

    #use 6 levels
    n_lev = 6
    o_wav = pywt.swt2(img_sub, wavelet, level=n_lev)
    #only use the first 4
    f_img = pywt.iswt2(o_wav[0:4], wavelet)
    #Add  wavelet back into image
    f_img = f_img + wav_img

    #remove zero values
    f_img[f_img < 0.] = 0.

    #set alpha values only works with png file 2018/05/02 J. Prchlik
    #alphas = np.ones(f_img.shape)
    #alphas[:,:] = np.linspace(1,0,f_img.shape[0])
    colors = Normalize((15.)**0.25, (3500.)**0.25, clip=True)
    #img_193 = cmap((f_img)**0.25)
    img_193 = cm.sdoaia193(colors((f_img)**0.25))

    #get radius values and convert to arcsec
    mesh_x2, mesh_y2 = np.meshgrid(np.arange(img.data.shape[0]),
                                   np.arange(img.data.shape[1]))
    mesh_x2 = mesh_x2.T * (maxx - minx) / f_img.shape[0] + minx
    mesh_y2 = mesh_y2.T * (maxy - miny) / f_img.shape[1] + miny

    #mask out less than a solar radius
    rsun = sunpy.sun.solar_semidiameter_angular_size(
        t=img.meta['date-obs'][:-1].replace('T', ' ')).value
    r2 = np.sqrt(mesh_x2**2 + mesh_y2**2) / rsun

    rmin = .98
    rfad = 1.02
    rep2 = (r2 < rmin)
    rep3 = ((r2 > rmin) & (r2 < rfad))
    #set alpha values
    img_193[..., 3][rep2] = 0
    img_193[..., 3][rep3] = (r2[rep3] - rmin) / (rfad - rmin)

    #plot the image in matplotlib
    ax.imshow(img_193,
              interpolation='none',
              cmap=cm.sdoaia193,
              origin='lower',
              vmin=(15.)**0.25,
              vmax=(3500.)**0.25,
              extent=[minx, maxx, miny, maxy],
              zorder=0)
    return ax
Exemple #32
0
 def adjoint(self, point: array) -> array:
     return pywt.iswt2(self.cube2coeffs(point), self.wlt, norm=self.norm)
Exemple #33
0
 def time_iswt2(self, n, wavelet):
     pywt.iswt2(self.data, wavelet)
def udDTCWTLvl1Inv(coeffsAll, lvl1Filters):

    AA_LL = coeffsAll[1][0]
    AB_LL = coeffsAll[1][1]
    BA_LL = coeffsAll[1][2]
    BB_LL = coeffsAll[1][3]

    (coeffs0, coeffs1, coeffs2, coeffs3, coeffs4, coeffs5) = coeffsAll[0]

    fac = np.sqrt(0.5)
    AA_HH = fac * (np.real(coeffs4) + np.real(coeffs1))
    AA_LH = fac * (np.real(coeffs0) + np.real(coeffs5))
    AA_HL = fac * (np.real(coeffs2) + np.real(coeffs3))

    BB_HH = fac * (np.real(coeffs4) - np.real(coeffs1))
    BB_LH = fac * (np.real(coeffs0) - np.real(coeffs5))
    BB_HL = fac * (np.real(coeffs2) - np.real(coeffs3))

    AB_HH = fac * (np.imag(coeffs1) + np.imag(coeffs4))
    AB_LH = fac * (np.imag(coeffs5) + np.imag(coeffs0))
    AB_HL = fac * (np.imag(coeffs3) + np.imag(coeffs2))

    BA_HH = fac * (np.imag(coeffs1) - np.imag(coeffs4))
    BA_LH = fac * (np.imag(coeffs5) - np.imag(coeffs0))
    BA_HL = fac * (np.imag(coeffs3) - np.imag(coeffs2))

    # pad the lvl 0 filters
    lvl1Len = np.max((lvl1Filters[0].shape[0], lvl1Filters[1].shape[0],
                      lvl1Filters[2].shape[0], lvl1Filters[3].shape[0]))
    fac = 1.0
    h0o = np.roll(
        np.pad(fac * lvl1Filters[0][::1, 0], [3, 3], mode='constant')[::1],
        1).tolist()
    g0o = np.roll(
        np.pad(fac * lvl1Filters[1][::1, 0], [0, 0], mode='constant')[::1],
        0).tolist()
    h1o = np.roll(
        np.pad(fac * lvl1Filters[2][::1, 0], [0, 0], mode='constant')[::1],
        1).tolist()
    g1o = np.roll(
        np.pad(fac * lvl1Filters[3][::1, 0], [3, 3], mode='constant')[::1],
        0).tolist()

    hsd = 1

    lvl1WvtA = pywt.Wavelet('lvl1',
                            filter_bank=[
                                np.roll(h0o, 0),
                                np.roll(h1o, 0),
                                np.roll(g0o, 0),
                                np.roll(g1o, 0)
                            ])  #g0o,g1o])
    lvl1WvtB = pywt.Wavelet('lvl1',
                            filter_bank=[
                                np.roll(h0o, 1),
                                np.roll(h1o, 1),
                                np.roll(g0o, 1),
                                np.roll(g1o, 1)
                            ])

    # Undecimated "a trous" transform tree AA
    swtCoeffs = (AA_LL, (AA_LH, AA_HL, AA_HH))
    AA_LL = pywt.iswt2((swtCoeffs, ), wavelet=(lvl1WvtA, lvl1WvtA))
    # Undecimated "a trous" transform tree AB
    swtCoeffs = (AB_LL, (AB_LH, AB_HL, AB_HH))
    AB_LL = pywt.iswt2((swtCoeffs, ), wavelet=(lvl1WvtA, lvl1WvtB))
    # Undecimated "a trous" transform tree BA
    swtCoeffs = (BA_LL, (BA_LH, BA_HL, BA_HH))
    BA_LL = pywt.iswt2((swtCoeffs, ), wavelet=(lvl1WvtB, lvl1WvtA))
    # Undecimated "a trous" transform tree BB
    swtCoeffs = (BB_LL, (BB_LH, BB_HL, BB_HH))
    BB_LL = pywt.iswt2((swtCoeffs, ), wavelet=(lvl1WvtB, lvl1WvtB))

    return (AA_LL, AB_LL, BA_LL, BB_LL)
def udDTCWTLvl2Inv(coeffsAll, lvl2Filters):

    AA_LL = coeffsAll[1][0]
    AB_LL = coeffsAll[1][1]
    BA_LL = coeffsAll[1][2]
    BB_LL = coeffsAll[1][3]

    nLevels = len(coeffsAll[0])

    # Index zero is the lvl1 part of the transform
    # So iterate down to index 1 and no further.
    for iLev in range(nLevels - 1, 0, -1):

        (coeffs0, coeffs1, coeffs2, coeffs3, coeffs4,
         coeffs5) = coeffsAll[0][iLev]

        fac = np.sqrt(0.5)
        AA_HH = fac * (np.real(coeffs4) + np.real(coeffs1))
        AA_LH = fac * (np.real(coeffs0) + np.real(coeffs5))
        AA_HL = fac * (np.real(coeffs2) + np.real(coeffs3))

        BB_HH = fac * (np.real(coeffs4) - np.real(coeffs1))
        BB_LH = fac * (np.real(coeffs0) - np.real(coeffs5))
        BB_HL = fac * (np.real(coeffs2) - np.real(coeffs3))

        AB_HH = fac * (np.imag(coeffs1) + np.imag(coeffs4))
        AB_LH = fac * (np.imag(coeffs5) + np.imag(coeffs0))
        AB_HL = fac * (np.imag(coeffs3) + np.imag(coeffs2))

        BA_HH = fac * (np.imag(coeffs1) - np.imag(coeffs4))
        BA_LH = fac * (np.imag(coeffs5) - np.imag(coeffs0))
        BA_HL = fac * (np.imag(coeffs3) - np.imag(coeffs2))

        # phi (0)
        h0a = lvl2Filters[0][:, 0]
        h0b = lvl2Filters[1][:, 0]
        g0a = lvl2Filters[2][:, 0]
        g0b = lvl2Filters[3][:, 0]

        # psi (1)
        h1a = lvl2Filters[4][:, 0]
        h1b = lvl2Filters[5][:, 0]
        g1a = lvl2Filters[6][:, 0]
        g1b = lvl2Filters[7][:, 0]
        # On the first and future iterations, insert extra '0' between filter coefficients
        # This is to account for lack of decimation, a la "a trous".
        # The pywavelets swt2 already does this but this accounts for previous iterations.
        for iiLev in range(
                0, iLev):  #upscale once for lvl 1, twice for lvl 2, e.t.c.
            h0a = upscaleFilter(h0a)
            h0b = upscaleFilter(h0b)
            h1a = upscaleFilter(h1a)
            h1b = upscaleFilter(h1b)
            g0a = upscaleFilter(g0a)
            g0b = upscaleFilter(g0b)
            g1a = upscaleFilter(g1a)
            g1b = upscaleFilter(g1b)

        # What is the half-sample delay offset?
        hsd = 2**iLev

        if iLev % 2 == 0:
            lvl2WvtA = pywt.Wavelet('lvl2a',
                                    filter_bank=[
                                        np.roll(h0a, 0),
                                        np.roll(h1a, 0),
                                        np.roll(g0a, hsd),
                                        np.roll(g1a, hsd)
                                    ])
            lvl2WvtB = pywt.Wavelet('lvl2b',
                                    filter_bank=[
                                        np.roll(h0b, hsd),
                                        np.roll(h1b, hsd),
                                        np.roll(g0b, 0),
                                        np.roll(g1b, 0)
                                    ])
        else:
            lvl2WvtB = pywt.Wavelet('lvl2b',
                                    filter_bank=[
                                        np.roll(h0a, 0),
                                        np.roll(h1a, 0),
                                        np.roll(g0a, hsd),
                                        np.roll(g1a, hsd)
                                    ])
            lvl2WvtA = pywt.Wavelet('lvl2a',
                                    filter_bank=[
                                        np.roll(h0b, hsd),
                                        np.roll(h1b, hsd),
                                        np.roll(g0b, 0),
                                        np.roll(g1b, 0)
                                    ])

        # Undecimated "a trous" transform tree AA
        swtCoeffs = (AA_LL, (AA_LH, AA_HL, AA_HH))
        AA_LL = np.roll(np.roll(pywt.iswt2((swtCoeffs, ),
                                           wavelet=(lvl2WvtA, lvl2WvtA)),
                                -1,
                                axis=0),
                        -1,
                        axis=1)
        # Undecimated "a trous" transform tree AB
        swtCoeffs = (AB_LL, (AB_LH, AB_HL, AB_HH))
        AB_LL = np.roll(np.roll(pywt.iswt2((swtCoeffs, ),
                                           wavelet=(lvl2WvtA, lvl2WvtB)),
                                -1,
                                axis=0),
                        -1,
                        axis=1)
        # Undecimated "a trous" transform tree BA
        swtCoeffs = (BA_LL, (BA_LH, BA_HL, BA_HH))
        BA_LL = np.roll(np.roll(pywt.iswt2((swtCoeffs, ),
                                           wavelet=(lvl2WvtB, lvl2WvtA)),
                                -1,
                                axis=0),
                        -1,
                        axis=1)
        # Undecimated "a trous" transform tree BB
        swtCoeffs = (BB_LL, (BB_LH, BB_HL, BB_HH))
        BB_LL = np.roll(np.roll(pywt.iswt2((swtCoeffs, ),
                                           wavelet=(lvl2WvtB, lvl2WvtB)),
                                -1,
                                axis=0),
                        -1,
                        axis=1)

    return (AA_LL, AB_LL, BA_LL, BB_LL)