def tikhonov_filter(s, *, lmbda=1.0, npd=16, dtype=torch.float32): r"""Lowpass filter based on Tikhonov regularization. Lowpass filter image(s) and return low and high frequency components, consisting of the lowpass filtered image and its difference with the input image. The lowpass filter is equivalent to Tikhonov regularization with `lmbda` as the regularization parameter and a discrete gradient as the operator in the regularization term, i.e. the lowpass component is the solution to .. math:: \mathrm{argmin}_\mathbf{x} \; (1/2) \left\|\mathbf{x} - \mathbf{s} \right\|_2^2 + (\lambda / 2) \sum_i \| G_i \mathbf{x} \|_2^2 \;\;, where :math:`\mathbf{s}` is the input image, :math:`\lambda` is the regularization parameter, and :math:`G_i` is an operator that computes the discrete gradient along image axis :math:`i`. Once the lowpass component :math:`\mathbf{x}` has been computed, the highpass component is just :math:`\mathbf{s} - \mathbf{x}`. Parameters ---------- s : array_like Input image or array of images. lmbda : float Regularization parameter controlling lowpass filtering. npd : int, optional (default=16) Number of samples to pad at image boundaries. Returns ------- slp : array_like Lowpass image or array of images. shp : array_like Highpass image or array of images. """ grv = torch.from_numpy(np.array([-1.0, 1.0]).reshape([2, 1])).to(s.device) gcv = torch.from_numpy(np.array([-1.0, 1.0]).reshape([1, 2])).to(s.device) fftopt = {"s": (s.shape[0] + 2 * npd, s.shape[1] + 2 * npd), "dim": (0, 1)} Gr = tfft.rfftn(grv, **fftopt) Gc = tfft.rfftn(gcv, **fftopt) A = 1.0 + lmbda * (torch.conj(Gr) * Gr + torch.conj(Gc) * Gc).real if s.ndim > 2: A = A[(slice(None), ) * 2 + (np.newaxis, ) * (s.ndim - 2)] fill = ((npd, npd), ) * 2 + ((0, 0), ) * (s.ndim - 2) snp = np.pad(s.cpu().numpy(), fill, 'symmetric') # sp = tpad(s, ((npd, npd),)*2 + ((0, 0),)*(s.ndim-2), 'symmetric') sp = torch.from_numpy(snp).to(s.device) # sp = torch.from_numpy(np.pad(s.numpy(), ((npd, npd),)*2 + ((0, 0),)*(s.ndim-2), 'symmetric')) spshp = sp.shape sp = tfft.rfftn(sp, dim=(0, 1)) sp /= A sp = tfft.irfftn(sp, s=spshp[0:2], dim=(0, 1)) slp = sp[npd:(sp.shape[0] - npd), npd:(sp.shape[1] - npd)] shp = s - slp return slp, shp
def fft_conv( signal: Tensor, kernel: Tensor, bias: Tensor = None, padding: int = 0, ) -> Tensor: """Performs N-d convolution of Tensors using a fast fourier transform, which is very fast for large kernel sizes. Also, optionally adds a bias Tensor after the convolution (in order ot mimic the PyTorch direct convolution). Args: signal: (Tensor) Input tensor to be convolved with the kernel. kernel: (Tensor) Convolution kernel. bias: (Optional, Tensor) Bias tensor to add to the output. padding: (int) Number of zero samples to pad the input on the last dimension. Returns: (Tensor) Convolved tensor """ # Pad the input signal & kernel tensors signal_padding = (signal.ndim - 2) * [padding, padding] signal = f.pad(signal, signal_padding) kernel_padding = [ pad for i in reversed(range(2, signal.ndim)) for pad in [0, signal.size(i) - kernel.size(i)] ] padded_kernel = f.pad(kernel, kernel_padding) # Perform fourier convolution -- FFT, matrix multiply, then IFFT signal_fr = rfftn(signal, dim=tuple(range(2, signal.ndim))) kernel_fr = rfftn(padded_kernel, dim=tuple(range(2, signal.ndim))) kernel_fr.imag *= -1 output_fr = complex_matmul(signal_fr, kernel_fr) output = irfftn(output_fr, dim=tuple(range(2, signal.ndim))) # Remove extra padded values crop_slices = [slice(0, output.shape[0]), slice(0, output.shape[1])] + [ slice(0, (signal.size(i) - kernel.size(i) + 1)) for i in range(2, signal.ndim) ] output = output[crop_slices].contiguous() # Optionally, add a bias term before returning. if bias is not None: bias_shape = tuple([1, -1] + (signal.ndim - 2) * [1]) output += bias.view(bias_shape) return output
def fft_new(z, x, label): zf = fft.rfftn(z, dim=[-2, -1]) xf = fft.rfftn(x, dim=[-2, -1]) # R[batch, 1, 121, 61] kzzf = torch.sum(torch.real(zf)**2 + torch.imag(zf)**2, dim=1, keepdim=True) # C[batch, 1, 121, 61] t = xf * torch.conj(zf) kxzf = torch.sum(t, dim=1, keepdim=True) # C[batch, 1, 121, 61, 2] alphaf = label.to(device=z.device) / (kzzf + lambda0) # R[batch, 1, 121, 121] return fft.irfftn(kxzf * alphaf, s=[121, 121], dim=[-2, -1])
def forward(self, template, search, label): # Template shape: R[batch, 32, 121, 121] # Search shape: R[batch, 32, 121, 121] # Label shape: R[batch, 1, 121, 61] # zf & xf shape: C[batch, 32, 121, 61] zf = fft.rfftn(template, dim=[-2, -1]) xf = fft.rfftn(search, dim=[-2, -1]) # R[batch, 1, 121, 61] kzzf = torch.sum(zf.real**2 + zf.imag**2, dim=1, keepdim=True) # C[batch, 1, 121, 61] t = xf * torch.conj(zf) kxzf = torch.sum(t, dim=1, keepdim=True) # C[batch, 1, 121, 61] alphaf = label.to(device=template.device) / (kzzf + self.lambda0) # R[batch, 1, 121, 121] response = fft.irfftn(kxzf * alphaf, s=[121, 121], dim=[-2, -1]) return response
def fft_conv( signal: Tensor, kernel: Tensor, bias: Tensor = None, padding: Union[int, Iterable[int], str] = 0, stride: Union[int, Iterable[int]] = 1, groups: int = 1, padding_mode: str = "constant", ) -> Tensor: """Performs N-d convolution of Tensors using a fast fourier transform, which is very fast for large kernel sizes. Also, optionally adds a bias Tensor after the convolution (in order ot mimic the PyTorch direct convolution). Args: signal: (Tensor) Input tensor to be convolved with the kernel. kernel: (Tensor) Convolution kernel. bias: (Tensor) Bias tensor to add to the output. padding: (Union[int, Iterable[int]) Number of zero samples to pad the input on the last dimension. stride: (Union[int, Iterable[int]) Stride size for computing output values. groups: (Union[int, Iterable[int]]) padding_mode: (str) Padding mode to use from {constant, reflection, replication}. reflection not available for 3d. Returns: (Tensor) Convolved tensor """ # Cast stride to tuple. stride_ = to_ntuple(stride, n=signal.ndim - 2) if padding != "same": padding_ = to_ntuple(padding, n=signal.ndim - 2) signal_padding = [p for p in padding_[::-1] for _ in range(2)] else: # signal_padding = [ # (0, 0) if k <= s else ((k - s) // 2, k - (k - s) // 2) # for s, k, in zip(signal.shape[2:], kernel.shape[2:]) # ] # signal_padding = [p for pd in signal_padding[::-1] for p in pd] padding_ = [k // 2 for k in kernel.shape[2:]] signal_padding = [p for p in padding_[::-1] for _ in range(2)] # Pad the input signal & kernel tensors signal = f.pad(signal, signal_padding, mode=padding_mode) # Because PyTorch computes a *one-sided* FFT, we need the final dimension to # have *even* length. Just pad with one more zero if the final dimension is odd. signal_size = signal.size() # original signal size without padding to even if signal.size(-1) % 2 != 0: signal = f.pad(signal, [0, 1]) kernel_padding = [ pad for i in reversed(range(2, signal.ndim)) for pad in [0, signal.size(i) - kernel.size(i)] ] padded_kernel = f.pad(kernel, kernel_padding) assert ( padded_kernel.shape[1:] == signal.shape[1:] ), f"padded kernel shape {padded_kernel.shape} not equal to signal shape {signal.shape}" # Perform fourier convolution -- FFT, matrix multiply, then IFFT # signal = signal.reshape(signal.size(0), groups, -1, *signal.shape[2:]) signal_fr = rfftn(signal.float(), dim=tuple(range(2, signal.ndim))) kernel_fr = rfftn(padded_kernel.float(), dim=tuple(range(2, signal.ndim))) kernel_fr.imag *= -1 output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups) output = irfftn(output_fr, dim=tuple(range(2, signal.ndim))) # Remove extra padded values crop_slices = [slice(None), slice(None)] + [ slice( 0, (signal_size[i] - kernel.size(i) + (kernel.size(i) % 2)), # if padding != "same" # else None, stride_[i - 2], ) for i in range(2, signal.ndim) ] output = output[crop_slices].contiguous() # Optionally, add a bias term before returning. if bias is not None: bias_shape = tuple([1, -1] + (signal.ndim - 2) * [1]) output += bias.view(bias_shape) return output
def get_reg_filter(sz: torch.Tensor, target_sz: torch.Tensor, params): """Computes regularization filter in CCOT and ECO.""" if not params.use_reg_window: return params.reg_window_min * torch.ones(1, 1, 1, 1) if getattr(params, 'reg_window_square', False): target_sz = target_sz.prod().sqrt() * torch.ones(2) # Normalization factor reg_scale = 0.5 * target_sz # Construct grid if getattr(params, 'reg_window_centered', True): wrg = torch.arange(-int((sz[0] - 1) / 2), int(sz[0] / 2 + 1), dtype=torch.float32).view(1, 1, -1, 1) wcg = torch.arange(-int((sz[1] - 1) / 2), int(sz[1] / 2 + 1), dtype=torch.float32).view(1, 1, 1, -1) else: wrg = torch.cat([ torch.arange(0, int(sz[0] / 2 + 1), dtype=torch.float32), torch.arange(-int((sz[0] - 1) / 2), 0, dtype=torch.float32) ]).view(1, 1, -1, 1) wcg = torch.cat([ torch.arange(0, int(sz[1] / 2 + 1), dtype=torch.float32), torch.arange(-int((sz[1] - 1) / 2), 0, dtype=torch.float32) ]).view(1, 1, 1, -1) # Construct regularization window reg_window = (params.reg_window_edge - params.reg_window_min) * \ (torch.abs(wrg / reg_scale[0]) ** params.reg_window_power + torch.abs(wcg / reg_scale[1]) ** params.reg_window_power) + params.reg_window_min # Compute DFT and enforce sparsity reg_window_dft = torch.view_as_real( torch_fft.rfftn(reg_window, dim=[-2, -1])) / sz.prod() reg_window_dft_abs = complex.abs(reg_window_dft) reg_window_dft[reg_window_dft_abs < params.reg_sparsity_threshold * reg_window_dft_abs.max(), :] = 0 # Do the inverse transform to correct for the window minimum reg_window_sparse = torch_fft.irfftn(torch.view_as_complex(reg_window_dft), s=sz.long().tolist(), dim=[-2, -1]) reg_window_dft[ 0, 0, 0, 0, 0] += params.reg_window_min - sz.prod() * reg_window_sparse.min() reg_window_dft = complex.real(fourier.rfftshift2(reg_window_dft)) # Remove zeros max_inds, _ = reg_window_dft.nonzero(as_tuple=False).max(dim=0) mid_ind = int((reg_window_dft.shape[2] - 1) / 2) top = max_inds[-2].item() + 1 bottom = 2 * mid_ind - max_inds[-2].item() right = max_inds[-1].item() + 1 reg_window_dft = reg_window_dft[..., bottom:top, :right] if reg_window_dft.shape[-1] > 1: reg_window_dft = torch.cat( [reg_window_dft[..., 1:].flip((2, 3)), reg_window_dft], -1) return reg_window_dft
def fft_comparison(): global a, b t = torch.rand((42, 32, 121, 121)) a = torch.rfft(t, signal_ndim=2) b = fft.rfftn(t, dim=[-2, -1]) compare_complex(a, b)
# fft_comparison() ############################################## lambda0 = 1e-4 y = util.gaussian_shaped_labels(4.166666666666667, [121, 121]).astype(np.float32) x = torch.rand((42, 32, 121, 121)) z = torch.rand((42, 32, 121, 121)) fft_label_view = torch.Tensor(y).view(1, 1, 121, 121).cuda() label_old = torch.rfft(fft_label_view, signal_ndim=2).repeat(42, 1, 1, 1, 1).cuda(non_blocking=True) label_new = fft.rfftn(fft_label_view, dim=[-2, -1]).repeat(42, 1, 1, 1).cuda(non_blocking=True) ############################################## zfnew = fft.rfftn(z, dim=[-2, -1]) zfold = torch.rfft(z, signal_ndim=2) xfnew = fft.rfftn(x, dim=[-2, -1]) xfold = torch.rfft(x, signal_ndim=2) tnew = xfnew * torch.conj(zfnew) # told = cn.mulconj(xfold, zfold) # rtold = told[...,0] # itold = told[...,1] # tnew = torch.complex(rtold, itold) # told = torch.view_as_complex(tnew) told1 = torch.view_as_real(tnew)
def fft_conv( signal: Tensor, kernel: Tensor, bias: Tensor = None, padding: Union[int, Iterable[int]] = 0, stride: Union[int, Iterable[int]] = 1, groups: int = 1, ) -> Tensor: """Performs N-d convolution of Tensors using a fast fourier transform, which is very fast for large kernel sizes. Also, optionally adds a bias Tensor after the convolution (in order ot mimic the PyTorch direct convolution). Args: signal: (Tensor) Input tensor to be convolved with the kernel. kernel: (Tensor) Convolution kernel. bias: (Tensor) Bias tensor to add to the output. padding: (Union[int, Iterable[int]) Number of zero samples to pad the input on the last dimension. stride: (Union[int, Iterable[int]) Stride size for computing output values. Returns: (Tensor) Convolved tensor """ # Cast padding & stride to tuples. # st = time.time() # padding_ = to_ntuple(padding, n=signal.ndim - 2) # stride_ = to_ntuple(stride, n=signal.ndim - 2) padding_ = padding stride_ = (1, 1) # print('padding_:', padding_) # print('stride_:', stride_) # padding_time = time.time() - st # print('padding_time:', padding_time) # Pad the input signal & kernel tensors signal_padding = [p for p in padding_[::-1] for _ in range(2)] signal = f.pad(signal, signal_padding) # Because PyTorch computes a *one-sided* FFT, we need the final dimension to # have *even* length. Just pad with one more zero if the final dimension is odd. if signal.size(-1) % 2 != 0: signal_ = f.pad(signal, [0, 1]) else: signal_ = signal # st = time.time() kernel_padding = [ pad for i in reversed(range(2, signal_.ndim)) for pad in [0, signal_.size(i) - kernel.size(i)] ] # print(kernel_padding) # print(kernel.shape) padded_kernel = f.pad(kernel, kernel_padding) # padding_time_kernel = time.time() - st # print('padding_time_kernel:', padding_time_kernel) # Perform fourier convolution -- FFT, matrix multiply, then IFFT # signal_ = signal_.reshape(signal_.size(0), groups, -1, *signal_.shape[2:]) # st = time.time() # signal_fr = rfftn(signal_, dim=tuple(range(2, signal.ndim))) # kernel_fr = rfftn(padded_kernel, dim=tuple(range(2, signal.ndim))) signal_fr = rfftn(signal_, dim=tuple(range(2, signal.ndim))) kernel_fr = rfftn(padded_kernel, dim=tuple(range(2, signal.ndim))) # rfft_time = time.time() - st # print('rfft_time:', rfft_time) # print('Line: padded signal shape:', signal_.shape) # print('Line: signal_ shape:', signal_.shape) # print('Line: padded_kernel shape:', padded_kernel.shape) # print('Line: kernel_fr shape:', kernel_fr.shape) # st = time.time() kernel_fr.imag *= -1 # output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups)/torch.numel(signal_fr[0,0,0,:]) # print('KOOOOME:', output_fr.shape) # print('KOOOOME:', output_fr) # x = signal_fr[0,:] # print(x) # output_fr = my_complex_matmul(signal_fr, kernel_fr, groups=groups)/torch.numel(signal_fr[0,0,0,:]) # output_fr = my_complex_matmul(signal_fr, kernel_fr, groups=groups) output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups) output = output_fr # print('output.shape:', output.shape) # matmul_time = time.time() - st # print('matmul_time:', matmul_time) # st = time.time() output = irfftn(output_fr, dim=tuple(range(4, signal.ndim + 2))) # print('output irfftn .shape:', output.shape) # output = ifftn(output_fr, dim=tuple(range(3, signal.ndim+1))) # inverse_time = time.time() - st # print('inverse_time:', inverse_time) # st = time.time() # Remove extra padded values # print('signal, kernel, padding', signal.shape, kernel.shape, padding) crop_slices = [ slice(0, output.size(0)), slice(0, output.size(1)), slice(0, output.size(2)), slice(0, output.size(3)) ] + [ slice(padding_[i - 3] - 1, (signal.size(i - 1) - kernel.size(i - 1) - padding_[i - 3] + 2), stride_[i - 3]) for i in range(3, signal.ndim + 1) ] # crop_slices = # print('crop_slices:', crop_slices) # print('my output before croping:', output.shape) output = output[crop_slices].contiguous() # output = output[:,:,:,1:].contiguous() # print('output after crop:', output.shape) # print('output norm 2:', torch.norm(output)) # crop_time = time.time() - st # print('crop_time:', crop_time) # Optionally, add a bias term before returning. if bias is not None: bias_shape = tuple([1, -1] + (signal.ndim - 2) * [1]) output += bias.view(bias_shape) return output
# # print('t difference: ') # compare_complex(told1, tnew) # compare_complex(told2, tnew) # compare_complex(told3, tnew) # # print('kxzf difference: ') # compare_complex(torch.sum(told1, dim=1, keepdim=True), torch.sum(tnew, dim=1, keepdim=True)) # compare_complex(torch.sum(told2, dim=1, keepdim=True), torch.sum(tnew, dim=1, keepdim=True)) # compare_complex(torch.sum(told3, dim=1, keepdim=True), torch.sum(tnew, dim=1, keepdim=True)) shape = (42, 32, 10, 10) x = torch.randn(shape) z = torch.randn(shape) tnew = fft.rfftn(x, dim=[2, 3]) * torch.conj(fft.rfftn(z, dim=[2, 3])) # tnew2 = torch.randn(shape, dtype=torch.complex64) # tnew2 = torch.empty(shape, dtype=torch.complex64).normal_(mean=0, std=0.00001) # h1 = torch.histc(torch.view_as_real(tnew)) # h2 = torch.histc(torch.view_as_real(tnew2)) # # import matplotlib.pyplot as plt # plt.plot(h1) # plt.plot(h2) # plt.show() sum_float = torch.sum(torch.view_as_real(tnew), dim=1, keepdim=True) sum_complex = torch.sum(tnew, dim=1, keepdim=True) print(