def __init__(self, data_path, dct_n=25, split=0, sets=None, is_cuda=False, add_data=None): if sets is None: sets = [[0, 1], [2], [3]] self.dct_n = dct_n correct, other = load_data(data_path, sets[split], add_data=add_data) pairs = dtw_pairs(correct, other, is_cuda=is_cuda) self.targets_label = [i[1] for i in pairs] self.inputs_label = [i[0] for i in pairs] self.targets = [correct[i] for i in self.targets_label] self.inputs_raw = [other[i] for i in self.inputs_label] self.inputs = [ dct.dct_2d(torch.from_numpy(x))[:, :self.dct_n].numpy() if x.shape[1] >= self.dct_n else dct.dct_2d( torch.nn.ZeroPad2d((0, self.dct_n - x.shape[1], 0, 0))(torch.from_numpy(x))).numpy() for x in self.inputs_raw ] self.node_n = np.shape(self.inputs_raw[0])[0] self.batch_ids = list(range(len(self.inputs_raw)))
def get_dct_init(self, dim_out, dim_in, med_out, med_in, small_out, small_in): factor = 1. # initialize as if dense matrix init = torch.rand([dim_out, dim_in]) if self.cuda: # TODO update to device. init = init.cuda() initrange = 1.0 / math.sqrt(dim_out) # initrange = 0.1 nn.init.uniform_(init, -initrange, initrange) # apply DCT init_f = dct.dct_2d(init, norm='ortho') ind = np.array([[i, j] for i in range(med_out) for j in range(med_in)]).transpose() coeffs_init = init_f[ind] * factor coeffs_init = torch.reshape(coeffs_init, [med_out, med_in]) # apply further DCT coeffs_init = dct.dct_2d(coeffs_init, norm='ortho') ind = np.array([[i, j] for i in range(small_out) for j in range(small_in)]).transpose() coeffs_init = coeffs_init[ind] * factor coeffs_init = torch.reshape(coeffs_init, [small_out, small_in]) return coeffs_init
def test_idct_2d(): for N1 in [2, 5, 32]: for N2 in [2, 5, 32]: x = np.random.normal(size=(1, N1, N2)) X = dct.dct_2d(torch.tensor(x)) y = dct.idct_2d(X).numpy() assert np.abs(x - y).max() < EPS, x
def dct_cutoff_low(x, bandwidth): if len(x.size()) == 2: x.unsqueeze_(0) mask = torch.ones_like(x) mask[:, :bandwidth, :bandwidth] = 0 return torch_dct.idct_2d(torch_dct.dct_2d(x, norm='ortho') * mask, norm='ortho').squeeze_()
def forward(self, X): nComponents = self.num_outputs nSamples = X.size(0) height = X.size(2) width = X.size(3) stride = self.decimation_factor nrows = int(math.ceil(height / stride[Direction.VERTICAL])) ncols = int(math.ceil(width / stride[Direction.HORIZONTAL])) ndecs = stride[0] * stride[1] #math.prod(stride) # Block DCT (nSamples x nComponents x nrows x ncols) x decV x decH arrayshape = stride.copy() arrayshape.insert(0, -1) Y = dct.dct_2d(X.view(arrayshape), norm='ortho') # Rearrange the DCT Coefs. (nSamples x nComponents x nrows x ncols) x (decV x decH) cee = Y[:, 0::2, 0::2].reshape(Y.size(0), -1) coo = Y[:, 1::2, 1::2].reshape(Y.size(0), -1) coe = Y[:, 1::2, 0::2].reshape(Y.size(0), -1) ceo = Y[:, 0::2, 1::2].reshape(Y.size(0), -1) A = torch.cat((cee, coo, coe, ceo), dim=-1) Z = A.view(nSamples, nComponents, nrows, ncols, ndecs) if nComponents < 2: return torch.squeeze(Z, dim=1) else: return map(lambda x: torch.squeeze(x, dim=1), torch.chunk(Z, nComponents, dim=1))
def dct_low_pass(x, bandwidth): if len(x.size()) == 2: x.unsqueeze_(0) mask = torch.zeros_like(x) mask[:, :bandwidth, :bandwidth] = 1 return torch_dct.idct_2d(torch_dct.dct_2d(x, norm='ortho') * mask, norm='ortho').squeeze_()
def forward(self, x): if self.downsample_to: # Downsample. x_orig = x x = torch.nn.functional.interpolate( x, size=(self.downsample_to, self.downsample_to), mode='bilinear') y = x if self.frequency_domain: # Input to viewmaker is in frequency domain, outputs frequency domain perturbation. # Uses the Discrete Cosine Transform. # shape still [batch_size, C, W, H] y = dct.dct_2d(y) y_pixels, features = self.basic_net(y, self.num_res_blocks, bound_multiplier=1) delta = self.get_delta(y_pixels) if self.frequency_domain: # Compute inverse DCT from frequency domain to time domain. delta = dct.idct_2d(delta) if self.downsample_to: # Upsample. x = x_orig delta = torch.nn.functional.interpolate(delta, size=x_orig.shape[-2:], mode='bilinear') # Additive perturbation result = x + delta if self.clamp: result = torch.clamp(result, 0, 1.0) return result
def __init__(self, originals: ep.Tensor, random_noise: str = "normal", basis_type: str = "dct", **kwargs : Any): """ Args: random_noise (str, optional): When basis is created, a noise will be added.This noise can be normal or uniform. Defaults to "normal". basis_type (str, optional): Type of the basis: DCT, Random, Genetic,. Defaults to "random". device (int, optional): [description]. Defaults to -1. args, kwargs: In args and kwargs, there is the basis params: * Random: No parameters * DCT: * function (tanh / constant / linear): function applied on the dct * beta * gamma * frequence_range: tuple of 2 float * dct_type: 8x8 or full """ self._originals = originals if isinstance(self._originals.raw, torch.Tensor): self._f_dct2 = lambda a: torch_dct.dct_2d(a) self._f_idct2 = lambda a: torch_dct.idct_2d(a) elif isinstance(v.raw, np.array): from scipy import fft self._f_dct2 = lambda a: fft.dct(fft.dct(a, axis=2, norm='ortho' ), axis=3, norm='ortho') self._f_idct2 = lambda a: fft.idct(fft.idct(a, axis=2, norm='ortho'), axis=3, norm='ortho') self.basis_type = basis_type self._function_generation = getattr(self, "_get_vector_" + self.basis_type) self._load_params(**kwargs) assert random_noise in ["normal", "uniform"] self.random_noise = random_noise
def dct_calculation(self, x): N, C, H, W = x.size() B = self.block_size # # 2. The chrominance channels Cr and Cb are subsampled # # if self.sub_sampling == '4:2:0': # # imSub = self.subsample_chrominance(x, 2, 2) # # 3. Get the quatisation matrices, which will be applied to the DCT coefficients # Q = self.quality_factorize(self.QY, self.QC, self.quality_factor) # # 4. Apply DCT algorithm for orignal image # TransAll, TransAllThresh ,TransAllQuant = self.dct_encoder(x, Q, self.block_size, self.thresh) # # 5. Split the same frequency in each 8x8 blocks to the same channel # dct_list = self.split_frequency(TransAll, self.block_size) # # 6. upsample the Cr & Cb channel to concatenate with Y channel # # dct_coefficients = self.upsample(dct_list) blocksV = int(W / B) blocksH = int(H / B) # vis0 = torch.zeros_like(x).cuda() # for row in range(blocksV): # for col in range(blocksH): # currentblock = torch_dct.dct_2d(x[:,:, row*B:(row+1)*B, col*B:(col+1)*B], norm='ortho') # vis0[:,:, row*B:(row+1)*B, col*B:(col+1)*B] = currentblock # vis0 = torch_dct.dct_2d(x, norm='ortho') block = x.reshape(N, C, blocksH, B, blocksV, B) block1 = block.permute(0,1,2,4,3,5) block2 = block1.reshape(N, C, -1, B, B) dct_block = torch_dct.dct_2d(block2, norm='ortho') block3 = dct_block.reshape(N, C, blocksH, blocksV, B, B) block4 = block3.permute(0,1,2,4,3,5) block5 = block4.reshape(N, C, H, W) return block5
def forward(self, x): X = dct.dct_2d(x) mask = torch.zeros_like(X) mask[:, :, 0:self.fre, 0:self.fre] = 1 out = dct.idct_2d(X * mask) out = self.model(out) return out
def dct_torch(self, x, h, w): out = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) out = out.to(torch.float32) return torch_dct.dct_2d(out, norm='ortho')
def test_dct_2d(): for N1 in [2, 5, 32]: for N2 in [2, 5, 32]: x = np.random.normal(size=(1, N1, N2)) ref = fftpack.dct(x, axis=2, type=2) ref = fftpack.dct(ref, axis=1, type=2) act = dct.dct_2d(torch.tensor(x)).numpy() assert np.abs(ref - act).max() < EPS, (ref, act)
def forward(self, x): batch_size, channels, height, width = x.size() dct_vector_coords_all = self.get_dct_vector_coords(r=height) dct_vector_coords = dct_vector_coords_all[:self.vec_dim] # masks = x.to(dtype=float) masks = x.clone() dct_all = torch_dct.dct_2d(masks, norm='ortho') xs, ys = dct_vector_coords[:, 0], dct_vector_coords[:, 1] dct_vectors = dct_all[:, :, xs, ys] # reshape as vector return dct_vectors # [batch_size, channels, D]
def testPredictRgbColor(self, stride, height, width, datatype): rtol, atol = 1e-5, 1e-8 # Parameters nSamples = 8 nComponents = 3 # RGB # Source (nSamples x nComponents x (Stride[0]xnRows) x (Stride[1]xnCols)) X = torch.rand(nSamples, nComponents, height, width, dtype=datatype, requires_grad=True) # Expected values nrows = int(math.ceil(height / stride[Direction.VERTICAL])) #.astype(int) ncols = int(math.ceil(width / stride[Direction.HORIZONTAL])) #.astype(int) ndecs = stride[0] * stride[1] # math.prod(stride) # Block DCT (nSamples x nComponents x nrows x ncols) x decV x decH arrayshape = stride.copy() arrayshape.insert(0, -1) Y = dct.dct_2d(X.view(arrayshape), norm='ortho') # Rearrange the DCT Coefs. (nSamples x nComponents x nrows x ncols) x (decV x decH) A = permuteDctCoefs_(Y) Z = A.view(nSamples, nComponents, nrows, ncols, ndecs) expctdZr = Z[:, 0, :, :, :] expctdZg = Z[:, 1, :, :, :] expctdZb = Z[:, 2, :, :, :] # Instantiation of target class layer = NsoltBlockDct2dLayer(decimation_factor=stride, number_of_components=nComponents, name='E0') # Actual values with torch.no_grad(): actualZr, actualZg, actualZb = layer.forward(X) # Evaluation self.assertEqual(actualZr.dtype, datatype) self.assertEqual(actualZg.dtype, datatype) self.assertEqual(actualZb.dtype, datatype) self.assertTrue( torch.allclose(actualZr, expctdZr, rtol=rtol, atol=atol)) self.assertTrue( torch.allclose(actualZg, expctdZg, rtol=rtol, atol=atol)) self.assertTrue( torch.allclose(actualZb, expctdZb, rtol=rtol, atol=atol)) self.assertFalse(actualZr.requires_grad) self.assertFalse(actualZg.requires_grad) self.assertFalse(actualZb.requires_grad)
def encode(self, masks, dim=None): """ Encode the mask to vector of vec_dim or specific dimention. """ if dim is None: dct_vector_coords = self.dct_vector_coords[:self.vec_dim] else: dct_vector_coords = self.dct_vector_coords[:dim] masks = masks.view([-1, self.mask_size, self.mask_size]).to(dtype=float) # [N, H, W] dct_all = torch_dct.dct_2d(masks, norm='ortho') xs, ys = dct_vector_coords[:, 0], dct_vector_coords[:, 1] dct_vectors = dct_all[:, xs, ys] # reshape as vector return dct_vectors # [N, D]
def get_dct_init(self, len_coeffs, dim_out, dim_in, diag_shift): factor = 1. init = torch.rand([dim_out, dim_in]) if self.cuda: # TODO update to device. init = init.cuda() initrange = 1.0 / math.sqrt(dim_out) nn.init.uniform_(init, -initrange, initrange) init_f = torch.fliplr(dct.dct_2d(init, norm='ortho')) ind = torch.triu_indices(dim_out, dim_in, diag_shift) coeffs_init = init_f[tuple(ind)] * factor return coeffs_init
def count_dct_zeros(x, qf): #Define lochelechser luminance matrix (see JPEG standard) lc = loch_matrix() lc = torch.floor(q_table(lc, qf)) #Shift pixels values (as in JPEG) aux = x * 255.0 - 128.0 #Stacks all image 8x8 blocks a = torch.nn.functional.unfold(aux, 8, stride=8) #Swaps spatial dimensions axis in correct order to perform dct b = torch.transpose(a, 1, 2) b = b.view([a.shape[0], a.shape[-1], 8, 8]) #Here b dimensions are [BATCH_SIZE, IMAGE AMOUNT OF 8X8BLOCKS, HEIGHT (8), WIDTH (8)) #Now do DCT in the last 2 dimensions of b b = dct.dct_2d(b, norm='ortho') #See Thesis for estimation rate method b = torch.abs(b) * 2 - lc b = torch.relu(b) count = (b / (b + 0.00001)).sum(dim=(1, 2, 3)) return torch.mean(count) #batch mean
def dct_encoder(self, imSub_list, Q, blocksize=8, thresh=0.05): TransAll_list=[] TransAllThresh_list=[] TransAllQuant_list=[] for idx,channel in enumerate(imSub_list): channelrows=channel.shape[0] channelcols=channel.shape[1] vis0 = np.zeros((channelrows,channelcols), np.float32) vis0[:channelrows, :channelcols] = channel # vis0=vis0-128 # before DCT the pixel values of all channels are shifted by -128 blocks = self.blockshaped(vis0, blocksize, blocksize) # dct_blocks = fftpack.dct(fftpack.dct(blocks, axis=1, norm='ortho'), axis=2, norm='ortho') dct_blocks = torch_dct.dct_2d(blocks, norm='ortho') thres_blocks = dct_blocks * \ (abs(dct_blocks) > thresh * np.amax(dct_blocks, axis=(1,2))[:, np.newaxis, np.newaxis]) # need to broadcast quant_blocks = np.round(thres_blocks / Q[idx]) TransAll_list.append(self.unblockshaped(dct_blocks, channelrows, channelcols)) TransAllThresh_list.append(self.unblockshaped(thres_blocks, channelrows, channelcols)) TransAllQuant_list.append(self.unblockshaped(quant_blocks, channelrows, channelcols)) return TransAll_list, TransAllThresh_list ,TransAllQuant_list
def forward(self, img, valid=False): batch_size = img.size(0) if self.pretrain_config.data_params.spectral_domain: img = self.system.normalize(img) img = dct.dct_2d(img) img = (img - img.mean()) / img.std() if not valid and not self.config.optim_params.no_views: img = self.viewmaker(img) if type(img) == tuple: idx = random.randint(0, 1) img = img[idx] if 'Expert' not in self.pretrain_config.system and not self.pretrain_config.data_params.spectral_domain: img = self.system.normalize(img) if self.pretrain_config.model_params.resnet_small: if self.config.model_params.use_prepool: embs = self.encoder(img, layer=5) else: embs = self.encoder(img, layer=6) else: embs = self.encoder(img) return self.model(embs.view(batch_size, -1))
def testBackwardGrayScale(self, stride, height, width, datatype): rtol, atol = 1e-3, 1e-6 # Parameters nSamples = 8 nrows = int(math.ceil(height / stride[Direction.VERTICAL])) ncols = int(math.ceil(width / stride[Direction.HORIZONTAL])) nDecs = stride[0] * stride[1] # math.prod(stride) nComponents = 1 # Source (nSamples x nRows x nCols x nDecs) X = torch.rand(nSamples, nrows, ncols, nDecs, dtype=datatype, requires_grad=True) # nSamples x nComponents x (Stride[0]xnRows) x (Stride[1]xnCols) dLdZ = torch.rand(nSamples, nComponents, height, width, dtype=datatype) # Expected values arrayshape = stride.copy() arrayshape.insert(0, -1) Y = dct.dct_2d(dLdZ.view(arrayshape), norm='ortho') A = permuteDctCoefs_(Y) # Rearrange the DCT Coefs. (nSamples x nComponents x nrows x ncols) x (decV x decH) expctddLdX = A.view(nSamples, nrows, ncols, nDecs) # Instantiation of target class layer = NsoltBlockIdct2dLayer(decimation_factor=stride, name='E0~') # Actual values Z = layer.forward(X) Z.backward(dLdZ) actualdLdX = X.grad # Evaluation self.assertEqual(actualdLdX.dtype, datatype) self.assertTrue( torch.allclose(actualdLdX, expctddLdX, rtol=rtol, atol=atol)) self.assertTrue(Z.requires_grad)
def main(video_file, pose_dict, model): is_cuda = torch.cuda.is_available() # ============== 3D pose estimation ============== # poses = main_VIBE(video_file, model) # with open('PoseCorrection/Results/poses_vibe.pickle', 'rb') as f: # poses = pickle.load(f) # ============== Squeleton uniformization ============== # poses_uniform = centralize_normalize_rotate_poses(poses, pose_dict) joints = list(range(15)) + [19, 21, 22, 24] poses_reshaped = poses_uniform[:, :, joints] poses_reshaped = poses_reshaped.reshape( -1, poses_reshaped.shape[1] * poses_reshaped.shape[2]).T frames = poses_reshaped.shape[1] # ============== Input ============== # dct_n = 25 if frames >= dct_n: inputs = dct.dct_2d(poses_reshaped)[:, :dct_n] else: inputs = dct.dct_2d( torch.nn.ZeroPad2d((0, dct_n - frames, 0, 0))(poses_reshaped)) if is_cuda: inputs = inputs.cuda() # ============== Action recognition ============== # model_class = GCN_class() model_class.load_state_dict( torch.load('PoseCorrection/Data/model_class.pt')) if is_cuda: model_class.cuda() model_class.eval() with torch.no_grad(): _, label = torch.max(model_class(inputs).data, 1) # ============== Motion correction ============== # model_corr = GCN_corr() model_corr.load_state_dict(torch.load('PoseCorrection/Data/model_corr.pt')) if is_cuda: model_corr.cuda() with torch.no_grad(): model_corr.eval() deltas_dct, att = model_corr(inputs) if frames > dct_n: m = torch.nn.ZeroPad2d((0, frames - dct_n, 0, 0)) deltas = dct.idct_2d(m(deltas_dct).transpose(1, 2)) else: deltas = dct.idct_2d(deltas_dct[:, :frames].transpose(1, 2)) poses_corrected = poses_reshaped + deltas.squeeze().squeeze().T # ============== Action recognition ============== # with torch.no_grad(): _, label_corr = torch.max(model_class(inputs + deltas_dct).data, 1) return poses_reshaped, poses_corrected, label, label_corr
def to_spectral(x): return dct.dct_2d(x)
def foolbox_attack(filter=None, filter_preserve='low', free_parm='eps', plot_num=None): # get model. model = get_model() model = nn.DataParallel(model).to(device) model = model.eval() preprocessing = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], axis=-3) fmodel = PyTorchModel(model, bounds=(0, 1), preprocessing=preprocessing) if plot_num: free_parm = '' val_loader = get_val_loader(plot_num) else: # Load images. val_loader = get_val_loader(args.attack_batch_size) if 'eps' in free_parm: epsilons = [0.001, 0.003, 0.005, 0.008, 0.01, 0.1] else: epsilons = [0.01] if 'step' in free_parm: steps = [1, 5, 10, 30, 40, 50] else: steps = [args.iteration] for step in steps: # Adversarial attack. if args.attack_type == 'LinfPGD': attack = LinfPGD(steps=step) elif args.attack_type == 'FGSM': attack = FGSM() clean_acc = 0.0 for i, data in enumerate(val_loader, 0): # Samples (attack_batch_size * attack_epochs) images for adversarial attack. if i >= args.attack_epochs: break images, labels = data[0].to(device), data[1].to(device) if step == steps[0]: clean_acc += (get_acc( fmodel, images, labels )) / args.attack_epochs # accumulate for attack epochs. _images, _labels = ep.astensors(images, labels) raw_advs, clipped_advs, success = attack(fmodel, _images, _labels, epsilons=epsilons) if plot_num: grad = torch.from_numpy( raw_advs[0].numpy()).to(device) - images grad = grad.clone().detach_() return grad if filter: robust_accuracy = torch.empty(len(epsilons)) for eps_id in range(len(epsilons)): grad = torch.from_numpy( raw_advs[eps_id].numpy()).to(device) - images grad = grad.clone().detach_() freq = dct.dct_2d(grad) if filter_preserve == 'low': mask = torch.zeros(freq.size()).to(device) mask[:, :, :filter, :filter] = 1 elif filter_preserve == 'high': mask = torch.zeros(freq.size()).to(device) mask[:, :, filter:, filter:] = 1 masked_freq = torch.mul(freq, mask) new_grad = dct.idct_2d(masked_freq) x_adv = torch.clamp(images + new_grad, 0, 1).detach_() robust_accuracy[eps_id] = (get_acc(fmodel, x_adv, labels)) else: robust_accuracy = 1 - success.float32().mean(axis=-1) if i == 0: robust_acc = robust_accuracy / args.attack_epochs else: robust_acc += robust_accuracy / args.attack_epochs if step == steps[0]: print("sample size is : ", args.attack_batch_size * args.attack_epochs) print(f"clean accuracy: {clean_acc * 100:.1f} %") print( f"Model {args.model} robust accuracy for {args.attack_type} perturbations with" ) for eps, acc in zip(epsilons, robust_acc): print( f" Step {step}, Linf norm ≤ {eps:<6}: {acc.item() * 100:4.1f} %" ) print(' -------------------')
def forward(self, x): y = dct.dct_2d(x, 'ortho') y = y[..., :self.h, :self.w] y = dct.idct_2d(y, 'ortho') return y
def optimize(ux0, uy0, im1, im2, maxIter, lambda_r, to1, theta, device): eps = 0.0000001 Ix, Iy = centralFiniteDifference(im1) It = im1 - im2 a11 = Ix * Ix a12 = Ix * Iy a22 = Iy * Iy t1 = Ix * (It - Ix * ux0 - Iy * uy0) t2 = Iy * (It - Ix * ux0 - Iy * uy0) h, w = im1.size() vx = torch.zeros((h, w)).to(device).double() vy = torch.zeros((h, w)).to(device).double() bx = torch.zeros((h, w)).to(device).double() by = torch.zeros((h, w)).to(device).double() ux = torch.zeros((h, w)).to(device).double() uy = torch.zeros((h, w)).to(device).double() X, Y = torch.meshgrid([torch.arange(0, h), torch.arange(0, w)]) # G = 2 * (torch.cos(PI * X / w + PI * Y / h) - 2) # G = G.to(device).double() X, Y = torch.meshgrid(torch.linspace(0, h - 1, h), torch.linspace(0, w - 1, w)) X, Y = X.cuda(), Y.cuda() G = torch.cos(math.pi * X / h) + torch.cos(math.pi * Y / w) - 2 # G = G.unsqueeze(0).repeat(N, 1, 1, 1) for i in range(maxIter): tempx = ux tempy = uy h1 = theta * (vx - bx) - t1 h2 = theta * (vy - by) - t2 ux = ((a22 + theta) * h1 - a12 * h2) / ((a11 + theta) * (a22 + theta) - a12 * a12) uy = ((a11 + theta) * h2 - a12 * h1) / ((a11 + theta) * (a22 + theta) - a12 * a12) # vx = (idct2(dct2(theta * (ux + bx)) / (theta + lambda_r * G * G))) # vy = (idct2(dct2(theta * (uy + by)) / (theta + lambda_r * G * G))) vx = (tdct.idct_2d( tdct.dct_2d(theta * (ux + bx)) / (theta + lambda_r * G * G))) vy = (tdct.idct_2d( tdct.dct_2d(theta * (uy + by)) / (theta + lambda_r * G * G))) bx = bx + ux - vx by = by + uy - vy # t1 = Ix * (It - Ix * ux - Iy * uy) # t2 = Iy * (It - Ix * ux - Iy * uy) stopx = torch.sum( torch.abs(ux - tempx)) / (torch.sum(torch.abs(tempx)) + eps) stopy = torch.sum( torch.abs(uy - tempy)) / (torch.sum(torch.abs(tempy)) + eps) # print(i, stopx, stopy) if stopx < to1 and stopy < to1: print('iterate {} times, stop due to converge to tolerance'.format( i)) break if i == maxIter - 1: print('iterate {} times, stop due to reach max iteration'.format(i)) return ux, uy
def main(): # Load dataset print('Loading dataset ...\n') ## R, M, D ## if opt.net_mode == 'R' or opt.net_mode == 'M' or opt.net_mode == 'D': if opt.color == 1: dataset_train = Dataset(train=True, aug_times=2, grayscale=False, scales=True) else: dataset_train = Dataset(train=True, aug_times=2, grayscale=True, scales=True) else: raise NotImplemented( 'Supported networks: R (DnCNN), M (MemNet), D (RIDNet) only') loader_train = DataLoader(dataset=dataset_train, num_workers=4, batch_size=opt.batch_size, shuffle=True) print("# of training samples: %d\n" % int(len(dataset_train))) # Build model model_channels = 1 + 2 * opt.color np.random.seed(0) torch.manual_seed(0) if opt.net_mode == 'R': print('** Creating DnCNN RL network: **') net = DnCNN_RL(channels=model_channels, num_of_layers=opt.num_of_layers) elif opt.net_mode == 'M': print('** Creating MemNet network: **') net = MemNet(in_channels=model_channels, channels=20, num_memblock=6, num_resblock=4) elif opt.net_mode == 'D': print('** Creating RIDNet network: **') net = RIDNet(in_channels=model_channels) print(net) net.apply(weights_init_kaiming) # Loss criterion = nn.MSELoss(size_average=False) # Move to GPU model = nn.DataParallel(net).cuda() criterion.cuda() # print(model) print('Trainable parameters: ', sum(p.numel() for p in model.parameters() if p.requires_grad)) # Optimizer optimizer = optim.Adam(model.parameters(), lr=opt.lr) # Training noiseL_B = [0, opt.noise_max] train_loss_log = np.zeros(opt.epochs) for epoch in range(opt.epochs): # Learning rate factor = epoch // opt.milestone current_lr = opt.lr / (10.**factor) for param_group in optimizer.param_groups: param_group["lr"] = current_lr print('\nlearning rate %f' % current_lr) # Train t = time.time() for i, data in enumerate(loader_train, 0): # Training model.train() model.zero_grad() optimizer.zero_grad() # ADD Noise img_train = data noise = torch.zeros(img_train.size()) noise_level_train = torch.zeros(img_train.size()) stdN = np.random.uniform(noiseL_B[0], noiseL_B[1], size=noise.size()[0]) sizeN = noise[0, :, :, :].size() # Noise Level map preparation (each step) for n in range(noise.size()[0]): noise[n, :, :, :] = torch.FloatTensor(sizeN).normal_( mean=0, std=stdN[n] / 255.) noise_level_value = stdN[n] / noiseL_B[1] noise_level_train[n, :, :, :] = torch.FloatTensor( np.ones(sizeN)) noise_level_train[n, :, :, :] = noise_level_train[ n, :, :, :] * noise_level_value noise_level_train = Variable(noise_level_train.cuda()) # Modifying the frequency content of the added noise (Low or High only) if opt.mask_train_noise in ([1, 2]): noise_mask = get_mask_low_high(w=sizeN[1], h=sizeN[2], radius_perc=0.5, mask_mode=opt.mask_train_noise) for n in range(noise.size()[0]): noise_dct = dct(dct(noise[n, 0, :, :].data.numpy(), axis=0, norm='ortho'), axis=1, norm='ortho') noise_dct = noise_dct * noise_mask noise_numpy = idct(idct(noise_dct, axis=0, norm='ortho'), axis=1, norm='ortho') noise[n, 0, :, :] = torch.from_numpy(noise_numpy) elif opt.mask_train_noise == 3: #Brownian noise for n in range(noise.size()[0]): noise_numpy = gaussian_filter(noise[n, 0, :, :].data.numpy(), sigma=3) noise[n, 0, :, :] = torch.from_numpy(noise_numpy) # DCT SFM if opt.DCT_DOR > 0: img_train_SFM = np.zeros(img_train.size(), dtype='float32') noise_SFM = np.zeros(noise.size(), dtype='float32') dct_bool = np.random.choice([1, 0], size=(img_train.size()[0], ), p=[opt.DCT_DOR, 1 - opt.DCT_DOR]) for img_idx in range(img_train.size()[0]): if dct_bool[img_idx] == 1: img_numpy, mask = random_drop( img_train[img_idx, :, :, :].data.numpy(), mode=opt.SFM_mode, SFM_center_radius_perc=opt.SFM_rad_perc, SFM_center_sigma_perc=opt.SFM_sigma_perc) img_train_SFM[img_idx, 0, :, :] = img_numpy noise_dct = dct(dct(noise[img_idx, 0, :, :].data.numpy(), axis=0, norm='ortho'), axis=1, norm='ortho') noise_dct = noise_dct * mask noise_numpy = idct(idct(noise_dct, axis=0, norm='ortho'), axis=1, norm='ortho') noise_SFM[img_idx, 0, :, :] = noise_numpy if opt.SFM_noise == 0: imgn_train = torch.from_numpy(img_train_SFM) + noise elif opt.SFM_noise == 1: imgn_train = torch.from_numpy( img_train_SFM) + torch.from_numpy(noise_SFM) if opt.SFM_GT == 1: img_train = torch.from_numpy(img_train_SFM) else: imgn_train = img_train + noise # Training step img_train, imgn_train = Variable(img_train.cuda()), Variable( imgn_train.cuda()) noise = Variable(noise.cuda()) out_train = model(imgn_train) OUT_NOISE = imgn_train - out_train loss = criterion(OUT_NOISE, noise) / (imgn_train.size()[0] * 2) if (opt.DCTloss_weight == 1): noise_DCT = torch.zeros(noise.size()) OUT_NOISE_DCT = torch.zeros(noise.size()) for img_idx in range(noise.size()[0]): noise_DCT[img_idx, 0, :, :] = torch_dct.dct_2d(noise[img_idx, 0, :, :]) OUT_NOISE_DCT[img_idx, 0, :, :] = torch_dct.dct_2d( OUT_NOISE[img_idx, 0, :, :]) noise_DCT, OUT_NOISE_DCT = Variable( noise_DCT.cuda()), Variable(OUT_NOISE_DCT.cuda()) loss += criterion(OUT_NOISE_DCT, noise_DCT) / (imgn_train.size()[0] * 2) loss_DCTcomponent = criterion( OUT_NOISE_DCT, noise_DCT) / (imgn_train.size()[0] * 2) loss.backward() optimizer.step() train_loss_log[epoch] += loss.item() train_loss_log[epoch] = train_loss_log[epoch] / len(loader_train) elapsed = time.time() - t if (opt.DCTloss_weight == 1): print( 'Epoch %d: loss=%.4f, lossDCT=%.4f, elapsed time (min):%.2f' % (epoch, train_loss_log[epoch], loss_DCTcomponent.item(), elapsed / 60.)) else: print('Epoch %d: loss=%.4f, elapsed time (min):%.2f' % (epoch, train_loss_log[epoch], elapsed / 60.)) model_name = get_model_name(opt) model_dir = os.path.join('saved_models', model_name) if not os.path.exists(model_dir): os.makedirs(model_dir) torch.save(model.state_dict(), os.path.join(model_dir, 'net_%d.pth' % (epoch)))
def dct_flip(x): return torch_dct.idct_2d(torch.flip(torch_dct.dct_2d(x, norm='ortho'), [-2, -1]), norm='ortho')
def testBackwardRgbColor(self, stride, height, width, datatype): rtol, atol = 1e-3, 1e-6 # Parameters nSamples = 8 nrows = int(math.ceil(height / stride[Direction.VERTICAL])) ncols = int(math.ceil(width / stride[Direction.HORIZONTAL])) nDecs = stride[0] * stride[1] # math.prod(stride) nComponents = 3 # RGB # Source (nSamples x nRows x nCols x nDecs) Xr = torch.rand(nSamples, nrows, ncols, nDecs, dtype=datatype, requires_grad=True) Xg = torch.rand(nSamples, nrows, ncols, nDecs, dtype=datatype, requires_grad=True) Xb = torch.rand(nSamples, nrows, ncols, nDecs, dtype=datatype, requires_grad=True) # nSamples x nComponents x (Stride[0]xnRows) x (Stride[1]xnCols) dLdZ = torch.rand(nSamples, nComponents, height, width, dtype=datatype) # Expected values arrayshape = stride.copy() arrayshape.insert(0, -1) Y = dct.dct_2d(dLdZ.view(arrayshape), norm='ortho') A = permuteDctCoefs_(Y) # Rearrange the DCT Coefs. (nSamples x nRows x nCols x nDecs) Z = A.view(nSamples, nComponents, nrows, ncols, nDecs) expctddLdXr, expctddLdXg, expctddLdXb = map( lambda x: torch.squeeze(x, dim=1), torch.chunk(Z, nComponents, dim=1)) # Instantiation of target class layer = NsoltBlockIdct2dLayer(decimation_factor=stride, number_of_components=nComponents, name='E0~') # Actual values Z = layer.forward(Xr, Xg, Xb) Z.backward(dLdZ) actualdLdXr = Xr.grad actualdLdXg = Xg.grad actualdLdXb = Xb.grad # Evaluation self.assertEqual(actualdLdXr.dtype, datatype) self.assertEqual(actualdLdXg.dtype, datatype) self.assertEqual(actualdLdXb.dtype, datatype) self.assertTrue( torch.allclose(actualdLdXr, expctddLdXr, rtol=rtol, atol=atol)) self.assertTrue( torch.allclose(actualdLdXg, expctddLdXg, rtol=rtol, atol=atol)) self.assertTrue( torch.allclose(actualdLdXb, expctddLdXb, rtol=rtol, atol=atol)) self.assertTrue(Z.requires_grad)