Пример #1
0
def test_equal_oddshape(size):
    wave = 'db3'
    J = 3
    mode = 'symmetric'
    x = torch.randn(5, 4, *size).to(dev)
    dwt1 = DWTForward(J=J, wave=wave, mode=mode).to(dev)
    iwt1 = DWTInverse(wave=wave, mode=mode).to(dev)
    dwt2 = DWTForward(J=J, wave=wave, mode=mode).to(dev)
    iwt2 = DWTInverse(wave=wave, mode=mode).to(dev)

    yl1, yh1 = dwt1(x)
    x1 = iwt1((yl1, yh1))
    yl2, yh2 = dwt2(x)
    x2 = iwt2((yl2, yh2))

    # Test it is the same as doing the PyWavelets wavedec
    coeffs = pywt.wavedec2(x.cpu().numpy(), wave, level=J, axes=(-2,-1),
                           mode=mode)
    X2 = pywt.waverec2(coeffs, wave, mode=mode)
    np.testing.assert_array_almost_equal(X2, x1.detach(), decimal=PREC_FLT)
    np.testing.assert_array_almost_equal(X2, x2.detach(), decimal=PREC_FLT)
    np.testing.assert_array_almost_equal(yl1.cpu(), coeffs[0], decimal=PREC_FLT)
    np.testing.assert_array_almost_equal(yl2.cpu(), coeffs[0], decimal=PREC_FLT)
    for j in range(J):
        for b in range(3):
            np.testing.assert_array_almost_equal(
                coeffs[J-j][b], yh1[j][:,:,b].cpu(), decimal=PREC_FLT)
            np.testing.assert_array_almost_equal(
                coeffs[J-j][b], yh2[j][:,:,b].cpu(), decimal=PREC_FLT)
Пример #2
0
    def __init__(self, outC1, outC2, outC3):
        super(Dwtconv, self).__init__()

        self.dwt1 = DWTForward(J=1, wave='haar', mode='symmetric')
        outChannel_conv1 = outC1 // 2
        nIn_conv1 = 4
        self.conv1_1 = nn.Conv2d(nIn_conv1,
                                 outChannel_conv1,
                                 kernel_size=3,
                                 padding=1)
        self.bn1_1 = nn.BatchNorm2d(outChannel_conv1)
        self.relu1_1 = nn.ReLU()
        self.conv1_2 = nn.Conv2d(outChannel_conv1,
                                 outC1,
                                 kernel_size=3,
                                 padding=1)
        self.bn1_2 = nn.BatchNorm2d(outC1)
        self.relu1_2 = nn.ReLU()

        nIn_conv2 = 4
        self.dwt2 = DWTForward(J=1, wave='haar', mode='symmetric')
        outChannel_conv2 = outC2 // 2
        self.conv2_1 = nn.Conv2d(nIn_conv2,
                                 outChannel_conv2,
                                 kernel_size=3,
                                 padding=1)
        self.bn2_1 = nn.BatchNorm2d(outChannel_conv2)
        self.relu2_1 = nn.ReLU()
        self.conv2_2 = nn.Conv2d(outChannel_conv2,
                                 outC2,
                                 kernel_size=3,
                                 padding=1)
        self.bn2_2 = nn.BatchNorm2d(outC2)
        self.relu2_2 = nn.ReLU()

        self.dwt3 = DWTForward(J=1, wave='haar', mode='symmetric')
        outChannel_conv3 = outC3 // 2
        nIn_conv3 = 4
        self.conv3_1 = nn.Conv2d(nIn_conv3,
                                 outChannel_conv3,
                                 kernel_size=3,
                                 padding=1)
        self.bn3_1 = nn.BatchNorm2d(outChannel_conv3)
        self.relu3_1 = nn.ReLU()
        self.conv3_2 = nn.Conv2d(outChannel_conv3,
                                 outC3,
                                 kernel_size=3,
                                 padding=1)
        self.bn3_2 = nn.BatchNorm2d(outC3)
        self.relu3_2 = nn.ReLU()
Пример #3
0
    def __init__(self,
                 J=3,
                 wave='haar',
                 window_size=15,
                 stride=8,
                 verbose=False):
        '''
        :param J: int, decomposition levels
        :param wave: wavelet to use,
                     see https://pywavelets.readthedocs.io/en/latest/ref/wavelets.html#built-in-wavelets-wavelist
                     for a list of available wavelets
        :param window_size: int, the window which the local CW-SSIM index is calculated
        :param stride: int, controls how much will the window move in one single step

        In general, local index masked by window is calculated first. The window
        the strides through the whole image scale to get a list of local indices.

        Formula (9) in reference [2] is used to compute the local index.

        The index returned by self.forward is the mean of those local indices.
        '''
        super().__init__()
        self.window_size = window_size
        self.J = J
        self.stride = stride
        self.dwt = DWTForward(J=J, wave=wave)
        self.verbose = verbose
Пример #4
0
def test_equal_double(wave, J, mode):
    with set_double_precision():
        x = torch.randn(5, 4, 64, 64).to(dev)
        assert x.dtype == torch.float64
        dwt = DWTForward(J=J, wave=wave, mode=mode).to(dev)
        iwt = DWTInverse(wave=wave, mode=mode).to(dev)

    yl, yh = dwt(x)
    x2 = iwt((yl, yh))

    # Test the forward and inverse worked
    np.testing.assert_array_almost_equal(x.cpu(),
                                         x2.detach().cpu(),
                                         decimal=PREC_DBL)
    coeffs = pywt.wavedec2(x.cpu().numpy(),
                           wave,
                           level=J,
                           axes=(-2, -1),
                           mode=mode)
    np.testing.assert_array_almost_equal(yl.cpu(), coeffs[0], decimal=7)
    for j in range(J):
        for b in range(3):
            np.testing.assert_array_almost_equal(coeffs[J - j][b],
                                                 yh[j][:, :, b].cpu(),
                                                 decimal=PREC_DBL)
Пример #5
0
def test_commutativity(wave, J, j):
    # Test the commutativity of the dwt
    C = 3
    Y = torch.randn(4, C, 128, 128, requires_grad=True, device=dev)
    dwt = DWTForward(J=J, wave=wave).to(dev)
    iwt = DWTInverse(wave=wave).to(dev)

    coeffs = dwt(Y)
    coeffs_zero = dwt(torch.zeros_like(Y))
    # Set level j LH to be nonzero
    coeffs_zero[1][j][:,:,0] = coeffs[1][j][:,:,0]
    ya = iwt(coeffs_zero)
    # Set level j HL to also be nonzero
    coeffs_zero[1][j][:,:,1] = coeffs[1][j][:,:,1]
    yab = iwt(coeffs_zero)
    # Set level j LH to be nonzero
    coeffs_zero[1][j][:,:,0] = torch.zeros_like(coeffs[1][j][:,:,0])
    yb = iwt(coeffs_zero)
    # Set level j HH to also be nonzero
    coeffs_zero[1][j][:,:,2] = coeffs[1][j][:,:,2]
    ybc = iwt(coeffs_zero)
    # Set level j HL to be nonzero
    coeffs_zero[1][j][:,:,1] = torch.zeros_like(coeffs[1][j][:,:,1])
    yc = iwt(coeffs_zero)

    np.testing.assert_array_almost_equal(
        (ya+yb).detach().cpu(), yab.detach().cpu(), decimal=PREC_FLT)
    np.testing.assert_array_almost_equal(
        (yc+yb).detach().cpu(), ybc.detach().cpu(), decimal=PREC_FLT)
Пример #6
0
def get_wavelets(x, device):
    xfm = DWTForward(J=3, mode='zero', wave='db4').to(
        device)  # Accepts all wave types available to PyWavelets
    Yl, Yh = xfm(x)

    batch_size = x.shape[0]
    channels = x.shape[1]
    rows = nextPowerOf2(Yh[0].shape[-2] * 2)
    cols = nextPowerOf2(Yh[0].shape[-1] * 2)
    wavelets = torch.zeros(batch_size, channels, rows, cols).to(device)
    # Yl is LL coefficients, Yh is list of higher bands with finest frequency in the beginning.
    for i, band in enumerate(Yh):
        irow = rows // 2**(i + 1)
        icol = cols // 2**(i + 1)
        wavelets[:, :, 0:(band[:, :, 0, :, :].shape[-2]),
                 icol:(icol + band[:, :, 0, :, :].shape[-1])] = band[:, :,
                                                                     0, :, :]
        wavelets[:, :, irow:(irow + band[:, :, 0, :, :].shape[-2]),
                 0:(band[:, :, 0, :, :].shape[-1])] = band[:, :, 1, :, :]
        wavelets[:, :, irow:(irow + band[:, :, 0, :, :].shape[-2]),
                 icol:(icol + band[:, :, 0, :, :].shape[-1])] = band[:, :,
                                                                     2, :, :]

    wavelets[:, :, :Yl.shape[-2], :Yl.shape[-1]] = Yl  # Put in LL coefficients
    return wavelets
Пример #7
0
def test_fwd(mode):
    with set_double_precision():
        x = torch.randn(1, 3, 16, 16, device=dev, requires_grad=True)
        xfm = DWTForward(J=2).to(dev)

    input = (x, xfm.h0_row, xfm.h1_row, xfm.h0_col, xfm.h1_col, mode)
    gradcheck(AFB2D.apply, input, eps=EPS, atol=ATOL)
Пример #8
0
    def __init__(self, opt):
        super(LRHR_wavelet_Mixunpair_Dataset, self).__init__()
        self.opt = opt
        self.paths_LR = None
        self.paths_HR = None
        self.paths_RealLR = None
        self.LR_env = None  # environment for lmdb
        self.HR_env = None

        # read image list from subset list txt
        if opt['subset_file'] is not None and opt['phase'] == 'train':
            with open(opt['subset_file']) as f:
                self.paths_HR = sorted([os.path.join(opt['dataroot_HR'], line.rstrip('\n')) \
                        for line in f])
            if opt['dataroot_LR'] is not None:
                raise NotImplementedError('Now subset only supports generating LR on-the-fly.')
        else:  # read image list from lmdb or image files
            self.HR_env, self.paths_HR = util.get_image_paths(opt['data_type'], opt['dataroot_HR'])
            self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR'])
            self.LR_env, self.paths_RealLR = util.get_image_paths(opt['data_type'], opt['dataroot_RealLR'])
            self.LR_env, self.paths_weights = util.get_image_paths(opt['data_type'], opt['dataroot_weights'])

        assert self.paths_HR, 'Error: HR path is empty.'
        # if self.paths_LR and self.paths_HR:
        #     assert len(self.paths_LR) == len(self.paths_HR), \
        #         'HR and LR datasets have different number of images - {}, {}.'.format(\
        #         len(self.paths_LR), len(self.paths_HR))

        self.random_scale_list = [1]
        self.DWT2 = DWTForward(J=1, wave='haar', mode='zero')
Пример #9
0
def test_gradients_inv(wave, J, mode):
    """ Gradient of inverse function should be forward function with filters
    swapped """
    wave = pywt.Wavelet(wave)
    fwd_filts = (wave.dec_lo, wave.dec_hi)
    inv_filts = (wave.dec_lo[::-1], wave.dec_hi[::-1])
    dwt = DWTForward(J=J, wave=fwd_filts, mode=mode).to(dev)
    iwt = DWTInverse(wave=inv_filts, mode=mode).to(dev)

    # Get the shape of the pyramid
    temp = torch.zeros(5,6,128,128).to(dev)
    l, h = dwt(temp)
    # Create our inputs
    yl = torch.randn(*l.shape, requires_grad=True, device=dev)
    yh = [torch.randn(*h[i].shape, requires_grad=True, device=dev)
          for i in range(J)]
    y = iwt((yl, yh))

    # Test the gradients
    yg = torch.randn(*y.shape, device=dev)
    y.backward(yg, retain_graph=True)
    dyl, dyh = dwt(yg)

    # test the lowpass
    np.testing.assert_array_almost_equal(yl.grad.detach().cpu(), dyl.cpu(),
                                         decimal=PREC_FLT)

    # Test the bandpass
    for j in range(J):
        np.testing.assert_array_almost_equal(yh[j].grad.detach().cpu(),
                                             dyh[j].cpu(),
                                             decimal=PREC_FLT)
Пример #10
0
    def __init__(self,
                 C,
                 F,
                 lp_size=3,
                 bp_sizes=(1, ),
                 q=1.0,
                 wave='db2',
                 mode='zero',
                 xfm=True,
                 ifm=True):
        super().__init__()
        self.C = C
        self.F = F
        self.J = len(bp_sizes)

        if xfm:
            self.XFM = DWTForward(J=self.J, mode=mode, wave=wave)
        else:
            self.XFM = lambda x: x

        if q < 0:
            self.shrink = ReLUWaveCoeffs()
        else:
            self.shrink = lambda x: x

        self.GainLayer = WaveGainLayer(C, F, lp_size, bp_sizes)

        if ifm:
            self.IFM = DWTInverse(mode=mode, wave=wave)
        else:
            self.IFM = lambda x: x
Пример #11
0
def test_gradients_fwd(wave, J, mode):
    """ Gradient of forward function should be inverse function with filters
    swapped """
    im = np.random.randn(5,6,128, 128).astype('float32')
    imt = torch.tensor(im, dtype=torch.float32, requires_grad=True, device=dev)

    wave = pywt.Wavelet(wave)
    fwd_filts = (wave.dec_lo, wave.dec_hi)
    inv_filts = (wave.dec_lo[::-1], wave.dec_hi[::-1])
    dwt = DWTForward(J=J, wave=fwd_filts, mode=mode).to(dev)
    iwt = DWTInverse(wave=inv_filts, mode=mode).to(dev)

    yl, yh = dwt(imt)

    # Test the lowpass
    ylg = torch.randn(*yl.shape, device=dev)
    yl.backward(ylg, retain_graph=True)
    zeros = [torch.zeros_like(yh[i]) for i in range(J)]
    ref = iwt((ylg, zeros))
    np.testing.assert_array_almost_equal(imt.grad.detach().cpu(), ref.cpu(),
                                         decimal=PREC_FLT)

    # Test the bandpass
    for j, y in enumerate(yh):
        imt.grad.zero_()
        g = torch.randn(*y.shape, device=dev)
        y.backward(g, retain_graph=True)
        hps = [zeros[i] for i in range(J)]
        hps[j] = g
        ref = iwt((torch.zeros_like(yl), hps))
        np.testing.assert_array_almost_equal(imt.grad.detach().cpu(), ref.cpu(),
                                             decimal=PREC_FLT)
Пример #12
0
 def __init__(self, J=3, wave='db4', device='cpu', dtype=torch.float64):
     super().__init__()
     _dtype = torch.get_default_dtype()
     torch.set_default_dtype(dtype)
     self.dwt = DWTForward(J=J, wave=wave, mode='per').to(device)
     self.idwt = DWTInverse(wave=wave, mode='per').to(device)
     torch.set_default_dtype(_dtype)
     self.norm_bound = 1.
Пример #13
0
def test_ok(wave, J, mode):
    x = torch.randn(5, 4, 64, 64).to(dev)
    dwt = DWTForward(J=J, wave=wave, mode=mode).to(dev)
    iwt = DWTInverse(wave=wave, mode=mode).to(dev)
    yl, yh = dwt(x)
    x2 = iwt((yl, yh))
    # Can have data errors sometimes
    assert yl.is_contiguous()
    for j in range(J):
        assert yh[j].is_contiguous()
    assert x2.is_contiguous()
Пример #14
0
 def wavelets(self, x):
     output = torch.zeros(x.shape[0], x.shape[1] * 4, int(x.shape[2] / 2),
                          int(x.shape[3] / 2))
     output = output.cuda()
     xfm = DWTForward(J=1, wave='haar', mode='symmetric').cuda()
     Yl, Yh = xfm(x)
     output[:, [0, 4, 8], :] = Yl
     output[:, 1:4, :] = Yh[0][:, 0, :]
     output[:, 5:8, :] = Yh[0][:, 1, :]
     output[:, 9:12, :] = Yh[0][:, 2, :]
     return output
Пример #15
0
 def __init__(self, C, F, k=4, stride=1, J=1, wd=0, wd1=None, right=True):
     super().__init__()
     self.wd = wd
     if wd1 is None:
         self.wd1 = wd
     else:
         self.wd1 = wd1
     self.C = C
     self.F = F
     x = torch.zeros(F, C, k, k)
     torch.nn.init.xavier_uniform_(x)
     xfm = DWTForward(J=J)
     self.ifm = DWTInverse()
     yl, yh = xfm(x)
     self.J = J
     if k == 4 and J == 1:
         self.gl = nn.Parameter(torch.zeros_like(yl))
         self.gh = nn.Parameter(torch.zeros_like(yh[0]))
         self.gl.data = yl.data
         self.gh.data = yh[0].data
         if right:
             self.pad = (1, 2, 1, 2)
         else:
             self.pad = (2, 1, 2, 1)
     elif k == 8 and J == 1:
         self.gl = nn.Parameter(torch.zeros_like(yl))
         self.gh = nn.Parameter(torch.zeros_like(yh[0]))
         self.gl.data = yl.data
         self.gh.data = yh[0].data
         if right:
             self.pad = (3, 4, 3, 4)
         else:
             self.pad = (4, 3, 4, 3)
     elif k == 8 and J == 2:
         self.gl = nn.Parameter(torch.zeros_like(yl))
         self.gh = nn.Parameter(torch.zeros_like(yh[1]))
         self.gl.data = yl.data
         self.gh.data = yh[1].data
         if right:
             self.pad = (3, 4, 3, 4)
         else:
             self.pad = (4, 3, 4, 3)
     elif k == 8 and J == 3:
         self.gl = nn.Parameter(torch.zeros_like(yl))
         self.gh = nn.Parameter(torch.zeros_like(yh[2]))
         self.gl.data = yl.data
         self.gh.data = yh[2].data
         if right:
             self.pad = (3, 4, 3, 4)
         else:
             self.pad = (4, 3, 4, 3)
     else:
         raise NotImplementedError
Пример #16
0
 def __init__(self, in_planes, lifting_size, kernel_size, no_bottleneck,
              share_weights, simple_lifting, regu_details, regu_approx):
     super(Haar, self).__init__()
     from pytorch_wavelets import DWTForward
     self.wavelet = DWTForward(J=1,mode='zero', wave='db1').cuda()
     self.share_weights = share_weights
     if no_bottleneck:
         # We still want to do a BN and RELU, but we will not perform a conv
         # as the input_plane and output_plare are the same
         self.bootleneck = BottleneckBlock(in_planes * 1, in_planes * 1)            
     else:
         self.bootleneck = BottleneckBlock(in_planes * 4, in_planes * 2)
Пример #17
0
def init_dwt(resume=None, shape=None, wave=None, colors=None):
    size = None
    wp_fake = pywt.WaveletPacket2D(data=np.zeros(shape[2:]),
                                   wavelet='db1',
                                   mode='symmetric')
    xfm = DWTForward(J=wp_fake.maxlevel, wave=wave, mode='symmetric').cuda()
    # xfm = DTCWTForward(J=lvl, biort='near_sym_b', qshift='qshift_b').cuda() # 4x more params, biort ['antonini','legall','near_sym_a','near_sym_b']
    ifm = DWTInverse(wave=wave,
                     mode='symmetric').cuda()  # symmetric zero periodization
    # ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda() # 4x more params, biort ['antonini','legall','near_sym_a','near_sym_b']
    if resume is None:  # random init
        Yl_in, Yh_in = xfm(torch.zeros(shape).cuda())
        Ys = [torch.randn(*Y.shape).cuda() for Y in [Yl_in, *Yh_in]]
    elif isinstance(resume, str):
        if os.path.isfile(resume):
            if os.path.splitext(resume)[1].lower()[1:] in [
                    'jpg', 'png', 'tif', 'bmp'
            ]:
                img_in = imread(resume)
                Ys = img2dwt(img_in, wave=wave, colors=colors)
                print(' loaded image', resume, img_in.shape, 'level',
                      len(Ys) - 1)
                size = img_in.shape[:2]
                wp_fake = pywt.WaveletPacket2D(data=np.zeros(size),
                                               wavelet='db1',
                                               mode='symmetric')
                xfm = DWTForward(J=wp_fake.maxlevel,
                                 wave=wave,
                                 mode='symmetric').cuda()
            else:
                Ys = torch.load(resume)
                Ys = [y.detach().cuda() for y in Ys]
        else:
            print(' Snapshot not found:', resume)
            exit()
    else:
        Ys = [y.cuda() for y in resume]
    # print('level', len(Ys)-1, 'low freq', Ys[0].cpu().numpy().shape)
    return Ys, xfm, ifm, size
Пример #18
0
    def filter_wavelet(self, x, norm=True):
        DWT2 = DWTForward(J=1, wave='haar', mode='reflect')
        LL, Hc = DWT2(x)
        LH, HL, HH = Hc[0][0, :, 0, :, :], Hc[0][0, :, 1, :, :], Hc[0][0, :,
                                                                       2, :, :]

        if norm:
            LL = LL * 0.5
            LH, HL, HH = LH * 0.5 + 0.5, HL * 0.5 + 0.5, HH * 0.5 + 0.5
        if self.cs == 'sum':
            return LL, (LH + HL + HH) / 3.
        elif self.cs == 'cat':
            return LL, torch.cat((LH, HL, HH), 1)
Пример #19
0
    def __init__(
        self,
        num_levels: int = 1,
        wave: str = 'db2',
        mode: str = 'periodization',
        device: Optional[str] = None,
    ) -> None:
        self.num_levels = num_levels
        self.wave = wave
        self.mode = mode

        self.xfm = DWTForward(J=num_levels, wave=wave, mode=mode).to(device)
        self.ifm = DWTInverse(wave=wave, mode=mode).to(device)
Пример #20
0
 def __init__(self, in_ch, out_ch=None, filter_sz=3, num_conv=3):
     super(WCNN, self).__init__()
     if out_ch is None:
         out_ch = 4 * in_ch
     self.DwT = DWTForward(J=1, wave='haar', mode='zero')
     # 4 * input channels since DwT creates 4 outputs per image
     modules = []
     modules.append(nn.Conv2d(4 * in_ch, out_ch, kernel_size=3, padding=1))
     #modules.append(nn.BatchNorm2d(num_features=out_ch))
     modules.append(nn.ReLU())
     for i in range(num_conv):
         modules.append(nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
         #modules.append(nn.BatchNorm2d(num_features=out_ch))
         modules.append(nn.ReLU())
     self.conv = nn.Sequential(*modules)
Пример #21
0
 def __init__(self,
              discriminator_size,
              perceptual_size=256,
              generator_weights=[1.0, 1.0, 1.0, 0.5, 0.1]):
     super().__init__()
     # l1/l2 loss
     self.l1_loss = nn.L1Loss()
     self.l2_loss = nn.MSELoss()
     # perceptual_loss
     self.perceptual_loss = PerceptualLoss(perceptual_size)
     # gan loss
     self.gan_loss = GANLoss(image_size=int(discriminator_size))
     self.generator_weights = generator_weights
     self.dwt = DWTForward(J=1, mode='zero', wave='db1')
     self.idwt = DWTInverse(mode="zero", wave="db1")
Пример #22
0
def img2dwt(img_in, wave='coif2', sharp=0.3, colors=1.):
    image_t = un_rgb(img_in, colors=colors)
    with torch.no_grad():
        wp_fake = pywt.WaveletPacket2D(data=np.zeros(image_t.shape[2:]),
                                       wavelet='db1',
                                       mode='zero')
        lvl = wp_fake.maxlevel
        # print(image_t.shape, lvl)
        xfm = DWTForward(J=lvl, wave=wave, mode='symmetric').cuda()
        Yl_in, Yh_in = xfm(image_t.cuda())
        Ys = [Yl_in, *Yh_in]
    scale = dwt_scale(Ys, sharp)
    for i in range(len(Ys) - 1):
        Ys[i + 1] /= scale[i]
    return Ys
Пример #23
0
 def __init__(self,
              main_dir,
              device,
              transform=[
                  transforms.Resize((768, 1024), interpolation=4),
                  transforms.ToTensor()
              ],
              levels=2):
     self.main_dir = main_dir
     self.transforms = transforms.Compose(transform)
     self.image_list = glob.glob(main_dir + "/*/*.png")
     self.levels = levels
     self.names = [os.path.basename(x) for x in self.image_list]
     self.xfm = DWTForward(J=self.levels,
                           mode='periodization',
                           wave='bior3.3').to(device)
Пример #24
0
def test_equal(wave, J, mode):
    x = torch.randn(5, 4, 64, 64).to(dev)
    dwt = DWTForward(J=J, wave=wave, mode=mode).to(dev)
    iwt = DWTInverse(wave=wave, mode=mode).to(dev)
    yl, yh = dwt(x)
    x2 = iwt((yl, yh))

    # Test the forward and inverse worked
    np.testing.assert_array_almost_equal(x.cpu(), x2.detach(), decimal=PREC_FLT)
    # Test it is the same as doing the PyWavelets wavedec with reflection
    # padding
    coeffs = pywt.wavedec2(x.cpu().numpy(), wave, level=J, axes=(-2,-1),
                           mode=mode)
    np.testing.assert_array_almost_equal(yl.cpu(), coeffs[0], decimal=PREC_FLT)
    for j in range(J):
        for b in range(3):
            np.testing.assert_array_almost_equal(
                coeffs[J-j][b], yh[j][:,:,b].cpu(), decimal=PREC_FLT)
Пример #25
0
 def __init__(self,
              wt_type='DTCWT',
              biort='near_sym_b',
              qshift='qshift_b',
              J=5,
              wave='db3',
              mode='zero',
              device='cuda',
              requires_grad=True):
     super().__init__()
     if wt_type == 'DTCWT':
         self.xfm = DTCWTForward(biort=biort, qshift=qshift, J=J).to(device)
         self.ifm = DTCWTInverse(biort=biort, qshift=qshift).to(device)
     elif wt_type == 'DWT':
         self.xfm = DWTForward(wave=wave, J=J, mode=mode).to(device)
         self.ifm = DWTInverse(wave=wave, mode=mode).to(device)
     else:
         raise ValueError('no such type of wavelet transform is supported')
     self.J = J
     self.wt_type = wt_type
Пример #26
0
 def __init__(self,
              block,
              num_blocks,
              num_classes=10,
              in_C=12,
              wave='sym17',
              mode='append'):
     super(FDResNet, self).__init__()
     self.wave = wave
     self.DWT = DWTForward(J=1,
                           wave=self.wave,
                           mode='symmetric',
                           Requirs_Grad=True).cuda()
     self.FDmode = mode
     self.in_planes = 64
     self.conv1 = conv3x3(in_C, 64)
     self.bn1 = nn.BatchNorm2d(64)
     self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
     self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
     self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
     self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
     self.linear = nn.Linear(512 * block.expansion, num_classes)
Пример #27
0
    def __init__(self,
                 in_C=3,
                 C_Number=10,
                 wave='haar',
                 mode='append',
                 init_weights=True,
                 batch_norm=True):
        super(FDVgg, self).__init__()
        self.wave = wave
        self.DWT = DWTForward(J=1,
                              wave=self.wave,
                              mode='symmetric',
                              Requirs_Grad=True).cuda()
        self.FDmode = mode
        self.features = make_layers(in_C=in_C,
                                    cfg=cfg['E'],
                                    batch_norm=batch_norm)
        self.layers = []
        for m in self.features:
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
                self.layers.append(m)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        self.classifier = nn.Sequential(nn.Dropout(), nn.Linear(512, 512),
                                        nn.ReLU(True), nn.Dropout(),
                                        nn.Linear(512, 512), nn.ReLU(True),
                                        nn.Linear(512, C_Number))
        for m in self.classifier:
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
                self.layers.append(m)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
                self.bn_params += [m.weight, m.bias]

        if init_weights:
            self._initialize_weights()
Пример #28
0
    def __init__(self, recursions=1, stride=1, kernel_size=5, wgan=False, highpass=True, D_arch='FSD',
                 norm_layer='Instance', filter_type='gau', cs='cat'):
        super(Discriminator, self).__init__()

        self.wgan = wgan
        n_input_channel = 3
        if highpass:
            if filter_type.lower() == 'gau':
                self.filter = FilterHigh(recursions=recursions, stride=stride, kernel_size=kernel_size,
                                         include_pad=False, gaussian=True)

            elif filter_type.lower() == 'avg_pool':
                self.filter = FilterHigh(recursions=recursions, stride=stride, kernel_size=kernel_size,
                                         include_pad=False, gaussian=False)
            elif filter_type.lower() == 'wavelet':
                self.DWT2 = DWTForward(J=1, wave='haar', mode='reflect')
                self.filter = self.filter_wavelet
                self.cs = cs
                n_input_channel = 9 if self.cs == 'cat' else 3
            else:
                raise NotImplementedError('Frequency Separation type [{:s}] not recognized'.format(filter_type))
            print('# FS type: {}, kernel size={}'.format(filter_type.lower(), kernel_size))
        else:
            self.filter = None

        if D_arch.lower() == 'nld_s1':
            self.net = NLayerDiscriminator(input_nc=n_input_channel, ndf=64, n_layers=2, norm_layer=norm_layer, stride=1)
            print('# Initializing NLayer-Discriminator-stride-1 with {} norm-layer'.format(norm_layer.upper()))
        elif D_arch.lower() == 'nld_s2':
            self.net = NLayerDiscriminator(input_nc=n_input_channel, ndf=64, n_layers=2, norm_layer=norm_layer, stride=2)
            print('#Initializing NLayer-Discriminator-stride-2 with {} norm-layer'.format(norm_layer.upper()))
        elif D_arch.lower() == 'fsd':
            self.net = DiscriminatorBasic(n_input_channels=n_input_channel, norm_layer=norm_layer)
            print('# Initializing FSSR-DiscriminatorBasic with {} norm layer'.format(norm_layer))
        else:
            raise NotImplementedError('Discriminator architecture [{:s}] not recognized'.format(D_arch))
Пример #29
0
import torch, torchvision
import torchvision.transforms as transforms
import numpy as np
from pytorch_wavelets import DWTForward, DWTInverse
import glob,os, shutil
import matplotlib.pyplot as plt
from PIL import Image
import pickle
# height=1024, width=768
levels = 2
xfm = DWTForward(J=levels, mode='periodization', wave='bior3.3')
imgs = glob.glob("D:\\upla_sir_stuff\\kadid10k\\images\\???.png")
imgs_n = [i[-7:-4] for i in imgs]
# for i in imgs:
#     os.mkdir("D:\\upla_sir_stuff\\kadid10k\\images\\"+i)
trans = transforms.ToTensor()
imgs = glob.glob("D:\\upla_sir_stuff\\kadid10k\\images\\*\\*.png")
names = [os.path.basename(x) for x in imgs]
print(names)
# print(imgs)
# for i in imgs:
#     img = Image.open(i)
#     img = img.resize((768, 1024),3)
#     img = trans(img).unsqueeze(0)
#     # img = img.permute([0,3,1,2])
#     # print(img.shape)
#     Yl, Yh = xfm(img)
#     print(Yh[0].shape)
#     print(Yh[1].shape)
#     # with open(i[0:-4]+"dwt.pkl", 'wb') as f:
#         # pickle.dump(Yh, f)
Пример #30
0
#     x = torch.stack([emboss_transform(y) for y in x], dim=0)
#     return x.to(DEVICE).requires_grad_()


def greyscale(x):
    def grey_transform(x):
        return x.mean(0, keepdim=True)

    x = torch.stack([grey_transform(y) for y in x], dim=0)
    return x.to(DEVICE).requires_grad_()


from pytorch_wavelets import DWTForward, DWTInverse

xfm = DWTForward(J=3, mode='zero', wave='db1').cuda()


def wav_kernel(x):
    def wav_transform(x):
        x_h, x_l = xfm(x.unsqueeze(0))
        x_h = F.interpolate(x_h, size=256, mode='bilinear')
        x_l_0, x_l_1, x_l_2 = F.interpolate(
            x_l[0][:, :, 0, :], size=256,
            mode='bilinear'), F.interpolate(x_l[1][:, :, 0, :],
                                            size=256,
                                            mode='bilinear'), F.interpolate(
                                                x_l[2][:, :, 0, :],
                                                size=256,
                                                mode='bilinear')
        x_final = torch.cat(