def interpolate_dft(a: torch.Tensor, interp_fs) -> torch.Tensor: if isinstance(interp_fs, torch.Tensor): return complex.mult(a, interp_fs) if isinstance(interp_fs, (tuple, list)): return complex.mult(complex.mult(a, interp_fs[0]), interp_fs[1]) raise ValueError('"interp_fs" must be tensor or tuple of tensors.')
def get_interp_fourier(sz: torch.Tensor, method='ideal', bicubic_param=0.5, centering=True, windowing=False, device='cpu'): ky, kx = fourier.get_frequency_coord(sz) if method == 'ideal': interp_y = torch.ones(ky.shape) / sz[0] interp_x = torch.ones(kx.shape) / sz[1] elif method == 'bicubic': interp_y = cubic_spline_fourier(ky / sz[0], bicubic_param) / sz[0] interp_x = cubic_spline_fourier(kx / sz[1], bicubic_param) / sz[1] else: raise ValueError('Unknown method.') if centering: interp_y = complex.mult(interp_y, complex.exp_imag((-math.pi / sz[0]) * ky)) interp_x = complex.mult(interp_x, complex.exp_imag((-math.pi / sz[1]) * kx)) if windowing: raise NotImplementedError return interp_y.to(device), interp_x.to(device)
def A(self, hf: TensorList): # Classify sh = complex.mtimes(self.training_samples, hf.permute(2,3,1,0,4)) # (h, w, num_samp, num_filt, 2) sh = complex.mult(self.sample_weights.view(1,1,-1,1), sh) # Multiply with transpose hf_out = complex.mtimes(sh.permute(0,1,3,2,4), self.training_samples, conj_b=True).permute(2,3,0,1,4) # Add regularization for hfe, hfe_out, reg_filter in zip(hf, hf_out, self.reg_filter): reg_pad1 = min(reg_filter.shape[-2] - 1, hfe.shape[-3] - 1) reg_pad2 = min(reg_filter.shape[-1] - 1, 2*hfe.shape[-2]- 2) # Add part needed for convolution if reg_pad2 > 0: hfe_conv = torch.cat([complex.conj(hfe[...,1:reg_pad2+1,:].flip((2,3))), hfe], -2) else: hfe_conv = hfe.clone() # Shift data to batch dimension hfe_conv = hfe_conv.permute(0,1,4,2,3).reshape(-1, 1, hfe_conv.shape[-3], hfe_conv.shape[-2]) # Do first convolution hfe_conv = F.conv2d(hfe_conv, reg_filter, padding=(reg_pad1, reg_pad2)) # Do second convolution remove_size = min(reg_pad2, hfe.shape[-2]-1) hfe_conv = F.conv2d(hfe_conv[...,remove_size:], reg_filter) # Reshape back and add hfe_out += hfe_conv.reshape(hfe.shape[0], hfe.shape[1], 2, hfe.shape[2], hfe.shape[3]).permute(0,1,3,4,2) return hf_out
def shift_fs(a: torch.Tensor, shift: torch.Tensor): """Shift a sample a in the Fourier domain. Params: a : The fourier coefficiens of the sample. shift : The shift to be performed normalized to the range [-pi, pi].""" if a.dim() != 5: raise ValueError( 'a must be the Fourier coefficients, a 5-dimensional tensor.') if shift[0] == 0 and shift[1] == 0: return a ky, kx = get_frequency_coord((a.shape[2], 2 * a.shape[3] - 1), device=a.device) return complex.mult( complex.mult(a, complex.exp_imag(shift[0].item() * ky)), complex.exp_imag(shift[1].item() * kx))
def __call__(self, x: TensorList): """ Compute residuals :param x: [filters, projection_matrices] :return: [data_terms, filter_regularizations, proj_mat_regularizations] """ hf = x[:len(x) // 2] P = x[len(x) // 2:] compressed_samples = complex.mtimes(self.training_samples, P) residuals = complex.mtimes(compressed_samples, hf.permute( 2, 3, 1, 0, 4)) # (h, w, num_samp, num_filt, 2) residuals = residuals - self.yf if self.sample_weights_sqrt is not None: residuals = complex.mult( self.sample_weights_sqrt.view(1, 1, -1, 1), residuals) # Add spatial regularization for hfe, reg_filter in zip(hf, self.reg_filter): reg_pad1 = min(reg_filter.shape[-2] - 1, hfe.shape[-3] - 1) reg_pad2 = min(reg_filter.shape[-1] - 1, hfe.shape[-2] - 1) # Add part needed for convolution if reg_pad2 > 0: hfe_left_padd = complex.conj( hfe[..., 1:reg_pad2 + 1, :].clone().detach().flip((2, 3))) hfe_conv = torch.cat([hfe_left_padd, hfe], -2) else: hfe_conv = hfe.clone() # Shift data to batch dimension hfe_conv = hfe_conv.permute(0, 1, 4, 2, 3).reshape(-1, 1, hfe_conv.shape[-3], hfe_conv.shape[-2]) # Do first convolution hfe_conv = F.conv2d(hfe_conv, reg_filter, padding=(reg_pad1, reg_pad2)) residuals.append(hfe_conv) # Add regularization for projection matrix residuals.extend(math.sqrt(self.params.projection_reg) * P) return residuals
def apply_filter(self, sample_xf: TensorList) -> torch.Tensor: return complex.mult(self.filter, sample_xf).sum(1, keepdim=True)
def apply_filters(self, sample_xf: TensorList) -> torch.Tensor: return TensorList([ complex.mult(f, sample_xf).sum(1, keepdim=True) for f in self.filters ])