def test_equal_oddshape(size): wave = 'db3' J = 3 mode = 'symmetric' x = torch.randn(5, 4, *size).to(dev) dwt1 = DWTForward(J=J, wave=wave, mode=mode).to(dev) iwt1 = DWTInverse(wave=wave, mode=mode).to(dev) dwt2 = DWTForward(J=J, wave=wave, mode=mode).to(dev) iwt2 = DWTInverse(wave=wave, mode=mode).to(dev) yl1, yh1 = dwt1(x) x1 = iwt1((yl1, yh1)) yl2, yh2 = dwt2(x) x2 = iwt2((yl2, yh2)) # Test it is the same as doing the PyWavelets wavedec coeffs = pywt.wavedec2(x.cpu().numpy(), wave, level=J, axes=(-2,-1), mode=mode) X2 = pywt.waverec2(coeffs, wave, mode=mode) np.testing.assert_array_almost_equal(X2, x1.detach(), decimal=PREC_FLT) np.testing.assert_array_almost_equal(X2, x2.detach(), decimal=PREC_FLT) np.testing.assert_array_almost_equal(yl1.cpu(), coeffs[0], decimal=PREC_FLT) np.testing.assert_array_almost_equal(yl2.cpu(), coeffs[0], decimal=PREC_FLT) for j in range(J): for b in range(3): np.testing.assert_array_almost_equal( coeffs[J-j][b], yh1[j][:,:,b].cpu(), decimal=PREC_FLT) np.testing.assert_array_almost_equal( coeffs[J-j][b], yh2[j][:,:,b].cpu(), decimal=PREC_FLT)
def test_inv_j2(mode): with set_double_precision(): low = torch.randn(1, 3, 16, 16, device=dev, requires_grad=True) high = torch.randn(1, 3, 3, 16, 16, device=dev, requires_grad=True) ifm = DWTInverse().to(dev) input = (low, high, ifm.g0_row, ifm.g1_row, ifm.g0_col, ifm.g1_col, mode) gradcheck(SFB2D.apply, input, eps=EPS, atol=ATOL)
def test_commutativity(wave, J, j): # Test the commutativity of the dwt C = 3 Y = torch.randn(4, C, 128, 128, requires_grad=True, device=dev) dwt = DWTForward(J=J, wave=wave).to(dev) iwt = DWTInverse(wave=wave).to(dev) coeffs = dwt(Y) coeffs_zero = dwt(torch.zeros_like(Y)) # Set level j LH to be nonzero coeffs_zero[1][j][:,:,0] = coeffs[1][j][:,:,0] ya = iwt(coeffs_zero) # Set level j HL to also be nonzero coeffs_zero[1][j][:,:,1] = coeffs[1][j][:,:,1] yab = iwt(coeffs_zero) # Set level j LH to be nonzero coeffs_zero[1][j][:,:,0] = torch.zeros_like(coeffs[1][j][:,:,0]) yb = iwt(coeffs_zero) # Set level j HH to also be nonzero coeffs_zero[1][j][:,:,2] = coeffs[1][j][:,:,2] ybc = iwt(coeffs_zero) # Set level j HL to be nonzero coeffs_zero[1][j][:,:,1] = torch.zeros_like(coeffs[1][j][:,:,1]) yc = iwt(coeffs_zero) np.testing.assert_array_almost_equal( (ya+yb).detach().cpu(), yab.detach().cpu(), decimal=PREC_FLT) np.testing.assert_array_almost_equal( (yc+yb).detach().cpu(), ybc.detach().cpu(), decimal=PREC_FLT)
def test_gradients_fwd(wave, J, mode): """ Gradient of forward function should be inverse function with filters swapped """ im = np.random.randn(5,6,128, 128).astype('float32') imt = torch.tensor(im, dtype=torch.float32, requires_grad=True, device=dev) wave = pywt.Wavelet(wave) fwd_filts = (wave.dec_lo, wave.dec_hi) inv_filts = (wave.dec_lo[::-1], wave.dec_hi[::-1]) dwt = DWTForward(J=J, wave=fwd_filts, mode=mode).to(dev) iwt = DWTInverse(wave=inv_filts, mode=mode).to(dev) yl, yh = dwt(imt) # Test the lowpass ylg = torch.randn(*yl.shape, device=dev) yl.backward(ylg, retain_graph=True) zeros = [torch.zeros_like(yh[i]) for i in range(J)] ref = iwt((ylg, zeros)) np.testing.assert_array_almost_equal(imt.grad.detach().cpu(), ref.cpu(), decimal=PREC_FLT) # Test the bandpass for j, y in enumerate(yh): imt.grad.zero_() g = torch.randn(*y.shape, device=dev) y.backward(g, retain_graph=True) hps = [zeros[i] for i in range(J)] hps[j] = g ref = iwt((torch.zeros_like(yl), hps)) np.testing.assert_array_almost_equal(imt.grad.detach().cpu(), ref.cpu(), decimal=PREC_FLT)
def test_gradients_inv(wave, J, mode): """ Gradient of inverse function should be forward function with filters swapped """ wave = pywt.Wavelet(wave) fwd_filts = (wave.dec_lo, wave.dec_hi) inv_filts = (wave.dec_lo[::-1], wave.dec_hi[::-1]) dwt = DWTForward(J=J, wave=fwd_filts, mode=mode).to(dev) iwt = DWTInverse(wave=inv_filts, mode=mode).to(dev) # Get the shape of the pyramid temp = torch.zeros(5,6,128,128).to(dev) l, h = dwt(temp) # Create our inputs yl = torch.randn(*l.shape, requires_grad=True, device=dev) yh = [torch.randn(*h[i].shape, requires_grad=True, device=dev) for i in range(J)] y = iwt((yl, yh)) # Test the gradients yg = torch.randn(*y.shape, device=dev) y.backward(yg, retain_graph=True) dyl, dyh = dwt(yg) # test the lowpass np.testing.assert_array_almost_equal(yl.grad.detach().cpu(), dyl.cpu(), decimal=PREC_FLT) # Test the bandpass for j in range(J): np.testing.assert_array_almost_equal(yh[j].grad.detach().cpu(), dyh[j].cpu(), decimal=PREC_FLT)
def __init__(self, C, F, lp_size=3, bp_sizes=(1, ), q=1.0, wave='db2', mode='zero', xfm=True, ifm=True): super().__init__() self.C = C self.F = F self.J = len(bp_sizes) if xfm: self.XFM = DWTForward(J=self.J, mode=mode, wave=wave) else: self.XFM = lambda x: x if q < 0: self.shrink = ReLUWaveCoeffs() else: self.shrink = lambda x: x self.GainLayer = WaveGainLayer(C, F, lp_size, bp_sizes) if ifm: self.IFM = DWTInverse(mode=mode, wave=wave) else: self.IFM = lambda x: x
def test_equal_double(wave, J, mode): with set_double_precision(): x = torch.randn(5, 4, 64, 64).to(dev) assert x.dtype == torch.float64 dwt = DWTForward(J=J, wave=wave, mode=mode).to(dev) iwt = DWTInverse(wave=wave, mode=mode).to(dev) yl, yh = dwt(x) x2 = iwt((yl, yh)) # Test the forward and inverse worked np.testing.assert_array_almost_equal(x.cpu(), x2.detach().cpu(), decimal=PREC_DBL) coeffs = pywt.wavedec2(x.cpu().numpy(), wave, level=J, axes=(-2, -1), mode=mode) np.testing.assert_array_almost_equal(yl.cpu(), coeffs[0], decimal=7) for j in range(J): for b in range(3): np.testing.assert_array_almost_equal(coeffs[J - j][b], yh[j][:, :, b].cpu(), decimal=PREC_DBL)
def __init__(self, J=3, wave='db4', device='cpu', dtype=torch.float64): super().__init__() _dtype = torch.get_default_dtype() torch.set_default_dtype(dtype) self.dwt = DWTForward(J=J, wave=wave, mode='per').to(device) self.idwt = DWTInverse(wave=wave, mode='per').to(device) torch.set_default_dtype(_dtype) self.norm_bound = 1.
def test_ok(wave, J, mode): x = torch.randn(5, 4, 64, 64).to(dev) dwt = DWTForward(J=J, wave=wave, mode=mode).to(dev) iwt = DWTInverse(wave=wave, mode=mode).to(dev) yl, yh = dwt(x) x2 = iwt((yl, yh)) # Can have data errors sometimes assert yl.is_contiguous() for j in range(J): assert yh[j].is_contiguous() assert x2.is_contiguous()
def __init__(self, C, F, k=4, stride=1, J=1, wd=0, wd1=None, right=True): super().__init__() self.wd = wd if wd1 is None: self.wd1 = wd else: self.wd1 = wd1 self.C = C self.F = F x = torch.zeros(F, C, k, k) torch.nn.init.xavier_uniform_(x) xfm = DWTForward(J=J) self.ifm = DWTInverse() yl, yh = xfm(x) self.J = J if k == 4 and J == 1: self.gl = nn.Parameter(torch.zeros_like(yl)) self.gh = nn.Parameter(torch.zeros_like(yh[0])) self.gl.data = yl.data self.gh.data = yh[0].data if right: self.pad = (1, 2, 1, 2) else: self.pad = (2, 1, 2, 1) elif k == 8 and J == 1: self.gl = nn.Parameter(torch.zeros_like(yl)) self.gh = nn.Parameter(torch.zeros_like(yh[0])) self.gl.data = yl.data self.gh.data = yh[0].data if right: self.pad = (3, 4, 3, 4) else: self.pad = (4, 3, 4, 3) elif k == 8 and J == 2: self.gl = nn.Parameter(torch.zeros_like(yl)) self.gh = nn.Parameter(torch.zeros_like(yh[1])) self.gl.data = yl.data self.gh.data = yh[1].data if right: self.pad = (3, 4, 3, 4) else: self.pad = (4, 3, 4, 3) elif k == 8 and J == 3: self.gl = nn.Parameter(torch.zeros_like(yl)) self.gh = nn.Parameter(torch.zeros_like(yh[2])) self.gl.data = yl.data self.gh.data = yh[2].data if right: self.pad = (3, 4, 3, 4) else: self.pad = (4, 3, 4, 3) else: raise NotImplementedError
def __init__( self, num_levels: int = 1, wave: str = 'db2', mode: str = 'periodization', device: Optional[str] = None, ) -> None: self.num_levels = num_levels self.wave = wave self.mode = mode self.xfm = DWTForward(J=num_levels, wave=wave, mode=mode).to(device) self.ifm = DWTInverse(wave=wave, mode=mode).to(device)
def __init__(self, in_ch, internal_ch=None, filter_sz=3, num_conv=3): super(IWCNN, self).__init__() if internal_ch is None: internal_ch = in_ch self.IDwT = DWTInverse(wave='haar', mode='zero') modules = [] for i in range(num_conv): modules.append(nn.Conv2d(in_ch, in_ch, kernel_size=3, padding=1)) #modules.append(nn.BatchNorm2d(num_features=in_ch)) modules.append(nn.ReLU()) modules.append(nn.Conv2d(in_ch, internal_ch, kernel_size=3, padding=1)) #modules.append(nn.BatchNorm2d(num_features=internal_ch)) modules.append(nn.ReLU()) self.conv = nn.Sequential(*modules)
def __init__(self, discriminator_size, perceptual_size=256, generator_weights=[1.0, 1.0, 1.0, 0.5, 0.1]): super().__init__() # l1/l2 loss self.l1_loss = nn.L1Loss() self.l2_loss = nn.MSELoss() # perceptual_loss self.perceptual_loss = PerceptualLoss(perceptual_size) # gan loss self.gan_loss = GANLoss(image_size=int(discriminator_size)) self.generator_weights = generator_weights self.dwt = DWTForward(J=1, mode='zero', wave='db1') self.idwt = DWTInverse(mode="zero", wave="db1")
def test_equal(wave, J, mode): x = torch.randn(5, 4, 64, 64).to(dev) dwt = DWTForward(J=J, wave=wave, mode=mode).to(dev) iwt = DWTInverse(wave=wave, mode=mode).to(dev) yl, yh = dwt(x) x2 = iwt((yl, yh)) # Test the forward and inverse worked np.testing.assert_array_almost_equal(x.cpu(), x2.detach(), decimal=PREC_FLT) # Test it is the same as doing the PyWavelets wavedec with reflection # padding coeffs = pywt.wavedec2(x.cpu().numpy(), wave, level=J, axes=(-2,-1), mode=mode) np.testing.assert_array_almost_equal(yl.cpu(), coeffs[0], decimal=PREC_FLT) for j in range(J): for b in range(3): np.testing.assert_array_almost_equal( coeffs[J-j][b], yh[j][:,:,b].cpu(), decimal=PREC_FLT)
def __init__(self, wt_type='DTCWT', biort='near_sym_b', qshift='qshift_b', J=5, wave='db3', mode='zero', device='cuda', requires_grad=True): super().__init__() if wt_type == 'DTCWT': self.xfm = DTCWTForward(biort=biort, qshift=qshift, J=J).to(device) self.ifm = DTCWTInverse(biort=biort, qshift=qshift).to(device) elif wt_type == 'DWT': self.xfm = DWTForward(wave=wave, J=J, mode=mode).to(device) self.ifm = DWTInverse(wave=wave, mode=mode).to(device) else: raise ValueError('no such type of wavelet transform is supported') self.J = J self.wt_type = wt_type
def init_dwt(resume=None, shape=None, wave=None, colors=None): size = None wp_fake = pywt.WaveletPacket2D(data=np.zeros(shape[2:]), wavelet='db1', mode='symmetric') xfm = DWTForward(J=wp_fake.maxlevel, wave=wave, mode='symmetric').cuda() # xfm = DTCWTForward(J=lvl, biort='near_sym_b', qshift='qshift_b').cuda() # 4x more params, biort ['antonini','legall','near_sym_a','near_sym_b'] ifm = DWTInverse(wave=wave, mode='symmetric').cuda() # symmetric zero periodization # ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda() # 4x more params, biort ['antonini','legall','near_sym_a','near_sym_b'] if resume is None: # random init Yl_in, Yh_in = xfm(torch.zeros(shape).cuda()) Ys = [torch.randn(*Y.shape).cuda() for Y in [Yl_in, *Yh_in]] elif isinstance(resume, str): if os.path.isfile(resume): if os.path.splitext(resume)[1].lower()[1:] in [ 'jpg', 'png', 'tif', 'bmp' ]: img_in = imread(resume) Ys = img2dwt(img_in, wave=wave, colors=colors) print(' loaded image', resume, img_in.shape, 'level', len(Ys) - 1) size = img_in.shape[:2] wp_fake = pywt.WaveletPacket2D(data=np.zeros(size), wavelet='db1', mode='symmetric') xfm = DWTForward(J=wp_fake.maxlevel, wave=wave, mode='symmetric').cuda() else: Ys = torch.load(resume) Ys = [y.detach().cuda() for y in Ys] else: print(' Snapshot not found:', resume) exit() else: Ys = [y.cuda() for y in resume] # print('level', len(Ys)-1, 'low freq', Ys[0].cpu().numpy().shape) return Ys, xfm, ifm, size
def __init__(self): ##play attention the upscales super(MWCNN, self).__init__() self.DWT = DWTForward(J=1, wave="haar").cuda() self.IDWT = DWTInverse(wave="haar").cuda() # DMT1 operation # DMT1 # 因为对原图首先做了一次DWT,导致通道数变为原来的4倍调用保持的feature map的函数可以保持通道的正常 self.conv_DMT1 = nn.Conv2d(in_channels=3 * 4, out_channels=160, kernel_size=3, stride=1, padding=1) self.bn_DMT1 = nn.BatchNorm2d(160, affine=True) self.relu_DMT1 = nn.ReLU() # IDMT1 逆变换 self.conv_IDMT1 = nn.Conv2d(in_channels=160, out_channels=3 * 4, kernel_size=3, stride=1, padding=1) # feature map 保持 self.blockDMT1 = self.make_layer(Block_of_DMT1, 3) # DMT2 operation # DMT2 self.conv_DMT2 = nn.Conv2d(in_channels=640, out_channels=256, kernel_size=3, stride=1, padding=1) self.bn_DMT2 = nn.BatchNorm2d(256, affine=True) self.relu_DMT2 = nn.ReLU() # IDMT2 self.conv_IDMT2 = nn.Conv2d(in_channels=256, out_channels=640, kernel_size=3, stride=1, padding=1) self.bn_IDMT2 = nn.BatchNorm2d(640, affine=True) self.relu_IDMT2 = nn.ReLU() self.blockDMT2 = self.make_layer(Block_of_DMT2, 3) # DMT3 operation # DMT3 self.conv_DMT3 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=3, stride=1, padding=1) self.bn_DMT3 = nn.BatchNorm2d(256, affine=True) self.relu_DMT3 = nn.ReLU() # IDMT3 self.conv_IDMT3 = nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=3, stride=1, padding=1) self.bn_IDMT3 = nn.BatchNorm2d(1024, affine=True) self.relu_IDMT3 = nn.ReLU() self.blockDMT3 = self.make_layer(Block_of_DMT3, 3) self.try_conv1 = nn.Conv2d(160, 80, 3, 1, 1) # self.try_bn = nn.BatchNorm2d(20) self.try_conv2 = nn.Conv2d(80, 12, 3, 1, 1)
def perturb_img_bands_only(img, label, target, model, attack_type, dataset=None, net_type=None, alpha=0, max_iters=20, epsilon=8.0 / 255.0, random_restarts=False, update_type='gradient', step_size=2.0 / 255.0, bands='ll', lr=5e-2, odi=False, loss_type='xent', class2idx=None): if bands: print('Using bands only:', bands) print('Loss fn:', loss_type) print('Update:', update_type) # idx2class = dict() # for key in class2idx: # idx2class[str(class2idx[key])] = int(key) if dataset == 'imagenet': for i in range(len(target)): # print('before', target[i], idx2class[str(target[i].data.cpu().numpy())]) target[i] = idx2class[str(target[i].data.cpu().numpy())] # pritn('after', target[i]) target = target.cuda() print(target) if random_restarts: num_restarts = 20 else: num_restarts = 1 if odi: odi_steps = 2 else: odi_steps = 0 if dataset == 'multi-pie' and net_type == 'vggface2': epsilon = 8.0 lr = 2.0 if dataset == 'multi-pie' and net_type == 'vggface2-kernel-def': print('adjusted epsilon') epsilon = 8.0 lr = 2.0 odi_step_size = epsilon best_adv = None curr_pred = None switch_band = bands ifm = DWTInverse(mode='zero', wave='haar').cuda() xfm = DWTForward(J=1, mode='zero', wave='haar').cuda() # avg_Linf_error = 0 for restart in range(num_restarts): random_init = random_restarts print('Restart {}'.format(restart + 1)) adv = Variable(postprocess(img.clone().detach(), dataset, net_type).data, requires_grad=True) if loss_type == 'xent': loss_fn = torch.nn.CrossEntropyLoss() elif loss_type == 'margin': loss_fn = margin_loss if random_init: random_noise = torch.FloatTensor(*adv.shape).uniform_( -epsilon, epsilon).cuda() adv = Variable(adv.data + random_noise, requires_grad=True) for step in range(odi_steps + max_iters): LL, Y = xfm(adv) LH, HL, HH = torch.unbind(Y[0], dim=2) LL = Variable(LL.data, requires_grad=True) LH = Variable(LH.data, requires_grad=True) HL = Variable(HL.data, requires_grad=True) HH = Variable(HH.data, requires_grad=True) if bands == 'll': band_optim = optim.SGD([LL], lr=lr) elif bands == 'lh': band_optim = optim.SGD([LH], lr=lr) elif bands == 'hl': band_optim = optim.SGD([HL], lr=lr) elif bands == 'hh': band_optim = optim.SGD([HH], lr=lr) elif bands == 'high': band_optim = optim.SGD([LH, HL, HH], lr=lr) elif bands == 'all': band_optim = optim.SGD([LL, LH, HL, HH], lr=lr) band_optim.zero_grad() adv = Variable(ifm((LL, [torch.stack((LH, HL, HH), 2)])).data, requires_grad=True) band_loss = loss_fn( model( preprocess(ifm((LL, [torch.stack((LH, HL, HH), 2)])), dataset, net_type)), target) band_loss.backward() temp = [] if bands == 'll': for band in [LL]: band_eta = lr * band.grad.data.sign() band = Variable(band.data + band_eta, requires_grad=True) temp.append(band) LL, LH, HL, HH = temp[0], LH, HL, HH elif bands == 'lh': for band in [LH]: band_eta = lr * band.grad.data.sign() band = Variable(band.data + band_eta, requires_grad=True) temp.append(band) LL, LH, HL, HH = LL, temp[0], HL, HH elif bands == 'hl': for band in [HL]: band_eta = lr * band.grad.data.sign() band = Variable(band.data + band_eta, requires_grad=True) temp.append(band) LL, LH, HL, HH = LL, LH, temp[0], HH elif bands == 'hh': for band in [HH]: band_eta = lr * band.grad.data.sign() band = Variable(band.data + band_eta, requires_grad=True) temp.append(band) LL, LH, HL, HH = LL, LH, HL, temp[0] elif bands == 'high': for band in [LH, HL, HH]: band_eta = lr * band.grad.data.sign() band = Variable(band.data + band_eta, requires_grad=True) temp.append(band) LL, LH, HL, HH = LL, temp[0], temp[1], temp[2] elif bands == 'all': for band in [LL, LH, HL, HH]: band_eta = lr * band.grad.data.sign() band = Variable(band.data + band_eta, requires_grad=True) temp.append(band) LL, LH, HL, HH = temp[0], temp[1], temp[2], temp[3] adv = ifm((LL, [torch.stack((LH, HL, HH), 2)])) eta = torch.clamp( adv.data - postprocess(img.clone().detach(), dataset, net_type), -epsilon, epsilon) adv = Variable( postprocess(img.clone().detach(), dataset, net_type) + eta, requires_grad=True) if dataset == 'multi-pie' and net_type == 'vggface2': adv = Variable(torch.clamp(adv, 0, 255.0), requires_grad=True) elif dataset == 'multi-pie' and net_type == 'vggface2-kernel-def': adv = Variable(torch.clamp(adv, 0, 255.0), requires_grad=True) else: adv = Variable(torch.clamp(adv, 0, 1.0), requires_grad=True) with torch.no_grad(): curr_pred = torch.max(model(preprocess(adv, dataset, net_type)), 1)[1] if best_adv == None: best_adv = adv.clone().detach() else: for i in range(len(curr_pred)): if curr_pred[i].data != target[i].data: best_adv[i] = adv[i] return preprocess(best_adv, dataset, net_type)
netG.apply(weights_init) else: netG.apply(weights_init) # setup optimizer if opt.optimizer == 'adam': optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) elif opt.optimizer == 'sgd': optimizerG = optim.SGD(netG.parameters(), lr=opt.lr, momentum=0.9) criterion_pre = nn.CrossEntropyLoss() criterion_pre = criterion_pre.cuda(gpulist[0]) # create forward and inverse wavelet transform xfm = DWTForward(J=3, wave='db3', mode='symmetric').cuda(gpulist[0]) ifm = DWTInverse(wave='db3', mode='symmetric').cuda(gpulist[0]) # Penalty coef for conditional loss dssimCoef = { 0.05: 3.0, 0.1: 2.0, 0.2: 1.0, 0.3: 0.5 } #________________________________________________________________________________ # ## @brief conditional loss function with budget control # #________________________________________________________________________________ def conditionalBudgetLoss(perturbedX, x, outputLabel, targetLabel):
def __init__(self, opt): super(DSWN, self).__init__() self.DWT = DWTForward(J=1, wave='haar').cuda() self.IDWT = DWTInverse(wave='haar').cuda() # The generator is U shaped # Encoder self.E1 = Conv2dLayer(in_channels=3, out_channels=160, kernel_size=3, stride=1, padding=1, dilation=1, pad_type=opt.pad, activation='prelu', norm=opt.norm) self.E2 = Conv2dLayer(in_channels=3 * 4, out_channels=256, kernel_size=3, stride=1, padding=1, dilation=1, pad_type=opt.pad, activation='prelu', norm=opt.norm) self.E3 = Conv2dLayer(in_channels=3 * 4 * 4, out_channels=256, kernel_size=3, stride=1, padding=1, dilation=1, pad_type=opt.pad, activation='prelu', norm=opt.norm) self.E4 = Conv2dLayer(in_channels=3 * 4 * 16, out_channels=256, kernel_size=3, stride=1, padding=1, dilation=1, pad_type=opt.pad, activation='prelu', norm=opt.norm) # Bottle neck self.BottleNeck = nn.Sequential( ResConv2dLayer(256, 3, 1, 1, pad_type=opt.pad, norm=opt.norm), ResConv2dLayer(256, 3, 1, 1, pad_type=opt.pad, norm=opt.norm), ResConv2dLayer(256, 3, 1, 1, pad_type=opt.pad, norm=opt.norm), ResConv2dLayer(256, 3, 1, 1, pad_type=opt.pad, norm=opt.norm)) self.blockDMT11 = self.make_layer(_DCR_block, 320) self.blockDMT12 = self.make_layer(_DCR_block, 320) self.blockDMT13 = self.make_layer(_DCR_block, 320) self.blockDMT14 = self.make_layer(_DCR_block, 320) self.blockDMT21 = self.make_layer(_DCR_block, 512) # self.blockDMT22 = self.make_layer(_DCR_block, 512) # self.blockDMT23 = self.make_layer(_DCR_block, 512) # self.blockDMT24 = self.make_layer(_DCR_block, 512) self.blockDMT31 = self.make_layer(_DCR_block, 512) # self.blockDMT32 = self.make_layer(_DCR_block, 512) # self.blockDMT33 = self.make_layer(_DCR_block, 512) # self.blockDMT34 = self.make_layer(_DCR_block, 512) self.blockDMT41 = self.make_layer(_DCR_block, 256) # self.blockDMT42 = self.make_layer(_DCR_block, 256) # self.blockDMT43 = self.make_layer(_DCR_block, 256) # self.blockDMT44 = self.make_layer(_DCR_block, 256) # self.DRB11 = ResidualDenseBlock_5C(nf=320, gc=64) # self.DRB12 = ResidualDenseBlock_5C(nf=320, gc=64) # self.DRB21 = ResidualDenseBlock_5C(nf=512, gc=64) # self.DRB31 = ResidualDenseBlock_5C(nf=512, gc=64) # self.DRB41 = ResidualDenseBlock_5C(nf=256, gc=64) # Decoder self.D1 = Conv2dLayer(in_channels=256, out_channels=1024, kernel_size=3, stride=1, padding=1, dilation=1, pad_type=opt.pad, activation='prelu', norm=opt.norm) self.D2 = Conv2dLayer(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1, dilation=1, pad_type=opt.pad, activation='prelu', norm=opt.norm) self.D3 = Conv2dLayer(in_channels=512, out_channels=640, kernel_size=3, stride=1, padding=1, dilation=1, pad_type=opt.pad, activation='prelu', norm=opt.norm) self.D4 = Conv2dLayer(in_channels=320, out_channels=3, kernel_size=3, stride=1, padding=1, dilation=1, pad_type=opt.pad, norm='none', activation='tanh') self.D5 = Conv2dLayer(in_channels=320, out_channels=3, kernel_size=3, stride=1, padding=1, dilation=1, pad_type=opt.pad, norm='none', activation='tanh') # channel shuffle self.S1 = Conv2dLayer(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0, dilation=1, pad_type=opt.pad, activation='none', norm='none') self.S2 = Conv2dLayer(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0, dilation=1, pad_type=opt.pad, activation='none', norm='none') self.S3 = Conv2dLayer(in_channels=320, out_channels=320, kernel_size=1, stride=1, padding=0, dilation=1, pad_type=opt.pad, activation='none', norm='none') self.S4 = Conv2dLayer(in_channels=320, out_channels=320, kernel_size=1, stride=1, padding=0, dilation=1, groups=3 * 320, pad_type=opt.pad, activation='none', norm='none')
def __init__(self, in_channels, out_channels, num_features, num_simdb, upscale_factor, act_type='prelu', norm_type=None): super(WSR, self).__init__() padding = 2 self.num_features = num_features self.upscale_factor = upscale_factor self.num_steps = int(np.log2(self.upscale_factor) + 1) # LR feature extraction block self.conv_in = ConvBlock(in_channels, 4 * num_features, kernel_size=3, act_type=act_type, norm_type=norm_type) self.feat_in = ConvBlock(4 * num_features, num_features, kernel_size=1, act_type=act_type, norm_type=norm_type) # recurrent block self.rb = RecurrentBlock(num_features, num_simdb, act_type, norm_type) # reconstruction block self.conv_steps = nn.ModuleList([ nn.Sequential( ConvBlock(num_features, num_features, kernel_size=3, act_type=act_type, norm_type=norm_type), ConvBlock(num_features, out_channels, kernel_size=3, act_type=None, norm_type=norm_type)), nn.Sequential( ConvBlock(num_features, num_features, kernel_size=3, act_type=act_type, norm_type=norm_type), ConvBlock(num_features, out_channels * 3, kernel_size=3, act_type=None, norm_type=norm_type)) ]) for step in range(2, self.num_steps): conv_step = nn.Sequential( DeconvBlock(num_features, num_features, kernel_size=int(2**(step - 1) + 4), stride=int(2**(step - 1)), padding=padding, act_type=act_type, norm_type=norm_type), ConvBlock(num_features, out_channels * 3, kernel_size=3, act_type=None, norm_type=norm_type)) self.conv_steps.append(conv_step) # inverse wavelet transformation self.ifm = DWTInverse(wave='db1', mode='symmetric').eval() for k, v in self.ifm.named_parameters(): v.requires_grad = False
Yl, Yh = DWTForward(J=J_spec, mode='periodization', wave='db3')(img) print(Yl.shape, [h.shape for h in Yh]) imgLR = F.interpolate(img, scale_factor=.5) LQYl, LQYh = DWTForward(J=J_spec-1, mode='periodization', wave='db3')(imgLR) print(LQYl.shape, [h.shape for h in LQYh]) for i in range(J_spec): smd = torch.sum(Yh[i], dim=2).cpu() save_img(smd, "high_%i.png" % (i,)) save_img(Yl, "lo.png") ''' Following code reconstructs the image with different high passes cancelled out. ''' for i in range(J_spec): corrupted_im = [y for y in Yh] corrupted_im[i] = torch.zeros_like(corrupted_im[i]) im = DWTInverse(mode='periodization', wave='db3')((Yl, corrupted_im)) save_img(im, "corrupt_%i.png" % (i,)) im = DWTInverse(mode='periodization', wave='db3')((torch.full_like(Yl, fill_value=torch.mean(Yl)), Yh)) save_img(im, "corrupt_im.png") ''' Following code reconstructs a hybrid image with the first high pass from the HR and the rest of the data from the LR. highpass = [Yh[0]] + LQYh im = DWTInverse(mode='periodization', wave='db3')((LQYl, highpass)) save_img(im, "hybrid_lrhr.png") save_img(F.interpolate(imgLR, scale_factor=2), "upscaled.png") '''
def __init__(self, opt): super(MWCNN, self).__init__() self.DWT = DWTForward(J=1, wave='haar').cuda() self.IDWT = DWTInverse(wave='haar').cuda() # The generator is U shaped # Encoder self.E1 = Conv2dLayer(in_channels=3 * 4, out_channels=160, kernel_size=3, stride=1, padding=1, dilation=1, pad_type=opt.pad, activation='relu', norm=opt.norm) self.E2 = Conv2dLayer(in_channels=640, out_channels=256, kernel_size=3, stride=1, padding=1, dilation=1, pad_type=opt.pad, activation='relu', norm=opt.norm) self.E3 = Conv2dLayer(in_channels=1024, out_channels=256, kernel_size=3, stride=1, padding=1, dilation=1, pad_type=opt.pad, activation='relu', norm=opt.norm) self.E4 = Conv2dLayer(in_channels=1024, out_channels=256, kernel_size=3, stride=1, padding=1, dilation=1, pad_type=opt.pad, activation='relu', norm=opt.norm) # Bottle neck self.BottleNeck = nn.Sequential( ResConv2dLayer(256, 3, 1, 1, pad_type=opt.pad, norm=opt.norm), ResConv2dLayer(256, 3, 1, 1, pad_type=opt.pad, norm=opt.norm), ResConv2dLayer(256, 3, 1, 1, pad_type=opt.pad, norm=opt.norm), ResConv2dLayer(256, 3, 1, 1, pad_type=opt.pad, norm=opt.norm)) self.blockDMT1 = self.make_layer(Block_of_DMT1, 3) self.blockDMT2 = self.make_layer(Block_of_DMT2, 3) self.blockDMT3 = self.make_layer(Block_of_DMT3, 3) self.blockDMT4 = self.make_layer(Block_of_DMT4, 3) # Decoder self.D1 = Conv2dLayer(in_channels=256, out_channels=1024, kernel_size=3, stride=1, padding=1, dilation=1, pad_type=opt.pad, activation='relu', norm=opt.norm) self.D2 = Conv2dLayer(in_channels=256, out_channels=1024, kernel_size=3, stride=1, padding=1, dilation=1, pad_type=opt.pad, activation='relu', norm=opt.norm) self.D3 = Conv2dLayer(in_channels=256, out_channels=640, kernel_size=3, stride=1, padding=1, dilation=1, pad_type=opt.pad, activation='relu', norm=opt.norm) self.D4 = Conv2dLayer(in_channels=160, out_channels=3 * 4, kernel_size=3, stride=1, padding=1, dilation=1, pad_type=opt.pad, norm='none', activation='tanh')
def mix_bands(img_1, img_2, net, label_1, label_2, dataset, switch_band): """ Parameters: img_1: 1x3xHxW RGB image img_2: 1x3xHxW RGB image """ img_1 = torch.unsqueeze(img_1.clone(), 0) img_2 = torch.unsqueeze(img_2.clone(), 0) img_1 = postprocess(img_1, dataset) img_2 = postprocess(img_2, dataset) img_1_og = img_1.clone() img_2_og = img_2.clone() ifm = DWTInverse(mode='zero', wave='haar').cuda() xfm = DWTForward(J=1, mode='zero', wave='haar').cuda() fig, ax = plt.subplots(nrows=2, ncols=10, figsize=(20, 10)) for i, img in enumerate([img_1, img_2]): LL, Y = xfm(img) LH, HL, HH = torch.unbind(Y[0], dim=2) ax[i, 0].imshow(LL[0].data.cpu().permute(1, 2, 0)) ax[i, 1].imshow(10 * LH[0].data.cpu().permute(1, 2, 0) / torch.max(LH[0].data.cpu().permute(1, 2, 0))) ax[i, 2].imshow(10 * HL[0].data.cpu().permute(1, 2, 0) / torch.max(HL[0].data.cpu().permute(1, 2, 0))) ax[i, 3].imshow(10 * HH[0].data.cpu().permute(1, 2, 0) / torch.max(HH[0].data.cpu().permute(1, 2, 0))) ax[i, 4].imshow(img[0].data.cpu().permute(1, 2, 0)) ax[i, 0].set_title('LL' + '_' + str(i)) ax[i, 1].set_title('LH' + '_' + str(i)) ax[i, 2].set_title('HL' + '_' + str(i)) ax[i, 3].set_title('HH' + '_' + str(i)) ax[i, 4].set_title('normal_img' + '_' + str(i)) ax[i, 0].set_yticks([], []) ax[i, 0].set_xticks([], []) ax[i, 1].set_yticks([], []) ax[i, 1].set_xticks([], []) ax[i, 2].set_yticks([], []) ax[i, 2].set_xticks([], []) ax[i, 3].set_yticks([], []) ax[i, 3].set_xticks([], []) ax[i, 4].set_yticks([], []) ax[i, 4].set_xticks([], []) # Reconstruct using new components LL_1, Y_1 = xfm(img_1) LH_1, HL_1, HH_1 = torch.unbind(Y_1[0], dim=2) LL_2, Y_2 = xfm(img_2) LH_2, HL_2, HH_2 = torch.unbind(Y_2[0], dim=2) if switch_band == 'll': img_1 = ifm((LL_2, [torch.stack((LH_1, HL_1, HH_1), 2)])) img_2 = ifm((LL_1, [torch.stack((LH_2, HL_2, HH_2), 2)])) elif switch_band == 'lh': img_1 = ifm((LL_1, [torch.stack((LH_2, HL_1, HH_1), 2)])) img_2 = ifm((LL_2, [torch.stack((LH_1, HL_2, HH_2), 2)])) elif switch_band == 'hl': img_1 = ifm((LL_1, [torch.stack((LH_1, HL_2, HH_1), 2)])) img_2 = ifm((LL_2, [torch.stack((LH_2, HL_1, HH_2), 2)])) elif switch_band == 'hh': img_1 = ifm((LL_1, [torch.stack((LH_1, HL_1, HH_2), 2)])) img_2 = ifm((LL_2, [torch.stack((LH_2, HL_2, HH_1), 2)])) elif switch_band == 'high': img_1 = ifm((LL_1, [torch.stack((LH_2, HL_2, HH_2), 2)])) img_2 = ifm((LL_2, [torch.stack((LH_1, HL_1, HH_1), 2)])) epsilon = 0.031 eta = torch.clamp(img_1.data - img_1_og, -epsilon, epsilon) img_1 = img_1_og + eta img_1 = torch.clamp(img_1, 0, 1.0) eta = torch.clamp(img_2.data - img_2_og, -epsilon, epsilon) img_2 = img_2_og + eta img_2 = torch.clamp(img_2, 0, 1.0) for i, img in enumerate([img_1, img_2]): LL, Y = xfm(img) LH, HL, HH = torch.unbind(Y[0], dim=2) ax[i, 0 + 5].imshow(LL[0].data.cpu().permute(1, 2, 0)) ax[i, 1 + 5].imshow(10 * LH[0].data.cpu().permute(1, 2, 0) / torch.max(LH[0].data.cpu().permute(1, 2, 0))) ax[i, 2 + 5].imshow(10 * HL[0].data.cpu().permute(1, 2, 0) / torch.max(HL[0].data.cpu().permute(1, 2, 0))) ax[i, 3 + 5].imshow(10 * HH[0].data.cpu().permute(1, 2, 0) / torch.max(HH[0].data.cpu().permute(1, 2, 0))) ax[i, 4 + 5].imshow(img[0].data.cpu().permute(1, 2, 0)) ax[i, 0 + 5].set_title('LL' + '_' + str(i)) ax[i, 1 + 5].set_title('LH' + '_' + str(i)) ax[i, 2 + 5].set_title('HL' + '_' + str(i)) ax[i, 3 + 5].set_title('HH' + '_' + str(i)) ax[i, 4 + 5].set_title('perturbed_img' + '_' + str(i)) ax[i, 0 + 5].set_yticks([], []) ax[i, 0 + 5].set_xticks([], []) ax[i, 1 + 5].set_yticks([], []) ax[i, 1 + 5].set_xticks([], []) ax[i, 2 + 5].set_yticks([], []) ax[i, 2 + 5].set_xticks([], []) ax[i, 3 + 5].set_yticks([], []) ax[i, 3 + 5].set_xticks([], []) ax[i, 4 + 5].set_yticks([], []) ax[i, 4 + 5].set_xticks([], []) plt.savefig('./imgs/mix.png') print(['*'] * 20) resize = transforms.Resize((32, 32)) LL_1, Y_1 = xfm(preprocess(img_1, dataset)) LH_2, HL_2, HH_2 = torch.unbind(Y_1[0], dim=2) LL_2, Y_2 = xfm(preprocess(img_2, dataset)) LH_1, HL_1, HH_1 = torch.unbind(Y_2[0], dim=2) print('image 1 label:\t', label_1) print('perturbed label:\t', torch.max(net(preprocess(img_1, dataset)), 1)[1]) print('print LH_1 label:\t', torch.max(net(F.interpolate(LH_1, 32)), 1)[1]) print('print HL_1 label:\t', torch.max(net(F.interpolate(HL_1, 32)), 1)[1]) print('print HH_1 label:\t', torch.max(net(F.interpolate(HH_1, 32)), 1)[1]) print('image 2 label:\t', label_2) print('perturbed label:\t', torch.max(net(preprocess(img_2, dataset)), 1)[1]) print('print LH_2 label:\t', torch.max(net(F.interpolate(LH_2, 32)), 1)[1]) print('print HL_2 label:\t', torch.max(net(F.interpolate(HL_2, 32)), 1)[1]) print('print HH_2 label:\t', torch.max(net(F.interpolate(HH_2, 32)), 1)[1])