def log_likelihood_gradient( eta: torch.Tensor, masked_kspace: torch.Tensor, sense: torch.Tensor, mask: torch.Tensor, sigma: float, fft_centered: bool, fft_normalization: str, spatial_dims: Sequence[int], coil_dim: int, ) -> torch.Tensor: """ Computes the gradient of the log-likelihood function. Parameters ---------- eta: Initial guess for the reconstruction. masked_kspace: Subsampled k-space data. sense: Sensing matrix. mask: Sampling mask. sigma: Noise level. fft_centered: Whether to center the FFT. fft_normalization: Whether to normalize the FFT. spatial_dims: Spatial dimensions of the data. coil_dim: Dimension of the coil. Returns ------- Gradient of the log-likelihood function. """ eta_real, eta_imag = map(lambda x: torch.unsqueeze(x, 0), eta.chunk(2, -1)) sense_real, sense_imag = sense.chunk(2, -1) re_se = eta_real * sense_real - eta_imag * sense_imag im_se = eta_real * sense_imag + eta_imag * sense_real pred = ifft2( mask * (fft2( torch.cat((re_se, im_se), -1), centered=fft_centered, normalization=fft_normalization, spatial_dims=spatial_dims, ) - masked_kspace), centered=fft_centered, normalization=fft_normalization, spatial_dims=spatial_dims, ) pred_real, pred_imag = pred.chunk(2, -1) re_out = torch.sum(pred_real * sense_real + pred_imag * sense_imag, coil_dim) / (sigma**2.0) im_out = torch.sum(pred_imag * sense_real - pred_real * sense_imag, coil_dim) / (sigma**2.0) eta_real = eta_real.squeeze(0) eta_imag = eta_imag.squeeze(0) return torch.cat((eta_real, eta_imag, re_out, im_out), 0).unsqueeze(0).squeeze(-1)
def test_centered_fft2_forward_normalization(shape): """ Test centered 2D Fast Fourier Transform with forward normalization. Args: shape: shape of the input Returns: None """ shape = shape + [2] x = create_input(shape) out_torch = fft2(x, centered=True, normalization="forward", spatial_dims=[-2, -1]).numpy() out_torch = out_torch[..., 0] + 1j * out_torch[..., 1] input_numpy = tensor_to_complex_np(x) input_numpy = np.fft.ifftshift(input_numpy, (-2, -1)) out_numpy = np.fft.fft2(input_numpy, norm="forward") out_numpy = np.fft.fftshift(out_numpy, (-2, -1)) if not np.allclose(out_torch, out_numpy): raise AssertionError
def forward(self, image): """Forward method for the MultiDomainConvTranspose2d class.""" kspace = [ fft2(im, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims) for im in torch.split( image.permute(0, 2, 3, 1).contiguous(), 2, -1) ] kspace = torch.cat(kspace, -1).permute(0, 3, 1, 2) kspace = self.kspace_conv(kspace) backward = [ ifft2( ks.float(), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).type(image.type()) for ks in torch.split( kspace.permute(0, 2, 3, 1).contiguous(), 2, -1) ] backward = torch.cat(backward, -1).permute(0, 3, 1, 2) image = self.image_conv(image) return torch.cat([image, backward], dim=self.coil_dim)
def A(x): x = (fft2( complex_mul(x.expand_as(smaps), smaps), centered=ctx.fft_centered, normalization=ctx.fft_normalization, spatial_dims=ctx.spatial_dims, ) * mask) return torch.sum(x, dim=-4, keepdim=True)
def _forward_operator(self, image, sampling_mask, sensitivity_map): """Forward operator.""" return torch.where( sampling_mask == 0, torch.tensor([0.0], dtype=image.dtype).to(image.device), fft2( complex_mul(image.unsqueeze(self.coil_dim), sensitivity_map), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).type(image.type()), )
def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: """ Expand the sensitivity maps to the same size as the input. Parameters ---------- x: Input data. sens_maps: Coil Sensitivity maps. Returns ------- SENSE reconstruction expanded to the same size as the input sens_maps. """ return fft2( complex_mul(x, sens_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, )
def forward(self, x, y, smaps, mask): """ Forward pass of the data-consistency block. Parameters ---------- x: Input image. y: Subsampled k-space data. smaps: Coil sensitivity maps. mask: Sampling mask. Returns ------- Output image. """ A_x = torch.sum( fft2( complex_mul(x.unsqueeze(-5).expand_as(smaps), smaps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), -4, keepdim=True, ) k_dc = (1 - mask) * A_x + mask * (self.alpha * A_x + (1 - self.alpha) * y) x_dc = torch.sum( complex_mul( ifft2( k_dc, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(smaps), ), dim=(-5), ) return self.beta * x + (1 - self.beta) * x_dc
def test_non_centered_fft2(shape): """ Test non-centered 2D Fast Fourier Transform. Args: shape: shape of the input Returns: None """ shape = shape + [2] x = create_input(shape) out_torch = fft2(x, centered=False, normalization="ortho", spatial_dims=[-2, -1]).numpy() out_torch = out_torch[..., 0] + 1j * out_torch[..., 1] input_numpy = tensor_to_complex_np(x) out_numpy = np.fft.fft2(input_numpy, norm="ortho") if not np.allclose(out_torch, out_numpy): raise AssertionError
def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: """ Expand the sensitivity maps to the same size as the input. Parameters ---------- x: Input data. torch.Tensor, shape [batch_size, n_coils, height, width, 2] sens_maps: Sensitivity maps. torch.Tensor, shape [batch_size, n_coils, height, width, 2] Returns ------- SENSE reconstruction expanded to the same size as the input. torch.Tensor, shape [batch_size, n_coils, height, width, 2] """ return fft2( complex_mul(x, sens_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, )
def forward(self, x, y, smaps, mask): """ Parameters ---------- x: Input image. y: Subsampled k-space data. smaps: Coil sensitivity maps. mask: Sampling mask. Returns ------- data_loss: Data term loss. """ A_x_y = (torch.sum( fft2( complex_mul(x.unsqueeze(-5).expand_as(smaps), smaps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) * mask, -4, keepdim=True, ) - y) gradD_x = torch.sum( complex_mul( ifft2c( A_x_y * mask, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(smaps), ), dim=(-5), ) return x - self.data_weight * gradD_x
def forward(self, x, y, mask): """ Forward pass of the data-consistency block. Parameters ---------- x: Input image. y: Subsampled k-space data. mask: Sampling mask. Returns ------- Output image. """ A_x = fft2(x, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims) k_dc = (1 - mask) * A_x + mask * (self.lambda_ * A_x + (1 - self.lambda_) * y) return ifft2(k_dc, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims)
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ kspace = y.clone() zero = torch.zeros(1, 1, 1, 1, 1).to(kspace) for idx in range(self.num_iter): soft_dc = torch.where(mask.bool(), kspace - y, zero) * self.dc_weight kspace = self.kspace_model_list[idx](kspace) if kspace.shape[-1] != 2: kspace = kspace.permute(0, 1, 3, 4, 2).to(target) kspace = torch.view_as_real(kspace[..., 0] + 1j * kspace[..., 1]) # this is necessary, but why? image = complex_mul( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) image = self.image_model_list[idx](image.unsqueeze(self.coil_dim)).squeeze(self.coil_dim) if not self.no_dc: image = fft2( complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).type(image.type()) image = kspace - soft_dc - image image = complex_mul( ifft2( image, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) if idx < self.num_iter - 1: kspace = fft2( complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).type(image.type()) image = torch.view_as_complex(image) _, image = center_crop_to_smallest(target, image) return image
def forward( self, current_kspace: torch.Tensor, masked_kspace: torch.Tensor, sampling_mask: torch.Tensor, sensitivity_map: torch.Tensor, hidden_state: Union[None, torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes forward pass of RecurrentVarNetBlock. Parameters ---------- current_kspace: Current k-space prediction. torch.Tensor, shape [batch_size, n_coil, height, width, 2] masked_kspace: Subsampled k-space. torch.Tensor, shape [batch_size, n_coil, height, width, 2] sampling_mask: Sampling mask. torch.Tensor, shape [batch_size, 1, height, width, 1] sensitivity_map: Coil sensitivities. torch.Tensor, shape [batch_size, n_coil, height, width, 2] hidden_state: ConvGRU hidden state. None or torch.Tensor, shape [batch_size, n_l, height, width, hidden_channels] Returns ------- new_kspace: New k-space prediction. torch.Tensor, shape [batch_size, n_coil, height, width, 2] hidden_state: Next hidden state. list of torch.Tensor, shape [batch_size, hidden_channels, height, width, num_layers] """ kspace_error = torch.where( sampling_mask == 0, torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), current_kspace - masked_kspace, ) recurrent_term = torch.cat( [ complex_mul( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_map), ).sum(self.coil_dim) for kspace in torch.split(current_kspace, 2, -1) ], dim=-1, ).permute(0, 3, 1, 2) recurrent_term, hidden_state = self.regularizer(recurrent_term, hidden_state) # :math:`w_t`, :math:`h_{t+1}` recurrent_term = recurrent_term.permute(0, 2, 3, 1) recurrent_term = torch.cat( [ fft2( complex_mul(image.unsqueeze(self.coil_dim), sensitivity_map), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) for image in torch.split(recurrent_term, 2, -1) ], dim=-1, ) new_kspace = current_kspace - self.learning_rate * kspace_error + recurrent_term return new_kspace, hidden_state
def update_X(self, idx, image, sensitivity_maps, y, mask): """ Update the image. .. math:: x_{k} = (1 - 2 * \lamdba_{{k}_{I}} * mi_{k} - 2 * \lamdba_{{k}_{F}} * mi_{k}) * x_{k} x_{k} = 2 * mi_{k} * (\lambda_{{k}_{I}} * D_I(x_{k}) + \lambda_{{k}_{F}} * F^-1(D_F(f))) A(x{k} - b) = M * F * (C * x{k}) - b x_{k} = 2 * mi_{k} * A^* * (A(x{k} - b)) Parameters ---------- idx: int The current iteration index. image: torch.Tensor [batch_size, num_coils, num_rows, num_cols] The predicted image. sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols] The coil sensitivity maps. y: torch.Tensor [batch_size, num_coils, num_rows, num_cols] The subsampled k-space data. mask: torch.Tensor [batch_size, 1, num_rows, num_cols] The subsampled mask. Returns ------- image: torch.Tensor [batch_size, num_coils, num_rows, num_cols] The updated image. """ # (1 - 2 * lamdba_{k}_{I} * mi_{k} - 2 * lamdba_{k}_{F} * mi_{k}) * x_{k} image_term_1 = ( 1 - 2 * self.reg_param_I[idx] * self.lr_image[idx] - 2 * self.reg_param_F[idx] * self.lr_image[idx] ) * image # D_I(x_{k}) image_term_2_DI = self.image_model(image.unsqueeze(self.coil_dim)).squeeze(self.coil_dim).contiguous() image_term_2_DF = ifft2( self.kspace_model( fft2( image, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).unsqueeze(self.coil_dim) ) .squeeze(self.coil_dim) .contiguous(), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) # 2 * mi_{k} * (lambda_{k}_{I} * D_I(x_{k}) + lambda_{k}_{F} * F^-1(D_F(f))) image_term_2 = ( 2 * self.lr_image[idx] * (self.reg_param_I[idx] * image_term_2_DI + self.reg_param_F[idx] * image_term_2_DF) ) # A(x{k}) - b) = M * F * (C * x{k}) - b image_term_3_A = fft2( complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) image_term_3_A = torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), image_term_3_A) - y # 2 * mi_{k} * A^* * (A(x{k}) - b)) image_term_3_Aconj = complex_mul( ifft2( image_term_3_A, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) image_term_3 = 2 * self.lr_image[idx] * image_term_3_Aconj image = image_term_1 + image_term_2 - image_term_3 return image
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, **kwargs, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ previous_state: Optional[torch.Tensor] = None if self.initializer is not None: if self.initializer_initialization == "sense": initializer_input_image = (complex_mul( ifft2( y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim).unsqueeze(self.coil_dim)) elif self.initializer_initialization == "input_image": if "initial_image" not in kwargs: raise ValueError( "`'initial_image` is required as input if initializer_initialization " f"is {self.initializer_initialization}.") initializer_input_image = kwargs["initial_image"].unsqueeze( self.coil_dim) elif self.initializer_initialization == "zero_filled": initializer_input_image = ifft2( y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) previous_state = self.initializer( fft2( initializer_input_image, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).sum(1).permute(0, 3, 1, 2)) kspace_prediction = y.clone() for step in range(self.num_steps): block = self.block_list[ step] if self.no_parameter_sharing else self.block_list[0] kspace_prediction, previous_state = block( kspace_prediction, y, mask, sensitivity_maps, previous_state, ) eta = ifft2( kspace_prediction, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) eta = coil_combination(eta, sensitivity_maps, method=self.coil_combination_method, dim=self.coil_dim) eta = torch.view_as_complex(eta) _, eta = center_crop_to_smallest(target, eta) return eta
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ input_image = complex_mul( ifft2( torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), y), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) dual_buffer = torch.cat([y] * self.num_dual, -1).to(y.device) primal_buffer = torch.cat([input_image] * self.num_primal, -1).to(y.device) for idx in range(self.num_iter): # Dual f_2 = primal_buffer[..., 2:4].clone() f_2 = torch.where( mask == 0, torch.tensor([0.0], dtype=f_2.dtype).to(f_2.device), fft2( complex_mul(f_2.unsqueeze(self.coil_dim), sensitivity_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).type(f_2.type()), ) dual_buffer = self.dual_net[idx](dual_buffer, f_2, y) # Primal h_1 = dual_buffer[..., 0:2].clone() h_1 = complex_mul( ifft2( torch.where( mask == 0, torch.tensor([0.0], dtype=h_1.dtype).to(h_1.device), h_1), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) primal_buffer = self.primal_net[idx](primal_buffer, h_1) output = primal_buffer[..., 0:2] output = (output**2).sum(-1).sqrt() _, output = center_crop_to_smallest(target, output) return output
def update_C(self, idx, DC_sens, sensitivity_maps, image, y, mask) -> torch.Tensor: """ Update the coil sensitivity maps. .. math:: C = (1 - 2 * \lambda_{k}^{C} * ni_{k}) * C_{k} C = 2 * \lambda_{k}^{C} * ni_{k} * D_{C}(F^-1(b)) A(x_{k}) = M * F * (C * x_{k}) C = 2 * ni_{k} * F^-1(M.T * (M * F * (C * x_{k}) - b)) * x_{k}^* Parameters ---------- idx: int The current iteration index. DC_sens: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols] The initial coil sensitivity maps. sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols] The coil sensitivity maps. image: torch.Tensor [batch_size, num_coils, num_rows, num_cols] The predicted image. y: torch.Tensor [batch_size, num_coils, num_rows, num_cols] The subsampled k-space data. mask: torch.Tensor [batch_size, 1, num_rows, num_cols] The subsampled mask. Returns ------- sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols] The updated coil sensitivity maps. """ # (1 - 2 * lambda_{k}^{C} * ni_{k}) * C_{k} sense_term_1 = (1 - 2 * self.reg_param_C[idx] * self.lr_sens[idx]) * sensitivity_maps # 2 * lambda_{k}^{C} * ni_{k} * D_{C}(F^-1(b)) sense_term_2 = 2 * self.reg_param_C[idx] * self.lr_sens[idx] * DC_sens # A(x_{k}) = M * F * (C * x_{k}) sense_term_3_A = fft2( complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) sense_term_3_A = torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), sense_term_3_A) # 2 * ni_{k} * F^-1(M.T * (M * F * (C * x_{k}) - b)) * x_{k}^* sense_term_3_mask = torch.where( mask == 1, torch.tensor([0.0], dtype=y.dtype).to(y.device), sense_term_3_A - y, ) sense_term_3_backward = ifft2( sense_term_3_mask, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) sense_term_3 = 2 * self.lr_sens[idx] * sense_term_3_backward * complex_conj(image).unsqueeze(self.coil_dim) sensitivity_maps = sense_term_1 + sense_term_2 - sense_term_3 return sensitivity_maps
def forward( self, pred: torch.Tensor, masked_kspace: torch.Tensor, sense: torch.Tensor, mask: torch.Tensor, eta: torch.Tensor = None, hx: torch.Tensor = None, sigma: float = 1.0, keep_eta: bool = False, ) -> Tuple[Any, Union[list, torch.Tensor, None]]: """ Forward pass of the RIMBlock. Parameters ---------- pred: Predicted k-space. masked_kspace: Subsampled k-space. sense: Coil sensitivity maps. mask: Sample mask. eta: Initial guess for the eta. hx: Initial guess for the hidden state. sigma: Noise level. keep_eta: Whether to keep the eta. Returns ------- Reconstructed image and hidden states. """ if hx is None: hx = [ masked_kspace.new_zeros( (masked_kspace.size(0), f, *masked_kspace.size()[2:-1])) for f in self.recurrent_filters if f != 0 ] if isinstance(pred, list): pred = pred[-1].detach() if eta is None or eta.ndim < 3: eta = (pred if keep_eta else torch.sum( complex_mul( ifft2( pred, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sense), ), self.coil_dim, )) etas = [] for _ in range(self.time_steps): grad_eta = log_likelihood_gradient( eta, masked_kspace, sense, mask, sigma=sigma, fft_centered=self.fft_centered, fft_normalization=self.fft_normalization, spatial_dims=self.spatial_dims, coil_dim=self.coil_dim, ).contiguous() for h, convrnn in enumerate(self.layers): hx[h] = convrnn(grad_eta, hx[h]) grad_eta = hx[h] eta = eta + self.final_layer(grad_eta).permute(0, 2, 3, 1) etas.append(eta) eta = etas if self.no_dc: return eta, None soft_dc = torch.where(mask, pred - masked_kspace, self.zero.to(masked_kspace)) * self.dc_weight current_kspace = [ masked_kspace - soft_dc - fft2( complex_mul(e.unsqueeze(self.coil_dim), sense), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) for e in eta ] return current_kspace, None
def __call__( self, kspace: np.ndarray, sensitivity_map: np.ndarray, mask: np.ndarray, eta: np.ndarray, target: np.ndarray, attrs: Dict, fname: str, slice_idx: int, ) -> Tuple[torch.Tensor, Union[Union[List, torch.Tensor], torch.Tensor], Union[Optional[torch.Tensor], Any], Union[List, Any], Union[ Optional[torch.Tensor], Any], Union[torch.Tensor, Any], str, int, Union[Union[List, torch.Tensor], Any], ]: """ Apply the data transform. Parameters ---------- kspace: The kspace. sensitivity_map: The sensitivity map. mask: The mask. eta: The initial estimation. target: The target. attrs: The attributes. fname: The file name. slice_idx: The slice number. Returns ------- The transformed data. """ kspace = to_tensor(kspace) # This condition is necessary in case of auto estimation of sense maps. if sensitivity_map is not None and sensitivity_map.size != 0: sensitivity_map = to_tensor(sensitivity_map) # Apply zero-filling on kspace if self.kspace_zero_filling_size is not None and self.kspace_zero_filling_size not in ( "", "None"): padding_top = np.floor_divide( abs(int(self.kspace_zero_filling_size[0]) - kspace.shape[1]), 2) padding_bottom = padding_top padding_left = np.floor_divide( abs(int(self.kspace_zero_filling_size[1]) - kspace.shape[2]), 2) padding_right = padding_left kspace = torch.view_as_complex(kspace) kspace = torch.nn.functional.pad(kspace, pad=(padding_left, padding_right, padding_top, padding_bottom), mode="constant", value=0) kspace = torch.view_as_real(kspace) sensitivity_map = fft2( sensitivity_map, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) sensitivity_map = torch.view_as_complex(sensitivity_map) sensitivity_map = torch.nn.functional.pad( sensitivity_map, pad=(padding_left, padding_right, padding_top, padding_bottom), mode="constant", value=0, ) sensitivity_map = torch.view_as_real(sensitivity_map) sensitivity_map = ifft2( sensitivity_map, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) # Initial estimation eta = to_tensor( eta) if eta is not None and eta.size != 0 else torch.tensor([]) # If the target is not given, we need to compute it. if self.coil_combination_method.upper() == "RSS": target = rss( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), dim=self.coil_dim, ) elif self.coil_combination_method.upper() == "SENSE": if sensitivity_map is not None and sensitivity_map.size != 0: target = sense( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), sensitivity_map, dim=self.coil_dim, ) elif target is not None and target.size != 0: target = to_tensor(target) elif "target" in attrs or "target_rss" in attrs: target = torch.tensor(attrs["target"]) else: raise ValueError("No target found") target = torch.view_as_complex(target) target = torch.abs(target / torch.max(torch.abs(target))) seed = tuple(map(ord, fname)) if self.use_seed else None acq_start = attrs["padding_left"] if "padding_left" in attrs else 0 acq_end = attrs["padding_right"] if "padding_left" in attrs else 0 # This should be outside the condition because it needs to be returned in the end, even if cropping is off. # crop_size = torch.tensor([attrs["recon_size"][0], attrs["recon_size"][1]]) crop_size = target.shape if self.crop_size is not None and self.crop_size not in ("", "None"): # Check for smallest size against the target shape. h = min(int(self.crop_size[0]), target.shape[0]) w = min(int(self.crop_size[1]), target.shape[1]) # Check for smallest size against the stored recon shape in metadata. if crop_size[0] != 0: h = h if h <= crop_size[0] else crop_size[0] if crop_size[1] != 0: w = w if w <= crop_size[1] else crop_size[1] self.crop_size = (int(h), int(w)) target = center_crop(target, self.crop_size) if sensitivity_map is not None and sensitivity_map.size != 0: sensitivity_map = (ifft2( complex_center_crop( fft2( sensitivity_map, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), self.crop_size, ), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) if self.kspace_crop else complex_center_crop( sensitivity_map, self.crop_size)) if eta is not None and eta.ndim > 2: eta = (ifft2( complex_center_crop( fft2( eta, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), self.crop_size, ), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) if self.kspace_crop else complex_center_crop( eta, self.crop_size)) # Cropping before masking will maintain the shape of original kspace intact for masking. if self.crop_size is not None and self.crop_size not in ( "", "None") and self.crop_before_masking: kspace = (complex_center_crop(kspace, self.crop_size) if self.kspace_crop else fft2( complex_center_crop( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), self.crop_size, ), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, )) # Undersample kspace if undersampling is enabled. if self.mask_func is None: masked_kspace = kspace acc = torch.tensor([np.around(mask.size / mask.sum()) ]) if mask is not None else torch.tensor([1]) if mask is None: mask = torch.ones( [masked_kspace.shape[-3], masked_kspace.shape[-2]], dtype=torch.float32 # type: ignore ) else: mask = torch.from_numpy(mask) if mask.shape[0] == masked_kspace.shape[2]: # type: ignore mask = mask.permute(1, 0) elif mask.shape[0] != masked_kspace.shape[1]: # type: ignore mask = torch.ones( [masked_kspace.shape[-3], masked_kspace.shape[-2]], dtype=torch.float32 # type: ignore ) if mask.ndim == 1: mask = np.expand_dims(mask, axis=0) if mask.shape[-2] == 1: # 1D mask mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(-1) else: # 2D mask # Crop loaded mask. if self.crop_size is not None and self.crop_size not in ( "", "None"): mask = center_crop(mask, self.crop_size) mask = mask.unsqueeze(0).unsqueeze(-1) if self.shift_mask: mask = torch.fft.fftshift(mask, dim=[-3, -2]) masked_kspace = masked_kspace * mask mask = mask.byte() elif isinstance(self.mask_func, list): masked_kspaces = [] masks = [] accs = [] for m in self.mask_func: _masked_kspace, _mask, _acc = apply_mask( kspace, m, seed, (acq_start, acq_end), shift=self.shift_mask, half_scan_percentage=self.half_scan_percentage, center_scale=self.mask_center_scale, ) masked_kspaces.append(_masked_kspace) masks.append(_mask.byte()) accs.append(_acc) masked_kspace = masked_kspaces mask = masks acc = accs else: masked_kspace, mask, acc = apply_mask( kspace, self.mask_func[0], # type: ignore seed, (acq_start, acq_end), shift=self.shift_mask, half_scan_percentage=self.half_scan_percentage, center_scale=self.mask_center_scale, ) mask = mask.byte() # Cropping after masking. if self.crop_size is not None and self.crop_size not in ( "", "None") and not self.crop_before_masking: masked_kspace = (complex_center_crop(masked_kspace, self.crop_size) if self.kspace_crop else fft2( complex_center_crop( ifft2( masked_kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), self.crop_size, ), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, )) mask = center_crop(mask.squeeze(-1), self.crop_size).unsqueeze(-1) # Normalize by the max value. if self.normalize_inputs: if isinstance(self.mask_func, list): masked_kspaces = [] for y in masked_kspace: if self.fft_normalization in ("orthogonal", "orthogonal_norm_only", "ortho"): imspace = ifft2( y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) imspace = imspace / torch.max(torch.abs(imspace)) masked_kspaces.append( fft2( imspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, )) elif self.fft_normalization == "fft_norm": imspace = ifft2( y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) masked_kspaces.append( fft2( imspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, )) elif self.fft_normalization == "backward": imspace = ifft2(y, centered=self.fft_centered, normalization="backward", spatial_dims=self.spatial_dims) masked_kspaces.append( fft2( imspace, centered=self.fft_centered, normalization="backward", spatial_dims=self.spatial_dims, )) else: imspace = torch.fft.ifftn(torch.view_as_complex(y), dim=[-2, -1], norm=None) imspace = imspace / torch.max(torch.abs(imspace)) masked_kspaces.append( torch.view_as_real( torch.fft.fftn(imspace, dim=[-2, -1], norm=None))) masked_kspace = masked_kspaces elif self.fft_normalization in ("orthogonal", "orthogonal_norm_only", "ortho"): imspace = ifft2( masked_kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) imspace = imspace / torch.max(torch.abs(imspace)) masked_kspace = fft2( imspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) elif self.fft_normalization == "fft_norm": masked_kspace = fft2( ifft2( masked_kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) elif self.fft_normalization == "backward": masked_kspace = fft2( ifft2( masked_kspace, centered=self.fft_centered, normalization="backward", spatial_dims=self.spatial_dims, ), centered=self.fft_centered, normalization="backward", spatial_dims=self.spatial_dims, ) else: imspace = torch.fft.ifftn(torch.view_as_complex(masked_kspace), dim=[-2, -1], norm=None) imspace = imspace / torch.max(torch.abs(imspace)) masked_kspace = torch.view_as_real( torch.fft.fftn(imspace, dim=[-2, -1], norm=None)) if sensitivity_map.size != 0: sensitivity_map = sensitivity_map / torch.max( torch.abs(sensitivity_map)) if eta.size != 0 and eta.ndim > 2: eta = eta / torch.max(torch.abs(eta)) target = target / torch.max(torch.abs(target)) return kspace, masked_kspace, sensitivity_map, mask, eta, target, fname, slice_idx, acc