Пример #1
0
def test_centered_fft2_forward_normalization(shape):
    """
    Test centered 2D Fast Fourier Transform with forward normalization.

    Args:
        shape: shape of the input

    Returns:
        None
    """
    shape = shape + [2]
    x = create_input(shape)
    out_torch = fft2(x,
                     centered=True,
                     normalization="forward",
                     spatial_dims=[-2, -1]).numpy()
    out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

    input_numpy = tensor_to_complex_np(x)
    input_numpy = np.fft.ifftshift(input_numpy, (-2, -1))
    out_numpy = np.fft.fft2(input_numpy, norm="forward")
    out_numpy = np.fft.fftshift(out_numpy, (-2, -1))

    if not np.allclose(out_torch, out_numpy):
        raise AssertionError
Пример #2
0
def test_tensor_to_complex_np(x):
    """
    Test if the tensor_to_complex_np function works as expected.

    Args:
        x: The input array.

    Returns:
        None
    """
    x = tensor_to_complex_np(x)
    if x.ndim != 3:
        raise AssertionError
    if x.shape[-1] == 2:
        raise AssertionError
Пример #3
0
def test_complex_abs(shape):
    """
    Test complex absolute value.

    Args:
        shape: shape of the input

    Returns:
        None
    """
    shape = shape + [2]
    x = create_input(shape)
    out_torch = complex_abs(x).numpy()
    input_numpy = tensor_to_complex_np(x)
    out_numpy = np.abs(input_numpy)

    if not np.allclose(out_torch, out_numpy):
        raise AssertionError
Пример #4
0
def test_non_centered_fft2(shape):
    """
    Test non-centered 2D Fast Fourier Transform.

    Args:
        shape: shape of the input

    Returns:
        None
    """
    shape = shape + [2]
    x = create_input(shape)
    out_torch = fft2(x,
                     centered=False,
                     normalization="ortho",
                     spatial_dims=[-2, -1]).numpy()
    out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

    input_numpy = tensor_to_complex_np(x)
    out_numpy = np.fft.fft2(input_numpy, norm="ortho")

    if not np.allclose(out_torch, out_numpy):
        raise AssertionError
Пример #5
0
def evaluate(
    arguments,
    reconstruction_key,
    mask_background,
    output_path,
    method,
    acc,
    no_params,
    slice_start,
    slice_end,
    coil_dim,
):
    """
    Evaluates the reconstructions.

    Parameters
    ----------
    arguments: The CLI arguments.
    reconstruction_key: The key of the reconstruction to evaluate.
    mask_background: The background mask.
    output_path: The output path.
    method: The reconstruction method.
    acc: The acceleration factor.
    no_params: The number of parameters.
    slice_start: The start slice. (optional)
    slice_end: The end slice. (optional)
    coil_dim: The coil dimension. (optional)

    Returns
    -------
    dict: A dict where the keys are metric names and the values are the mean of the metric.
    """
    _metrics = Metrics(METRIC_FUNCS, output_path,
                       method) if arguments.type == "mean_std" else {}

    for tgt_file in tqdm(arguments.target_path.iterdir()):
        if exists(arguments.predictions_path / tgt_file.name):
            with h5py.File(tgt_file, "r") as target, h5py.File(
                    arguments.predictions_path / tgt_file.name, "r") as recons:
                kspace = target["kspace"][()]

                if arguments.sense_path is not None:
                    sense = h5py.File(arguments.sense_path / tgt_file.name,
                                      "r")["sensitivity_map"][()]
                elif "sensitivity_map" in target:
                    sense = target["sensitivity_map"][()]

                sense = sense.squeeze().astype(np.complex64)

                if sense.shape != kspace.shape:
                    sense = np.transpose(sense, (0, 3, 1, 2))

                target = np.abs(
                    tensor_to_complex_np(
                        torch.sum(
                            complex_mul(
                                ifft2(to_tensor(kspace),
                                      centered="fastmri"
                                      in str(arguments.sense_path).lower()),
                                complex_conj(to_tensor(sense)),
                            ),
                            coil_dim,
                        )))

                recons = recons[reconstruction_key][()]

                if recons.ndim == 4:
                    recons = recons.squeeze(coil_dim)

                if arguments.crop_size is not None:
                    crop_size = arguments.crop_size
                    crop_size[0] = min(target.shape[-2], int(crop_size[0]))
                    crop_size[1] = min(target.shape[-1], int(crop_size[1]))
                    crop_size[0] = min(recons.shape[-2], int(crop_size[0]))
                    crop_size[1] = min(recons.shape[-1], int(crop_size[1]))

                    target = center_crop(target, crop_size)
                    recons = center_crop(recons, crop_size)

                if mask_background:
                    for sl in range(target.shape[0]):
                        mask = convex_hull_image(
                            np.where(
                                np.abs(target[sl]) > threshold_otsu(
                                    np.abs(target[sl])), 1, 0)  # type: ignore
                        )
                        target[sl] = target[sl] * mask
                        recons[sl] = recons[sl] * mask

                if slice_start is not None:
                    target = target[slice_start:]
                    recons = recons[slice_start:]

                if slice_end is not None:
                    target = target[:slice_end]
                    recons = recons[:slice_end]

                for sl in range(target.shape[0]):
                    target[sl] = target[sl] / np.max(np.abs(target[sl]))
                    recons[sl] = recons[sl] / np.max(np.abs(recons[sl]))

                target = np.abs(target)
                recons = np.abs(recons)

                if arguments.type == "mean_std":
                    _metrics.push(target, recons)
                else:
                    _target = np.expand_dims(target, coil_dim)
                    _recons = np.expand_dims(recons, coil_dim)
                    for sl in range(target.shape[0]):
                        _metrics["FNAME"] = tgt_file.name
                        _metrics["SLICE"] = sl
                        _metrics["ACC"] = acc
                        _metrics["METHOD"] = method
                        _metrics["MSE"] = [mse(target[sl], recons[sl])]
                        _metrics["NMSE"] = [nmse(target[sl], recons[sl])]
                        _metrics["PSNR"] = [psnr(target[sl], recons[sl])]
                        _metrics["SSIM"] = [ssim(_target[sl], _recons[sl])]
                        _metrics["PARAMS"] = no_params

                        if not exists(arguments.output_path):
                            pd.DataFrame(columns=_metrics.keys()).to_csv(
                                arguments.output_path, index=False, mode="w")
                        pd.DataFrame(_metrics).to_csv(arguments.output_path,
                                                      index=False,
                                                      header=False,
                                                      mode="a")

    return _metrics