def circular_convolution_fft(keys, values, normalized=True, conj=False, cuda=False): ''' For the circular convolution of x and y to be equivalent, you must pad the vectors with zeros to length at least N + L - 1 before you take the DFT. After you invert the product of the DFTs, retain only the first N + L - 1 elements. ''' assert values.dim() == keys.dim() == 2, "only 2 dims supported" assert values.size(-1) % 2 == keys.size(-1) % 2 == 0, "need last dim to be divisible by 2" batch_size, keys_feature_size = keys.size(0), keys.size(1) values_feature_size = values.size(1) required_size = keys_feature_size + values_feature_size - 1 required_size = required_size + 1 if required_size % 2 != 0 else required_size # conj transpose keys = Complex(keys).conj().unstack() if conj else keys # reshape to [batch, [real, imag]] half = keys.size(-1) // 2 keys = torch.cat([keys[:, 0:half].unsqueeze(2), keys[:, half:].unsqueeze(2)], -1) values = torch.cat([values[:, 0:half].unsqueeze(2), values[:, half:].unsqueeze(2)], -1) # do the fft, ifft and return num_required kf = torch.fft(keys, signal_ndim=1, normalized=normalized) vf = torch.fft(values, signal_ndim=1, normalized=normalized) kvif = torch.ifft(kf*vf, signal_ndim=1, normalized=normalized)#[:, 0:required_size] # if conj: # return Complex(kvif[:, :, 1], kvif[:, :, 0]).unstack() #return Complex(kvif[:, :, 0], kvif[:, :, 1]).abs() if not conj \ # return Complex(kvif[:, :, 0], kvif[:, :, 1]).unstack() # if not conj \ # else Complex(kvif[:, :, 1], kvif[:, :, 0]).abs() return Complex(kvif[:, :, 0], kvif[:, :, 1]).unstack().view(batch_size, -1)
def so3_ifft(x, for_grad=False, b_out=None): ''' :param x: [l * m * n, ..., complex] ''' assert x.size(-1) == 2 nspec = x.size(0) b_in = round((3 / 4 * nspec)**(1 / 3)) assert nspec == b_in * (4 * b_in**2 - 1) // 3 if b_out is None: b_out = b_in batch_size = x.size()[1:-1] x = x.view(nspec, -1, 2) # [l * m * n, batch, complex] (nspec, nbatch, 2) ''' :param x: [l * m * n, batch, complex] (b_in (4 b_in**2 - 1) // 3, nbatch, 2) :return: [batch, beta, alpha, gamma, complex] (nbatch, 2 b_out, 2 b_out, 2 b_out, 2) ''' nbatch = x.size(1) wigner = _setup_wigner( b_out, nl=b_in, weighted=for_grad, device=x.device) # [beta, l * m * n] (2 * b_out, nspec) output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2 * b_out, 2)) if x.is_cuda and x.dtype == torch.float32: cuda_kernel = _setup_so3ifft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_output=False, device=x.device.index) cuda_kernel(x, wigner, output) # [batch, beta, m, n, complex] else: output.fill_(0) for l in range(min(b_in, b_out)): s = slice(l * (4 * l**2 - 1) // 3, l * (4 * l**2 - 1) // 3 + (2 * l + 1)**2) out = torch.einsum("mnzc,bmn->zbmnc", (x[s].view(2 * l + 1, 2 * l + 1, -1, 2), wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1))) l1 = min(l, b_out - 1) # if b_out < b_in output[:, :, :l1 + 1, :l1 + 1] += out[:, :, l:l + l1 + 1, l:l + l1 + 1] if l > 0: output[:, :, -l1:, :l1 + 1] += out[:, :, l - l1:l, l:l + l1 + 1] output[:, :, :l1 + 1, -l1:] += out[:, :, l:l + l1 + 1, l - l1:l] output[:, :, -l1:, -l1:] += out[:, :, l - l1:l, l - l1:l] output = torch.ifft( output, 2) * output.size(-2)**2 # [batch, beta, alpha, gamma, complex] output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2 * b_out, 2) return output
def backward(ctx, grad_output): # extract saved tensors for gradient update device, fft_res, phasemask_term, phase_emitter = ctx.device, ctx.fft_res, ctx.phasemask_term, ctx.phase_emitter normfactor, Nbatch, Nemitters, H, W = ctx.normfactor, ctx.Nbatch, ctx.Nemitters, ctx.H, ctx.W # gradient w.r.t the single-emitter images grad_input = grad_output.data # depth-wise normalization factor for i in range(Nbatch): for j in range(Nemitters): grad_input[i, j, :, :] = grad_input[i, j, :, :] * normfactor[i, j] # gradient of abs squared grad_abs_square = torch.zeros((Nbatch, Nemitters, H, W, 2)).to(device) grad_abs_square[:, :, :, :, 0] = 2 * grad_input * fft_res[:, :, :, :, 0] grad_abs_square[:, :, :, :, 1] = 2 * grad_input * fft_res[:, :, :, :, 1] # calculate the centered inverse fourier transform on the H, W dims grad_fft = batch_fftshift2d( torch.ifft(batch_ifftshift2d(grad_abs_square), 2, True)) # gradient w.r.t phase mask phase_term grad_phasemask_term = torch.zeros( (Nbatch, Nemitters, H, W, 2)).to(device) grad_phasemask_term[:, :, :, :, 0] = grad_fft[:, :, :, :, 0] * phase_emitter[:, :, :, :, 0] + grad_fft[:, :, :, :, 1] * phase_emitter[:, :, :, :, 1] grad_phasemask_term[:, :, :, :, 1] = -grad_fft[:, :, :, :, 0] * phase_emitter[:, :, :, :, 1] + grad_fft[:, :, :, :, 1] * phase_emitter[:, :, :, :, 0] # gradient w.r.t the phasemask 4D grad_phasemask4D = -grad_phasemask_term[:, :, :, :, 0] * phasemask_term[:, :, :, :, 1] + grad_phasemask_term[:, :, :, :, 1] * phasemask_term[:, :, :, :, 0] # sum to get the final gradient grad_phasemask = grad_phasemask4D.sum(0).sum(0) return grad_phasemask, None, None, None
def forward_operator_from_real(x, mask): """ Forward operator for real images :param x: real input image :param mask: mask of radial lines :return: x_new """ x_new = torch.rfft(x, signal_ndim=3, onesided=False) / x.shape[1] x_new[:, :, :, 0] = torch.mul(torch.from_numpy(mask).float().cuda(), x_new[:, :, :, 0]) x_new[:, :, :, 1] = torch.mul(torch.from_numpy(mask).float().cuda(), x_new[:, :, :, 1]) x_new = torch.ifft(x_new, signal_ndim=3) * x.shape[1] return x_new
def forward_operator(x, mask): """ Forward operator for complex images :param x: complex input image :param mask: mask of radial lines :return: x_new """ x_new = torch.fft(x, signal_ndim=3) / x.shape[1] x_new[:, :, :, 0] = torch.mul(torch.from_numpy(mask).float().cuda(), x_new[:, :, :, 0]) x_new[:, :, :, 1] = torch.mul(torch.from_numpy(mask).float().cuda(), x_new[:, :, :, 1]) x_new = torch.ifft(x_new, signal_ndim=3) * x.shape[1] return x_new
def ifft2(tensor_re, tensor_im, shift=False): """Applies a 2D ifft to the complex tensor represented by tensor_re and _im""" tensor_out = torch.stack((tensor_re, tensor_im), 4) if shift: tensor_out = ifftshift(tensor_out) (tensor_out_re, tensor_out_im) = torch.ifft(tensor_out, 2, True).split(1, 4) tensor_out_re = tensor_out_re.squeeze(4) tensor_out_im = tensor_out_im.squeeze(4) return tensor_out_re, tensor_out_im
def ifft(var_real, var_imag, axis=-1, backend='autograd', normalize=False): if backend == 'autograd': var = var_real + 1j * var_imag norm = None if not normalize else 'ortho' var = anp.fft.ifft(var, axis=axis, norm=norm) return anp.real(var), anp.imag(var) elif backend == 'pytorch': var = tc.stack([var_real, var_imag], dim=-1) var = tc.ifft(var, signal_ndim=1, normalized=normalize) var_real, var_imag = tc.split(var, 1, dim=-1) slicer = [slice(None)] * (len(var_real.shape) - 1) + [0] return var_real[tuple(slicer)], var_imag[tuple(slicer)]
def cropMeasurements(self, measurements): # cropping measurements_cropped = torch.zeros(measurements.shape[0], self.Np_meas[0], self.Np_meas[1], 2) for img_idx in range(measurements.shape[0]): fmeas = utility.fftshift2(torch.fft(measurements[img_idx, ...], 2)) tmp = fmeas[self.crops[0]:self.crops[1], self.crops[2]:self.crops[3], :] measurements_cropped[img_idx, ...] = (1 / self.scaling) * torch.ifft( utility.ifftshift2(tmp), 2) return measurements_cropped
def fft(x): """ Layer that performs a fast Fourier-Transformation. """ img_size = x.size(1) // 2 # sort the incoming tensor in real and imaginary part arr_real = x[:, 0:img_size].reshape(-1, int(sqrt(img_size)), int(sqrt(img_size))) arr_imag = x[:, img_size:].reshape(-1, int(sqrt(img_size)), int(sqrt(img_size))) arr = torch.stack((arr_real, arr_imag), dim=-1) # perform fourier transformation and switch imaginary and real part arr_fft = torch.ifft(arr, 2).permute(0, 3, 2, 1).transpose(2, 3) return arr_fft
def s2_ifft(x, for_grad=False, b_out=None): ''' :param x: [l * m, ..., complex] ''' assert x.size(-1) == 2 nspec = x.size(0) b_in = round(nspec**0.5) assert nspec == b_in**2 if b_out is None: b_out = b_in assert b_out >= b_in batch_size = x.size()[1:-1] x = x.view(nspec, -1, 2) # [l * m, batch, complex] (nspec, nbatch, 2) ''' :param x: [l * m, batch, complex] (b_in**2, nbatch, 2) :return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2) ''' nbatch = x.size(1) wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device) wigner = wigner.view(2 * b_out, -1) # [beta, l * m] (2 * b_out, nspec) if x.is_cuda and x.dtype == torch.float32: import s2cnn.utils.cuda as cuda_utils cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out, nl=b_in, nbatch=nbatch, device=x.device.index) stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2)) cuda_kernel(block=(1024, 1, 1), grid=(cuda_utils.get_blocks(nbatch * (2 * b_out)**2, 1024), 1, 1), args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()], stream=stream) # [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2) else: output = x.new_zeros((nbatch, 2 * b_out, 2 * b_out, 2)) for l in range(b_in): s = slice(l**2, l**2 + 2 * l + 1) out = torch.einsum("mzc,bm->zbmc", (x[s], wigner[:, s])) output[:, :, :l + 1] += out[:, :, -l - 1:] if l > 0: output[:, :, -l:] += out[:, :, :l] output = torch.ifft(output, 1) * output.size( -2) # [batch, beta, alpha, complex] output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2) return output
def test_circulant(self): batch_size = 10 n = 13 for complex in [False, True]: dtype = torch.float32 if not complex else torch.complex64 col = torch.randn(n, dtype=dtype) C = la.circulant(col.numpy()) input = torch.randn(batch_size, n, dtype=dtype) out_torch = torch.tensor(input.detach().numpy() @ C.T) out_np = torch.tensor(np.fft.ifft( np.fft.fft(input.numpy()) * np.fft.fft(col.numpy())), dtype=dtype) self.assertTrue( torch.allclose(out_torch, out_np, self.rtol, self.atol)) # Just to show how to implement circulant multiply with FFT if complex: input_f = view_as_complex( torch.fft(view_as_real(input), signal_ndim=1)) col_f = view_as_complex( torch.fft(view_as_real(col), signal_ndim=1)) prod_f = complex_mul(input_f, col_f) out_fft = view_as_complex( torch.ifft(view_as_real(prod_f), signal_ndim=1)) self.assertTrue( torch.allclose(out_torch, out_fft, self.rtol, self.atol)) for separate_diagonal in [True, False]: b = torch_butterfly.special.circulant( col, transposed=False, separate_diagonal=separate_diagonal) out = b(input) self.assertTrue( torch.allclose(out, out_torch, self.rtol, self.atol)) row = torch.randn(n, dtype=dtype) C = la.circulant(row.numpy()).T input = torch.randn(batch_size, n, dtype=dtype) out_torch = torch.tensor(input.detach().numpy() @ C.T) # row is the reverse of col, except the 0-th element stays put # This corresponds to the same reversal in the frequency domain. # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Time_and_frequency_reversal row_f = np.fft.fft(row.numpy()) row_f_reversed = np.hstack((row_f[:1], row_f[1:][::-1])) out_np = torch.tensor(np.fft.ifft( np.fft.fft(input.numpy()) * row_f_reversed), dtype=dtype) self.assertTrue( torch.allclose(out_torch, out_np, self.rtol, self.atol)) for separate_diagonal in [True, False]: b = torch_butterfly.special.circulant( row, transposed=True, separate_diagonal=separate_diagonal) out = b(input) self.assertTrue( torch.allclose(out, out_torch, self.rtol, self.atol))
def forward(self, x, mask): x_dim_0 = x.shape[0] x_dim_1 = x.shape[1] x_dim_2 = x.shape[2] x_dim_3 = x.shape[3] x = x.view(-1, x_dim_2, x_dim_3, 1) y = torch.zeros_like(x) z = torch.cat([x, y], 3) fftz = torch.fft(z, 2) z_hat = torch.ifft(fftz * mask, 2) x = z_hat[:, :, :, 0:1] x = x.view(x_dim_0, x_dim_1, x_dim_2, x_dim_3) return x
def ifft2(data: torch.Tensor, dim: Tuple[str, ...] = ('height', 'width'), centered: bool = True) -> torch.Tensor: """ Apply centered two-dimensional Inverse Fast Fourier Transform Parameters ---------- data : torch.Tensor Complex-valued input tensor. dim : tuple, list or int Dimensions over which to compute. centered : bool Whether to apply a centered ifft (center of kspace is in the center versus in the corners). For FastMRI dataset this has to be true and for the Calgary-Campinas dataset false. Returns ------- torch.Tensor: the ifft of the output. """ assert_complex(data) if centered: data = ifftshift(data, dim=dim) names = data.names # TODO: Fix when ifft supports named tensors # Verify whether half precision and if ifft is possible in this shape. Else do a typecast. if verify_fft_dtype_possible(data, dim): data = torch.ifft(data.rename(None), 2, normalized=True) else: data = torch.ifft(data.rename(None).float(), 2, normalized=True).type(data.type()) if any(names): data = data.refine_names(*names) if centered: data = fftshift(data, dim=dim) return data
def forward(self, x_under, mask): x_under_per = x_under.permute(0, 2, 3, 1) x_zf_per = torch.ifft(x_under_per, 2) x_zf = x_zf_per.permute(0, 3, 1, 2) x_rec_dc = x_zf recimg = list() recimg.append(sigtoimage(x_zf)) for i, l in enumerate(self.layer): x_rec = self.layer[i](x_rec_dc) x_res = x_rec_dc + x_rec x_rec_dc = self.dc(mask, x_res, x_under) recimg.append(sigtoimage(x_rec_dc)) return recimg
def ifft2(data): """ Apply centered 2-dimensional Inverse Fast Fourier Transform. Args: data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are assumed to be batch dimensions. Returns: torch.Tensor: The IFFT of the input. """ assert data.size(-1) == 2 data = torch.ifft(data, 2, normalized=True) return data
def ifft(x): '''x.size()=[1, 2, h, w]''' #change to [n,1,h,w,2] x = complex_split(x) #center_to_topleft x = cen_cor(x) #ifft x = torch.ifft(x, 2, normalized=True) #topleft to center x = cor_cen(x) #merge back to [1, 2, h, w] x = complex_merge(x) return x.to(device)
def imgFromSubF_pytorch(subF, returnComplex=False): subIm = torch.ifft(subF, 2, normalized=True) if (len(subIm.shape) == 4): subIm = subIm.permute(0, 3, 1, 2) else: subIm = subIm.permute(0, 4, 1, 2, 3) if (returnComplex): return subIm else: subIm = torch.sqrt(subIm[:, 0:1] * subIm[:, 0:1] + subIm[:, 1:2] * subIm[:, 1:2]) return subIm
def generateMultiMeas(self, field, device="cpu"): self.pupils = self.pupils.to(self.device) self.planewaves = self.planewaves.to(self.device) self.P = self.P.to(self.device) output = mul_c(self.planewaves, field) output = torch.fft(output, 2) output = mul_c(self.P, output) output = torch.ifft(output, 2) output = abs2_c(output) multiMeas = torch.matmul(output.permute(1, 2, 0), self.C.permute(1, 0)).permute(2, 0, 1) return multiMeas
def forward(self, x_hat): # x_hat: n * n * 2 s = 1 / (2 * pi)**2 * torch.mean( torch.mean((x_hat[:, :, 0]**2 + x_hat[:, :, 1]**2) * (self.phi_hat_real**2 + self.phi_hat_imag**2), 0), 0) s = s.unsqueeze(0) J = self.psi_hat_real.size()[1] for i in range(J): temp_real = x_hat[:, :, 0] * self.psi_hat_real[:self.K, i, :, :] - x_hat[:, :, 1] * self.psi_hat_imag[:self . K, i, :, :] # first layer only use gabor wavelets which has K angles temp_imag = x_hat[:, :, 0] * self.psi_hat_imag[:self.K, i, :, :] + x_hat[:, :, 1] * self.psi_hat_real[:self . K, i, :, :] temp = torch.ifft( torch.cat((temp_real.unsqueeze(3), temp_imag.unsqueeze(3)), 3), 2) # K * n * n * 2 temp2 = torch.rfft(torch.sqrt(temp[:, :, :, 0]**2 + temp[:, :, :, 1]**2 + 1e-8), 2, onesided=False) # K * n * n * 2 a = 1 / (2 * pi)**2 * torch.mean( torch.mean( (temp2[:, :, :, 0]**2 + temp2[:, :, :, 1]**2) * (self.phi_hat_real**2 + self.phi_hat_imag**2), 2), 1) s = torch.cat((s, a), 0) if i < J - 1: temp3 = (temp2[:, :, :, 0]**2 + temp2[:, :, :, 1]**2).unsqueeze(1).unsqueeze(2) if self.second_all: temp4 = (self.psi_hat_real[:, :, :, :]**2 + self.psi_hat_imag[:, :, :, :]**2).unsqueeze(0) else: temp4 = ( self.psi_hat_real[:, (i + 1):J, :, :]**2 + self.psi_hat_imag[:, (i + 1):J, :, :]**2).unsqueeze(0) b = 1 / (2 * pi)**2 * torch.mean(torch.mean(temp3 * temp4, 4), 3) s = torch.cat((s, b.flatten()), 0) return s
def grad(self, field_est, device='cpu'): self.measurements = self.measurements.to(self.device) self.pupils = self.pupils.to(self.device) self.planewaves = self.planewaves.to(self.device) self.P = self.P.to(self.device) multiMeas = torch.matmul(self.measurements.permute(1, 2, 0), self.C.permute(1, 0)).permute(2, 0, 1) multiMeas = torch.abs(multiMeas) # simulate current estimate of measurements y = self.generateMultiMeas(field_est, device=device) # compute residual sqrty = torch.sqrt(y + EPS) residual = sqrty - torch.sqrt(multiMeas + EPS) cost = torch.sum(torch.pow(residual, 2)).detach() Ajx = residual / (sqrty + 1e-10) Ajx_c = torch.stack((Ajx, torch.zeros_like(Ajx)), dim=len(Ajx.shape)) # compute gradient output = mul_c(self.planewaves, field_est) output = torch.fft(output, 2) output = mul_c(self.P, output) output = torch.ifft(output, 2) g = field_est * 0. for meas_index in range(self.Nmeas): output2 = mul_c(Ajx_c[meas_index, ...], output) output2 = mul_c(conj(self.planewaves), output2) output2 = torch.fft(output2, 2) output2 = mul_c(self.pupils, output2) output2 = torch.ifft(output2, 2) g_tmp = torch.matmul(output2.permute(1, 2, 3, 0), self.C[meas_index, :]) g = g + g_tmp # return -1 * self.alpha * g, cost return g
def loss_QSMnet(outputs, QSMs, Masks, D): # l1 loss loss = lossL1() outputs = outputs[:, 0:1, ...] device = outputs.get_device() outputs_cplx = torch.zeros(*(outputs.size() + (2, ))).to(device) outputs_cplx[..., 0] = outputs QSMs_cplx = torch.zeros(*(QSMs.size() + (2, ))).to(device) QSMs_cplx[..., 0] = QSMs D = np.repeat(D[np.newaxis, np.newaxis, ..., np.newaxis], outputs.size()[0], axis=0) D_cplx = np.concatenate((D, np.zeros(D.shape)), axis=-1) D_cplx = torch.tensor(D_cplx, device=device).float() RDFs_outputs = torch.ifft(cplx_mlpy(torch.fft(outputs_cplx, 3), D_cplx), 3) RDFs_QSMs = torch.ifft(cplx_mlpy(torch.fft(QSMs_cplx, 3), D_cplx), 3) errl1 = loss(outputs * Masks, QSMs * Masks) errModel = loss(RDFs_outputs[..., 0] * Masks, RDFs_QSMs[..., 0] * Masks) errl1_grad = loss(abs(dxp(outputs)) * Masks, abs(dxp(QSMs)) * Masks) + loss( abs(dyp(outputs)) * Masks, abs(dyp(QSMs)) * Masks) + loss( abs(dzp(outputs)) * Masks, abs(dzp(QSMs)) * Masks) errModel_grad = loss( abs(dxp(RDFs_outputs[..., 0])) * Masks, abs(dxp(RDFs_QSMs[..., 0])) * Masks) + loss( abs(dyp(RDFs_outputs[..., 0])) * Masks, abs(dyp(RDFs_QSMs[..., 0])) * Masks) + loss( abs(dzp(RDFs_outputs[..., 0])) * Masks, abs(dzp(RDFs_QSMs[..., 0])) * Masks) errGrad = errl1_grad + errModel_grad return errl1 + errModel + 0.1 * errGrad
def Fourier_based_Corruption(dataset, imgsize, position): CUDA_AVAILABLE = torch.cuda.is_available() N = imgsize origin_value = 1.0 i, j = position[0], position[1] print("position({},{})".format(i, j)) testset = dataset samples_size = dataset.__len__() samples = np.array(range(samples_size)) loader = transforms.Compose([transforms.ToTensor()]) F_base_vec = torch.zeros( (N, N, 2)).cuda() if (CUDA_AVAILABLE) else torch.zeros((N, N, 2)) F_base_vec[i][j][0] = F_base_vec[i][j][0] = origin_value Uij = torch.ifft(F_base_vec, 2)[:, :, 0].cuda() if ( CUDA_AVAILABLE) else torch.ifft(F_base_vec, 2)[:, :, 0] Uij /= torch.norm(Uij, p=2) result = torch.zeros( (samples_size, 3, N, N)).cuda() if (CUDA_AVAILABLE) else torch.zeros( (samples_size, 3, N, N)) for k in range(samples_size): img = testset[samples[k]][0] img_array = loader(img) img_new_array = torch.zeros( (3, N, N)).cuda() if (CUDA_AVAILABLE) else torch.zeros((3, N, N)) for channel in range(3): img_one_channel = torch.Tensor(img_array[channel, :, :]) img_one_channel = img_one_channel.cuda() if ( CUDA_AVAILABLE) else img_one_channel L2norm = torch.norm(img_one_channel, p=2) * 0.1 r = 1 rvUij = Uij * L2norm * r img_one_channel += rvUij img_new_array[channel, :, :] = img_one_channel result[k] = img_new_array result = result.cpu() return result
def forward(self, x, y): x = x.squeeze(2) y = y.squeeze(2) x = x.permute([0, 2, 3, 1]) y = y.permute([0, 2, 3, 1]) cEs = self.batch_fftshift2d(torch.fft(x, 3, normalized=True)) cEsp = self.complex_mult(cEs, self.prop) S = torch.ifft(self.batch_ifftshift2d(cEsp), 3, normalized=True) Se = S[:, :, :, 0] mse = torch.mean(torch.abs(Se - y[:, :, :, 0])) / 2 return mse
def test_ifft_unitary(self): batch_size = 10 n = 16 input = torch.randn(batch_size, n, dtype=torch.complex64) normalized = True out_torch = view_as_complex( torch.ifft(view_as_real(input), signal_ndim=1, normalized=normalized)) for br_first in [True, False]: b = torch_butterfly.special.ifft_unitary(n, br_first=br_first) out = b(input) self.assertTrue( torch.allclose(out, out_torch, self.rtol, self.atol))
def ifft(x): # input is assumed to be a tensor of size mbs x n a = x[0] b = x[1] bs = a.size()[0] nu = a.size()[1] a2 = a.view(bs, 1, nu) a3 = torch.transpose(a2, 1, 2) b2 = b.view(bs, 1, nu) b3 = torch.transpose(b2, 1, 2).view(bs, nu, 1) x_in = torch.cat([a3,b3], dim=2) p = torch.ifft(x_in, 1, normalized=True) out_re = p[:,:,0].view(bs, nu) out_im = p[:,:,1].view(bs, nu) return (out_re, out_im)
def reverse(self, z, device='cpu'): with torch.no_grad(): if self.fullInvFlag: ys = torch.stack((self.y, torch.zeros_like(self.y)), 2) Fy = torch.fft(ys, 2) AHy = torch.ifft(mul_c(conj(self.fpsf), Fy), 2)[..., 0] aAHy = self.alpha * AHy xkpaAHy = z - aAHy ts = torch.stack((xkpaAHy, torch.zeros_like(xkpaAHy)), 2) Ft = torch.fft(ts, 2) AHA = mul_c(conj(self.fpsf), self.fpsf) I = torch.zeros_like(AHA) I[..., 0] = 1 ImaAHA = I - self.alpha * AHA x = torch.ifft(div_c(Ft, ImaAHA), 2) return x[..., 0] else: x = z for _ in range(self.T): x = z - self.step(x) return x
def wiener_filt_torch(Y, R, Np, batch_dim=False): if not batch_dim: Y, R = Y.unsqueeze(0), R.unsqueeze(0) Yc = complexify(Y) W = torch.ifft(Yc, signal_ndim=2) n, s1, s2 = R.shape s2 *= 3 S_auto = torch_fft_shift(W, dims=(1, 2)) XR_cross = S_auto[:, :, :2 * R.shape[1]] _, t1, t2, _ = XR_cross.shape R_ = torch.zeros((n, t1, t2, 2), dtype=torch.float, device=W.device) R_[:, :s1, :R.shape[2], 0] = R F_R = torch.fft(R_, signal_ndim=2) F_R_conj = torch.clone(F_R) F_R_conj[..., 1] *= -1 F_R_abs = (F_R**2).sum(-1, keepdim=True) F_SNR = Np * torch.norm(Y, dim=(1, 2))**2 / torch.norm(Y, p=1, dim=(1, 2))**2 print('XR_cross.shape!', XR_cross.shape, torch.norm(Y, dim=(1, 2)).shape) F_SNR = F_SNR.reshape(n, 1, 1, 1) X_ = torch.ifft(complex_mult(torch.fft(XR_cross, signal_ndim=2), F_R) / (F_R_abs + 1 / F_SNR), signal_ndim=2) print('X_.shape', X_.shape, R.shape[1], R.shape[2]) X = X_[:, R.shape[1]:, R.shape[2]:] if not batch_dim: X = X[0] return X
def forward(self, x): if self.dr is not None: x = self.conv_dr_block(x) bsn = 1 batchSize, dim, h, w = x.data.shape x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim) # batchsize,h, w, dim, y = torch.ones(batchSize, self.output_dim, device=x.device) for img in range(batchSize // bsn): segLen = bsn * h * w upper = batchSize * h * w interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long) interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long) batch_x = x_flat[interLarge, :] sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2) sketch1 = torch.fft( torch.cat( (sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1) sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2) sketch2 = torch.fft( torch.cat( (sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1) Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul( sketch2[:, :, 1]) Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul( sketch2[:, :, 0]) tmp_y = torch.ifft( torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0] y[interSmall, :] = tmp_y.view( torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1) y = self._signed_sqrt(y) y = self._l2norm(y) return y
def phase_correlation(a, b): B, H, W = a.size() a = a.unsqueeze(dim=-1).expand(B, H, W, 2) b = b.unsqueeze(dim=-1).expand(B, H, W, 2) G_a = torch.fft(a, signal_ndim=2) G_b = torch.fft(b, signal_ndim=2) conj_b = torch.conj(G_b) R = G_a * conj_b R /= torch.abs(R) r = torch.ifft(R, signal_ndim=2) r = torch.split(r, 1, dim=-1)[0].squeeze(-1) shift = r.view(B, -1).argmax(dim=1) shift = torch.cat(((shift / W).view(-1, 1), (shift % W).view(-1, 1)), dim=1) return shift
def filterProjections(radon_img, filter_mode, d=1.): length = radon_img.size(0) H = designFilter(filter_mode, length, d) p = torch.zeros(len(H), radon_img.size(1), 2) # p holds fft of projections p[0:length, :, 0] = radon_img # zero pad fp = torch.fft(p.permute(1,0,2), signal_ndim=1) H_expand = H.unsqueeze(0).expand([fp.size(0), fp.size(1)]).unsqueeze(-1).expand(*fp.size()) fp = fp * H_expand # frequency domain filtering p = torch.ifft(fp, signal_ndim=1).permute(1,0,2) p = p[...,0] # real part p = p[0:length, :] #Truncate the filtered projection return p.contiguous() # method 'contiguous' is vitally important, if not it will cause memory leaking
def forward(self, x): image = x['image'].permute(0, 2, 3, 1) # prepare for torch.fft temp = torch.fft(image, signal_ndim=2, normalized=True) if self.noise: temp = (temp + self.noise * x['k']) / (1 + self.noise) else: temp = (1 - x['mask']) * temp + x['k'] temp = torch.ifft(temp, signal_ndim=2, normalized=True) temp = temp.permute(0, 3, 1, 2).float() x['image'] = temp return x