def test_equal_oddshape(size):
    wave = 'db3'
    J = 3
    mode = 'symmetric'
    x = torch.randn(5, 4, *size).to(dev)
    dwt1 = DWTForward(J=J, wave=wave, mode=mode).to(dev)
    iwt1 = DWTInverse(wave=wave, mode=mode).to(dev)
    dwt2 = DWTForward(J=J, wave=wave, mode=mode).to(dev)
    iwt2 = DWTInverse(wave=wave, mode=mode).to(dev)

    yl1, yh1 = dwt1(x)
    x1 = iwt1((yl1, yh1))
    yl2, yh2 = dwt2(x)
    x2 = iwt2((yl2, yh2))

    # Test it is the same as doing the PyWavelets wavedec
    coeffs = pywt.wavedec2(x.cpu().numpy(), wave, level=J, axes=(-2,-1),
                           mode=mode)
    X2 = pywt.waverec2(coeffs, wave, mode=mode)
    np.testing.assert_array_almost_equal(X2, x1.detach(), decimal=PREC_FLT)
    np.testing.assert_array_almost_equal(X2, x2.detach(), decimal=PREC_FLT)
    np.testing.assert_array_almost_equal(yl1.cpu(), coeffs[0], decimal=PREC_FLT)
    np.testing.assert_array_almost_equal(yl2.cpu(), coeffs[0], decimal=PREC_FLT)
    for j in range(J):
        for b in range(3):
            np.testing.assert_array_almost_equal(
                coeffs[J-j][b], yh1[j][:,:,b].cpu(), decimal=PREC_FLT)
            np.testing.assert_array_almost_equal(
                coeffs[J-j][b], yh2[j][:,:,b].cpu(), decimal=PREC_FLT)
Exemple #2
0
    def __init__(self, outC1, outC2, outC3):
        super(Dwtconv, self).__init__()

        self.dwt1 = DWTForward(J=1, wave='haar', mode='symmetric')
        outChannel_conv1 = outC1 // 2
        nIn_conv1 = 4
        self.conv1_1 = nn.Conv2d(nIn_conv1,
                                 outChannel_conv1,
                                 kernel_size=3,
                                 padding=1)
        self.bn1_1 = nn.BatchNorm2d(outChannel_conv1)
        self.relu1_1 = nn.ReLU()
        self.conv1_2 = nn.Conv2d(outChannel_conv1,
                                 outC1,
                                 kernel_size=3,
                                 padding=1)
        self.bn1_2 = nn.BatchNorm2d(outC1)
        self.relu1_2 = nn.ReLU()

        nIn_conv2 = 4
        self.dwt2 = DWTForward(J=1, wave='haar', mode='symmetric')
        outChannel_conv2 = outC2 // 2
        self.conv2_1 = nn.Conv2d(nIn_conv2,
                                 outChannel_conv2,
                                 kernel_size=3,
                                 padding=1)
        self.bn2_1 = nn.BatchNorm2d(outChannel_conv2)
        self.relu2_1 = nn.ReLU()
        self.conv2_2 = nn.Conv2d(outChannel_conv2,
                                 outC2,
                                 kernel_size=3,
                                 padding=1)
        self.bn2_2 = nn.BatchNorm2d(outC2)
        self.relu2_2 = nn.ReLU()

        self.dwt3 = DWTForward(J=1, wave='haar', mode='symmetric')
        outChannel_conv3 = outC3 // 2
        nIn_conv3 = 4
        self.conv3_1 = nn.Conv2d(nIn_conv3,
                                 outChannel_conv3,
                                 kernel_size=3,
                                 padding=1)
        self.bn3_1 = nn.BatchNorm2d(outChannel_conv3)
        self.relu3_1 = nn.ReLU()
        self.conv3_2 = nn.Conv2d(outChannel_conv3,
                                 outC3,
                                 kernel_size=3,
                                 padding=1)
        self.bn3_2 = nn.BatchNorm2d(outC3)
        self.relu3_2 = nn.ReLU()
Exemple #3
0
    def __init__(self,
                 J=3,
                 wave='haar',
                 window_size=15,
                 stride=8,
                 verbose=False):
        '''
        :param J: int, decomposition levels
        :param wave: wavelet to use,
                     see https://pywavelets.readthedocs.io/en/latest/ref/wavelets.html#built-in-wavelets-wavelist
                     for a list of available wavelets
        :param window_size: int, the window which the local CW-SSIM index is calculated
        :param stride: int, controls how much will the window move in one single step

        In general, local index masked by window is calculated first. The window
        the strides through the whole image scale to get a list of local indices.

        Formula (9) in reference [2] is used to compute the local index.

        The index returned by self.forward is the mean of those local indices.
        '''
        super().__init__()
        self.window_size = window_size
        self.J = J
        self.stride = stride
        self.dwt = DWTForward(J=J, wave=wave)
        self.verbose = verbose
def test_equal_double(wave, J, mode):
    with set_double_precision():
        x = torch.randn(5, 4, 64, 64).to(dev)
        assert x.dtype == torch.float64
        dwt = DWTForward(J=J, wave=wave, mode=mode).to(dev)
        iwt = DWTInverse(wave=wave, mode=mode).to(dev)

    yl, yh = dwt(x)
    x2 = iwt((yl, yh))

    # Test the forward and inverse worked
    np.testing.assert_array_almost_equal(x.cpu(),
                                         x2.detach().cpu(),
                                         decimal=PREC_DBL)
    coeffs = pywt.wavedec2(x.cpu().numpy(),
                           wave,
                           level=J,
                           axes=(-2, -1),
                           mode=mode)
    np.testing.assert_array_almost_equal(yl.cpu(), coeffs[0], decimal=7)
    for j in range(J):
        for b in range(3):
            np.testing.assert_array_almost_equal(coeffs[J - j][b],
                                                 yh[j][:, :, b].cpu(),
                                                 decimal=PREC_DBL)
def test_commutativity(wave, J, j):
    # Test the commutativity of the dwt
    C = 3
    Y = torch.randn(4, C, 128, 128, requires_grad=True, device=dev)
    dwt = DWTForward(J=J, wave=wave).to(dev)
    iwt = DWTInverse(wave=wave).to(dev)

    coeffs = dwt(Y)
    coeffs_zero = dwt(torch.zeros_like(Y))
    # Set level j LH to be nonzero
    coeffs_zero[1][j][:,:,0] = coeffs[1][j][:,:,0]
    ya = iwt(coeffs_zero)
    # Set level j HL to also be nonzero
    coeffs_zero[1][j][:,:,1] = coeffs[1][j][:,:,1]
    yab = iwt(coeffs_zero)
    # Set level j LH to be nonzero
    coeffs_zero[1][j][:,:,0] = torch.zeros_like(coeffs[1][j][:,:,0])
    yb = iwt(coeffs_zero)
    # Set level j HH to also be nonzero
    coeffs_zero[1][j][:,:,2] = coeffs[1][j][:,:,2]
    ybc = iwt(coeffs_zero)
    # Set level j HL to be nonzero
    coeffs_zero[1][j][:,:,1] = torch.zeros_like(coeffs[1][j][:,:,1])
    yc = iwt(coeffs_zero)

    np.testing.assert_array_almost_equal(
        (ya+yb).detach().cpu(), yab.detach().cpu(), decimal=PREC_FLT)
    np.testing.assert_array_almost_equal(
        (yc+yb).detach().cpu(), ybc.detach().cpu(), decimal=PREC_FLT)
Exemple #6
0
def get_wavelets(x, device):
    xfm = DWTForward(J=3, mode='zero', wave='db4').to(
        device)  # Accepts all wave types available to PyWavelets
    Yl, Yh = xfm(x)

    batch_size = x.shape[0]
    channels = x.shape[1]
    rows = nextPowerOf2(Yh[0].shape[-2] * 2)
    cols = nextPowerOf2(Yh[0].shape[-1] * 2)
    wavelets = torch.zeros(batch_size, channels, rows, cols).to(device)
    # Yl is LL coefficients, Yh is list of higher bands with finest frequency in the beginning.
    for i, band in enumerate(Yh):
        irow = rows // 2**(i + 1)
        icol = cols // 2**(i + 1)
        wavelets[:, :, 0:(band[:, :, 0, :, :].shape[-2]),
                 icol:(icol + band[:, :, 0, :, :].shape[-1])] = band[:, :,
                                                                     0, :, :]
        wavelets[:, :, irow:(irow + band[:, :, 0, :, :].shape[-2]),
                 0:(band[:, :, 0, :, :].shape[-1])] = band[:, :, 1, :, :]
        wavelets[:, :, irow:(irow + band[:, :, 0, :, :].shape[-2]),
                 icol:(icol + band[:, :, 0, :, :].shape[-1])] = band[:, :,
                                                                     2, :, :]

    wavelets[:, :, :Yl.shape[-2], :Yl.shape[-1]] = Yl  # Put in LL coefficients
    return wavelets
Exemple #7
0
def test_fwd(mode):
    with set_double_precision():
        x = torch.randn(1, 3, 16, 16, device=dev, requires_grad=True)
        xfm = DWTForward(J=2).to(dev)

    input = (x, xfm.h0_row, xfm.h1_row, xfm.h0_col, xfm.h1_col, mode)
    gradcheck(AFB2D.apply, input, eps=EPS, atol=ATOL)
Exemple #8
0
    def __init__(self, opt):
        super(LRHR_wavelet_Mixunpair_Dataset, self).__init__()
        self.opt = opt
        self.paths_LR = None
        self.paths_HR = None
        self.paths_RealLR = None
        self.LR_env = None  # environment for lmdb
        self.HR_env = None

        # read image list from subset list txt
        if opt['subset_file'] is not None and opt['phase'] == 'train':
            with open(opt['subset_file']) as f:
                self.paths_HR = sorted([os.path.join(opt['dataroot_HR'], line.rstrip('\n')) \
                        for line in f])
            if opt['dataroot_LR'] is not None:
                raise NotImplementedError('Now subset only supports generating LR on-the-fly.')
        else:  # read image list from lmdb or image files
            self.HR_env, self.paths_HR = util.get_image_paths(opt['data_type'], opt['dataroot_HR'])
            self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR'])
            self.LR_env, self.paths_RealLR = util.get_image_paths(opt['data_type'], opt['dataroot_RealLR'])
            self.LR_env, self.paths_weights = util.get_image_paths(opt['data_type'], opt['dataroot_weights'])

        assert self.paths_HR, 'Error: HR path is empty.'
        # if self.paths_LR and self.paths_HR:
        #     assert len(self.paths_LR) == len(self.paths_HR), \
        #         'HR and LR datasets have different number of images - {}, {}.'.format(\
        #         len(self.paths_LR), len(self.paths_HR))

        self.random_scale_list = [1]
        self.DWT2 = DWTForward(J=1, wave='haar', mode='zero')
def test_gradients_inv(wave, J, mode):
    """ Gradient of inverse function should be forward function with filters
    swapped """
    wave = pywt.Wavelet(wave)
    fwd_filts = (wave.dec_lo, wave.dec_hi)
    inv_filts = (wave.dec_lo[::-1], wave.dec_hi[::-1])
    dwt = DWTForward(J=J, wave=fwd_filts, mode=mode).to(dev)
    iwt = DWTInverse(wave=inv_filts, mode=mode).to(dev)

    # Get the shape of the pyramid
    temp = torch.zeros(5,6,128,128).to(dev)
    l, h = dwt(temp)
    # Create our inputs
    yl = torch.randn(*l.shape, requires_grad=True, device=dev)
    yh = [torch.randn(*h[i].shape, requires_grad=True, device=dev)
          for i in range(J)]
    y = iwt((yl, yh))

    # Test the gradients
    yg = torch.randn(*y.shape, device=dev)
    y.backward(yg, retain_graph=True)
    dyl, dyh = dwt(yg)

    # test the lowpass
    np.testing.assert_array_almost_equal(yl.grad.detach().cpu(), dyl.cpu(),
                                         decimal=PREC_FLT)

    # Test the bandpass
    for j in range(J):
        np.testing.assert_array_almost_equal(yh[j].grad.detach().cpu(),
                                             dyh[j].cpu(),
                                             decimal=PREC_FLT)
Exemple #10
0
    def __init__(self,
                 C,
                 F,
                 lp_size=3,
                 bp_sizes=(1, ),
                 q=1.0,
                 wave='db2',
                 mode='zero',
                 xfm=True,
                 ifm=True):
        super().__init__()
        self.C = C
        self.F = F
        self.J = len(bp_sizes)

        if xfm:
            self.XFM = DWTForward(J=self.J, mode=mode, wave=wave)
        else:
            self.XFM = lambda x: x

        if q < 0:
            self.shrink = ReLUWaveCoeffs()
        else:
            self.shrink = lambda x: x

        self.GainLayer = WaveGainLayer(C, F, lp_size, bp_sizes)

        if ifm:
            self.IFM = DWTInverse(mode=mode, wave=wave)
        else:
            self.IFM = lambda x: x
def test_gradients_fwd(wave, J, mode):
    """ Gradient of forward function should be inverse function with filters
    swapped """
    im = np.random.randn(5,6,128, 128).astype('float32')
    imt = torch.tensor(im, dtype=torch.float32, requires_grad=True, device=dev)

    wave = pywt.Wavelet(wave)
    fwd_filts = (wave.dec_lo, wave.dec_hi)
    inv_filts = (wave.dec_lo[::-1], wave.dec_hi[::-1])
    dwt = DWTForward(J=J, wave=fwd_filts, mode=mode).to(dev)
    iwt = DWTInverse(wave=inv_filts, mode=mode).to(dev)

    yl, yh = dwt(imt)

    # Test the lowpass
    ylg = torch.randn(*yl.shape, device=dev)
    yl.backward(ylg, retain_graph=True)
    zeros = [torch.zeros_like(yh[i]) for i in range(J)]
    ref = iwt((ylg, zeros))
    np.testing.assert_array_almost_equal(imt.grad.detach().cpu(), ref.cpu(),
                                         decimal=PREC_FLT)

    # Test the bandpass
    for j, y in enumerate(yh):
        imt.grad.zero_()
        g = torch.randn(*y.shape, device=dev)
        y.backward(g, retain_graph=True)
        hps = [zeros[i] for i in range(J)]
        hps[j] = g
        ref = iwt((torch.zeros_like(yl), hps))
        np.testing.assert_array_almost_equal(imt.grad.detach().cpu(), ref.cpu(),
                                             decimal=PREC_FLT)
Exemple #12
0
 def __init__(self, J=3, wave='db4', device='cpu', dtype=torch.float64):
     super().__init__()
     _dtype = torch.get_default_dtype()
     torch.set_default_dtype(dtype)
     self.dwt = DWTForward(J=J, wave=wave, mode='per').to(device)
     self.idwt = DWTInverse(wave=wave, mode='per').to(device)
     torch.set_default_dtype(_dtype)
     self.norm_bound = 1.
def test_ok(wave, J, mode):
    x = torch.randn(5, 4, 64, 64).to(dev)
    dwt = DWTForward(J=J, wave=wave, mode=mode).to(dev)
    iwt = DWTInverse(wave=wave, mode=mode).to(dev)
    yl, yh = dwt(x)
    x2 = iwt((yl, yh))
    # Can have data errors sometimes
    assert yl.is_contiguous()
    for j in range(J):
        assert yh[j].is_contiguous()
    assert x2.is_contiguous()
Exemple #14
0
 def wavelets(self, x):
     output = torch.zeros(x.shape[0], x.shape[1] * 4, int(x.shape[2] / 2),
                          int(x.shape[3] / 2))
     output = output.cuda()
     xfm = DWTForward(J=1, wave='haar', mode='symmetric').cuda()
     Yl, Yh = xfm(x)
     output[:, [0, 4, 8], :] = Yl
     output[:, 1:4, :] = Yh[0][:, 0, :]
     output[:, 5:8, :] = Yh[0][:, 1, :]
     output[:, 9:12, :] = Yh[0][:, 2, :]
     return output
Exemple #15
0
 def __init__(self, C, F, k=4, stride=1, J=1, wd=0, wd1=None, right=True):
     super().__init__()
     self.wd = wd
     if wd1 is None:
         self.wd1 = wd
     else:
         self.wd1 = wd1
     self.C = C
     self.F = F
     x = torch.zeros(F, C, k, k)
     torch.nn.init.xavier_uniform_(x)
     xfm = DWTForward(J=J)
     self.ifm = DWTInverse()
     yl, yh = xfm(x)
     self.J = J
     if k == 4 and J == 1:
         self.gl = nn.Parameter(torch.zeros_like(yl))
         self.gh = nn.Parameter(torch.zeros_like(yh[0]))
         self.gl.data = yl.data
         self.gh.data = yh[0].data
         if right:
             self.pad = (1, 2, 1, 2)
         else:
             self.pad = (2, 1, 2, 1)
     elif k == 8 and J == 1:
         self.gl = nn.Parameter(torch.zeros_like(yl))
         self.gh = nn.Parameter(torch.zeros_like(yh[0]))
         self.gl.data = yl.data
         self.gh.data = yh[0].data
         if right:
             self.pad = (3, 4, 3, 4)
         else:
             self.pad = (4, 3, 4, 3)
     elif k == 8 and J == 2:
         self.gl = nn.Parameter(torch.zeros_like(yl))
         self.gh = nn.Parameter(torch.zeros_like(yh[1]))
         self.gl.data = yl.data
         self.gh.data = yh[1].data
         if right:
             self.pad = (3, 4, 3, 4)
         else:
             self.pad = (4, 3, 4, 3)
     elif k == 8 and J == 3:
         self.gl = nn.Parameter(torch.zeros_like(yl))
         self.gh = nn.Parameter(torch.zeros_like(yh[2]))
         self.gl.data = yl.data
         self.gh.data = yh[2].data
         if right:
             self.pad = (3, 4, 3, 4)
         else:
             self.pad = (4, 3, 4, 3)
     else:
         raise NotImplementedError
Exemple #16
0
 def __init__(self, in_planes, lifting_size, kernel_size, no_bottleneck,
              share_weights, simple_lifting, regu_details, regu_approx):
     super(Haar, self).__init__()
     from pytorch_wavelets import DWTForward
     self.wavelet = DWTForward(J=1,mode='zero', wave='db1').cuda()
     self.share_weights = share_weights
     if no_bottleneck:
         # We still want to do a BN and RELU, but we will not perform a conv
         # as the input_plane and output_plare are the same
         self.bootleneck = BottleneckBlock(in_planes * 1, in_planes * 1)            
     else:
         self.bootleneck = BottleneckBlock(in_planes * 4, in_planes * 2)
Exemple #17
0
def init_dwt(resume=None, shape=None, wave=None, colors=None):
    size = None
    wp_fake = pywt.WaveletPacket2D(data=np.zeros(shape[2:]),
                                   wavelet='db1',
                                   mode='symmetric')
    xfm = DWTForward(J=wp_fake.maxlevel, wave=wave, mode='symmetric').cuda()
    # xfm = DTCWTForward(J=lvl, biort='near_sym_b', qshift='qshift_b').cuda() # 4x more params, biort ['antonini','legall','near_sym_a','near_sym_b']
    ifm = DWTInverse(wave=wave,
                     mode='symmetric').cuda()  # symmetric zero periodization
    # ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda() # 4x more params, biort ['antonini','legall','near_sym_a','near_sym_b']
    if resume is None:  # random init
        Yl_in, Yh_in = xfm(torch.zeros(shape).cuda())
        Ys = [torch.randn(*Y.shape).cuda() for Y in [Yl_in, *Yh_in]]
    elif isinstance(resume, str):
        if os.path.isfile(resume):
            if os.path.splitext(resume)[1].lower()[1:] in [
                    'jpg', 'png', 'tif', 'bmp'
            ]:
                img_in = imread(resume)
                Ys = img2dwt(img_in, wave=wave, colors=colors)
                print(' loaded image', resume, img_in.shape, 'level',
                      len(Ys) - 1)
                size = img_in.shape[:2]
                wp_fake = pywt.WaveletPacket2D(data=np.zeros(size),
                                               wavelet='db1',
                                               mode='symmetric')
                xfm = DWTForward(J=wp_fake.maxlevel,
                                 wave=wave,
                                 mode='symmetric').cuda()
            else:
                Ys = torch.load(resume)
                Ys = [y.detach().cuda() for y in Ys]
        else:
            print(' Snapshot not found:', resume)
            exit()
    else:
        Ys = [y.cuda() for y in resume]
    # print('level', len(Ys)-1, 'low freq', Ys[0].cpu().numpy().shape)
    return Ys, xfm, ifm, size
Exemple #18
0
    def filter_wavelet(self, x, norm=True):
        DWT2 = DWTForward(J=1, wave='haar', mode='reflect')
        LL, Hc = DWT2(x)
        LH, HL, HH = Hc[0][0, :, 0, :, :], Hc[0][0, :, 1, :, :], Hc[0][0, :,
                                                                       2, :, :]

        if norm:
            LL = LL * 0.5
            LH, HL, HH = LH * 0.5 + 0.5, HL * 0.5 + 0.5, HH * 0.5 + 0.5
        if self.cs == 'sum':
            return LL, (LH + HL + HH) / 3.
        elif self.cs == 'cat':
            return LL, torch.cat((LH, HL, HH), 1)
Exemple #19
0
    def __init__(
        self,
        num_levels: int = 1,
        wave: str = 'db2',
        mode: str = 'periodization',
        device: Optional[str] = None,
    ) -> None:
        self.num_levels = num_levels
        self.wave = wave
        self.mode = mode

        self.xfm = DWTForward(J=num_levels, wave=wave, mode=mode).to(device)
        self.ifm = DWTInverse(wave=wave, mode=mode).to(device)
Exemple #20
0
 def __init__(self, in_ch, out_ch=None, filter_sz=3, num_conv=3):
     super(WCNN, self).__init__()
     if out_ch is None:
         out_ch = 4 * in_ch
     self.DwT = DWTForward(J=1, wave='haar', mode='zero')
     # 4 * input channels since DwT creates 4 outputs per image
     modules = []
     modules.append(nn.Conv2d(4 * in_ch, out_ch, kernel_size=3, padding=1))
     #modules.append(nn.BatchNorm2d(num_features=out_ch))
     modules.append(nn.ReLU())
     for i in range(num_conv):
         modules.append(nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
         #modules.append(nn.BatchNorm2d(num_features=out_ch))
         modules.append(nn.ReLU())
     self.conv = nn.Sequential(*modules)
Exemple #21
0
 def __init__(self,
              discriminator_size,
              perceptual_size=256,
              generator_weights=[1.0, 1.0, 1.0, 0.5, 0.1]):
     super().__init__()
     # l1/l2 loss
     self.l1_loss = nn.L1Loss()
     self.l2_loss = nn.MSELoss()
     # perceptual_loss
     self.perceptual_loss = PerceptualLoss(perceptual_size)
     # gan loss
     self.gan_loss = GANLoss(image_size=int(discriminator_size))
     self.generator_weights = generator_weights
     self.dwt = DWTForward(J=1, mode='zero', wave='db1')
     self.idwt = DWTInverse(mode="zero", wave="db1")
Exemple #22
0
def img2dwt(img_in, wave='coif2', sharp=0.3, colors=1.):
    image_t = un_rgb(img_in, colors=colors)
    with torch.no_grad():
        wp_fake = pywt.WaveletPacket2D(data=np.zeros(image_t.shape[2:]),
                                       wavelet='db1',
                                       mode='zero')
        lvl = wp_fake.maxlevel
        # print(image_t.shape, lvl)
        xfm = DWTForward(J=lvl, wave=wave, mode='symmetric').cuda()
        Yl_in, Yh_in = xfm(image_t.cuda())
        Ys = [Yl_in, *Yh_in]
    scale = dwt_scale(Ys, sharp)
    for i in range(len(Ys) - 1):
        Ys[i + 1] /= scale[i]
    return Ys
Exemple #23
0
 def __init__(self,
              main_dir,
              device,
              transform=[
                  transforms.Resize((768, 1024), interpolation=4),
                  transforms.ToTensor()
              ],
              levels=2):
     self.main_dir = main_dir
     self.transforms = transforms.Compose(transform)
     self.image_list = glob.glob(main_dir + "/*/*.png")
     self.levels = levels
     self.names = [os.path.basename(x) for x in self.image_list]
     self.xfm = DWTForward(J=self.levels,
                           mode='periodization',
                           wave='bior3.3').to(device)
def test_equal(wave, J, mode):
    x = torch.randn(5, 4, 64, 64).to(dev)
    dwt = DWTForward(J=J, wave=wave, mode=mode).to(dev)
    iwt = DWTInverse(wave=wave, mode=mode).to(dev)
    yl, yh = dwt(x)
    x2 = iwt((yl, yh))

    # Test the forward and inverse worked
    np.testing.assert_array_almost_equal(x.cpu(), x2.detach(), decimal=PREC_FLT)
    # Test it is the same as doing the PyWavelets wavedec with reflection
    # padding
    coeffs = pywt.wavedec2(x.cpu().numpy(), wave, level=J, axes=(-2,-1),
                           mode=mode)
    np.testing.assert_array_almost_equal(yl.cpu(), coeffs[0], decimal=PREC_FLT)
    for j in range(J):
        for b in range(3):
            np.testing.assert_array_almost_equal(
                coeffs[J-j][b], yh[j][:,:,b].cpu(), decimal=PREC_FLT)
Exemple #25
0
 def __init__(self,
              wt_type='DTCWT',
              biort='near_sym_b',
              qshift='qshift_b',
              J=5,
              wave='db3',
              mode='zero',
              device='cuda',
              requires_grad=True):
     super().__init__()
     if wt_type == 'DTCWT':
         self.xfm = DTCWTForward(biort=biort, qshift=qshift, J=J).to(device)
         self.ifm = DTCWTInverse(biort=biort, qshift=qshift).to(device)
     elif wt_type == 'DWT':
         self.xfm = DWTForward(wave=wave, J=J, mode=mode).to(device)
         self.ifm = DWTInverse(wave=wave, mode=mode).to(device)
     else:
         raise ValueError('no such type of wavelet transform is supported')
     self.J = J
     self.wt_type = wt_type
Exemple #26
0
 def __init__(self,
              block,
              num_blocks,
              num_classes=10,
              in_C=12,
              wave='sym17',
              mode='append'):
     super(FDResNet, self).__init__()
     self.wave = wave
     self.DWT = DWTForward(J=1,
                           wave=self.wave,
                           mode='symmetric',
                           Requirs_Grad=True).cuda()
     self.FDmode = mode
     self.in_planes = 64
     self.conv1 = conv3x3(in_C, 64)
     self.bn1 = nn.BatchNorm2d(64)
     self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
     self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
     self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
     self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
     self.linear = nn.Linear(512 * block.expansion, num_classes)
Exemple #27
0
    def __init__(self,
                 in_C=3,
                 C_Number=10,
                 wave='haar',
                 mode='append',
                 init_weights=True,
                 batch_norm=True):
        super(FDVgg, self).__init__()
        self.wave = wave
        self.DWT = DWTForward(J=1,
                              wave=self.wave,
                              mode='symmetric',
                              Requirs_Grad=True).cuda()
        self.FDmode = mode
        self.features = make_layers(in_C=in_C,
                                    cfg=cfg['E'],
                                    batch_norm=batch_norm)
        self.layers = []
        for m in self.features:
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
                self.layers.append(m)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        self.classifier = nn.Sequential(nn.Dropout(), nn.Linear(512, 512),
                                        nn.ReLU(True), nn.Dropout(),
                                        nn.Linear(512, 512), nn.ReLU(True),
                                        nn.Linear(512, C_Number))
        for m in self.classifier:
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
                self.layers.append(m)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
                self.bn_params += [m.weight, m.bias]

        if init_weights:
            self._initialize_weights()
Exemple #28
0
    def __init__(self, recursions=1, stride=1, kernel_size=5, wgan=False, highpass=True, D_arch='FSD',
                 norm_layer='Instance', filter_type='gau', cs='cat'):
        super(Discriminator, self).__init__()

        self.wgan = wgan
        n_input_channel = 3
        if highpass:
            if filter_type.lower() == 'gau':
                self.filter = FilterHigh(recursions=recursions, stride=stride, kernel_size=kernel_size,
                                         include_pad=False, gaussian=True)

            elif filter_type.lower() == 'avg_pool':
                self.filter = FilterHigh(recursions=recursions, stride=stride, kernel_size=kernel_size,
                                         include_pad=False, gaussian=False)
            elif filter_type.lower() == 'wavelet':
                self.DWT2 = DWTForward(J=1, wave='haar', mode='reflect')
                self.filter = self.filter_wavelet
                self.cs = cs
                n_input_channel = 9 if self.cs == 'cat' else 3
            else:
                raise NotImplementedError('Frequency Separation type [{:s}] not recognized'.format(filter_type))
            print('# FS type: {}, kernel size={}'.format(filter_type.lower(), kernel_size))
        else:
            self.filter = None

        if D_arch.lower() == 'nld_s1':
            self.net = NLayerDiscriminator(input_nc=n_input_channel, ndf=64, n_layers=2, norm_layer=norm_layer, stride=1)
            print('# Initializing NLayer-Discriminator-stride-1 with {} norm-layer'.format(norm_layer.upper()))
        elif D_arch.lower() == 'nld_s2':
            self.net = NLayerDiscriminator(input_nc=n_input_channel, ndf=64, n_layers=2, norm_layer=norm_layer, stride=2)
            print('#Initializing NLayer-Discriminator-stride-2 with {} norm-layer'.format(norm_layer.upper()))
        elif D_arch.lower() == 'fsd':
            self.net = DiscriminatorBasic(n_input_channels=n_input_channel, norm_layer=norm_layer)
            print('# Initializing FSSR-DiscriminatorBasic with {} norm layer'.format(norm_layer))
        else:
            raise NotImplementedError('Discriminator architecture [{:s}] not recognized'.format(D_arch))
Exemple #29
0
import torch, torchvision
import torchvision.transforms as transforms
import numpy as np
from pytorch_wavelets import DWTForward, DWTInverse
import glob,os, shutil
import matplotlib.pyplot as plt
from PIL import Image
import pickle
# height=1024, width=768
levels = 2
xfm = DWTForward(J=levels, mode='periodization', wave='bior3.3')
imgs = glob.glob("D:\\upla_sir_stuff\\kadid10k\\images\\???.png")
imgs_n = [i[-7:-4] for i in imgs]
# for i in imgs:
#     os.mkdir("D:\\upla_sir_stuff\\kadid10k\\images\\"+i)
trans = transforms.ToTensor()
imgs = glob.glob("D:\\upla_sir_stuff\\kadid10k\\images\\*\\*.png")
names = [os.path.basename(x) for x in imgs]
print(names)
# print(imgs)
# for i in imgs:
#     img = Image.open(i)
#     img = img.resize((768, 1024),3)
#     img = trans(img).unsqueeze(0)
#     # img = img.permute([0,3,1,2])
#     # print(img.shape)
#     Yl, Yh = xfm(img)
#     print(Yh[0].shape)
#     print(Yh[1].shape)
#     # with open(i[0:-4]+"dwt.pkl", 'wb') as f:
#         # pickle.dump(Yh, f)
Exemple #30
0
#     x = torch.stack([emboss_transform(y) for y in x], dim=0)
#     return x.to(DEVICE).requires_grad_()


def greyscale(x):
    def grey_transform(x):
        return x.mean(0, keepdim=True)

    x = torch.stack([grey_transform(y) for y in x], dim=0)
    return x.to(DEVICE).requires_grad_()


from pytorch_wavelets import DWTForward, DWTInverse

xfm = DWTForward(J=3, mode='zero', wave='db1').cuda()


def wav_kernel(x):
    def wav_transform(x):
        x_h, x_l = xfm(x.unsqueeze(0))
        x_h = F.interpolate(x_h, size=256, mode='bilinear')
        x_l_0, x_l_1, x_l_2 = F.interpolate(
            x_l[0][:, :, 0, :], size=256,
            mode='bilinear'), F.interpolate(x_l[1][:, :, 0, :],
                                            size=256,
                                            mode='bilinear'), F.interpolate(
                                                x_l[2][:, :, 0, :],
                                                size=256,
                                                mode='bilinear')
        x_final = torch.cat(