def rfft(input_tensor, signal_ndim=1, n=None, dim=-1, norm=None) -> torch.Tensor: check_fft_version() if "torch.fft" not in sys.modules: return torch.rfft(input_tensor, signal_ndim=signal_ndim) else: return torch.fft.rfft(input_tensor, n, dim, norm)
def forward(self, x): # batch_size x rho[0] + 2(len(rho) - 1)*rho[0] batch_len = len(x) mask = self.mask.to(x.device) grouped = scatter_add(x, mask).reshape(batch_len, len(self.rhos), self.num, 2).transpose(1, 2) signal = torch.irfft(grouped, 1) result = torch.rfft(self.nonlinearity(signal), 1, onesided=True).transpose(1, 2) result = result.reshape(batch_len, -1)[:, mask] return result
def eval_torch_rfft2d(x, runs): for i in range(100): a = torch.rfft(x, signal_ndim=2, onesided=True) torch.cuda.synchronize() tt = time.time() for i in range(runs): a = torch.rfft(x, signal_ndim=2, onesided=True) torch.cuda.synchronize() print("torch.rfft2d takes %.7f ms" % ((time.time()-tt)/runs*1000)) b = torch.irfft(a, signal_ndim=2, onesided=True, signal_sizes=x.shape) torch.cuda.synchronize() tt = time.time() for i in range(runs): b = torch.irfft(a, signal_ndim=2, onesided=True, signal_sizes=x.shape) torch.cuda.synchronize() print("torch.irfft2d takes %.7f ms" % ((time.time()-tt)/runs*1000)) print("")
def calculate_energy(frames): """ Calculate energy of each frame by rfft. :param frame: (nframes, framelen) :return: (nframes, framelen//2) or (nframes, framelen//2 - 1) that equals to half frequencies """ mag = torch.norm(torch.rfft(frames, 1), dim=2)[:, 1:] energy = mag**2 return energy
def apply(self, x): x = torch.rfft(x, signal_ndim=2, normalized=False, onesided=False) x = batch_fftshift2d(x) x = x * self.mask x = batch_ifftshift2d(x) x = torch.irfft(x, signal_ndim=2, normalized=False, onesided=False) return x
def make_cepts2(X, T_pi): """Calculate the squared real cepstral coefficents.""" Y = F.unfold(X, kernel_size=[T_pi, 1], stride=T_pi) Y = torch.transpose(Y, 1, 2) # Compute the power spectral density window = torch.Tensor(hann(Y.shape[-1])[np.newaxis, np.newaxis]).type(Y.dtype) Yf = torch.rfft(Y * window, 1, onesided=True) spect = Yf[:, :, :, 0]**2 + Yf[:, :, :, 1]**2 spect = spect.mean(dim=1) spect = torch.cat([torch.flip(spect[:, 1:], dims=(1,)), spect], dim=1) # Log of the DFT of the autocorrelation logspect = torch.log(spect) - np.log(float(Y.shape[-1])) # Compute squared cepstral coefs (b_k^2) cepts = torch.rfft(logspect, 1, onesided=True) / float(Y.shape[-1]) cepts = torch.sqrt(cepts[:, :, 0]**2 + cepts[:, :, 1]**2) return cepts**2
def ccorr(a, b): """ Compute circular correlation of two tensors. Parameters ---------- a: Tensor, 1D or 2D b: Tensor, 1D or 2D Notes ----- Input a and b should have the same dimensions. And this operation supports broadcasting. Returns ------- Tensor, having the same dimension as the input a. """ return th.irfft(com_mult(conj(th.rfft(a, 1)), th.rfft(b, 1)), 1, signal_sizes=(a.shape[-1], ))
def FourierMod2_nopad(a): [n_batch,n_c,ha,wa]=a.shape mydevice=a.device assert n_c==1, "Only grayscale currently supported" a=a.view(n_batch,ha,wa) A=torch.rfft(a,signal_ndim=2,onesided=False,normalized=False) Ar = A[:, :, :, 0] Ai = A[:, :, :, 1] Aabs2=Ar.abs()**2+Ai.abs()**2#Unlike the definition used in xcorr2, Aabs2 is not complex here. return Aabs2.reshape([n_batch,n_c,ha,wa]), Ar.reshape([n_batch,n_c,ha,wa]), Ai.reshape([n_batch,n_c,ha,wa])
def dct1(x): """ Discrete Cosine Transform, Type I :param x: the input signal :return: the DCT-I of the signal over the last dimension """ x_shape = x.shape x = x.view(-1, x_shape[-1]) return torch.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1), 1)[:, :, 0].view(*x_shape)
def data_solution(x, FB, FBC, F2B, FBFy, alpha, sf): FR = FBFy + torch.rfft(alpha*x, 2, onesided=False) x1 = cmul(FB, FR) FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False) invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False) invWBR = cdiv(FBR, csum(invW, alpha)) FCBinvWBR = cmul(FBC, invWBR.repeat(1, 1, sf, sf, 1)) FX = (FR-FCBinvWBR)/alpha.unsqueeze(-1) Xest = torch.irfft(FX, 2, onesided=False) return Xest
def fft_old(z, x, label): # [batch, 32, 121, 61, 2] zf = torch.rfft(z, signal_ndim=2) xf = torch.rfft(x, signal_ndim=2) # [batch, 1, 121, 61, 1] kzzf = torch.sum(torch.sum(zf**2, dim=4, keepdim=True), dim=1, keepdim=True) # [batch, 1, 121, 61, 2] t = cn.mulconj(xf, zf) kxzf = torch.sum(t, dim=1, keepdim=True) # [batch, 1, 121, 61, 2] alphaf = label.to(device=z.device) / (kzzf + lambda0) # [batch, 1, 121, 121] return torch.irfft(cn.mul(kxzf, alphaf), signal_ndim=2)
def D(x, Dh_DFT, Dv_DFT): x_DFT = torch.rfft(x, signal_ndim=1, onesided=False) x_DFT = torch.view_as_complex(x_DFT).cuda() Dh_x = torch.irfft(torch.view_as_real(Dh_DFT * x_DFT), signal_ndim=1, onesided=False) Dv_x = torch.irfft(torch.view_as_real(Dv_DFT * x_DFT), signal_ndim=1, onesided=False) return Dh_x, Dv_x
def forward(self, x, edge_index, edge_weight=None): """""" edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) row, col = edge_index num_nodes, num_edges, order_filter = x.size(0), row.size(0), self.weight.size(0) if edge_weight is None: edge_weight = x.new_ones((num_edges,)) edge_weight = edge_weight.view(-1) assert edge_weight.size(0) == edge_index.size(1) deg = degree(row, num_nodes, dtype=x.dtype) # Compute normalized and rescaled Laplacian. deg = deg.pow(-0.5) deg[deg == float('inf')] = 0 lap = -deg[row] * edge_weight * deg[col] def weight_mult(x, w): y = torch.einsum('fgrs,ifrs->igrs', w, x) return y def lap_mult(edge_index, lap, x): L = torch.sparse.IntTensor(edge_index, lap, torch.Size([x.shape[0], x.shape[0]])).to_dense() x_tilde = torch.einsum('ij,ifrs->jfrs', L, x) return x_tilde # Perform filter operation recurrently. horizon = x.shape[1] x = x.permute(0, 2, 1) x_hat = torch.rfft(x, 1, normalized=True, onesided=True) Tx_0 = x_hat y_hat = weight_mult(Tx_0, self.weight[0, :]) if order_filter > 1: Tx_1 = lap_mult(edge_index, lap, x_hat) y_hat = y_hat + weight_mult(Tx_1, self.weight[1, :]) for k in range(2, order_filter): Tx_2 = 2 * lap_mult(edge_index, lap, Tx_1) - Tx_0 y_hat = y_hat + weight_mult(Tx_2, self.weight[k, :]) Tx_0, Tx_1 = Tx_1, Tx_2 y = torch.irfft(y_hat, 1, normalized=True, onesided=True, signal_sizes=(horizon,)) y = y.permute(0, 2, 1) if self.bias is not None: y = y + self.bias return y
def initGT(self, device): if 'Ribo' in self.args.dataset: self.GroundTruth=torch.Tensor( mrcfile.open('./Datasets/Ribo/emd_2660.mrc').data).to(device) self.gen.X.data=self.GroundTruth elif 'Betagal' in self.args.dataset: print("Using 2.5 A molmap") fittedBetagal=torch.Tensor( mrcfile.open('./Datasets/Betagal-Synthetic/fitted_betagal_2.5A.mrc').data).to(device) fittedBetagal=torch.nn.functional.avg_pool3d(fittedBetagal.unsqueeze(0).unsqueeze(0), kernel_size=self.args.DownSampleRate, stride=self.args.DownSampleRate, padding=0).squeeze() n=self.args.VolumeSize GroundTruth= self.ExpandVolume( fittedBetagal, n, device) self.gen.X.data=GroundTruth.unsqueeze(0) self.GroundTruth=self.gen.X.data elif 'ABC' in self.args.dataset: n=self.args.VolumeSize if self.args.VolumeNumbers==1: vol=torch.Tensor( mrcfile.open('./Datasets/ABC-Synthetic/fitted_ABC.mrc').data).to(device) self.gen.X.data=self.ExpandVolume( vol, n, device).unsqueeze(0) else: for i in range(self.args.VolumeNumbers): num=4773+2*i vol=torch.Tensor( mrcfile.open('./Datasets/ABC/fitted_'+str(num)+'.mrc').data).to(device)/4 self.gen.X.data[i]= self.ExpandVolume( vol, n, device).unsqueeze(0) self.GroundTruth=self.gen.X.data elif 'proteasome' in self.args.dataset: n=self.args.VolumeSize vol=torch.Tensor( mrcfile.open('./Datasets/proteasome-Synthetic/fitted_proteasome.mrc').data).to(device) self.gen.X.data[0]= self.ExpandVolume( vol, n, device) self.GroundTruth=self.gen.X.data elif 'serotonin' in self.args.dataset: n=self.args.VolumeSize vol=torch.Tensor( mrcfile.open('./Datasets/serotonin-Synthetic/fitted_serotonin.mrc').data).to(device) self.gen.X.data[0]= self.ExpandVolume( vol, n, device) self.GroundTruth=self.gen.X.data if self.args.VolumeDomain =='fourier' and self.args.FourierProjector==True: self.gen.X.data=fftshift(torch.rfft(ifftshift(self.GroundTruth, mode='real', signal_dim=3), 3, onesided=False),mode='complex', signal_dim=3) print("Fourier projector")
def fft_conv2d(x, weight, bias=None): b, c, h, w = x.shape n_out, n_in, k, k = weight.shape total_pad_w, total_pad_h = k - 1, k - 1 pad_lw, pad_lh = total_pad_w // 2, total_pad_h // 2 pad_rw, pad_rh = total_pad_w - pad_lw, total_pad_h - pad_lh x = F.pad(x, (pad_lw, pad_rw, pad_lh, pad_rh)) b, c, h, w = x.shape # Get it again in case we padded fft_pad_lw, fft_pad_lh = (w - 1) // 2, (h - 1) // 2 fft_pad_rw, fft_pad_rh = w - 1 - fft_pad_lw, h - 1 - fft_pad_lh weight_pad_lw, weight_pad_lh = (w - k) // 2, (h - k) // 2 weight_pad_rw, weight_pad_rh = w - k - weight_pad_lw, h - k - weight_pad_lh weight_padded = F.pad( weight, ( weight_pad_lw + fft_pad_lw, weight_pad_rw + fft_pad_rw, weight_pad_lh + fft_pad_lh, weight_pad_rh + fft_pad_rh, ), ) x = F.pad(x, (fft_pad_lw, fft_pad_rw, fft_pad_lh, fft_pad_rh)) x_fft = torch.rfft(x, 2) weight_fft = torch.rfft(weight_padded, 2) result_fft = compl_mul(x_fft, weight_fft) result = torch.irfft(result_fft, 2, signal_sizes=(x.shape[-2], x.shape[-1])) b, c, h, w = result.shape # Get it again in case we unpadded if bias is not None: result += bias res = result.clone() res[:, :, :int(np.floor(h / 2)), :] = result[:, :, int(np.ceil(h / 2)):, :] res[:, :, int(np.floor(h / 2)):, :] = result[:, :, :int(np.ceil(h / 2)), :] result = res.clone() res[:, :, :, :int(np.floor(w / 2))] = result[:, :, :, int(np.ceil(w / 2)):] res[:, :, :, int(np.floor(w / 2)):] = result[:, :, :, :int(np.ceil(w / 2))] res = res[:, :, pad_lh + fft_pad_lh:h - pad_rh - fft_pad_rh, pad_lw + fft_pad_lw:w - pad_rw - fft_pad_rw, ].contiguous() return res
def synthesis(s_target, test_id, ind, scat, n, min_error, err_it, nit, is_complex = False, initial_type = 'gaussian'): print('shape of s:', s_target.size()) # s_target = s_target / torch.sum(s_target) if initial_type == 'gaussian': x0 = torch.randn(n,n) elif initial_type == 'uniform': x0 = torch.rand(n,n) if torch.cuda.is_available(): s_target = s_target.cuda() x0 = x0.cuda() x0 = Variable(x0, requires_grad=True) x0_hat = torch.rfft(x0, 2, onesided = False) s0 = scat(x0_hat) loss = nn.MSELoss() optimizer = optim.Adam([x0], lr=lr) output = loss(s_target, s0) l0 = output error = [] count = 0 while output / l0 > min_error: optimizer.zero_grad() x0_hat = torch.rfft(x0, 2, onesided = False) s0 = scat(x0_hat) output = loss(s_target, s0) if count % err_it ==0: error.append(output.item()) output.backward() optimizer.step() output = loss(s_target, s0) if count % nit == 0: print(output.data.cpu().numpy()) np.save('./result%d/syn_error%d.npy'%(test_id, ind), np.asarray(error)) np.save('./result%d/syn_result%d.npy'%(test_id, ind), x0.data.cpu().numpy()) # plot_image(x0.data.cpu().numpy(), test_id, ind, count, nit) count += 1 print('error reduced by: ', output / l0) print('error supposed reduced by: ', min_error) np.save('./result%d/syn_error%d.npy'%(test_id, ind), np.asarray(error)) np.save('./result%d/syn_result%d.npy'%(test_id, ind), x0.data.cpu().numpy())
def CACF_update(self, inputs, lr=1.): inputs[0] = self.feature(inputs[0]) * self.config.cos_window inputs[1] = self.feature(inputs[1]) * self.config.cos_window inputs[2] = self.feature(inputs[2]) * self.config.cos_window inputs[3] = self.feature(inputs[3]) * self.config.cos_window inputs[4] = self.feature(inputs[4]) * self.config.cos_window zf = torch.rfft(inputs[0], signal_ndim=2) # target region cf1 = torch.rfft(inputs[1], signal_ndim=2) # contex 1 region cf2 = torch.rfft(inputs[2], signal_ndim=2) # contex 2 region cf3 = torch.rfft(inputs[3], signal_ndim=2) # contex 3 region cf4 = torch.rfft(inputs[4], signal_ndim=2) # contex 4 region kzzf = torch.sum(torch.sum(zf**2, dim=4, keepdim=True), dim=1, keepdim=True) kccf1 = torch.sum(torch.sum(cf1**2, dim=4, keepdim=True), dim=1, keepdim=True) kccf2 = torch.sum(torch.sum(cf2**2, dim=4, keepdim=True), dim=1, keepdim=True) kccf3 = torch.sum(torch.sum(cf3**2, dim=4, keepdim=True), dim=1, keepdim=True) kccf4 = torch.sum(torch.sum(cf4**2, dim=4, keepdim=True), dim=1, keepdim=True) if lr > 0.99: alphaf = self.config.yf / (kzzf + self.config.lambda0) else: alphaf = self.config.yf / (kzzf + self.config.lambda0 + self.config.lambda1 * (kccf1 + kccf2 + kccf3 + kccf4)) if lr > 0.99: self.model_alphaf = alphaf self.model_zf = zf else: self.model_alphaf = ( 1 - lr) * self.model_alphaf.data + lr * alphaf.data self.model_zf = (1 - lr) * self.model_zf.data + lr * zf.data
def forward(ctx, h1, s1, h2, s2, output_size, x, y, force_cpu_scatter_add=False): ctx.save_for_backward(h1, s1, h2, s2, x, y) ctx.x_size = tuple(x.size()) ctx.y_size = tuple(y.size()) ctx.force_cpu_scatter_add = force_cpu_scatter_add ctx.output_size = output_size # Compute the count sketch of each input px = CountSketchFn_forward(h1, s1, output_size, x, force_cpu_scatter_add) fx = torch.rfft(px, 1) re_fx = fx.select(-1, 0) im_fx = fx.select(-1, 1) del px py = CountSketchFn_forward(h2, s2, output_size, y, force_cpu_scatter_add) fy = torch.rfft(py, 1) re_fy = fy.select(-1, 0) im_fy = fy.select(-1, 1) del py # Convolution of the two sketch using an FFT. # Compute the FFT of each sketch # Complex multiplication re_prod, im_prod = ComplexMultiply_forward(re_fx, im_fx, re_fy, im_fy) # Back to real domain # The imaginary part should be zero's re = torch.irfft(torch.stack((re_prod, im_prod), re_prod.dim()), 1, signal_sizes=(output_size, )) return re
def forward(self, input, FB, FBC, F2B, FBFy, alpha, sf): alpha = alpha[:, 1:2, ...] FR = FBFy + torch.rfft(alpha * input, 2, onesided=False) x1 = cmul(FB, FR) FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False) invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False) invWBR = cdiv(FBR, csum(invW, alpha)) FCBinvWBR = cmul(FBC, invWBR.repeat(1, 1, sf, sf, 1)) FX = (FR - FCBinvWBR) / alpha.unsqueeze(-1) Xest = torch.irfft(FX, 2, onesided=False) return Xest
def rfft(t): # Real-to-complex Discrete Fourier Transform ver = torch.__version__ major, minor, ver = ver.split('.') ver_int = int(major) * 100 + int(minor) if ver_int >= 108: ft = torch.fft.fft2(t) ft = torch.stack([torch.real(ft), torch.imag(ft)], dim=-1) else: ft = torch.rfft(t, 2, onesided=False) return ft
def forward(self, x): #Compute Fourier coeffcients up to factor of e^(- something constant) x_ft = torch.rfft(x, 1, normalized=True, onesided=True) # Multiply relevant Fourier modes out_ft = torch.zeros(x_ft.shape, device=x.device) out_ft[:, :self.modes1, :] = x_ft[:, :self.modes1, :] #Return to physical space x = torch.irfft(out_ft, 1, normalized=True, onesided=True, signal_sizes=(x.size(-1), )) return x
def forward(ctx, params, luts, inverse, mv): ctx.params = params ctx.luts = luts ctx.inverse = inverse sh = mv.shape spatial_dim = len(sh)-2 Fmv = torch.rfft(mv, spatial_dim, normalized=True) lagomorph_cuda.fluid_operator(Fmv, inverse, luts['cos'], luts['sin'], *params) return torch.irfft(Fmv, spatial_dim, normalized=True, signal_sizes=sh[2:])
def rfft2(data): # (H x W) or (C x H x W) or (B x C x H x W) assert data.shape[-1] == data.shape[-2] assert (len(data.shape) == 2 or (len(data.shape) == 3 and data.shape[0] == 1) or (len(data.shape) == 4 and data.shape[1] == 1)) data = ifftshift(data, dim=(-2, -1)) data = torch.rfft(data, 2, normalized=True, onesided=False) # Now complex valued with dim -1 as [real, imaginary] dimension data = fftshift(data, dim=(-3, -2)) return data
def extract_power(self, frames): t_frames = torch.Tensor(frames).to(self.opts['device']) t_frames -= t_frames.mean(2, True) t_frames = t_frames * self.window t_frames = F.pad(t_frames, (0, self.next_power - self.opts['win_len'])) spect = torch.rfft(t_frames, 1) power_spect = torch.pow(spect[:, :, :, 0], 2.0) + torch.pow( spect[:, :, :, 1], 2.0) # c1 =torch.matmul(power_spect, self.melbank) return power_spect
def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: weight_decay = group['weight_decay'] momentum = group['momentum'] dampening = group['dampening'] nesterov = group['nesterov'] sigma = group['sigma'] index = 0 for p in group['params']: if p.grad is None: continue # Laplacian smoothing gradient if sigma > 0: if self.cs[index] is not None: g = p.grad.data.view(-1) fg = torch.rfft(g, 1, onesided=False) g = torch.zeros_like(fg) self.cs[index] = self.cs[index].to(fg.device) g = complex_div(fg, self.cs[index]) g = torch.irfft(g, 1, onesided=False).view( p.grad.data.size()) p.grad.data = g index += 1 d_p = p.grad.data if weight_decay != 0: d_p.add_(weight_decay, p.data) if momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: buf = param_state['momentum_buffer'] = torch.clone( d_p).detach() else: buf = param_state['momentum_buffer'] buf.mul_(momentum).add_(1 - dampening, d_p) if nesterov: d_p = d_p.add(momentum, buf) else: d_p = buf p.data.add_(-group['lr'], d_p) return loss
def test_rfft(signal_ndim, normalized, onesided): input = torch.randn(4, 3, 8, 8) assert torch.allclose(torch.rfft(input, signal_ndim, normalized, onesided), utils._rfft(input, signal_ndim, normalized, onesided), atol=1e-4) assert torch.allclose(utils._irfft( utils._rfft(input, signal_ndim, normalized, onesided), signal_ndim, normalized, onesided), input, atol=1e-4)
def extract_fbank(self, wavs): t_frames = self.enframes(wavs) t_frames -= t_frames.mean(2, True) t_frames = t_frames * self.window t_frames = F.pad(t_frames, (0, self.next_power - self.opts['win_len'])) spect = torch.rfft(t_frames, 1) power_spect = torch.pow(spect[:, :, :, 0], 2.0) + torch.pow( spect[:, :, :, 1], 2.0) c1 = torch.matmul(power_spect, self.melbank) return c1
def forward(self, x): fftfull = torch.rfft(x,2) xreal = fftfull[... , 0] xim = fftfull[... ,1] x = torch.cat((xreal.unsqueeze(1), xim.unsqueeze(1)), 1 ).unsqueeze( -3 ) x = torch.index_select( x, -2, self.indF[0] ) x = self.hl0 * x h0f = x.select( -3, 0 ).unsqueeze( -3 ) l0f = x.select( -3, 1 ).unsqueeze( -3 ) lf = l0f output = [] for n in range( self.N ): bf = self.b[n] * lf lf = self.l[n] * central_crop( lf ) if self.hilb: hbf = self.s[n] * torch.cat( (bf.narrow(1,1,1), -bf.narrow(1,0,1)), 1 ) bf = torch.cat( ( bf , hbf ), -3 ) if self.includeHF and n == 0: bf = torch.cat( ( h0f, bf ), -3 ) output.append( bf ) output.append( lf ) for n in range( len( output ) ): output[n] = torch.index_select( output[n], -2, self.indB[n] ) sig_size = [output[n].shape[-2],(output[n].shape[-1]-1)*2] output[n] = torch.stack((output[n].select(1,0), output[n].select(1,1)),-1) output[n] = torch.irfft( output[n], 2, signal_sizes = sig_size) if self.includeHF: output.insert( 0, output[0].narrow( -3, 0, 1 ) ) output[1] = output[1].narrow( -3, 1, output[1].size(-3)-1 ) for n in range( len( output ) ): if self.hilb: if ((not self.includeHF) or 0 < n) and n < len(output)-1: nfeat = output[n].size(-3)//2 o1 = output[n].narrow( -3, 0, nfeat ).unsqueeze(1) o2 = -output[n].narrow( -3, nfeat, nfeat ).unsqueeze(1) output[n] = torch.cat( (o2, o1), 1 ) else: output[n] = output[n].unsqueeze(1) for n in range( len( output ) ): if n>0: output[n] = output[n]*(2**(n-1)) return output
def p2o(psf, shape): otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf) otf[..., :psf.shape[2], :psf.shape[3]].copy_(psf) for axis, axis_size in enumerate(psf.shape[2:]): otf = torch.roll(otf, -int(axis_size / 2), dims=axis + 2) otf = torch.rfft(otf, 2, onesided=False) n_ops = torch.sum(psf.size * torch.log2(psf.shape)) otf[..., 1][torch.asb(otf[..., 1]) < n_ops * 2.22e-16] = torch.tensor(0).float() return otf
def forward(self, x): input_size = x.shape[2] projection_size_padded = \ max(64, int(2 ** (2 * torch.tensor(input_size)).float().log2().ceil())) pad_width = projection_size_padded - input_size padded_tensor = F.pad(x, (0,0,0,pad_width)) f = self._get_fourier_filter(padded_tensor.shape[2]).to(x.device) fourier_filter = self.create_filter(f) fourier_filter = fourier_filter.unsqueeze(-2) projection = torch.rfft(padded_tensor.transpose(2,3), 1, onesided=False).transpose(2,3) * fourier_filter return torch.irfft(projection.transpose(2,3), 1, onesided=False).transpose(2,3)[:,:,:input_size,:]