def validate(self) -> None: """ Validates the parameters stored in the present object. """ super().validate() check_is_any_of("Architecture", self.architecture, vars(ModelArchitectureConfig).keys()) def len_or_zero(lst: Optional[List[Any]]) -> int: return 0 if lst is None else len(lst) if self.kernel_size % 2 == 0: raise ValueError("The kernel size must be an odd number (kernel_size: {})".format(self.kernel_size)) if self.architecture != ModelArchitectureConfig.UNet3D: if any_pairwise_larger(self.center_size, self.crop_size): raise ValueError("Each center_size should be less than or equal to the crop_size " "(center_size: {}, crop_size: {}".format(self.center_size, self.crop_size)) else: if self.crop_size != self.center_size: raise ValueError("For UNet3D, the center size of each dimension should be equal to the crop size " "(center_size: {}, crop_size: {}".format(self.center_size, self.crop_size)) self.validate_inference_stride_size(self.inference_stride_size, self.get_output_size()) # check to make sure there is no overlap between image and ground-truth channels image_gt_intersect = np.intersect1d(self.image_channels, self.ground_truth_ids) if len(image_gt_intersect) != 0: raise ValueError("Channels: {} were found in both image_channels, and ground_truth_ids" .format(image_gt_intersect)) valid_norm_methods = [method.value for method in PhotometricNormalizationMethod] check_is_any_of("norm_method", self.norm_method.value, valid_norm_methods) if len(self.trim_percentiles) < 2 or self.trim_percentiles[0] >= self.trim_percentiles[1]: raise ValueError("Thresholds should contain lower and upper percentile thresholds, but got: {}" .format(self.trim_percentiles)) if len_or_zero(self.class_weights) != (len_or_zero(self.ground_truth_ids) + 1): raise ValueError("class_weights needs to be equal to number of ground_truth_ids + 1") if self.class_weights is None: raise ValueError("class_weights must be set.") SegmentationModelBase.validate_class_weights(self.class_weights) if self.ground_truth_ids is None: raise ValueError("ground_truth_ids is None") if len(self.ground_truth_ids_display_names) != len(self.ground_truth_ids): raise ValueError("len(ground_truth_ids_display_names)!=len(ground_truth_ids)") if len(self.ground_truth_ids_display_names) != len(self.colours): raise ValueError("len(ground_truth_ids_display_names)!=len(colours)") if len(self.ground_truth_ids_display_names) != len(self.fill_holes): raise ValueError("len(ground_truth_ids_display_names)!=len(fill_holes)") if self.mean_teacher_alpha is not None: raise ValueError("Mean teacher model is currently only supported for ScalarModels." "Please reset mean_teacher_alpha to None.") if not self.disable_extra_postprocessing: if self.slice_exclusion_rules is not None: for rule in self.slice_exclusion_rules: rule.validate(self.ground_truth_ids) if self.summed_probability_rules is not None: for rule in self.summed_probability_rules: rule.validate(self.ground_truth_ids)
def test_is_any_of() -> None: """ Tests for check_is_any_of: checks if a string is any of the strings in a valid set. """ check_is_any_of("prefix", "foo", ["foo"]) check_is_any_of("prefix", "foo", ["bar", "foo"]) check_is_any_of("prefix", None, ["bar", "foo", None]) # When the value is not found, an error message with the valid values should be printed with pytest.raises(ValueError) as ex: check_is_any_of("prefix", None, ["bar", "foo"]) assert "bar" in ex.value.args[0] assert "foo" in ex.value.args[0] assert "prefix" in ex.value.args[0] # The error message should also work when one of the valid values is None with pytest.raises(ValueError) as ex: check_is_any_of("prefix", "baz", ["bar", None]) assert "bar" in ex.value.args[0] assert "<None>" in ex.value.args[0] assert "prefix" in ex.value.args[0] assert "baz" in ex.value.args[0]