def test_greater_than_equal_to_validator_creator():
    """Check the greater_than_equal_to_validator creator raises
    errors when start is not a number type"""
    with pytest.raises(TypeError):
        greater_than_equal_to_validator("str")
    with pytest.raises(TypeError):
        greater_than_equal_to_validator([])
def test_greater_than_equal_to_validator():
    """ Test the greater_than_equal_to_validator with some values """
    validator = greater_than_equal_to_validator(100)
    assert str(validator) == "Value(s) must be greater than or equal to 100"
    assert validator(101)
    assert validator(1000)
    assert validator(float("inf"))
    assert not validator(99)
    assert validator(100)
    assert not validator(float("-inf"))
    assert not validator("not a number")
def test_greater_than_equal_to_validator_image_container():
    """ Test the greater than equal to validator with an image container """
    validator = greater_than_equal_to_validator(1.5)
    assert str(validator) == "Value(s) must be greater than or equal to 1.5"

    image_container = NumpyImageContainer(
        image=np.array([[-0.5, 0.2], [0.1, -0.9]]))
    assert not validator(image_container)
    image_container = NumpyImageContainer(
        image=np.array([[1.5, 2.2], [1.7, 90]]))
    assert validator(image_container)
    image_container = NumpyImageContainer(
        image=np.array([[1.51, 2.2], [1.7, 90]]))
    assert validator(image_container)
    def _validate_inputs(self):
        """
        'image' must be derived from BaseImageContainer
        'snr' must be a positive float
        'reference_image' if present must be derived from BaseImageContainer.
        image.shape and reference_image.shape must be equal
        """
        input_validator = ParameterValidator(
            parameters={
                self.KEY_IMAGE:
                Parameter(validators=isinstance_validator(BaseImageContainer)),
                self.KEY_SNR:
                Parameter(validators=[
                    isinstance_validator(float),
                    greater_than_equal_to_validator(0),
                ]),
                self.KEY_REF_IMAGE:
                Parameter(validators=isinstance_validator(BaseImageContainer),
                          optional=True),
            })

        input_validator.validate(self.inputs,
                                 error_type=FilterInputValidationError)

        # If 'reference_image' is supplied, check that its dimensions match 'image'
        if self.KEY_REF_IMAGE in self.inputs:
            input_reference_image = self.inputs[self.KEY_REF_IMAGE]
            input_image = self.inputs[self.KEY_IMAGE]
            if not isinstance(input_reference_image, BaseImageContainer):
                raise FilterInputValidationError(
                    f"Input 'reference_image' is not a BaseImageContainer"
                    f"(is {type(input_reference_image)})")
            if not input_image.shape == input_reference_image.shape:
                raise FilterInputValidationError(
                    f"Shape of inputs 'image' and 'reference_image' are not equal"
                    f"Shape of 'image' is {input_image.shape}"
                    f"Shape of 'reference_image' is {input_reference_image.shape}"
                )
 def _validate_inputs(self):
     """
     'image' must be derived from BaseImageContainer.
     'snr' must be a float and >= 0
     'reference_image' if present must be derived from BaseImageContainer
     """
     input_validator = ParameterValidator(
         parameters={
             self.KEY_IMAGE: Parameter(
                 validators=isinstance_validator(BaseImageContainer)
             ),
             self.KEY_SNR: Parameter(
                 validators=[
                     isinstance_validator(float),
                     greater_than_equal_to_validator(0),
                 ]
             ),
             self.KEY_REF_IMAGE: Parameter(
                 validators=isinstance_validator(BaseImageContainer), optional=True
             ),
         }
     )
     input_validator.validate(self.inputs, error_type=FilterInputValidationError)
    def _validate_inputs(self):
        """ Checks that the inputs meet their validation critera
        't1' must be derived from BaseImageContainer, >=0, and non-complex
        't2' must be derived from BaseImageContainer, >=0, and non-complex
        't2_star' must be derived from BaseImageContainer, >=0, and non-complex
            Only required if 'acq_contrast' == 'ge'
        'm0' must be derived from BaseImageContainer, >=0, and non-complex
        'mag_enc' (optional) must be derived from BaseImageContainer and non-complex
        'acq_contrast' must be a string and equal to "ge" or "se" (case insensitive)
        'echo_time' must be a float and >= 0
        'repetition_time' must be a float and >= 0
        'excitation_flip_angle' must be a float and >=0
        'inversion_flip_angle' must be a float and >=0
        'inversion_time' must be a float and >=0

        All images must have the same dimensions

        """
        input_validator = ParameterValidator(
            parameters={
                self.KEY_M0:
                Parameter(validators=[
                    isinstance_validator(BaseImageContainer),
                    greater_than_equal_to_validator(0),
                ]),
                self.KEY_T1:
                Parameter(validators=[
                    isinstance_validator(BaseImageContainer),
                    greater_than_equal_to_validator(0),
                ]),
                self.KEY_T2:
                Parameter(validators=[
                    isinstance_validator(BaseImageContainer),
                    greater_than_equal_to_validator(0),
                ]),
                self.KEY_T2_STAR:
                Parameter(
                    validators=[
                        isinstance_validator(BaseImageContainer),
                        greater_than_equal_to_validator(0),
                    ],
                    optional=True,
                ),
                self.KEY_MAG_ENC:
                Parameter(
                    validators=[isinstance_validator(BaseImageContainer)],
                    optional=True),
                self.KEY_ACQ_CONTRAST:
                Parameter(validators=[
                    isinstance_validator(str),
                    from_list_validator(
                        [self.CONTRAST_GE, self.CONTRAST_SE, self.CONTRAST_IR],
                        case_insensitive=True,
                    ),
                ]),
                self.KEY_ECHO_TIME:
                Parameter(validators=[
                    isinstance_validator(float),
                    greater_than_equal_to_validator(0),
                ]),
                self.KEY_REPETITION_TIME:
                Parameter(validators=[
                    isinstance_validator(float),
                    greater_than_equal_to_validator(0),
                ]),
                self.KEY_EXCITATION_FLIP_ANGLE:
                Parameter(
                    validators=[
                        isinstance_validator(float),
                    ],
                    optional=True,
                ),
                self.KEY_INVERSION_FLIP_ANGLE:
                Parameter(
                    validators=[
                        isinstance_validator(float),
                    ],
                    optional=True,
                ),
                self.KEY_INVERSION_TIME:
                Parameter(
                    validators=[
                        isinstance_validator(float),
                        greater_than_equal_to_validator(0),
                    ],
                    optional=True,
                ),
                self.KEY_IMAGE_FLAVOUR:
                Parameter(validators=[
                    isinstance_validator(str),
                ],
                          optional=True),
            })
        input_validator.validate(self.inputs,
                                 error_type=FilterInputValidationError)

        # Parameters that are conditionally required based on the value of "acq_contrast"
        # if the acquisition contrast is gradient echo ("ge")
        if self.inputs[self.KEY_ACQ_CONTRAST].lower() == self.CONTRAST_GE:
            # 't2_star' must be present in inputs
            if self.inputs.get(self.KEY_T2_STAR) is None:
                raise FilterInputValidationError(
                    "Acquisition contrast is ge, 't2_star' image required")
        # if the acquisition contrast is gradient echo ("ge") or inversion recovery ("ir")
        if self.inputs[self.KEY_ACQ_CONTRAST].lower() in (
                self.CONTRAST_GE,
                self.CONTRAST_IR,
        ):
            # 'excitation_flip_angle' must be present in inputs
            if self.inputs.get(self.KEY_EXCITATION_FLIP_ANGLE) is None:
                raise FilterInputValidationError(
                    f"Acquisition contrast is {self.inputs[self.KEY_ACQ_CONTRAST]},"
                    " 'excitation_flip_angle' required")

        # if the acquisition contrast is inversion recovery ("ir")
        if self.inputs[self.KEY_ACQ_CONTRAST].lower() == self.CONTRAST_IR:
            if self.inputs.get(self.KEY_INVERSION_FLIP_ANGLE) is None:
                raise FilterInputValidationError(
                    f"Acquisition contrast is {self.inputs[self.KEY_ACQ_CONTRAST]},"
                    " 'inversion_flip_angle' required")
            if self.inputs.get(self.KEY_INVERSION_TIME) is None:
                raise FilterInputValidationError(
                    f"Acquisition contrast is {self.inputs[self.KEY_ACQ_CONTRAST]},"
                    " 'inversion_time' required")
            if self.inputs.get(self.KEY_REPETITION_TIME) < (
                    self.inputs.get(self.KEY_ECHO_TIME) +
                    self.inputs.get(self.KEY_INVERSION_TIME)):
                raise FilterInputValidationError(
                    "repetition_time must be greater than echo_time + inversion_time"
                )

        # Check repetition_time is not < echo_time for ge and se
        if self.inputs.get(self.KEY_REPETITION_TIME) < self.inputs.get(
                self.KEY_ECHO_TIME):
            raise FilterInputValidationError(
                "repetition_time must be greater than echo_time")

        # Check that all the input images are all the same dimensions
        input_keys = self.inputs.keys()
        keys_of_images = [
            key for key in input_keys
            if isinstance(self.inputs[key], BaseImageContainer)
        ]

        list_of_image_shapes = [
            self.inputs[key].shape for key in keys_of_images
        ]
        if list_of_image_shapes.count(
                list_of_image_shapes[0]) != len(list_of_image_shapes):
            raise FilterInputValidationError([
                "Input image shapes do not match.",
                [
                    f"{keys_of_images[i]}: {list_of_image_shapes[i]}, "
                    for i in range(len(list_of_image_shapes))
                ],
            ])

        # Check that all the input images are not of image_type == "COMPLEX_IMAGE_TYPE"
        for key in keys_of_images:
            if self.inputs[key].image_type == COMPLEX_IMAGE_TYPE:
                raise FilterInputValidationError(
                    f"{key} has image type {COMPLEX_IMAGE_TYPE}, this is not supported"
                )
Esempio n. 7
0
    def _validate_inputs(self):
        """ Checks that the inputs meet their validation critera
        'image' must be a derived from BaseImageContainer
        'output_directory' must be a string and a path
        'filename_prefix' must be a string and is optional

        Also checks the input image's metadata
        """

        input_validator = ParameterValidator(
            parameters={
                self.KEY_IMAGE:
                Parameter(validators=isinstance_validator(BaseImageContainer)),
                self.KEY_OUTPUT_DIRECTORY:
                Parameter(validators=isinstance_validator(str)),
                self.KEY_FILENAME_PREFIX:
                Parameter(
                    validators=isinstance_validator(str),
                    optional=True,
                    default_value="",
                ),
            })
        # validate the inputs
        new_params = input_validator.validate(
            self.inputs, error_type=FilterInputValidationError)

        metdata_validator = ParameterValidator(
            parameters={
                self.SERIES_TYPE:
                Parameter(
                    validators=from_list_validator(SUPPORTED_IMAGE_TYPES)),
                MODALITY:
                Parameter(validators=isinstance_validator(str), optional=True),
                self.SERIES_NUMBER:
                Parameter(validators=[
                    isinstance_validator(int),
                    greater_than_equal_to_validator(0),
                ]),
                ASL_CONTEXT:
                Parameter(
                    validators=isinstance_validator((str, list)),
                    optional=True,
                ),
                GkmFilter.KEY_LABEL_TYPE:
                Parameter(
                    validators=isinstance_validator(str),
                    optional=True,
                ),
                GkmFilter.KEY_LABEL_DURATION:
                Parameter(validators=isinstance_validator(float),
                          optional=True),
                GkmFilter.KEY_POST_LABEL_DELAY:
                Parameter(validators=isinstance_validator(float),
                          optional=True),
                GkmFilter.KEY_LABEL_EFFICIENCY:
                Parameter(validators=isinstance_validator(float),
                          optional=True),
                GroundTruthLoaderFilter.KEY_QUANTITY:
                Parameter(
                    validators=isinstance_validator(str),
                    optional=True,
                ),
                GroundTruthLoaderFilter.KEY_UNITS:
                Parameter(
                    validators=isinstance_validator(str),
                    optional=True,
                ),
                "image_flavour":
                Parameter(
                    validators=isinstance_validator(str),
                    optional=True,
                ),
            })
        # validate the metadata
        metadata = self.inputs[self.KEY_IMAGE].metadata
        metdata_validator.validate(metadata,
                                   error_type=FilterInputValidationError)

        # Specific validation for series_type == "structural"
        if metadata[self.SERIES_TYPE] == STRUCTURAL:
            if metadata.get(MODALITY) is None:
                raise FilterInputValidationError(
                    "metadata field 'modality' is required when `series_type` is 'structural'"
                )

        # specific validation when series_type is "ground_truth"
        if metadata[self.SERIES_TYPE] == GROUND_TRUTH:
            if metadata.get(GroundTruthLoaderFilter.KEY_QUANTITY) is None:
                raise FilterInputValidationError(
                    "metadata field 'quantity' is required when `series_type` is 'ground_truth'"
                )
        if metadata[self.SERIES_TYPE] == GROUND_TRUTH:
            if metadata.get(GroundTruthLoaderFilter.KEY_UNITS) is None:
                raise FilterInputValidationError(
                    "metadata field 'units' is required when `series_type` is 'ground_truth'"
                )

        # Specific validation for series_type == "asl"
        if metadata[self.SERIES_TYPE] == ASL:
            # asl_context needs some further validating
            asl_context = metadata.get(ASL_CONTEXT)
            if asl_context is None:
                raise FilterInputValidationError(
                    "metadata field 'asl_context' is required when `series_type` is 'asl'"
                )
            if isinstance(asl_context, str):
                asl_context_validator = ParameterValidator(
                    parameters={
                        ASL_CONTEXT:
                        Parameter(validators=from_list_validator(
                            SUPPORTED_ASL_CONTEXTS), ),
                    })

            elif isinstance(asl_context, list):
                asl_context_validator = ParameterValidator(
                    parameters={
                        ASL_CONTEXT:
                        Parameter(validators=for_each_validator(
                            from_list_validator(SUPPORTED_ASL_CONTEXTS)), ),
                    })
            asl_context_validator.validate(
                {"asl_context": asl_context},
                error_type=FilterInputValidationError)

            # determine the modality_label based on asl_context
            modality_label = self.determine_asl_modality_label(asl_context)

            if modality_label == ASL:
                # do some checking for when the `modality` is 'asl'
                if metadata.get(GkmFilter.KEY_LABEL_TYPE) is None:
                    raise FilterInputValidationError(
                        "metadata field 'label_type' is required for 'series_type'"
                        + " and 'modality' is 'asl'")
                if metadata.get(GkmFilter.KEY_LABEL_DURATION) is None:
                    raise FilterInputValidationError(
                        "metadata field 'label_duration' is required for 'series_type'"
                        + " and 'modality' is 'asl'")
                if metadata.get(GkmFilter.KEY_POST_LABEL_DELAY) is None:
                    raise FilterInputValidationError(
                        "metadata field 'post_label_delay' is required for 'series_type'"
                        + " and 'modality' is 'asl'")
                if metadata.get("image_flavour") is None:
                    raise FilterInputValidationError(
                        "metadata field 'image_flavour' is required for 'series_type'"
                        + " and 'modality' is 'asl'")

        # Check that self.inputs[self.KEY_OUTPUT_DIRECTORY] is a valid path.
        if not os.path.exists(self.inputs[self.KEY_OUTPUT_DIRECTORY]):
            raise FilterInputValidationError(
                f"'output_directory' {self.inputs[self.KEY_OUTPUT_DIRECTORY]} does not exist"
            )

        # merge the updated parameters from the output with the input parameters
        self.inputs = {**self._i, **new_params}
Esempio n. 8
0
     default_value=[197, 233, 189],
 ),
 ACQ_CONTRAST:
 Parameter(
     validators=from_list_validator(["ge", "se", "ir"],
                                    case_insensitive=True),
     default_value="se",
 ),
 EXCITATION_FLIP_ANGLE:
 Parameter(validators=range_inclusive_validator(-180.0, 180.0),
           default_value=90.0),
 INVERSION_FLIP_ANGLE:
 Parameter(validators=range_inclusive_validator(-180.0, 180.0),
           default_value=180.0),
 INVERSION_TIME:
 Parameter(validators=greater_than_equal_to_validator(0.0),
           default_value=1.0),
 DESIRED_SNR:
 Parameter(validators=greater_than_equal_to_validator(0),
           default_value=50.0),
 RANDOM_SEED:
 Parameter(validators=greater_than_equal_to_validator(0),
           default_value=0),
 OUTPUT_IMAGE_TYPE:
 Parameter(
     validators=from_list_validator(["complex", "magnitude"]),
     default_value="magnitude",
 ),
 MODALITY:
 Parameter(
     validators=from_list_validator(["T1w", "T2w", "FLAIR", "anat"
    def _validate_inputs(self):
        """Checks that the inputs meet their validation criteria
        'perfusion_rate' must be derived from BaseImageContainer and be >= 0
        'transit_time' must be derived from BaseImageContainer and be >= 0
        'm0' must be either a float or derived from BaseImageContainer and be >= 0
        'label_type' must be a string and equal to "CASL" OR "pCASL" OR "PASL"
        'label_duration" must be a float between 0 and 100
        'signal_time' must be a float between 0 and 100
        'label_efficiency' must be a float between 0 and 1
        'lambda_blood_brain' must be a float between 0 and 1
        't1_arterial_blood' must be a float between 0 and 100

        all BaseImageContainers supplied should be the same dimensions
        """
        input_validator = ParameterValidator(
            parameters={
                self.KEY_PERFUSION_RATE:
                Parameter(validators=[
                    greater_than_equal_to_validator(0),
                    isinstance_validator(BaseImageContainer),
                ]),
                self.KEY_TRANSIT_TIME:
                Parameter(validators=[
                    greater_than_equal_to_validator(0),
                    isinstance_validator(BaseImageContainer),
                ]),
                self.KEY_M0:
                Parameter(validators=[
                    greater_than_equal_to_validator(0),
                    isinstance_validator((BaseImageContainer, float)),
                ]),
                self.KEY_T1_TISSUE:
                Parameter(validators=[
                    range_inclusive_validator(0, 100),
                    isinstance_validator(BaseImageContainer),
                ]),
                self.KEY_LABEL_TYPE:
                Parameter(validators=from_list_validator(
                    [self.CASL, self.PCASL, self.PASL],
                    case_insensitive=True)),
                self.KEY_LABEL_DURATION:
                Parameter(validators=[
                    range_inclusive_validator(0, 100),
                    isinstance_validator(float),
                ]),
                self.KEY_SIGNAL_TIME:
                Parameter(validators=[
                    range_inclusive_validator(0, 100),
                    isinstance_validator(float),
                ]),
                self.KEY_LABEL_EFFICIENCY:
                Parameter(validators=[
                    range_inclusive_validator(0, 1),
                    isinstance_validator(float),
                ]),
                self.KEY_LAMBDA_BLOOD_BRAIN:
                Parameter(validators=[
                    range_inclusive_validator(0, 1),
                    isinstance_validator(float),
                ]),
                self.KEY_T1_ARTERIAL_BLOOD:
                Parameter(validators=[
                    range_inclusive_validator(0, 100),
                    isinstance_validator(float),
                ]),
            })

        input_validator.validate(self.inputs,
                                 error_type=FilterInputValidationError)

        # Check that all the input images are all the same dimensions
        input_keys = self.inputs.keys()
        keys_of_images = [
            key for key in input_keys
            if isinstance(self.inputs[key], BaseImageContainer)
        ]

        list_of_image_shapes = [
            self.inputs[key].shape for key in keys_of_images
        ]
        if list_of_image_shapes.count(
                list_of_image_shapes[0]) != len(list_of_image_shapes):
            raise FilterInputValidationError([
                "Input image shapes do not match.",
                [
                    f"{keys_of_images[i]}: {list_of_image_shapes[i]}, "
                    for i in range(len(list_of_image_shapes))
                ],
            ])