def forward(self, image): """Forward method for the MultiDomainConvTranspose2d class.""" kspace = [ fft2(im, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims) for im in torch.split( image.permute(0, 2, 3, 1).contiguous(), 2, -1) ] kspace = torch.cat(kspace, -1).permute(0, 3, 1, 2) kspace = self.kspace_conv(kspace) backward = [ ifft2( ks.float(), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).type(image.type()) for ks in torch.split( kspace.permute(0, 2, 3, 1).contiguous(), 2, -1) ] backward = torch.cat(backward, -1).permute(0, 3, 1, 2) image = self.image_conv(image) return torch.cat([image, backward], dim=self.coil_dim)
def test_centered_ifft2_forward_normalization(shape): """ Test centered 2D Inverse Fast Fourier Transform with forward normalization. Args: shape: shape of the input Returns: None """ shape = shape + [2] x = create_input(shape) out_torch = ifft2(x, centered=True, normalization="forward", spatial_dims=[-2, -1]).numpy() out_torch = out_torch[..., 0] + 1j * out_torch[..., 1] input_numpy = tensor_to_complex_np(x) input_numpy = np.fft.ifftshift(input_numpy, (-2, -1)) out_numpy = np.fft.ifft2(input_numpy, norm="forward") out_numpy = np.fft.fftshift(out_numpy, (-2, -1)) if not np.allclose(out_torch, out_numpy): raise AssertionError
def log_likelihood_gradient( eta: torch.Tensor, masked_kspace: torch.Tensor, sense: torch.Tensor, mask: torch.Tensor, sigma: float, fft_centered: bool, fft_normalization: str, spatial_dims: Sequence[int], coil_dim: int, ) -> torch.Tensor: """ Computes the gradient of the log-likelihood function. Parameters ---------- eta: Initial guess for the reconstruction. masked_kspace: Subsampled k-space data. sense: Sensing matrix. mask: Sampling mask. sigma: Noise level. fft_centered: Whether to center the FFT. fft_normalization: Whether to normalize the FFT. spatial_dims: Spatial dimensions of the data. coil_dim: Dimension of the coil. Returns ------- Gradient of the log-likelihood function. """ eta_real, eta_imag = map(lambda x: torch.unsqueeze(x, 0), eta.chunk(2, -1)) sense_real, sense_imag = sense.chunk(2, -1) re_se = eta_real * sense_real - eta_imag * sense_imag im_se = eta_real * sense_imag + eta_imag * sense_real pred = ifft2( mask * (fft2( torch.cat((re_se, im_se), -1), centered=fft_centered, normalization=fft_normalization, spatial_dims=spatial_dims, ) - masked_kspace), centered=fft_centered, normalization=fft_normalization, spatial_dims=spatial_dims, ) pred_real, pred_imag = pred.chunk(2, -1) re_out = torch.sum(pred_real * sense_real + pred_imag * sense_imag, coil_dim) / (sigma**2.0) im_out = torch.sum(pred_imag * sense_real - pred_real * sense_imag, coil_dim) / (sigma**2.0) eta_real = eta_real.squeeze(0) eta_imag = eta_imag.squeeze(0) return torch.cat((eta_real, eta_imag, re_out, im_out), 0).unsqueeze(0).squeeze(-1)
def process_intermediate_pred(self, pred, sensitivity_maps, target): """ Process the intermediate prediction. Parameters ---------- pred: Intermediate prediction. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] target: Target data to crop to size. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: torch.Tensor, shape [batch_size, n_x, n_y, 2] Processed prediction. """ pred = ifft2( pred, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims ) pred = coil_combination(pred, sensitivity_maps, method=self.coil_combination_method, dim=self.coil_dim) pred = torch.view_as_complex(pred) _, pred = center_crop_to_smallest(target, pred) return pred
def AT(x): return torch.sum( complex_mul( ifft2(x * mask, centered=fft_centered, normalization=fft_normalization, spatial_dims=spatial_dims), complex_conj(smaps), ), dim=(-5), )
def test_step(self, batch: Dict[float, torch.Tensor], batch_idx: int) -> Tuple[str, int, torch.Tensor]: """ Test step. Parameters ---------- batch: Batch of data. Dict of torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] batch_idx: Batch index. int Returns ------- name: Name of the volume. str slice_num: Slice number. int pred: Predicted data. torch.Tensor, shape [batch_size, n_x, n_y, 2] """ kspace, y, sensitivity_maps, mask, init_pred, target, fname, slice_num, _ = batch y, mask, _ = self.process_inputs(y, mask) if self.use_sens_net: sensitivity_maps = self.sens_net(kspace, mask) if self.coil_combination_method.upper() == "SENSE": target = sense( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), sensitivity_maps, dim=self.coil_dim, ) prediction = self.forward(y, sensitivity_maps, mask, target) slice_num = int(slice_num) name = str(fname[0]) # type: ignore key = f"{name}_images_idx_{slice_num}" # type: ignore output = torch.abs(prediction).detach().cpu() target = torch.abs(target).detach().cpu() output = output / output.max() # type: ignore target = target / target.max() # type: ignore error = torch.abs(target - output) self.log_image(f"{key}/target", target) self.log_image(f"{key}/reconstruction", output) self.log_image(f"{key}/error", error) return name, slice_num, prediction.detach().cpu().numpy()
def _backward_operator(self, kspace, sampling_mask, sensitivity_map): """Backward operator.""" kspace = torch.where( sampling_mask == 0, torch.tensor([0.0], dtype=kspace.dtype).to(kspace.device), kspace) return (complex_mul( ifft2( kspace.float(), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_map), ).sum(self.coil_dim).type(kspace.type()))
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ estimation = y.clone() for cascade in self.cascades: # Forward pass through the cascades estimation = cascade(estimation, y, sensitivity_maps, mask) estimation = ifft2( estimation, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) estimation = coil_combination(estimation, sensitivity_maps, method=self.coil_combination_method, dim=self.coil_dim) estimation = torch.view_as_complex(estimation) _, estimation = center_crop_to_smallest(target, estimation) return estimation
def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: """ Reduce the sensitivity maps. Parameters ---------- x: Input data. sens_maps: Coil Sensitivity maps. Returns ------- SENSE coil-combined reconstruction. """ x = ifft2(x, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims) return complex_mul(x, complex_conj(sens_maps)).sum(self.coil_dim)
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ image = ifft2(y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims) if hasattr(self, "standardization"): image = self.standardization(image, sensitivity_maps) output_image = self._compute_model_per_coil( self.unet, image.permute(0, 1, 4, 2, 3)).permute(0, 1, 3, 4, 2) output_image = coil_combination(output_image, sensitivity_maps, method=self.coil_combination_method, dim=self.coil_dim) output_image = torch.view_as_complex(output_image) _, output_image = center_crop_to_smallest(target, output_image) return output_image
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ sensitivity_maps = self.sens_net( y, mask) if self.use_sens_net else sensitivity_maps image = self.model(y, sensitivity_maps, mask) image = torch.view_as_complex( coil_combination( ifft2( image, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), sensitivity_maps, method=self.coil_combination_method, dim=self.coil_dim, )) _, image = center_crop_to_smallest(target, image) return image
def forward( self, masked_kspace: torch.Tensor, mask: torch.Tensor, num_low_frequencies: Optional[int] = None, ) -> torch.Tensor: """ Forward pass of the model. Parameters ---------- masked_kspace: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [batch_size, 1, n_x, n_y, 1] num_low_frequencies: Number of low frequencies to keep. int Returns ------- Normalized UNet output tensor. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] """ if self.mask_center: pad, num_low_freqs = self.get_pad_and_num_low_freqs( mask, num_low_frequencies) masked_kspace = batched_mask_center(masked_kspace, pad, pad + num_low_freqs, mask_type=self.mask_type) # convert to image space images, batches = self.chans_to_batch_dim( ifft2( masked_kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, )) # estimate sensitivities images = self.batch_chans_to_chan_dim(self.norm_unet(images), batches) if self.normalize: images = self.divide_root_sum_of_squares(images, self.coil_dim) return images
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ DC_sens = self.sens_net(y, mask) sensitivity_maps = DC_sens.clone() image = complex_mul( ifft2(y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims), complex_conj(sensitivity_maps), ).sum(self.coil_dim) for idx in range(self.num_iter): sensitivity_maps = self.update_C(idx, DC_sens, sensitivity_maps, image, y, mask) image = self.update_X(idx, image, sensitivity_maps, y, mask) image = torch.view_as_complex(image) _, image = center_crop_to_smallest(target, image) return image
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, target: torch.Tensor = None, ) -> Union[list, Any]: """ Forward pass of the zero-filled method. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: torch.Tensor, shape [batch_size, n_x, n_y, 2] Predicted data. """ pred = coil_combination( ifft2(y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims), sensitivity_maps, method=self.coil_combination_method.upper(), dim=self.coil_dim, ) pred = check_stacked_complex(pred) _, pred = center_crop_to_smallest(target, pred) return pred
def forward(self, x, y, smaps, mask): """ Forward pass of the data-consistency block. Parameters ---------- x: Input image. y: Subsampled k-space data. smaps: Coil sensitivity maps. mask: Sampling mask. Returns ------- Output image. """ A_x = torch.sum( fft2( complex_mul(x.unsqueeze(-5).expand_as(smaps), smaps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), -4, keepdim=True, ) k_dc = (1 - mask) * A_x + mask * (self.alpha * A_x + (1 - self.alpha) * y) x_dc = torch.sum( complex_mul( ifft2( k_dc, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(smaps), ), dim=(-5), ) return self.beta * x + (1 - self.beta) * x_dc
def test_non_centered_ifft2(shape): """ Test non-centered 2D Inverse Fast Fourier Transform. Args: shape: shape of the input Returns: None """ shape = shape + [2] x = create_input(shape) out_torch = ifft2(x, centered=False, normalization="ortho", spatial_dims=[-2, -1]).numpy() out_torch = out_torch[..., 0] + 1j * out_torch[..., 1] input_numpy = tensor_to_complex_np(x) out_numpy = np.fft.ifft2(input_numpy, norm="ortho") if not np.allclose(out_torch, out_numpy): raise AssertionError
def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: """ Reduce the sensitivity maps to the same size as the input. Parameters ---------- x: Input data. torch.Tensor, shape [batch_size, n_coils, height, width, 2] sens_maps: Sensitivity maps. torch.Tensor, shape [batch_size, n_coils, height, width, 2] Returns ------- SENSE reconstruction. torch.Tensor, shape [batch_size, height, width, 2] """ x = ifft2(x, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims) return complex_mul(x, complex_conj(sens_maps)).sum(dim=self.coil_dim, keepdim=True)
def forward(self, x, y, mask): """ Forward pass of the data-consistency block. Parameters ---------- x: Input image. y: Subsampled k-space data. mask: Sampling mask. Returns ------- Output image. """ A_x = fft2(x, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims) k_dc = (1 - mask) * A_x + mask * (self.lambda_ * A_x + (1 - self.lambda_) * y) return ifft2(k_dc, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims)
def forward( self, pred: torch.Tensor, masked_kspace: torch.Tensor, sense: torch.Tensor, mask: torch.Tensor, eta: torch.Tensor = None, hx: torch.Tensor = None, sigma: float = 1.0, keep_eta: bool = False, ) -> Tuple[Any, Union[list, torch.Tensor, None]]: """ Forward pass of the RIMBlock. Parameters ---------- pred: Predicted k-space. masked_kspace: Subsampled k-space. sense: Coil sensitivity maps. mask: Sample mask. eta: Initial guess for the eta. hx: Initial guess for the hidden state. sigma: Noise level. keep_eta: Whether to keep the eta. Returns ------- Reconstructed image and hidden states. """ if hx is None: hx = [ masked_kspace.new_zeros( (masked_kspace.size(0), f, *masked_kspace.size()[2:-1])) for f in self.recurrent_filters if f != 0 ] if isinstance(pred, list): pred = pred[-1].detach() if eta is None or eta.ndim < 3: eta = (pred if keep_eta else torch.sum( complex_mul( ifft2( pred, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sense), ), self.coil_dim, )) etas = [] for _ in range(self.time_steps): grad_eta = log_likelihood_gradient( eta, masked_kspace, sense, mask, sigma=sigma, fft_centered=self.fft_centered, fft_normalization=self.fft_normalization, spatial_dims=self.spatial_dims, coil_dim=self.coil_dim, ).contiguous() for h, convrnn in enumerate(self.layers): hx[h] = convrnn(grad_eta, hx[h]) grad_eta = hx[h] eta = eta + self.final_layer(grad_eta).permute(0, 2, 3, 1) etas.append(eta) eta = etas if self.no_dc: return eta, None soft_dc = torch.where(mask, pred - masked_kspace, self.zero.to(masked_kspace)) * self.dc_weight current_kspace = [ masked_kspace - soft_dc - fft2( complex_mul(e.unsqueeze(self.coil_dim), sense), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) for e in eta ] return current_kspace, None
def update_X(self, idx, image, sensitivity_maps, y, mask): """ Update the image. .. math:: x_{k} = (1 - 2 * \lamdba_{{k}_{I}} * mi_{k} - 2 * \lamdba_{{k}_{F}} * mi_{k}) * x_{k} x_{k} = 2 * mi_{k} * (\lambda_{{k}_{I}} * D_I(x_{k}) + \lambda_{{k}_{F}} * F^-1(D_F(f))) A(x{k} - b) = M * F * (C * x{k}) - b x_{k} = 2 * mi_{k} * A^* * (A(x{k} - b)) Parameters ---------- idx: int The current iteration index. image: torch.Tensor [batch_size, num_coils, num_rows, num_cols] The predicted image. sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols] The coil sensitivity maps. y: torch.Tensor [batch_size, num_coils, num_rows, num_cols] The subsampled k-space data. mask: torch.Tensor [batch_size, 1, num_rows, num_cols] The subsampled mask. Returns ------- image: torch.Tensor [batch_size, num_coils, num_rows, num_cols] The updated image. """ # (1 - 2 * lamdba_{k}_{I} * mi_{k} - 2 * lamdba_{k}_{F} * mi_{k}) * x_{k} image_term_1 = ( 1 - 2 * self.reg_param_I[idx] * self.lr_image[idx] - 2 * self.reg_param_F[idx] * self.lr_image[idx] ) * image # D_I(x_{k}) image_term_2_DI = self.image_model(image.unsqueeze(self.coil_dim)).squeeze(self.coil_dim).contiguous() image_term_2_DF = ifft2( self.kspace_model( fft2( image, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).unsqueeze(self.coil_dim) ) .squeeze(self.coil_dim) .contiguous(), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) # 2 * mi_{k} * (lambda_{k}_{I} * D_I(x_{k}) + lambda_{k}_{F} * F^-1(D_F(f))) image_term_2 = ( 2 * self.lr_image[idx] * (self.reg_param_I[idx] * image_term_2_DI + self.reg_param_F[idx] * image_term_2_DF) ) # A(x{k}) - b) = M * F * (C * x{k}) - b image_term_3_A = fft2( complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) image_term_3_A = torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), image_term_3_A) - y # 2 * mi_{k} * A^* * (A(x{k}) - b)) image_term_3_Aconj = complex_mul( ifft2( image_term_3_A, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) image_term_3 = 2 * self.lr_image[idx] * image_term_3_Aconj image = image_term_1 + image_term_2 - image_term_3 return image
def update_C(self, idx, DC_sens, sensitivity_maps, image, y, mask) -> torch.Tensor: """ Update the coil sensitivity maps. .. math:: C = (1 - 2 * \lambda_{k}^{C} * ni_{k}) * C_{k} C = 2 * \lambda_{k}^{C} * ni_{k} * D_{C}(F^-1(b)) A(x_{k}) = M * F * (C * x_{k}) C = 2 * ni_{k} * F^-1(M.T * (M * F * (C * x_{k}) - b)) * x_{k}^* Parameters ---------- idx: int The current iteration index. DC_sens: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols] The initial coil sensitivity maps. sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols] The coil sensitivity maps. image: torch.Tensor [batch_size, num_coils, num_rows, num_cols] The predicted image. y: torch.Tensor [batch_size, num_coils, num_rows, num_cols] The subsampled k-space data. mask: torch.Tensor [batch_size, 1, num_rows, num_cols] The subsampled mask. Returns ------- sensitivity_maps: torch.Tensor [batch_size, num_coils, num_sens_maps, num_rows, num_cols] The updated coil sensitivity maps. """ # (1 - 2 * lambda_{k}^{C} * ni_{k}) * C_{k} sense_term_1 = (1 - 2 * self.reg_param_C[idx] * self.lr_sens[idx]) * sensitivity_maps # 2 * lambda_{k}^{C} * ni_{k} * D_{C}(F^-1(b)) sense_term_2 = 2 * self.reg_param_C[idx] * self.lr_sens[idx] * DC_sens # A(x_{k}) = M * F * (C * x_{k}) sense_term_3_A = fft2( complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) sense_term_3_A = torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), sense_term_3_A) # 2 * ni_{k} * F^-1(M.T * (M * F * (C * x_{k}) - b)) * x_{k}^* sense_term_3_mask = torch.where( mask == 1, torch.tensor([0.0], dtype=y.dtype).to(y.device), sense_term_3_A - y, ) sense_term_3_backward = ifft2( sense_term_3_mask, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) sense_term_3 = 2 * self.lr_sens[idx] * sense_term_3_backward * complex_conj(image).unsqueeze(self.coil_dim) sensitivity_maps = sense_term_1 + sense_term_2 - sense_term_3 return sensitivity_maps
def test_step(self, batch: Dict[float, torch.Tensor], batch_idx: int) -> Tuple[str, int, torch.Tensor]: """ Performs a test step. Parameters ---------- batch: Batch of data. Dict[str, torch.Tensor], with keys, 'y': subsampled kspace, torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] 'sensitivity_maps': sensitivity_maps, torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] 'mask': sampling mask, torch.Tensor, shape [1, 1, n_x, n_y, 1] 'init_pred': initial prediction. For example zero-filled or PICS. torch.Tensor, shape [batch_size, n_x, n_y, 2] 'target': target data, torch.Tensor, shape [batch_size, n_x, n_y, 2] 'fname': filename, str, shape [batch_size] 'slice_idx': slice_idx, torch.Tensor, shape [batch_size] 'acc': acceleration factor, torch.Tensor, shape [batch_size] 'max_value': maximum value of the magnitude image space, torch.Tensor, shape [batch_size] 'crop_size': crop size, torch.Tensor, shape [n_x, n_y] batch_idx: Batch index. int Returns ------- name: Name of the volume. str slice_num: Slice number. int pred: Predicted data. torch.Tensor, shape [batch_size, n_x, n_y, 2] """ kspace, y, sensitivity_maps, mask, init_pred, target, fname, slice_num, _ = batch y, mask, init_pred, r = self.process_inputs(y, mask, init_pred) if self.use_sens_net: sensitivity_maps = self.sens_net(kspace, mask) if self.coil_combination_method.upper() == "SENSE": target = sense( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), sensitivity_maps, dim=self.coil_dim, ) preds = self.forward(y, sensitivity_maps, mask, init_pred, target) if self.accumulate_estimates: try: preds = next(preds) except StopIteration: pass # Cascades if isinstance(preds, list): preds = preds[-1] # Time-steps if isinstance(preds, list): preds = preds[-1] slice_num = int(slice_num) name = str(fname[0]) # type: ignore key = f"{name}_images_idx_{slice_num}" # type: ignore output = torch.abs(preds).detach().cpu() output = output / output.max() # type: ignore target = torch.abs(target).detach().cpu() target = target / target.max() # type: ignore error = torch.abs(target - output) self.log_image(f"{key}/target", target) self.log_image(f"{key}/reconstruction", output) self.log_image(f"{key}/error", error) return name, slice_num, preds.detach().cpu().numpy()
def validation_step(self, batch: Dict[float, torch.Tensor], batch_idx: int) -> Dict: """ Performs a validation step. Parameters ---------- batch: Batch of data. Dict[str, torch.Tensor], with keys, 'y': subsampled kspace, torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] 'sensitivity_maps': sensitivity_maps, torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] 'mask': sampling mask, torch.Tensor, shape [1, 1, n_x, n_y, 1] 'init_pred': initial prediction. For example zero-filled or PICS. torch.Tensor, shape [batch_size, n_x, n_y, 2] 'target': target data, torch.Tensor, shape [batch_size, n_x, n_y, 2] 'fname': filename, str, shape [batch_size] 'slice_idx': slice_idx, torch.Tensor, shape [batch_size] 'acc': acceleration factor, torch.Tensor, shape [batch_size] 'max_value': maximum value of the magnitude image space, torch.Tensor, shape [batch_size] 'crop_size': crop size, torch.Tensor, shape [n_x, n_y] batch_idx: Batch index. int Returns ------- Dict[str, torch.Tensor], with keys, 'loss': loss, torch.Tensor, shape [1] 'log': log, dict, shape [1] """ kspace, y, sensitivity_maps, mask, init_pred, target, fname, slice_num, _ = batch y, mask, init_pred, r = self.process_inputs(y, mask, init_pred) if self.use_sens_net: sensitivity_maps = self.sens_net(kspace, mask) if self.coil_combination_method.upper() == "SENSE": target = sense( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), sensitivity_maps, dim=self.coil_dim, ) preds = self.forward(y, sensitivity_maps, mask, init_pred, target) if self.accumulate_estimates: try: preds = next(preds) except StopIteration: pass val_loss = sum( self.process_loss(target, preds, _loss_fn=self.eval_loss_fn)) else: val_loss = self.process_loss(target, preds, _loss_fn=self.eval_loss_fn) # Cascades if isinstance(preds, list): preds = preds[-1] # Time-steps if isinstance(preds, list): preds = preds[-1] key = f"{fname[0]}_images_idx_{int(slice_num)}" # type: ignore output = torch.abs(preds).detach().cpu() target = torch.abs(target).detach().cpu() output = output / output.max() # type: ignore target = target / target.max() # type: ignore error = torch.abs(target - output) self.log_image(f"{key}/target", target) self.log_image(f"{key}/reconstruction", output) self.log_image(f"{key}/error", error) target = target.numpy() # type: ignore output = output.numpy() # type: ignore self.mse_vals[fname][slice_num] = torch.tensor(mse(target, output)).view(1) self.nmse_vals[fname][slice_num] = torch.tensor(nmse(target, output)).view(1) self.ssim_vals[fname][slice_num] = torch.tensor( ssim(target, output, maxval=output.max() - output.min())).view(1) self.psnr_vals[fname][slice_num] = torch.tensor( psnr(target, output, maxval=output.max() - output.min())).view(1) return {"val_loss": val_loss}
def training_step(self, batch: Dict[float, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]: """ Performs a training step. Parameters ---------- batch: Batch of data. Dict[str, torch.Tensor], with keys, 'y': subsampled kspace, torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] 'sensitivity_maps': sensitivity_maps, torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] 'mask': sampling mask, torch.Tensor, shape [1, 1, n_x, n_y, 1] 'init_pred': initial prediction. For example zero-filled or PICS. torch.Tensor, shape [batch_size, n_x, n_y, 2] 'target': target data, torch.Tensor, shape [batch_size, n_x, n_y, 2] 'fname': filename, str, shape [batch_size] 'slice_idx': slice_idx, torch.Tensor, shape [batch_size] 'acc': acceleration factor, torch.Tensor, shape [batch_size] 'max_value': maximum value of the magnitude image space, torch.Tensor, shape [batch_size] 'crop_size': crop size, torch.Tensor, shape [n_x, n_y] batch_idx: Batch index. int Returns ------- Dict[str, torch.Tensor], with keys, 'loss': loss, torch.Tensor, shape [1] 'log': log, dict, shape [1] """ kspace, y, sensitivity_maps, mask, init_pred, target, _, _, acc = batch y, mask, init_pred, r = self.process_inputs(y, mask, init_pred) if self.use_sens_net: sensitivity_maps = self.sens_net(kspace, mask) if self.coil_combination_method.upper() == "SENSE": target = sense( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), sensitivity_maps, dim=self.coil_dim, ) preds = self.forward(y, sensitivity_maps, mask, init_pred, target) if self.accumulate_estimates: try: preds = next(preds) except StopIteration: pass train_loss = sum( self.process_loss(target, preds, _loss_fn=self.train_loss_fn)) else: train_loss = self.process_loss(target, preds, _loss_fn=self.train_loss_fn) acc = r if r != 0 else acc tensorboard_logs = { f"train_loss_{acc}x": train_loss.item(), # type: ignore "lr": self._optimizer.param_groups[0]["lr"], # type: ignore } return {"loss": train_loss, "log": tensorboard_logs}
def evaluate( arguments, reconstruction_key, mask_background, output_path, method, acc, no_params, slice_start, slice_end, coil_dim, ): """ Evaluates the reconstructions. Parameters ---------- arguments: The CLI arguments. reconstruction_key: The key of the reconstruction to evaluate. mask_background: The background mask. output_path: The output path. method: The reconstruction method. acc: The acceleration factor. no_params: The number of parameters. slice_start: The start slice. (optional) slice_end: The end slice. (optional) coil_dim: The coil dimension. (optional) Returns ------- dict: A dict where the keys are metric names and the values are the mean of the metric. """ _metrics = Metrics(METRIC_FUNCS, output_path, method) if arguments.type == "mean_std" else {} for tgt_file in tqdm(arguments.target_path.iterdir()): if exists(arguments.predictions_path / tgt_file.name): with h5py.File(tgt_file, "r") as target, h5py.File( arguments.predictions_path / tgt_file.name, "r") as recons: kspace = target["kspace"][()] if arguments.sense_path is not None: sense = h5py.File(arguments.sense_path / tgt_file.name, "r")["sensitivity_map"][()] elif "sensitivity_map" in target: sense = target["sensitivity_map"][()] sense = sense.squeeze().astype(np.complex64) if sense.shape != kspace.shape: sense = np.transpose(sense, (0, 3, 1, 2)) target = np.abs( tensor_to_complex_np( torch.sum( complex_mul( ifft2(to_tensor(kspace), centered="fastmri" in str(arguments.sense_path).lower()), complex_conj(to_tensor(sense)), ), coil_dim, ))) recons = recons[reconstruction_key][()] if recons.ndim == 4: recons = recons.squeeze(coil_dim) if arguments.crop_size is not None: crop_size = arguments.crop_size crop_size[0] = min(target.shape[-2], int(crop_size[0])) crop_size[1] = min(target.shape[-1], int(crop_size[1])) crop_size[0] = min(recons.shape[-2], int(crop_size[0])) crop_size[1] = min(recons.shape[-1], int(crop_size[1])) target = center_crop(target, crop_size) recons = center_crop(recons, crop_size) if mask_background: for sl in range(target.shape[0]): mask = convex_hull_image( np.where( np.abs(target[sl]) > threshold_otsu( np.abs(target[sl])), 1, 0) # type: ignore ) target[sl] = target[sl] * mask recons[sl] = recons[sl] * mask if slice_start is not None: target = target[slice_start:] recons = recons[slice_start:] if slice_end is not None: target = target[:slice_end] recons = recons[:slice_end] for sl in range(target.shape[0]): target[sl] = target[sl] / np.max(np.abs(target[sl])) recons[sl] = recons[sl] / np.max(np.abs(recons[sl])) target = np.abs(target) recons = np.abs(recons) if arguments.type == "mean_std": _metrics.push(target, recons) else: _target = np.expand_dims(target, coil_dim) _recons = np.expand_dims(recons, coil_dim) for sl in range(target.shape[0]): _metrics["FNAME"] = tgt_file.name _metrics["SLICE"] = sl _metrics["ACC"] = acc _metrics["METHOD"] = method _metrics["MSE"] = [mse(target[sl], recons[sl])] _metrics["NMSE"] = [nmse(target[sl], recons[sl])] _metrics["PSNR"] = [psnr(target[sl], recons[sl])] _metrics["SSIM"] = [ssim(_target[sl], _recons[sl])] _metrics["PARAMS"] = no_params if not exists(arguments.output_path): pd.DataFrame(columns=_metrics.keys()).to_csv( arguments.output_path, index=False, mode="w") pd.DataFrame(_metrics).to_csv(arguments.output_path, index=False, header=False, mode="a") return _metrics
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ input_image = complex_mul( ifft2( torch.where(mask == 0, torch.tensor([0.0], dtype=y.dtype).to(y.device), y), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) dual_buffer = torch.cat([y] * self.num_dual, -1).to(y.device) primal_buffer = torch.cat([input_image] * self.num_primal, -1).to(y.device) for idx in range(self.num_iter): # Dual f_2 = primal_buffer[..., 2:4].clone() f_2 = torch.where( mask == 0, torch.tensor([0.0], dtype=f_2.dtype).to(f_2.device), fft2( complex_mul(f_2.unsqueeze(self.coil_dim), sensitivity_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).type(f_2.type()), ) dual_buffer = self.dual_net[idx](dual_buffer, f_2, y) # Primal h_1 = dual_buffer[..., 0:2].clone() h_1 = complex_mul( ifft2( torch.where( mask == 0, torch.tensor([0.0], dtype=h_1.dtype).to(h_1.device), h_1), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) primal_buffer = self.primal_net[idx](primal_buffer, h_1) output = primal_buffer[..., 0:2] output = (output**2).sum(-1).sqrt() _, output = center_crop_to_smallest(target, output) return output
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, **kwargs, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ previous_state: Optional[torch.Tensor] = None if self.initializer is not None: if self.initializer_initialization == "sense": initializer_input_image = (complex_mul( ifft2( y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim).unsqueeze(self.coil_dim)) elif self.initializer_initialization == "input_image": if "initial_image" not in kwargs: raise ValueError( "`'initial_image` is required as input if initializer_initialization " f"is {self.initializer_initialization}.") initializer_input_image = kwargs["initial_image"].unsqueeze( self.coil_dim) elif self.initializer_initialization == "zero_filled": initializer_input_image = ifft2( y, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) previous_state = self.initializer( fft2( initializer_input_image, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).sum(1).permute(0, 3, 1, 2)) kspace_prediction = y.clone() for step in range(self.num_steps): block = self.block_list[ step] if self.no_parameter_sharing else self.block_list[0] kspace_prediction, previous_state = block( kspace_prediction, y, mask, sensitivity_maps, previous_state, ) eta = ifft2( kspace_prediction, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) eta = coil_combination(eta, sensitivity_maps, method=self.coil_combination_method, dim=self.coil_dim) eta = torch.view_as_complex(eta) _, eta = center_crop_to_smallest(target, eta) return eta
def forward( self, y: torch.Tensor, sensitivity_maps: torch.Tensor, mask: torch.Tensor, init_pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the network. Parameters ---------- y: Subsampled k-space data. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] sensitivity_maps: Coil sensitivity maps. torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] mask: Sampling mask. torch.Tensor, shape [1, 1, n_x, n_y, 1] init_pred: Initial prediction. torch.Tensor, shape [batch_size, n_x, n_y, 2] target: Target data to compute the loss. torch.Tensor, shape [batch_size, n_x, n_y, 2] Returns ------- pred: list of torch.Tensor, shape [batch_size, n_x, n_y, 2], or torch.Tensor, shape [batch_size, n_x, n_y, 2] If self.accumulate_loss is True, returns a list of all intermediate estimates. If False, returns the final estimate. """ kspace = y.clone() zero = torch.zeros(1, 1, 1, 1, 1).to(kspace) for idx in range(self.num_iter): soft_dc = torch.where(mask.bool(), kspace - y, zero) * self.dc_weight kspace = self.kspace_model_list[idx](kspace) if kspace.shape[-1] != 2: kspace = kspace.permute(0, 1, 3, 4, 2).to(target) kspace = torch.view_as_real(kspace[..., 0] + 1j * kspace[..., 1]) # this is necessary, but why? image = complex_mul( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) image = self.image_model_list[idx](image.unsqueeze(self.coil_dim)).squeeze(self.coil_dim) if not self.no_dc: image = fft2( complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).type(image.type()) image = kspace - soft_dc - image image = complex_mul( ifft2( image, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_maps), ).sum(self.coil_dim) if idx < self.num_iter - 1: kspace = fft2( complex_mul(image.unsqueeze(self.coil_dim), sensitivity_maps), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ).type(image.type()) image = torch.view_as_complex(image) _, image = center_crop_to_smallest(target, image) return image
def forward( self, current_kspace: torch.Tensor, masked_kspace: torch.Tensor, sampling_mask: torch.Tensor, sensitivity_map: torch.Tensor, hidden_state: Union[None, torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes forward pass of RecurrentVarNetBlock. Parameters ---------- current_kspace: Current k-space prediction. torch.Tensor, shape [batch_size, n_coil, height, width, 2] masked_kspace: Subsampled k-space. torch.Tensor, shape [batch_size, n_coil, height, width, 2] sampling_mask: Sampling mask. torch.Tensor, shape [batch_size, 1, height, width, 1] sensitivity_map: Coil sensitivities. torch.Tensor, shape [batch_size, n_coil, height, width, 2] hidden_state: ConvGRU hidden state. None or torch.Tensor, shape [batch_size, n_l, height, width, hidden_channels] Returns ------- new_kspace: New k-space prediction. torch.Tensor, shape [batch_size, n_coil, height, width, 2] hidden_state: Next hidden state. list of torch.Tensor, shape [batch_size, hidden_channels, height, width, num_layers] """ kspace_error = torch.where( sampling_mask == 0, torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), current_kspace - masked_kspace, ) recurrent_term = torch.cat( [ complex_mul( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), complex_conj(sensitivity_map), ).sum(self.coil_dim) for kspace in torch.split(current_kspace, 2, -1) ], dim=-1, ).permute(0, 3, 1, 2) recurrent_term, hidden_state = self.regularizer(recurrent_term, hidden_state) # :math:`w_t`, :math:`h_{t+1}` recurrent_term = recurrent_term.permute(0, 2, 3, 1) recurrent_term = torch.cat( [ fft2( complex_mul(image.unsqueeze(self.coil_dim), sensitivity_map), centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ) for image in torch.split(recurrent_term, 2, -1) ], dim=-1, ) new_kspace = current_kspace - self.learning_rate * kspace_error + recurrent_term return new_kspace, hidden_state
def test_step(self, batch: Dict[float, torch.Tensor], batch_idx: int) -> Tuple[str, int, torch.Tensor]: """ Test step. Parameters ---------- batch: Batch of data. Dict of torch.Tensor, shape [batch_size, n_coils, n_x, n_y, 2] batch_idx: Batch index. int Returns ------- name: Name of the volume. str slice_num: Slice number. int pred: Predicted data. torch.Tensor, shape [batch_size, n_x, n_y, 2] """ kspace, y, sensitivity_maps, mask, _, target, fname, slice_num, _ = batch y, mask, _ = self.process_inputs(y, mask) if self.use_sens_net: sensitivity_maps = self.sens_net(kspace, mask) if self.coil_combination_method.upper() == "SENSE": target = sense( ifft2( kspace, centered=self.fft_centered, normalization=self.fft_normalization, spatial_dims=self.spatial_dims, ), sensitivity_maps, dim=self.coil_dim, ) y = torch.view_as_complex(y).permute(0, 2, 3, 1).detach().cpu().numpy() if sensitivity_maps is None and not self.sens_net: raise ValueError( "Sensitivity maps are required for PICS. " "Please set use_sens_net to True if you precomputed sensitivity maps are not available." ) sensitivity_maps = torch.view_as_complex(sensitivity_maps) if self.fft_type != "orthogonal": sensitivity_maps = torch.fft.fftshift(sensitivity_maps, dim=(-2, -1)) sensitivity_maps = sensitivity_maps.permute( 0, 2, 3, 1).detach().cpu().numpy() # type: ignore prediction = torch.from_numpy( self.forward(y, sensitivity_maps, mask, target)).unsqueeze(0) if self.fft_type != "orthogonal": prediction = torch.fft.fftshift(prediction, dim=(-2, -1)) slice_num = int(slice_num) name = str(fname[0]) # type: ignore key = f"{name}_images_idx_{slice_num}" # type: ignore output = torch.abs(prediction).detach().cpu() target = torch.abs(target).detach().cpu() output = output / output.max() # type: ignore target = target / target.max() # type: ignore error = torch.abs(target - output) self.log_image(f"{key}/target", target) self.log_image(f"{key}/reconstruction", output) self.log_image(f"{key}/error", error) return name, slice_num, prediction.detach().cpu().numpy()