def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ eta = self.xpdnet(y, sensitivity_maps, mask) eta = (eta**2).sqrt().sum(-1) _, eta = center_crop_to_smallest(target, eta) return eta
def process_intermediate_pred(self, pred, sensitivity_maps, target): """ Process the intermediate prediction. Parameters ---------- pred: Intermediate prediction. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] target: Target data to crop to size. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: torch.Tensor, shape [batch_size, n_x, n_y, 2] Processed prediction. """ pred = ifft2( pred, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims ) pred = coil_combination(pred, sensitivity_maps, method=self.coil_combination_method, dim=self.coil_dim) pred = torch.view_as_complex(pred) _, pred = center_crop_to_smallest(target, pred) return pred
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, target: torch.Tensor = None, ) -> Union[list, Any]: """ Forward pass of PICS. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: torch.Tensor, shape [batch_size, n_x, n_y, 2] Predicted data. """ pred = torch.zeros_like(sensitivity_maps) # if "cuda" in str(self._device): # pred = bart.bart(1, f"pics -d0 -g -S -R W:7:0:{self.reg_wt} -i {self.num_iters}", y, sensitivity_maps)[0] # else: # pred = bart.bart(1, f"pics -d0 -S -R W:7:0:{self.reg_wt} -i {self.num_iters}", y, sensitivity_maps)[0] _, pred = center_crop_to_smallest(target, pred) return pred
def test_center_crop_to_smallest(x, y): """ Test if the center_crop_to_smallest function works as expected. Args: x: The input array. y: The input array. Returns: None """ x, y = center_crop_to_smallest(x, y) if x.shape != y.shape: raise AssertionError
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ estimation = y.clone() for cascade in self.cascades: # Forward pass through the cascades estimation = cascade(estimation, y, sensitivity_maps, mask) estimation = ifft2( estimation, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) estimation = coil_combination(estimation, sensitivity_maps, method=self.coil_combination_method, dim=self.coil_dim) estimation = torch.view_as_complex(estimation) _, estimation = center_crop_to_smallest(target, estimation) return estimation
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ sensitivity_maps = self.sens_net( y, mask) if self.use_sens_net else sensitivity_maps image = self.model(y, sensitivity_maps, mask) image = torch.view_as_complex( coil_combination( ifft2( image, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), sensitivity_maps, method=self.coil_combination_method, dim=self.coil_dim, )) _, image = center_crop_to_smallest(target, image) return image
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ image = ifft2(y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims) if hasattr(self, "standardization"): image = self.standardization(image, sensitivity_maps) output_image = self._compute_model_per_coil( self.unet, image.permute(0, 1, 4, 2, 3)).permute(0, 1, 3, 4, 2) output_image = coil_combination(output_image, sensitivity_maps, method=self.coil_combination_method, dim=self.coil_dim) output_image = torch.view_as_complex(output_image) _, output_image = center_crop_to_smallest(target, output_image) return output_image
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ DC_sens = self.sens_net(y, mask) sensitivity_maps = DC_sens.clone() image = complex_mul( ifft2(y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims), complex_conj(sensitivity_maps), ).sum(self.coil_dim) for idx in range(self.num_iter): sensitivity_maps = self.update_C(idx, DC_sens, sensitivity_maps, image, y, mask) image = self.update_X(idx, image, sensitivity_maps, y, mask) image = torch.view_as_complex(image) _, image = center_crop_to_smallest(target, image) return image
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, target: torch.Tensor = None, ) -> Union[list, Any]: """ Forward pass of the zero-filled method. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: torch.Tensor, shape [batch_size, n_x, n_y, 2] Predicted data. """ pred = coil_combination( ifft2(y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims), sensitivity_maps, method=self.coil_combination_method.upper(), dim=self.coil_dim, ) pred = check_stacked_complex(pred) _, pred = center_crop_to_smallest(target, pred) return pred
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ input_image = complex_mul( ifft2( torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), y), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) dual_buffer = torch.cat([y] * self.num_dual, -1).to(y.device) primal_buffer = torch.cat([input_image] * self.num_primal, -1).to(y.device) for idx in range(self.num_iter): # Dual f_2 = primal_buffer[..., 2:4].clone() f_2 = torch.where( mask == 0, torch.tensor([0.0], dtype=f_2.dtype).to(f_2.device), fft2( complex_mul(f_2.unsqueeze(self.coil_dim), sensitivity_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).type(f_2.type()), ) dual_buffer = self.dual_net[idx](dual_buffer, f_2, y) # Primal h_1 = dual_buffer[..., 0:2].clone() h_1 = complex_mul( ifft2( torch.where( mask == 0, torch.tensor([0.0], dtype=h_1.dtype).to(h_1.device), h_1), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) primal_buffer = self.primal_net[idx](primal_buffer, h_1) output = primal_buffer[..., 0:2] output = (output**2).sum(-1).sqrt() _, output = center_crop_to_smallest(target, output) return output
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, **kwargs, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ previous_state: Optional[torch.Tensor] = None if self.initializer is not None: if self.initializer_initialization == "sense": initializer_input_image = (complex_mul( ifft2( y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim).unsqueeze(self.coil_dim)) elif self.initializer_initialization == "input_image": if "initial_image" not in kwargs: raise ValueError( "`'initial_image` is required as input if initializer_initialization " f"is {self.initializer_initialization}.") initializer_input_image = kwargs["initial_image"].unsqueeze( self.coil_dim) elif self.initializer_initialization == "zero_filled": initializer_input_image = ifft2( y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) previous_state = self.initializer( fft2( initializer_input_image, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).sum(1).permute(0, 3, 1, 2)) kspace_prediction = y.clone() for step in range(self.num_steps): block = self.block_list[ step] if self.no_parameter_sharing else self.block_list[0] kspace_prediction, previous_state = block( kspace_prediction, y, mask, sensitivity_maps, previous_state, ) eta = ifft2( kspace_prediction, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) eta = coil_combination(eta, sensitivity_maps, method=self.coil_combination_method, dim=self.coil_dim) eta = torch.view_as_complex(eta) _, eta = center_crop_to_smallest(target, eta) return eta
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ kspace = y.clone() zero = torch.zeros(1, 1, 1, 1, 1).to(kspace) for idx in range(self.num_iter): soft_dc = torch.where(mask.bool(), kspace - y, zero) * self.dc_weight kspace = self.kspace_model_list[idx](kspace) if kspace.shape[-1] != 2: kspace = kspace.permute(0, 1, 3, 4, 2).to(target) kspace = torch.view_as_real(kspace[..., 0] + 1j * kspace[..., 1]) # this is necessary, but why? image = complex_mul( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) image = self.image_model_list[idx](image.unsqueeze(self.coil_dim)).squeeze(self.coil_dim) if not self.no_dc: image = fft2( complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).type(image.type()) image = kspace - soft_dc - image image = complex_mul( ifft2( image, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) if idx < self.num_iter - 1: kspace = fft2( complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).type(image.type()) image = torch.view_as_complex(image) _, image = center_crop_to_smallest(target, image) return image