Exemplo n.º 1
0
    def validate(self) -> None:
        """
        Validates the parameters stored in the present object.
        """
        super().validate()
        check_is_any_of("Architecture", self.architecture, vars(ModelArchitectureConfig).keys())

        def len_or_zero(lst: Optional[List[Any]]) -> int:
            return 0 if lst is None else len(lst)

        if self.kernel_size % 2 == 0:
            raise ValueError("The kernel size must be an odd number (kernel_size: {})".format(self.kernel_size))

        if self.architecture != ModelArchitectureConfig.UNet3D:
            if any_pairwise_larger(self.center_size, self.crop_size):
                raise ValueError("Each center_size should be less than or equal to the crop_size "
                                 "(center_size: {}, crop_size: {}".format(self.center_size, self.crop_size))
        else:
            if self.crop_size != self.center_size:
                raise ValueError("For UNet3D, the center size of each dimension should be equal to the crop size "
                                 "(center_size: {}, crop_size: {}".format(self.center_size, self.crop_size))

        self.validate_inference_stride_size(self.inference_stride_size, self.get_output_size())

        # check to make sure there is no overlap between image and ground-truth channels
        image_gt_intersect = np.intersect1d(self.image_channels, self.ground_truth_ids)
        if len(image_gt_intersect) != 0:
            raise ValueError("Channels: {} were found in both image_channels, and ground_truth_ids"
                             .format(image_gt_intersect))

        valid_norm_methods = [method.value for method in PhotometricNormalizationMethod]
        check_is_any_of("norm_method", self.norm_method.value, valid_norm_methods)

        if len(self.trim_percentiles) < 2 or self.trim_percentiles[0] >= self.trim_percentiles[1]:
            raise ValueError("Thresholds should contain lower and upper percentile thresholds, but got: {}"
                             .format(self.trim_percentiles))

        if len_or_zero(self.class_weights) != (len_or_zero(self.ground_truth_ids) + 1):
            raise ValueError("class_weights needs to be equal to number of ground_truth_ids + 1")
        if self.class_weights is None:
            raise ValueError("class_weights must be set.")
        SegmentationModelBase.validate_class_weights(self.class_weights)
        if self.ground_truth_ids is None:
            raise ValueError("ground_truth_ids is None")
        if len(self.ground_truth_ids_display_names) != len(self.ground_truth_ids):
            raise ValueError("len(ground_truth_ids_display_names)!=len(ground_truth_ids)")
        if len(self.ground_truth_ids_display_names) != len(self.colours):
            raise ValueError("len(ground_truth_ids_display_names)!=len(colours)")
        if len(self.ground_truth_ids_display_names) != len(self.fill_holes):
            raise ValueError("len(ground_truth_ids_display_names)!=len(fill_holes)")
        if self.mean_teacher_alpha is not None:
            raise ValueError("Mean teacher model is currently only supported for ScalarModels."
                             "Please reset mean_teacher_alpha to None.")
        if not self.disable_extra_postprocessing:
            if self.slice_exclusion_rules is not None:
                for rule in self.slice_exclusion_rules:
                    rule.validate(self.ground_truth_ids)
            if self.summed_probability_rules is not None:
                for rule in self.summed_probability_rules:
                    rule.validate(self.ground_truth_ids)
def test_is_any_of() -> None:
    """
    Tests for check_is_any_of: checks if a string is any of the strings in a valid set.
    """
    check_is_any_of("prefix", "foo", ["foo"])
    check_is_any_of("prefix", "foo", ["bar", "foo"])
    check_is_any_of("prefix", None, ["bar", "foo", None])
    # When the value is not found, an error message with the valid values should be printed
    with pytest.raises(ValueError) as ex:
        check_is_any_of("prefix", None, ["bar", "foo"])
    assert "bar" in ex.value.args[0]
    assert "foo" in ex.value.args[0]
    assert "prefix" in ex.value.args[0]
    # The error message should also work when one of the valid values is None
    with pytest.raises(ValueError) as ex:
        check_is_any_of("prefix", "baz", ["bar", None])
    assert "bar" in ex.value.args[0]
    assert "<None>" in ex.value.args[0]
    assert "prefix" in ex.value.args[0]
    assert "baz" in ex.value.args[0]