def __init__(self, height=5, nbands=4, scale_factor=2, extract_level=1, visualize=False): '''Phase_Difference_Extractor: A class to do steerable pyramid computation, extract the phase and phase difference Usage: build_pyramid(): build complex steerable pyramid coefficients extract(): extract phase differences Parameters: height: int, default 5 The coefficients levels including low-pass and high-pass nbands: int, default 4 The number of orientations of the bandpass filters scale_factor: int, default 2 Spatial resolution reduction scale scale_factor extract_level: int, or list of int numbers, default 1 If extract_level is an int number, build_pyramid() will only return the coefficients in one level; If extract_level is a list, build_pyramid() will only return the coefficients of multiple levels. visualize: bool, default False If true, the build_pyramid() and extract() will show the processed results. ''' self.pyramid = SCFpyr_PyTorch(height=height, nbands=nbands, scale_factor=scale_factor, device=get_device()) self.height = height self.nbands = nbands self.scale_factor = scale_factor self.extract_level = extract_level self.visualize = visualize
def STSIM(self, img1, img2, sub_sample=True): assert img1.shape == img2.shape assert len(img1.shape) == 4 # [N,C,H,W] assert img1.shape[1] == 1 # gray image s = SCFpyr_PyTorch(sub_sample=sub_sample, device=self.device) pyrA = s.getlist(s.build(img1)) pyrB = s.getlist(s.build(img2)) stsim = map(self.pooling, pyrA, pyrB) return torch.mean(torch.stack(list(stsim)), dim=0)
def __init__(self, height=5, nbands=4, scale_factor=2, device=None, extract_level=1, visualize=False): self.pyramid = SCFpyr_PyTorch( height=height, nbands=nbands, scale_factor=scale_factor, device=device ) self.height = height self.nbands = nbands self.scale_factor = scale_factor self.device = device self.extract_level = extract_level self.visualize = visualize
def STSIM_M(self, imgs): ''' :param imgs: [N,C=1,H,W] :return: ''' s = SCFpyr_PyTorch(sub_sample=True, device=self.device) coeffs = s.build(imgs) f = [] # single subband statistics for c in s.getlist(coeffs): c = self.abs(c) var = torch.var(c, dim=[1, 2, 3]) f.append(torch.mean(c, dim=[1, 2, 3])) f.append(var) f.append( torch.mean(c[:, :, :-1, :] * c[:, :, 1:, :], dim=[1, 2, 3]) / var) f.append( torch.mean(c[:, :, :, :-1] * c[:, :, :, 1:], dim=[1, 2, 3]) / var) # correlation statistics # across orientations for orients in coeffs[1:-1]: for (c1, c2) in list(itertools.combinations(orients, 2)): c1 = self.abs(c1) c2 = self.abs(c2) f.append(torch.mean(c1 * c2, dim=[1, 2, 3])) for orient in range(len(coeffs[1])): for height in range(len(coeffs) - 3): c1 = self.abs(coeffs[height + 1][orient]) c2 = self.abs(coeffs[height + 2][orient]) c1 = F.interpolate(c1, size=c2.shape[2:]) f.append( torch.mean(c1 * c2, dim=[1, 2, 3]) / torch.sqrt(torch.var(c1, dim=[1, 2, 3])) / torch.sqrt(torch.var(c2, dim=[1, 2, 3]))) return torch.stack(f)
def STSIM2(self, img1, img2): assert img1.shape == img2.shape s = SCFpyr_PyTorch(sub_sample=True, device=self.device) s_nosub = SCFpyr_PyTorch(sub_sample=False, device=self.device) pyrA = s.getlist(s.build(img1)) pyrB = s.getlist(s.build(img2)) stsimg2 = list(map(self.pooling, pyrA, pyrB)) # Add cross terms bandsAn = s_nosub.build(img1) bandsBn = s_nosub.build(img2) Nor = len(bandsAn[1]) # Accross scale, same orientation for scale in range(2, len(bandsAn) - 1): for orient in range(Nor): img11 = self.abs(bandsAn[scale - 1][orient]) img12 = self.abs(bandsAn[scale][orient]) img21 = self.abs(bandsBn[scale - 1][orient]) img22 = self.abs(bandsBn[scale][orient]) stsimg2.append( self.compute_cross_term(img11, img12, img21, img22).mean(dim=[1, 2, 3])) # Accross orientation, same scale for scale in range(1, len(bandsAn) - 1): for orient in range(Nor - 1): img11 = self.abs(bandsAn[scale][orient]) img21 = self.abs(bandsBn[scale][orient]) for orient2 in range(orient + 1, Nor): img13 = self.abs(bandsAn[scale][orient2]) img23 = self.abs(bandsBn[scale][orient2]) stsimg2.append( self.compute_cross_term(img11, img13, img21, img23).mean(dim=[1, 2, 3])) return torch.mean(torch.stack(stsimg2), dim=0)
# Requires PyTorch with MKL when setting to 'cpu' device = torch.device('cpu') # Load batch of images [N,1,H,W] im_batch_numpy = utils.load_image_batch('./assets/lena.jpg',32,600) img=cv2.imread('./assets/lena.jpg',0) cv2.imshow('yuantu',img) im_torch = torch.from_numpy(img).to(device) im_batch_torch=im_torch.unsqueeze(0).unsqueeze(0).float() # Initialize Complex Steerbale Pyramid height = 12 nbands = 4 scale_factor = 2**(1/2) pyr = SCFpyr_PyTorch(height=height, nbands=nbands, scale_factor=scale_factor, device=device) pyr_type = 1 # Decompose entire batch of images coeff = pyr.build(im_batch_torch,pyr_type) # Reconstruct batch of images again img_recon = pyr.reconstruct(coeff,pyr_type) img=im_torch.float() recon=img_recon.squeeze() loss=torch.nn.MSELoss() print('MSE:',loss(img,recon)) cv2.imshow('recon',recon.numpy().astype(np.uint8)) # Visualization # coeff_single = utils.extract_from_batch(coeff, 0) # coeff_grid = utils.make_grid_coeff(coeff_single, normalize=True)
parser.add_argument('--batch_size', type=int, default='32') parser.add_argument('--image_size', type=int, default='200') parser.add_argument('--pyr_nlevels', type=int, default='5') parser.add_argument('--pyr_nbands', type=int, default='4') parser.add_argument('--pyr_scale_factor', type=int, default='2') parser.add_argument('--device', type=str, default='cuda:0') parser.add_argument('--visualize', type=bool, default=True) config = parser.parse_args() device = utils.get_device(config.device) ############################################################################ # Build the complex steerable pyramid pyr = SCFpyr_PyTorch(height=config.pyr_nlevels, nbands=config.pyr_nbands, scale_factor=config.pyr_scale_factor, device=device) ############################################################################ # Create a batch and feed-forward start_time = time.time() # Load Batch im_batch_numpy = utils.load_image_batch(config.image_file, config.batch_size, config.image_size) im_batch_torch = torch.from_numpy(im_batch_numpy).to(device) # Compute Steerable Pyramid coeff = pyr.build(im_batch_torch)
config.batch_sizes = list(map(int, config.batch_sizes.split(','))) config.image_sizes = list(map(int, config.image_sizes.split(','))) device = utils.get_device(config.device) ################################################################################ pyr_numpy = SCFpyr_NumPy(height=config.pyr_nlevels, nbands=config.pyr_nbands, scale_factor=config.pyr_scale_factor, precision=config.precision) pyr_torch = SCFpyr_PyTorch(height=config.pyr_nlevels, nbands=config.pyr_nbands, scale_factor=config.pyr_scale_factor, device=device, precision=config.precision) pyr_tf = SCFpyr_TF(height=config.pyr_nlevels, nbands=config.pyr_nbands, scale_factor=config.pyr_scale_factor, precision=config.precision) ############################################################################ # Run Benchmark durations_numpy = np.zeros( (len(config.batch_sizes), len(config.image_sizes), config.num_runs)) durations_torch = np.zeros( (len(config.batch_sizes), len(config.image_sizes), config.num_runs)) durations_tf = np.zeros(
num_epochs = 2 learning_rate = 0.001 batch_size = 8 # pyr parameter height = 12 nbands = 4 scale_factor = 2**(1/2) pyr_type = 1 # Load dataset transform = transforms.Compose( [transforms.Resize((256, 256)), transforms.ToTensor()]) dataset = Triplets( '/home/lj/Documents/code/python/DAVIS/JPEGImages/480p/', transform) pyr = SCFpyr_PyTorch(height=height, nbands=nbands, scale_factor=scale_factor, device=device) # define network model = PhaseNet() criterion = Total_loss(v=1.0) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9,0.999)) # Train the model total_step = 0 for epoch in range(num_epochs): trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4) for channel in range(3): for n, Triplets_batch in enumerate(trainloader): # get images_list [[N,C,H,W],[N,C,H,W],...], usually len(images_list)=batch_size if len(dataset)%batch_size==0 images_list = [torch.stack([Triplets_batch['start'][i], Triplets_batch['inter'][i],
pyr_numpy = SCFpyr_NumPy(pyr_height, pyr_nbands, scale_factor=2) coeff_numpy = pyr_numpy.build(image) reconstruction_numpy = pyr_numpy.reconstruct(coeff_numpy) reconstruction_numpy = reconstruction_numpy.astype(np.uint8) print('#' * 60) ################################################################################ # PyTorch device = torch.device('cuda:0') im_batch = torch.from_numpy(image[None, None, :, :]) im_batch = im_batch.to(device).float() pyr_torch = SCFpyr_PyTorch(pyr_height, pyr_nbands, device=device) coeff_torch = pyr_torch.build(im_batch) reconstruction_torch = pyr_torch.reconstruct(coeff_torch) reconstruction_torch = reconstruction_torch.cpu().numpy()[0, ] # Extract first example from the batch and move to CPU coeff_torch = utils.extract_from_batch(coeff_torch, 0) ################################################################################ # Check correctness print('#' * 60) assert len(coeff_numpy) == len(coeff_torch) for level, _ in enumerate(coeff_numpy):
import numpy as np import cv2 import torch import sys sys.path.append('..') from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch from perceptual.filterbank import SteerableNoSub device = torch.device('cuda:0') pyr_NoSub = SCFpyr_PyTorch(sub_sample=False, device=device) path = '' img = cv2.imread(path, 0) img_batch = torch.from_numpy(img).to(device) img_batch = img_batch.unsqueeze(0).float().unsqueeze(0) coeffs = pyr_NoSub.build(img_batch) pyr_NoSub_c = SteerableNoSub() coeffs_c = pyr_NoSub_c.buildSCFpyr(img) tolerance = 1e-3 coeff = coeffs[0].cpu().numpy().squeeze(0) all_close = np.allclose(coeff, coeffs_c[0], atol=tolerance) s = np.sum(coeff - coeffs_c[0]) print('Succesful for subband {}: {}, with tolerance of {}'.format( 0, all_close, tolerance)) print('Sum of difference: {}'.format(s)) for i in range(1, len(coeffs) - 1): for j in range(len(coeffs[i])):
import numpy as np import cv2 import torch import sys sys.path.append('..') from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch from perceptual.filterbank import Steerable device = torch.device('cuda:0') pyr = SCFpyr_PyTorch(sub_sample = True, device = device) path = '' img = cv2.imread(path,0) img_batch = torch.from_numpy(img).to(device) img_batch = img_batch.unsqueeze(0).float().unsqueeze(0) coeffs = pyr.build(img_batch) pyr_c = Steerable() coeffs_c = pyr_c.buildSCFpyr(img) tolerance = 1e-3 coeff = coeffs[0].cpu().numpy().squeeze(0) all_close = np.allclose(coeff, coeffs_c[0], atol=tolerance) s = np.sum(coeff-coeffs_c[0]) print('Succesful for subband {}: {}, with tolerance of {}'.format(0,all_close, tolerance)) print('Sum of difference: {}'.format(s)) for i in range(1,len(coeffs)-1):
class Steerable_Pyramid_Phase(object): def __init__(self, height=5, nbands=4, scale_factor=2, device=None, extract_level=1, visualize=False): self.pyramid = SCFpyr_PyTorch(height=height, nbands=nbands, scale_factor=scale_factor, device=device) self.height = height self.nbands = nbands self.scale_factor = scale_factor self.device = device self.extract_level = extract_level self.visualize = visualize def build_pyramid(self, im_batch, symmetry=True): """ input image batch has 4 dimensions: batch size, number of phase images, W, H """ bs, num_phase_frames, W, H = im_batch.size() trans_im_batch = im_batch.view( bs * num_phase_frames, 1, W, H) # the second dim is 1, indicating it's grayscale image if symmetry: trans_im_batch = symmetric_extension_batch(trans_im_batch) #tic= time() coeff_batch = self.pyramid.build(trans_im_batch) #print("process {} images for {}".format(bs*num_phase_frames, time()-tic)) if not isinstance(coeff_batch, list): raise ValueError('Batch of coefficients must be a list') if self.visualize: example_id = 10 # the 10th image from number of phase images example_coeff = extract_from_batch(coeff_batch, example_id, symmetry) example_coeff = make_grid_coeff(example_coeff) example_coeff = Image.fromarray(example_coeff) example_img = trans_im_batch[example_id, 0, ...].cpu().numpy() example_img = Image.fromarray(255 * example_img / example_img.max()) example_img.show() example_img_remove_symm = trans_im_batch[example_id, 0, ...].cpu().numpy() example_img_remove_symm = 255 * example_img_remove_symm / example_img_remove_symm.max( ) if symmetry: W, H = example_img_remove_symm.shape example_img_remove_symm = example_img_remove_symm[:W // 2, :H // 2] example_img_remove_symm = Image.fromarray( example_img_remove_symm) example_img_remove_symm.show() example_coeff.show() if isinstance(self.extract_level, int): extr_level_coeff_batch = self.extract_coeff_level( self.extract_level, coeff_batch) W, H, _ = extr_level_coeff_batch.size()[-3:] nbands = extr_level_coeff_batch.size()[0] extr_level_coeff_batch = extr_level_coeff_batch.view( nbands, bs, num_phase_frames, W, H, 2) extr_level_coeff_batch = extr_level_coeff_batch.permute( 1, 0, 2, 3, 4, 5).contiguous() if symmetry: extr_level_coeff_batch = extr_level_coeff_batch[..., :W // 2, :H // 2, :] elif isinstance(self.extract_level, list): extr_level_coeff_batch = [] for level in self.extract_level: level_coeff_batch = self.extract_coeff_level( level, coeff_batch) W, H, _ = level_coeff_batch.size()[-3:] nbands = level_coeff_batch.size()[0] level_coeff_batch = level_coeff_batch.view( nbands, bs, num_phase_frames, W, H, 2) level_coeff_batch = level_coeff_batch.permute( 1, 0, 2, 3, 4, 5).contiguous() if symmetry: level_coeff_batch = level_coeff_batch[..., :W // 2, :H // 2, :] extr_level_coeff_batch.append(level_coeff_batch) return extr_level_coeff_batch def extract_coeff_level(self, level, coeff_batch): extr_level_coeff_batch = coeff_batch[level] assert isinstance(extr_level_coeff_batch, list) extr_level_coeff_batch = torch.stack(extr_level_coeff_batch, 0) return extr_level_coeff_batch def extract_phase(self, coeff_batch, return_phase=False, return_both=False): """ coeff batch has dimension: batch size, nbands, number phase frames (17), W, H, 2 (2 is for real part and imaginary part) """ bs, n_bands, n_phase_frames, W, H, _ = coeff_batch.size() trans_coeff_batch = coeff_batch.view(bs * n_bands * n_phase_frames, W, H, -1) real_coeff_batch, imag_coeff_batch = torch.unbind( trans_coeff_batch, -1) phase_batch = torch.atan2(imag_coeff_batch, real_coeff_batch) mag_batch = torch.sqrt( torch.pow(imag_coeff_batch, 2) + torch.pow(real_coeff_batch, 2)) phase_batch = phase_batch.view(bs * n_bands, n_phase_frames, W, H) EPS = 1e-10 mag_batch = mag_batch.view(bs * n_bands, n_phase_frames, W, H) + EPS # TO avoid mag==0 assert (mag_batch <= 0.0).nonzero().size(0) == 0 # phase unwrap over time phase_batch = torch_unwrap(phase_batch, discont=math.pi, dim=-3) # phase denoising (amplitude-based gaussian blur) g_kernel = torch.from_numpy(gaussian_kernel(std=2, tap=11)) #denoised_phase_batch = amplitude_based_gaussian_blur_numpy(mag_batch, phase_batch, g_kernel) denoised_phase_batch = amplitude_based_gaussian_blur( mag_batch, phase_batch, g_kernel) denoised_phase_batch = denoised_phase_batch.view( bs, n_bands, n_phase_frames, W, H) # phase difference phase_difference_batch = torch_diff(denoised_phase_batch, dim=2) phase_difference_batch = phase_difference_batch.view( bs, n_bands, n_phase_frames - 1, W, H) if self.visualize: phase_example = phase_batch.view(bs, n_bands, n_phase_frames, W, H)[0, ...] mag_example = mag_batch.view(bs, n_bands, n_phase_frames, W, H)[0, ...] denoised_phase_example = denoised_phase_batch.view( bs, n_bands, n_phase_frames, W, H)[0, ...] phase_diff_example = phase_difference_batch.view( bs, n_bands, n_phase_frames - 1, W, H)[0, ...] self.show_3D_subplots(phase_example, title="phase example", first_k_frames=2) self.show_3D_subplots(mag_example, title="magnitude example", first_k_frames=2) self.show_3D_subplots(denoised_phase_example, title="denoised phase example", first_k_frames=2) self.show_3D_subplots(phase_diff_example, title="phase difference example", first_k_frames=2) # denoised phase centered mean = denoised_phase_batch.mean(-1).mean(-1) mean = mean.unsqueeze(-1).unsqueeze(-1) denoised_phase_batch = denoised_phase_batch - mean mean = phase_difference_batch.mean(-1).mean(-1) mean = mean.unsqueeze(-1).unsqueeze(-1) phase_difference_batch = phase_difference_batch - mean phase_difference_batch = torch.clamp(phase_difference_batch, -5 * math.pi, 5 * math.pi) if return_both: # remove one phase image denoised_phase_batch = denoised_phase_batch[:, :, 1:, :] assert phase_difference_batch.size() == denoised_phase_batch.size() result = self.insert_tensors(phase_difference_batch, denoised_phase_batch, dim=2) result = result.cuda() return result if return_phase: return denoised_phase_batch else: return phase_difference_batch def insert_tensors(self, t_a, t_b, dim): size = list(t_a.size()) size[dim] = 2 * size[dim] result = torch.zeros(size) length = t_a.size(dim) for i in range(length): slice0 = [slice(None, None)] * len(size) slice0[dim] = slice(i, i + 1) slice1 = [slice(None, None)] * len(size) slice1[dim] = slice(i // 2, i // 2 + 1) if i % 2 == 0: result[slice0] = t_a[slice1] else: result[slice0] = t_b[slice1] return result def show_3D_subplots(self, data, title, first_k_frames=None): """ data has dimensions: nbands, n_phase_frames, W, H """ nbands, n_phase_frames, W, H = data.size() m = nbands n = first_k_frames if first_k_frames is not None else n_phase_frames X, Y = range(1, W + 1), range(1, H + 1) Xm, Ym = np.meshgrid(X, Y) for i in range(m): fig, ax = plt.subplots(nrows=1, ncols=n, subplot_kw={'projection': "3d"}) for j in range(n): img = data[i, j, ...].cpu().numpy() surf = ax[j].plot_surface(Xm, Ym, img, rstride=1, cstride=1, cmap=cm.coolwarm, linewidth=0, antialiased=False) fig.colorbar(surf, shrink=0.5, aspect=10) fig.suptitle(title + ": orientation {}".format(i)) plt.show()
class Phase_Difference_Extractor(object): def __init__(self, height=5, nbands=4, scale_factor=2, extract_level=1, device='cuda:0', visualize=False): """ Phase_Difference_Extractor: A class to do steerable pyramid computation, extract the phase and phase difference. Parameters: height: int, default 5 The coefficients levels including low-pass and high-pass nbands: int, default 4 The number of orientations of the bandpass filters scale_factor: int, default 2 Spatial resolution reduction scale scale_factor extract_level: int, or list of int numbers, default 1 If extract_level is an int number, build_pyramid() will only return the coefficients in one level; If extract_level is a list, build_pyramid() will only return the coefficients of multiple levels. visualize: bool, default False If true, the build_pyramid() and extract() will show the processed results. """ self.pyramid = SCFpyr_PyTorch(height=height, nbands=nbands, scale_factor=scale_factor, device=device) self.height = height self.nbands = nbands self.scale_factor = scale_factor self.extract_level = extract_level self.visualize = visualize def build_pyramid(self, im_batch, symmetry=True): """ input image batch has 4 dimensions: batch size, number of phase images, W, H """ bs, num_phase_frames, W, H = im_batch.size() trans_im_batch = im_batch.view( bs * num_phase_frames, 1, W, H) # the second dim is 1, indicating it's grayscale image if symmetry: trans_im_batch = symmetric_extension_batch(trans_im_batch) #tic= time() coeff_batch = self.pyramid.build(trans_im_batch) #print("process {} images for {}".format(bs*num_phase_frames, time()-tic)) if not isinstance(coeff_batch, list): raise ValueError('Batch of coefficients must be a list') if self.visualize: example_id = 10 # the 10th image from number of phase images example_coeff = extract_from_batch(coeff_batch, example_id, symmetry) example_coeff = make_grid_coeff(example_coeff) example_coeff = Image.fromarray(example_coeff) example_img = trans_im_batch[example_id, 0, ...].cpu().numpy() example_img = Image.fromarray(255 * example_img / example_img.max()) example_img.show() example_img_remove_symm = trans_im_batch[example_id, 0, ...].cpu().numpy() example_img_remove_symm = 255 * example_img_remove_symm / example_img_remove_symm.max( ) if symmetry: W, H = example_img_remove_symm.shape example_img_remove_symm = example_img_remove_symm[:W // 2, :H // 2] example_img_remove_symm = Image.fromarray( example_img_remove_symm) example_img_remove_symm.show() example_coeff.show() if isinstance(self.extract_level, int): extr_level_coeff_batch = self.extract_coeff_level( self.extract_level, coeff_batch) W, H, _ = extr_level_coeff_batch.size()[-3:] nbands = extr_level_coeff_batch.size()[0] extr_level_coeff_batch = extr_level_coeff_batch.view( nbands, bs, num_phase_frames, W, H, 2) extr_level_coeff_batch = extr_level_coeff_batch.permute( 1, 0, 2, 3, 4, 5).contiguous() if symmetry: extr_level_coeff_batch = extr_level_coeff_batch[..., :W // 2, :H // 2, :] elif isinstance(self.extract_level, list): extr_level_coeff_batch = [] for level in self.extract_level: level_coeff_batch = self.extract_coeff_level( level, coeff_batch) W, H, _ = level_coeff_batch.size()[-3:] nbands = level_coeff_batch.size()[0] level_coeff_batch = level_coeff_batch.view( nbands, bs, num_phase_frames, W, H, 2) level_coeff_batch = level_coeff_batch.permute( 1, 0, 2, 3, 4, 5).contiguous() if symmetry: level_coeff_batch = level_coeff_batch[..., :W // 2, :H // 2, :] extr_level_coeff_batch.append(level_coeff_batch) return extr_level_coeff_batch def extract_coeff_level(self, level, coeff_batch): extr_level_coeff_batch = coeff_batch[level] assert isinstance(extr_level_coeff_batch, list) extr_level_coeff_batch = torch.stack(extr_level_coeff_batch, 0) return extr_level_coeff_batch def extract(self, coeff_batch): """ coeff batch has dimension: batch size, nbands, number phase frames (17), W, H, 2 (2 is for real part and imaginary part) """ bs, n_bands, n_phase_frames, W, H, _ = coeff_batch.size() trans_coeff_batch = coeff_batch.view(bs * n_bands * n_phase_frames, W, H, -1) real_coeff_batch, imag_coeff_batch = torch.unbind( trans_coeff_batch, -1) phase_batch = torch.atan2(imag_coeff_batch, real_coeff_batch) mag_batch = torch.sqrt( torch.pow(imag_coeff_batch, 2) + torch.pow(real_coeff_batch, 2)) phase_batch = phase_batch.view(bs * n_bands, n_phase_frames, W, H) EPS = 1e-10 mag_batch = mag_batch.view(bs * n_bands, n_phase_frames, W, H) + EPS # TO avoid mag==0 assert (mag_batch <= 0.0).nonzero().size(0) == 0 # phase unwrap over time phase_batch = torch_unwrap(phase_batch, discont=math.pi, dim=-3) # phase denoising (amplitude-based gaussian blur) g_kernel = torch.from_numpy(gaussian_kernel(std=2, tap=11)) #denoised_phase_batch = amplitude_based_gaussian_blur_numpy(mag_batch, phase_batch, g_kernel) denoised_phase_batch = amplitude_based_gaussian_blur( mag_batch, phase_batch, g_kernel) denoised_phase_batch = denoised_phase_batch.view( bs, n_bands, n_phase_frames, W, H) # phase difference phase_difference_batch = torch_diff(denoised_phase_batch, dim=2) phase_difference_batch = phase_difference_batch.view( bs, n_bands, n_phase_frames - 1, W, H) if self.visualize: phase_example = phase_batch.view(bs, n_bands, n_phase_frames, W, H)[0, ...] mag_example = mag_batch.view(bs, n_bands, n_phase_frames, W, H)[0, ...] denoised_phase_example = denoised_phase_batch.view( bs, n_bands, n_phase_frames, W, H)[0, ...] phase_diff_example = phase_difference_batch.view( bs, n_bands, n_phase_frames - 1, W, H)[0, ...] self.show_3D_subplots(phase_example, title="phase example", first_k_frames=2) self.show_3D_subplots(mag_example, title="magnitude example", first_k_frames=2) self.show_3D_subplots(denoised_phase_example, title="denoised phase example", first_k_frames=2) self.show_3D_subplots(phase_diff_example, title="phase difference example", first_k_frames=2) # denoised phase centered mean = denoised_phase_batch.mean(-1).mean(-1) mean = mean.unsqueeze(-1).unsqueeze(-1) denoised_phase_batch = denoised_phase_batch - mean mean = phase_difference_batch.mean(-1).mean(-1) mean = mean.unsqueeze(-1).unsqueeze(-1) phase_difference_batch = phase_difference_batch - mean phase_difference_batch = torch.clamp(phase_difference_batch, -5 * math.pi, 5 * math.pi) return phase_difference_batch def show_3D_subplots(self, data, title, first_k_frames=None): """ data has dimensions: nbands, n_phase_frames, W, H """ nbands, n_phase_frames, W, H = data.size() m = nbands n = first_k_frames if first_k_frames is not None else n_phase_frames X, Y = range(1, W + 1), range(1, H + 1) Xm, Ym = np.meshgrid(X, Y) for i in range(m): fig, ax = plt.subplots(nrows=1, ncols=n, subplot_kw={'projection': "3d"}) for j in range(n): img = data[i, j, ...].cpu().numpy() surf = ax[j].plot_surface(Xm, Ym, img, rstride=1, cstride=1, cmap=cm.coolwarm, linewidth=0, antialiased=False) fig.colorbar(surf, shrink=0.5, aspect=10) fig.suptitle(title + ": orientation {}".format(i)) plt.show()