def __init__(self,
                 data_path,
                 dct_n=25,
                 split=0,
                 sets=None,
                 is_cuda=False,
                 add_data=None):
        if sets is None:
            sets = [[0, 1], [2], [3]]

        self.dct_n = dct_n

        correct, other = load_data(data_path, sets[split], add_data=add_data)
        pairs = dtw_pairs(correct, other, is_cuda=is_cuda)

        self.targets_label = [i[1] for i in pairs]
        self.inputs_label = [i[0] for i in pairs]
        self.targets = [correct[i] for i in self.targets_label]
        self.inputs_raw = [other[i] for i in self.inputs_label]
        self.inputs = [
            dct.dct_2d(torch.from_numpy(x))[:, :self.dct_n].numpy()
            if x.shape[1] >= self.dct_n else dct.dct_2d(
                torch.nn.ZeroPad2d((0, self.dct_n - x.shape[1], 0,
                                    0))(torch.from_numpy(x))).numpy()
            for x in self.inputs_raw
        ]

        self.node_n = np.shape(self.inputs_raw[0])[0]
        self.batch_ids = list(range(len(self.inputs_raw)))
示例#2
0
    def get_dct_init(self, dim_out, dim_in, med_out, med_in, small_out,
                     small_in):

        factor = 1.
        # initialize as if dense matrix
        init = torch.rand([dim_out, dim_in])
        if self.cuda:  # TODO update to device.
            init = init.cuda()
        initrange = 1.0 / math.sqrt(dim_out)
        # initrange = 0.1
        nn.init.uniform_(init, -initrange, initrange)

        # apply DCT
        init_f = dct.dct_2d(init, norm='ortho')
        ind = np.array([[i, j] for i in range(med_out)
                        for j in range(med_in)]).transpose()
        coeffs_init = init_f[ind] * factor
        coeffs_init = torch.reshape(coeffs_init, [med_out, med_in])

        # apply further DCT
        coeffs_init = dct.dct_2d(coeffs_init, norm='ortho')
        ind = np.array([[i, j] for i in range(small_out)
                        for j in range(small_in)]).transpose()
        coeffs_init = coeffs_init[ind] * factor
        coeffs_init = torch.reshape(coeffs_init, [small_out, small_in])

        return coeffs_init
示例#3
0
def test_idct_2d():
    for N1 in [2, 5, 32]:
        for N2 in [2, 5, 32]:
            x = np.random.normal(size=(1, N1, N2))
            X = dct.dct_2d(torch.tensor(x))
            y = dct.idct_2d(X).numpy()
            assert np.abs(x - y).max() < EPS, x
示例#4
0
def dct_cutoff_low(x, bandwidth):
    if len(x.size()) == 2:
        x.unsqueeze_(0)

    mask = torch.ones_like(x)
    mask[:, :bandwidth, :bandwidth] = 0
    return torch_dct.idct_2d(torch_dct.dct_2d(x, norm='ortho') * mask, norm='ortho').squeeze_()
示例#5
0
    def forward(self, X):
        nComponents = self.num_outputs
        nSamples = X.size(0)
        height = X.size(2)
        width = X.size(3)
        stride = self.decimation_factor
        nrows = int(math.ceil(height / stride[Direction.VERTICAL]))
        ncols = int(math.ceil(width / stride[Direction.HORIZONTAL]))
        ndecs = stride[0] * stride[1]  #math.prod(stride)
        # Block DCT (nSamples x nComponents x nrows x ncols) x decV x decH
        arrayshape = stride.copy()
        arrayshape.insert(0, -1)
        Y = dct.dct_2d(X.view(arrayshape), norm='ortho')
        # Rearrange the DCT Coefs. (nSamples x nComponents x nrows x ncols) x (decV x decH)
        cee = Y[:, 0::2, 0::2].reshape(Y.size(0), -1)
        coo = Y[:, 1::2, 1::2].reshape(Y.size(0), -1)
        coe = Y[:, 1::2, 0::2].reshape(Y.size(0), -1)
        ceo = Y[:, 0::2, 1::2].reshape(Y.size(0), -1)
        A = torch.cat((cee, coo, coe, ceo), dim=-1)
        Z = A.view(nSamples, nComponents, nrows, ncols, ndecs)

        if nComponents < 2:
            return torch.squeeze(Z, dim=1)
        else:
            return map(lambda x: torch.squeeze(x, dim=1),
                       torch.chunk(Z, nComponents, dim=1))
示例#6
0
def dct_low_pass(x, bandwidth):
    if len(x.size()) == 2:
        x.unsqueeze_(0)

    mask = torch.zeros_like(x)
    mask[:, :bandwidth, :bandwidth] = 1
    return torch_dct.idct_2d(torch_dct.dct_2d(x, norm='ortho') * mask, norm='ortho').squeeze_()
示例#7
0
    def forward(self, x):
        if self.downsample_to:
            # Downsample.
            x_orig = x
            x = torch.nn.functional.interpolate(
                x, size=(self.downsample_to, self.downsample_to), mode='bilinear')
        y = x
        
        if self.frequency_domain:
            # Input to viewmaker is in frequency domain, outputs frequency domain perturbation.
            # Uses the Discrete Cosine Transform.
            # shape still [batch_size, C, W, H]
            y = dct.dct_2d(y)

        y_pixels, features = self.basic_net(y, self.num_res_blocks, bound_multiplier=1)
        delta = self.get_delta(y_pixels)
        if self.frequency_domain:
            # Compute inverse DCT from frequency domain to time domain.
            delta = dct.idct_2d(delta)
        if self.downsample_to:
            # Upsample.
            x = x_orig
            delta = torch.nn.functional.interpolate(delta, size=x_orig.shape[-2:], mode='bilinear')

        # Additive perturbation
        result = x + delta
        if self.clamp:
            result = torch.clamp(result, 0, 1.0)

        return result
示例#8
0
    def __init__(self, originals: ep.Tensor, random_noise: str = "normal", basis_type: str = "dct", **kwargs : Any):
        """
        Args:
            random_noise (str, optional): When basis is created, a noise will be added.This noise can be normal or 
                                          uniform. Defaults to "normal".
            basis_type (str, optional): Type of the basis: DCT, Random, Genetic,. Defaults to "random".
            device (int, optional): [description]. Defaults to -1.
            args, kwargs: In args and kwargs, there is the basis params:
                    * Random: No parameters                    
                    * DCT:
                            * function (tanh / constant / linear): function applied on the dct
                            * beta
                            * gamma
                            * frequence_range: tuple of 2 float
                            * dct_type: 8x8 or full
        """
        self._originals = originals
        if isinstance(self._originals.raw, torch.Tensor):
            self._f_dct2 = lambda a: torch_dct.dct_2d(a)
            self._f_idct2 = lambda a: torch_dct.idct_2d(a)
        elif isinstance(v.raw, np.array):
            from scipy import fft
            self._f_dct2 = lambda a: fft.dct(fft.dct(a, axis=2, norm='ortho' ), axis=3, norm='ortho')
            self._f_idct2 = lambda a: fft.idct(fft.idct(a, axis=2, norm='ortho'), axis=3, norm='ortho')

        self.basis_type = basis_type
        self._function_generation = getattr(self, "_get_vector_" + self.basis_type)
        self._load_params(**kwargs)

        assert random_noise in ["normal", "uniform"]
        self.random_noise = random_noise
示例#9
0
    def dct_calculation(self, x):
        N, C, H, W = x.size()
        B = self.block_size
        # # 2. The chrominance channels Cr and Cb are subsampled
        # # if self.sub_sampling == '4:2:0':
        # #     imSub = self.subsample_chrominance(x, 2, 2)
        # # 3. Get the quatisation matrices, which will be applied to the DCT coefficients
        # Q = self.quality_factorize(self.QY, self.QC, self.quality_factor)
        # # 4. Apply DCT algorithm for orignal image
        # TransAll, TransAllThresh ,TransAllQuant = self.dct_encoder(x, Q, self.block_size, self.thresh)
        # # 5. Split the same frequency in each 8x8 blocks to the same channel
        # dct_list = self.split_frequency(TransAll, self.block_size)
        # # 6. upsample the Cr & Cb channel to concatenate with Y channel
        # # dct_coefficients = self.upsample(dct_list)
        
        blocksV = int(W / B)
        blocksH = int(H / B)
        # vis0 = torch.zeros_like(x).cuda()
        # for row in range(blocksV):
        #     for col in range(blocksH):
        #         currentblock = torch_dct.dct_2d(x[:,:, row*B:(row+1)*B, col*B:(col+1)*B], norm='ortho')
        #         vis0[:,:, row*B:(row+1)*B, col*B:(col+1)*B] = currentblock

        # vis0 = torch_dct.dct_2d(x, norm='ortho')

        block = x.reshape(N, C, blocksH, B, blocksV, B)
        block1 = block.permute(0,1,2,4,3,5)
        block2 = block1.reshape(N, C, -1, B, B)
        dct_block = torch_dct.dct_2d(block2, norm='ortho')
        block3 = dct_block.reshape(N, C, blocksH, blocksV, B, B)
        block4 = block3.permute(0,1,2,4,3,5)
        block5 = block4.reshape(N, C, H, W)

        return block5
示例#10
0
 def forward(self, x):
     X = dct.dct_2d(x)
     mask = torch.zeros_like(X)
     mask[:, :, 0:self.fre, 0:self.fre] = 1
     out = dct.idct_2d(X * mask)
     out = self.model(out)
     return out
示例#11
0
    def dct_torch(self, x, h, w):
        out = F.interpolate(x,
                            size=(h, w),
                            mode='bilinear',
                            align_corners=True)
        out = out.to(torch.float32)

        return torch_dct.dct_2d(out, norm='ortho')
示例#12
0
def test_dct_2d():
    for N1 in [2, 5, 32]:
        for N2 in [2, 5, 32]:
            x = np.random.normal(size=(1, N1, N2))
            ref = fftpack.dct(x, axis=2, type=2)
            ref = fftpack.dct(ref, axis=1, type=2)
            act = dct.dct_2d(torch.tensor(x)).numpy()
            assert np.abs(ref - act).max() < EPS, (ref, act)
示例#13
0
 def forward(self, x):
     batch_size, channels, height, width = x.size()
     dct_vector_coords_all = self.get_dct_vector_coords(r=height)
     dct_vector_coords = dct_vector_coords_all[:self.vec_dim]
     # masks = x.to(dtype=float)
     masks = x.clone()
     dct_all = torch_dct.dct_2d(masks, norm='ortho')
     xs, ys = dct_vector_coords[:, 0], dct_vector_coords[:, 1]
     dct_vectors = dct_all[:, :, xs, ys]  # reshape as vector
     return dct_vectors  # [batch_size, channels, D]
示例#14
0
    def testPredictRgbColor(self, stride, height, width, datatype):
        rtol, atol = 1e-5, 1e-8

        # Parameters
        nSamples = 8
        nComponents = 3  # RGB
        # Source (nSamples x nComponents x (Stride[0]xnRows) x (Stride[1]xnCols))
        X = torch.rand(nSamples,
                       nComponents,
                       height,
                       width,
                       dtype=datatype,
                       requires_grad=True)

        # Expected values
        nrows = int(math.ceil(height /
                              stride[Direction.VERTICAL]))  #.astype(int)
        ncols = int(math.ceil(width /
                              stride[Direction.HORIZONTAL]))  #.astype(int)
        ndecs = stride[0] * stride[1]  # math.prod(stride)

        # Block DCT (nSamples x nComponents x nrows x ncols) x decV x decH
        arrayshape = stride.copy()
        arrayshape.insert(0, -1)
        Y = dct.dct_2d(X.view(arrayshape), norm='ortho')
        # Rearrange the DCT Coefs. (nSamples x nComponents x nrows x ncols) x (decV x decH)
        A = permuteDctCoefs_(Y)
        Z = A.view(nSamples, nComponents, nrows, ncols, ndecs)
        expctdZr = Z[:, 0, :, :, :]
        expctdZg = Z[:, 1, :, :, :]
        expctdZb = Z[:, 2, :, :, :]

        # Instantiation of target class
        layer = NsoltBlockDct2dLayer(decimation_factor=stride,
                                     number_of_components=nComponents,
                                     name='E0')

        # Actual values
        with torch.no_grad():
            actualZr, actualZg, actualZb = layer.forward(X)

        # Evaluation
        self.assertEqual(actualZr.dtype, datatype)
        self.assertEqual(actualZg.dtype, datatype)
        self.assertEqual(actualZb.dtype, datatype)
        self.assertTrue(
            torch.allclose(actualZr, expctdZr, rtol=rtol, atol=atol))
        self.assertTrue(
            torch.allclose(actualZg, expctdZg, rtol=rtol, atol=atol))
        self.assertTrue(
            torch.allclose(actualZb, expctdZb, rtol=rtol, atol=atol))
        self.assertFalse(actualZr.requires_grad)
        self.assertFalse(actualZg.requires_grad)
        self.assertFalse(actualZb.requires_grad)
示例#15
0
 def encode(self, masks, dim=None):
     """
     Encode the mask to vector of vec_dim or specific dimention.
     """
     if dim is None:
         dct_vector_coords = self.dct_vector_coords[:self.vec_dim]
     else:
         dct_vector_coords = self.dct_vector_coords[:dim]
     masks = masks.view([-1, self.mask_size, self.mask_size]).to(dtype=float)  # [N, H, W]
     dct_all = torch_dct.dct_2d(masks, norm='ortho')
     xs, ys = dct_vector_coords[:, 0], dct_vector_coords[:, 1]
     dct_vectors = dct_all[:, xs, ys]  # reshape as vector
     return dct_vectors  # [N, D]
    def get_dct_init(self, len_coeffs, dim_out, dim_in, diag_shift):

        factor = 1.
        init = torch.rand([dim_out, dim_in])
        if self.cuda:  # TODO update to device.
            init = init.cuda()

        initrange = 1.0 / math.sqrt(dim_out)
        nn.init.uniform_(init, -initrange, initrange)
        init_f = torch.fliplr(dct.dct_2d(init, norm='ortho'))
        ind = torch.triu_indices(dim_out, dim_in, diag_shift)
        coeffs_init = init_f[tuple(ind)] * factor

        return coeffs_init
示例#17
0
def count_dct_zeros(x, qf):
    #Define lochelechser luminance matrix (see JPEG standard)
    lc = loch_matrix()
    lc = torch.floor(q_table(lc, qf))
    #Shift pixels values (as in JPEG)
    aux = x * 255.0 - 128.0
    #Stacks all image 8x8 blocks
    a = torch.nn.functional.unfold(aux, 8, stride=8)
    #Swaps spatial dimensions axis in correct order to perform dct
    b = torch.transpose(a, 1, 2)
    b = b.view([a.shape[0], a.shape[-1], 8, 8])
    #Here b dimensions are [BATCH_SIZE, IMAGE AMOUNT OF 8X8BLOCKS, HEIGHT (8), WIDTH (8))
    #Now do DCT in the last 2 dimensions of b
    b = dct.dct_2d(b, norm='ortho')
    #See Thesis for estimation rate method
    b = torch.abs(b) * 2 - lc
    b = torch.relu(b)
    count = (b / (b + 0.00001)).sum(dim=(1, 2, 3))
    return torch.mean(count)  #batch mean
示例#18
0
 def dct_encoder(self, imSub_list, Q, blocksize=8, thresh=0.05):
     TransAll_list=[]
     TransAllThresh_list=[]
     TransAllQuant_list=[]
     for idx,channel in enumerate(imSub_list):
         channelrows=channel.shape[0]
         channelcols=channel.shape[1]
         vis0 = np.zeros((channelrows,channelcols), np.float32)
         vis0[:channelrows, :channelcols] = channel
         # vis0=vis0-128 # before DCT the pixel values of all channels are shifted by -128
         blocks = self.blockshaped(vis0, blocksize, blocksize)
         # dct_blocks = fftpack.dct(fftpack.dct(blocks, axis=1, norm='ortho'), axis=2, norm='ortho')
         dct_blocks = torch_dct.dct_2d(blocks, norm='ortho')
         thres_blocks = dct_blocks * \
             (abs(dct_blocks) > thresh * np.amax(dct_blocks, axis=(1,2))[:, np.newaxis, np.newaxis]) # need to broadcast
         quant_blocks = np.round(thres_blocks / Q[idx])
     
         TransAll_list.append(self.unblockshaped(dct_blocks, channelrows, channelcols))
         TransAllThresh_list.append(self.unblockshaped(thres_blocks, channelrows, channelcols))
         TransAllQuant_list.append(self.unblockshaped(quant_blocks, channelrows, channelcols))
     return TransAll_list, TransAllThresh_list ,TransAllQuant_list
示例#19
0
 def forward(self, img, valid=False):
     batch_size = img.size(0)
     if self.pretrain_config.data_params.spectral_domain:
         img = self.system.normalize(img)
         img = dct.dct_2d(img)
         img = (img - img.mean()) / img.std()
     if not valid and not self.config.optim_params.no_views: 
         img = self.viewmaker(img)
         if type(img) == tuple:
             idx = random.randint(0, 1)
             img = img[idx]
     if 'Expert' not in self.pretrain_config.system and not self.pretrain_config.data_params.spectral_domain:
         img = self.system.normalize(img)
     if self.pretrain_config.model_params.resnet_small:
         if self.config.model_params.use_prepool:
             embs = self.encoder(img, layer=5)
         else:
             embs = self.encoder(img, layer=6)
     else:
         embs = self.encoder(img)
     return self.model(embs.view(batch_size, -1))
示例#20
0
    def testBackwardGrayScale(self, stride, height, width, datatype):
        rtol, atol = 1e-3, 1e-6

        # Parameters
        nSamples = 8
        nrows = int(math.ceil(height / stride[Direction.VERTICAL]))
        ncols = int(math.ceil(width / stride[Direction.HORIZONTAL]))
        nDecs = stride[0] * stride[1]  # math.prod(stride)
        nComponents = 1
        # Source (nSamples x nRows x nCols x nDecs)
        X = torch.rand(nSamples,
                       nrows,
                       ncols,
                       nDecs,
                       dtype=datatype,
                       requires_grad=True)
        # nSamples x nComponents x (Stride[0]xnRows) x (Stride[1]xnCols)
        dLdZ = torch.rand(nSamples, nComponents, height, width, dtype=datatype)

        # Expected values
        arrayshape = stride.copy()
        arrayshape.insert(0, -1)
        Y = dct.dct_2d(dLdZ.view(arrayshape), norm='ortho')
        A = permuteDctCoefs_(Y)
        # Rearrange the DCT Coefs. (nSamples x nComponents x nrows x ncols) x (decV x decH)
        expctddLdX = A.view(nSamples, nrows, ncols, nDecs)

        # Instantiation of target class
        layer = NsoltBlockIdct2dLayer(decimation_factor=stride, name='E0~')

        # Actual values
        Z = layer.forward(X)
        Z.backward(dLdZ)
        actualdLdX = X.grad

        # Evaluation
        self.assertEqual(actualdLdX.dtype, datatype)
        self.assertTrue(
            torch.allclose(actualdLdX, expctddLdX, rtol=rtol, atol=atol))
        self.assertTrue(Z.requires_grad)
def main(video_file, pose_dict, model):

    is_cuda = torch.cuda.is_available()

    # ============== 3D pose estimation ============== #
    poses = main_VIBE(video_file, model)
    # with open('PoseCorrection/Results/poses_vibe.pickle', 'rb') as f:
    #     poses = pickle.load(f)

    # ============== Squeleton uniformization ============== #
    poses_uniform = centralize_normalize_rotate_poses(poses, pose_dict)
    joints = list(range(15)) + [19, 21, 22, 24]
    poses_reshaped = poses_uniform[:, :, joints]
    poses_reshaped = poses_reshaped.reshape(
        -1, poses_reshaped.shape[1] * poses_reshaped.shape[2]).T

    frames = poses_reshaped.shape[1]

    # ============== Input ============== #
    dct_n = 25
    if frames >= dct_n:
        inputs = dct.dct_2d(poses_reshaped)[:, :dct_n]
    else:
        inputs = dct.dct_2d(
            torch.nn.ZeroPad2d((0, dct_n - frames, 0, 0))(poses_reshaped))

    if is_cuda:
        inputs = inputs.cuda()

    # ============== Action recognition ============== #
    model_class = GCN_class()
    model_class.load_state_dict(
        torch.load('PoseCorrection/Data/model_class.pt'))

    if is_cuda:
        model_class.cuda()

    model_class.eval()
    with torch.no_grad():
        _, label = torch.max(model_class(inputs).data, 1)

    # ============== Motion correction ============== #
    model_corr = GCN_corr()
    model_corr.load_state_dict(torch.load('PoseCorrection/Data/model_corr.pt'))

    if is_cuda:
        model_corr.cuda()

    with torch.no_grad():
        model_corr.eval()
        deltas_dct, att = model_corr(inputs)

        if frames > dct_n:
            m = torch.nn.ZeroPad2d((0, frames - dct_n, 0, 0))
            deltas = dct.idct_2d(m(deltas_dct).transpose(1, 2))
        else:
            deltas = dct.idct_2d(deltas_dct[:, :frames].transpose(1, 2))

        poses_corrected = poses_reshaped + deltas.squeeze().squeeze().T

    # ============== Action recognition ============== #
    with torch.no_grad():
        _, label_corr = torch.max(model_class(inputs + deltas_dct).data, 1)

    return poses_reshaped, poses_corrected, label, label_corr
def to_spectral(x):
    return dct.dct_2d(x)
示例#23
0
def foolbox_attack(filter=None,
                   filter_preserve='low',
                   free_parm='eps',
                   plot_num=None):
    # get model.
    model = get_model()
    model = nn.DataParallel(model).to(device)
    model = model.eval()

    preprocessing = dict(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225],
                         axis=-3)
    fmodel = PyTorchModel(model, bounds=(0, 1), preprocessing=preprocessing)

    if plot_num:
        free_parm = ''
        val_loader = get_val_loader(plot_num)
    else:
        # Load images.
        val_loader = get_val_loader(args.attack_batch_size)

    if 'eps' in free_parm:
        epsilons = [0.001, 0.003, 0.005, 0.008, 0.01, 0.1]
    else:
        epsilons = [0.01]
    if 'step' in free_parm:
        steps = [1, 5, 10, 30, 40, 50]
    else:
        steps = [args.iteration]

    for step in steps:
        # Adversarial attack.
        if args.attack_type == 'LinfPGD':
            attack = LinfPGD(steps=step)
        elif args.attack_type == 'FGSM':
            attack = FGSM()

        clean_acc = 0.0

        for i, data in enumerate(val_loader, 0):

            # Samples (attack_batch_size * attack_epochs) images for adversarial attack.
            if i >= args.attack_epochs:
                break

            images, labels = data[0].to(device), data[1].to(device)
            if step == steps[0]:
                clean_acc += (get_acc(
                    fmodel, images, labels
                )) / args.attack_epochs  # accumulate for attack epochs.

            _images, _labels = ep.astensors(images, labels)
            raw_advs, clipped_advs, success = attack(fmodel,
                                                     _images,
                                                     _labels,
                                                     epsilons=epsilons)

            if plot_num:
                grad = torch.from_numpy(
                    raw_advs[0].numpy()).to(device) - images
                grad = grad.clone().detach_()
                return grad

            if filter:
                robust_accuracy = torch.empty(len(epsilons))
                for eps_id in range(len(epsilons)):
                    grad = torch.from_numpy(
                        raw_advs[eps_id].numpy()).to(device) - images
                    grad = grad.clone().detach_()
                    freq = dct.dct_2d(grad)
                    if filter_preserve == 'low':
                        mask = torch.zeros(freq.size()).to(device)
                        mask[:, :, :filter, :filter] = 1
                    elif filter_preserve == 'high':
                        mask = torch.zeros(freq.size()).to(device)
                        mask[:, :, filter:, filter:] = 1
                    masked_freq = torch.mul(freq, mask)
                    new_grad = dct.idct_2d(masked_freq)
                    x_adv = torch.clamp(images + new_grad, 0, 1).detach_()

                    robust_accuracy[eps_id] = (get_acc(fmodel, x_adv, labels))
            else:
                robust_accuracy = 1 - success.float32().mean(axis=-1)
            if i == 0:
                robust_acc = robust_accuracy / args.attack_epochs
            else:
                robust_acc += robust_accuracy / args.attack_epochs

        if step == steps[0]:
            print("sample size is : ",
                  args.attack_batch_size * args.attack_epochs)
            print(f"clean accuracy:  {clean_acc * 100:.1f} %")
            print(
                f"Model {args.model} robust accuracy for {args.attack_type} perturbations with"
            )
        for eps, acc in zip(epsilons, robust_acc):
            print(
                f"  Step {step}, Linf norm ≤ {eps:<6}: {acc.item() * 100:4.1f} %"
            )
        print('  -------------------')
示例#24
0
    def forward(self, x):
        y = dct.dct_2d(x, 'ortho')
        y = y[..., :self.h, :self.w]
        y = dct.idct_2d(y, 'ortho')

        return y
示例#25
0
def optimize(ux0, uy0, im1, im2, maxIter, lambda_r, to1, theta, device):
    eps = 0.0000001
    Ix, Iy = centralFiniteDifference(im1)
    It = im1 - im2

    a11 = Ix * Ix
    a12 = Ix * Iy
    a22 = Iy * Iy

    t1 = Ix * (It - Ix * ux0 - Iy * uy0)
    t2 = Iy * (It - Ix * ux0 - Iy * uy0)

    h, w = im1.size()

    vx = torch.zeros((h, w)).to(device).double()
    vy = torch.zeros((h, w)).to(device).double()
    bx = torch.zeros((h, w)).to(device).double()
    by = torch.zeros((h, w)).to(device).double()
    ux = torch.zeros((h, w)).to(device).double()
    uy = torch.zeros((h, w)).to(device).double()

    X, Y = torch.meshgrid([torch.arange(0, h), torch.arange(0, w)])
    # G = 2 * (torch.cos(PI * X / w + PI * Y / h) - 2)
    # G = G.to(device).double()

    X, Y = torch.meshgrid(torch.linspace(0, h - 1, h),
                          torch.linspace(0, w - 1, w))
    X, Y = X.cuda(), Y.cuda()
    G = torch.cos(math.pi * X / h) + torch.cos(math.pi * Y / w) - 2
    # G = G.unsqueeze(0).repeat(N, 1, 1, 1)

    for i in range(maxIter):
        tempx = ux
        tempy = uy

        h1 = theta * (vx - bx) - t1
        h2 = theta * (vy - by) - t2

        ux = ((a22 + theta) * h1 - a12 * h2) / ((a11 + theta) *
                                                (a22 + theta) - a12 * a12)
        uy = ((a11 + theta) * h2 - a12 * h1) / ((a11 + theta) *
                                                (a22 + theta) - a12 * a12)

        # vx = (idct2(dct2(theta * (ux + bx)) / (theta + lambda_r * G * G)))
        # vy = (idct2(dct2(theta * (uy + by)) / (theta + lambda_r * G * G)))

        vx = (tdct.idct_2d(
            tdct.dct_2d(theta * (ux + bx)) / (theta + lambda_r * G * G)))
        vy = (tdct.idct_2d(
            tdct.dct_2d(theta * (uy + by)) / (theta + lambda_r * G * G)))

        bx = bx + ux - vx
        by = by + uy - vy

        # t1 = Ix * (It - Ix * ux - Iy * uy)
        # t2 = Iy * (It - Ix * ux - Iy * uy)

        stopx = torch.sum(
            torch.abs(ux - tempx)) / (torch.sum(torch.abs(tempx)) + eps)
        stopy = torch.sum(
            torch.abs(uy - tempy)) / (torch.sum(torch.abs(tempy)) + eps)
        # print(i, stopx, stopy)
        if stopx < to1 and stopy < to1:
            print('iterate {} times, stop due to converge to tolerance'.format(
                i))
            break

    if i == maxIter - 1:
        print('iterate {} times, stop due to reach max iteration'.format(i))
    return ux, uy
示例#26
0
def main():

    # Load dataset
    print('Loading dataset ...\n')
    ## R, M, D ##
    if opt.net_mode == 'R' or opt.net_mode == 'M' or opt.net_mode == 'D':
        if opt.color == 1:
            dataset_train = Dataset(train=True,
                                    aug_times=2,
                                    grayscale=False,
                                    scales=True)
        else:
            dataset_train = Dataset(train=True,
                                    aug_times=2,
                                    grayscale=True,
                                    scales=True)
    else:
        raise NotImplemented(
            'Supported networks: R (DnCNN), M (MemNet), D (RIDNet) only')

    loader_train = DataLoader(dataset=dataset_train,
                              num_workers=4,
                              batch_size=opt.batch_size,
                              shuffle=True)
    print("# of training samples: %d\n" % int(len(dataset_train)))

    # Build model
    model_channels = 1 + 2 * opt.color

    np.random.seed(0)
    torch.manual_seed(0)

    if opt.net_mode == 'R':
        print('** Creating DnCNN RL network: **')
        net = DnCNN_RL(channels=model_channels,
                       num_of_layers=opt.num_of_layers)
    elif opt.net_mode == 'M':
        print('** Creating MemNet network: **')
        net = MemNet(in_channels=model_channels,
                     channels=20,
                     num_memblock=6,
                     num_resblock=4)
    elif opt.net_mode == 'D':
        print('** Creating RIDNet network: **')
        net = RIDNet(in_channels=model_channels)
    print(net)
    net.apply(weights_init_kaiming)

    # Loss
    criterion = nn.MSELoss(size_average=False)

    # Move to GPU
    model = nn.DataParallel(net).cuda()
    criterion.cuda()  # print(model)
    print('Trainable parameters: ',
          sum(p.numel() for p in model.parameters() if p.requires_grad))

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=opt.lr)

    # Training
    noiseL_B = [0, opt.noise_max]

    train_loss_log = np.zeros(opt.epochs)

    for epoch in range(opt.epochs):

        # Learning rate
        factor = epoch // opt.milestone
        current_lr = opt.lr / (10.**factor)
        for param_group in optimizer.param_groups:
            param_group["lr"] = current_lr
        print('\nlearning rate %f' % current_lr)

        # Train
        t = time.time()
        for i, data in enumerate(loader_train, 0):
            # Training
            model.train()
            model.zero_grad()
            optimizer.zero_grad()

            # ADD Noise
            img_train = data

            noise = torch.zeros(img_train.size())
            noise_level_train = torch.zeros(img_train.size())
            stdN = np.random.uniform(noiseL_B[0],
                                     noiseL_B[1],
                                     size=noise.size()[0])
            sizeN = noise[0, :, :, :].size()

            # Noise Level map preparation (each step)
            for n in range(noise.size()[0]):
                noise[n, :, :, :] = torch.FloatTensor(sizeN).normal_(
                    mean=0, std=stdN[n] / 255.)
                noise_level_value = stdN[n] / noiseL_B[1]
                noise_level_train[n, :, :, :] = torch.FloatTensor(
                    np.ones(sizeN))
                noise_level_train[n, :, :, :] = noise_level_train[
                    n, :, :, :] * noise_level_value
            noise_level_train = Variable(noise_level_train.cuda())

            # Modifying the frequency content of the added noise (Low or High only)
            if opt.mask_train_noise in ([1, 2]):
                noise_mask = get_mask_low_high(w=sizeN[1],
                                               h=sizeN[2],
                                               radius_perc=0.5,
                                               mask_mode=opt.mask_train_noise)
                for n in range(noise.size()[0]):
                    noise_dct = dct(dct(noise[n, 0, :, :].data.numpy(),
                                        axis=0,
                                        norm='ortho'),
                                    axis=1,
                                    norm='ortho')
                    noise_dct = noise_dct * noise_mask
                    noise_numpy = idct(idct(noise_dct, axis=0, norm='ortho'),
                                       axis=1,
                                       norm='ortho')
                    noise[n, 0, :, :] = torch.from_numpy(noise_numpy)
            elif opt.mask_train_noise == 3:  #Brownian noise
                for n in range(noise.size()[0]):
                    noise_numpy = gaussian_filter(noise[n,
                                                        0, :, :].data.numpy(),
                                                  sigma=3)
                    noise[n, 0, :, :] = torch.from_numpy(noise_numpy)

            # DCT SFM
            if opt.DCT_DOR > 0:
                img_train_SFM = np.zeros(img_train.size(), dtype='float32')
                noise_SFM = np.zeros(noise.size(), dtype='float32')

                dct_bool = np.random.choice([1, 0],
                                            size=(img_train.size()[0], ),
                                            p=[opt.DCT_DOR, 1 - opt.DCT_DOR])
                for img_idx in range(img_train.size()[0]):
                    if dct_bool[img_idx] == 1:
                        img_numpy, mask = random_drop(
                            img_train[img_idx, :, :, :].data.numpy(),
                            mode=opt.SFM_mode,
                            SFM_center_radius_perc=opt.SFM_rad_perc,
                            SFM_center_sigma_perc=opt.SFM_sigma_perc)
                        img_train_SFM[img_idx, 0, :, :] = img_numpy

                        noise_dct = dct(dct(noise[img_idx,
                                                  0, :, :].data.numpy(),
                                            axis=0,
                                            norm='ortho'),
                                        axis=1,
                                        norm='ortho')
                        noise_dct = noise_dct * mask
                        noise_numpy = idct(idct(noise_dct,
                                                axis=0,
                                                norm='ortho'),
                                           axis=1,
                                           norm='ortho')
                        noise_SFM[img_idx, 0, :, :] = noise_numpy

                if opt.SFM_noise == 0:
                    imgn_train = torch.from_numpy(img_train_SFM) + noise
                elif opt.SFM_noise == 1:
                    imgn_train = torch.from_numpy(
                        img_train_SFM) + torch.from_numpy(noise_SFM)

                if opt.SFM_GT == 1:
                    img_train = torch.from_numpy(img_train_SFM)
            else:
                imgn_train = img_train + noise

            # Training step
            img_train, imgn_train = Variable(img_train.cuda()), Variable(
                imgn_train.cuda())
            noise = Variable(noise.cuda())

            out_train = model(imgn_train)
            OUT_NOISE = imgn_train - out_train
            loss = criterion(OUT_NOISE, noise) / (imgn_train.size()[0] * 2)

            if (opt.DCTloss_weight == 1):
                noise_DCT = torch.zeros(noise.size())
                OUT_NOISE_DCT = torch.zeros(noise.size())
                for img_idx in range(noise.size()[0]):
                    noise_DCT[img_idx,
                              0, :, :] = torch_dct.dct_2d(noise[img_idx,
                                                                0, :, :])
                    OUT_NOISE_DCT[img_idx, 0, :, :] = torch_dct.dct_2d(
                        OUT_NOISE[img_idx, 0, :, :])
                noise_DCT, OUT_NOISE_DCT = Variable(
                    noise_DCT.cuda()), Variable(OUT_NOISE_DCT.cuda())
                loss += criterion(OUT_NOISE_DCT,
                                  noise_DCT) / (imgn_train.size()[0] * 2)
                loss_DCTcomponent = criterion(
                    OUT_NOISE_DCT, noise_DCT) / (imgn_train.size()[0] * 2)

            loss.backward()
            optimizer.step()

            train_loss_log[epoch] += loss.item()

        train_loss_log[epoch] = train_loss_log[epoch] / len(loader_train)

        elapsed = time.time() - t
        if (opt.DCTloss_weight == 1):
            print(
                'Epoch %d: loss=%.4f, lossDCT=%.4f, elapsed time (min):%.2f' %
                (epoch, train_loss_log[epoch], loss_DCTcomponent.item(),
                 elapsed / 60.))
        else:
            print('Epoch %d: loss=%.4f, elapsed time (min):%.2f' %
                  (epoch, train_loss_log[epoch], elapsed / 60.))

        model_name = get_model_name(opt)
        model_dir = os.path.join('saved_models', model_name)
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        torch.save(model.state_dict(),
                   os.path.join(model_dir, 'net_%d.pth' % (epoch)))
示例#27
0
def dct_flip(x):
    return torch_dct.idct_2d(torch.flip(torch_dct.dct_2d(x, norm='ortho'), [-2, -1]), norm='ortho')
示例#28
0
    def testBackwardRgbColor(self, stride, height, width, datatype):
        rtol, atol = 1e-3, 1e-6

        # Parameters
        nSamples = 8
        nrows = int(math.ceil(height / stride[Direction.VERTICAL]))
        ncols = int(math.ceil(width / stride[Direction.HORIZONTAL]))
        nDecs = stride[0] * stride[1]  # math.prod(stride)
        nComponents = 3  # RGB
        # Source (nSamples x nRows x nCols x nDecs)
        Xr = torch.rand(nSamples,
                        nrows,
                        ncols,
                        nDecs,
                        dtype=datatype,
                        requires_grad=True)
        Xg = torch.rand(nSamples,
                        nrows,
                        ncols,
                        nDecs,
                        dtype=datatype,
                        requires_grad=True)
        Xb = torch.rand(nSamples,
                        nrows,
                        ncols,
                        nDecs,
                        dtype=datatype,
                        requires_grad=True)
        # nSamples x nComponents x (Stride[0]xnRows) x (Stride[1]xnCols)
        dLdZ = torch.rand(nSamples, nComponents, height, width, dtype=datatype)

        # Expected values
        arrayshape = stride.copy()
        arrayshape.insert(0, -1)
        Y = dct.dct_2d(dLdZ.view(arrayshape), norm='ortho')
        A = permuteDctCoefs_(Y)
        # Rearrange the DCT Coefs. (nSamples x nRows x nCols x nDecs)
        Z = A.view(nSamples, nComponents, nrows, ncols, nDecs)
        expctddLdXr, expctddLdXg, expctddLdXb = map(
            lambda x: torch.squeeze(x, dim=1),
            torch.chunk(Z, nComponents, dim=1))

        # Instantiation of target class
        layer = NsoltBlockIdct2dLayer(decimation_factor=stride,
                                      number_of_components=nComponents,
                                      name='E0~')

        # Actual values
        Z = layer.forward(Xr, Xg, Xb)
        Z.backward(dLdZ)
        actualdLdXr = Xr.grad
        actualdLdXg = Xg.grad
        actualdLdXb = Xb.grad

        # Evaluation
        self.assertEqual(actualdLdXr.dtype, datatype)
        self.assertEqual(actualdLdXg.dtype, datatype)
        self.assertEqual(actualdLdXb.dtype, datatype)
        self.assertTrue(
            torch.allclose(actualdLdXr, expctddLdXr, rtol=rtol, atol=atol))
        self.assertTrue(
            torch.allclose(actualdLdXg, expctddLdXg, rtol=rtol, atol=atol))
        self.assertTrue(
            torch.allclose(actualdLdXb, expctddLdXb, rtol=rtol, atol=atol))
        self.assertTrue(Z.requires_grad)