def STSIM2(self, img1, img2):
        assert img1.shape == img2.shape

        s = SCFpyr_PyTorch(sub_sample=True, device=self.device)
        s_nosub = SCFpyr_PyTorch(sub_sample=False, device=self.device)

        pyrA = s.getlist(s.build(img1))
        pyrB = s.getlist(s.build(img2))
        stsimg2 = list(map(self.pooling, pyrA, pyrB))

        # Add cross terms
        bandsAn = s_nosub.build(img1)
        bandsBn = s_nosub.build(img2)

        Nor = len(bandsAn[1])

        # Accross scale, same orientation
        for scale in range(2, len(bandsAn) - 1):
            for orient in range(Nor):
                img11 = self.abs(bandsAn[scale - 1][orient])
                img12 = self.abs(bandsAn[scale][orient])

                img21 = self.abs(bandsBn[scale - 1][orient])
                img22 = self.abs(bandsBn[scale][orient])

                stsimg2.append(
                    self.compute_cross_term(img11, img12, img21,
                                            img22).mean(dim=[1, 2, 3]))

        # Accross orientation, same scale
        for scale in range(1, len(bandsAn) - 1):
            for orient in range(Nor - 1):
                img11 = self.abs(bandsAn[scale][orient])
                img21 = self.abs(bandsBn[scale][orient])

                for orient2 in range(orient + 1, Nor):
                    img13 = self.abs(bandsAn[scale][orient2])
                    img23 = self.abs(bandsBn[scale][orient2])
                    stsimg2.append(
                        self.compute_cross_term(img11, img13, img21,
                                                img23).mean(dim=[1, 2, 3]))

        return torch.mean(torch.stack(stsimg2), dim=0)
    def STSIM(self, img1, img2, sub_sample=True):
        assert img1.shape == img2.shape
        assert len(img1.shape) == 4  # [N,C,H,W]
        assert img1.shape[1] == 1  # gray image

        s = SCFpyr_PyTorch(sub_sample=sub_sample, device=self.device)

        pyrA = s.getlist(s.build(img1))
        pyrB = s.getlist(s.build(img2))

        stsim = map(self.pooling, pyrA, pyrB)

        return torch.mean(torch.stack(list(stsim)), dim=0)
    def STSIM_M(self, imgs):
        '''
		:param imgs: [N,C=1,H,W]
		:return:
		'''
        s = SCFpyr_PyTorch(sub_sample=True, device=self.device)
        coeffs = s.build(imgs)

        f = []
        # single subband statistics
        for c in s.getlist(coeffs):
            c = self.abs(c)
            var = torch.var(c, dim=[1, 2, 3])
            f.append(torch.mean(c, dim=[1, 2, 3]))
            f.append(var)
            f.append(
                torch.mean(c[:, :, :-1, :] * c[:, :, 1:, :], dim=[1, 2, 3]) /
                var)
            f.append(
                torch.mean(c[:, :, :, :-1] * c[:, :, :, 1:], dim=[1, 2, 3]) /
                var)

        # correlation statistics
        # across orientations
        for orients in coeffs[1:-1]:
            for (c1, c2) in list(itertools.combinations(orients, 2)):
                c1 = self.abs(c1)
                c2 = self.abs(c2)
                f.append(torch.mean(c1 * c2, dim=[1, 2, 3]))

        for orient in range(len(coeffs[1])):
            for height in range(len(coeffs) - 3):
                c1 = self.abs(coeffs[height + 1][orient])
                c2 = self.abs(coeffs[height + 2][orient])

                c1 = F.interpolate(c1, size=c2.shape[2:])
                f.append(
                    torch.mean(c1 * c2, dim=[1, 2, 3]) /
                    torch.sqrt(torch.var(c1, dim=[1, 2, 3])) /
                    torch.sqrt(torch.var(c2, dim=[1, 2, 3])))
        return torch.stack(f)
Ejemplo n.º 4
0
                         scale_factor=config.pyr_scale_factor,
                         device=device)

    ############################################################################
    # Create a batch and feed-forward

    start_time = time.time()

    # Load Batch
    im_batch_numpy = utils.load_image_batch(config.image_file,
                                            config.batch_size,
                                            config.image_size)
    im_batch_torch = torch.from_numpy(im_batch_numpy).to(device)

    # Compute Steerable Pyramid
    coeff = pyr.build(im_batch_torch)

    duration = time.time() - start_time
    print(
        'Finishing decomposing {batch_size} images in {duration:.1f} seconds.'.
        format(batch_size=config.batch_size, duration=duration))

    ############################################################################
    # Visualization

    # Just extract a single example from the batch
    # Also moves the example to CPU and NumPy
    coeff = utils.extract_from_batch(coeff, 0)

    if config.visualize:
        import cv2
Ejemplo n.º 5
0
# Requires PyTorch with MKL when setting to 'cpu' 
device = torch.device('cpu')

# Load batch of images [N,1,H,W]
im_batch_numpy = utils.load_image_batch('./assets/lena.jpg',32,600)
img=cv2.imread('./assets/lena.jpg',0)
cv2.imshow('yuantu',img)
im_torch = torch.from_numpy(img).to(device)
im_batch_torch=im_torch.unsqueeze(0).unsqueeze(0).float()
# Initialize Complex Steerbale Pyramid
height = 12
nbands = 4
scale_factor = 2**(1/2)
pyr = SCFpyr_PyTorch(height=height, nbands=nbands, scale_factor=scale_factor, device=device)
pyr_type = 1

# Decompose entire batch of images 
coeff = pyr.build(im_batch_torch,pyr_type)

# Reconstruct batch of images again
img_recon = pyr.reconstruct(coeff,pyr_type)
img=im_torch.float()
recon=img_recon.squeeze()
loss=torch.nn.MSELoss()
print('MSE:',loss(img,recon))
cv2.imshow('recon',recon.numpy().astype(np.uint8))
# Visualization
# coeff_single = utils.extract_from_batch(coeff, 0)
# coeff_grid = utils.make_grid_coeff(coeff_single, normalize=True)
# cv2.imshow('Complex Steerable Pyramid', coeff_grid)
cv2.waitKey(0)
Ejemplo n.º 6
0
model = torch.load('./model/2019-04-15 21:46:19_model.pkl')
model.eval()
# get images_list [[N,C,H,W],[N,C,H,W],...] len(images_list)=batch_size
images_list = [
    torch.stack([
        Triplets_batch['start'][i], Triplets_batch['inter'][i],
        Triplets_batch['end'][i]
    ]) for i in range(batch_size)
]

img_recon = np.empty(shape=(256, 256, 3))
plt.figure(1)
with torch.no_grad():
    for channel in range(3):
        batch_coeff_list = [
            pyr.build(image[:, channel, :, :].unsqueeze(1).to(device),
                      pyr_type=pyr_type) for image in images_list
        ]
        train_coeff, truth_coeff = get_input(batch_coeff_list)
        pre_coeff = model(train_coeff)

        truth_img = Triplets_batch['inter'][:, channel, :, :]
        pre_img = pyr.reconstruct(output_convert(pre_coeff), pyr_type=pyr_type)
        print(pre_img.shape)
        # import pdb;pdb.set_trace()
        img = pre_img[0].numpy()
        # img = 255*(img-img.min())/(img.max()-img.min())
        img_recon[:, :, channel] = img
        plt.subplot(1, 3, channel + 1)
        plt.imshow((255 * (img - img.min()) / (img.max() - img.min())).astype(
            np.uint8), 'gray')
        plt.title('channel:{}'.format(channel))
Ejemplo n.º 7
0
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9,0.999))

# Train the model
total_step = 0
for epoch in range(num_epochs):
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                          shuffle=True, num_workers=4)
    for channel in range(3):
        for n, Triplets_batch in enumerate(trainloader):
            # get images_list [[N,C,H,W],[N,C,H,W],...], usually len(images_list)=batch_size if len(dataset)%batch_size==0
            images_list = [torch.stack([Triplets_batch['start'][i],
                                        Triplets_batch['inter'][i],
                                        Triplets_batch['end'][i]]) for i in range(len(Triplets_batch['start']))]
            # batch_coeff_list = [pyr.BatchCsp(
            #     image, channel=channel, type=1) for image in images_list]
            batch_coeff_list = [pyr.build(image[:,channel,:,:].unsqueeze(1).to(device), pyr_type=pyr_type)
                                for image in images_list]
            train_coeff, truth_coeff = get_input(batch_coeff_list)

            # Forward pass
            pre_coeff = model(train_coeff)

            truth_img = Triplets_batch['inter'][:, channel, :, :]
            pre_img = pyr.reconstruct(output_convert(pre_coeff),pyr_type=pyr_type)
            # import pdb;pdb.set_trace()
            loss = criterion(truth_coeff, pre_coeff, truth_img, pre_img)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
Ejemplo n.º 8
0
coeff_numpy = pyr_numpy.build(image)
reconstruction_numpy = pyr_numpy.reconstruct(coeff_numpy)
reconstruction_numpy = reconstruction_numpy.astype(np.uint8)

print('#' * 60)

################################################################################
# PyTorch

device = torch.device('cuda:0')

im_batch = torch.from_numpy(image[None, None, :, :])
im_batch = im_batch.to(device).float()

pyr_torch = SCFpyr_PyTorch(pyr_height, pyr_nbands, device=device)
coeff_torch = pyr_torch.build(im_batch)
reconstruction_torch = pyr_torch.reconstruct(coeff_torch)
reconstruction_torch = reconstruction_torch.cpu().numpy()[0, ]

# Extract first example from the batch and move to CPU
coeff_torch = utils.extract_from_batch(coeff_torch, 0)

################################################################################
# Check correctness

print('#' * 60)
assert len(coeff_numpy) == len(coeff_torch)

for level, _ in enumerate(coeff_numpy):

    print('Pyramid Level {level}'.format(level=level))
Ejemplo n.º 9
0
import cv2
import torch
import sys
sys.path.append('..')
from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch
from perceptual.filterbank import SteerableNoSub

device = torch.device('cuda:0')
pyr_NoSub = SCFpyr_PyTorch(sub_sample=False, device=device)

path = ''
img = cv2.imread(path, 0)
img_batch = torch.from_numpy(img).to(device)
img_batch = img_batch.unsqueeze(0).float().unsqueeze(0)

coeffs = pyr_NoSub.build(img_batch)

pyr_NoSub_c = SteerableNoSub()
coeffs_c = pyr_NoSub_c.buildSCFpyr(img)

tolerance = 1e-3

coeff = coeffs[0].cpu().numpy().squeeze(0)
all_close = np.allclose(coeff, coeffs_c[0], atol=tolerance)
s = np.sum(coeff - coeffs_c[0])
print('Succesful for subband {}: {}, with tolerance of {}'.format(
    0, all_close, tolerance))
print('Sum of difference: {}'.format(s))
for i in range(1, len(coeffs) - 1):
    for j in range(len(coeffs[i])):
        coeff = coeffs[i][j].cpu().numpy().squeeze(0)
Ejemplo n.º 10
0
import sys
sys.path.append('..')
from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch
from perceptual.filterbank import Steerable


device = torch.device('cuda:0')
pyr = SCFpyr_PyTorch(sub_sample = True, device = device)

path = ''
img = cv2.imread(path,0)
img_batch = torch.from_numpy(img).to(device)
img_batch = img_batch.unsqueeze(0).float().unsqueeze(0)


coeffs = pyr.build(img_batch)

pyr_c = Steerable()
coeffs_c = pyr_c.buildSCFpyr(img)

tolerance = 1e-3

coeff = coeffs[0].cpu().numpy().squeeze(0)
all_close = np.allclose(coeff, coeffs_c[0], atol=tolerance)
s = np.sum(coeff-coeffs_c[0])
print('Succesful for subband {}: {}, with tolerance of {}'.format(0,all_close, tolerance))
print('Sum of difference: {}'.format(s))
for i in range(1,len(coeffs)-1):
    for j in range(len(coeffs[i])):
        coeff = coeffs[i][j].cpu().numpy().squeeze(0)
        coeff = coeff[...,0] + 1j * coeff[...,1]
Ejemplo n.º 11
0
class Steerable_Pyramid_Phase(object):
    def __init__(self,
                 height=5,
                 nbands=4,
                 scale_factor=2,
                 device=None,
                 extract_level=1,
                 visualize=False):
        self.pyramid = SCFpyr_PyTorch(height=height,
                                      nbands=nbands,
                                      scale_factor=scale_factor,
                                      device=device)
        self.height = height
        self.nbands = nbands
        self.scale_factor = scale_factor
        self.device = device
        self.extract_level = extract_level
        self.visualize = visualize

    def build_pyramid(self, im_batch, symmetry=True):
        """
        input image batch has 4 dimensions: batch size,  number of phase images, W, H
        """
        bs, num_phase_frames, W, H = im_batch.size()
        trans_im_batch = im_batch.view(
            bs * num_phase_frames, 1, W,
            H)  # the second dim is 1, indicating it's grayscale image
        if symmetry:
            trans_im_batch = symmetric_extension_batch(trans_im_batch)
        #tic= time()
        coeff_batch = self.pyramid.build(trans_im_batch)
        #print("process {} images for {}".format(bs*num_phase_frames, time()-tic))
        if not isinstance(coeff_batch, list):
            raise ValueError('Batch of coefficients must be a list')

        if self.visualize:
            example_id = 10  # the 10th image from number of phase images
            example_coeff = extract_from_batch(coeff_batch, example_id,
                                               symmetry)
            example_coeff = make_grid_coeff(example_coeff)
            example_coeff = Image.fromarray(example_coeff)
            example_img = trans_im_batch[example_id, 0, ...].cpu().numpy()
            example_img = Image.fromarray(255 * example_img /
                                          example_img.max())
            example_img.show()
            example_img_remove_symm = trans_im_batch[example_id, 0,
                                                     ...].cpu().numpy()
            example_img_remove_symm = 255 * example_img_remove_symm / example_img_remove_symm.max(
            )
            if symmetry:
                W, H = example_img_remove_symm.shape
                example_img_remove_symm = example_img_remove_symm[:W //
                                                                  2, :H // 2]
                example_img_remove_symm = Image.fromarray(
                    example_img_remove_symm)
                example_img_remove_symm.show()
            example_coeff.show()
        if isinstance(self.extract_level, int):
            extr_level_coeff_batch = self.extract_coeff_level(
                self.extract_level, coeff_batch)
            W, H, _ = extr_level_coeff_batch.size()[-3:]
            nbands = extr_level_coeff_batch.size()[0]
            extr_level_coeff_batch = extr_level_coeff_batch.view(
                nbands, bs, num_phase_frames, W, H, 2)
            extr_level_coeff_batch = extr_level_coeff_batch.permute(
                1, 0, 2, 3, 4, 5).contiguous()
            if symmetry:
                extr_level_coeff_batch = extr_level_coeff_batch[..., :W //
                                                                2, :H // 2, :]
        elif isinstance(self.extract_level, list):
            extr_level_coeff_batch = []
            for level in self.extract_level:
                level_coeff_batch = self.extract_coeff_level(
                    level, coeff_batch)
                W, H, _ = level_coeff_batch.size()[-3:]
                nbands = level_coeff_batch.size()[0]
                level_coeff_batch = level_coeff_batch.view(
                    nbands, bs, num_phase_frames, W, H, 2)
                level_coeff_batch = level_coeff_batch.permute(
                    1, 0, 2, 3, 4, 5).contiguous()
                if symmetry:
                    level_coeff_batch = level_coeff_batch[..., :W // 2, :H //
                                                          2, :]
                extr_level_coeff_batch.append(level_coeff_batch)
        return extr_level_coeff_batch

    def extract_coeff_level(self, level, coeff_batch):
        extr_level_coeff_batch = coeff_batch[level]
        assert isinstance(extr_level_coeff_batch, list)
        extr_level_coeff_batch = torch.stack(extr_level_coeff_batch, 0)
        return extr_level_coeff_batch

    def extract_phase(self,
                      coeff_batch,
                      return_phase=False,
                      return_both=False):
        """
        coeff batch has dimension: batch size, nbands, number phase frames (17), W, H, 2   (2 is for real part and imaginary part) 
        """
        bs, n_bands, n_phase_frames, W, H, _ = coeff_batch.size()
        trans_coeff_batch = coeff_batch.view(bs * n_bands * n_phase_frames, W,
                                             H, -1)
        real_coeff_batch, imag_coeff_batch = torch.unbind(
            trans_coeff_batch, -1)
        phase_batch = torch.atan2(imag_coeff_batch, real_coeff_batch)
        mag_batch = torch.sqrt(
            torch.pow(imag_coeff_batch, 2) + torch.pow(real_coeff_batch, 2))
        phase_batch = phase_batch.view(bs * n_bands, n_phase_frames, W, H)
        EPS = 1e-10
        mag_batch = mag_batch.view(bs * n_bands, n_phase_frames, W,
                                   H) + EPS  # TO avoid mag==0
        assert (mag_batch <= 0.0).nonzero().size(0) == 0

        # phase unwrap over time
        phase_batch = torch_unwrap(phase_batch, discont=math.pi, dim=-3)
        # phase denoising (amplitude-based gaussian blur)
        g_kernel = torch.from_numpy(gaussian_kernel(std=2, tap=11))
        #denoised_phase_batch = amplitude_based_gaussian_blur_numpy(mag_batch, phase_batch, g_kernel)
        denoised_phase_batch = amplitude_based_gaussian_blur(
            mag_batch, phase_batch, g_kernel)
        denoised_phase_batch = denoised_phase_batch.view(
            bs, n_bands, n_phase_frames, W, H)
        # phase difference
        phase_difference_batch = torch_diff(denoised_phase_batch, dim=2)
        phase_difference_batch = phase_difference_batch.view(
            bs, n_bands, n_phase_frames - 1, W, H)
        if self.visualize:
            phase_example = phase_batch.view(bs, n_bands, n_phase_frames, W,
                                             H)[0, ...]
            mag_example = mag_batch.view(bs, n_bands, n_phase_frames, W,
                                         H)[0, ...]
            denoised_phase_example = denoised_phase_batch.view(
                bs, n_bands, n_phase_frames, W, H)[0, ...]
            phase_diff_example = phase_difference_batch.view(
                bs, n_bands, n_phase_frames - 1, W, H)[0, ...]
            self.show_3D_subplots(phase_example,
                                  title="phase example",
                                  first_k_frames=2)
            self.show_3D_subplots(mag_example,
                                  title="magnitude example",
                                  first_k_frames=2)
            self.show_3D_subplots(denoised_phase_example,
                                  title="denoised phase example",
                                  first_k_frames=2)
            self.show_3D_subplots(phase_diff_example,
                                  title="phase difference example",
                                  first_k_frames=2)
        # denoised phase centered
        mean = denoised_phase_batch.mean(-1).mean(-1)
        mean = mean.unsqueeze(-1).unsqueeze(-1)
        denoised_phase_batch = denoised_phase_batch - mean
        mean = phase_difference_batch.mean(-1).mean(-1)
        mean = mean.unsqueeze(-1).unsqueeze(-1)
        phase_difference_batch = phase_difference_batch - mean
        phase_difference_batch = torch.clamp(phase_difference_batch,
                                             -5 * math.pi, 5 * math.pi)
        if return_both:
            # remove one phase image
            denoised_phase_batch = denoised_phase_batch[:, :, 1:, :]
            assert phase_difference_batch.size() == denoised_phase_batch.size()
            result = self.insert_tensors(phase_difference_batch,
                                         denoised_phase_batch,
                                         dim=2)
            result = result.cuda()
            return result
        if return_phase:
            return denoised_phase_batch
        else:
            return phase_difference_batch

    def insert_tensors(self, t_a, t_b, dim):
        size = list(t_a.size())
        size[dim] = 2 * size[dim]
        result = torch.zeros(size)
        length = t_a.size(dim)
        for i in range(length):
            slice0 = [slice(None, None)] * len(size)
            slice0[dim] = slice(i, i + 1)
            slice1 = [slice(None, None)] * len(size)
            slice1[dim] = slice(i // 2, i // 2 + 1)
            if i % 2 == 0:
                result[slice0] = t_a[slice1]
            else:
                result[slice0] = t_b[slice1]
        return result

    def show_3D_subplots(self, data, title, first_k_frames=None):
        """
        data has dimensions: nbands, n_phase_frames, W, H
        """
        nbands, n_phase_frames, W, H = data.size()
        m = nbands
        n = first_k_frames if first_k_frames is not None else n_phase_frames

        X, Y = range(1, W + 1), range(1, H + 1)
        Xm, Ym = np.meshgrid(X, Y)
        for i in range(m):
            fig, ax = plt.subplots(nrows=1,
                                   ncols=n,
                                   subplot_kw={'projection': "3d"})
            for j in range(n):
                img = data[i, j, ...].cpu().numpy()
                surf = ax[j].plot_surface(Xm,
                                          Ym,
                                          img,
                                          rstride=1,
                                          cstride=1,
                                          cmap=cm.coolwarm,
                                          linewidth=0,
                                          antialiased=False)
            fig.colorbar(surf, shrink=0.5, aspect=10)
            fig.suptitle(title + ": orientation {}".format(i))
        plt.show()
class Phase_Difference_Extractor(object):
    def __init__(self,
                 height=5,
                 nbands=4,
                 scale_factor=2,
                 extract_level=1,
                 device='cuda:0',
                 visualize=False):
        """
        Phase_Difference_Extractor: A class to do steerable pyramid computation,
        extract the phase and phase difference.
        
        Parameters:
            height: int, default 5
                The coefficients levels including low-pass and high-pass
            nbands: int, default 4
                The number of orientations of the bandpass filters
            scale_factor: int, default 2
                Spatial resolution reduction scale scale_factor
            extract_level: int, or list of int numbers, default 1
                If extract_level is an int number, build_pyramid() will only 
                return the coefficients in one level;
                If extract_level is a list, build_pyramid() will only return 
                the coefficients of multiple levels.
            visualize: bool, default False
               If true, the build_pyramid() and extract() will show the processed results.
        """

        self.pyramid = SCFpyr_PyTorch(height=height,
                                      nbands=nbands,
                                      scale_factor=scale_factor,
                                      device=device)
        self.height = height
        self.nbands = nbands
        self.scale_factor = scale_factor
        self.extract_level = extract_level
        self.visualize = visualize

    def build_pyramid(self, im_batch, symmetry=True):
        """
        input image batch has 4 dimensions: batch size, number of phase images, W, H
        """
        bs, num_phase_frames, W, H = im_batch.size()
        trans_im_batch = im_batch.view(
            bs * num_phase_frames, 1, W,
            H)  # the second dim is 1, indicating it's grayscale image
        if symmetry:
            trans_im_batch = symmetric_extension_batch(trans_im_batch)
        #tic= time()
        coeff_batch = self.pyramid.build(trans_im_batch)
        #print("process {} images for {}".format(bs*num_phase_frames, time()-tic))
        if not isinstance(coeff_batch, list):
            raise ValueError('Batch of coefficients must be a list')

        if self.visualize:
            example_id = 10  # the 10th image from number of phase images
            example_coeff = extract_from_batch(coeff_batch, example_id,
                                               symmetry)
            example_coeff = make_grid_coeff(example_coeff)
            example_coeff = Image.fromarray(example_coeff)
            example_img = trans_im_batch[example_id, 0, ...].cpu().numpy()
            example_img = Image.fromarray(255 * example_img /
                                          example_img.max())
            example_img.show()
            example_img_remove_symm = trans_im_batch[example_id, 0,
                                                     ...].cpu().numpy()
            example_img_remove_symm = 255 * example_img_remove_symm / example_img_remove_symm.max(
            )
            if symmetry:
                W, H = example_img_remove_symm.shape
                example_img_remove_symm = example_img_remove_symm[:W //
                                                                  2, :H // 2]
                example_img_remove_symm = Image.fromarray(
                    example_img_remove_symm)
                example_img_remove_symm.show()
            example_coeff.show()
        if isinstance(self.extract_level, int):
            extr_level_coeff_batch = self.extract_coeff_level(
                self.extract_level, coeff_batch)
            W, H, _ = extr_level_coeff_batch.size()[-3:]
            nbands = extr_level_coeff_batch.size()[0]
            extr_level_coeff_batch = extr_level_coeff_batch.view(
                nbands, bs, num_phase_frames, W, H, 2)
            extr_level_coeff_batch = extr_level_coeff_batch.permute(
                1, 0, 2, 3, 4, 5).contiguous()
            if symmetry:
                extr_level_coeff_batch = extr_level_coeff_batch[..., :W //
                                                                2, :H // 2, :]
        elif isinstance(self.extract_level, list):
            extr_level_coeff_batch = []
            for level in self.extract_level:
                level_coeff_batch = self.extract_coeff_level(
                    level, coeff_batch)
                W, H, _ = level_coeff_batch.size()[-3:]
                nbands = level_coeff_batch.size()[0]
                level_coeff_batch = level_coeff_batch.view(
                    nbands, bs, num_phase_frames, W, H, 2)
                level_coeff_batch = level_coeff_batch.permute(
                    1, 0, 2, 3, 4, 5).contiguous()
                if symmetry:
                    level_coeff_batch = level_coeff_batch[..., :W // 2, :H //
                                                          2, :]
                extr_level_coeff_batch.append(level_coeff_batch)
        return extr_level_coeff_batch

    def extract_coeff_level(self, level, coeff_batch):
        extr_level_coeff_batch = coeff_batch[level]
        assert isinstance(extr_level_coeff_batch, list)
        extr_level_coeff_batch = torch.stack(extr_level_coeff_batch, 0)
        return extr_level_coeff_batch

    def extract(self, coeff_batch):
        """
        coeff batch has dimension: batch size, nbands, number phase frames (17), W, H, 2   (2 is for real part and imaginary part) 
        """
        bs, n_bands, n_phase_frames, W, H, _ = coeff_batch.size()
        trans_coeff_batch = coeff_batch.view(bs * n_bands * n_phase_frames, W,
                                             H, -1)
        real_coeff_batch, imag_coeff_batch = torch.unbind(
            trans_coeff_batch, -1)
        phase_batch = torch.atan2(imag_coeff_batch, real_coeff_batch)
        mag_batch = torch.sqrt(
            torch.pow(imag_coeff_batch, 2) + torch.pow(real_coeff_batch, 2))
        phase_batch = phase_batch.view(bs * n_bands, n_phase_frames, W, H)
        EPS = 1e-10
        mag_batch = mag_batch.view(bs * n_bands, n_phase_frames, W,
                                   H) + EPS  # TO avoid mag==0
        assert (mag_batch <= 0.0).nonzero().size(0) == 0

        # phase unwrap over time
        phase_batch = torch_unwrap(phase_batch, discont=math.pi, dim=-3)
        # phase denoising (amplitude-based gaussian blur)
        g_kernel = torch.from_numpy(gaussian_kernel(std=2, tap=11))
        #denoised_phase_batch = amplitude_based_gaussian_blur_numpy(mag_batch, phase_batch, g_kernel)
        denoised_phase_batch = amplitude_based_gaussian_blur(
            mag_batch, phase_batch, g_kernel)
        denoised_phase_batch = denoised_phase_batch.view(
            bs, n_bands, n_phase_frames, W, H)
        # phase difference
        phase_difference_batch = torch_diff(denoised_phase_batch, dim=2)
        phase_difference_batch = phase_difference_batch.view(
            bs, n_bands, n_phase_frames - 1, W, H)
        if self.visualize:
            phase_example = phase_batch.view(bs, n_bands, n_phase_frames, W,
                                             H)[0, ...]
            mag_example = mag_batch.view(bs, n_bands, n_phase_frames, W,
                                         H)[0, ...]
            denoised_phase_example = denoised_phase_batch.view(
                bs, n_bands, n_phase_frames, W, H)[0, ...]
            phase_diff_example = phase_difference_batch.view(
                bs, n_bands, n_phase_frames - 1, W, H)[0, ...]
            self.show_3D_subplots(phase_example,
                                  title="phase example",
                                  first_k_frames=2)
            self.show_3D_subplots(mag_example,
                                  title="magnitude example",
                                  first_k_frames=2)
            self.show_3D_subplots(denoised_phase_example,
                                  title="denoised phase example",
                                  first_k_frames=2)
            self.show_3D_subplots(phase_diff_example,
                                  title="phase difference example",
                                  first_k_frames=2)
        # denoised phase centered
        mean = denoised_phase_batch.mean(-1).mean(-1)
        mean = mean.unsqueeze(-1).unsqueeze(-1)
        denoised_phase_batch = denoised_phase_batch - mean
        mean = phase_difference_batch.mean(-1).mean(-1)
        mean = mean.unsqueeze(-1).unsqueeze(-1)
        phase_difference_batch = phase_difference_batch - mean
        phase_difference_batch = torch.clamp(phase_difference_batch,
                                             -5 * math.pi, 5 * math.pi)
        return phase_difference_batch

    def show_3D_subplots(self, data, title, first_k_frames=None):
        """
        data has dimensions: nbands, n_phase_frames, W, H
        """
        nbands, n_phase_frames, W, H = data.size()
        m = nbands
        n = first_k_frames if first_k_frames is not None else n_phase_frames

        X, Y = range(1, W + 1), range(1, H + 1)
        Xm, Ym = np.meshgrid(X, Y)
        for i in range(m):
            fig, ax = plt.subplots(nrows=1,
                                   ncols=n,
                                   subplot_kw={'projection': "3d"})
            for j in range(n):
                img = data[i, j, ...].cpu().numpy()
                surf = ax[j].plot_surface(Xm,
                                          Ym,
                                          img,
                                          rstride=1,
                                          cstride=1,
                                          cmap=cm.coolwarm,
                                          linewidth=0,
                                          antialiased=False)
            fig.colorbar(surf, shrink=0.5, aspect=10)
            fig.suptitle(title + ": orientation {}".format(i))
        plt.show()