def test_centered_fft2_forward_normalization(shape): """ Test centered 2D Fast Fourier Transform with forward normalization. Args: shape: shape of the input Returns: None """ shape = shape + [2] x = create_input(shape) out_torch = fft2(x, centered=True, normalization="forward", spatial_dims=[-2, -1]).numpy() out_torch = out_torch[..., 0] + 1j * out_torch[..., 1] input_numpy = tensor_to_complex_np(x) input_numpy = np.fft.ifftshift(input_numpy, (-2, -1)) out_numpy = np.fft.fft2(input_numpy, norm="forward") out_numpy = np.fft.fftshift(out_numpy, (-2, -1)) if not np.allclose(out_torch, out_numpy): raise AssertionError
def test_tensor_to_complex_np(x): """ Test if the tensor_to_complex_np function works as expected. Args: x: The input array. Returns: None """ x = tensor_to_complex_np(x) if x.ndim != 3: raise AssertionError if x.shape[-1] == 2: raise AssertionError
def test_complex_abs(shape): """ Test complex absolute value. Args: shape: shape of the input Returns: None """ shape = shape + [2] x = create_input(shape) out_torch = complex_abs(x).numpy() input_numpy = tensor_to_complex_np(x) out_numpy = np.abs(input_numpy) if not np.allclose(out_torch, out_numpy): raise AssertionError
def test_non_centered_fft2(shape): """ Test non-centered 2D Fast Fourier Transform. Args: shape: shape of the input Returns: None """ shape = shape + [2] x = create_input(shape) out_torch = fft2(x, centered=False, normalization="ortho", spatial_dims=[-2, -1]).numpy() out_torch = out_torch[..., 0] + 1j * out_torch[..., 1] input_numpy = tensor_to_complex_np(x) out_numpy = np.fft.fft2(input_numpy, norm="ortho") if not np.allclose(out_torch, out_numpy): raise AssertionError
def evaluate( arguments, reconstruction_key, mask_background, output_path, method, acc, no_params, slice_start, slice_end, coil_dim, ): """ Evaluates the reconstructions. Parameters ---------- arguments: The CLI arguments. reconstruction_key: The key of the reconstruction to evaluate. mask_background: The background mask. output_path: The output path. method: The reconstruction method. acc: The acceleration factor. no_params: The number of parameters. slice_start: The start slice. (optional) slice_end: The end slice. (optional) coil_dim: The coil dimension. (optional) Returns ------- dict: A dict where the keys are metric names and the values are the mean of the metric. """ _metrics = Metrics(METRIC_FUNCS, output_path, method) if arguments.type == "mean_std" else {} for tgt_file in tqdm(arguments.target_path.iterdir()): if exists(arguments.predictions_path / tgt_file.name): with h5py.File(tgt_file, "r") as target, h5py.File( arguments.predictions_path / tgt_file.name, "r") as recons: kspace = target["kspace"][()] if arguments.sense_path is not None: sense = h5py.File(arguments.sense_path / tgt_file.name, "r")["sensitivity_map"][()] elif "sensitivity_map" in target: sense = target["sensitivity_map"][()] sense = sense.squeeze().astype(np.complex64) if sense.shape != kspace.shape: sense = np.transpose(sense, (0, 3, 1, 2)) target = np.abs( tensor_to_complex_np( torch.sum( complex_mul( ifft2(to_tensor(kspace), centered="fastmri" in str(arguments.sense_path).lower()), complex_conj(to_tensor(sense)), ), coil_dim, ))) recons = recons[reconstruction_key][()] if recons.ndim == 4: recons = recons.squeeze(coil_dim) if arguments.crop_size is not None: crop_size = arguments.crop_size crop_size[0] = min(target.shape[-2], int(crop_size[0])) crop_size[1] = min(target.shape[-1], int(crop_size[1])) crop_size[0] = min(recons.shape[-2], int(crop_size[0])) crop_size[1] = min(recons.shape[-1], int(crop_size[1])) target = center_crop(target, crop_size) recons = center_crop(recons, crop_size) if mask_background: for sl in range(target.shape[0]): mask = convex_hull_image( np.where( np.abs(target[sl]) > threshold_otsu( np.abs(target[sl])), 1, 0) # type: ignore ) target[sl] = target[sl] * mask recons[sl] = recons[sl] * mask if slice_start is not None: target = target[slice_start:] recons = recons[slice_start:] if slice_end is not None: target = target[:slice_end] recons = recons[:slice_end] for sl in range(target.shape[0]): target[sl] = target[sl] / np.max(np.abs(target[sl])) recons[sl] = recons[sl] / np.max(np.abs(recons[sl])) target = np.abs(target) recons = np.abs(recons) if arguments.type == "mean_std": _metrics.push(target, recons) else: _target = np.expand_dims(target, coil_dim) _recons = np.expand_dims(recons, coil_dim) for sl in range(target.shape[0]): _metrics["FNAME"] = tgt_file.name _metrics["SLICE"] = sl _metrics["ACC"] = acc _metrics["METHOD"] = method _metrics["MSE"] = [mse(target[sl], recons[sl])] _metrics["NMSE"] = [nmse(target[sl], recons[sl])] _metrics["PSNR"] = [psnr(target[sl], recons[sl])] _metrics["SSIM"] = [ssim(_target[sl], _recons[sl])] _metrics["PARAMS"] = no_params if not exists(arguments.output_path): pd.DataFrame(columns=_metrics.keys()).to_csv( arguments.output_path, index=False, mode="w") pd.DataFrame(_metrics).to_csv(arguments.output_path, index=False, header=False, mode="a") return _metrics