Пример #1
0
    def test_load_live_schemas(self):
        MLSchema.populate_registry()
        all_schema_paths = list(Path('mlspeclib').glob('schemas/**/*.yaml'))

        for schema_path in all_schema_paths:
            this_text = schema_path.read_text()
            this_dict = convert_yaml_to_dict(this_text)
            if this_dict['mlspec_schema_version']['meta'] == '0.0.1':
                continue
            print(schema_path)
            loaded_schema = MLSchema.create_schema(this_text)
            self.assertIsNotNone(loaded_schema.schema_name)
Пример #2
0
    def test_try_create_missing_mlspec_version_and_type(self):
        schema_string = """
            mlspec_schema_version:
                # Identifies the version of this schema
                meta: 0.0.1

            mlspec_schema_type:
                # Identifies the type of this schema
                meta: base
            """
        no_version = convert_yaml_to_dict(schema_string)
        no_version.pop("mlspec_schema_version")

        with self.assertRaises(KeyError):
            MLSchema.create_schema(no_version)

        no_schema = convert_yaml_to_dict(schema_string)
        no_schema.pop("mlspec_schema_type")

        with self.assertRaises(KeyError):
            MLSchema.create_schema(no_schema)
Пример #3
0
 def populate_registry():
     for schema_file in list(
             Path(os.path.dirname(__file__)).glob('schemas/**/*.yaml')):
         schema_text = schema_file.read_text()
         loaded_schema = convert_yaml_to_dict(schema_text)
         loaded_schema_name = MLSchema.build_schema_name_for_schema(
             mlspec_schema_type=loaded_schema['mlspec_schema_type'],
             mlspec_schema_version=loaded_schema['mlspec_schema_version'])
         try:
             marshmallow.class_registry.get_class(loaded_schema_name)
         except RegistryError:
             MLSchema.create_schema(loaded_schema)
Пример #4
0
    def create_schema_type(raw_string: dict, schema_name: str = None):
        """ Creates a new schema from a string of yaml. inheriting from MLSchema.\
            Schema still needs to be instantiated before use.

            e.g. this_schema = MLSchema.create_schema(raw_string)
                 schema = this_schema()
                 this_object = schema.load(object_submission_dict)
            """
        schema_as_dict = convert_yaml_to_dict(raw_string)
        schema_as_dict = MLSchema._augment_with_base_schema(schema_as_dict)

        if schema_name is None:
            # schema_name = foo_foo("a", "b")
            schema_name = build_schema_name_for_schema(
                mlspec_schema_version=schema_as_dict["mlspec_schema_version"],
                mlspec_schema_type=schema_as_dict["mlspec_schema_type"],
            )

        fields_dict = {}

        for field in schema_as_dict:
            if "marshmallow.fields" in str(
                    type(schema_as_dict[field]
                         )) or "mlschemafields.MLSchemaFields" in str(
                             type(schema_as_dict[field])):
                # In this case, the field has already been created an instantiated properly (because
                # it comes from a base schema registered in marshmallow.class_registry). We can skip
                # all of the below and just add it to the field dict. This includes nested fields.
                fields_dict[field] = schema_as_dict[field]
            elif schema_as_dict[field] is None:
                raise AttributeError(
                    f"""It appears at the field '{field}' the yaml/dict \
                    is not formatted with attributes. Could it be an indentation error?"""
                )
            elif ("type" in schema_as_dict[field]
                  and schema_as_dict[field]["type"].lower() == "nested"):

                nested_schema_type = MLSchema.create_schema_type(
                    schema_as_dict[field]["schema"],
                    schema_name + "_" + field.lower())
                nested_schema_type.name = field
                fields_dict[field] = fields.Nested(nested_schema_type)
            else:
                field_method = MLSchema._field_method_from_dict(
                    field, schema_as_dict[field])
                if field_method:
                    fields_dict[field] = field_method

        abstract_schema = MLSchema.from_dict(fields_dict)
        if schema_name:
            marshmallow.class_registry.register(schema_name, abstract_schema)
            abstract_schema.schema_name = schema_name
        return abstract_schema
Пример #5
0
    def test_convert_dicts_to_sub_schema(self):
        class UserSchema(Schema):
            name = fields.String(required=True)
            email = fields.Email(required=True)


        class BlogSchema(Schema):
            title = fields.String(required=True)
            year = fields.Int(required=True)
            author = fields.Nested(UserSchema, required=True)

        marshmallow.class_registry.register("blog_author", \
                                            UserSchema)
        marshmallow.class_registry.register("blog", \
                                            BlogSchema)

        sub_schema_string = """
            title: "Something Completely Different"
            year: 1970
            author:
                name: "Monty"
                email: "*****@*****.**"
            """
        full_schema_data = convert_yaml_to_dict(sub_schema_string)
        full_schema_loaded = BlogSchema().load(full_schema_data)

        self.assertTrue(full_schema_loaded['title'] == full_schema_data['title'])
        self.assertTrue(full_schema_loaded['author']['name'] == full_schema_data['author']['name'])

        missing_author_name_data = convert_yaml_to_dict(sub_schema_string)
        missing_author_name_data['author'].pop('name', None)

        with self.assertRaises(ValidationError):
            BlogSchema().load(missing_author_name_data)

        missing_year_data = convert_yaml_to_dict(sub_schema_string)
        missing_year_data.pop('year', None)

        with self.assertRaises(ValidationError):
            BlogSchema().load(missing_year_data)
Пример #6
0
    def test_load_live_data(self):
        MLSchema.populate_registry()
        all_data_files = list(Path("tests").glob("data/**/*.yaml"))

        for data_file in all_data_files:
            this_text = data_file.read_text(encoding="utf-8")
            this_dict = convert_yaml_to_dict(this_text)
            if this_dict["schema_version"] == "0.0.1":
                continue
            # print(data_file)
            loaded_object, errors = MLObject.create_object_from_file(data_file)
            self.assertTrue(len(errors) == 0)
            self.assertIsNotNone(loaded_object.get_schema())
Пример #7
0
    def test_load_live_data(self):
        MLSchema.populate_registry()
        all_data_files = list(Path('tests').glob('data/**/*.yaml'))

        for data_file in all_data_files:
            this_text = data_file.read_text()
            this_dict = convert_yaml_to_dict(this_text)
            if this_dict['schema_version'] == '0.0.1':
                continue
            print(data_file)
            loaded_object, errors = MLObject.create_object_from_file(data_file)
            self.assertTrue(len(errors) == 0)
            self.assertIsNotNone(loaded_object.get_schema())
Пример #8
0
    def test_interfaces_mismatch_type(self):
        instantiated_schema = MLSchema.create_schema(
            self.wrap_schema_with_mlschema_info(
                SampleSchema.TEST.INTERFACE))  # noqa
        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.INTERFACE_INVALID_MISMATCH_TYPE)
        )  # noqa

        with self.assertRaises(ValidationError) as context:
            instantiated_schema.load(yaml_submission)

        self.assertTrue(
            'valid default' in context.exception.messages['inputs'][0][0])
Пример #9
0
    def test_interfaces_type_unknown(self):
        instantiated_schema = MLSchema.create_schema(
            self.wrap_schema_with_mlschema_info(
                SampleSchema.TEST.INTERFACE))  # noqa
        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.INTERFACE_INVALID_TYPE_UNKNOWN_1)
        )  # noqa

        with self.assertRaises(ValidationError) as context:
            instantiated_schema.load(yaml_submission)

        self.assertTrue(
            'string or a dict' in context.exception.messages['inputs'][0][0])
Пример #10
0
    def test_load_full_datapath_schema(self):
        MLSchema.create_schema(SampleSchema.SCHEMAS.BASE)
        instantiated_schema = MLSchema.create_schema(
            SampleSchema.SCHEMAS.DATAPATH)
        submission_dict = convert_yaml_to_dict(
            SampleSubmissions.FULL_SUBMISSIONS.DATAPATH)
        instantiated_object = instantiated_schema.load(submission_dict)
        assert instantiated_object['run_date'].isoformat() == \
            submission_dict['run_date'].isoformat()
        assert instantiated_object['connection']['endpoint'] == \
            submission_dict['connection']['endpoint']

        submission_dict.pop('run_date', None)
        with self.assertRaises(ValidationError):
            instantiated_schema.load(submission_dict)
Пример #11
0
    def test_merge_two_dicts_with_valid_base(self):
        base_schema = MLSchema.create_schema(SampleSchema.SCHEMAS.BASE)
        base_object = base_schema.load(SampleSubmissions.FULL_SUBMISSIONS.BASE)
        datapath_schema = MLSchema.create_schema(SampleSchema.SCHEMAS.DATAPATH)
        datapath_object = datapath_schema.load(
            SampleSubmissions.FULL_SUBMISSIONS.DATAPATH
        )

        # Should not work - BASE did not merge with DATAPATH
        with self.assertRaises(KeyError):
            base_object[
                "data_store"
            ] == "NULL_STRING_SHOULD_NOT_WORK"  # pylint: disable=pointless-statement

        base_object_dict = convert_yaml_to_dict(SampleSubmissions.FULL_SUBMISSIONS.BASE)
        datapath_object_dict = convert_yaml_to_dict(
            SampleSubmissions.FULL_SUBMISSIONS.DATAPATH
        )

        self.assertTrue(isinstance(base_object["run_id"], UUID))
        self.assertTrue(base_object["run_id"] == UUID(base_object_dict["run_id"]))

        # Should work - DATAPATH inherited from BASE
        self.assertTrue(isinstance(datapath_object["data_store"], str))
        self.assertTrue(
            datapath_object["data_store"] == datapath_object_dict["data_store"]
        )
        self.assertTrue(isinstance(datapath_object["connection"]["access_key_id"], str))
        self.assertTrue(
            datapath_object["connection"]["access_key_id"]
            == datapath_object_dict["connection"]["access_key_id"]
        )
        self.assertTrue(isinstance(datapath_object["run_id"], UUID))
        self.assertTrue(
            datapath_object["run_id"] == UUID(datapath_object_dict["run_id"])
        )
Пример #12
0
    def populate_registry():
        """ Loads all the base schemas for the schema registry. """

        rootLogger = logging.getLogger()

        schemas_to_process = []
        no_base = []
        has_base = []
        last = []

        load_root = os.path.dirname(mlspeclib.__file__)
        rootLogger.debug(f"Registry load root: {load_root}")
        load_path = Path(load_root).glob("schemas/**/*.yaml")
        rootLogger.debug(f"Registry load path: {load_path}")
        load_list = list(load_path)
        rootLogger.debug(f"Registry load list: {load_list}")

        for schema_file in load_list:
            schema_text = schema_file.read_text('utf-8')
            schema_dict = convert_yaml_to_dict(schema_text)

            if "last" in schema_dict["mlspec_schema_type"]:
                last.append(schema_dict)
            elif "mlspec_base_type" not in schema_dict:
                no_base.append(schema_dict)
            else:
                has_base.append(schema_dict)

        schemas_to_process = no_base + has_base + last

        for loaded_schema in schemas_to_process:
            loaded_schema_name = build_schema_name_for_schema(
                mlspec_schema_type=loaded_schema["mlspec_schema_type"],
                mlspec_schema_version=loaded_schema["mlspec_schema_version"],
            )
            try:
                marshmallow.class_registry.get_class(loaded_schema_name)
            except RegistryError:
                MLSchema.create_schema(loaded_schema)
Пример #13
0
    def create_object_from_string(file_contents: str):
        """ Creates an MLObject based on a string. String must be valid yaml.
        Returns Tuple MLObject and list of errors."""
        contents_as_dict = convert_yaml_to_dict(file_contents)

        # if (self.schema_type() is not None and self.version() is not None):
        #     schema_string = self.schema_type().name.lower()
        #     if (self.version() != contents_as_dict['schema_version'] or \
        #         schema_string != contents_as_dict['schema_type']):
        #         raise AttributeError("""The schema version and schema type were not in sync
        #     with those provided in the data. Rather than guessing which you want to use, we
        #     are throwing this error:
        #     Version Expected: %s
        #     Version Provided: %s
        #     Schema Type Expected: %s
        #     Schema Type Provided: %s """ % (self.version(), contents_as_dict['schema_version'], \
        #                                 schema_string, contents_as_dict['schema_type']))

        ml_object = MLObject()
        ml_object.set_type(schema_version=contents_as_dict["schema_version"],
                           schema_type=contents_as_dict['schema_type'])
        MLObject.update_tree(ml_object, contents_as_dict)
        errors = ml_object.validate()
        return ml_object, errors
Пример #14
0
def return_base_schema_and_submission():
    instantiated_schema = MLSchema.create_schema(SampleSchema.SCHEMAS.BASE)
    yaml_submission = convert_yaml_to_dict(SampleSubmissions.FULL_SUBMISSIONS.BASE)
    return instantiated_schema, yaml_submission
Пример #15
0
class ValidatorsTestSuite(unittest.TestCase):
    """Validators test cases."""

    schema_schema_info = convert_yaml_to_dict("""
mlspec_schema_version:
    # Identifies the version of this schema
    meta: 0.0.1

mlspec_schema_type:
    # Base schema type that this extends
    meta: base

schema_version:
  # Identifies version of MLSpec to use
  type: semver
  required: True
schema_type:
  # Identifies version of MLSpec to use
  type: allowed_schema_types
  required: True
""")
    submission_schema_info = {"schema_version": "0.0.1", "schema_type": "base"}

    def test_semver_found(self):
        assert MLSchemaValidators.validate_type_semver("0.0.1")

    def test_semver_not_found(self):
        assert not MLSchemaValidators.validate_type_semver("x.x.x")

    def test_uuid_found(self):
        instantiated_object = self.generic_schema_validator(
            SampleSchema.TEST.UUID, SampleSubmissions.UNIT_TESTS.UUID_VALID)
        self.assertTrue(instantiated_object["run_id"])

    def test_uuid_not_found(self):
        self.generic_schema_validator(
            SampleSchema.TEST.UUID,
            SampleSubmissions.UNIT_TESTS.UUID_INVALID,
            ValidationError,
            "UUID",
        )

    def test_uri_valid(self):
        instantiated_object = self.generic_schema_validator(
            SampleSchema.TEST.URI, SampleSubmissions.UNIT_TESTS.URI_VALID_1)
        self.assertTrue(instantiated_object["endpoint"])

        instantiated_object = self.generic_schema_validator(
            SampleSchema.TEST.URI, SampleSubmissions.UNIT_TESTS.URI_VALID_2)
        self.assertTrue(instantiated_object["endpoint"])

    def test_uri_invalid(self):
        self.generic_schema_validator(
            SampleSchema.TEST.URI,
            SampleSubmissions.UNIT_TESTS.URI_INVALID_1,
            ValidationError,
            "Invalid",
        )
        self.generic_schema_validator(
            SampleSchema.TEST.URI,
            SampleSubmissions.UNIT_TESTS.URI_INVALID_2,
            ValidationError,
            "valid string",
        )

    def test_regex_valid(self):
        instantiated_object = self.generic_schema_validator(
            SampleSchema.TEST.REGEX,
            SampleSubmissions.UNIT_TESTS.REGEX_ALL_LETTERS)
        self.assertTrue(instantiated_object["all_letters"])

        self.generic_schema_validator(
            SampleSchema.TEST.REGEX,
            SampleSubmissions.UNIT_TESTS.REGEX_ALL_NUMBERS,
            ValidationError,
            "No match",
        )

    def test_regex_invalid(self):
        self.generic_schema_validator(SampleSchema.TEST.INVALID_REGEX, None,
                                      AssertionError, "valid regex")

    # @unittest.skip("NYI")
    # def test_path(self):
    #     self.assertTrue(False)

    # @unittest.skip("NYI")
    # def test_bucket(self):
    #     self.assertTrue(False)

    def test_interfaces_valid(self):
        instantiated_object = self.generic_schema_validator(
            SampleSchema.TEST.INTERFACE,
            SampleSubmissions.UNIT_TESTS.INTERFACE_VALID_UNNAMED,
        )
        self.assertTrue(len(instantiated_object["inputs"]) == 2)

        instantiated_object = self.generic_schema_validator(
            SampleSchema.TEST.INTERFACE,
            SampleSubmissions.UNIT_TESTS.INTERFACE_VALID_NAMED,
        )
        self.assertTrue(len(instantiated_object["inputs"]) == 2)

    # @unittest.skip("Type is not required in KFP (but it should be)")
    # def test_interfaces_missing_type(self):
    #     self.generic_schema_validator(
    #         SampleSchema.TEST.INTERFACE,
    #         SampleSubmissions.UNIT_TESTS.INTERFACE_INVALID_MISSING_TYPE,
    #         ValidationError,
    #         "No type",
    #     )

    def test_interfaces_mismatch_type(self):
        self.generic_schema_validator(
            SampleSchema.TEST.INTERFACE,
            SampleSubmissions.UNIT_TESTS.INTERFACE_INVALID_MISMATCH_TYPE,
            ValidationError,
            "valid default",
        )

    def test_interfaces_type_unknown(self):
        self.generic_schema_validator(
            SampleSchema.TEST.INTERFACE,
            SampleSubmissions.UNIT_TESTS.INTERFACE_INVALID_TYPE_UNKNOWN_1,
            ValidationError,
            "string or a dict",
        )

    def test_validate_constraints_constraint_valid_greater_equal(self):
        this_schema = MLSchema.create_schema(
            self.wrap_schema_with_mlschema_info(
                SampleSchema.TEST.OPERATOR_VALID))

        self.assertTrue(
            isinstance(this_schema.declared_fields["num"], fields.Integer))

    def test_validate_constraints_constraint_valid_modulo(self):
        this_schema = MLSchema.create_schema(
            self.wrap_schema_with_mlschema_info(
                SampleSchema.TEST.OPERATOR_VALID_MODULO_2))

        self.assertTrue(
            isinstance(this_schema.declared_fields["num"], fields.Integer))

    def test_validate_constraints_true_value(self):
        instantiated_object = self.generic_schema_validator(
            SampleSchema.TEST.OPERATOR_VALID,
            SampleSubmissions.UNIT_TESTS.CONSTRAINT_VALID_MORE_THAN_1000,
            None,
        )
        self.assertTrue(isinstance(instantiated_object["num"], int))

    def test_validate_constraints_false_value(self):
        self.generic_schema_validator(
            SampleSchema.TEST.OPERATOR_VALID,
            SampleSubmissions.UNIT_TESTS.CONSTRAINT_VALID_LESS_THAN_1000,
            ValidationError,
            "Invalid value",
        )

    def test_validate_constraints_modulo_true(self):
        instantiated_object = self.generic_schema_validator(
            SampleSchema.TEST.OPERATOR_VALID_MODULO_2,
            SampleSubmissions.UNIT_TESTS.CONSTRAINT_VALID_MODULO_2_TRUE,
        )
        self.assertTrue(isinstance(instantiated_object["num"], int))

    def test_validate_constraints_modulo_false(self):
        self.generic_schema_validator(
            SampleSchema.TEST.OPERATOR_VALID_MODULO_2,
            SampleSubmissions.UNIT_TESTS.CONSTRAINT_VALID_MODULO_2_FALSE,
            ValidationError,
            "Invalid value",
        )

    def test_validate_constraints_constraint_invalid_type(self):
        self.generic_schema_validator(
            SampleSchema.TEST.OPERATOR_INVALID_TYPE,
            None,
            ValueError,
            "Attempting to add",
        )

    def test_validate_constraints_constraint_no_valid_operator(self):
        self.generic_schema_validator(
            SampleSchema.TEST.OPERATOR_INVALID_NO_OPERATOR,
            None,
            ValidationError,
            "No parsable lambda",
        )

    def test_validate_workflow_valid(self):
        MLSchema.populate_registry()
        this_schema = self.generic_schema_validator(
            SampleSchema.TEST.WORKFLOW_STEP,
            SampleSubmissions.UNIT_TESTS.WORKFLOW_VALID,
            None,
            None,
        )
        self.assertTrue(len(this_schema["steps"]["step_name"]) == 5)

    def test_validate_workflow_missing_input(self):
        MLSchema.populate_registry()
        self.generic_schema_validator(
            SampleSchema.TEST.WORKFLOW_STEP,
            SampleSubmissions.UNIT_TESTS.WORKFLOW_BAD_INPUT,
            ValidationError,
            "class registry",
        )

    def test_validate_workflow_bad_semver(self):
        MLSchema.populate_registry()
        self.generic_schema_validator(
            SampleSchema.TEST.WORKFLOW_STEP,
            SampleSubmissions.UNIT_TESTS.WORKFLOW_BAD_SEMVER,
            ValidationError,
            "semver",
        )

    def test_validate_workflow_no_input(self):
        MLSchema.populate_registry()
        self.generic_schema_validator(
            SampleSchema.TEST.WORKFLOW_STEP,
            SampleSubmissions.UNIT_TESTS.WORKFLOW_NO_INPUT,
            ValidationError,
            "input",
        )

    def generic_schema_validator(self,
                                 test_schema,
                                 test_submission,
                                 exception_type=None,
                                 exception_string=None) -> MLObject:
        error_string = None
        try:
            instantiated_schema = MLSchema.create_schema(
                self.wrap_schema_with_mlschema_info(test_schema))  # noqa
        except Exception as e:
            self.assertTrue(isinstance(e, exception_type))
            error_string = str(e)

        if test_submission is not None:
            yaml_submission = convert_yaml_to_dict(
                self.wrap_submission_with_mlschema_info(
                    test_submission))  # noqa

            if exception_type is not None:
                with self.assertRaises(exception_type) as context:
                    instantiated_schema.load(yaml_submission)

                if context is not None:
                    error_string = str(context.exception)

        # if error string is not none, we threw an error, return
        if error_string is not None:
            if exception_string is not None:
                self.assertTrue(exception_string in error_string)
            else:
                print(error_string)  # Unexpected error, print it out
            return  # Raised an exception during loading dict, return

        return instantiated_schema.load(yaml_submission)

    def wrap_schema_with_mlschema_info(self, this_dict):
        return merge_two_dicts(self.schema_schema_info,
                               convert_yaml_to_dict(this_dict))

    def wrap_submission_with_mlschema_info(self, this_dict):
        return merge_two_dicts(self.submission_schema_info,
                               convert_yaml_to_dict(this_dict))
Пример #16
0
 def wrap_submission_with_mlschema_info(self, this_dict):
     return merge_two_dicts(self.submission_schema_info,
                            convert_yaml_to_dict(this_dict))
Пример #17
0
 def create_object(submission_text: str):
     """ Creates an object that can be read and written to. """
     submission_dict = convert_yaml_to_dict(submission_text)
     schema = MLSchema.load_schema_from_registry(data=submission_dict)
     return schema().load(submission_dict)
Пример #18
0
    def append_schema_to_registry(load_path: Path) -> bool:
        if isinstance(load_path, str):
            load_path = Path(load_path)

        if not isinstance(load_path, Path):
            raise TypeError(
                "Appending schemas to a registry expects a Path object.")

        all_found_files = list(load_path.glob("**/*.yaml"))

        if len(all_found_files) == 0:
            raise FileNotFoundError(
                f"No files ending in '.yaml' were found in the path '{load_path}'"
            )

        all_schemas = []
        files_with_errors = []

        no_base_schemas = []
        schemas_with_base = []
        last_schemas = []

        for putative_schema_file in all_found_files:
            this_text = putative_schema_file.read_text('utf-8')
            try:
                this_dict = convert_yaml_to_dict(this_text)
            except ScannerError as se:
                files_with_errors.append((
                    putative_schema_file.name,
                    f"Yaml could not be parsed. Error details: \n{str(se)}",
                ))
                continue

            if not contains_minimum_fields_for_schema(this_dict):
                files_with_errors.append((
                    putative_schema_file.name,
                    """Does not contain all of the minimum schema necessary as top level fields - list includes: mlspec_schema_version, mlspec_schema_version.meta, mlspec_base_type, mlspec_base_type.meta, mlspec_schema_type, mlspec_schema_type.meta, schema_version and schema_type""",
                ))
                continue

            if "last" in this_dict["mlspec_schema_type"]:
                last_schemas.append(this_dict)
            elif "mlspec_base_type" in this_dict:
                schemas_with_base.append(this_dict)
            else:
                no_base_schemas.append(this_dict)

        all_schemas = no_base_schemas + schemas_with_base + last_schemas

        if len(files_with_errors) > 0:
            rootLogger = logging.getLogger()
            error_string = ""
            for err in files_with_errors:
                error_string += f"::CRITICAL - {err[0]}: {err[1]}\n"

            # TODO: Move to root logger
            rootLogger.critical(error_string)
            return False

        for schema_dict in all_schemas:
            schema_name = build_schema_name_for_schema(
                mlspec_schema_type=schema_dict["mlspec_schema_type"],
                mlspec_schema_version=schema_dict["mlspec_schema_version"],
            )
            try:
                marshmallow.class_registry.get_class(schema_name)
            except RegistryError:
                MLSchema.create_schema(schema_dict)
Пример #19
0
class ValidatorsTestSuite(unittest.TestCase):
    """Validators test cases."""

    schema_schema_info = convert_yaml_to_dict("""
mlspec_schema_version:
    # Identifies the version of this schema
    meta: 0.0.1

mlspec_schema_type:
    # Base schema type that this extends
    meta: base

schema_version:
  # Identifies version of MLSpec to use
  type: semver
  required: True
schema_type:
  # Identifies version of MLSpec to use
  type: allowed_schema_types
  required: True
""")
    submission_schema_info = {'schema_version': '0.0.1', 'schema_type': 'base'}

    def test_semver_found(self):
        assert MLSchemaValidators.validate_type_semver('0.0.1')

    def test_semver_not_found(self):
        assert not MLSchemaValidators.validate_type_semver('x.x.x')

    def test_uuid_found(self):
        instantiated_schema = MLSchema.create_schema(
            self.wrap_schema_with_mlschema_info(
                SampleSchema.TEST.UUID))  # noqa
        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.UUID_VALID))  # noqa
        instantiated_object = instantiated_schema.load(yaml_submission)

        self.assertTrue(instantiated_object['run_id'])

    def test_uuid_not_found(self):
        instantiated_schema = MLSchema.create_schema(
            self.wrap_schema_with_mlschema_info(
                SampleSchema.TEST.UUID))  # noqa
        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.UUID_INVALID))  # noqa
        with self.assertRaises(ValidationError):
            instantiated_schema.load(yaml_submission)

    def test_uri_valid(self):
        instantiated_schema = MLSchema.create_schema(
            self.wrap_schema_with_mlschema_info(SampleSchema.TEST.URI))  # noqa
        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.URI_VALID_1))  # noqa
        instantiated_object = instantiated_schema.load(yaml_submission)

        self.assertTrue(instantiated_object['endpoint'])

        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.URI_VALID_2))  # noqa
        instantiated_object = instantiated_schema.load(yaml_submission)

        self.assertTrue(instantiated_object['endpoint'])

    def test_uri_invalid(self):
        instantiated_schema = MLSchema.create_schema(
            self.wrap_schema_with_mlschema_info(SampleSchema.TEST.URI))  # noqa
        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.URI_INVALID_1))  # noqa
        with self.assertRaises(ValidationError):
            instantiated_schema.load(yaml_submission)

        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.URI_INVALID_2))  # noqa
        with self.assertRaises(ValidationError):
            instantiated_schema.load(yaml_submission)

    def test_regex_valid(self):
        instantiated_schema = MLSchema.create_schema(
            self.wrap_schema_with_mlschema_info(
                SampleSchema.TEST.REGEX))  # noqa
        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.REGEX_ALL_LETTERS))  # noqa
        instantiated_object = instantiated_schema.load(yaml_submission)

        self.assertTrue(instantiated_object['all_letters'])

        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.REGEX_ALL_NUMBERS))  # noqa
        with self.assertRaises(ValidationError):
            instantiated_schema.load(yaml_submission)

    def test_regex_invalid(self):
        with self.assertRaises(AssertionError):
            MLSchema.create_schema(
                self.wrap_schema_with_mlschema_info(
                    SampleSchema.TEST.INVALID_REGEX))  # noqa

    @unittest.skip("NYI")
    def test_path(self):
        self.assertTrue(False)

    @unittest.skip("NYI")
    def test_bucket(self):
        self.assertTrue(False)

    def test_interfaces_valid(self):
        instantiated_schema = MLSchema.create_schema(
            self.wrap_schema_with_mlschema_info(
                SampleSchema.TEST.INTERFACE))  # noqa
        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.INTERFACE_VALID_UNNAMED))  # noqa
        instantiated_object = instantiated_schema.load(yaml_submission)

        self.assertTrue(len(instantiated_object['inputs']) == 2)

        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.INTERFACE_VALID_NAMED))  # noqa
        instantiated_object = instantiated_schema.load(yaml_submission)

        self.assertTrue(len(instantiated_object['inputs']) == 2)

    @unittest.skip("Type is not required in KFP (but it should be)")
    def test_interfaces_missing_type(self):
        instantiated_schema = MLSchema.create_schema(
            self.wrap_schema_with_mlschema_info(
                SampleSchema.TEST.INTERFACE))  # noqa
        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.INTERFACE_INVALID_MISSING_TYPE)
        )  # noqa

        with self.assertRaises(ValidationError) as context:
            instantiated_schema.load(yaml_submission)

        self.assertTrue(
            'No type' in context.exception.messages['inputs'][0][0])

    def test_interfaces_mismatch_type(self):
        instantiated_schema = MLSchema.create_schema(
            self.wrap_schema_with_mlschema_info(
                SampleSchema.TEST.INTERFACE))  # noqa
        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.INTERFACE_INVALID_MISMATCH_TYPE)
        )  # noqa

        with self.assertRaises(ValidationError) as context:
            instantiated_schema.load(yaml_submission)

        self.assertTrue(
            'valid default' in context.exception.messages['inputs'][0][0])

    def test_interfaces_type_unknown(self):
        instantiated_schema = MLSchema.create_schema(
            self.wrap_schema_with_mlschema_info(
                SampleSchema.TEST.INTERFACE))  # noqa
        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.INTERFACE_INVALID_TYPE_UNKNOWN_1)
        )  # noqa

        with self.assertRaises(ValidationError) as context:
            instantiated_schema.load(yaml_submission)

        self.assertTrue(
            'string or a dict' in context.exception.messages['inputs'][0][0])

    def wrap_schema_with_mlschema_info(self, this_dict):
        return merge_two_dicts(self.schema_schema_info,
                               convert_yaml_to_dict(this_dict))

    def wrap_submission_with_mlschema_info(self, this_dict):
        return merge_two_dicts(self.submission_schema_info,
                               convert_yaml_to_dict(this_dict))
class ValidatorsTestSuite(unittest.TestCase):
    """Validators test cases."""

    schema_schema_info = convert_yaml_to_dict("""
mlspec_schema_version:
    # Identifies the version of this schema
    meta: 0.0.1

mlspec_schema_type:
    # Base schema type that this extends
    meta: base

schema_version:
  # Identifies version of MLSpec to use
  type: semver
  required: True
schema_type:
  # Identifies version of MLSpec to use
  type: allowed_schema_types
  required: True
""")
    submission_schema_info = {'schema_version': '0.0.1', 'schema_type': 'base'}

    def test_semver_found(self):
        assert MLSchemaValidators.validate_type_semver('0.0.1')

    def test_semver_not_found(self):
        assert not MLSchemaValidators.validate_type_semver('x.x.x')

    def test_uuid_found(self):
        instantiated_schema = MLSchema.create_schema(
            self.wrap_schema_with_mlschema_info(
                SampleSchema.TEST.UUID))  # noqa
        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.UUID_VALID))  # noqa
        instantiated_object = instantiated_schema.load(yaml_submission)

        self.assertTrue(instantiated_object['run_id'])

    def test_uuid_not_found(self):
        instantiated_schema = MLSchema.create_schema(
            self.wrap_schema_with_mlschema_info(
                SampleSchema.TEST.UUID))  # noqa
        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.UUID_INVALID))  # noqa
        with self.assertRaises(ValidationError):
            instantiated_schema.load(yaml_submission)

    def test_uri_valid(self):
        instantiated_schema = MLSchema.create_schema(
            self.wrap_schema_with_mlschema_info(SampleSchema.TEST.URI))  # noqa
        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.URI_VALID_1))  # noqa
        instantiated_object = instantiated_schema.load(yaml_submission)

        self.assertTrue(instantiated_object['endpoint'])

        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.URI_VALID_2))  # noqa
        instantiated_object = instantiated_schema.load(yaml_submission)

        self.assertTrue(instantiated_object['endpoint'])

    def test_uri_invalid(self):
        instantiated_schema = MLSchema.create_schema(
            self.wrap_schema_with_mlschema_info(SampleSchema.TEST.URI))  # noqa
        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.URI_INVALID_1))  # noqa
        with self.assertRaises(ValidationError):
            instantiated_schema.load(yaml_submission)

        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.URI_INVALID_2))  # noqa
        with self.assertRaises(ValidationError):
            instantiated_schema.load(yaml_submission)

    def test_regex_valid(self):
        instantiated_schema = MLSchema.create_schema(
            self.wrap_schema_with_mlschema_info(
                SampleSchema.TEST.REGEX))  # noqa
        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.REGEX_ALL_LETTERS))  # noqa
        instantiated_object = instantiated_schema.load(yaml_submission)

        self.assertTrue(instantiated_object['all_letters'])

        yaml_submission = convert_yaml_to_dict(
            self.wrap_submission_with_mlschema_info(
                SampleSubmissions.UNIT_TESTS.REGEX_ALL_NUMBERS))  # noqa
        with self.assertRaises(ValidationError):
            instantiated_schema.load(yaml_submission)

    def test_regex_invalid(self):
        with self.assertRaises(AssertionError):
            MLSchema.create_schema(
                self.wrap_schema_with_mlschema_info(
                    SampleSchema.TEST.INVALID_REGEX))  # noqa

    @unittest.skip("NYI")
    def test_path(self):
        self.assertTrue(False)

    @unittest.skip("NYI")
    def test_bucket(self):
        self.assertTrue(False)

    def wrap_schema_with_mlschema_info(self, this_dict):
        return merge_two_dicts(self.schema_schema_info,
                               convert_yaml_to_dict(this_dict))

    def wrap_submission_with_mlschema_info(self, this_dict):
        return merge_two_dicts(self.submission_schema_info,
                               convert_yaml_to_dict(this_dict))