def __init__(self,
                 height=5,
                 nbands=4,
                 scale_factor=2,
                 extract_level=1,
                 visualize=False):
        '''Phase_Difference_Extractor: A class to do steerable pyramid computation, extract the phase and phase difference
        Usage: 
              build_pyramid(): build complex steerable pyramid coefficients
              extract(): extract phase differences
        Parameters:
            height: int, default 5
                The coefficients levels including low-pass and high-pass
            nbands: int, default 4
                The number of orientations of the bandpass filters
            scale_factor: int, default 2
                Spatial resolution reduction scale scale_factor
            extract_level: int, or list of int numbers, default 1
                If extract_level is an int number, build_pyramid() will only return the coefficients in one level;
                If extract_level is a list, build_pyramid() will only return the coefficients of multiple levels.
            visualize: bool, default False
               If true, the build_pyramid() and extract() will show the processed results.
        '''

        self.pyramid = SCFpyr_PyTorch(height=height,
                                      nbands=nbands,
                                      scale_factor=scale_factor,
                                      device=get_device())
        self.height = height
        self.nbands = nbands
        self.scale_factor = scale_factor
        self.extract_level = extract_level
        self.visualize = visualize
    def STSIM(self, img1, img2, sub_sample=True):
        assert img1.shape == img2.shape
        assert len(img1.shape) == 4  # [N,C,H,W]
        assert img1.shape[1] == 1  # gray image

        s = SCFpyr_PyTorch(sub_sample=sub_sample, device=self.device)

        pyrA = s.getlist(s.build(img1))
        pyrB = s.getlist(s.build(img2))

        stsim = map(self.pooling, pyrA, pyrB)

        return torch.mean(torch.stack(list(stsim)), dim=0)
Пример #3
0
 def __init__(self, height=5, nbands=4, scale_factor=2, device=None, extract_level=1, visualize=False):
     self.pyramid = SCFpyr_PyTorch(
         height=height, 
         nbands=nbands,
         scale_factor=scale_factor, 
         device=device
     )
     self.height = height
     self.nbands = nbands
     self.scale_factor = scale_factor
     self.device = device
     self.extract_level = extract_level
     self.visualize = visualize
    def STSIM_M(self, imgs):
        '''
		:param imgs: [N,C=1,H,W]
		:return:
		'''
        s = SCFpyr_PyTorch(sub_sample=True, device=self.device)
        coeffs = s.build(imgs)

        f = []
        # single subband statistics
        for c in s.getlist(coeffs):
            c = self.abs(c)
            var = torch.var(c, dim=[1, 2, 3])
            f.append(torch.mean(c, dim=[1, 2, 3]))
            f.append(var)
            f.append(
                torch.mean(c[:, :, :-1, :] * c[:, :, 1:, :], dim=[1, 2, 3]) /
                var)
            f.append(
                torch.mean(c[:, :, :, :-1] * c[:, :, :, 1:], dim=[1, 2, 3]) /
                var)

        # correlation statistics
        # across orientations
        for orients in coeffs[1:-1]:
            for (c1, c2) in list(itertools.combinations(orients, 2)):
                c1 = self.abs(c1)
                c2 = self.abs(c2)
                f.append(torch.mean(c1 * c2, dim=[1, 2, 3]))

        for orient in range(len(coeffs[1])):
            for height in range(len(coeffs) - 3):
                c1 = self.abs(coeffs[height + 1][orient])
                c2 = self.abs(coeffs[height + 2][orient])

                c1 = F.interpolate(c1, size=c2.shape[2:])
                f.append(
                    torch.mean(c1 * c2, dim=[1, 2, 3]) /
                    torch.sqrt(torch.var(c1, dim=[1, 2, 3])) /
                    torch.sqrt(torch.var(c2, dim=[1, 2, 3])))
        return torch.stack(f)
    def STSIM2(self, img1, img2):
        assert img1.shape == img2.shape

        s = SCFpyr_PyTorch(sub_sample=True, device=self.device)
        s_nosub = SCFpyr_PyTorch(sub_sample=False, device=self.device)

        pyrA = s.getlist(s.build(img1))
        pyrB = s.getlist(s.build(img2))
        stsimg2 = list(map(self.pooling, pyrA, pyrB))

        # Add cross terms
        bandsAn = s_nosub.build(img1)
        bandsBn = s_nosub.build(img2)

        Nor = len(bandsAn[1])

        # Accross scale, same orientation
        for scale in range(2, len(bandsAn) - 1):
            for orient in range(Nor):
                img11 = self.abs(bandsAn[scale - 1][orient])
                img12 = self.abs(bandsAn[scale][orient])

                img21 = self.abs(bandsBn[scale - 1][orient])
                img22 = self.abs(bandsBn[scale][orient])

                stsimg2.append(
                    self.compute_cross_term(img11, img12, img21,
                                            img22).mean(dim=[1, 2, 3]))

        # Accross orientation, same scale
        for scale in range(1, len(bandsAn) - 1):
            for orient in range(Nor - 1):
                img11 = self.abs(bandsAn[scale][orient])
                img21 = self.abs(bandsBn[scale][orient])

                for orient2 in range(orient + 1, Nor):
                    img13 = self.abs(bandsAn[scale][orient2])
                    img23 = self.abs(bandsBn[scale][orient2])
                    stsimg2.append(
                        self.compute_cross_term(img11, img13, img21,
                                                img23).mean(dim=[1, 2, 3]))

        return torch.mean(torch.stack(stsimg2), dim=0)
Пример #6
0

# Requires PyTorch with MKL when setting to 'cpu' 
device = torch.device('cpu')

# Load batch of images [N,1,H,W]
im_batch_numpy = utils.load_image_batch('./assets/lena.jpg',32,600)
img=cv2.imread('./assets/lena.jpg',0)
cv2.imshow('yuantu',img)
im_torch = torch.from_numpy(img).to(device)
im_batch_torch=im_torch.unsqueeze(0).unsqueeze(0).float()
# Initialize Complex Steerbale Pyramid
height = 12
nbands = 4
scale_factor = 2**(1/2)
pyr = SCFpyr_PyTorch(height=height, nbands=nbands, scale_factor=scale_factor, device=device)
pyr_type = 1

# Decompose entire batch of images 
coeff = pyr.build(im_batch_torch,pyr_type)

# Reconstruct batch of images again
img_recon = pyr.reconstruct(coeff,pyr_type)
img=im_torch.float()
recon=img_recon.squeeze()
loss=torch.nn.MSELoss()
print('MSE:',loss(img,recon))
cv2.imshow('recon',recon.numpy().astype(np.uint8))
# Visualization
# coeff_single = utils.extract_from_batch(coeff, 0)
# coeff_grid = utils.make_grid_coeff(coeff_single, normalize=True)
Пример #7
0
    parser.add_argument('--batch_size', type=int, default='32')
    parser.add_argument('--image_size', type=int, default='200')
    parser.add_argument('--pyr_nlevels', type=int, default='5')
    parser.add_argument('--pyr_nbands', type=int, default='4')
    parser.add_argument('--pyr_scale_factor', type=int, default='2')
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--visualize', type=bool, default=True)
    config = parser.parse_args()

    device = utils.get_device(config.device)

    ############################################################################
    # Build the complex steerable pyramid

    pyr = SCFpyr_PyTorch(height=config.pyr_nlevels,
                         nbands=config.pyr_nbands,
                         scale_factor=config.pyr_scale_factor,
                         device=device)

    ############################################################################
    # Create a batch and feed-forward

    start_time = time.time()

    # Load Batch
    im_batch_numpy = utils.load_image_batch(config.image_file,
                                            config.batch_size,
                                            config.image_size)
    im_batch_torch = torch.from_numpy(im_batch_numpy).to(device)

    # Compute Steerable Pyramid
    coeff = pyr.build(im_batch_torch)
Пример #8
0
    config.batch_sizes = list(map(int, config.batch_sizes.split(',')))
    config.image_sizes = list(map(int, config.image_sizes.split(',')))

    device = utils.get_device(config.device)

    ################################################################################

    pyr_numpy = SCFpyr_NumPy(height=config.pyr_nlevels,
                             nbands=config.pyr_nbands,
                             scale_factor=config.pyr_scale_factor,
                             precision=config.precision)

    pyr_torch = SCFpyr_PyTorch(height=config.pyr_nlevels,
                               nbands=config.pyr_nbands,
                               scale_factor=config.pyr_scale_factor,
                               device=device,
                               precision=config.precision)

    pyr_tf = SCFpyr_TF(height=config.pyr_nlevels,
                       nbands=config.pyr_nbands,
                       scale_factor=config.pyr_scale_factor,
                       precision=config.precision)
    ############################################################################
    # Run Benchmark

    durations_numpy = np.zeros(
        (len(config.batch_sizes), len(config.image_sizes), config.num_runs))
    durations_torch = np.zeros(
        (len(config.batch_sizes), len(config.image_sizes), config.num_runs))
    durations_tf = np.zeros(
Пример #9
0
num_epochs = 2
learning_rate = 0.001
batch_size = 8
# pyr parameter
height = 12
nbands = 4
scale_factor = 2**(1/2)
pyr_type = 1
# Load dataset
transform = transforms.Compose(
    [transforms.Resize((256, 256)),
     transforms.ToTensor()])
dataset = Triplets(
    '/home/lj/Documents/code/python/DAVIS/JPEGImages/480p/', transform)

pyr = SCFpyr_PyTorch(height=height, nbands=nbands, scale_factor=scale_factor, device=device)
# define network
model = PhaseNet()
criterion = Total_loss(v=1.0)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9,0.999))

# Train the model
total_step = 0
for epoch in range(num_epochs):
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                          shuffle=True, num_workers=4)
    for channel in range(3):
        for n, Triplets_batch in enumerate(trainloader):
            # get images_list [[N,C,H,W],[N,C,H,W],...], usually len(images_list)=batch_size if len(dataset)%batch_size==0
            images_list = [torch.stack([Triplets_batch['start'][i],
                                        Triplets_batch['inter'][i],
Пример #10
0
pyr_numpy = SCFpyr_NumPy(pyr_height, pyr_nbands, scale_factor=2)
coeff_numpy = pyr_numpy.build(image)
reconstruction_numpy = pyr_numpy.reconstruct(coeff_numpy)
reconstruction_numpy = reconstruction_numpy.astype(np.uint8)

print('#' * 60)

################################################################################
# PyTorch

device = torch.device('cuda:0')

im_batch = torch.from_numpy(image[None, None, :, :])
im_batch = im_batch.to(device).float()

pyr_torch = SCFpyr_PyTorch(pyr_height, pyr_nbands, device=device)
coeff_torch = pyr_torch.build(im_batch)
reconstruction_torch = pyr_torch.reconstruct(coeff_torch)
reconstruction_torch = reconstruction_torch.cpu().numpy()[0, ]

# Extract first example from the batch and move to CPU
coeff_torch = utils.extract_from_batch(coeff_torch, 0)

################################################################################
# Check correctness

print('#' * 60)
assert len(coeff_numpy) == len(coeff_torch)

for level, _ in enumerate(coeff_numpy):
Пример #11
0
import numpy as np
import cv2
import torch
import sys
sys.path.append('..')
from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch
from perceptual.filterbank import SteerableNoSub

device = torch.device('cuda:0')
pyr_NoSub = SCFpyr_PyTorch(sub_sample=False, device=device)

path = ''
img = cv2.imread(path, 0)
img_batch = torch.from_numpy(img).to(device)
img_batch = img_batch.unsqueeze(0).float().unsqueeze(0)

coeffs = pyr_NoSub.build(img_batch)

pyr_NoSub_c = SteerableNoSub()
coeffs_c = pyr_NoSub_c.buildSCFpyr(img)

tolerance = 1e-3

coeff = coeffs[0].cpu().numpy().squeeze(0)
all_close = np.allclose(coeff, coeffs_c[0], atol=tolerance)
s = np.sum(coeff - coeffs_c[0])
print('Succesful for subband {}: {}, with tolerance of {}'.format(
    0, all_close, tolerance))
print('Sum of difference: {}'.format(s))
for i in range(1, len(coeffs) - 1):
    for j in range(len(coeffs[i])):
Пример #12
0
import numpy as np
import cv2
import torch
import sys
sys.path.append('..')
from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch
from perceptual.filterbank import Steerable


device = torch.device('cuda:0')
pyr = SCFpyr_PyTorch(sub_sample = True, device = device)

path = ''
img = cv2.imread(path,0)
img_batch = torch.from_numpy(img).to(device)
img_batch = img_batch.unsqueeze(0).float().unsqueeze(0)


coeffs = pyr.build(img_batch)

pyr_c = Steerable()
coeffs_c = pyr_c.buildSCFpyr(img)

tolerance = 1e-3

coeff = coeffs[0].cpu().numpy().squeeze(0)
all_close = np.allclose(coeff, coeffs_c[0], atol=tolerance)
s = np.sum(coeff-coeffs_c[0])
print('Succesful for subband {}: {}, with tolerance of {}'.format(0,all_close, tolerance))
print('Sum of difference: {}'.format(s))
for i in range(1,len(coeffs)-1):
Пример #13
0
class Steerable_Pyramid_Phase(object):
    def __init__(self,
                 height=5,
                 nbands=4,
                 scale_factor=2,
                 device=None,
                 extract_level=1,
                 visualize=False):
        self.pyramid = SCFpyr_PyTorch(height=height,
                                      nbands=nbands,
                                      scale_factor=scale_factor,
                                      device=device)
        self.height = height
        self.nbands = nbands
        self.scale_factor = scale_factor
        self.device = device
        self.extract_level = extract_level
        self.visualize = visualize

    def build_pyramid(self, im_batch, symmetry=True):
        """
        input image batch has 4 dimensions: batch size,  number of phase images, W, H
        """
        bs, num_phase_frames, W, H = im_batch.size()
        trans_im_batch = im_batch.view(
            bs * num_phase_frames, 1, W,
            H)  # the second dim is 1, indicating it's grayscale image
        if symmetry:
            trans_im_batch = symmetric_extension_batch(trans_im_batch)
        #tic= time()
        coeff_batch = self.pyramid.build(trans_im_batch)
        #print("process {} images for {}".format(bs*num_phase_frames, time()-tic))
        if not isinstance(coeff_batch, list):
            raise ValueError('Batch of coefficients must be a list')

        if self.visualize:
            example_id = 10  # the 10th image from number of phase images
            example_coeff = extract_from_batch(coeff_batch, example_id,
                                               symmetry)
            example_coeff = make_grid_coeff(example_coeff)
            example_coeff = Image.fromarray(example_coeff)
            example_img = trans_im_batch[example_id, 0, ...].cpu().numpy()
            example_img = Image.fromarray(255 * example_img /
                                          example_img.max())
            example_img.show()
            example_img_remove_symm = trans_im_batch[example_id, 0,
                                                     ...].cpu().numpy()
            example_img_remove_symm = 255 * example_img_remove_symm / example_img_remove_symm.max(
            )
            if symmetry:
                W, H = example_img_remove_symm.shape
                example_img_remove_symm = example_img_remove_symm[:W //
                                                                  2, :H // 2]
                example_img_remove_symm = Image.fromarray(
                    example_img_remove_symm)
                example_img_remove_symm.show()
            example_coeff.show()
        if isinstance(self.extract_level, int):
            extr_level_coeff_batch = self.extract_coeff_level(
                self.extract_level, coeff_batch)
            W, H, _ = extr_level_coeff_batch.size()[-3:]
            nbands = extr_level_coeff_batch.size()[0]
            extr_level_coeff_batch = extr_level_coeff_batch.view(
                nbands, bs, num_phase_frames, W, H, 2)
            extr_level_coeff_batch = extr_level_coeff_batch.permute(
                1, 0, 2, 3, 4, 5).contiguous()
            if symmetry:
                extr_level_coeff_batch = extr_level_coeff_batch[..., :W //
                                                                2, :H // 2, :]
        elif isinstance(self.extract_level, list):
            extr_level_coeff_batch = []
            for level in self.extract_level:
                level_coeff_batch = self.extract_coeff_level(
                    level, coeff_batch)
                W, H, _ = level_coeff_batch.size()[-3:]
                nbands = level_coeff_batch.size()[0]
                level_coeff_batch = level_coeff_batch.view(
                    nbands, bs, num_phase_frames, W, H, 2)
                level_coeff_batch = level_coeff_batch.permute(
                    1, 0, 2, 3, 4, 5).contiguous()
                if symmetry:
                    level_coeff_batch = level_coeff_batch[..., :W // 2, :H //
                                                          2, :]
                extr_level_coeff_batch.append(level_coeff_batch)
        return extr_level_coeff_batch

    def extract_coeff_level(self, level, coeff_batch):
        extr_level_coeff_batch = coeff_batch[level]
        assert isinstance(extr_level_coeff_batch, list)
        extr_level_coeff_batch = torch.stack(extr_level_coeff_batch, 0)
        return extr_level_coeff_batch

    def extract_phase(self,
                      coeff_batch,
                      return_phase=False,
                      return_both=False):
        """
        coeff batch has dimension: batch size, nbands, number phase frames (17), W, H, 2   (2 is for real part and imaginary part) 
        """
        bs, n_bands, n_phase_frames, W, H, _ = coeff_batch.size()
        trans_coeff_batch = coeff_batch.view(bs * n_bands * n_phase_frames, W,
                                             H, -1)
        real_coeff_batch, imag_coeff_batch = torch.unbind(
            trans_coeff_batch, -1)
        phase_batch = torch.atan2(imag_coeff_batch, real_coeff_batch)
        mag_batch = torch.sqrt(
            torch.pow(imag_coeff_batch, 2) + torch.pow(real_coeff_batch, 2))
        phase_batch = phase_batch.view(bs * n_bands, n_phase_frames, W, H)
        EPS = 1e-10
        mag_batch = mag_batch.view(bs * n_bands, n_phase_frames, W,
                                   H) + EPS  # TO avoid mag==0
        assert (mag_batch <= 0.0).nonzero().size(0) == 0

        # phase unwrap over time
        phase_batch = torch_unwrap(phase_batch, discont=math.pi, dim=-3)
        # phase denoising (amplitude-based gaussian blur)
        g_kernel = torch.from_numpy(gaussian_kernel(std=2, tap=11))
        #denoised_phase_batch = amplitude_based_gaussian_blur_numpy(mag_batch, phase_batch, g_kernel)
        denoised_phase_batch = amplitude_based_gaussian_blur(
            mag_batch, phase_batch, g_kernel)
        denoised_phase_batch = denoised_phase_batch.view(
            bs, n_bands, n_phase_frames, W, H)
        # phase difference
        phase_difference_batch = torch_diff(denoised_phase_batch, dim=2)
        phase_difference_batch = phase_difference_batch.view(
            bs, n_bands, n_phase_frames - 1, W, H)
        if self.visualize:
            phase_example = phase_batch.view(bs, n_bands, n_phase_frames, W,
                                             H)[0, ...]
            mag_example = mag_batch.view(bs, n_bands, n_phase_frames, W,
                                         H)[0, ...]
            denoised_phase_example = denoised_phase_batch.view(
                bs, n_bands, n_phase_frames, W, H)[0, ...]
            phase_diff_example = phase_difference_batch.view(
                bs, n_bands, n_phase_frames - 1, W, H)[0, ...]
            self.show_3D_subplots(phase_example,
                                  title="phase example",
                                  first_k_frames=2)
            self.show_3D_subplots(mag_example,
                                  title="magnitude example",
                                  first_k_frames=2)
            self.show_3D_subplots(denoised_phase_example,
                                  title="denoised phase example",
                                  first_k_frames=2)
            self.show_3D_subplots(phase_diff_example,
                                  title="phase difference example",
                                  first_k_frames=2)
        # denoised phase centered
        mean = denoised_phase_batch.mean(-1).mean(-1)
        mean = mean.unsqueeze(-1).unsqueeze(-1)
        denoised_phase_batch = denoised_phase_batch - mean
        mean = phase_difference_batch.mean(-1).mean(-1)
        mean = mean.unsqueeze(-1).unsqueeze(-1)
        phase_difference_batch = phase_difference_batch - mean
        phase_difference_batch = torch.clamp(phase_difference_batch,
                                             -5 * math.pi, 5 * math.pi)
        if return_both:
            # remove one phase image
            denoised_phase_batch = denoised_phase_batch[:, :, 1:, :]
            assert phase_difference_batch.size() == denoised_phase_batch.size()
            result = self.insert_tensors(phase_difference_batch,
                                         denoised_phase_batch,
                                         dim=2)
            result = result.cuda()
            return result
        if return_phase:
            return denoised_phase_batch
        else:
            return phase_difference_batch

    def insert_tensors(self, t_a, t_b, dim):
        size = list(t_a.size())
        size[dim] = 2 * size[dim]
        result = torch.zeros(size)
        length = t_a.size(dim)
        for i in range(length):
            slice0 = [slice(None, None)] * len(size)
            slice0[dim] = slice(i, i + 1)
            slice1 = [slice(None, None)] * len(size)
            slice1[dim] = slice(i // 2, i // 2 + 1)
            if i % 2 == 0:
                result[slice0] = t_a[slice1]
            else:
                result[slice0] = t_b[slice1]
        return result

    def show_3D_subplots(self, data, title, first_k_frames=None):
        """
        data has dimensions: nbands, n_phase_frames, W, H
        """
        nbands, n_phase_frames, W, H = data.size()
        m = nbands
        n = first_k_frames if first_k_frames is not None else n_phase_frames

        X, Y = range(1, W + 1), range(1, H + 1)
        Xm, Ym = np.meshgrid(X, Y)
        for i in range(m):
            fig, ax = plt.subplots(nrows=1,
                                   ncols=n,
                                   subplot_kw={'projection': "3d"})
            for j in range(n):
                img = data[i, j, ...].cpu().numpy()
                surf = ax[j].plot_surface(Xm,
                                          Ym,
                                          img,
                                          rstride=1,
                                          cstride=1,
                                          cmap=cm.coolwarm,
                                          linewidth=0,
                                          antialiased=False)
            fig.colorbar(surf, shrink=0.5, aspect=10)
            fig.suptitle(title + ": orientation {}".format(i))
        plt.show()
class Phase_Difference_Extractor(object):
    def __init__(self,
                 height=5,
                 nbands=4,
                 scale_factor=2,
                 extract_level=1,
                 device='cuda:0',
                 visualize=False):
        """
        Phase_Difference_Extractor: A class to do steerable pyramid computation,
        extract the phase and phase difference.
        
        Parameters:
            height: int, default 5
                The coefficients levels including low-pass and high-pass
            nbands: int, default 4
                The number of orientations of the bandpass filters
            scale_factor: int, default 2
                Spatial resolution reduction scale scale_factor
            extract_level: int, or list of int numbers, default 1
                If extract_level is an int number, build_pyramid() will only 
                return the coefficients in one level;
                If extract_level is a list, build_pyramid() will only return 
                the coefficients of multiple levels.
            visualize: bool, default False
               If true, the build_pyramid() and extract() will show the processed results.
        """

        self.pyramid = SCFpyr_PyTorch(height=height,
                                      nbands=nbands,
                                      scale_factor=scale_factor,
                                      device=device)
        self.height = height
        self.nbands = nbands
        self.scale_factor = scale_factor
        self.extract_level = extract_level
        self.visualize = visualize

    def build_pyramid(self, im_batch, symmetry=True):
        """
        input image batch has 4 dimensions: batch size, number of phase images, W, H
        """
        bs, num_phase_frames, W, H = im_batch.size()
        trans_im_batch = im_batch.view(
            bs * num_phase_frames, 1, W,
            H)  # the second dim is 1, indicating it's grayscale image
        if symmetry:
            trans_im_batch = symmetric_extension_batch(trans_im_batch)
        #tic= time()
        coeff_batch = self.pyramid.build(trans_im_batch)
        #print("process {} images for {}".format(bs*num_phase_frames, time()-tic))
        if not isinstance(coeff_batch, list):
            raise ValueError('Batch of coefficients must be a list')

        if self.visualize:
            example_id = 10  # the 10th image from number of phase images
            example_coeff = extract_from_batch(coeff_batch, example_id,
                                               symmetry)
            example_coeff = make_grid_coeff(example_coeff)
            example_coeff = Image.fromarray(example_coeff)
            example_img = trans_im_batch[example_id, 0, ...].cpu().numpy()
            example_img = Image.fromarray(255 * example_img /
                                          example_img.max())
            example_img.show()
            example_img_remove_symm = trans_im_batch[example_id, 0,
                                                     ...].cpu().numpy()
            example_img_remove_symm = 255 * example_img_remove_symm / example_img_remove_symm.max(
            )
            if symmetry:
                W, H = example_img_remove_symm.shape
                example_img_remove_symm = example_img_remove_symm[:W //
                                                                  2, :H // 2]
                example_img_remove_symm = Image.fromarray(
                    example_img_remove_symm)
                example_img_remove_symm.show()
            example_coeff.show()
        if isinstance(self.extract_level, int):
            extr_level_coeff_batch = self.extract_coeff_level(
                self.extract_level, coeff_batch)
            W, H, _ = extr_level_coeff_batch.size()[-3:]
            nbands = extr_level_coeff_batch.size()[0]
            extr_level_coeff_batch = extr_level_coeff_batch.view(
                nbands, bs, num_phase_frames, W, H, 2)
            extr_level_coeff_batch = extr_level_coeff_batch.permute(
                1, 0, 2, 3, 4, 5).contiguous()
            if symmetry:
                extr_level_coeff_batch = extr_level_coeff_batch[..., :W //
                                                                2, :H // 2, :]
        elif isinstance(self.extract_level, list):
            extr_level_coeff_batch = []
            for level in self.extract_level:
                level_coeff_batch = self.extract_coeff_level(
                    level, coeff_batch)
                W, H, _ = level_coeff_batch.size()[-3:]
                nbands = level_coeff_batch.size()[0]
                level_coeff_batch = level_coeff_batch.view(
                    nbands, bs, num_phase_frames, W, H, 2)
                level_coeff_batch = level_coeff_batch.permute(
                    1, 0, 2, 3, 4, 5).contiguous()
                if symmetry:
                    level_coeff_batch = level_coeff_batch[..., :W // 2, :H //
                                                          2, :]
                extr_level_coeff_batch.append(level_coeff_batch)
        return extr_level_coeff_batch

    def extract_coeff_level(self, level, coeff_batch):
        extr_level_coeff_batch = coeff_batch[level]
        assert isinstance(extr_level_coeff_batch, list)
        extr_level_coeff_batch = torch.stack(extr_level_coeff_batch, 0)
        return extr_level_coeff_batch

    def extract(self, coeff_batch):
        """
        coeff batch has dimension: batch size, nbands, number phase frames (17), W, H, 2   (2 is for real part and imaginary part) 
        """
        bs, n_bands, n_phase_frames, W, H, _ = coeff_batch.size()
        trans_coeff_batch = coeff_batch.view(bs * n_bands * n_phase_frames, W,
                                             H, -1)
        real_coeff_batch, imag_coeff_batch = torch.unbind(
            trans_coeff_batch, -1)
        phase_batch = torch.atan2(imag_coeff_batch, real_coeff_batch)
        mag_batch = torch.sqrt(
            torch.pow(imag_coeff_batch, 2) + torch.pow(real_coeff_batch, 2))
        phase_batch = phase_batch.view(bs * n_bands, n_phase_frames, W, H)
        EPS = 1e-10
        mag_batch = mag_batch.view(bs * n_bands, n_phase_frames, W,
                                   H) + EPS  # TO avoid mag==0
        assert (mag_batch <= 0.0).nonzero().size(0) == 0

        # phase unwrap over time
        phase_batch = torch_unwrap(phase_batch, discont=math.pi, dim=-3)
        # phase denoising (amplitude-based gaussian blur)
        g_kernel = torch.from_numpy(gaussian_kernel(std=2, tap=11))
        #denoised_phase_batch = amplitude_based_gaussian_blur_numpy(mag_batch, phase_batch, g_kernel)
        denoised_phase_batch = amplitude_based_gaussian_blur(
            mag_batch, phase_batch, g_kernel)
        denoised_phase_batch = denoised_phase_batch.view(
            bs, n_bands, n_phase_frames, W, H)
        # phase difference
        phase_difference_batch = torch_diff(denoised_phase_batch, dim=2)
        phase_difference_batch = phase_difference_batch.view(
            bs, n_bands, n_phase_frames - 1, W, H)
        if self.visualize:
            phase_example = phase_batch.view(bs, n_bands, n_phase_frames, W,
                                             H)[0, ...]
            mag_example = mag_batch.view(bs, n_bands, n_phase_frames, W,
                                         H)[0, ...]
            denoised_phase_example = denoised_phase_batch.view(
                bs, n_bands, n_phase_frames, W, H)[0, ...]
            phase_diff_example = phase_difference_batch.view(
                bs, n_bands, n_phase_frames - 1, W, H)[0, ...]
            self.show_3D_subplots(phase_example,
                                  title="phase example",
                                  first_k_frames=2)
            self.show_3D_subplots(mag_example,
                                  title="magnitude example",
                                  first_k_frames=2)
            self.show_3D_subplots(denoised_phase_example,
                                  title="denoised phase example",
                                  first_k_frames=2)
            self.show_3D_subplots(phase_diff_example,
                                  title="phase difference example",
                                  first_k_frames=2)
        # denoised phase centered
        mean = denoised_phase_batch.mean(-1).mean(-1)
        mean = mean.unsqueeze(-1).unsqueeze(-1)
        denoised_phase_batch = denoised_phase_batch - mean
        mean = phase_difference_batch.mean(-1).mean(-1)
        mean = mean.unsqueeze(-1).unsqueeze(-1)
        phase_difference_batch = phase_difference_batch - mean
        phase_difference_batch = torch.clamp(phase_difference_batch,
                                             -5 * math.pi, 5 * math.pi)
        return phase_difference_batch

    def show_3D_subplots(self, data, title, first_k_frames=None):
        """
        data has dimensions: nbands, n_phase_frames, W, H
        """
        nbands, n_phase_frames, W, H = data.size()
        m = nbands
        n = first_k_frames if first_k_frames is not None else n_phase_frames

        X, Y = range(1, W + 1), range(1, H + 1)
        Xm, Ym = np.meshgrid(X, Y)
        for i in range(m):
            fig, ax = plt.subplots(nrows=1,
                                   ncols=n,
                                   subplot_kw={'projection': "3d"})
            for j in range(n):
                img = data[i, j, ...].cpu().numpy()
                surf = ax[j].plot_surface(Xm,
                                          Ym,
                                          img,
                                          rstride=1,
                                          cstride=1,
                                          cmap=cm.coolwarm,
                                          linewidth=0,
                                          antialiased=False)
            fig.colorbar(surf, shrink=0.5, aspect=10)
            fig.suptitle(title + ": orientation {}".format(i))
        plt.show()