def test_contains_missing_output_disallowed_values(self, value):
        condition = Conditions.EQUALS
        sparse_yaml_str = output_requirements_yaml(self.field, condition, [value])

        parsed_yaml = load(sparse_yaml_str, get_type_schema_yaml_validator())
        with pytest.raises(YAMLValidationError):
            revalidate_typeschema(parsed_yaml)
 def yaml_str_to_schema_dict(yaml_str: str) -> dict:
     """this emulates how we cast a yaml to a dict for validation in
     `datarobot_drum.drum.common.read_model_metadata_yaml` and these assumptions
     are tested in: `tests.drum.test_units.test_read_model_metadata_properly_casts_typeschema` """
     schema = load(yaml_str, get_type_schema_yaml_validator())
     revalidate_typeschema(schema)
     return schema.data
    def test_sparsity_output_only_single_value(self):
        condition = Conditions.EQUALS
        sparse_yaml_str = output_requirements_yaml(self.field, condition, Values.output_values())

        parsed_yaml = load(sparse_yaml_str, get_type_schema_yaml_validator())
        with pytest.raises(YAMLValidationError):
            revalidate_typeschema(parsed_yaml)
 def test_contains_missing_input_output_disallows_conditions(self, condition):
     sparse_yaml_input_str = input_requirements_yaml(self.field, condition, [Values.REQUIRED])
     sparse_yaml_output_str = output_requirements_yaml(self.field, condition, [Values.ALWAYS])
     for yaml_str in (sparse_yaml_input_str, sparse_yaml_output_str):
         parsed_yaml = load(yaml_str, get_type_schema_yaml_validator())
         with pytest.raises(YAMLValidationError):
             revalidate_typeschema(parsed_yaml)
    def test_datatypes_allowed_conditions(self, condition):
        values = [Values.NUM, Values.TXT]
        input_data_type_str = input_requirements_yaml(self.field, condition, values)
        output_data_type_str = output_requirements_yaml(self.field, condition, values)

        for data_type_str in (input_data_type_str, output_data_type_str):
            parsed_yaml = load(data_type_str, get_type_schema_yaml_validator())
            revalidate_typeschema(parsed_yaml)
    def test_datatyped_allowed_values(self, value):
        condition = Conditions.EQUALS
        input_data_type_str = input_requirements_yaml(self.field, condition, [value])
        output_data_type_str = output_requirements_yaml(self.field, condition, [value])

        for data_type_str in (input_data_type_str, output_data_type_str):
            parsed_yaml = load(data_type_str, get_type_schema_yaml_validator())
            revalidate_typeschema(parsed_yaml)
    def test_contains_missing_output_only_single_value(self):
        condition = Conditions.EQUALS
        sparse_yaml_str = output_requirements_yaml(
            self.field, condition, [Values.NEVER, Values.DYNAMIC]
        )

        parsed_yaml = load(sparse_yaml_str, get_type_schema_yaml_validator())
        with pytest.raises(YAMLValidationError):
            revalidate_typeschema(parsed_yaml)
    def test_datatypes_multiple_values(self):
        condition = Conditions.IN
        values = Values.data_values()
        input_data_type_str = input_requirements_yaml(self.field, condition, values)
        output_data_type_str = output_requirements_yaml(self.field, condition, values)

        for data_type_str in (input_data_type_str, output_data_type_str):
            parsed_yaml = load(data_type_str, get_type_schema_yaml_validator())
            revalidate_typeschema(parsed_yaml)
    def test_contains_missing_input_only_single_value(self):
        condition = Conditions.EQUALS
        sparse_yaml_str = input_requirements_yaml(
            self.field, condition, [Values.FORBIDDEN, Values.SUPPORTED]
        )

        parsed_yaml = load(sparse_yaml_str, get_type_schema_yaml_validator())
        with pytest.raises(YAMLValidationError):
            revalidate_typeschema(parsed_yaml)
    def test_datatypes_mix_allowed_and_unallowed_values(self):
        values = [Values.NUM, Values.REQUIRED]
        condition = Conditions.EQUALS
        input_data_type_str = input_requirements_yaml(self.field, condition, values)
        output_data_type_str = output_requirements_yaml(self.field, condition, values)

        for data_type_str in (input_data_type_str, output_data_type_str):
            parsed_yaml = load(data_type_str, get_type_schema_yaml_validator())
            with pytest.raises(YAMLValidationError):
                revalidate_typeschema(parsed_yaml)
    def test_revalidate_typescehma_mutates_yaml_num_columns_to_int(self):
        yaml_single_int = input_requirements_yaml(self.field, Conditions.EQUALS, [1])
        yaml_int_list = input_requirements_yaml(self.field, Conditions.EQUALS, [1, 2])
        parsed_single_int = load(yaml_single_int, get_type_schema_yaml_validator())
        parsed_int_list = load(yaml_int_list, get_type_schema_yaml_validator())

        def get_value(yaml):
            return yaml[str(RequirementTypes.INPUT_REQUIREMENTS)][0]["value"].data

        assert isinstance(get_value(parsed_single_int), str)
        assert isinstance(get_value(parsed_int_list)[0], str)

        revalidate_typeschema(parsed_single_int)
        revalidate_typeschema(parsed_int_list)

        assert isinstance(get_value(parsed_single_int), int)
        assert isinstance(get_value(parsed_int_list)[0], int)
Пример #12
0
def read_model_metadata_yaml(code_dir) -> PythonTypingOptional[dict]:
    code_dir = Path(code_dir)
    config_path = code_dir.joinpath(MODEL_CONFIG_FILENAME)
    if config_path.exists():
        with open(config_path) as f:
            try:
                model_config = load(f.read(), MODEL_CONFIG_SCHEMA)
                if "typeSchema" in model_config:
                    revalidate_typeschema(model_config["typeSchema"])
                model_config = model_config.data
            except YAMLError as e:
                print(e)
                raise SystemExit(1)

        if model_config[
                ModelMetadataKeys.TARGET_TYPE] == TargetType.BINARY.value:
            if model_config[ModelMetadataKeys.TYPE] == "inference":
                validate_config_fields(model_config,
                                       ModelMetadataKeys.INFERENCE_MODEL)
                validate_config_fields(
                    model_config[ModelMetadataKeys.INFERENCE_MODEL],
                    *["positiveClassLabel", "negativeClassLabel"])

        if model_config[
                ModelMetadataKeys.TARGET_TYPE] == TargetType.MULTICLASS.value:
            if model_config[ModelMetadataKeys.TYPE] == "inference":
                validate_config_fields(model_config,
                                       ModelMetadataKeys.INFERENCE_MODEL)
                classLabelsKeyIn = "classLabels" in model_config[
                    ModelMetadataKeys.INFERENCE_MODEL]
                classLabelFileKeyIn = (
                    "classLabelsFile"
                    in model_config[ModelMetadataKeys.INFERENCE_MODEL])
                if all([classLabelsKeyIn, classLabelFileKeyIn]):
                    raise DrumCommonException(
                        "\nError - for multiclass classification, either the class labels or "
                        "a class labels file should be provided in {} file, but not both."
                        .format(MODEL_CONFIG_FILENAME))
                elif not any([classLabelsKeyIn, classLabelFileKeyIn]):
                    raise DrumCommonException(
                        "\nError - for multiclass classification, either the class labels or "
                        "a class labels file must be provided in {} file.".
                        format(MODEL_CONFIG_FILENAME))

                if classLabelFileKeyIn:
                    classLabelsFile = model_config[
                        ModelMetadataKeys.INFERENCE_MODEL]["classLabelsFile"]

                    with open(classLabelsFile) as f:
                        labels = [
                            label for label in f.read().split(os.linesep)
                            if label
                        ]
                        if len(labels) < 2:
                            raise DrumCommonException(
                                "Multiclass classification requires at least 2 labels."
                            )
                        model_config[ModelMetadataKeys.
                                     INFERENCE_MODEL]["classLabels"] = labels
                        model_config[ModelMetadataKeys.
                                     INFERENCE_MODEL]["classLabelsFile"] = None

        return model_config
    return None
 def test_failing_on_bad_value(self, passing_yaml_string):
     bad_yaml = passing_yaml_string.replace("NUM", "oooooops")
     parsed_yaml = load(bad_yaml, get_type_schema_yaml_validator())
     with pytest.raises(YAMLValidationError):
         revalidate_typeschema(parsed_yaml)
 def test_happy_path(self, passing_yaml_string):
     parsed_yaml = load(passing_yaml_string, get_type_schema_yaml_validator())
     revalidate_typeschema(parsed_yaml)
 def test_number_of_columns_cannot_use_other_values(self, value):
     yaml_str = input_requirements_yaml(self.field, Conditions.EQUALS, [value])
     parsed_yaml = load(yaml_str, get_type_schema_yaml_validator())
     with pytest.raises(YAMLValidationError):
         revalidate_typeschema(parsed_yaml)
 def test_number_of_columns_can_have_multiple_ints(self):
     yaml_str = input_requirements_yaml(self.field, Conditions.EQUALS, [1, 0, -1])
     parsed_yaml = load(yaml_str, get_type_schema_yaml_validator())
     revalidate_typeschema(parsed_yaml)
 def test_number_of_columns_can_use_all_conditions(self, condition):
     sparse_yaml_input_str = input_requirements_yaml(self.field, condition, [1])
     sparse_yaml_output_str = output_requirements_yaml(self.field, condition, [1])
     for yaml_str in (sparse_yaml_input_str, sparse_yaml_output_str):
         parsed_yaml = load(yaml_str, get_type_schema_yaml_validator())
         revalidate_typeschema(parsed_yaml)
 def test_regression_test_datatypes_multi_values(self, permutation):
     corner_case = input_requirements_yaml(Fields.DATA_TYPES, Conditions.IN, permutation)
     parsed_yaml = load(corner_case, get_type_schema_yaml_validator())
     revalidate_typeschema(parsed_yaml)
    def test_contains_missing_input_allowed_values(self, value):
        condition = Conditions.EQUALS
        sparse_yaml_str = input_requirements_yaml(self.field, condition, [value])

        parsed_yaml = load(sparse_yaml_str, get_type_schema_yaml_validator())
        revalidate_typeschema(parsed_yaml)