def test_equal_oddshape(size): wave = 'db3' J = 3 mode = 'symmetric' x = torch.randn(5, 4, *size).to(dev) dwt1 = DWTForward(J=J, wave=wave, mode=mode).to(dev) iwt1 = DWTInverse(wave=wave, mode=mode).to(dev) dwt2 = DWTForward(J=J, wave=wave, mode=mode).to(dev) iwt2 = DWTInverse(wave=wave, mode=mode).to(dev) yl1, yh1 = dwt1(x) x1 = iwt1((yl1, yh1)) yl2, yh2 = dwt2(x) x2 = iwt2((yl2, yh2)) # Test it is the same as doing the PyWavelets wavedec coeffs = pywt.wavedec2(x.cpu().numpy(), wave, level=J, axes=(-2,-1), mode=mode) X2 = pywt.waverec2(coeffs, wave, mode=mode) np.testing.assert_array_almost_equal(X2, x1.detach(), decimal=PREC_FLT) np.testing.assert_array_almost_equal(X2, x2.detach(), decimal=PREC_FLT) np.testing.assert_array_almost_equal(yl1.cpu(), coeffs[0], decimal=PREC_FLT) np.testing.assert_array_almost_equal(yl2.cpu(), coeffs[0], decimal=PREC_FLT) for j in range(J): for b in range(3): np.testing.assert_array_almost_equal( coeffs[J-j][b], yh1[j][:,:,b].cpu(), decimal=PREC_FLT) np.testing.assert_array_almost_equal( coeffs[J-j][b], yh2[j][:,:,b].cpu(), decimal=PREC_FLT)
def __init__(self, outC1, outC2, outC3): super(Dwtconv, self).__init__() self.dwt1 = DWTForward(J=1, wave='haar', mode='symmetric') outChannel_conv1 = outC1 // 2 nIn_conv1 = 4 self.conv1_1 = nn.Conv2d(nIn_conv1, outChannel_conv1, kernel_size=3, padding=1) self.bn1_1 = nn.BatchNorm2d(outChannel_conv1) self.relu1_1 = nn.ReLU() self.conv1_2 = nn.Conv2d(outChannel_conv1, outC1, kernel_size=3, padding=1) self.bn1_2 = nn.BatchNorm2d(outC1) self.relu1_2 = nn.ReLU() nIn_conv2 = 4 self.dwt2 = DWTForward(J=1, wave='haar', mode='symmetric') outChannel_conv2 = outC2 // 2 self.conv2_1 = nn.Conv2d(nIn_conv2, outChannel_conv2, kernel_size=3, padding=1) self.bn2_1 = nn.BatchNorm2d(outChannel_conv2) self.relu2_1 = nn.ReLU() self.conv2_2 = nn.Conv2d(outChannel_conv2, outC2, kernel_size=3, padding=1) self.bn2_2 = nn.BatchNorm2d(outC2) self.relu2_2 = nn.ReLU() self.dwt3 = DWTForward(J=1, wave='haar', mode='symmetric') outChannel_conv3 = outC3 // 2 nIn_conv3 = 4 self.conv3_1 = nn.Conv2d(nIn_conv3, outChannel_conv3, kernel_size=3, padding=1) self.bn3_1 = nn.BatchNorm2d(outChannel_conv3) self.relu3_1 = nn.ReLU() self.conv3_2 = nn.Conv2d(outChannel_conv3, outC3, kernel_size=3, padding=1) self.bn3_2 = nn.BatchNorm2d(outC3) self.relu3_2 = nn.ReLU()
def __init__(self, J=3, wave='haar', window_size=15, stride=8, verbose=False): ''' :param J: int, decomposition levels :param wave: wavelet to use, see https://pywavelets.readthedocs.io/en/latest/ref/wavelets.html#built-in-wavelets-wavelist for a list of available wavelets :param window_size: int, the window which the local CW-SSIM index is calculated :param stride: int, controls how much will the window move in one single step In general, local index masked by window is calculated first. The window the strides through the whole image scale to get a list of local indices. Formula (9) in reference [2] is used to compute the local index. The index returned by self.forward is the mean of those local indices. ''' super().__init__() self.window_size = window_size self.J = J self.stride = stride self.dwt = DWTForward(J=J, wave=wave) self.verbose = verbose
def test_equal_double(wave, J, mode): with set_double_precision(): x = torch.randn(5, 4, 64, 64).to(dev) assert x.dtype == torch.float64 dwt = DWTForward(J=J, wave=wave, mode=mode).to(dev) iwt = DWTInverse(wave=wave, mode=mode).to(dev) yl, yh = dwt(x) x2 = iwt((yl, yh)) # Test the forward and inverse worked np.testing.assert_array_almost_equal(x.cpu(), x2.detach().cpu(), decimal=PREC_DBL) coeffs = pywt.wavedec2(x.cpu().numpy(), wave, level=J, axes=(-2, -1), mode=mode) np.testing.assert_array_almost_equal(yl.cpu(), coeffs[0], decimal=7) for j in range(J): for b in range(3): np.testing.assert_array_almost_equal(coeffs[J - j][b], yh[j][:, :, b].cpu(), decimal=PREC_DBL)
def test_commutativity(wave, J, j): # Test the commutativity of the dwt C = 3 Y = torch.randn(4, C, 128, 128, requires_grad=True, device=dev) dwt = DWTForward(J=J, wave=wave).to(dev) iwt = DWTInverse(wave=wave).to(dev) coeffs = dwt(Y) coeffs_zero = dwt(torch.zeros_like(Y)) # Set level j LH to be nonzero coeffs_zero[1][j][:,:,0] = coeffs[1][j][:,:,0] ya = iwt(coeffs_zero) # Set level j HL to also be nonzero coeffs_zero[1][j][:,:,1] = coeffs[1][j][:,:,1] yab = iwt(coeffs_zero) # Set level j LH to be nonzero coeffs_zero[1][j][:,:,0] = torch.zeros_like(coeffs[1][j][:,:,0]) yb = iwt(coeffs_zero) # Set level j HH to also be nonzero coeffs_zero[1][j][:,:,2] = coeffs[1][j][:,:,2] ybc = iwt(coeffs_zero) # Set level j HL to be nonzero coeffs_zero[1][j][:,:,1] = torch.zeros_like(coeffs[1][j][:,:,1]) yc = iwt(coeffs_zero) np.testing.assert_array_almost_equal( (ya+yb).detach().cpu(), yab.detach().cpu(), decimal=PREC_FLT) np.testing.assert_array_almost_equal( (yc+yb).detach().cpu(), ybc.detach().cpu(), decimal=PREC_FLT)
def get_wavelets(x, device): xfm = DWTForward(J=3, mode='zero', wave='db4').to( device) # Accepts all wave types available to PyWavelets Yl, Yh = xfm(x) batch_size = x.shape[0] channels = x.shape[1] rows = nextPowerOf2(Yh[0].shape[-2] * 2) cols = nextPowerOf2(Yh[0].shape[-1] * 2) wavelets = torch.zeros(batch_size, channels, rows, cols).to(device) # Yl is LL coefficients, Yh is list of higher bands with finest frequency in the beginning. for i, band in enumerate(Yh): irow = rows // 2**(i + 1) icol = cols // 2**(i + 1) wavelets[:, :, 0:(band[:, :, 0, :, :].shape[-2]), icol:(icol + band[:, :, 0, :, :].shape[-1])] = band[:, :, 0, :, :] wavelets[:, :, irow:(irow + band[:, :, 0, :, :].shape[-2]), 0:(band[:, :, 0, :, :].shape[-1])] = band[:, :, 1, :, :] wavelets[:, :, irow:(irow + band[:, :, 0, :, :].shape[-2]), icol:(icol + band[:, :, 0, :, :].shape[-1])] = band[:, :, 2, :, :] wavelets[:, :, :Yl.shape[-2], :Yl.shape[-1]] = Yl # Put in LL coefficients return wavelets
def test_fwd(mode): with set_double_precision(): x = torch.randn(1, 3, 16, 16, device=dev, requires_grad=True) xfm = DWTForward(J=2).to(dev) input = (x, xfm.h0_row, xfm.h1_row, xfm.h0_col, xfm.h1_col, mode) gradcheck(AFB2D.apply, input, eps=EPS, atol=ATOL)
def __init__(self, opt): super(LRHR_wavelet_Mixunpair_Dataset, self).__init__() self.opt = opt self.paths_LR = None self.paths_HR = None self.paths_RealLR = None self.LR_env = None # environment for lmdb self.HR_env = None # read image list from subset list txt if opt['subset_file'] is not None and opt['phase'] == 'train': with open(opt['subset_file']) as f: self.paths_HR = sorted([os.path.join(opt['dataroot_HR'], line.rstrip('\n')) \ for line in f]) if opt['dataroot_LR'] is not None: raise NotImplementedError('Now subset only supports generating LR on-the-fly.') else: # read image list from lmdb or image files self.HR_env, self.paths_HR = util.get_image_paths(opt['data_type'], opt['dataroot_HR']) self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR']) self.LR_env, self.paths_RealLR = util.get_image_paths(opt['data_type'], opt['dataroot_RealLR']) self.LR_env, self.paths_weights = util.get_image_paths(opt['data_type'], opt['dataroot_weights']) assert self.paths_HR, 'Error: HR path is empty.' # if self.paths_LR and self.paths_HR: # assert len(self.paths_LR) == len(self.paths_HR), \ # 'HR and LR datasets have different number of images - {}, {}.'.format(\ # len(self.paths_LR), len(self.paths_HR)) self.random_scale_list = [1] self.DWT2 = DWTForward(J=1, wave='haar', mode='zero')
def test_gradients_inv(wave, J, mode): """ Gradient of inverse function should be forward function with filters swapped """ wave = pywt.Wavelet(wave) fwd_filts = (wave.dec_lo, wave.dec_hi) inv_filts = (wave.dec_lo[::-1], wave.dec_hi[::-1]) dwt = DWTForward(J=J, wave=fwd_filts, mode=mode).to(dev) iwt = DWTInverse(wave=inv_filts, mode=mode).to(dev) # Get the shape of the pyramid temp = torch.zeros(5,6,128,128).to(dev) l, h = dwt(temp) # Create our inputs yl = torch.randn(*l.shape, requires_grad=True, device=dev) yh = [torch.randn(*h[i].shape, requires_grad=True, device=dev) for i in range(J)] y = iwt((yl, yh)) # Test the gradients yg = torch.randn(*y.shape, device=dev) y.backward(yg, retain_graph=True) dyl, dyh = dwt(yg) # test the lowpass np.testing.assert_array_almost_equal(yl.grad.detach().cpu(), dyl.cpu(), decimal=PREC_FLT) # Test the bandpass for j in range(J): np.testing.assert_array_almost_equal(yh[j].grad.detach().cpu(), dyh[j].cpu(), decimal=PREC_FLT)
def __init__(self, C, F, lp_size=3, bp_sizes=(1, ), q=1.0, wave='db2', mode='zero', xfm=True, ifm=True): super().__init__() self.C = C self.F = F self.J = len(bp_sizes) if xfm: self.XFM = DWTForward(J=self.J, mode=mode, wave=wave) else: self.XFM = lambda x: x if q < 0: self.shrink = ReLUWaveCoeffs() else: self.shrink = lambda x: x self.GainLayer = WaveGainLayer(C, F, lp_size, bp_sizes) if ifm: self.IFM = DWTInverse(mode=mode, wave=wave) else: self.IFM = lambda x: x
def test_gradients_fwd(wave, J, mode): """ Gradient of forward function should be inverse function with filters swapped """ im = np.random.randn(5,6,128, 128).astype('float32') imt = torch.tensor(im, dtype=torch.float32, requires_grad=True, device=dev) wave = pywt.Wavelet(wave) fwd_filts = (wave.dec_lo, wave.dec_hi) inv_filts = (wave.dec_lo[::-1], wave.dec_hi[::-1]) dwt = DWTForward(J=J, wave=fwd_filts, mode=mode).to(dev) iwt = DWTInverse(wave=inv_filts, mode=mode).to(dev) yl, yh = dwt(imt) # Test the lowpass ylg = torch.randn(*yl.shape, device=dev) yl.backward(ylg, retain_graph=True) zeros = [torch.zeros_like(yh[i]) for i in range(J)] ref = iwt((ylg, zeros)) np.testing.assert_array_almost_equal(imt.grad.detach().cpu(), ref.cpu(), decimal=PREC_FLT) # Test the bandpass for j, y in enumerate(yh): imt.grad.zero_() g = torch.randn(*y.shape, device=dev) y.backward(g, retain_graph=True) hps = [zeros[i] for i in range(J)] hps[j] = g ref = iwt((torch.zeros_like(yl), hps)) np.testing.assert_array_almost_equal(imt.grad.detach().cpu(), ref.cpu(), decimal=PREC_FLT)
def __init__(self, J=3, wave='db4', device='cpu', dtype=torch.float64): super().__init__() _dtype = torch.get_default_dtype() torch.set_default_dtype(dtype) self.dwt = DWTForward(J=J, wave=wave, mode='per').to(device) self.idwt = DWTInverse(wave=wave, mode='per').to(device) torch.set_default_dtype(_dtype) self.norm_bound = 1.
def test_ok(wave, J, mode): x = torch.randn(5, 4, 64, 64).to(dev) dwt = DWTForward(J=J, wave=wave, mode=mode).to(dev) iwt = DWTInverse(wave=wave, mode=mode).to(dev) yl, yh = dwt(x) x2 = iwt((yl, yh)) # Can have data errors sometimes assert yl.is_contiguous() for j in range(J): assert yh[j].is_contiguous() assert x2.is_contiguous()
def wavelets(self, x): output = torch.zeros(x.shape[0], x.shape[1] * 4, int(x.shape[2] / 2), int(x.shape[3] / 2)) output = output.cuda() xfm = DWTForward(J=1, wave='haar', mode='symmetric').cuda() Yl, Yh = xfm(x) output[:, [0, 4, 8], :] = Yl output[:, 1:4, :] = Yh[0][:, 0, :] output[:, 5:8, :] = Yh[0][:, 1, :] output[:, 9:12, :] = Yh[0][:, 2, :] return output
def __init__(self, C, F, k=4, stride=1, J=1, wd=0, wd1=None, right=True): super().__init__() self.wd = wd if wd1 is None: self.wd1 = wd else: self.wd1 = wd1 self.C = C self.F = F x = torch.zeros(F, C, k, k) torch.nn.init.xavier_uniform_(x) xfm = DWTForward(J=J) self.ifm = DWTInverse() yl, yh = xfm(x) self.J = J if k == 4 and J == 1: self.gl = nn.Parameter(torch.zeros_like(yl)) self.gh = nn.Parameter(torch.zeros_like(yh[0])) self.gl.data = yl.data self.gh.data = yh[0].data if right: self.pad = (1, 2, 1, 2) else: self.pad = (2, 1, 2, 1) elif k == 8 and J == 1: self.gl = nn.Parameter(torch.zeros_like(yl)) self.gh = nn.Parameter(torch.zeros_like(yh[0])) self.gl.data = yl.data self.gh.data = yh[0].data if right: self.pad = (3, 4, 3, 4) else: self.pad = (4, 3, 4, 3) elif k == 8 and J == 2: self.gl = nn.Parameter(torch.zeros_like(yl)) self.gh = nn.Parameter(torch.zeros_like(yh[1])) self.gl.data = yl.data self.gh.data = yh[1].data if right: self.pad = (3, 4, 3, 4) else: self.pad = (4, 3, 4, 3) elif k == 8 and J == 3: self.gl = nn.Parameter(torch.zeros_like(yl)) self.gh = nn.Parameter(torch.zeros_like(yh[2])) self.gl.data = yl.data self.gh.data = yh[2].data if right: self.pad = (3, 4, 3, 4) else: self.pad = (4, 3, 4, 3) else: raise NotImplementedError
def __init__(self, in_planes, lifting_size, kernel_size, no_bottleneck, share_weights, simple_lifting, regu_details, regu_approx): super(Haar, self).__init__() from pytorch_wavelets import DWTForward self.wavelet = DWTForward(J=1,mode='zero', wave='db1').cuda() self.share_weights = share_weights if no_bottleneck: # We still want to do a BN and RELU, but we will not perform a conv # as the input_plane and output_plare are the same self.bootleneck = BottleneckBlock(in_planes * 1, in_planes * 1) else: self.bootleneck = BottleneckBlock(in_planes * 4, in_planes * 2)
def init_dwt(resume=None, shape=None, wave=None, colors=None): size = None wp_fake = pywt.WaveletPacket2D(data=np.zeros(shape[2:]), wavelet='db1', mode='symmetric') xfm = DWTForward(J=wp_fake.maxlevel, wave=wave, mode='symmetric').cuda() # xfm = DTCWTForward(J=lvl, biort='near_sym_b', qshift='qshift_b').cuda() # 4x more params, biort ['antonini','legall','near_sym_a','near_sym_b'] ifm = DWTInverse(wave=wave, mode='symmetric').cuda() # symmetric zero periodization # ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda() # 4x more params, biort ['antonini','legall','near_sym_a','near_sym_b'] if resume is None: # random init Yl_in, Yh_in = xfm(torch.zeros(shape).cuda()) Ys = [torch.randn(*Y.shape).cuda() for Y in [Yl_in, *Yh_in]] elif isinstance(resume, str): if os.path.isfile(resume): if os.path.splitext(resume)[1].lower()[1:] in [ 'jpg', 'png', 'tif', 'bmp' ]: img_in = imread(resume) Ys = img2dwt(img_in, wave=wave, colors=colors) print(' loaded image', resume, img_in.shape, 'level', len(Ys) - 1) size = img_in.shape[:2] wp_fake = pywt.WaveletPacket2D(data=np.zeros(size), wavelet='db1', mode='symmetric') xfm = DWTForward(J=wp_fake.maxlevel, wave=wave, mode='symmetric').cuda() else: Ys = torch.load(resume) Ys = [y.detach().cuda() for y in Ys] else: print(' Snapshot not found:', resume) exit() else: Ys = [y.cuda() for y in resume] # print('level', len(Ys)-1, 'low freq', Ys[0].cpu().numpy().shape) return Ys, xfm, ifm, size
def filter_wavelet(self, x, norm=True): DWT2 = DWTForward(J=1, wave='haar', mode='reflect') LL, Hc = DWT2(x) LH, HL, HH = Hc[0][0, :, 0, :, :], Hc[0][0, :, 1, :, :], Hc[0][0, :, 2, :, :] if norm: LL = LL * 0.5 LH, HL, HH = LH * 0.5 + 0.5, HL * 0.5 + 0.5, HH * 0.5 + 0.5 if self.cs == 'sum': return LL, (LH + HL + HH) / 3. elif self.cs == 'cat': return LL, torch.cat((LH, HL, HH), 1)
def __init__( self, num_levels: int = 1, wave: str = 'db2', mode: str = 'periodization', device: Optional[str] = None, ) -> None: self.num_levels = num_levels self.wave = wave self.mode = mode self.xfm = DWTForward(J=num_levels, wave=wave, mode=mode).to(device) self.ifm = DWTInverse(wave=wave, mode=mode).to(device)
def __init__(self, in_ch, out_ch=None, filter_sz=3, num_conv=3): super(WCNN, self).__init__() if out_ch is None: out_ch = 4 * in_ch self.DwT = DWTForward(J=1, wave='haar', mode='zero') # 4 * input channels since DwT creates 4 outputs per image modules = [] modules.append(nn.Conv2d(4 * in_ch, out_ch, kernel_size=3, padding=1)) #modules.append(nn.BatchNorm2d(num_features=out_ch)) modules.append(nn.ReLU()) for i in range(num_conv): modules.append(nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) #modules.append(nn.BatchNorm2d(num_features=out_ch)) modules.append(nn.ReLU()) self.conv = nn.Sequential(*modules)
def __init__(self, discriminator_size, perceptual_size=256, generator_weights=[1.0, 1.0, 1.0, 0.5, 0.1]): super().__init__() # l1/l2 loss self.l1_loss = nn.L1Loss() self.l2_loss = nn.MSELoss() # perceptual_loss self.perceptual_loss = PerceptualLoss(perceptual_size) # gan loss self.gan_loss = GANLoss(image_size=int(discriminator_size)) self.generator_weights = generator_weights self.dwt = DWTForward(J=1, mode='zero', wave='db1') self.idwt = DWTInverse(mode="zero", wave="db1")
def img2dwt(img_in, wave='coif2', sharp=0.3, colors=1.): image_t = un_rgb(img_in, colors=colors) with torch.no_grad(): wp_fake = pywt.WaveletPacket2D(data=np.zeros(image_t.shape[2:]), wavelet='db1', mode='zero') lvl = wp_fake.maxlevel # print(image_t.shape, lvl) xfm = DWTForward(J=lvl, wave=wave, mode='symmetric').cuda() Yl_in, Yh_in = xfm(image_t.cuda()) Ys = [Yl_in, *Yh_in] scale = dwt_scale(Ys, sharp) for i in range(len(Ys) - 1): Ys[i + 1] /= scale[i] return Ys
def __init__(self, main_dir, device, transform=[ transforms.Resize((768, 1024), interpolation=4), transforms.ToTensor() ], levels=2): self.main_dir = main_dir self.transforms = transforms.Compose(transform) self.image_list = glob.glob(main_dir + "/*/*.png") self.levels = levels self.names = [os.path.basename(x) for x in self.image_list] self.xfm = DWTForward(J=self.levels, mode='periodization', wave='bior3.3').to(device)
def test_equal(wave, J, mode): x = torch.randn(5, 4, 64, 64).to(dev) dwt = DWTForward(J=J, wave=wave, mode=mode).to(dev) iwt = DWTInverse(wave=wave, mode=mode).to(dev) yl, yh = dwt(x) x2 = iwt((yl, yh)) # Test the forward and inverse worked np.testing.assert_array_almost_equal(x.cpu(), x2.detach(), decimal=PREC_FLT) # Test it is the same as doing the PyWavelets wavedec with reflection # padding coeffs = pywt.wavedec2(x.cpu().numpy(), wave, level=J, axes=(-2,-1), mode=mode) np.testing.assert_array_almost_equal(yl.cpu(), coeffs[0], decimal=PREC_FLT) for j in range(J): for b in range(3): np.testing.assert_array_almost_equal( coeffs[J-j][b], yh[j][:,:,b].cpu(), decimal=PREC_FLT)
def __init__(self, wt_type='DTCWT', biort='near_sym_b', qshift='qshift_b', J=5, wave='db3', mode='zero', device='cuda', requires_grad=True): super().__init__() if wt_type == 'DTCWT': self.xfm = DTCWTForward(biort=biort, qshift=qshift, J=J).to(device) self.ifm = DTCWTInverse(biort=biort, qshift=qshift).to(device) elif wt_type == 'DWT': self.xfm = DWTForward(wave=wave, J=J, mode=mode).to(device) self.ifm = DWTInverse(wave=wave, mode=mode).to(device) else: raise ValueError('no such type of wavelet transform is supported') self.J = J self.wt_type = wt_type
def __init__(self, block, num_blocks, num_classes=10, in_C=12, wave='sym17', mode='append'): super(FDResNet, self).__init__() self.wave = wave self.DWT = DWTForward(J=1, wave=self.wave, mode='symmetric', Requirs_Grad=True).cuda() self.FDmode = mode self.in_planes = 64 self.conv1 = conv3x3(in_C, 64) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear = nn.Linear(512 * block.expansion, num_classes)
def __init__(self, in_C=3, C_Number=10, wave='haar', mode='append', init_weights=True, batch_norm=True): super(FDVgg, self).__init__() self.wave = wave self.DWT = DWTForward(J=1, wave=self.wave, mode='symmetric', Requirs_Grad=True).cuda() self.FDmode = mode self.features = make_layers(in_C=in_C, cfg=cfg['E'], batch_norm=batch_norm) self.layers = [] for m in self.features: if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): self.layers.append(m) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() self.classifier = nn.Sequential(nn.Dropout(), nn.Linear(512, 512), nn.ReLU(True), nn.Dropout(), nn.Linear(512, 512), nn.ReLU(True), nn.Linear(512, C_Number)) for m in self.classifier: if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): self.layers.append(m) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() self.bn_params += [m.weight, m.bias] if init_weights: self._initialize_weights()
def __init__(self, recursions=1, stride=1, kernel_size=5, wgan=False, highpass=True, D_arch='FSD', norm_layer='Instance', filter_type='gau', cs='cat'): super(Discriminator, self).__init__() self.wgan = wgan n_input_channel = 3 if highpass: if filter_type.lower() == 'gau': self.filter = FilterHigh(recursions=recursions, stride=stride, kernel_size=kernel_size, include_pad=False, gaussian=True) elif filter_type.lower() == 'avg_pool': self.filter = FilterHigh(recursions=recursions, stride=stride, kernel_size=kernel_size, include_pad=False, gaussian=False) elif filter_type.lower() == 'wavelet': self.DWT2 = DWTForward(J=1, wave='haar', mode='reflect') self.filter = self.filter_wavelet self.cs = cs n_input_channel = 9 if self.cs == 'cat' else 3 else: raise NotImplementedError('Frequency Separation type [{:s}] not recognized'.format(filter_type)) print('# FS type: {}, kernel size={}'.format(filter_type.lower(), kernel_size)) else: self.filter = None if D_arch.lower() == 'nld_s1': self.net = NLayerDiscriminator(input_nc=n_input_channel, ndf=64, n_layers=2, norm_layer=norm_layer, stride=1) print('# Initializing NLayer-Discriminator-stride-1 with {} norm-layer'.format(norm_layer.upper())) elif D_arch.lower() == 'nld_s2': self.net = NLayerDiscriminator(input_nc=n_input_channel, ndf=64, n_layers=2, norm_layer=norm_layer, stride=2) print('#Initializing NLayer-Discriminator-stride-2 with {} norm-layer'.format(norm_layer.upper())) elif D_arch.lower() == 'fsd': self.net = DiscriminatorBasic(n_input_channels=n_input_channel, norm_layer=norm_layer) print('# Initializing FSSR-DiscriminatorBasic with {} norm layer'.format(norm_layer)) else: raise NotImplementedError('Discriminator architecture [{:s}] not recognized'.format(D_arch))
import torch, torchvision import torchvision.transforms as transforms import numpy as np from pytorch_wavelets import DWTForward, DWTInverse import glob,os, shutil import matplotlib.pyplot as plt from PIL import Image import pickle # height=1024, width=768 levels = 2 xfm = DWTForward(J=levels, mode='periodization', wave='bior3.3') imgs = glob.glob("D:\\upla_sir_stuff\\kadid10k\\images\\???.png") imgs_n = [i[-7:-4] for i in imgs] # for i in imgs: # os.mkdir("D:\\upla_sir_stuff\\kadid10k\\images\\"+i) trans = transforms.ToTensor() imgs = glob.glob("D:\\upla_sir_stuff\\kadid10k\\images\\*\\*.png") names = [os.path.basename(x) for x in imgs] print(names) # print(imgs) # for i in imgs: # img = Image.open(i) # img = img.resize((768, 1024),3) # img = trans(img).unsqueeze(0) # # img = img.permute([0,3,1,2]) # # print(img.shape) # Yl, Yh = xfm(img) # print(Yh[0].shape) # print(Yh[1].shape) # # with open(i[0:-4]+"dwt.pkl", 'wb') as f: # # pickle.dump(Yh, f)
# x = torch.stack([emboss_transform(y) for y in x], dim=0) # return x.to(DEVICE).requires_grad_() def greyscale(x): def grey_transform(x): return x.mean(0, keepdim=True) x = torch.stack([grey_transform(y) for y in x], dim=0) return x.to(DEVICE).requires_grad_() from pytorch_wavelets import DWTForward, DWTInverse xfm = DWTForward(J=3, mode='zero', wave='db1').cuda() def wav_kernel(x): def wav_transform(x): x_h, x_l = xfm(x.unsqueeze(0)) x_h = F.interpolate(x_h, size=256, mode='bilinear') x_l_0, x_l_1, x_l_2 = F.interpolate( x_l[0][:, :, 0, :], size=256, mode='bilinear'), F.interpolate(x_l[1][:, :, 0, :], size=256, mode='bilinear'), F.interpolate( x_l[2][:, :, 0, :], size=256, mode='bilinear') x_final = torch.cat(