def test_mri_signal_timecourse(
    t1: float,
    t2: float,
    m0: float,
    t2_star: float,
    acq_contrast: str,
    echo_time: float,
    repetition_time: float,
    expected: float,
):
    """Tests the MriSignalFilter with timecourse data that is generated at multiple echo times

    Args:
        t1 (float): longitudinal relaxation time, s
        t2 (float): transverse relaxation time, s
        m0 (float): equilibrium magnetisation
        t2_star (float): transverse relaxation time inc. time invariant fields, s
        acq_contrast (str): signal model to use: 'ge' or 'se'
        echo_time (float): array of echo times, s
        repetition_time (float): repeat time, s
        expected (float): Array of expected values that the MriSignalFilter should generate
        Should be the same size and shape as 'echo_time'
    """
    mri_signal_timecourse = np.ndarray(echo_time.shape)
    for idx, te in np.ndenumerate(echo_time):
        params = {
            "t1": NumpyImageContainer(image=np.full((1, 1, 1), t1)),
            "t2": NumpyImageContainer(image=np.full((1, 1, 1), t2)),
            "t2_star": NumpyImageContainer(image=np.full((1, 1, 1), t2_star)),
            "m0": NumpyImageContainer(image=np.full((1, 1, 1), m0)),
            "mag_enc": NumpyImageContainer(image=np.zeros((1, 1, 1))),
            "acq_contrast": acq_contrast,
            "echo_time": te,
            "repetition_time": repetition_time,
            "excitation_flip_angle": 90.0,
        }

        mri_signal_filter = MriSignalFilter()
        mri_signal_filter = add_multiple_inputs_to_filter(
            mri_signal_filter, params)
        mri_signal_filter.run()
        mri_signal_timecourse[idx] = mri_signal_filter.outputs["image"].image
    # arrays should be equal to 9 decimal places
    numpy.testing.assert_array_almost_equal(mri_signal_timecourse, expected, 9)
def test_mri_signal_filter_inversion_recovery(mock_data):
    """ Tests the MriSignalFilter for 'acq_contrast' == 'ir':
    Inversion Recovery """
    mock_data["acq_contrast"] = "ir"
    mock_data["inversion_flip_angle"] = 180.0
    mock_data["inversion_time"] = 1.0
    mock_data["repetition_time"] = 1.1
    mri_signal_filter = MriSignalFilter()
    mri_signal_filter = add_multiple_inputs_to_filter(mri_signal_filter,
                                                      mock_data)
    mri_signal_filter.run()

    ir_signal = mri_signal_inversion_recovery_function(mock_data)
    numpy.testing.assert_array_equal(ir_signal,
                                     mri_signal_filter.outputs["image"].image)

    assert mri_signal_filter.outputs["image"].metadata == {
        "acq_contrast": mock_data["acq_contrast"],
        "echo_time": mock_data["echo_time"],
        "repetition_time": mock_data["repetition_time"],
        "excitation_flip_angle": mock_data["excitation_flip_angle"],
        "image_flavour": "OTHER",
        "inversion_time": mock_data["inversion_time"],
        "inversion_flip_angle": mock_data["inversion_flip_angle"],
        "mr_acq_type": "3D",
    }

    # edit mock_data["mag_enc"].metadata["image_flavour"] and check
    mock_data["mag_enc"].metadata["image_flavour"] = "PERFUSION"
    mri_signal_filter = MriSignalFilter()
    mri_signal_filter = add_multiple_inputs_to_filter(mri_signal_filter,
                                                      mock_data)
    mri_signal_filter.run()
    assert mri_signal_filter.outputs["image"].metadata == {
        "acq_contrast": mock_data["acq_contrast"],
        "echo_time": mock_data["echo_time"],
        "repetition_time": mock_data["repetition_time"],
        "excitation_flip_angle": mock_data["excitation_flip_angle"],
        "image_flavour": "PERFUSION",
        "inversion_time": mock_data["inversion_time"],
        "inversion_flip_angle": mock_data["inversion_flip_angle"],
        "mr_acq_type": "3D",
    }
def test_mri_signal_filter_validate_inputs_ge_no_t2_star():
    """ Checks a FilterInputValidationError is raised when
    'acq_contrast' == 'ge' and 't2_star' is not supplied """
    test_data = deepcopy(TEST_DATA_DICT_GE)
    # remove the 't2_star' entry
    test_data.pop("t2_star")
    mri_signal_filter = MriSignalFilter()
    for data_key in test_data:
        mri_signal_filter.add_input(data_key, test_data[data_key][0])

    with pytest.raises(FilterInputValidationError):
        mri_signal_filter.run()
def test_mri_signal_filter_image_flavour(mock_data):
    """ Tests the MriSignalFilter when the input "image_flavour" is changed """
    # check overrides no supplied mag_enc

    test_data = deepcopy(mock_data)
    test_data.pop("mag_enc")
    mock_data["image_flavour"] = "ABCD"
    mri_signal_filter = MriSignalFilter()
    mri_signal_filter = add_multiple_inputs_to_filter(mri_signal_filter,
                                                      mock_data)
    mri_signal_filter.run()

    assert mri_signal_filter.outputs["image"].metadata == {
        "acq_contrast": mock_data["acq_contrast"],
        "echo_time": mock_data["echo_time"],
        "repetition_time": mock_data["repetition_time"],
        "excitation_flip_angle": mock_data["excitation_flip_angle"],
        "image_flavour": "ABCD",
        "mr_acq_type": "3D",
    }

    test_data = deepcopy(mock_data)
    mock_data["mag_enc"].metadata["image_flavour"] = "PERFUSION"
    mock_data["image_flavour"] = "ABCD"
    mri_signal_filter = MriSignalFilter()
    mri_signal_filter = add_multiple_inputs_to_filter(mri_signal_filter,
                                                      mock_data)
    mri_signal_filter.run()

    assert mri_signal_filter.outputs["image"].metadata == {
        "acq_contrast": mock_data["acq_contrast"],
        "echo_time": mock_data["echo_time"],
        "repetition_time": mock_data["repetition_time"],
        "excitation_flip_angle": mock_data["excitation_flip_angle"],
        "image_flavour": "ABCD",
        "mr_acq_type": "3D",
    }
def test_mri_signal_filter_spin_echo(mock_data):
    """ Tests the MriSignalFilter for 'acq_contrast' == 'se':
    Spin Echo """
    mock_data["acq_contrast"] = "se"
    mri_signal_filter = MriSignalFilter()
    mri_signal_filter = add_multiple_inputs_to_filter(mri_signal_filter,
                                                      mock_data)
    mri_signal_filter.run()

    se_signal = mri_signal_spin_echo_function(mock_data)
    numpy.testing.assert_array_equal(se_signal,
                                     mri_signal_filter.outputs["image"].image)

    assert mri_signal_filter.outputs["image"].metadata == {
        "acq_contrast": mock_data["acq_contrast"],
        "echo_time": mock_data["echo_time"],
        "repetition_time": mock_data["repetition_time"],
        "excitation_flip_angle": mock_data["excitation_flip_angle"],
        "image_flavour": "OTHER",
        "mr_acq_type": "3D",
    }

    # edit mock_data["mag_enc"].metadata["image_flavour"] and check
    mock_data["mag_enc"].metadata["image_flavour"] = "PERFUSION"
    mri_signal_filter = MriSignalFilter()
    mri_signal_filter = add_multiple_inputs_to_filter(mri_signal_filter,
                                                      mock_data)
    mri_signal_filter.run()
    assert mri_signal_filter.outputs["image"].metadata == {
        "acq_contrast": mock_data["acq_contrast"],
        "echo_time": mock_data["echo_time"],
        "repetition_time": mock_data["repetition_time"],
        "excitation_flip_angle": mock_data["excitation_flip_angle"],
        "image_flavour": "PERFUSION",
        "mr_acq_type": "3D",
    }
Esempio n. 6
0
    def _create_filter_block(self):
        """Runs:
        1. MriSignalFilter
        2. TransformResampleFilter
        3. AddComplexNoiseFilter

        Returns AddComplexNoiseFilter
        """

        add_complex_noise_filter = AddComplexNoiseFilter()

        # MriSignalFilter
        # add required inputs - these should always be present
        mri_signal_filter = MriSignalFilter()
        mri_signal_filter.add_input(MriSignalFilter.KEY_T1, self.inputs[self.KEY_T1])
        mri_signal_filter.add_input(MriSignalFilter.KEY_T2, self.inputs[self.KEY_T2])
        mri_signal_filter.add_input(
            MriSignalFilter.KEY_T2_STAR, self.inputs[self.KEY_T2_STAR]
        )
        mri_signal_filter.add_input(MriSignalFilter.KEY_M0, self.inputs[self.KEY_M0])
        mri_signal_filter.add_input(
            MriSignalFilter.KEY_ACQ_CONTRAST, self.inputs[self.KEY_ACQ_CONTRAST]
        )
        mri_signal_filter.add_input(
            MriSignalFilter.KEY_ECHO_TIME, self.inputs[self.KEY_ECHO_TIME]
        )
        mri_signal_filter.add_input(
            MriSignalFilter.KEY_REPETITION_TIME, self.inputs[self.KEY_REPETITION_TIME]
        )

        mri_signal_filter.add_input(
            MriSignalFilter.KEY_EXCITATION_FLIP_ANGLE,
            self.inputs[self.KEY_EXCITATION_FLIP_ANGLE],
        )
        # add optional inputs if present
        if self.inputs.get(self.KEY_MAG_ENC) is not None:
            mri_signal_filter.add_input(
                MriSignalFilter.KEY_MAG_ENC, self.inputs[self.KEY_MAG_ENC]
            )

        if self.inputs.get(self.KEY_INVERSION_FLIP_ANGLE) is not None:
            mri_signal_filter.add_input(
                MriSignalFilter.KEY_INVERSION_FLIP_ANGLE,
                self.inputs[self.KEY_INVERSION_FLIP_ANGLE],
            )
        if self.inputs.get(self.KEY_INVERSION_TIME) is not None:
            mri_signal_filter.add_input(
                MriSignalFilter.KEY_INVERSION_TIME, self.inputs[self.KEY_INVERSION_TIME]
            )
        if self.inputs.get(self.KEY_IMAGE_FLAVOUR) is not None:
            mri_signal_filter.add_input(
                MriSignalFilter.KEY_IMAGE_FLAVOUR, self.inputs[self.KEY_IMAGE_FLAVOUR]
            )

        # TransformResampleImageFilter
        transform_resample_image_filter = TransformResampleImageFilter()
        # Add mri_signal_filter as parent
        transform_resample_image_filter.add_parent_filter(mri_signal_filter)
        # all other parameters are optional
        if self.inputs.get(self.KEY_ROTATION) is not None:
            transform_resample_image_filter.add_input(
                TransformResampleImageFilter.KEY_ROTATION,
                self.inputs[self.KEY_ROTATION],
            )
        if self.inputs.get(self.KEY_ROTATION_ORIGIN) is not None:
            transform_resample_image_filter.add_input(
                TransformResampleImageFilter.KEY_ROTATION_ORIGIN,
                self.inputs[self.KEY_ROTATION_ORIGIN],
            )
        if self.inputs.get(self.KEY_TARGET_SHAPE) is not None:
            transform_resample_image_filter.add_input(
                TransformResampleImageFilter.KEY_TARGET_SHAPE,
                self.inputs[self.KEY_TARGET_SHAPE],
            )
        if self.inputs.get(self.KEY_TRANSLATION) is not None:
            transform_resample_image_filter.add_input(
                TransformResampleImageFilter.KEY_TRANSLATION,
                self.inputs[self.KEY_TRANSLATION],
            )

        # AddComplexNoiseFilter
        add_complex_noise_filter = AddComplexNoiseFilter()
        # add transform_resample_image_filter as parent
        add_complex_noise_filter.add_parent_filter(transform_resample_image_filter)
        # add required inputs - these should always be present
        add_complex_noise_filter.add_input(self.KEY_SNR, self.inputs[self.KEY_SNR])
        # add optional input ref_image
        if self.inputs.get(self.KEY_REF_IMAGE) is not None:
            add_complex_noise_filter.add_input(
                AddComplexNoiseFilter.KEY_REF_IMAGE,
                self.inputs[self.KEY_REF_IMAGE],
            )

        # return add_complex_noise_filter
        return add_complex_noise_filter
def test_mri_signal_timecourse_inversion_recovery(
    t1: float,
    t2: float,
    m0: float,
    t2_star: float,
    echo_time: float,
    repetition_time: float,
    flip_angle: float,
    inversion_angle: float,
    inversion_time: float,
    expected: float,
):
    """Tests the MriSignalFilter inversion recovery signal over a range of TI's.

    :param t1: longitudinal relaxation time, s
    :type t1: float
    :param t2: transverse relaxation time, s
    :type t2: float
    :param m0: equilibrium magnetisation
    :type m0: float
    :param t2_star: transverse relaxation time inc. time invariant fields, s
    :type t2_star: float


    :param echo_time: the echo time, s
    :type echo_time: float
    :param repetition_time: the repetition time, s
    :type repetition_time: float
    :param flip_angle: the excitation pulse flip angle, degrees
    :type flip_angle: float
    :param inversion_angle: the inversion pulse flip angle, degrees
    :type inversion_angle: float
    :param inversion_time: array of durations between the inversion pulse and excitation pulse, s
    :type inversion_time: float
    :param expected: array of expected valuesm, same length as `inversion_time`
    :type expected: float

    """

    mri_signal_timecourse = np.ndarray(inversion_time.shape)
    for idx, ti in np.ndenumerate(inversion_time):
        params = {
            "t1": NumpyImageContainer(image=np.full((1, 1, 1), t1)),
            "t2": NumpyImageContainer(image=np.full((1, 1, 1), t2)),
            "t2_star": NumpyImageContainer(image=np.full((1, 1, 1), t2_star)),
            "m0": NumpyImageContainer(image=np.full((1, 1, 1), m0)),
            "acq_contrast": "ir",
            "echo_time": echo_time,
            "repetition_time": repetition_time,
            "excitation_flip_angle": flip_angle,
            "inversion_flip_angle": inversion_angle,
            "inversion_time": ti,
        }

        mri_signal_filter = MriSignalFilter()
        mri_signal_filter = add_multiple_inputs_to_filter(
            mri_signal_filter, params)
        mri_signal_filter.run()
        mri_signal_timecourse[idx] = mri_signal_filter.outputs["image"].image
    # arrays should be equal to 9 decimal places
    numpy.testing.assert_array_almost_equal(mri_signal_timecourse, expected, 9)
def test_mri_signal_filter_validate_inputs(validation_data: dict):
    """ Check a FilterInputValidationError is raised when the
    inputs to the MriSignalFilter are incorrect or missing
    """
    for inputs_key in validation_data:
        mri_signal_filter = MriSignalFilter()
        test_data = deepcopy(validation_data)
        # remove the corresponding key from test_data
        test_data.pop(inputs_key)

        for data_key in test_data:
            mri_signal_filter.add_input(data_key, test_data[data_key][0])

        # Key not defined

        with pytest.raises(FilterInputValidationError):
            mri_signal_filter.run()

        # Key has wrong data type
        mri_signal_filter.add_input(inputs_key, None)
        with pytest.raises(FilterInputValidationError):
            mri_signal_filter.run()

        # Data not in the valid range
        for test_value in validation_data[inputs_key][1:]:
            # re-initialise filter
            mri_signal_filter = MriSignalFilter()

            # add valid inputs
            for data_key in test_data:
                mri_signal_filter.add_input(data_key, test_data[data_key][0])

            # add invalid input and check a FilterInputValidationError is raised
            mri_signal_filter.add_input(inputs_key, test_value)
            with pytest.raises(FilterInputValidationError):
                mri_signal_filter.run()

    # Check optional parameters
    # 'mag_enc': optional
    test_data = deepcopy(validation_data)
    mri_signal_filter = MriSignalFilter()
    # add passing data
    for data_key in test_data:
        mri_signal_filter.add_input(data_key, test_data[data_key][0])

    # Wrong data type, should fail
    mri_signal_filter.add_input("mag_enc", "str")
    with pytest.raises(FilterInputValidationError):
        mri_signal_filter.run()

    # Numerically out-of-bounds
    mri_signal_filter = MriSignalFilter()
    # add passing data
    for data_key in test_data:
        mri_signal_filter.add_input(data_key, test_data[data_key][0])

    mri_signal_filter.add_input("mag_enc", TEST_IMAGE_COMPLEX)
    # negative values not allowed, so should fail
    with pytest.raises(FilterInputValidationError):
        mri_signal_filter.run()

    # Check correct use
    mri_signal_filter = MriSignalFilter()
    # add passing data
    for data_key in test_data:
        mri_signal_filter.add_input(data_key, test_data[data_key][0])

    mri_signal_filter.add_input("mag_enc", TEST_IMAGE_ONES)
    # Should run normally
    mri_signal_filter.run()

    # 'image_flavour': optional
    test_data = deepcopy(validation_data)
    mri_signal_filter = MriSignalFilter()
    # add passing data
    for data_key in test_data:
        mri_signal_filter.add_input(data_key, test_data[data_key][0])

    # Wrong data type, should fail
    mri_signal_filter.add_input("image_flavour", 0)
    with pytest.raises(FilterInputValidationError):
        mri_signal_filter.run()

    # Check correct use
    mri_signal_filter = MriSignalFilter()
    # add passing data
    for data_key in test_data:
        mri_signal_filter.add_input(data_key, test_data[data_key][0])

    mri_signal_filter.add_input("image_flavour", "PERFUSION")
    # Should run normally
    mri_signal_filter.run()