def create_fft_plots(sample, model, epoch): train_code = model.encode(sample.view(1, -1)) train_code = train_code[0] fig = plt.figure() train_code_pad = torch.zeros(100).cuda() train_code_pad[:len(train_code)] = train_code train_code_complex = torch.stack( (train_code_pad, torch.zeros(*train_code_pad.size()).cuda()), dim=1).cuda() H = torch.fft(train_code_complex, 1, normalized=True).cpu().detach().numpy() plt.plot([np.sqrt(H[i, 0]**2 + H[i, 1]**2) for i in range(len(H))]) filter_real = model.conv1.conv_real.weight.data.view(-1) filter_imag = model.conv1.conv_imag.weight.data.view(-1) # lowpass_pad = torch.zeros(L) # lowpass_pad[:len(lowpass_coeff)] = lowpass_coeff filter_complex = torch.stack((filter_real, filter_imag), dim=1) print(filter_complex.shape) lowpass_fft = torch.fft(filter_complex, 1, normalized=False).cpu().detach().numpy() plt.plot([ np.sqrt(lowpass_fft[i, 0]**2 + lowpass_fft[i, 1]**2) for i in range(len(lowpass_fft)) ]) plt.title('Epoch ' + str(epoch)) plt.savefig('../results/images/fft_none/fft_%s.png' % (str(epoch).zfill(4))) fig.clf() plt.close()
def reconstruct_matrix(self, signal, calibration_matrix, eigval1_reciprocal, eigval2_reciprocal): """ Recovers the random matrix. Parameters ---------- signal: ComplexTensor, Tensor with the signal values calibration_matrix: torch.Tensor, calibration matrix (the partial one) eigenvalues_reciprocal1: ComplexTensor, eigenvalues of the first circulant matrix block of the partial calibration matrix. eigenvalues_reciprocal2: ComplexTensor, eigenvalues of the second circulant matrix block of the partial calibration matrix. Returns ------- reconstructed_A: ComplexTensor, recostructed transmission matrix. If batch size<rows, it is a batch of rows. """ start = time() if self.solver == "least-square": inv_calibration_matrix = torch.pinverse(calibration_matrix) reconstructed_A = ComplexTensor( real=torch.matmul(signal.real, inv_calibration_matrix), imag=torch.matmul(signal.imag, inv_calibration_matrix)) elif self.solver == "fft": signal = signal.conj().stack() signal1_star = signal[:, :self.n_signals // 2] signal2_star = signal[:, self.n_signals // 2:] fft_buffer = torch.fft(signal1_star, signal_ndim=1) block1 = ComplexTensor(real=fft_buffer[:, :, 0], imag=fft_buffer[:, :, 1]) block1 = block1.batch_elementwise(eigval1_reciprocal) fft_buffer = torch.fft(signal2_star, signal_ndim=1) block2 = ComplexTensor(real=fft_buffer[:, :, 0], imag=fft_buffer[:, :, 1]) block2 = block2.batch_elementwise(eigval2_reciprocal) reconstructed_A = torch.ifft((block1 + block2).stack(), signal_ndim=1) reconstructed_A = ComplexTensor(real=reconstructed_A[:, :, 0], imag=reconstructed_A[:, :, 1]).conj() self.time_logger["solver"] += time() - start return reconstructed_A
def one_hot_add(inputs, shift): """Performs (inputs + shift) % vocab_size in the one-hot space. Args: inputs: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot Tensor. shift: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot Tensor specifying how much to shift the corresponding one-hot vector in inputs. Soft values perform a "weighted shift": for example, shift=[0.2, 0.3, 0.5] performs a linear combination of 0.2 * shifting by zero; 0.3 * shifting by one; and 0.5 * shifting by two. Returns: Tensor of same shape and dtype as inputs. """ inputs = torch.stack((inputs, torch.zeros_like(inputs)), dim=-1) shift = torch.stack((shift, torch.zeros_like(shift)), dim=-1) inputs_fft = torch.fft( inputs, 1) #ignore last and first dimension to do batched fft shift_fft = torch.fft(shift, 1) result_fft_real = inputs_fft[..., 0] * shift_fft[..., 0] - inputs_fft[ ..., 1] * shift_fft[..., 1] result_fft_imag = inputs_fft[..., 0] * shift_fft[..., 1] + inputs_fft[ ..., 1] * shift_fft[..., 0] result_fft = torch.stack((result_fft_real, result_fft_imag), dim=-1) return torch.ifft( result_fft, 1)[..., 0], result_fft, inputs_fft, shift_fft #return only the real part
def test_fft2d(self): batch_size = 10 n1 = 16 n2 = 32 input = torch.randn(batch_size, n2, n1, dtype=torch.complex64) for normalized in [False, True]: out_torch = view_as_complex( torch.fft(view_as_real(input), signal_ndim=2, normalized=normalized)) # Just to show how fft2d is exactly 2 ffts on each dimension input_f = view_as_complex( torch.fft(view_as_real(input), signal_ndim=1, normalized=normalized)) out_fft = view_as_complex( torch.fft(view_as_real(input_f.transpose(-1, -2)), signal_ndim=1, normalized=normalized)).transpose(-1, -2) self.assertTrue( torch.allclose(out_torch, out_fft, self.rtol, self.atol)) for br_first in [True, False]: for flatten in [False, True]: b = torch_butterfly.special.fft2d(n1, n2, normalized=normalized, br_first=br_first, flatten=flatten) out = b(input) self.assertTrue( torch.allclose(out, out_torch, self.rtol, self.atol))
def cross_corr_and_conv(x, y, pad=False, real=True): if pad: x = zero_pad(x) y = zero_pad(y) x = complexify(x) y = complexify(y) xf = torch.fft(x, signal_ndim=2) yf = torch.fft(y, signal_ndim=2) convf = complex_mult(xf, yf) yf_conj = yf yf_conj[..., 1] *= -1 corrf = complex_mult(xf, yf_conj) conv = torch.ifft(convf, signal_ndim=2) corr = torch.ifft(corrf, signal_ndim=2) if real: conv = conv[..., 0] corr = corr[..., 0] return corr, conv
def get_target_tensor(self, input, target_is_real, degree, mask, pred_and_gt=None): if target_is_real: target_tensor = torch.ones_like(input) target_tensor[:] = degree else: target_tensor = torch.zeros_like(input) if not self.use_mse_as_energy: if degree != 1: target_tensor[:] = degree else: pred, gt = pred_and_gt if self.options.dataroot == "KNEE_RAW": gt = center_crop(gt, [368, 320]) pred = center_crop(pred, [368, 320]) w = gt.shape[2] ks_gt = fft(gt, normalized=True) ks_input = fft(pred, normalized=True) ks_row_mse = F.mse_loss(ks_input, ks_gt, reduce=False).sum( 1, keepdim=True).sum(2, keepdim=True).squeeze() / (2 * w) energy = torch.exp(-ks_row_mse * self.gamma) target_tensor[:] = energy # force observed part to always for i in range(mask.shape[0]): idx = torch.nonzero(mask[i, 0, 0, :]) target_tensor[i, idx] = 1 return target_tensor
def forward(self, x): """ x: input Tensor of shape [batch_size, input_dim1, height, width]. """ batch_size, input_dim, height, width = x.size() assert input_dim == self.input_dim x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim).to(self.device) sketch_1 = x_flat.mm(self.sparse_sketch_matrix1) sketch_2 = x_flat.mm(self.sparse_sketch_matrix2) # Build real+imag arrays to compute FFT, with imag = 0 sketch_1 = torch.stack((sketch_1, torch.zeros(sketch_1.shape).to(self.device)), dim=-1) sketch_2 = torch.stack((sketch_2, torch.zeros(sketch_2.shape).to(self.device)), dim=-1) fft1 = torch.fft(sketch_1, signal_ndim=1) fft2 = torch.fft(sketch_2, signal_ndim=1) del sketch_1, sketch_2 # Element-wise complex product real1, imag1 = fft1.transpose(0, -1) real2, imag2 = fft2.transpose(0, -1) prod = torch.stack((real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2), dim=0).transpose(0, -1) del real1, real2, imag1, imag2 cbp_flat = torch.ifft(prod, signal_ndim=1)[..., 0] cbp = cbp_flat.view(batch_size, height, width, self.output_dim) if self.sum_pool: cbp = cbp.sum(dim=1).sum(dim=1) return cbp
def forward(self, img1, img2): zeros = torch.zeros(img1.size()).cuda(img1.device) loss = nn.L1Loss(size_average=True)(torch.fft( torch.stack((img1, zeros), -1), 2), torch.fft(torch.stack((img2, zeros), -1), 2)) loss = self.loss_weight * loss return loss
def grad(self, field_est, device='cpu'): self.measurements = self.measurements.to(self.device) self.pupils = self.pupils.to(self.device) self.planewaves = self.planewaves.to(self.device) self.P = self.P.to(self.device) multiMeas = torch.matmul(self.measurements.permute(1,2,0),self.C.permute(1,0)).permute(2,0,1) multiMeas = torch.abs(multiMeas) # simulate current estimate of measurements y = self.generateMultiMeas(field_est,device=device) # compute residual sqrty = torch.sqrt(y + EPS) residual = sqrty - torch.sqrt(multiMeas + EPS) cost = torch.sum(torch.pow(residual,2)).detach() Ajx = residual/(sqrty + 1e-10) Ajx_c = torch.stack((Ajx,torch.zeros_like(Ajx)),dim=len(Ajx.shape)) # compute gradient output = mul_c(self.planewaves,field_est) output = torch.fft(output,2) output = mul_c(self.P,output) output = torch.ifft(output,2) g = field_est*0. for meas_index in range(self.Nmeas): output2 = mul_c(Ajx_c[meas_index,...],output) output2 = mul_c(conj(self.planewaves),output2) output2 = torch.fft(output2,2) output2 = mul_c(self.pupils,output2) output2 = torch.ifft(output2,2) g_tmp = torch.matmul(output2.permute(1,2,3,0),self.C[meas_index,:]) g = g + g_tmp return g
def circular_convolution_fft(keys, values, normalized=True, conj=False, cuda=False): ''' For the circular convolution of x and y to be equivalent, you must pad the vectors with zeros to length at least N + L - 1 before you take the DFT. After you invert the product of the DFTs, retain only the first N + L - 1 elements. ''' assert values.dim() == keys.dim() == 2, "only 2 dims supported" assert values.size(-1) % 2 == keys.size(-1) % 2 == 0, "need last dim to be divisible by 2" batch_size, keys_feature_size = keys.size(0), keys.size(1) values_feature_size = values.size(1) required_size = keys_feature_size + values_feature_size - 1 required_size = required_size + 1 if required_size % 2 != 0 else required_size # conj transpose keys = Complex(keys).conj().unstack() if conj else keys # reshape to [batch, [real, imag]] half = keys.size(-1) // 2 keys = torch.cat([keys[:, 0:half].unsqueeze(2), keys[:, half:].unsqueeze(2)], -1) values = torch.cat([values[:, 0:half].unsqueeze(2), values[:, half:].unsqueeze(2)], -1) # do the fft, ifft and return num_required kf = torch.fft(keys, signal_ndim=1, normalized=normalized) vf = torch.fft(values, signal_ndim=1, normalized=normalized) kvif = torch.ifft(kf*vf, signal_ndim=1, normalized=normalized)#[:, 0:required_size] # if conj: # return Complex(kvif[:, :, 1], kvif[:, :, 0]).unstack() #return Complex(kvif[:, :, 0], kvif[:, :, 1]).abs() if not conj \ # return Complex(kvif[:, :, 0], kvif[:, :, 1]).unstack() # if not conj \ # else Complex(kvif[:, :, 1], kvif[:, :, 0]).abs() return Complex(kvif[:, :, 0], kvif[:, :, 1]).unstack().view(batch_size, -1)
def partial_circulant_torch(inputs, filters, indices, sign_pattern): ''' ''' n = np.prod(inputs.shape[1:]) bs = inputs.shape[0] input_reshape = inputs.reshape(bs, n) input_sign = input_reshape * sign_pattern def to_complex(tensor): zeros = torch.zeros_like(tensor) concat = torch.cat((tensor, zeros), axis=0) reshape = concat.view(2, -1, n) return reshape.permute(1, 2, 0) complex_input = to_complex(input_sign) complex_filter = to_complex(filters) input_fft = torch.fft(complex_input, 1) filter_fft = torch.fft(complex_filter, 1) output_fft = torch.zeros_like(input_fft) # is there a simpler way to do complex multiplies in pytorch? output_fft[:, :, 0] = input_fft[:, :, 0] * filter_fft[:, :, 0] - input_fft[:, :, 1] * filter_fft[:, :, 1] output_fft[:, :, 1] = input_fft[:, :, 1] * filter_fft[:, :, 0] + input_fft[:, :, 0] * filter_fft[:, :, 1] output_ifft = torch.ifft(output_fft, 1) output_real = output_ifft[:, :, 0] return output_real[:, indices]
def kspaceFuse(x1, x2): lout = [] for xin in [x1, x2]: if (len(xin.shape) == 4): if (xin.shape[1] == 1): emptyImag = torch.zeros_like(xin) xin_c = torch.cat([xin, emptyImag], 1).permute(0, 2, 3, 1) else: xin_c = xin.permute(0, 2, 3, 1) elif (len(xin.shape) == 5): if (xin.shape[1] == 1): emptyImag = torch.zeros_like(xin) xin_c = torch.cat([xin, emptyImag], 1).permute(0, 2, 3, 4, 1) else: xin_c = xin.permute(0, 2, 3, 4, 1) else: assert False, "xin shape length has to be 4(2d) or 5(3d)" lout.append(xin_c) x1c, x2c = lout x1f = torch.fft(x1c, 2, normalized=True) x2f = torch.fft(x2c, 2, normalized=True) xout_f = x1f + x2f xout = torch.ifft(xout_f, 2, normalized=True) if (len(x1.shape) == 4): xout = xout.permute(0, 3, 1, 2) else: xout = xout.permute(0, 4, 1, 2, 3) if (xin.shape[1] == 1): xout = torch.sqrt(xout[:, 0:1] * xout[:, 0:1] + xout[:, 1:2] * xout[:, 1:2]) return xout
def forward(self, bottom1, bottom2): assert bottom1.size(1) == self.input_dim1 and \ bottom2.size(1) == self.input_dim2 batch_size, _, height, width = bottom1.size() bottom1_flat = bottom1.permute(0, 2, 3, 1).contiguous().view( -1, self.input_dim1) bottom2_flat = bottom2.permute(0, 2, 3, 1).contiguous().view( -1, self.input_dim2) sketch_1 = bottom1_flat.mm(self.sparse_sketch_matrix1) sketch_2 = bottom2_flat.mm(self.sparse_sketch_matrix2) sketch_1 = torch.stack((sketch_1, torch.zeros_like(sketch_1)), 2) sketch_2 = torch.stack((sketch_2, torch.zeros_like(sketch_2)), 2) fft1 = torch.fft(sketch_1, 1).split(1, dim=-1) fft1_real = fft1[0].squeeze() fft1_imag = fft1[1].squeeze() fft2 = torch.fft(sketch_2, 1).split(1, dim=-1) fft2_real = fft2[0].squeeze() fft2_imag = fft2[1].squeeze() fft_product_real = fft1_real.mul(fft2_real) - fft1_imag.mul(fft2_imag) fft_product_imag = fft1_real.mul(fft2_imag) + fft1_imag.mul(fft2_real) fft_product = torch.stack((fft_product_real, fft_product_imag), dim=2) cbp_flat = torch.ifft(fft_product, 1).split(1, dim=-1)[0].squeeze() cbp = cbp_flat.view(batch_size, height, width, self.output_dim) if self.sum_pool: cbp = cbp.sum(dim=1).sum(dim=1) return cbp
def inv_filt_torch(Y, R, batch_dim=False): if not batch_dim: Y, R = Y.unsqueeze(0), R.unsqueeze(0) Yc = complexify(Y) W = torch.ifft(Yc, signal_ndim=2) n, s1, s2 = R.shape s2 *= 3 S_auto = torch_fft_shift(W, dims=(1, 2)) XR_cross = S_auto[:, :, :2 * R.shape[1]] _, t1, t2, _ = XR_cross.shape R_ = torch.zeros((n, t1, t2, 2), dtype=torch.float, device=W.device) R_[:, :s1, :R.shape[2], 0] = R F_R = torch.fft(R_, signal_ndim=2) F_R_conj = torch.clone(F_R) F_R_conj[..., 1] *= -1 F_R_abs = (F_R**2).sum(-1, keepdim=True) X_ = torch.ifft(complex_mult(torch.fft(XR_cross, signal_ndim=2), F_R) / F_R_abs, signal_ndim=2) X = X_[:, R.shape[1]:, R.shape[2]:] if not batch_dim: X = X[0] return X
def orth_phase22(im2, loc, devid): """ Given a batch of images and a tensor of local maxima, this function returns a tuple consisting of the phase and the orthogonal phase centered at the local minima. """ #im2 (1, P_c, N, N, 2) size = im2.size(-2) phase = torch.atan2(im2[...,1], im2[...,0]) # (1, P_c, N, N) # unwrapping z = torch.arange(-size//2,size//2).unsqueeze(0).unsqueeze(0) z = z.repeat(tuple(loc.size()[:2])+(1,)).type(torch.cuda.FloatTensor) # (1, P_c, N) z1 = z.unsqueeze(-1).repeat(1, 1, 1, size) # (1, P_c, N, N) z2 = z.unsqueeze(-2).repeat(1, 1, size, 1) # (1, P_c, N, N) z = z1**2 + z2**2 del z1; del z2 z = shift2(z, -torch.cuda.FloatTensor([size//2,size//2]).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(1,loc.size(1),1,1), devid) z = z.squeeze().unsqueeze(0).unsqueeze(-1) cplx_phase = torch.stack((torch.cos(phase), torch.sin(phase)), dim=-1) cpi = complex_mul(torch.ifft(z*torch.fft(cplx_phase,2),2),conjugate(cplx_phase))[...,1] lin_sp_c_dx = torch.ifft(torch.fft(torch.stack((cpi, 0*cpi.clone()),dim=-1),2)/z, 2)[...,0] shifted_phase = shift2(phase, loc, devid) lin_sp_c_dx = lin_sp_c_dx - lin_sp_c_dx[:,:,0,0].unsqueeze(-1).unsqueeze(-1) lin_sp_c_dx = lin_sp_c_dx.unsqueeze(2) + shifted_phase[:,:,:,0,0].unsqueeze(-1).unsqueeze(-1) t_s_phase = torch.transpose(lin_sp_c_dx, -1, -2) # (1, P_c, m, N, N) del(shifted_phase) if loc.size(2) == 1: t_s_phase = torch.flip(t_s_phase.unsqueeze(-2), [-2,-3]).squeeze().unsqueeze(0).unsqueeze(0) # (1, P_c, m, N, N) else: t_s_phase = torch.flip(t_s_phase.unsqueeze(-2), [-2,-3]).squeeze().unsqueeze(0) # (1, P_c, m, N, N) orth_ph = unshift2(t_s_phase, loc, devid) # (1, P_c, m, N, N) phase_ = unshift2(lin_sp_c_dx, loc, devid) del(t_s_phase) return phase_, orth_ph
def propagate_focal_to_back(self, u1): # Based on the function propFF out of the book "Computational Fourier # Optics. A MATLAB Tutorial". There you can find more information. wavelenght = self.optic_config.PSF_config.wvl [M, N] = u1.shape[-3:-1] #source sample interval dx1 = self.sampling_rate # obs sidelength L2 = wavelenght * self.focal_length / dx1 #obs sample interval #dx2 = wavelenght*self.focal_length/L1 # filter input with apperture mask # mask = self.mask.unsqueeze(0).unsqueeze(0).repeat((u1.shape[0],u1.shape[1],1,1,1)) # u1 = torch.mul(u1,self.TransferFunctionIncoherent) #output field if M % 2 == 1: u2 = torch.mul( self.TransferFunctionIncoherent, ob.batch_fftshift2d(torch.fft(ob.batch_ifftshift2d(u1), 2))) * dx1 * dx1 else: u2 = torch.mul( self.TransferFunctionIncoherent, ob.batch_ifftshift2d(torch.fft(ob.batch_fftshift2d(u1), 2))) * dx1 * dx1 # multiply by precomputed coeff u2 = ob.mulComplex(u2, self.coefU1minus) return u2, L2
def synthesis(target, test_id, ind, scat, n, min_error, err_it, nit, is_complex = False, initial_type = 'gaussian'): if torch.cuda.is_available(): target = target.cuda() print(is_complex) # set up target if is_complex: target_hat = torch.fft(target,2) s_target = scat(target_hat) if initial_type == 'gaussian': x0 = torch.randn(n,n,2) elif initial_type == 'uniform': x0 = torch.rand(n,n,2) else: target_hat = torch.rfft(target, 2, onesided = False) s_target = scat(target_hat) if initial_type == 'gaussian': x0 = torch.randn(n,n) elif initial_type == 'uniform': x0 = torch.rand(n,n) x0 = Dealias(x0) if torch.cuda.is_available(): s_target = s_target.cuda() x0 = x0.cuda() x0 = Variable(x0, requires_grad=True) if is_complex: x0_hat = torch.fft(x0, 2) else: x0_hat = torch.rfft(x0, 2, onesided = False) s0 = scat(x0_hat) loss = nn.MSELoss() optimizer = optim.Adam([x0], lr=lr) output = loss(s_target, s0) l0 = output error = [] count = 0 while output / l0 > min_error: optimizer.zero_grad() if is_complex: x0_hat = torch.fft(x0, 2) else: x0_hat = torch.rfft(x0, 2, onesided = False) s0 = scat(x0_hat) output = loss(s_target, s0) if count % err_it ==0: error.append(output.item()) output.backward() optimizer.step() output = loss(s_target, s0) if count % nit == 0: print(output.data.cpu().numpy()) np.save('./result%d/syn_error%d.npy'%(test_id, ind), np.asarray(error)) np.save('./result%d/syn_result%d.npy'%(test_id, ind), x0.data.cpu().numpy()) # plot_image(x0.data.cpu().numpy(), test_id, ind, count, nit) count += 1 print('error reduced by: ', output / l0) print('error supposed reduced by: ', min_error) np.save('./result%d/syn_error%d.npy'%(test_id, ind), np.asarray(error)) np.save('./result%d/syn_result%d.npy'%(test_id, ind), x0.data.cpu().numpy())
def cross_conv_corr(x, x_hat): if len(x.shape) != 3 or len(x_hat.shape) != 3: raise RuntimeError( 'Expects images with dimensions (batch_idx, nx, ny)') x_comp = torch.zeros(x.shape[0], x.shape[1], x.shape[2], 2) x_comp[:, :, :, 0] = x x_hat_comp = torch.zeros(x.shape[0], x.shape[1], x.shape[2], 2) x_hat_comp[:, :, :, 0] = x_hat fx = torch.fft(x_comp, signal_ndim=2) fx_hat = torch.fft(x_hat_comp, signal_ndim=2) cross_conv_fft = torch.stack([ fx[:, :, :, 0] * fx_hat[:, :, :, 0] - fx[:, :, :, 1] * fx_hat[:, :, :, 1], fx[:, :, :, 0] * fx_hat[:, :, :, 1] + fx[:, :, :, 1] * fx_hat[:, :, :, 0] ], -1) cross_corr_fft = torch.stack([ fx[:, :, :, 0] * fx_hat[:, :, :, 0] + fx[:, :, :, 1] * fx_hat[:, :, :, 1], fx[:, :, :, 0] * fx_hat[:, :, :, 1] - fx[:, :, :, 1] * fx_hat[:, :, :, 0] ], -1) cross_conv = torch.ifft(cross_conv_fft, signal_ndim=2) cross_corr = torch.ifft(cross_corr_fft, signal_ndim=2) return cross_conv[..., 0], cross_corr[..., 0]
def __call__(self, x, masks=None): if masks is None: y = self.mask.view(1, *self.mask.shape, 1) * torch.fft( x.permute(0, 2, 3, 1), signal_ndim=2, normalized=True) else: y = masks.view(*masks.shape, 1) * torch.fft( x.permute(0, 2, 3, 1), signal_ndim=2, normalized=True) return y.permute(0, 3, 1, 2)
def calCC(aT,bT): N=aT.shape[0] cT=torch.zeros((N,2)) aT=torch.fft(aT,1) bT=torch.fft(bT,1) cT[:,0]=aT[:,0]*bT[:,0]+aT[:,1]*bT[:,1] cT[:,1]=aT[:,0]*bT[:,1]-aT[:,1]*bT[:,0] return torch.ifft(cT,1)[:,0]
def custom_fft(iT, real=True): ###make complex again --- TEMP: use complex signal if real: iTC = torch.stack([iT, torch.zeros_like(iT, requires_grad=False)], dim=-1) return torch.fft(iTC, signal_ndim=2, normalized=normalizeFFT) ###convert to modified transformation else: return torch.fft(iT, signal_ndim=2, normalized=normalizeFFT)
def lower_level_mixed_derivs(x, w, y, S, alpha, eps, A, reg_func): # Compute mixed second derivatives of the lower level objective function for use in the adjoint method Fw = torch.fft(w, 2, normalized=True) DwDy = -S.view(1, *S.shape, 1)**2 * Fw DwDalpha = torch.sum(w * A.T(reg_func.grad(A(x)))) Fx = torch.fft(x, 2, normalized=True) DwDS = torch.sum(Fw * 2 * S.view(1, *S.shape, 1) * (Fx - y), dim=(0, 3)) DwDeps = torch.sum(w * x) return DwDy, DwDS, DwDalpha, DwDeps
def forward(self, h, r, t): h_e, r_e, t_e = self.embed(h, r, t) r_e = F.normalize(r_e, p=2, dim=-1) h_e = torch.stack((h_e, torch.zeros_like(h_e)), -1) t_e = torch.stack((t_e, torch.zeros_like(t_e)), -1) e, _ = torch.unbind( torch.ifft(torch.conj(torch.fft(h_e, 1)) * torch.fft(t_e, 1), 1), -1) return -F.sigmoid(torch.sum(r_e * e, 1))
def forward(self, input, recon): input = input.unsqueeze(4) recon = recon.unsqueeze(4) input = torch.cat((input, input), 4) recon = torch.cat((recon, recon), 4) input_fft = torch.fft(input, 2) recon_fft = torch.fft(recon, 2) fft_loss = self.mse_loss(input_fft, recon_fft) return fft_loss
def forward(self, data_streams): """ Main forward pass of the model. :param data_streams: DataStreams({'images',**}) :type data_streams: ``ptp.dadatypes.DataStreams`` """ # Unpack DataStreams. enc_img = data_streams[self.key_image_encodings] enc_q = data_streams[self.key_question_encodings] sketch_pm_img = self.image_sketch_projection_matrix sketch_pm_q = self.question_sketch_projection_matrix # Project both batches. sketch_img = enc_img.mm(sketch_pm_img) sketch_q = enc_q.mm(sketch_pm_q) # Add imaginary parts (with zeros). sketch_img_reim = torch.stack([ sketch_img, torch.zeros(sketch_img.shape).type(self.app_state.FloatTensor) ], dim=2) sketch_q_reim = torch.stack([ sketch_q, torch.zeros(sketch_q.shape).type(self.app_state.FloatTensor) ], dim=2) #print("\n sketch_img_reim=",sketch_img_reim) #print("\n sketch_img_reim.shape=",sketch_img_reim.shape) # Perform FFT. # Returns the real and the imaginary parts together as one tensor of the same shape of input. fft_img = torch.fft(sketch_img_reim, signal_ndim=1) fft_q = torch.fft(sketch_q_reim, signal_ndim=1) #print(fft_img) # Get real and imaginary parts. real1 = fft_img[:, :, 0] imag1 = fft_img[:, :, 1] real2 = fft_q[:, :, 0] imag2 = fft_q[:, :, 1] # Calculate product. fft_product = torch.stack( [real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1) #print("fft_product=",fft_product) # Inverse FFT. cbp = torch.ifft(fft_product, signal_ndim=1)[:, :, 0] #print("cbp=",cbp) # Add predictions to datadict. data_streams.publish({self.key_outputs: cbp})
def dispersion(xpol, ypol, length): frequency_domain_xpol = torch.fft(xpol, 1) frequency_domain_ypol = torch.fft(ypol, 1) frequency_domain_xpol = complex_exp(frequency_domain_xpol, D * length) frequency_domain_ypol = complex_exp(frequency_domain_ypol, D * length) xpol_time_domain = torch.ifft(frequency_domain_xpol, 1) ypol_time_domain = torch.ifft(frequency_domain_ypol, 1) return xpol_time_domain, ypol_time_domain
def perform(self, x, k0, mask, sensitivity): """ transform to x-f space with subtraction of average temporal frame in multi-coil setting :param x: input image with shape [nt, nx, ny, 2] :param mask: undersampling mask [nt, ns, nx, ny, 2] :param k0: undersampled k-space data [nt, ns, nx, ny, 2] :param sensitivity: sensitivity maps [nt, ns, nx, ny, 2] :return: difference data; DC baseline """ x = complex_multiply(x[..., 0].unsqueeze(1), x[..., 1].unsqueeze(1), sensitivity[..., 0], sensitivity[..., 1]) k = torch.fft(x, 2, normalized=self.normalized) if self.divide_by_n: k_avg = torch.div(torch.sum(k, 0), k.shape[0]) else: k_avg = torch.div(torch.sum(k0, 0), torch.clamp(torch.sum(mask, 0), min=1)) ns, nx, ny, nc = k_avg.shape k_avg = k_avg.view(1, ns, nx, ny, nc) k_avg = k_avg.repeat(k.shape[0], 1, 1, 1, 1) # subtract the temporal average frame k_diff = torch.sub(k, k_avg) x_diff = torch.ifft(k_diff, 2, normalized=self.normalized) Sx_diff = complex_multiply(x_diff[..., 0], x_diff[..., 1], sensitivity[..., 0], -sensitivity[..., 1]).sum( dim=1) # [nt, nx, ny, 2] # transform to x-f space to get the baseline x_avg = torch.ifft(k_avg, 2, normalized=self.normalized) Sx_avg = complex_multiply(x_avg[..., 0], x_avg[..., 1], sensitivity[..., 0], -sensitivity[..., 1]).sum(dim=1) Sx_avg = Sx_avg.permute(1, 2, 0, 3) # [nx, ny, nt, 2] x_f_avg = fftshift_pytorch(torch.fft(ifftshift_pytorch(Sx_avg, axes=[-2]), 1, normalized=self.normalized), axes=[-2]) x_f_avg = x_f_avg.permute(2, 0, 1, 3) # difference data Sx_diff = Sx_diff.permute(1, 2, 0, 3) # [nx, ny, nt, 2] x_f_diff = fftshift_pytorch(torch.fft(ifftshift_pytorch(Sx_diff, axes=[-2]), 1, normalized=self.normalized), axes=[-2]) x_f_diff = x_f_diff.permute(2, 0, 1, 3) return x_f_diff, x_f_avg
def bench(batch_size: int, d: int, hw: int, num_iter: int): if not torch.cuda.is_available(): print("GPU is not available") return device = torch.device('cuda:0') torch.set_grad_enabled(False) # BxDxHxWx2 inp = torch.randn(batch_size, d, hw, hw, 2, device=device) # warmup outp = torch.fft(inp, 3) inp_ = torch.ifft(outp, 3) # fft start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() with contexttimer.Timer() as t: for it in range(num_iter): outp = torch.fft(inp, 3) end.record() torch.cuda.synchronize() elapsed = start.elapsed_time(end) / 1e3 tps = num_iter / elapsed fft_time_consume = elapsed del outp, inp outp = torch.randn(batch_size, d, hw, hw, 2, device=device) # ifft start.record() with contexttimer.Timer() as t: for it in range(num_iter): inp_ = torch.ifft(outp, 3) end.record() torch.cuda.synchronize() elapsed = start.elapsed_time(end) / 1e3 itps = num_iter / elapsed ifft_time_consume = elapsed print( json.dumps({ "TPS": tps, "fft_elapsed": fft_time_consume, "ITPS": itps, "ifft_elapsed": ifft_time_consume, "n": num_iter, "batch_size": batch_size, "D_size": d, "HW_size": hw, }))
def grad(self, x): ys = torch.stack((self.y, torch.zeros_like(self.y)), 2) Fy = torch.fft(ys, 2) AHy = mul_c(conj(self.fpsf), Fy) xs = torch.stack((x, torch.zeros_like(x)), 2) Fx = torch.fft(xs, 2) AHA = mul_c(conj(self.fpsf), self.fpsf) AHAx = mul_c(AHA, Fx) return torch.ifft(AHAx - AHy, 2)[..., 0]
def test_fft_function_clobbered(self, device): t = torch.randn((100, 2), device=device) eager_result = fft_fn(t, 1) def method_fn(t): return t.fft(1) scripted_method_fn = torch.jit.script(method_fn) self.assertEqual(scripted_method_fn(t), eager_result) with self.assertRaisesRegex(TypeError, "'module' object is not callable"): torch.fft(t, 1)
def test_circulant(self): batch_size = 10 n = 13 for complex in [False, True]: dtype = torch.float32 if not complex else torch.complex64 col = torch.randn(n, dtype=dtype) C = la.circulant(col.numpy()) input = torch.randn(batch_size, n, dtype=dtype) out_torch = torch.tensor(input.detach().numpy() @ C.T) out_np = torch.tensor(np.fft.ifft( np.fft.fft(input.numpy()) * np.fft.fft(col.numpy())), dtype=dtype) self.assertTrue( torch.allclose(out_torch, out_np, self.rtol, self.atol)) # Just to show how to implement circulant multiply with FFT if complex: input_f = view_as_complex( torch.fft(view_as_real(input), signal_ndim=1)) col_f = view_as_complex( torch.fft(view_as_real(col), signal_ndim=1)) prod_f = complex_mul(input_f, col_f) out_fft = view_as_complex( torch.ifft(view_as_real(prod_f), signal_ndim=1)) self.assertTrue( torch.allclose(out_torch, out_fft, self.rtol, self.atol)) for separate_diagonal in [True, False]: b = torch_butterfly.special.circulant( col, transposed=False, separate_diagonal=separate_diagonal) out = b(input) self.assertTrue( torch.allclose(out, out_torch, self.rtol, self.atol)) row = torch.randn(n, dtype=dtype) C = la.circulant(row.numpy()).T input = torch.randn(batch_size, n, dtype=dtype) out_torch = torch.tensor(input.detach().numpy() @ C.T) # row is the reverse of col, except the 0-th element stays put # This corresponds to the same reversal in the frequency domain. # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Time_and_frequency_reversal row_f = np.fft.fft(row.numpy()) row_f_reversed = np.hstack((row_f[:1], row_f[1:][::-1])) out_np = torch.tensor(np.fft.ifft( np.fft.fft(input.numpy()) * row_f_reversed), dtype=dtype) self.assertTrue( torch.allclose(out_torch, out_np, self.rtol, self.atol)) for separate_diagonal in [True, False]: b = torch_butterfly.special.circulant( row, transposed=True, separate_diagonal=separate_diagonal) out = b(input) self.assertTrue( torch.allclose(out, out_torch, self.rtol, self.atol))