def spectrum_to_basis(spectrum: torch.Tensor, l2_normalize: bool = True) -> torch.Tensor: """Convert spectrum matrix to Fourier basis by 2D FFT. Shape of returned basis is (H, W). Note: - Currently, only supported the case H==W. If H!=W, returned basis might be wrong. - In order to apply 2D FFT, dim argument of torch.fft.irfftn should be =(-2,-1). Args: spectrum (torch.Tensor): 2D spectrum matrix. Its shape should be (H, W//2+1). Here, (H, W) represent the size of 2D Fourier basis we want to get. l2_normalize (bool): If True, basis is l2 normalized. Returns: torch.Tensor: 2D Fourier basis. """ assert len(spectrum.size()) == 2 H = spectrum.size(-2) # currently, only consider the case H==W basis = fft.irfftn(spectrum, s=(H, H), dim=(-2, -1)) if l2_normalize: return cast(torch.Tensor, basis / basis.norm(dim=(-2, -1))[None, None]) else: return cast(torch.Tensor, basis)
def cifft2(a, signal_sizes=None): """Do inverse FFT corresponding to cfft2.""" b_in = torch.view_as_complex(irfftshift2(a)) # , 2, signal_sizes=signal_sizes) s = [-1, 2 * b_in.size(-1) - 1] if signal_sizes is None else signal_sizes return torch_fft.irfftn(b_in, s=s, dim=[-2, -1])
def tikhonov_filter(s, *, lmbda=1.0, npd=16, dtype=torch.float32): r"""Lowpass filter based on Tikhonov regularization. Lowpass filter image(s) and return low and high frequency components, consisting of the lowpass filtered image and its difference with the input image. The lowpass filter is equivalent to Tikhonov regularization with `lmbda` as the regularization parameter and a discrete gradient as the operator in the regularization term, i.e. the lowpass component is the solution to .. math:: \mathrm{argmin}_\mathbf{x} \; (1/2) \left\|\mathbf{x} - \mathbf{s} \right\|_2^2 + (\lambda / 2) \sum_i \| G_i \mathbf{x} \|_2^2 \;\;, where :math:`\mathbf{s}` is the input image, :math:`\lambda` is the regularization parameter, and :math:`G_i` is an operator that computes the discrete gradient along image axis :math:`i`. Once the lowpass component :math:`\mathbf{x}` has been computed, the highpass component is just :math:`\mathbf{s} - \mathbf{x}`. Parameters ---------- s : array_like Input image or array of images. lmbda : float Regularization parameter controlling lowpass filtering. npd : int, optional (default=16) Number of samples to pad at image boundaries. Returns ------- slp : array_like Lowpass image or array of images. shp : array_like Highpass image or array of images. """ grv = torch.from_numpy(np.array([-1.0, 1.0]).reshape([2, 1])).to(s.device) gcv = torch.from_numpy(np.array([-1.0, 1.0]).reshape([1, 2])).to(s.device) fftopt = {"s": (s.shape[0] + 2 * npd, s.shape[1] + 2 * npd), "dim": (0, 1)} Gr = tfft.rfftn(grv, **fftopt) Gc = tfft.rfftn(gcv, **fftopt) A = 1.0 + lmbda * (torch.conj(Gr) * Gr + torch.conj(Gc) * Gc).real if s.ndim > 2: A = A[(slice(None), ) * 2 + (np.newaxis, ) * (s.ndim - 2)] fill = ((npd, npd), ) * 2 + ((0, 0), ) * (s.ndim - 2) snp = np.pad(s.cpu().numpy(), fill, 'symmetric') # sp = tpad(s, ((npd, npd),)*2 + ((0, 0),)*(s.ndim-2), 'symmetric') sp = torch.from_numpy(snp).to(s.device) # sp = torch.from_numpy(np.pad(s.numpy(), ((npd, npd),)*2 + ((0, 0),)*(s.ndim-2), 'symmetric')) spshp = sp.shape sp = tfft.rfftn(sp, dim=(0, 1)) sp /= A sp = tfft.irfftn(sp, s=spshp[0:2], dim=(0, 1)) slp = sp[npd:(sp.shape[0] - npd), npd:(sp.shape[1] - npd)] shp = s - slp return slp, shp
def forward( self, signal: Tensor, kernel: Tensor, bias: Tensor = None, padding: Union[int, Iterable[int]] = 0, stride: Union[int, Iterable[int]] = 1, groups: int = 1, ) -> Tensor: # Cast padding & stride to tuples. padding_ = self.to_ntuple(padding, n=signal.ndim - 2) stride_ = self.to_ntuple(stride, n=signal.ndim - 2) # Pad the input signal & kernel tensors signal_padding = [p for p in padding_[::-1] for _ in range(2)] signal = f.pad(signal, signal_padding) # Because PyTorch computes a *one-sided* FFT, we need the final dimension to # have *even* length. Just pad with one more zero if the final dimension is odd. if signal.size(-1) % 2 != 0: signal_ = f.pad(signal, [0, 1]) else: signal_ = signal kernel_padding = [ pad for i in reversed(range(2, signal_.ndim)) for pad in [0, signal_.size(i) - kernel.size(i)] ] padded_kernel = f.pad(kernel, kernel_padding) # Perform fourier convolution -- FFT, matrix multiply, then IFFT # signal_ = signal_.reshape(signal_.size(0), groups, -1, *signal_.shape[2:]) signal_fr = rfftn(signal_, dim=tuple(range(2, signal.ndim))) kernel_fr = rfftn(padded_kernel, dim=tuple(range(2, signal.ndim))) kernel_fr.imag *= -1 output_fr = self.complex_matmul(signal_fr, kernel_fr, groups=groups) output = irfftn(output_fr, dim=tuple(range(2, signal.ndim))) # Remove extra padded values crop_slices = [slice(0, output.size(0)), slice(0, output.size(1))] + [ slice(0, (signal.size(i) - kernel.size(i) + 1), stride_[i - 2]) for i in range(2, signal.ndim) ] output = output[crop_slices].contiguous() # Optionally, add a bias term before returning. if bias is not None: bias_shape = tuple([1, -1] + (signal.ndim - 2) * [1]) output += bias.view(bias_shape) return output
def forward(self, x): x = self.feature(x) * self.config.cos_window xf = fft.rfftn(x, dim=[-2, -1]) kxzf = torch.sum(xf * torch.conj(self.model_zf), dim=1, keepdim=True) response = fft.irfftn(kxzf * self.model_alphaf, dim=[-2, -1]) # r_max = torch.max(response) # cv2.imshow('response', response[0, 0].data.cpu().numpy()) # cv2.waitKey(0) return response
def fft_conv( signal: Tensor, kernel: Tensor, bias: Tensor = None, padding: int = 0, ) -> Tensor: """Performs N-d convolution of Tensors using a fast fourier transform, which is very fast for large kernel sizes. Also, optionally adds a bias Tensor after the convolution (in order ot mimic the PyTorch direct convolution). Args: signal: (Tensor) Input tensor to be convolved with the kernel. kernel: (Tensor) Convolution kernel. bias: (Optional, Tensor) Bias tensor to add to the output. padding: (int) Number of zero samples to pad the input on the last dimension. Returns: (Tensor) Convolved tensor """ # Pad the input signal & kernel tensors signal_padding = (signal.ndim - 2) * [padding, padding] signal = f.pad(signal, signal_padding) kernel_padding = [ pad for i in reversed(range(2, signal.ndim)) for pad in [0, signal.size(i) - kernel.size(i)] ] padded_kernel = f.pad(kernel, kernel_padding) # Perform fourier convolution -- FFT, matrix multiply, then IFFT signal_fr = rfftn(signal, dim=tuple(range(2, signal.ndim))) kernel_fr = rfftn(padded_kernel, dim=tuple(range(2, signal.ndim))) kernel_fr.imag *= -1 output_fr = complex_matmul(signal_fr, kernel_fr) output = irfftn(output_fr, dim=tuple(range(2, signal.ndim))) # Remove extra padded values crop_slices = [slice(0, output.shape[0]), slice(0, output.shape[1])] + [ slice(0, (signal.size(i) - kernel.size(i) + 1)) for i in range(2, signal.ndim) ] output = output[crop_slices].contiguous() # Optionally, add a bias term before returning. if bias is not None: bias_shape = tuple([1, -1] + (signal.ndim - 2) * [1]) output += bias.view(bias_shape) return output
def fft_new(z, x, label): zf = fft.rfftn(z, dim=[-2, -1]) xf = fft.rfftn(x, dim=[-2, -1]) # R[batch, 1, 121, 61] kzzf = torch.sum(torch.real(zf)**2 + torch.imag(zf)**2, dim=1, keepdim=True) # C[batch, 1, 121, 61] t = xf * torch.conj(zf) kxzf = torch.sum(t, dim=1, keepdim=True) # C[batch, 1, 121, 61, 2] alphaf = label.to(device=z.device) / (kzzf + lambda0) # R[batch, 1, 121, 121] return fft.irfftn(kxzf * alphaf, s=[121, 121], dim=[-2, -1])
def forward(self, template, search, label): # Template shape: R[batch, 32, 121, 121] # Search shape: R[batch, 32, 121, 121] # Label shape: R[batch, 1, 121, 61] # zf & xf shape: C[batch, 32, 121, 61] zf = fft.rfftn(template, dim=[-2, -1]) xf = fft.rfftn(search, dim=[-2, -1]) # R[batch, 1, 121, 61] kzzf = torch.sum(zf.real**2 + zf.imag**2, dim=1, keepdim=True) # C[batch, 1, 121, 61] t = xf * torch.conj(zf) kxzf = torch.sum(t, dim=1, keepdim=True) # C[batch, 1, 121, 61] alphaf = label.to(device=template.device) / (kzzf + self.lambda0) # R[batch, 1, 121, 121] response = fft.irfftn(kxzf * alphaf, s=[121, 121], dim=[-2, -1]) return response
def _fft_conv_transposend( input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: Tuple[int], padding: Tuple[int], output_padding: Tuple[int], groups: int, dilation: Tuple[int], ) -> Tensor: output_size = _conv_transpose_shape(input.shape[2:], weight.shape[2:], stride, padding, output_padding, dilation) padded_output_size = tuple(o + 2 * p for o, p in zip(output_size, padding)) s: List[int] = [] weight_s: List[int] = [] for i, (x_size, w_size, d, st) in enumerate( zip(padded_output_size, weight.shape[2:], dilation, stride)): s_size = max(x_size, w_size * d) # find s size that can be divided by stride and dilation rfft_even = 2 if i == len(stride) - 1 else 1 factor = _lcm(st * rfft_even, d * rfft_even) offset = s_size % factor if offset: s_size += factor - offset s.append(s_size // st) weight_s.append(s_size // d) X = rfft(input, n=s[-1]) W = rfft(weight, n=weight_s[-1]) if stride[-1] > 1: X_neg_freq = X.flip(-1)[..., 1:] X_neg_freq.imag.mul_(-1) tmp = [X] for i in range(1, stride[-1]): if i % 2: tmp.append(X_neg_freq) else: tmp.append(X[..., 1:]) X = torch.cat(tmp, -1) if dilation[-1] > 1: W_neg_freq = W.flip(-1)[..., 1:] W_neg_freq.imag.mul_(-1) tmp = [W] for i in range(1, dilation[-1]): if i % 2: tmp.append(W_neg_freq) else: tmp.append(W[..., 1:]) W = torch.cat(tmp, -1) if len(s) > 1: X = fftn(X, s=s[:-1], dim=tuple(range(2, X.ndim - 1))) W = fftn(W, s=weight_s[:-1], dim=tuple(range(2, W.ndim - 1))) repeats = (1, 1) + stride[:-1] + (1, ) if sum(repeats) > X.ndim: X = X.repeat(*repeats) repeats = (1, 1) + dilation[:-1] + (1, ) if sum(repeats) > W.ndim: W = W.repeat(*repeats) Y = _complex_matmul(X, W, groups, True) output = irfftn(Y, dim=tuple(range(2, Y.ndim))) # Remove extra padded values index = (slice(None), ) * 2 + tuple( slice(p, o + p) for p, o in zip(padding, output_size)) output = output[index].contiguous() # Optionally, add a bias term before returning. if bias is not None: output += bias[(slice(None), ) + (None, ) * (output.ndim - 2)] return output
def fft_conv( signal: Tensor, kernel: Tensor, bias: Tensor = None, padding: Union[int, Iterable[int]] = 0, stride: Union[int, Iterable[int]] = 1, groups: int = 1, ) -> Tensor: """Performs N-d convolution of Tensors using a fast fourier transform, which is very fast for large kernel sizes. Also, optionally adds a bias Tensor after the convolution (in order ot mimic the PyTorch direct convolution). Args: signal: (Tensor) Input tensor to be convolved with the kernel. kernel: (Tensor) Convolution kernel. bias: (Tensor) Bias tensor to add to the output. padding: (Union[int, Iterable[int]) Number of zero samples to pad the input on the last dimension. stride: (Union[int, Iterable[int]) Stride size for computing output values. Returns: (Tensor) Convolved tensor """ # Cast padding & stride to tuples. padding_ = to_ntuple(padding, n=signal.ndim - 2) stride_ = to_ntuple(stride, n=signal.ndim - 2) # Pad the input signal & kernel tensors signal_padding = [p for p in padding_[::-1] for _ in range(2)] signal = f.pad(signal, signal_padding) # Because PyTorch computes a *one-sided* FFT, we need the final dimension to # have *even* length. Just pad with one more zero if the final dimension is odd. if signal.size(-1) % 2 != 0: signal_ = f.pad(signal, [0, 1]) else: signal_ = signal kernel_padding = [ pad for i in reversed(range(2, signal_.ndim)) for pad in [0, signal_.size(i) - kernel.size(i)] ] padded_kernel = f.pad(kernel, kernel_padding) # Perform fourier convolution -- FFT, matrix multiply, then IFFT # signal_ = signal_.reshape(signal_.size(0), groups, -1, *signal_.shape[2:]) signal_fr = rfftn(signal_, dim=tuple(range(2, signal.ndim))) kernel_fr = rfftn(padded_kernel, dim=tuple(range(2, signal.ndim))) kernel_fr.imag *= -1 output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups) output = irfftn(output_fr, dim=tuple(range(2, signal.ndim))) # Remove extra padded values crop_slices = [slice( 0, output.size(0)), slice(0, output.size(1))] + [ slice(0, (signal.size(i) - kernel.size(i) + 1), stride_[i - 2]) for i in range(2, signal.ndim) ] output = output[crop_slices].contiguous() # Optionally, add a bias term before returning. if bias is not None: bias_shape = tuple([1, -1] + (signal.ndim - 2) * [1]) output += bias.view(bias_shape) return output
def fft_conv( signal: Tensor, kernel: Tensor, bias: Tensor = None, padding: Union[int, Iterable[int], str] = 0, stride: Union[int, Iterable[int]] = 1, groups: int = 1, padding_mode: str = "constant", ) -> Tensor: """Performs N-d convolution of Tensors using a fast fourier transform, which is very fast for large kernel sizes. Also, optionally adds a bias Tensor after the convolution (in order ot mimic the PyTorch direct convolution). Args: signal: (Tensor) Input tensor to be convolved with the kernel. kernel: (Tensor) Convolution kernel. bias: (Tensor) Bias tensor to add to the output. padding: (Union[int, Iterable[int]) Number of zero samples to pad the input on the last dimension. stride: (Union[int, Iterable[int]) Stride size for computing output values. groups: (Union[int, Iterable[int]]) padding_mode: (str) Padding mode to use from {constant, reflection, replication}. reflection not available for 3d. Returns: (Tensor) Convolved tensor """ # Cast stride to tuple. stride_ = to_ntuple(stride, n=signal.ndim - 2) if padding != "same": padding_ = to_ntuple(padding, n=signal.ndim - 2) signal_padding = [p for p in padding_[::-1] for _ in range(2)] else: # signal_padding = [ # (0, 0) if k <= s else ((k - s) // 2, k - (k - s) // 2) # for s, k, in zip(signal.shape[2:], kernel.shape[2:]) # ] # signal_padding = [p for pd in signal_padding[::-1] for p in pd] padding_ = [k // 2 for k in kernel.shape[2:]] signal_padding = [p for p in padding_[::-1] for _ in range(2)] # Pad the input signal & kernel tensors signal = f.pad(signal, signal_padding, mode=padding_mode) # Because PyTorch computes a *one-sided* FFT, we need the final dimension to # have *even* length. Just pad with one more zero if the final dimension is odd. signal_size = signal.size() # original signal size without padding to even if signal.size(-1) % 2 != 0: signal = f.pad(signal, [0, 1]) kernel_padding = [ pad for i in reversed(range(2, signal.ndim)) for pad in [0, signal.size(i) - kernel.size(i)] ] padded_kernel = f.pad(kernel, kernel_padding) assert ( padded_kernel.shape[1:] == signal.shape[1:] ), f"padded kernel shape {padded_kernel.shape} not equal to signal shape {signal.shape}" # Perform fourier convolution -- FFT, matrix multiply, then IFFT # signal = signal.reshape(signal.size(0), groups, -1, *signal.shape[2:]) signal_fr = rfftn(signal.float(), dim=tuple(range(2, signal.ndim))) kernel_fr = rfftn(padded_kernel.float(), dim=tuple(range(2, signal.ndim))) kernel_fr.imag *= -1 output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups) output = irfftn(output_fr, dim=tuple(range(2, signal.ndim))) # Remove extra padded values crop_slices = [slice(None), slice(None)] + [ slice( 0, (signal_size[i] - kernel.size(i) + (kernel.size(i) % 2)), # if padding != "same" # else None, stride_[i - 2], ) for i in range(2, signal.ndim) ] output = output[crop_slices].contiguous() # Optionally, add a bias term before returning. if bias is not None: bias_shape = tuple([1, -1] + (signal.ndim - 2) * [1]) output += bias.view(bias_shape) return output
def get_reg_filter(sz: torch.Tensor, target_sz: torch.Tensor, params): """Computes regularization filter in CCOT and ECO.""" if not params.use_reg_window: return params.reg_window_min * torch.ones(1, 1, 1, 1) if getattr(params, 'reg_window_square', False): target_sz = target_sz.prod().sqrt() * torch.ones(2) # Normalization factor reg_scale = 0.5 * target_sz # Construct grid if getattr(params, 'reg_window_centered', True): wrg = torch.arange(-int((sz[0] - 1) / 2), int(sz[0] / 2 + 1), dtype=torch.float32).view(1, 1, -1, 1) wcg = torch.arange(-int((sz[1] - 1) / 2), int(sz[1] / 2 + 1), dtype=torch.float32).view(1, 1, 1, -1) else: wrg = torch.cat([ torch.arange(0, int(sz[0] / 2 + 1), dtype=torch.float32), torch.arange(-int((sz[0] - 1) / 2), 0, dtype=torch.float32) ]).view(1, 1, -1, 1) wcg = torch.cat([ torch.arange(0, int(sz[1] / 2 + 1), dtype=torch.float32), torch.arange(-int((sz[1] - 1) / 2), 0, dtype=torch.float32) ]).view(1, 1, 1, -1) # Construct regularization window reg_window = (params.reg_window_edge - params.reg_window_min) * \ (torch.abs(wrg / reg_scale[0]) ** params.reg_window_power + torch.abs(wcg / reg_scale[1]) ** params.reg_window_power) + params.reg_window_min # Compute DFT and enforce sparsity reg_window_dft = torch.view_as_real( torch_fft.rfftn(reg_window, dim=[-2, -1])) / sz.prod() reg_window_dft_abs = complex.abs(reg_window_dft) reg_window_dft[reg_window_dft_abs < params.reg_sparsity_threshold * reg_window_dft_abs.max(), :] = 0 # Do the inverse transform to correct for the window minimum reg_window_sparse = torch_fft.irfftn(torch.view_as_complex(reg_window_dft), s=sz.long().tolist(), dim=[-2, -1]) reg_window_dft[ 0, 0, 0, 0, 0] += params.reg_window_min - sz.prod() * reg_window_sparse.min() reg_window_dft = complex.real(fourier.rfftshift2(reg_window_dft)) # Remove zeros max_inds, _ = reg_window_dft.nonzero(as_tuple=False).max(dim=0) mid_ind = int((reg_window_dft.shape[2] - 1) / 2) top = max_inds[-2].item() + 1 bottom = 2 * mid_ind - max_inds[-2].item() right = max_inds[-1].item() + 1 reg_window_dft = reg_window_dft[..., bottom:top, :right] if reg_window_dft.shape[-1] > 1: reg_window_dft = torch.cat( [reg_window_dft[..., 1:].flip((2, 3)), reg_window_dft], -1) return reg_window_dft
def fft_conv( signal: Tensor, kernel: Tensor, bias: Tensor = None, padding: Union[int, Iterable[int]] = 0, stride: Union[int, Iterable[int]] = 1, groups: int = 1, ) -> Tensor: """Performs N-d convolution of Tensors using a fast fourier transform, which is very fast for large kernel sizes. Also, optionally adds a bias Tensor after the convolution (in order ot mimic the PyTorch direct convolution). Args: signal: (Tensor) Input tensor to be convolved with the kernel. kernel: (Tensor) Convolution kernel. bias: (Tensor) Bias tensor to add to the output. padding: (Union[int, Iterable[int]) Number of zero samples to pad the input on the last dimension. stride: (Union[int, Iterable[int]) Stride size for computing output values. Returns: (Tensor) Convolved tensor """ # Cast padding & stride to tuples. # st = time.time() # padding_ = to_ntuple(padding, n=signal.ndim - 2) # stride_ = to_ntuple(stride, n=signal.ndim - 2) padding_ = padding stride_ = (1, 1) # print('padding_:', padding_) # print('stride_:', stride_) # padding_time = time.time() - st # print('padding_time:', padding_time) # Pad the input signal & kernel tensors signal_padding = [p for p in padding_[::-1] for _ in range(2)] signal = f.pad(signal, signal_padding) # Because PyTorch computes a *one-sided* FFT, we need the final dimension to # have *even* length. Just pad with one more zero if the final dimension is odd. if signal.size(-1) % 2 != 0: signal_ = f.pad(signal, [0, 1]) else: signal_ = signal # st = time.time() kernel_padding = [ pad for i in reversed(range(2, signal_.ndim)) for pad in [0, signal_.size(i) - kernel.size(i)] ] # print(kernel_padding) # print(kernel.shape) padded_kernel = f.pad(kernel, kernel_padding) # padding_time_kernel = time.time() - st # print('padding_time_kernel:', padding_time_kernel) # Perform fourier convolution -- FFT, matrix multiply, then IFFT # signal_ = signal_.reshape(signal_.size(0), groups, -1, *signal_.shape[2:]) # st = time.time() # signal_fr = rfftn(signal_, dim=tuple(range(2, signal.ndim))) # kernel_fr = rfftn(padded_kernel, dim=tuple(range(2, signal.ndim))) signal_fr = rfftn(signal_, dim=tuple(range(2, signal.ndim))) kernel_fr = rfftn(padded_kernel, dim=tuple(range(2, signal.ndim))) # rfft_time = time.time() - st # print('rfft_time:', rfft_time) # print('Line: padded signal shape:', signal_.shape) # print('Line: signal_ shape:', signal_.shape) # print('Line: padded_kernel shape:', padded_kernel.shape) # print('Line: kernel_fr shape:', kernel_fr.shape) # st = time.time() kernel_fr.imag *= -1 # output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups)/torch.numel(signal_fr[0,0,0,:]) # print('KOOOOME:', output_fr.shape) # print('KOOOOME:', output_fr) # x = signal_fr[0,:] # print(x) # output_fr = my_complex_matmul(signal_fr, kernel_fr, groups=groups)/torch.numel(signal_fr[0,0,0,:]) # output_fr = my_complex_matmul(signal_fr, kernel_fr, groups=groups) output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups) output = output_fr # print('output.shape:', output.shape) # matmul_time = time.time() - st # print('matmul_time:', matmul_time) # st = time.time() output = irfftn(output_fr, dim=tuple(range(4, signal.ndim + 2))) # print('output irfftn .shape:', output.shape) # output = ifftn(output_fr, dim=tuple(range(3, signal.ndim+1))) # inverse_time = time.time() - st # print('inverse_time:', inverse_time) # st = time.time() # Remove extra padded values # print('signal, kernel, padding', signal.shape, kernel.shape, padding) crop_slices = [ slice(0, output.size(0)), slice(0, output.size(1)), slice(0, output.size(2)), slice(0, output.size(3)) ] + [ slice(padding_[i - 3] - 1, (signal.size(i - 1) - kernel.size(i - 1) - padding_[i - 3] + 2), stride_[i - 3]) for i in range(3, signal.ndim + 1) ] # crop_slices = # print('crop_slices:', crop_slices) # print('my output before croping:', output.shape) output = output[crop_slices].contiguous() # output = output[:,:,:,1:].contiguous() # print('output after crop:', output.shape) # print('output norm 2:', torch.norm(output)) # crop_time = time.time() - st # print('crop_time:', crop_time) # Optionally, add a bias term before returning. if bias is not None: bias_shape = tuple([1, -1] + (signal.ndim - 2) * [1]) output += bias.view(bias_shape) return output