def fn(a): b = a.clone() b1 = torch.view_as_complex(b) b2 = b1.reshape(b1.numel()) return b2
def table_interp_adjoint( data: Tensor, omega: Tensor, tables: List[Tensor], n_shift: Tensor, numpoints: Tensor, table_oversamp: Tensor, offsets: Tensor, grid_size: Tensor, ) -> Tensor: """Table interpolation adjoint backend. This interpolates from an off-grid set of data at coordinates given by ``omega`` to on-grid locations. Args: data: Off-grid data to interpolate from. omega: Fourier coordinates to interpolate to (in radians/voxel, -pi to pi). tables: List of tables for each image dimension. n_shift: Size of desired fftshift. numpoints: Number of neighbors in each dimension. table_oversamp: Size of table in each dimension. offsets: A list of offset values for interpolation. min_kspace_per_fork: Minimum number of k-space samples to use in each process fork. Returns: ``data`` interpolated to gridded locations. """ dtype = data.dtype device = data.device int_type = torch.long # we fork processes for accumulation, so we need to do a bit of thread management # for OMP to make sure we don't oversubscribe (managment not necessary for non-OMP) num_threads = torch.get_num_threads() factors = torch.arange(1, math.sqrt(num_threads)).flip(0) factors = factors[torch.remainder(torch.tensor(num_threads), factors) == 0] threads_per_fork = 1 for factor in factors: if factor <= num_threads / (data.shape[0] * data.shape[1]): threads_per_fork = int(factor) break num_forks = num_threads // threads_per_fork # calculate output size output_prod = int(torch.prod(grid_size)) output_size = [data.shape[0], data.shape[1]] for el in grid_size: output_size.append(int(el)) # convert to normalized freq locs and sort tm = omega / (2 * np.pi / grid_size.to(omega).unsqueeze(-1)) tm, omega, data = sort_data(tm, omega, data, grid_size) # compute interpolation centers centers = torch.floor(numpoints * table_oversamp / 2).to(dtype=int_type) # offset from k-space to first coef loc base_offset = 1 + torch.floor(tm - numpoints.unsqueeze(1) / 2.0).to( dtype=int_type) # initialized flattened image image = torch.zeros( size=(data.shape[0], data.shape[1], output_prod), dtype=dtype, device=device, ) # phase for fftshift data = (data * imag_exp( torch.mv(torch.transpose(omega, 1, 0), n_shift), return_complex=True, ).conj()) # necessary for index_add_ # TODO: change when PyTorch supports complex numbers for index_add_, index_put_ if not device == torch.device("cpu"): image = torch.view_as_real(image) # loop over offsets and take advantage of broadcasting for offset in offsets: coef, arr_ind = calc_coef_and_indices( tm=tm, base_offset=base_offset, offset_increments=offset, tables=tables, centers=centers, table_oversamp=table_oversamp, grid_size=grid_size, conjcoef=True, ) # we have to fork this multiply ourselves tmp = coef * data if not device == torch.device("cpu"): tmp = torch.view_as_real(tmp) if USING_OMP and device == torch.device("cpu"): torch.set_num_threads(threads_per_fork) # this is a much faster way of doing index accumulation fork_and_accum(image, arr_ind, tmp, num_forks) if USING_OMP and device == torch.device("cpu"): torch.set_num_threads(num_threads) if not device == torch.device("cpu"): image = torch.view_as_complex(image) return image.view(output_size)
def get_reg_filter(sz: torch.Tensor, target_sz: torch.Tensor, params): """Computes regularization filter in CCOT and ECO.""" if not params.use_reg_window: return params.reg_window_min * torch.ones(1, 1, 1, 1) if getattr(params, 'reg_window_square', False): target_sz = target_sz.prod().sqrt() * torch.ones(2) # Normalization factor reg_scale = 0.5 * target_sz # Construct grid if getattr(params, 'reg_window_centered', True): wrg = torch.arange(-int((sz[0] - 1) / 2), int(sz[0] / 2 + 1), dtype=torch.float32).view(1, 1, -1, 1) wcg = torch.arange(-int((sz[1] - 1) / 2), int(sz[1] / 2 + 1), dtype=torch.float32).view(1, 1, 1, -1) else: wrg = torch.cat([ torch.arange(0, int(sz[0] / 2 + 1), dtype=torch.float32), torch.arange(-int((sz[0] - 1) / 2), 0, dtype=torch.float32) ]).view(1, 1, -1, 1) wcg = torch.cat([ torch.arange(0, int(sz[1] / 2 + 1), dtype=torch.float32), torch.arange(-int((sz[1] - 1) / 2), 0, dtype=torch.float32) ]).view(1, 1, 1, -1) # Construct regularization window reg_window = (params.reg_window_edge - params.reg_window_min) * \ (torch.abs(wrg / reg_scale[0]) ** params.reg_window_power + torch.abs(wcg / reg_scale[1]) ** params.reg_window_power) + params.reg_window_min # Compute DFT and enforce sparsity reg_window_dft = torch.view_as_real( torch_fft.rfftn(reg_window, dim=[-2, -1])) / sz.prod() reg_window_dft_abs = complex.abs(reg_window_dft) reg_window_dft[reg_window_dft_abs < params.reg_sparsity_threshold * reg_window_dft_abs.max(), :] = 0 # Do the inverse transform to correct for the window minimum reg_window_sparse = torch_fft.irfftn(torch.view_as_complex(reg_window_dft), s=sz.long().tolist(), dim=[-2, -1]) reg_window_dft[ 0, 0, 0, 0, 0] += params.reg_window_min - sz.prod() * reg_window_sparse.min() reg_window_dft = complex.real(fourier.rfftshift2(reg_window_dft)) # Remove zeros max_inds, _ = reg_window_dft.nonzero(as_tuple=False).max(dim=0) mid_ind = int((reg_window_dft.shape[2] - 1) / 2) top = max_inds[-2].item() + 1 bottom = 2 * mid_ind - max_inds[-2].item() right = max_inds[-1].item() + 1 reg_window_dft = reg_window_dft[..., bottom:top, :right] if reg_window_dft.shape[-1] > 1: reg_window_dft = torch.cat( [reg_window_dft[..., 1:].flip((2, 3)), reg_window_dft], -1) return reg_window_dft
def propagation_ASM(u_in, feature_size, wavelength, z, linear_conv=True, padtype='zero', return_H=False, precomped_H=None, return_H_exp=False, precomped_H_exp=None, dtype=torch.float32): """Propagates the input field using the angular spectrum method Inputs ------ u_in: PyTorch Complex tensor (torch.cfloat) of size (num_images, 1, height, width) -- updated with PyTorch 1.7.0 feature_size: (height, width) of individual holographic features in m wavelength: wavelength in m z: propagation distance linear_conv: if True, pad the input to obtain a linear convolution padtype: 'zero' to pad with zeros, 'median' to pad with median of u_in's amplitude return_H[_exp]: used for precomputing H or H_exp, ends the computation early and returns the desired variable precomped_H[_exp]: the precomputed value for H or H_exp dtype: torch dtype for computation at different precision Output ------ tensor of size (num_images, 1, height, width, 2) """ if linear_conv: # preprocess with padding for linear conv. input_resolution = u_in.size()[-2:] conv_size = [i * 2 for i in input_resolution] if padtype == 'zero': padval = 0 elif padtype == 'median': padval = torch.median(torch.pow((u_in**2).sum(-1), 0.5)) u_in = utils.pad_image(u_in, conv_size, padval=padval, stacked_complex=False) if precomped_H is None and precomped_H_exp is None: # resolution of input field, should be: (num_images, num_channels, height, width, 2) field_resolution = u_in.size() # number of pixels num_y, num_x = field_resolution[2], field_resolution[3] # sampling inteval size dy, dx = feature_size # size of the field y, x = (dy * float(num_y), dx * float(num_x)) # frequency coordinates sampling fy = np.linspace(-1 / (2 * dy) + 0.5 / (2 * y), 1 / (2 * dy) - 0.5 / (2 * y), num_y) fx = np.linspace(-1 / (2 * dx) + 0.5 / (2 * x), 1 / (2 * dx) - 0.5 / (2 * x), num_x) # momentum/reciprocal space FX, FY = np.meshgrid(fx, fy) # transfer function in numpy (omit distance) HH = 2 * math.pi * np.sqrt(1 / wavelength**2 - (FX**2 + FY**2)) # create tensor & upload to device (GPU) H_exp = torch.tensor(HH, dtype=dtype).to(u_in.device) ### # here one may iterate over multiple distances, once H_exp is uploaded on GPU # reshape tensor and multiply H_exp = torch.reshape(H_exp, (1, 1, *H_exp.size())) # handle loading the precomputed H_exp value, or saving it for later runs elif precomped_H_exp is not None: H_exp = precomped_H_exp if precomped_H is None: # multiply by distance H_exp = torch.mul(H_exp, z) # band-limited ASM - Matsushima et al. (2009) fy_max = 1 / np.sqrt((2 * z * (1 / y))**2 + 1) / wavelength fx_max = 1 / np.sqrt((2 * z * (1 / x))**2 + 1) / wavelength H_filter = torch.tensor(((np.abs(FX) < fx_max) & (np.abs(FY) < fy_max)).astype(np.uint8), dtype=dtype) # get real/img components H_real, H_imag = utils.polar_to_rect(H_filter.to(u_in.device), H_exp) H = torch.stack((H_real, H_imag), 4) H = utils.ifftshift(H) H = torch.view_as_complex(H) else: H = precomped_H # return for use later as precomputed inputs if return_H_exp: return H_exp if return_H: return H # For who cannot use Pytorch 1.7.0 and its Complex tensors support: # # angular spectrum # U1 = torch.fft(utils.ifftshift(u_in), 2, True) # # # convolution of the system # U2 = utils.mul_complex(H, U1) # # # Fourier transform of the convolution to the observation plane # u_out = utils.fftshift(torch.ifft(U2, 2, True)) U1 = torch.fft.fftn(utils.ifftshift(u_in), dim=(-2, -1), norm='ortho') U2 = H * U1 u_out = utils.fftshift(torch.fft.ifftn(U2, dim=(-2, -1), norm='ortho')) if linear_conv: # return utils.crop_image(u_out, input_resolution) # using stacked version return utils.crop_image(u_out, input_resolution, pytorch=True, stacked_complex=False) # using complex tensor else: return u_out
def kb_table_nufft_adjoint( data: Tensor, scaling_coef: Tensor, im_size: Tensor, grid_size: Tensor, omega: Tensor, tables: List[Tensor], n_shift: Tensor, numpoints: Tensor, table_oversamp: Tensor, offsets: Tensor, norm: Optional[str] = None, ) -> Tensor: """Kaiser-Bessel NUFFT adjoint with table interpolation. See :py:class:`~torchkbnufft.KbNufftAdjoint` for an overall description of the adjoint NUFFT. Args: data: Scattered data to be iNUFFT'd to an image. scaling_coef: Image-domain coefficients to compensate for interpolation errors. im_size: Size of image with length being the number of dimensions. grid_size: Size of grid to use for interpolation, typically 1.25 to 2 times ``im_size``. omega: k-space trajectory (in radians/voxel). tables: Interpolation tables (one table for each dimension). n_shift: Size for fftshift, usually ``im_size // 2``. numpoints: Number of neighbors to use for interpolation. table_oversamp: Table oversampling factor. offsets: A list of offsets, looping over all possible combinations of `numpoints`. norm: Whether to apply normalization with the FFT operation. Options are ``"ortho"`` or ``None``. Returns: ``data`` transformed to an image. """ is_complex = True if not data.is_complex(): if not data.shape[-1] == 2: raise ValueError("For real inputs, last dimension must be size 2.") is_complex = False data = torch.view_as_complex(data) image = ifft_and_scale( image=kb_table_interp_adjoint( data=data, omega=omega, tables=tables, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, offsets=offsets, grid_size=grid_size, ), scaling_coef=scaling_coef, im_size=im_size, grid_size=grid_size, norm=norm, ) if is_complex is False: image = torch.view_as_real(image) return image
def _multi_tensor_adagrad( params: List[Tensor], grads: List[Tensor], state_sums: List[Tensor], state_steps: List[Tensor], *, lr: float, weight_decay: float, lr_decay: float, eps: float, has_sparse_grad: bool, maximize: bool, ): # Foreach functions will throw errors if given empty lists if len(params) == 0: return if maximize: grads = torch._foreach_neg(grads) if has_sparse_grad is None: has_sparse_grad = any(grad.is_sparse for grad in grads) if has_sparse_grad: return _single_tensor_adagrad( params, grads, state_sums, state_steps, lr=lr, weight_decay=weight_decay, lr_decay=lr_decay, eps=eps, has_sparse_grad=has_sparse_grad, maximize=False, ) # Update steps torch._foreach_add_(state_steps, 1) if weight_decay != 0: torch._foreach_add_(grads, params, alpha=weight_decay) minus_clr = [-lr / (1 + (step - 1) * lr_decay) for step in state_steps] grads = [ torch.view_as_real(x) if torch.is_complex(x) else x for x in grads ] state_sums = [ torch.view_as_real(x) if torch.is_complex(x) else x for x in state_sums ] torch._foreach_addcmul_(state_sums, grads, grads, value=1) std = torch._foreach_add(torch._foreach_sqrt(state_sums), eps) toAdd = torch._foreach_div(torch._foreach_mul(grads, minus_clr), std) toAdd = [ torch.view_as_complex(x) if torch.is_complex(params[i]) else x for i, x in enumerate(toAdd) ] torch._foreach_add_(params, toAdd) state_sums = [ torch.view_as_complex(x) if torch.is_complex(params[i]) else x for i, x in enumerate(state_sums) ]
def view_as_complex(x): sh = list(x.shape) sh[-1] //= 2 sh += [2] x = x.view(sh) return torch.view_as_complex(x)
def trasmission_2nd_g_alpha_beta_gamma_complex(patched_solved, x_batch_train, index, transmission_func, args, diffL, last_g=None): # g = alpha * u + beta * d2u/dt2 + gamma * du/dn transmission_func = find_transmission_function(transmission_func) row = int(index / args.y_patches) col = index % args.y_patches patched_solved_complex = torch.view_as_complex( patched_solved.permute(0, 2, 3, 1).contiguous()) g = torch.zeros(patched_solved_complex[index].shape, dtype=torch.cfloat) alpha = torch.zeros(patched_solved_complex[index].shape, dtype=torch.cfloat) beta = torch.zeros(patched_solved_complex[index].shape, dtype=torch.cfloat) gamma = torch.zeros(patched_solved_complex[index].shape, dtype=torch.cfloat) if row == 0: # top alpha[0, :] = 1 + 0j beta[0, :] = 0 + 0j gamma[0, :] = 0 + 0j g[0, :] = patched_solved_complex[index, 0, :] print("means: u, dudn, d2udt2: ", torch.mean(patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels,:]),\ torch.mean((patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels,:] - patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels+1,:])/diffL), \ torch.mean((patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels, :-2] + \ patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels, 2: ] - \ 2*patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels, 1:-1] )/diffL**2)) else: this_alpha, this_beta, this_gamma = transmission_func( x_batch_train[index, 0, 0, :], args) alpha[0, :] = this_alpha beta[0, :] = this_beta gamma[0, :] = this_gamma g[0,:] = gamma[0,:]*(patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels,:] - patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels+1,:])/diffL + \ alpha[0,:]* patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels,:] g[0,1:-1] += beta[0,1:-1]*(patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels, :-2] + \ patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels, 2: ] - \ 2*patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels, 1:-1] )/diffL**2 if row == args.x_patches - 1: # bottom alpha[-1, :] = 1 + 0j beta[-1, :] = 0 + 0j gamma[-1, :] = 0 + 0j g[-1, :] = patched_solved_complex[index, -1, :] else: this_alpha, this_beta, this_gamma = transmission_func( x_batch_train[index, 0, -1, :], args) alpha[-1, :] = this_alpha beta[-1, :] = this_beta gamma[-1, :] = this_gamma g[-1,:] = gamma[-1,:]*(patched_solved_complex[index+args.y_patches,args.overlap_pixels-1,:] - patched_solved_complex[index+args.y_patches,args.overlap_pixels-2,:])/diffL + \ alpha[-1,:]* patched_solved_complex[index+args.y_patches,args.overlap_pixels-1,:] g[-1,1:-1] += beta[-1,1:-1]*(patched_solved_complex[index+args.y_patches,args.overlap_pixels-1, :-2] + \ patched_solved_complex[index+args.y_patches,args.overlap_pixels-1, 2: ] - \ 2*patched_solved_complex[index+args.y_patches,args.overlap_pixels-1, 1:-1] )/diffL**2 if col == 0: # left alpha[:, 0] = 1 + 0j beta[:, 0] = 0 + 0j gamma[:, 0] = 0 + 0j g[:, 0] = patched_solved_complex[index, :, 0] else: this_alpha, this_beta, this_gamma = transmission_func( x_batch_train[index, 0, :, 0], args) alpha[:, 0] = this_alpha beta[:, 0] = this_beta gamma[:, 0] = this_gamma g[:,0] = gamma[:,0]*(patched_solved_complex[index-1,:,args.domain_sizey-args.overlap_pixels] - patched_solved_complex[index-1,:,args.domain_sizey-args.overlap_pixels+1])/diffL + \ alpha[:,0]* patched_solved_complex[index-1,:,args.domain_sizey-args.overlap_pixels] g[1:-1,0] += beta[1:-1,0]*(patched_solved_complex[index-1, :-2,args.domain_sizey-args.overlap_pixels] + \ patched_solved_complex[index-1,2: ,args.domain_sizey-args.overlap_pixels] - \ 2*patched_solved_complex[index-1,1:-1,args.domain_sizey-args.overlap_pixels])/diffL**2 if col == args.y_patches - 1: # right alpha[:, -1] = 1 + 0j beta[:, -1] = 0 + 0j gamma[:, -1] = 0 + 0j g[:, -1] = patched_solved_complex[index, :, -1] else: this_alpha, this_beta, this_gamma = transmission_func( x_batch_train[index, 0, :, -1], args) alpha[:, -1] = this_alpha beta[:, -1] = this_beta gamma[:, -1] = this_gamma g[:,-1] = gamma[:,-1]*(patched_solved_complex[index+1,:,args.overlap_pixels-1] - patched_solved_complex[index+1,:,args.overlap_pixels-2])/diffL + \ alpha[:,-1]* patched_solved_complex[index+1,:,args.overlap_pixels-1] g[1:-1,-1] += beta[1:-1,-1]*(patched_solved_complex[index+1, :-2,args.overlap_pixels-1] + \ patched_solved_complex[index+1,2: ,args.overlap_pixels-1] - \ 2*patched_solved_complex[index+1,1:-1,args.overlap_pixels-1])/diffL**2 if last_g is not None: g = (1 - args.relaxation) * last_g + args.relaxation * g return g, alpha, beta, gamma
def trasmission_pade_g_alpha_beta_gamma_complex(patched_solved, x_batch_train, index, transmission_func, args, operator, last_g=None): # g = alpha * u + beta*(du/dn - gamma * operator * u) # gamma should be 1 if the math is correct transmission_func = find_transmission_function(transmission_func) row = int(index / args.y_patches) col = index % args.y_patches patched_solved_complex = torch.view_as_complex( patched_solved.permute(0, 2, 3, 1).contiguous()).numpy() g = np.zeros(patched_solved_complex[index].shape, dtype=np.csingle) alpha = np.zeros(patched_solved_complex[index].shape, dtype=np.csingle) beta = np.zeros(patched_solved_complex[index].shape, dtype=np.csingle) gamma = np.zeros(patched_solved_complex[index].shape, dtype=np.csingle) if col == 0: # left alpha[:, 0] = 1 + 0j beta[:, 0] = 0 + 0j gamma[:, 0] = 0 + 0j g[:, 0] = patched_solved_complex[index, :, 0] else: this_beta, this_gamma = transmission_func( x_batch_train[index, 0, :, 0], args) alpha[1:-1, 0] = 0 + 0j beta[:, 0] = this_beta gamma[:, 0] = this_gamma g[ :, 0] = beta[:,0]*(patched_solved_complex[index-1,:,args.domain_sizey-args.overlap_pixels] - patched_solved_complex[index-1,:,args.domain_sizey-args.overlap_pixels+1] + \ -gamma[:,0]*operator.dot(patched_solved_complex[index-1,:,args.domain_sizey-args.overlap_pixels])) g[0, 0] = patched_solved_complex[index - 1, 0, args.domain_sizey - args.overlap_pixels] g[-1, 0] = patched_solved_complex[index - 1, -1, args.domain_sizey - args.overlap_pixels] if col == args.y_patches - 1: # right alpha[:, -1] = 1 + 0j beta[:, -1] = 0 + 0j gamma[:, -1] = 0 + 0j g[:, -1] = patched_solved_complex[index, :, -1] else: this_beta, this_gamma = transmission_func( x_batch_train[index, 0, :, -1], args) alpha[1:-1, -1] = 0 + 0j beta[:, -1] = this_beta gamma[:, -1] = this_gamma g[ :,-1] = beta[:,-1]*(patched_solved_complex[index+1,:,args.overlap_pixels-1] - patched_solved_complex[index+1,:,args.overlap_pixels-2] + \ -gamma[:,-1]*operator.dot(patched_solved_complex[index+1,:,args.overlap_pixels-1])) g[0, -1] = patched_solved_complex[index + 1, 0, args.overlap_pixels - 1] g[-1, -1] = patched_solved_complex[index + 1, -1, args.overlap_pixels - 1] if row == 0: # top alpha[0, :] = 1 + 0j beta[0, :] = 0 + 0j gamma[0, :] = 0 + 0j g[0, :] = patched_solved_complex[index, 0, :] else: this_beta, this_gamma = transmission_func( x_batch_train[index, 0, 0, :], args) alpha[0, 1:-1] = 0 + 0j beta[0, :] = this_beta gamma[0, :] = this_gamma g[ 0, :] = beta[0,:]*(patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels,:] - patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels+1,:] + \ -gamma[0,:]*operator.dot(patched_solved_complex[index-args.y_patches,args.domain_sizex-args.overlap_pixels,:])) g[0, 0] = patched_solved_complex[index - args.y_patches, args.domain_sizex - args.overlap_pixels, 0] g[0, -1] = patched_solved_complex[index - args.y_patches, args.domain_sizex - args.overlap_pixels, -1] if row == args.x_patches - 1: # bottom alpha[-1, :] = 1 + 0j beta[-1, :] = 0 + 0j gamma[-1, :] = 0 + 0j g[-1, :] = patched_solved_complex[index, -1, :] else: this_beta, this_gamma = transmission_func( x_batch_train[index, 0, -1, :], args) alpha[-1, 1:-1] = 0 + 0j beta[-1, :] = this_beta gamma[-1, :] = this_gamma g[-1,:] = beta[-1,:]*(patched_solved_complex[index+args.y_patches,args.overlap_pixels-1,:] - patched_solved_complex[index+args.y_patches,args.overlap_pixels-2,:] + \ -gamma[-1,:]*operator.dot(patched_solved_complex[index+args.y_patches,args.overlap_pixels-1,:])) g[-1, :] = patched_solved_complex[index + args.y_patches, args.overlap_pixels - 1, 0] g[-1, :] = patched_solved_complex[index + args.y_patches, args.overlap_pixels - 1, -1] if last_g is not None: g = (1 - args.relaxation) * last_g + args.relaxation * g return g, alpha, beta, gamma
def forward( self, data: Tensor, omega: Tensor, interp_mats: Optional[Tuple[Tensor, Tensor]] = None, smaps: Optional[Tensor] = None, norm: Optional[str] = None, ) -> Tensor: """Interpolate from scattered data to gridded data and then iFFT. Input tensors should be of shape ``(N, C) + klength``, where ``N`` is the batch size and ``C`` is the number of sensitivity coils. ``omega``, the k-space trajectory, should be of size ``(len(grid_size), klength)`` or ``(N, len(grid_size), klength)``, where ``klength`` is the length of the k-space trajectory. Note: If the batch dimension is included in ``omega``, the interpolator will parallelize over the batch dimension. This is efficient for many small trajectories that might occur in dynamic imaging settings. If your tensors are real, ensure that 2 is the size of the last dimension. Args: data: Data to be gridded and then inverse FFT'd. omega: k-space trajectory (in radians/voxel). interp_mats: 2-tuple of real, imaginary sparse matrices to use for sparse matrix NUFFT interpolation (overrides default table interpolation). smaps: Sensitivity maps. If input, these will be multiplied before the forward NUFFT. norm: Whether to apply normalization with the FFT operation. Options are ``"ortho"`` or ``None``. Returns: ``data`` transformed to the image domain. """ if smaps is not None: if not smaps.dtype == data.dtype: raise TypeError("data dtype does not match smaps dtype.") is_complex = True if not data.is_complex(): if not data.shape[-1] == 2: raise ValueError( "For real inputs, last dimension must be size 2.") if smaps is not None: if not smaps.shape[-1] == 2: raise ValueError( "For real inputs, last dimension must be size 2.") smaps = torch.view_as_complex(smaps) is_complex = False data = torch.view_as_complex(data) if interp_mats is not None: assert isinstance(self.scaling_coef, Tensor) assert isinstance(self.im_size, Tensor) assert isinstance(self.grid_size, Tensor) output = tkbnF.kb_spmat_nufft_adjoint( data=data, scaling_coef=self.scaling_coef, im_size=self.im_size, grid_size=self.grid_size, interp_mats=interp_mats, norm=norm, ) else: tables = [] for i in range(len(self.im_size)): # type: ignore tables.append(getattr(self, f"table_{i}")) assert isinstance(self.scaling_coef, Tensor) assert isinstance(self.im_size, Tensor) assert isinstance(self.grid_size, Tensor) assert isinstance(self.n_shift, Tensor) assert isinstance(self.numpoints, Tensor) assert isinstance(self.table_oversamp, Tensor) assert isinstance(self.offsets, Tensor) output = tkbnF.kb_table_nufft_adjoint( data=data, scaling_coef=self.scaling_coef, im_size=self.im_size, grid_size=self.grid_size, omega=omega, tables=tables, n_shift=self.n_shift, numpoints=self.numpoints, table_oversamp=self.table_oversamp, offsets=self.offsets.to(torch.long), norm=norm, ) if smaps is not None: output = torch.sum(output * smaps.conj(), dim=1, keepdim=True) if not is_complex: output = torch.view_as_real(output) return output
def forward( self, image: Tensor, kernel: Tensor, smaps: Optional[Tensor] = None, norm: Optional[str] = "ortho", ) -> Tensor: """Toeplitz NUFFT forward function. Args: image: The image to apply the forward/backward Toeplitz-embedded NUFFT to. kernel: The filter response taking into account Toeplitz embedding. norm: Whether to apply normalization with the FFT operation. Options are ``"ortho"`` or ``None``. Returns: ``image`` after applying the Toeplitz forward/backward NUFFT. """ if not kernel.dtype == image.dtype: raise TypeError("kernel and image must have same dtype.") if smaps is not None: if not smaps.dtype == image.dtype: raise TypeError("image dtype does not match smaps dtype.") is_complex = True if not image.is_complex(): if not image.shape[-1] == 2: raise ValueError( "For real inputs, last dimension must be size 2.") if not kernel.shape[-1] == 2: raise ValueError( "For real inputs, last dimension must be size 2.") if smaps is not None: if not smaps.shape[-1] == 2: raise ValueError( "For real inputs, last dimension must be size 2.") smaps = torch.view_as_complex(smaps) is_complex = False image = torch.view_as_complex(image) kernel = torch.view_as_complex(kernel) if len(kernel.shape) > len(image.shape[2:]): if kernel.shape[0] == 1: kernel = kernel[0] elif not kernel.shape[0] == image.shape[0]: raise ValueError("If using batch dimension, " "kernel must have same batch size as image") if smaps is None: output = tkbnF.fft_filter(image=image, kernel=kernel, norm=norm) else: output = self.toep_batch_loop(image=image, smaps=smaps, kernel=kernel, norm=norm) if not is_complex: output = torch.view_as_real(output) return output
def matnoise(mat, noise='wgn', snr=30, peak='maxv'): """add noise to an matrix Add noise to an matrix (real or complex) Args: mat (torch.Tensor): Input tensor, can be real or complex valued noise (str, optional): type of noise (default: ``'wgn'``) snr (float, optional): Signal-to-noise ratio (default: 30) peak (None or float, optional): Peak value in input, for complex data, ``peak=[peakr, peaki]``, if None, auto detected, if ``'maxv'``, use the maximum value as peak value. (default) Returns: (torch.Tensor): Output tensor. """ if th.is_complex(mat): cplxflag = True if peak is None: # peakr = mat.real.abs().max() peakr = mat.real.max() peaki = mat.imag.max() peakr = 2**nextpow2(peakr) - 1 peaki = 2**nextpow2(peaki) - 1 if peak == 'maxv': # peakr = mat.real.abs().max() peakr = mat.real.max() peaki = mat.imag.max() else: peakr, peaki = peak mat = th.view_as_real(mat) mat[..., 0] = awgn(mat[..., 0], snr=snr, peak=peakr, pmode='db', measMode='measured') mat[..., 1] = awgn(mat[..., 1], snr=snr, peak=peaki, pmode='db', measMode='measured') else: cplxflag = False if peak is None: # peakr = mat.real.abs().max() peakr = mat.real.max() peaki = mat.imag.max() peakr = 2**nextpow2(peakr) - 1 peaki = 2**nextpow2(peaki) - 1 if peak == 'maxv': # peakr = mat.real.abs().max() peakr = mat[..., 0].max() peaki = mat[..., 1].max() else: peakr, peaki = peak if mat.shape[-1] == 2: mat[..., 0] = awgn(mat[..., 0], snr=snr, peak=peakr, pmode='db', measMode='measured') mat[..., 1] = awgn(mat[..., 1], snr=snr, peak=peaki, pmode='db', measMode='measured') else: if peak is None: # peak = mat.abs().max() peak = mat.max() peak = 2**nextpow2(peak) - 1 elif peak == 'maxv': peak = mat.max() mat = awgn(mat, snr=snr, peak=peak, pmode='db', measMode='measured') if cplxflag: mat = th.view_as_complex(mat) return mat
x : tensor Input tensor with least one axis. i0 : int Start of original signal before padding. i1 : int End of original signal before padding. Returns ------- x_unpadded : tensor The tensor x[..., i0:i1]. """ return x[..., i0:i1] fft = FFT(lambda x: torch.view_as_real(torch.fft.fft(torch.view_as_complex(x))), lambda x: torch.view_as_real(torch.fft.ifft(torch.view_as_complex(x))), lambda x: torch.fft.ifft(torch.view_as_complex(x)).real, type_checks) backend = namedtuple('backend', ['name', 'modulus_complex', 'subsample_fourier', 'real', 'unpad', 'fft', 'concatenate']) backend.name = 'torch' backend.modulus_complex = Modulus() backend.subsample_fourier = subsample_fourier backend.real = real backend.unpad = unpad backend.cdgmm = cdgmm backend.pad = pad backend.pad_1d = pad_1d backend.fft = fft
def func(z): z_ = torch.view_as_complex(z) z_select = torch.select(z_, z_.dim() - 1, 0) z_select_real = torch.view_as_real(z_select) return z_select_real.sum()
def torch_complex_from_magphase(mag, phase): return torch.view_as_complex( torch.stack((mag * torch.cos(phase), mag * torch.sin(phase)), dim=-1))
def view_as_complex(data): """Named version of `torch.view_as_complex()`""" assert_complex(data) return torch.view_as_complex(data.rename(None)).refine_names(*data.names[:-1])
def torch_complex_from_reim(re, im): return torch.view_as_complex(torch.stack([re, im], dim=-1))
def forward( self, data: Tensor, omega: Tensor, interp_mats: Optional[Tuple[Tensor, Tensor]] = None, grid_size: Optional[Tensor] = None, ) -> Tensor: """Interpolate from scattered data to gridded data. Input tensors should be of shape ``(N, C) + klength``, where ``N`` is the batch size and ``C`` is the number of sensitivity coils. ``omega``, the k-space trajectory, should be of size ``(len(im_size), klength)``, where ``klength`` is the length of the k-space trajectory. If your tensors are real-valued, ensure that 2 is the size of the last dimension. Args: data: Data to be gridded. omega: k-space trajectory (in radians/voxel). interp_mats: 2-tuple of real, imaginary sparse matrices to use for sparse matrix KB interpolation (overrides default table interpolation). Returns: ``data`` interpolated to the grid. """ is_complex = True if not data.is_complex(): if not data.shape[-1] == 2: raise ValueError( "For real inputs, last dimension must be size 2.") is_complex = False data = torch.view_as_complex(data) if grid_size is None: assert isinstance(self.grid_size, Tensor) grid_size = self.grid_size if interp_mats is not None: output = tkbnF.kb_spmat_interp_adjoint(data=data, interp_mats=interp_mats, grid_size=grid_size) else: tables = [] for i in range(len(self.im_size)): # type: ignore tables.append(getattr(self, f"table_{i}")) assert isinstance(self.n_shift, Tensor) assert isinstance(self.numpoints, Tensor) assert isinstance(self.table_oversamp, Tensor) assert isinstance(self.offsets, Tensor) output = tkbnF.kb_table_interp_adjoint( data=data, omega=omega, tables=tables, n_shift=self.n_shift, numpoints=self.numpoints, table_oversamp=self.table_oversamp, offsets=self.offsets, grid_size=grid_size, ) if not is_complex: output = torch.view_as_real(output) return output
def irfft(input, signal_ndim, normalized=False, onesided=True, signal_sizes=None): if LooseVersion(torch.__version__) < LooseVersion("1.8.0"): return torch.irfft(input, signal_ndim, normalized, onesided, signal_sizes) else: assert signal_sizes, "Parameter signal_sizes is required" if onesided: if normalized: if signal_ndim == 1: y = torch.fft.irfft(torch.view_as_complex(input), signal_sizes[-1], -1, "ortho") elif signal_ndim == 2: y = torch.fft.irfft2(torch.view_as_complex(input), signal_sizes, (-2, -1), "ortho") elif signal_ndim == 3: y = torch.fft.irfftn(torch.view_as_complex(input), signal_sizes, (-3, -2, -1), "ortho") else: assert False, "Ortho-normalized irfft() has illegal number of dimensions %s" % (signal_ndim) else: if signal_ndim == 1: y = torch.fft.irfft(torch.view_as_complex(input), signal_sizes[-1], -1, "backward") elif signal_ndim == 2: y = torch.fft.irfft2(torch.view_as_complex(input), signal_sizes, (-2, -1), "backward") elif signal_ndim == 3: y = torch.fft.irfftn(torch.view_as_complex(input), signal_sizes, (-3, -2, -1), "backward") else: assert False, "Backward-normalized irfft() has illegal number of dimensions %s" % (signal_ndim) else: if normalized: if signal_ndim == 1: y = torch.fft.irfft(torch.view_as_complex(input), signal_sizes[-1], -1, "ortho") elif signal_ndim == 2: y = torch.fft.irfft2(torch.view_as_complex(input), signal_sizes, (-2, -1), "ortho") elif signal_ndim == 3: y = torch.fft.irfftn(torch.view_as_complex(input), signal_sizes, (-3, -2, -1), "ortho") else: assert False, "Ortho-normalized ifft() has illegal number of dimensions %s" % (signal_ndim) else: if signal_ndim == 1: y = torch.fft.irfft(torch.view_as_complex(input), signal_sizes[-1], -1, "backward") elif signal_ndim == 2: y = torch.fft.irfft2(torch.view_as_complex(input), signal_sizes, (-2, -1), "backward") elif signal_ndim == 3: y = torch.fft.irfftn(torch.view_as_complex(input), signal_sizes, (-3, -2, -1), "backward") else: assert False, "Backward-normalized ifft() has illegal number of dimensions %s" % (signal_ndim) assert not y.is_complex() return y.contiguous()
def tocplx(x): return torch.view_as_complex(x)
# [batch, 1, 121, 61, 2] alphaf = label.to(device=z.device) / (kzzf + lambda0) # [batch, 1, 121, 121] return torch.irfft(cn.mul(kxzf, alphaf), signal_ndim=2) ############################################## x = torch.rand((42, 32, 121, 121)) a = torch.rfft(x, signal_ndim=2, onesided=False) b = fft.fftn(x, dim=[-2, -1]) ca = torch.view_as_complex(a) print(a.shape) print(b.shape) print(torch.allclose(ca, b)) u = ca - b v = u.abs() h = torch.histc(v) import matplotlib.pyplot as plt plt.hist(v.flatten().numpy(), bins=500, log=True) plt.show() exit() # fft_comparison()
i0 : int Start of original signal before padding. i1 : int End of original signal before padding. Returns ------- x_unpadded : tensor The tensor x[..., i0:i1]. """ return x[..., i0:i1] if version.parse(torch.__version__) >= version.parse('1.8'): fft = FFT( lambda x: torch.view_as_real(torch.fft.fft(torch.view_as_complex(x))), lambda x: torch.view_as_real(torch.fft.ifft(torch.view_as_complex(x))), lambda x: torch.fft.ifft(torch.view_as_complex(x)).real, type_checks) else: fft = FFT(lambda x: torch.fft(x, 1, normalized=False), lambda x: torch.ifft(x, 1, normalized=False), lambda x: torch.irfft(x, 1, normalized=False, onesided=False), type_checks) backend = namedtuple('backend', [ 'name', 'modulus_complex', 'subsample_fourier', 'real', 'unpad', 'fft', 'concatenate' ]) backend.name = 'torch' backend.version = torch.__version__ backend.modulus_complex = Modulus()
def complex_ifft(A, *args, **kwargs): return torch.view_as_complex( torch.ifft(torch.view_as_real(A), *args, **kwargs))
def defocus(x, pa=None, pr=None, isfft=True, ftshift=True): r"""Defocus image with given phase error Defocus image in azimuth by .. math:: Y(k, n_r)=\sum_{n_a=0}^{N_a-1} X(n_a, n_r) \exp \left(j \varphi_{n_a}\right) \exp \left(-j \frac{2 \pi}{N_a} k n_a\right) where, :math:`\varphi_{n_a}` is the estimated azimuth phase error of the :math:`n_a`-th azimuth line, :math:`y(k, n_r)` is the focused pixel. The defocus method in range is the same as azimuth. Args: x (Tensor): Focused image data :math:`{\mathbf X} \in{\mathbb C}^{N\times N_c\times N_a\times N_r}` or :math:`{\mathbf X} \in{\mathbb R}^{N\times N_a\times N_r\times 2}` or :math:`{\mathbf X} \in{\mathbb R}^{N\times N_c\times N_a\times N_r\times 2}`. pa (Tensor, optional): Defocus parameters in azimuth, phase error in rad unit. (the default is None, not focus) pr (Tensor, optional): Defocus parameters in range, phase error in rad unit. (the default is None, not focus) isfft (bool, optional): Is need do fft (the default is True) ftshift (bool, optional): Is shift zero frequency to center when do fft/ifft/fftfreq (the default is True) Returns: (Tensor): A tensor of defocused images. Raises: TypeError: :attr:`x` is complex and should be in complex or real represent formation! """ if type(x) is not th.Tensor: x = th.tensor(x) if th.is_complex(x): # N, Na, Nr = x.size(0), x.size(-2), x.size(-1) x = th.view_as_real(x) crepflag = True elif x.size(-1) == 2: # N, Na, Nr = x.size(0), x.size(-3), x.size(-2) crepflag = False else: raise TypeError('x is complex and should be in complex or real represent formation!') d = x.dim() sizea, sizer = [1] * d, [1] * d if pa is not None: sizea[0], sizea[-3], sizea[-1] = pa.size(0), pa.size(1), 2 epa = th.stack((th.cos(pa), th.sin(pa)), dim=-1) epa = epa.reshape(sizea) if isfft: x = ts.fft(x, axis=-3, shift=ftshift) x = ts.ebemulcc(x, epa) x = ts.ifft(x, axis=-3, shift=ftshift) if pr is not None: sizer[0], sizer[-2], sizer[-1] = pr.size(0), pr.size(1), 2 epr = th.stack((th.cos(pr), th.sin(pr)), dim=-1) epr = epr.reshape(sizer) if isfft: x = ts.fft(x, axis=-2, shift=ftshift) x = ts.ebemulcc(x, epr) x = ts.ifft(x, axis=-2, shift=ftshift) if crepflag: x = th.view_as_complex(x) return x
def view_complex_native(x: torch.FloatTensor) -> torch.Tensor: """Convert a PyKEEN complex tensor representation into a torch one using :func:`torch.view_as_complex`.""" return torch.view_as_complex(x.view(*x.shape[:-1], -1, 2))
def test_fn(x): return torch_fn(torch.view_as_complex(x), *args)
def meta_tensor(self, t): # see expired-storages self.check_expired_count += 1 if self.check_expired_count >= self.check_expired_frequency: self.check_for_expired_weak_storages() self.check_expired_count = 0 if self.get_tensor_memo(t) is None: with torch.inference_mode(t.is_inference()): if t._is_view(): # Construct views in two steps: recursively meta-fy their # base, and then create the view off that. NB: doing it # directly from storage is WRONG because this won't cause # version counters to get shared. assert t._is_view() base = self.meta_tensor(t._base) def is_c_of_r(complex_dtype, real_dtype): return ( utils.is_complex_dtype(complex_dtype) and utils.corresponding_real_dtype(complex_dtype) == real_dtype ) if base.dtype == t.dtype: pass elif is_c_of_r(base.dtype, t.dtype): base = torch.view_as_real(base) elif is_c_of_r(t.dtype, base.dtype): base = torch.view_as_complex(base) else: # This is not guaranteed to succeed. If it fails, it # means there is another dtype-converting view function # that hasn't been handled here base = base.view(t.dtype) with torch.enable_grad(): r = base.as_strided(t.size(), t.stride(), t.storage_offset()) else: is_leaf = safe_is_leaf(t) # Fake up some autograd history. if t.requires_grad: r = torch.empty( (0,), dtype=t.dtype, device="meta", requires_grad=True ) if not is_leaf: with torch.enable_grad(): # The backward function here will be wrong, but # that's OK; our goal is just to get the metadata # looking as close as possible; we're not going to # actually try to backward() on these produced # metas. TODO: would be safer to install some # sort of unsupported grad_fn here r = r.clone() else: r = torch.empty((0,), dtype=t.dtype, device="meta") # As long as meta storage is not supported, need to prevent # redispatching on set_(Storage, ...) which will choke with # meta storage s = self.meta_storage(t.storage()) with no_dispatch(): with torch.no_grad(): r.set_(s, t.storage_offset(), t.size(), t.stride()) torch._C._set_conj(r, t.is_conj()) torch._C._set_neg(r, t.is_neg()) self.set_tensor_memo(t, r) return self.get_tensor_memo(t)
def ifft(input, signal_ndim, normalized=True): return torch.view_as_real( torch.fft.ifft2(torch.view_as_complex(input), norm="ortho" if normalized else "backward"))
def fft(x, n=None, axis=0, norm="backward", shift=False): """FFT in torchsar FFT in torchsar. Parameters ---------- x : {torch array} complex representation is supported. Since torch1.7 and above support complex array, when :attr:`x` is in real-representation formation(last dimension is 2, real, imag), we will change the representation in complex formation, after FFT, it will be change back. n : int, optional number of fft points (the default is None --> equals to signal dimension) axis : int, optional axis of fft (the default is 0, which the first dimension) norm : {None or str}, optional Normalization mode. For the forward transform (fft()), these correspond to: - "forward" - normalize by ``1/n`` - "backward" - no normalization (default) - "ortho" - normalize by ``1/sqrt(n)`` (making the FFT orthonormal) shift : bool, optional shift the zero frequency to center (the default is False) Returns ------- y : {torch array} fft results torch array with the same type as :attr:`x` Raises ------ ValueError nfft is small than signal dimension. """ if norm is None: norm = 'backward' if (x.size(-1) == 2) and (not th.is_complex(x)): realflag = True x = th.view_as_complex(x) if axis < 0: axis += 1 else: realflag = False d = x.size(axis) if n is None: n = d if d < n: x = padfft(x, n, axis, shift) elif d > n: raise ValueError('nfft is small than signal dimension!') if shift: y = thfft.fftshift(thfft.fft(thfft.fftshift(x, dim=axis), n=n, dim=axis, norm=norm), dim=axis) else: y = thfft.fft(x, n=n, dim=axis, norm=norm) if realflag: y = th.view_as_real(y) return y
def _spectral_residual_visual_saliency( x: torch.Tensor, scale: float = 0.25, kernel_size: int = 3, sigma: float = 3.8, gaussian_size: int = 10) -> torch.Tensor: r"""Compute Spectral Residual Visual Saliency Credits X. Hou and L. Zhang, CVPR 07, 2007 Reference: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.125.5641&rep=rep1&type=pdf Args: x: Tensor with shape (N, 1, H, W). scale: Resizing factor kernel_size: Kernel size of average blur filter sigma: Sigma of gaussian filter applied on saliency map gaussian_size: Size of gaussian filter applied on saliency map Returns: saliency_map: Tensor with shape BxHxW """ eps = torch.finfo(x.dtype).eps for kernel in kernel_size, gaussian_size: if x.size(-1) * scale < kernel or x.size(-2) * scale < kernel: raise ValueError( f'Kernel size can\'t be greater than actual input size. ' f'Input size: {x.size()} x {scale}. Kernel size: {kernel}') # Downsize image in_img = imresize(x, scale=scale) # Fourier transform (use complex format [a,b] instead of a + ib # because torch<1.8.0 autograd does not support the latter) recommended_torch_version = _parse_version('1.8.0') torch_version = _parse_version(torch.__version__) if len(torch_version) != 0 and torch_version >= recommended_torch_version: imagefft = torch.fft.fft2(in_img) log_amplitude = torch.log(imagefft.abs() + eps) phase = torch.angle(imagefft) else: imagefft = torch.rfft(in_img, 2, onesided=False) # Compute log of absolute value and angle of fourier transform log_amplitude = torch.log(imagefft.pow(2).sum(dim=-1).sqrt() + eps) phase = torch.atan2(imagefft[..., 1], imagefft[..., 0] + eps) # Compute spectral residual using average filtering padding = kernel_size // 2 if padding: up_pad = (kernel_size - 1) // 2 down_pad = padding pad_to_use = [up_pad, down_pad, up_pad, down_pad] # replicate padding before average filtering spectral_residual = F.pad(log_amplitude, pad=pad_to_use, mode='replicate') else: spectral_residual = log_amplitude spectral_residual = log_amplitude - F.avg_pool2d( spectral_residual, kernel_size=kernel_size, stride=1) # Saliency map # representation of complex exp(spectral_residual + j * phase) compx = torch.stack((torch.exp(spectral_residual) * torch.cos(phase), torch.exp(spectral_residual) * torch.sin(phase)), -1) if len(torch_version) != 0 and torch_version >= recommended_torch_version: saliency_map = torch.abs(torch.fft.ifft2( torch.view_as_complex(compx)))**2 else: saliency_map = torch.sum(torch.ifft(compx, 2)**2, dim=-1) # After effect for SR-SIM # Apply gaussian blur kernel = gaussian_filter(gaussian_size, sigma) if gaussian_size % 2 == 0: # matlab pads upper and lower borders with 0s for even kernels kernel = torch.cat((torch.zeros(1, 1, gaussian_size), kernel), 1) kernel = torch.cat((torch.zeros(1, gaussian_size + 1, 1), kernel), 2) gaussian_size += 1 kernel = kernel.view(1, 1, gaussian_size, gaussian_size).to(saliency_map) saliency_map = F.conv2d(saliency_map, kernel, padding=(gaussian_size - 1) // 2) # normalize between [0, 1] min_sal = torch.min(saliency_map[:]) max_sal = torch.max(saliency_map[:]) saliency_map = (saliency_map - min_sal) / (max_sal - min_sal + eps) # scale to original size saliency_map = imresize(saliency_map, sizes=x.size()[-2:]) return saliency_map