def general_gaussian_blur_3D_periodic(u, fk, fi, fj): 'Applies a separable convolution whose Fourier filters are (fk,fi,fj) on u, typically anisotropic Gaussian blur ' fU = torch.rfft(u, 1, onesided=False) # 1D Fourier Transform, (Z,X,Y,2) fU = complex_mult(fU, fj) U = torch.irfft(fU, 1, onesided=False).permute(0, 2, 1) # (Z,Y,X) fU = torch.rfft(U, 1, onesided=False) # 1D Fourier Transform, (Z,Y,X,2) fU = complex_mult(fU, fi) U = torch.irfft(fU, 1, onesided=False).permute(2, 1, 0) # (X,Y,Z) fU = torch.rfft(U, 1, onesided=False) # 1D Fourier Transform, (X,Y,Z,2) fU = complex_mult(fU, fk) U = torch.irfft(fU, 1, onesided=False).permute(2, 0, 1) # (Z,X,Y) return U
def Decomposition(input, N): High_freq = rfft(input.unsqueeze(3), signal_ndim=3) Low_freq = torch.FloatTensor(High_freq.shape).cuda() width, height = High_freq.shape[1], High_freq.shape[2] for x in range(N//2, width - N //2): for y in range(N//2, height - N // 2): # for x in range(0, N): # for y in range(0, N): Low_freq[:,x,y] = High_freq[:,x,y] High_freq -= Low_freq # return High_freq, Low_freq return irfft(High_freq, signal_ndim=3).squeeze(3), irfft(Low_freq, signal_ndim=3).squeeze(3)
def forward(ctx, h1, s1, h2, s2, output_size, x, y, force_cpu_scatter_add=False): ctx.save_for_backward(h1, s1, h2, s2, x, y) ctx.x_size = tuple(x.size()) ctx.y_size = tuple(y.size()) ctx.force_cpu_scatter_add = force_cpu_scatter_add ctx.output_size = output_size # Compute the count sketch of each input px = CountSketchFn_forward( h1, s1, output_size, x, force_cpu_scatter_add) fx = torch.rfft(px, 1) re_fx = fx.select(-1, 0) im_fx = fx.select(-1, 1) del px py = CountSketchFn_forward( h2, s2, output_size, y, force_cpu_scatter_add) fy = torch.rfft(py, 1) re_fy = fy.select(-1, 0) im_fy = fy.select(-1, 1) del py # Convolution of the two sketch using an FFT. # Compute the FFT of each sketch # Complex multiplication re_prod, im_prod = ComplexMultiply_forward(re_fx, im_fx, re_fy, im_fy) # Back to real domain # The imaginary part should be zero's re = torch.irfft(torch.stack((re_prod, im_prod), re_prod.dim()), 1, signal_sizes=(output_size,)) return re
def inner(): scaled_spectrum_t = scale * spectrum_real_imag_t image = torch.irfft(scaled_spectrum_t, 2, normalized=True, signal_sizes=(h, w)) image = image[:batch, :channels, :h, :w] magic = 4.0 # Magic constant from Lucid library; increasing this seems to reduce saturation image = image / magic return image
def synth(self, y_orig, params): wetdry = params[:, 0] decay = params[:, 1] # Pad the input sequence y_orig = nn.functional.pad(y_orig, (0, self.size), "constant", 0) # Compute STFT Y_S = torch.rfft(y, 1) # Compute the current impulse response idx = torch.sigmoid(wetdry) * identity imp = torch.sigmoid(1 - wetdry) * y_orig dcy = torch.exp(-(torch.exp(decay) + 2) * torch.linspace(0, 1, self.size).to(y_orig.device)) final_impulse = idx + imp * dcy # Pad the impulse response impulse = nn.functional.pad(final_impulse, (0, self.size), "constant", 0) if y.shape[-1] > self.size: impulse = nn.functional.pad(impulse, (0, y.shape[-1] - impulse.shape[-1]), "constant", 0) IR_S = torch.rfft(impulse.detach(), 1).expand_as(Y_S) # Apply the reverb Y_S_CONV = torch.zeros_like(IR_S) Y_S_CONV[:, :, 0] = Y_S[:, :, 0] * IR_S[:, :, 0] - Y_S[:, :, 1] * IR_S[:, :, 1] Y_S_CONV[:, :, 1] = Y_S[:, :, 0] * IR_S[:, :, 1] + Y_S[:, :, 1] * IR_S[:, :, 0] # Invert the reverberated signal y = torch.irfft(Y_S_CONV, 1, signal_sizes=(y.shape[-1], )) return y
def interaction_function( self, f: torch.FloatTensor, g: torch.FloatTensor, ) -> torch.FloatTensor: """Evaluate the interaction function for given embeddings. The embeddings have to be in a broadcastable shape. :param h: shape: (batch_size, num_entities, d) Head embeddings. :param r: shape: (batch_size, num_entities, d) Relation embeddings. :param t: shape: (batch_size, num_entities, d) Tail embeddings. :return: shape: (batch_size, num_entities) The scores. """ # Circular correlation of entity embeddings a_fft = torch.rfft(f, signal_ndim=1, onesided=True) b_fft = torch.rfft(g, signal_ndim=1, onesided=True) # complex conjugate, a_fft.shape = (batch_size, num_entities, d', 2) a_fft[:, :, 1] *= -1 # Hadamard product in frequency domain p_fft = a_fft * b_fft # inverse real FFT, shape: (batch_size, num_entities, d) composite = torch.irfft(p_fft, signal_ndim=1, onesided=True, signal_sizes=(f.shape[-1], )) return composite return scores
def InverseLPFinFreq(input,sigma): ax = np.arange(-input.size(2) // 2 + 1., input.size(3) // 2 + 1.) xx, yy = np.meshgrid(ax, ax) kernel = np.exp(-(xx**2 + yy**2) / (2. * sigma**2)) kernel = kernel / np.sum(kernel) kernel = Variable((torch.from_numpy(kernel).float()).cuda()).view(1,1,input.size(2),input.size(3)) kernel = kernel.expand(input.size(0),input.size(1),input.size(2),input.size(3)) kernel = 1/(kernel / torch.max(kernel)) x_f = torch.rfft(input,2,onesided=False) x_f_r = x_f[:,:,:,:,0].contiguous() x_f_i = x_f[:,:,:,:,1].contiguous() x_mag = torch.sqrt(x_f_r*x_f_r + x_f_i*x_f_i) x_pha = torch.atan2(x_f_i,x_f_r) x_mag = torch.cat((x_mag[:,:,128:,:],x_mag[:,:,0:128,:]),2) x_mag = torch.cat((x_mag[:,:,:,128:],x_mag[:,:,:,0:128]),3) x_mag = x_mag * kernel x_mag = torch.cat((x_mag[:,:,128:,:],x_mag[:,:,0:128,:]),2) x_mag = torch.cat((x_mag[:,:,:,128:],x_mag[:,:,:,0:128]),3) x_f_r = x_mag * torch.cos(x_pha) x_f_i = x_mag * torch.sin(x_pha) x_f = torch.cat((x_f_r.view(x_f_r.size(0),x_f_r.size(1),x_f_r.size(2),x_f_r.size(3),1),x_f_i.view(x_f_r.size(0),x_f_r.size(1),x_f_r.size(2),x_f_r.size(3),1)),4) output = torch.irfft(x_f,2,onesided=False) return output
def _register_translation_gpu(self,frame,upsample_factor=20): # Whole-pixel shift - Compute cross-correlation by an IFFT img_fourier = th.rfft(frame, self.frame_dims, onesided=False) image_product = th.zeros(img_fourier.size()) image_product[:,:,:,0] = img_fourier[:,:,:,0]* self.alignment_frame_fourier[:,:,:,0]- \ img_fourier[:,:,:,1]* self.alignment_frame_fourier[:,:,:,1] image_product[:,:,:,1] = img_fourier[:,:,:,0]* self.alignment_frame_fourier[:,:,:,1]+ \ img_fourier[:,:,:,1]* self.alignment_frame_fourier[:,:,:,0] cross_correlation = th.irfft(image_product, self.frame_dims, onesided=False, signal_sizes = frame.shape) # Locate maximum maxima = self._to_cpu(th.argmax(cross_correlation)) maxima = np.unravel_index(maxima, cross_correlation.size(), order='C') maxima = np.asarray(maxima) shifts = np.array(maxima, dtype=np.float64) shifts[shifts > self.frame_midpoints] -= np.array(cross_correlation.shape)[shifts > self.frame_midpoints] # in bara chie?? shifts = np.round(shifts * upsample_factor) / upsample_factor #aya round numpy ba torch yejur amal mikone?bale upsampled_region_size = np.ceil(upsample_factor * 1.5) dftshift = np.fix(upsampled_region_size / 2.0) upsample_factor = np.array(upsample_factor, dtype=np.float64) normalization = (img_fourier.numel() * upsample_factor ** 2) sample_region_offset = dftshift - shifts*upsample_factor image_product = self._to_cpu(image_product) imag_part = 1j*image_product[:,:,:,1] img_product_cpu = image_product[:,:,:,0]+imag_part cross_correlation = self._upsampled_dft_cpu(img_product_cpu.conj(), upsampled_region_size, upsample_factor, sample_region_offset).conj() cross_correlation /= normalization # Locate maximum and map back to original pixel grid maxima = np.array(np.unravel_index(np.argmax(np.abs(cross_correlation)),cross_correlation.shape,order='C'), dtype=np.float64) maxima -= dftshift shifts = shifts + maxima / upsample_factor return shifts
def compute_amplitude_gradients_for_X(model, X): device = next(model.parameters()).device ffted = np.fft.rfft(X, axis=2) amps = np.abs(ffted) phases = np.angle(ffted) amps_th = to_tensor(amps.astype(np.float32), device=device).requires_grad_(True) phases_th = to_tensor(phases.astype(np.float32), device=device).requires_grad_(True) fft_coefs = amps_th.unsqueeze(-1) * torch.stack( (torch.cos(phases_th), torch.sin(phases_th)), dim=-1) fft_coefs = fft_coefs.squeeze(3) iffted = torch.irfft(fft_coefs, signal_ndim=1, signal_sizes=(X.shape[2], )) outs = model(iffted) n_filters = outs.shape[1] amp_grads_per_filter = np.full((n_filters, ) + ffted.shape, np.nan, dtype=np.float32) for i_filter in range(n_filters): mean_out = torch.mean(outs[:, i_filter]) mean_out.backward(retain_graph=True) amp_grads = to_numpy(amps_th.grad.clone()) amp_grads_per_filter[i_filter] = amp_grads amps_th.grad.zero_() assert not np.any(np.isnan(amp_grads_per_filter)) return amp_grads_per_filter
def forward(self, x): n1, n2, n3 = self.n1, self.n2, self.n3 batchsize = x.shape[0] x_ft = torch.rfft(x, 3, normalized=True, onesided=True) out_ft = torch.zeros(batchsize, self.in_channels, x.size(-3), x.size(-2), x.size(-1) // 2 + 1, 2, device=x.device) out_ft[:, :, :n1, :n2, :n3] = compl_mul3d(x_ft[:, :, :n1, :n2, :n3], self.weights1) out_ft[:, :, -n1:, :n2, :n3] = compl_mul3d(x_ft[:, :, -n1:, :n2, :n3], self.weights2) out_ft[:, :, :n1, -n2:, :n3] = compl_mul3d(x_ft[:, :, :n1, -n2:, :n3], self.weights3) out_ft[:, :, -n1:, -n2:, :n3] = compl_mul3d(x_ft[:, :, -n1:, -n2:, :n3], self.weights4) out = torch.irfft(out_ft, 3, normalized=True, onesided=True, signal_sizes=(x.size(-3), x.size(-2), x.size(-1))) return out
def get_feature_weights(self, patch): n, k = self.n, self.k assert patch.ndimension() == 4 batch_size = patch.shape[0] c = patch.shape[1] assert patch.shape[2] == n * self.k and patch.shape[2] == patch.shape[3] w1 = patch.view(batch_size, c, n, k, n, k).permute(0, 2, 4, 1, 3, 5).contiguous() ''' w1_flatten = w1.view(w1.shape[0], w1.shape[1] * w1.shape[2], -1) normalize_term = torch.sqrt(torch.sum(w1_flatten ** 2, dim=2)) normalize_term = torch.mean(normalize_term, dim=1) + 1e-6 w1 = w1 / normalize_term[:, None, None, None, None, None] ''' w1 = w1.view(batch_size, n * n, c, k, k) if self.tiled: w1 = tile2d(w1, 1) w1 = w1.view(batch_size * n * n, c, k, k) w1_f = torch.rfft(w1, signal_ndim=2, normalized=True, onesided=True) * self.freq_coeff w1_new = torch.irfft(w1_f, signal_ndim=2, normalized=True, onesided=True, signal_sizes=w1.shape[-2:]) return w1_new
def root_filter(img, num_filters=2): assert (num_filters > 1) cuda = torch.device('cuda') img_cu = torch.from_numpy(img.transpose([2, 0, 1])).float().to('cpu') img_cu = MinMaxNormalize(img_cu) imgs = [img] img_cu.unsqueeze_(0) I_fft = torch.rfft(img_cu, signal_ndim=2, onesided=False, normalized=False) I_mag = ((I_fft[:, :, :, :, 0]**2 + I_fft[:, :, :, :, 1]**2)**0.5) #I_mag_nth = I_mag**(1-0.1) pf = 1.0 / num_filters I_mag_nth = I_mag**(pf) for i in range(num_filters): I_fft[:, :, :, :, 0] = I_fft[:, :, :, :, 0] / I_mag_nth I_fft[:, :, :, :, 1] = I_fft[:, :, :, :, 1] / I_mag_nth I_fft[I_fft != I_fft] = 0 I_hat = torch.irfft(I_fft, signal_ndim=2, onesided=False, normalized=False) I_hat_normalized = MinMaxNormalize(I_hat) I_hat_normalized = I_hat_normalized.cpu().numpy()[0].transpose( [1, 2, 0]) imgs.append(I_hat_normalized) return imgs
def forward(self, z): z, conditions = z # Pad the input sequence y = nn.functional.pad(z, (0, self.size), "constant", 0) # Compute STFT Y_S = torch.rfft(y, 1) # Compute the current impulse response idx = torch.sigmoid(self.wetdry) * self.identity imp = torch.sigmoid(1 - self.wetdry) * self.impulse dcy = torch.exp(-(torch.exp(self.decay) + 2) * torch.linspace(0, 1, self.size).to(z.device)) final_impulse = idx + imp * dcy # Pad the impulse response impulse = nn.functional.pad(final_impulse, (0, self.size), "constant", 0) if y.shape[-1] > self.size: impulse = nn.functional.pad(impulse, (0, y.shape[-1] - impulse.shape[-1]), "constant", 0) IR_S = torch.rfft(impulse.detach(), 1).expand_as(Y_S) # Apply the reverb Y_S_CONV = torch.zeros_like(IR_S) Y_S_CONV[:, :, 0] = Y_S[:, :, 0] * IR_S[:, :, 0] - Y_S[:, :, 1] * IR_S[:, :, 1] Y_S_CONV[:, :, 1] = Y_S[:, :, 0] * IR_S[:, :, 1] + Y_S[:, :, 1] * IR_S[:, :, 0] # Invert the reverberated signal y = torch.irfft(Y_S_CONV, 1, signal_sizes=(y.shape[-1], )) # Crop back to original input length y = y[:, :z.size(-1)] return y
def _exp_decay_reverb(self, y, wetdry, decay): # Compute STFT Y_S = torch.rfft(y, 1) # Compute the current impulse response # idx = torch.sigmoid(wetdry).unsqueeze(1).expand(-1, y.size(1)) # imp = torch.sigmoid(1 - wetdry).unsqueeze(1) * self.impulse.expand(len(wetdry), -1) # dcy = torch.exp(-(torch.exp(decay) + 2).unsqueeze(1) * torch.linspace(0,1, self.size).to(y.device)) idx = wetdry.unsqueeze(1).expand(-1, y.size(1)) imp = (1 - wetdry).unsqueeze(1) * self.impulse.expand(len(wetdry), -1) dcy = torch.exp(-(torch.exp(decay) + 2).unsqueeze(1) * torch.linspace(0, 1, self.size).to(y.device)) final_impulse = idx + imp * dcy # Pad the impulse response impulse = final_impulse IR_S = torch.rfft(impulse, 1).expand_as(Y_S) # IR_S = torch.rfft(impulse.detach(),1).expand_as(Y_S) # Apply the reverb Y_S_CONV = torch.zeros_like(IR_S) Y_S_CONV[:, :, 0] = Y_S[:, :, 0] * IR_S[:, :, 0] - Y_S[:, :, 1] * IR_S[:, :, 1] Y_S_CONV[:, :, 1] = Y_S[:, :, 0] * IR_S[:, :, 1] + Y_S[:, :, 1] * IR_S[:, :, 0] # Invert the reverberated signal y = torch.irfft(Y_S_CONV, 1, signal_sizes=(y.shape[-1], )) return y
def test(): #a=np.random.randn(128,1,64,64) a = np.zeros((2, 1, 5, 5)) a[0, 0, :, :] = np.array([[1., 2., 3., 4., 5.], [1., 2., 3., 4., 5.], [1., 2., 3., 4., 5.], [1., 2., 3., 4., 5.], [1., 2., 3., 4., 5.]]) a_torch = torch.tensor(a, dtype=torch.float32, device='cuda') astara_1 = xcorr2_torch(a_torch, a_torch).to(device='cuda') astara_2 = xcorr2_torch_CPU(a_torch, a_torch).to(device='cuda') astara_3 = xcorr2_torch(a_torch).to(device='cuda') #Apply an fftshift to astara_3 to make it consistent with the other definitions astara_3 = np.array(astara_3.cpu()) astara_3 = np.fft.fftshift(astara_3, (2, 3)) astara_3 = torch.tensor(astara_3, dtype=torch.float32, device='cuda') [n_batch, n_c, ha, wa] = a.shape absFa2 = torch.zeros((n_batch, n_c, 2 * ha - 1, 2 * wa - 1, 2)) absFa2[:, :, :, :, 0] = FourierMod2(a_torch) astara_4 = torch.irfft(absFa2, signal_ndim=2, onesided=False, normalized=False) #0-lag is at 0 # Apply an fftshift to astara_3 to make it consistent with the other definitions astara_4 = np.array(astara_4) astara_4 = np.fft.fftshift(astara_4, (2, 3)) astara_4 = torch.tensor(astara_4, dtype=torch.float32, device='cuda') diff_12 = torch.norm(astara_1 - astara_2, 2) diff_13 = torch.norm(astara_1 - astara_3, 2) diff_14 = torch.norm(astara_1 - astara_4, 2)
def forward(self, x): w1 = torch.nn.ReLU()(self.W1.repeat(x.shape[0], 1, 1, 1).to(x.device)) w2 = torch.nn.ReLU()(self.W2.repeat(x.shape[0], 1, 1, 1).to(x.device)) b1 = torch.nn.ReLU()(self.B1.repeat(x.shape[0], 1, 1, 1).to(x.device)) b2 = torch.nn.ReLU()(self.B2.repeat(x.shape[0], 1, 1, 1).to(x.device)) rft_x = torch.rfft(x, signal_ndim=3, normalized=True, onesided=True) init_spectrum = torch.sqrt( torch.pow(rft_x[..., 0], 2) + torch.pow(rft_x[..., 1], 2)) if self.log: spectrum = w2 * self.activation(w1 * torch.log(1 + init_spectrum) + b1) + b2 else: spectrum = w2 * self.activation(w1 * init_spectrum + b1) + b2 irf = torch.irfft(torch.stack([ rft_x[..., 0] * spectrum / (init_spectrum + 1e-16), rft_x[..., 1] * spectrum / (init_spectrum + 1e-16) ], dim=-1), signal_ndim=3, normalized=True, onesided=True, signal_sizes=x.shape[1:]) return irf
def mel2mcc(mel): mel_ndim = mel.shape[1] mel = mel.transpose(1, 2).unsqueeze(-1) mel = torch.cat([mel, torch.zeros_like(mel)], dim=-1) mcc = torch.irfft(mel, signal_ndim=1, signal_sizes=(2 * (mel_ndim - 1),)).transpose(1, 2)[:, :mel_ndim] mcc[:, 0] /= 2. return mcc
def forward(self, x): batchsize = x.shape[0] #Compute Fourier coeffcients up to factor of e^(- something constant) x_ft = torch.rfft(x, 3, normalized=True, onesided=True) # Multiply relevant Fourier modes out_ft = torch.zeros(batchsize, self.out_channels, x.size(-3), x.size(-2), x.size(-1) // 2 + 1, 2, device=x.device) out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \ compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1) out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \ compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2) out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \ compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3) out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \ compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4) #Return to physical space x = torch.irfft(out_ft, 3, normalized=True, onesided=True, signal_sizes=(x.size(-3), x.size(-2), x.size(-1))) return x
def ifft_image(self, input): input = input * self.scale input = torch.irfft(input, 2, normalized=True, signal_sizes=(self.h, self.w)) return input / 4
def diagImpConvFFT(x, B, h, mode=None): # convolution using FFT of (I + h B' B)^-1 n = x.shape m = B.shape mid1 = (m[2] - 1) // 2 mid2 = (m[3] - 1) // 2 Bp = torch.zeros(m[0], n[2], n[3], device=B.device) # flips up-down and left-right Bp[:, 0:(mid1 + 1), 0:(mid2 + 1)] = B[:, 0, mid1:, mid2:] Bp[:, -mid1:, 0:(mid2 + 1)] = B[:, 0, 0:mid1, -(mid2 + 1):] Bp[:, 0:(mid1 + 1), -mid2:] = B[:, 0, -(mid1 + 1):, 0:mid2] Bp[:, -mid1:, -mid2:] = B[:, 0, 0:mid1, 0:mid2] xh = torch.rfft(x, 2, onesided=False) Bh = torch.rfft(Bp, 2, onesided=False) xBh = torch.zeros(n[0], n[1], n[2], n[3], 2, device=B.device) if mode == 'laplacian': # don't need semi pos def if L is laplacian t = 1.0 / (h * torch.abs(Bh[:, :, :, 0]) + 1.0) else: t = 1.0 / (h * (Bh[:, :, :, 0]**2 + Bh[:, :, :, 1]**2) + 1.0) for i in range(n[0]): xBh[i, :, :, :, 0] = xh[i, :, :, :, 0] * t xBh[i, :, :, :, 1] = xh[i, :, :, :, 1] * t xB = torch.irfft(xBh, 2, onesided=False) return xB
def FFTConv(imgs, filt, plot=False): # take image of size B*C*H*W and filts of size 1*1*H*W, both real # filter should have odd dimensions # center is assumed at ceil(size/2) # # output: size B*C*H*W filtSize = np.array(filt.size()[2:]) if np.any(filtSize % 2 == 0): raise TypeError("filter size {} should be odd".format(filtSize)) imSize = np.array(imgs.size()[2:]) # zero pad # Pad arg = (last dim pad left side, last dim pad right side, 2nd last dim left side, etc..) fftSize = imSize + filtSize - 1 imgs = F.pad(imgs, (0, fftSize[0] - imSize[0], 0, fftSize[1] - imSize[1])) filt = F.pad(filt, (0, fftSize[0] - filtSize[0], 0, fftSize[1] - filtSize[1])) # shift the center to the upper left corner filt = roll_n(filt, 2, filtSize[0] // 2) filt = roll_n(filt, 3, filtSize[1] // 2) imgsFourier = torch.rfft( imgs, 2, onesided=False) # rfft doesn't require complex input filtFourier = torch.rfft(filt, 2, onesided=False) # Extract the real and imaginary parts imgR, imgIm = torch.unbind(imgsFourier, -1) filtR, filtIm = torch.unbind(filtFourier, -1) if plot == True: save_fig_double(filtR.data.cpu(), filtIm.data.cpu(), './', 'CurrentCTF-Fourier', iteration=None, Title1='Real', Title2='Imag') save_fig_double((imgR + 1e-8).abs().log().data.cpu(), (imgIm + 1e-8).abs().log().data.cpu(), './', 'CurrentProj-Fourier', iteration=None, Title1='Real', Title2='Imag') # Do element wise complex multiplication imgFilterdR = imgR * filtR - imgIm * filtIm imgFilteredIm = imgIm * filtR + imgR * filtIm imgFiltered = torch.stack((imgFilterdR, imgFilteredIm), -1) imgFiltered = torch.irfft(imgFiltered, 2, onesided=False) return imgFiltered[:, :, :imSize[0], : imSize[1]], imgsFourier, filtFourier, filt
def mse_bp(hx, y, scale): # first part: LS loss dip_loss_dif = (hx - y) dip_loss = torch.rfft(dip_loss_dif, 2, onesided=False, normalized=True).to(self.device) # second part: BP loss eps = 1e-3 mul_factor = 1e5 eps_ignored = 0.01 sigma = 0 h = torch.from_numpy(get_bicubic(scale)) h = torch.unsqueeze(h, 0).unsqueeze(0) conv_shape = (h.shape[2] + h.shape[2] - 1, h.shape[3] + h.shape[3] - 1) H = fft2(h, conv_shape[1], conv_shape[0]) H_flip = fft2(flip(h), conv_shape[1], conv_shape[0]) H_mul_H_flip = mul_complex(H, H_flip) H_mul_H_flip_ifft = torch.irfft(H_mul_H_flip, signal_ndim=2, normalized=True, onesided=False) h_downsampled = H_mul_H_flip_ifft[:, :, 1::scale, 1::scale] h_downsampled = pad_shift_filter(y, h_downsampled) H_downsampled = torch.rfft(h_downsampled, signal_ndim=2, normalized=True, onesided=False) bp_loss = torch.sqrt(abs2(H_downsampled)[:, :, :, :, 0:1]) bp_loss = mul_factor * bp_loss + eps_ignored * (sigma**2) + eps bp_loss = 1 / (torch.sqrt(bp_loss)) bp_loss = torch.repeat_interleave(bp_loss, 2, -1).to(self.device) loss_mat = bp_loss.to(self.device) * dip_loss return torch.mean(loss_mat**2)
def forward(self, z, x): if self.is_train: z_tmp = z[:, :, 56:183, 56:183] _, z2 = self.features(z_tmp) z1, z = self.features(z) # 5 x1, x = self.features(x) # 19 # fast cross correlation n, c, h, w = x.size() x = x.view(1, n * c, h, w) out = F.conv2d(x, z2, groups=n) out = out.view(n, 1, out.size(-2), out.size(-1)) # adjust the scale of responses #out = 0.001 * out + 0.0 out = self.adjust(out) #correlation_filter if self.is_train: zf = torch.rfft(z1, signal_ndim=2) xf = torch.rfft(x1, signal_ndim=2) kzzf = torch.sum(torch.sum(zf**2, dim=4, keepdim=True), dim=1, keepdim=True) kxzf = torch.sum(complex_mulconj(xf, zf), dim=1, keepdim=True) alphaf = self.yf / (kzzf + self.lambda0) # very Ugly response = torch.irfft(complex_mul(kxzf, alphaf), signal_ndim=2) return out, response else: return out
def idst(x, expkp1=None): """ Batch Inverse Discrete Sine Transformation without normalization to coefficients. Compute y_u = \sum_i x_i cos(pi*(2u+1)*i/(2N)), Impelements the 2N padding trick to solve IDCT with IFFT in the following link, https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/spectral_ops.py 1. Multiply by 2*exp(1j*pi*u/(2N)) 2. Pad x by zeros 3. Perform IFFT 4. Extract the real part """ # last dimension N = x.size(-1) if expkp1 is None: expkp1 = get_expkp1(N, dtype=x.dtype, device=x.device) # multiply by 2*exp(1j*pi*u/(2N)) x_pad = x.unsqueeze(-1).mul(expkp1) # pad second last dimension, excluding the complex number dimension x_pad = F.pad(x_pad, (0, 0, 0, N), 'constant', 0) if len(x.size()) == 1: x_pad.unsqueeze_(0) # the last dimension here becomes -2 because complex numbers introduce a new dimension y = torch.irfft(x_pad, signal_ndim=1, normalized=False, onesided=False, signal_sizes=[2*N])[..., 1:N+1] y.mul_(N) if len(x.size()) == 1: y.squeeze_(0) return y
def forward(self, z): """ :param z: the multiscale searching patch. Shape (num_scale, 3, crop_sz, crop_sz) :return response: the response of cross correlation. Shape (num_scale, 1, crop_sz, crop_sz) You are required to calculate response using self.wf to do cross correlation on the searching patch z """ # obtain feature of z and add hanning window z = self.feature(z) * self.config.cos_window # TODO: You are required to calculate response using self.wf to do cross correlation on the searching patch z # put your code here #calculate fourier transform z = torch.rfft(z, signal_ndim = 2) ### complex multiplication (x + yi)(u + vi) = (xu - yv) + (xv + yu)i ### conjugate multiplication (y1 - y2i)(p1 + p2i) = (y1p1 + y2p2) + (y1p2 - y2p1)i real_phiw = self.wf[..., 0] * z[..., 0] + self.wf[..., 1] * z[..., 1] imag_phiw = self.wf[..., 0] * z[..., 1] - z[..., 0] * self.wf[..., 1] phi_w = torch.stack((real_phiw, imag_phiw), -1) #n, c, h, w, v = phi_w.shape #response = torch.irfft(phi_w.reshape(n, 1, c, h, w, v).sum(2), signal_ndim = 2) response = torch.irfft(torch.sum(phi_w, dim = 1, keepdim=True), signal_ndim = 2) return response
def forward(self, x): ''' :param x: (b, c, h, w) :return: (b, c, h, w) ''' batch, c, h, w = x.size() r_size = x.size() ffted = torch.rfft(x, signal_ndim=2) # (batch, c, h, w/2+1, 2) # (batch, c, 2, h, w/2+1) ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() ffted = ffted.view(ffted.size()[0:1] + (-1, ) + ffted.size()[3:]) ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) ffted = self.relu(self.bn(ffted)) ffted = ffted.view(ffted.size()[0:1] + ( c, 2, ) + ffted.size()[2:]).permute( 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) output = torch.irfft(ffted, signal_ndim=2, signal_sizes=r_size[2:]) return self.alpha * output + x
def idct_1d(X): """ The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III Our definition of idct is that idct(dct(x)) == x :param X: the input signal :return: the inverse DCT-II of the signal over the last dimension """ x_shape = X.shape N = x_shape[-1] X_v = X.contiguous().view(-1, x_shape[-1]) / 2 k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N) W_r = torch.cos(k) W_i = torch.sin(k) V_t_r = X_v V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1) V_r = V_t_r * W_r - V_t_i * W_i V_i = V_t_r * W_i + V_t_i * W_r V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) v = torch.irfft(V, 1, onesided=False) x = v.new_zeros(v.shape) x[:, ::2] += v[:, :N - (N // 2)] x[:, 1::2] += v.flip([1])[:, :N // 2] return x.view(*x_shape)
def forward(self, z): """ :param z: the multiscale searching patch. Shape (num_scale, 3, crop_sz, crop_sz) :return response: the response of cross correlation. Shape (num_scale, 1, crop_sz, crop_sz) You are required to calculate response using self.wf to do cross correlation on the searching patch z """ # obtain feature of z and add hanning window z = self.feature(z) * self.config.cos_window num_scale, channels, crop_sz, crop_sz = z.shape zf = torch.rfft(z, 2) w_star = self.wf.clone().detach() w_star[:, :, :, :, 1] = w_star[:, :, :, :, 1] * -1 output = torch.cuda.FloatTensor(num_scale, 1, crop_sz, crop_sz // 2 + 1, 2).fill_(0) for c in range(num_scale): for l in range(channels): temp = torch.mul(w_star[0, 1, :, :, :], zf[c, l, :, :, :]) out_real, out_imag = self.imag_mult(w_star[0, 1, :, :, 0], w_star[0, 1, :, :, 1], zf[c, l, :, :, 0], zf[c, l, :, :, 1]) temp[:, :, 0] = out_real temp[:, :, 1] = out_imag output[c, 0, :, :, :] += temp response = torch.irfft(output, 2) return response
def FDA_source_to_target(src_img, trg_img, L=0.1): # exchange magnitude # input: src_img, trg_img # get fft of both source and target fft_src = torch.rfft(src_img.clone(), signal_ndim=2, onesided=False) fft_trg = torch.rfft(trg_img.clone(), signal_ndim=2, onesided=False) # extract amplitude and phase of both ffts amp_src, pha_src = extract_ampl_phase(fft_src.clone()) amp_trg, pha_trg = extract_ampl_phase(fft_trg.clone()) # replace the low frequency amplitude part of source with that from target amp_src_ = low_freq_mutate(amp_src.clone(), amp_trg.clone(), L=L) # recompose fft of source fft_src_ = torch.zeros(fft_src.size(), dtype=torch.float) fft_src_[:, :, :, :, 0] = torch.cos(pha_src.clone()) * amp_src_.clone() fft_src_[:, :, :, :, 1] = torch.sin(pha_src.clone()) * amp_src_.clone() # get the recomposed image: source content, target style _, _, imgH, imgW = src_img.size() src_in_trg = torch.irfft(fft_src_, signal_ndim=2, onesided=False, signal_sizes=[imgH, imgW]) return src_in_trg
def fft_frequency_decompose(x, min_size): coeffs = torch.rfft(input=x, signal_ndim=1, normalized=True) def make_mask(size, start, stop): mask = torch.zeros(size).to(x.device) mask[start:stop] = 1 return mask[None, None, :, None] output = {} current_size = min_size while current_size <= x.shape[-1]: sl = coeffs[:, :, :current_size // 2 + 1, :] if current_size > min_size: mask = make_mask(size=sl.shape[2], start=current_size // 4, stop=current_size // 2 + 1) sl = sl * mask recon = torch.irfft(input=sl, signal_ndim=1, normalized=True, signal_sizes=(current_size, )) # if recon.shape[-1] != x.shape[-1]: # recon = torch.zeros_like(recon) output[recon.shape[-1]] = recon current_size *= 2 return output