def A(x): x = (fft2( complex_mul(x.expand_as(smaps), smaps), centered=ctx.fft_centered, normalization=ctx.fft_normalization, spatial_dims=ctx.spatial_dims, ) * mask) return torch.sum(x, dim=-4, keepdim=True)
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ init_pred = torch.sum( complex_mul( ifft2(y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims), complex_conj(sensitivity_maps), ), self.coil_dim, ) image = self.model(init_pred, y, sensitivity_maps, mask) image = torch.sum(complex_mul(image, complex_conj(sensitivity_maps)), self.coil_dim) image = torch.view_as_complex(image) _, image = center_crop_to_smallest(target, image) return image
def complexDot(data1, data2): """Complex dot product of two tensors.""" nBatch = data1.shape[0] mult = complex_mul(data1, complex_conj(data2)) re, im = torch.unbind(mult, dim=-1) return torch.stack([ torch.sum(re.view(nBatch, -1), dim=-1), torch.sum(im.view(nBatch, -1), dim=-1) ], -1)
def forward(self, coil_images: torch.Tensor, sensitivity_map: torch.Tensor) -> torch.Tensor: """Forward pass.""" combined_image = complex_mul( coil_images, complex_conj(sensitivity_map)).sum(self.coil_dim) residual_image = combined_image.unsqueeze(self.coil_dim) - complex_mul( combined_image.unsqueeze(self.coil_dim), sensitivity_map) return torch.cat( [ torch.cat( [ combined_image, residual_image.select(self.coil_dim, idx) ], self.channel_dim, ).unsqueeze(self.coil_dim) for idx in range(coil_images.size(self.coil_dim)) ], self.coil_dim, )
def AT(x): return torch.sum( complex_mul( ifft2(x * mask, centered=fft_centered, normalization=fft_normalization, spatial_dims=spatial_dims), complex_conj(smaps), ), dim=(-5), )
def _forward_operator(self, image, sampling_mask, sensitivity_map): """Forward operator.""" return torch.where( sampling_mask == 0, torch.tensor([0.0], dtype=image.dtype).to(image.device), fft2( complex_mul(image.unsqueeze(self.coil_dim), sensitivity_map), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).type(image.type()), )
def forward(self, x, y, smaps, mask): """ Forward pass of the data-consistency block. Parameters ---------- x: Input image. y: Subsampled k-space data. smaps: Coil sensitivity maps. mask: Sampling mask. Returns ------- Output image. """ A_x = torch.sum( fft2( complex_mul(x.unsqueeze(-5).expand_as(smaps), smaps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), -4, keepdim=True, ) k_dc = (1 - mask) * A_x + mask * (self.alpha * A_x + (1 - self.alpha) * y) x_dc = torch.sum( complex_mul( ifft2( k_dc, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(smaps), ), dim=(-5), ) return self.beta * x + (1 - self.beta) * x_dc
def _backward_operator(self, kspace, sampling_mask, sensitivity_map): """Backward operator.""" kspace = torch.where( sampling_mask == 0, torch.tensor([0.0], dtype=kspace.dtype).to(kspace.device), kspace) return (complex_mul( ifft2( kspace.float(), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_map), ).sum(self.coil_dim).type(kspace.type()))
def solve(x0, M, tol, max_iter): """Solve the linear system Mx=b using conjugate gradient.""" nBatch = x0.shape[0] x = torch.zeros(x0.shape).to(x0.device) r = x0.clone() p = x0.clone() x0x0 = (x0.pow(2)).view(nBatch, -1).sum(-1) rr = torch.stack([(r.pow(2)).view(nBatch, -1).sum(-1), torch.zeros(nBatch).to(x0.device)], dim=-1) it = 0 while torch.min(rr[..., 0] / x0x0) > tol and it < max_iter: it += 1 q = M(p) data1 = rr data2 = ConjugateGradient.complexDot(p, q) re1, im1 = torch.unbind(data1, -1) re2, im2 = torch.unbind(data2, -1) alpha = torch.stack([re1 * re2 + im1 * im2, im1 * re2 - re1 * im2], -1) / complex_abs(data2)**2 x += complex_mul(alpha.reshape(nBatch, 1, 1, 1, -1), p.clone()) r -= complex_mul(alpha.reshape(nBatch, 1, 1, 1, -1), q.clone()) rr_new = torch.stack([(r.pow(2)).view(nBatch, -1).sum(-1), torch.zeros(nBatch).to(x0.device)], dim=-1) beta = torch.stack([ rr_new[..., 0] / rr[..., 0], torch.zeros(nBatch).to(x0.device) ], dim=-1) p = r.clone() + complex_mul(beta.reshape(nBatch, 1, 1, 1, -1), p) rr = rr_new.clone() return x
def forward(self, x, y, smaps, mask): """ Parameters ---------- x: Input image. y: Subsampled k-space data. smaps: Coil sensitivity maps. mask: Sampling mask. Returns ------- data_loss: Data term loss. """ A_x_y = (torch.sum( fft2( complex_mul(x.unsqueeze(-5).expand_as(smaps), smaps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) * mask, -4, keepdim=True, ) - y) gradD_x = torch.sum( complex_mul( ifft2c( A_x_y * mask, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(smaps), ), dim=(-5), ) return x - self.data_weight * gradD_x
def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: """ Reduce the sensitivity maps. Parameters ---------- x: Input data. sens_maps: Coil Sensitivity maps. Returns ------- SENSE coil-combined reconstruction. """ x = ifft2(x, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims) return complex_mul(x, complex_conj(sens_maps)).sum(self.coil_dim)
def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: """ Expand the sensitivity maps to the same size as the input. Parameters ---------- x: Input data. sens_maps: Coil Sensitivity maps. Returns ------- SENSE reconstruction expanded to the same size as the input sens_maps. """ return fft2( complex_mul(x, sens_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, )
def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: """ Reduce the sensitivity maps to the same size as the input. Parameters ---------- x: Input data. torch.Tensor, shape [batch_size, n_coils, height, width, 2] sens_maps: Sensitivity maps. torch.Tensor, shape [batch_size, n_coils, height, width, 2] Returns ------- SENSE reconstruction. torch.Tensor, shape [batch_size, height, width, 2] """ x = ifft2(x, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims) return complex_mul(x, complex_conj(sens_maps)).sum(dim=self.coil_dim, keepdim=True)
def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: """ Expand the sensitivity maps to the same size as the input. Parameters ---------- x: Input data. torch.Tensor, shape [batch_size, n_coils, height, width, 2] sens_maps: Sensitivity maps. torch.Tensor, shape [batch_size, n_coils, height, width, 2] Returns ------- SENSE reconstruction expanded to the same size as the input. torch.Tensor, shape [batch_size, n_coils, height, width, 2] """ return fft2( complex_mul(x, sens_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, )
def update_C(self, idx, DC_sens, sensitivity_maps, image, y, mask) -> torch.Tensor: """ Update the coil sensitivity maps. .. math:: C = (1 - 2 * \lambda_{k}^{C} * ni_{k}) * C_{k} C = 2 * \lambda_{k}^{C} * ni_{k} * D_{C}(F^-1(b)) A(x_{k}) = M * F * (C * x_{k}) C = 2 * ni_{k} * F^-1(M.T * (M * F * (C * x_{k}) - b)) * x_{k}^* Parameters ---------- idx: int The current iteration index. DC_sens: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols] The initial coil sensitivity maps. sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols] The coil sensitivity maps. image: torch.Tensor [batch_size, num_coils, num_rows, num_cols] The predicted image. y: torch.Tensor [batch_size, num_coils, num_rows, num_cols] The subsampled k-space data. mask: torch.Tensor [batch_size, 1, num_rows, num_cols] The subsampled mask. Returns ------- sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols] The updated coil sensitivity maps. """ # (1 - 2 * lambda_{k}^{C} * ni_{k}) * C_{k} sense_term_1 = (1 - 2 * self.reg_param_C[idx] * self.lr_sens[idx]) * sensitivity_maps # 2 * lambda_{k}^{C} * ni_{k} * D_{C}(F^-1(b)) sense_term_2 = 2 * self.reg_param_C[idx] * self.lr_sens[idx] * DC_sens # A(x_{k}) = M * F * (C * x_{k}) sense_term_3_A = fft2( complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) sense_term_3_A = torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), sense_term_3_A) # 2 * ni_{k} * F^-1(M.T * (M * F * (C * x_{k}) - b)) * x_{k}^* sense_term_3_mask = torch.where( mask == 1, torch.tensor([0.0], dtype=y.dtype).to(y.device), sense_term_3_A - y, ) sense_term_3_backward = ifft2( sense_term_3_mask, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) sense_term_3 = 2 * self.lr_sens[idx] * sense_term_3_backward * complex_conj(image).unsqueeze(self.coil_dim) sensitivity_maps = sense_term_1 + sense_term_2 - sense_term_3 return sensitivity_maps
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ input_image = complex_mul( ifft2( torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), y), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) dual_buffer = torch.cat([y] * self.num_dual, -1).to(y.device) primal_buffer = torch.cat([input_image] * self.num_primal, -1).to(y.device) for idx in range(self.num_iter): # Dual f_2 = primal_buffer[..., 2:4].clone() f_2 = torch.where( mask == 0, torch.tensor([0.0], dtype=f_2.dtype).to(f_2.device), fft2( complex_mul(f_2.unsqueeze(self.coil_dim), sensitivity_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).type(f_2.type()), ) dual_buffer = self.dual_net[idx](dual_buffer, f_2, y) # Primal h_1 = dual_buffer[..., 0:2].clone() h_1 = complex_mul( ifft2( torch.where( mask == 0, torch.tensor([0.0], dtype=h_1.dtype).to(h_1.device), h_1), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) primal_buffer = self.primal_net[idx](primal_buffer, h_1) output = primal_buffer[..., 0:2] output = (output**2).sum(-1).sqrt() _, output = center_crop_to_smallest(target, output) return output
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, **kwargs, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ previous_state: Optional[torch.Tensor] = None if self.initializer is not None: if self.initializer_initialization == "sense": initializer_input_image = (complex_mul( ifft2( y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim).unsqueeze(self.coil_dim)) elif self.initializer_initialization == "input_image": if "initial_image" not in kwargs: raise ValueError( "`'initial_image` is required as input if initializer_initialization " f"is {self.initializer_initialization}.") initializer_input_image = kwargs["initial_image"].unsqueeze( self.coil_dim) elif self.initializer_initialization == "zero_filled": initializer_input_image = ifft2( y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) previous_state = self.initializer( fft2( initializer_input_image, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).sum(1).permute(0, 3, 1, 2)) kspace_prediction = y.clone() for step in range(self.num_steps): block = self.block_list[ step] if self.no_parameter_sharing else self.block_list[0] kspace_prediction, previous_state = block( kspace_prediction, y, mask, sensitivity_maps, previous_state, ) eta = ifft2( kspace_prediction, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) eta = coil_combination(eta, sensitivity_maps, method=self.coil_combination_method, dim=self.coil_dim) eta = torch.view_as_complex(eta) _, eta = center_crop_to_smallest(target, eta) return eta
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ kspace = y.clone() zero = torch.zeros(1, 1, 1, 1, 1).to(kspace) for idx in range(self.num_iter): soft_dc = torch.where(mask.bool(), kspace - y, zero) * self.dc_weight kspace = self.kspace_model_list[idx](kspace) if kspace.shape[-1] != 2: kspace = kspace.permute(0, 1, 3, 4, 2).to(target) kspace = torch.view_as_real(kspace[..., 0] + 1j * kspace[..., 1]) # this is necessary, but why? image = complex_mul( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) image = self.image_model_list[idx](image.unsqueeze(self.coil_dim)).squeeze(self.coil_dim) if not self.no_dc: image = fft2( complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).type(image.type()) image = kspace - soft_dc - image image = complex_mul( ifft2( image, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) if idx < self.num_iter - 1: kspace = fft2( complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).type(image.type()) image = torch.view_as_complex(image) _, image = center_crop_to_smallest(target, image) return image
def forward( self, current_kspace: torch.Tensor, masked_kspace: torch.Tensor, sampling_mask: torch.Tensor, sensitivity_map: torch.Tensor, hidden_state: Union[None, torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes forward pass of RecurrentVarNetBlock. Parameters ---------- current_kspace: Current k-space prediction. torch.Tensor, shape [batch_size, n_coil, height, width, 2] masked_kspace: Subsampled k-space. torch.Tensor, shape [batch_size, n_coil, height, width, 2] sampling_mask: Sampling mask. torch.Tensor, shape [batch_size, 1, height, width, 1] sensitivity_map: Coil sensitivities. torch.Tensor, shape [batch_size, n_coil, height, width, 2] hidden_state: ConvGRU hidden state. None or torch.Tensor, shape [batch_size, n_l, height, width, hidden_channels] Returns ------- new_kspace: New k-space prediction. torch.Tensor, shape [batch_size, n_coil, height, width, 2] hidden_state: Next hidden state. list of torch.Tensor, shape [batch_size, hidden_channels, height, width, num_layers] """ kspace_error = torch.where( sampling_mask == 0, torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), current_kspace - masked_kspace, ) recurrent_term = torch.cat( [ complex_mul( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_map), ).sum(self.coil_dim) for kspace in torch.split(current_kspace, 2, -1) ], dim=-1, ).permute(0, 3, 1, 2) recurrent_term, hidden_state = self.regularizer(recurrent_term, hidden_state) # :math:`w_t`, :math:`h_{t+1}` recurrent_term = recurrent_term.permute(0, 2, 3, 1) recurrent_term = torch.cat( [ fft2( complex_mul(image.unsqueeze(self.coil_dim), sensitivity_map), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) for image in torch.split(recurrent_term, 2, -1) ], dim=-1, ) new_kspace = current_kspace - self.learning_rate * kspace_error + recurrent_term return new_kspace, hidden_state
def backward(ctx, grad_x): """ Backward pass of the conjugate gradient solver. Parameters ---------- ctx: Context object. grad_x: Gradient of the output image. Returns ------- grad_z: Gradient of the input image. """ ATy, rhs, smaps, mask, lambdaa = ctx.saved_tensors def A(x): x = (fft2( complex_mul(x.expand_as(smaps), smaps), centered=ctx.fft_centered, normalization=ctx.fft_normalization, spatial_dims=ctx.spatial_dims, ) * mask) return torch.sum(x, dim=-4, keepdim=True) def AT(x): return torch.sum( complex_mul( ifft2( x * mask, centered=ctx.fft_centered, normalization=ctx.fft_normalization, spatial_dims=ctx.spatial_dimso, ), complex_conj(smaps), ), dim=(-5), ) def M(p): return lambdaa * AT(A(p)) + p Qe = ConjugateGradient.solve(grad_x, M, ctx.tol, ctx.max_iter) QQe = ConjugateGradient.solve(Qe, M, ctx.tol, ctx.max_iter) grad_z = Qe grad_lambdaa = (complex_mul( ifft2(Qe, centered=ctx.fft_centered, normalization=ctx.fft_normalization, spatial_dims=ctx.spatial_dims), complex_conj(ATy), ).sum() - complex_mul( ifft2(QQe, centered=ctx.fft_centered, normalization=ctx.fft_normalization, spatial_dims=ctx.spatial_dims), complex_conj(rhs), ).sum()) return grad_z, grad_lambdaa, None, None, None, None, None, None
def update_X(self, idx, image, sensitivity_maps, y, mask): """ Update the image. .. math:: x_{k} = (1 - 2 * \lamdba_{{k}_{I}} * mi_{k} - 2 * \lamdba_{{k}_{F}} * mi_{k}) * x_{k} x_{k} = 2 * mi_{k} * (\lambda_{{k}_{I}} * D_I(x_{k}) + \lambda_{{k}_{F}} * F^-1(D_F(f))) A(x{k} - b) = M * F * (C * x{k}) - b x_{k} = 2 * mi_{k} * A^* * (A(x{k} - b)) Parameters ---------- idx: int The current iteration index. image: torch.Tensor [batch_size, num_coils, num_rows, num_cols] The predicted image. sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols] The coil sensitivity maps. y: torch.Tensor [batch_size, num_coils, num_rows, num_cols] The subsampled k-space data. mask: torch.Tensor [batch_size, 1, num_rows, num_cols] The subsampled mask. Returns ------- image: torch.Tensor [batch_size, num_coils, num_rows, num_cols] The updated image. """ # (1 - 2 * lamdba_{k}_{I} * mi_{k} - 2 * lamdba_{k}_{F} * mi_{k}) * x_{k} image_term_1 = ( 1 - 2 * self.reg_param_I[idx] * self.lr_image[idx] - 2 * self.reg_param_F[idx] * self.lr_image[idx] ) * image # D_I(x_{k}) image_term_2_DI = self.image_model(image.unsqueeze(self.coil_dim)).squeeze(self.coil_dim).contiguous() image_term_2_DF = ifft2( self.kspace_model( fft2( image, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).unsqueeze(self.coil_dim) ) .squeeze(self.coil_dim) .contiguous(), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) # 2 * mi_{k} * (lambda_{k}_{I} * D_I(x_{k}) + lambda_{k}_{F} * F^-1(D_F(f))) image_term_2 = ( 2 * self.lr_image[idx] * (self.reg_param_I[idx] * image_term_2_DI + self.reg_param_F[idx] * image_term_2_DF) ) # A(x{k}) - b) = M * F * (C * x{k}) - b image_term_3_A = fft2( complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) image_term_3_A = torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), image_term_3_A) - y # 2 * mi_{k} * A^* * (A(x{k}) - b)) image_term_3_Aconj = complex_mul( ifft2( image_term_3_A, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) image_term_3 = 2 * self.lr_image[idx] * image_term_3_Aconj image = image_term_1 + image_term_2 - image_term_3 return image
def forward( self, pred: torch.Tensor, masked_kspace: torch.Tensor, sense: torch.Tensor, mask: torch.Tensor, eta: torch.Tensor = None, hx: torch.Tensor = None, sigma: float = 1.0, keep_eta: bool = False, ) -> Tuple[Any, Union[list, torch.Tensor, None]]: """ Forward pass of the RIMBlock. Parameters ---------- pred: Predicted k-space. masked_kspace: Subsampled k-space. sense: Coil sensitivity maps. mask: Sample mask. eta: Initial guess for the eta. hx: Initial guess for the hidden state. sigma: Noise level. keep_eta: Whether to keep the eta. Returns ------- Reconstructed image and hidden states. """ if hx is None: hx = [ masked_kspace.new_zeros( (masked_kspace.size(0), f, *masked_kspace.size()[2:-1])) for f in self.recurrent_filters if f != 0 ] if isinstance(pred, list): pred = pred[-1].detach() if eta is None or eta.ndim < 3: eta = (pred if keep_eta else torch.sum( complex_mul( ifft2( pred, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sense), ), self.coil_dim, )) etas = [] for _ in range(self.time_steps): grad_eta = log_likelihood_gradient( eta, masked_kspace, sense, mask, sigma=sigma, fft_centered=self.fft_centered, fft_normalization=self.fft_normalization, spatial_dims=self.spatial_dims, coil_dim=self.coil_dim, ).contiguous() for h, convrnn in enumerate(self.layers): hx[h] = convrnn(grad_eta, hx[h]) grad_eta = hx[h] eta = eta + self.final_layer(grad_eta).permute(0, 2, 3, 1) etas.append(eta) eta = etas if self.no_dc: return eta, None soft_dc = torch.where(mask, pred - masked_kspace, self.zero.to(masked_kspace)) * self.dc_weight current_kspace = [ masked_kspace - soft_dc - fft2( complex_mul(e.unsqueeze(self.coil_dim), sense), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) for e in eta ] return current_kspace, None
def evaluate( arguments, reconstruction_key, mask_background, output_path, method, acc, no_params, slice_start, slice_end, coil_dim, ): """ Evaluates the reconstructions. Parameters ---------- arguments: The CLI arguments. reconstruction_key: The key of the reconstruction to evaluate. mask_background: The background mask. output_path: The output path. method: The reconstruction method. acc: The acceleration factor. no_params: The number of parameters. slice_start: The start slice. (optional) slice_end: The end slice. (optional) coil_dim: The coil dimension. (optional) Returns ------- dict: A dict where the keys are metric names and the values are the mean of the metric. """ _metrics = Metrics(METRIC_FUNCS, output_path, method) if arguments.type == "mean_std" else {} for tgt_file in tqdm(arguments.target_path.iterdir()): if exists(arguments.predictions_path / tgt_file.name): with h5py.File(tgt_file, "r") as target, h5py.File( arguments.predictions_path / tgt_file.name, "r") as recons: kspace = target["kspace"][()] if arguments.sense_path is not None: sense = h5py.File(arguments.sense_path / tgt_file.name, "r")["sensitivity_map"][()] elif "sensitivity_map" in target: sense = target["sensitivity_map"][()] sense = sense.squeeze().astype(np.complex64) if sense.shape != kspace.shape: sense = np.transpose(sense, (0, 3, 1, 2)) target = np.abs( tensor_to_complex_np( torch.sum( complex_mul( ifft2(to_tensor(kspace), centered="fastmri" in str(arguments.sense_path).lower()), complex_conj(to_tensor(sense)), ), coil_dim, ))) recons = recons[reconstruction_key][()] if recons.ndim == 4: recons = recons.squeeze(coil_dim) if arguments.crop_size is not None: crop_size = arguments.crop_size crop_size[0] = min(target.shape[-2], int(crop_size[0])) crop_size[1] = min(target.shape[-1], int(crop_size[1])) crop_size[0] = min(recons.shape[-2], int(crop_size[0])) crop_size[1] = min(recons.shape[-1], int(crop_size[1])) target = center_crop(target, crop_size) recons = center_crop(recons, crop_size) if mask_background: for sl in range(target.shape[0]): mask = convex_hull_image( np.where( np.abs(target[sl]) > threshold_otsu( np.abs(target[sl])), 1, 0) # type: ignore ) target[sl] = target[sl] * mask recons[sl] = recons[sl] * mask if slice_start is not None: target = target[slice_start:] recons = recons[slice_start:] if slice_end is not None: target = target[:slice_end] recons = recons[:slice_end] for sl in range(target.shape[0]): target[sl] = target[sl] / np.max(np.abs(target[sl])) recons[sl] = recons[sl] / np.max(np.abs(recons[sl])) target = np.abs(target) recons = np.abs(recons) if arguments.type == "mean_std": _metrics.push(target, recons) else: _target = np.expand_dims(target, coil_dim) _recons = np.expand_dims(recons, coil_dim) for sl in range(target.shape[0]): _metrics["FNAME"] = tgt_file.name _metrics["SLICE"] = sl _metrics["ACC"] = acc _metrics["METHOD"] = method _metrics["MSE"] = [mse(target[sl], recons[sl])] _metrics["NMSE"] = [nmse(target[sl], recons[sl])] _metrics["PSNR"] = [psnr(target[sl], recons[sl])] _metrics["SSIM"] = [ssim(_target[sl], _recons[sl])] _metrics["PARAMS"] = no_params if not exists(arguments.output_path): pd.DataFrame(columns=_metrics.keys()).to_csv( arguments.output_path, index=False, mode="w") pd.DataFrame(_metrics).to_csv(arguments.output_path, index=False, header=False, mode="a") return _metrics