예제 #1
0
    def test_live_interface_samples(self):
        MLSchema.populate_registry()

        print("Testing Keras")
        loaded_object, errors = MLObject.create_object_from_string(
            SampleSubmissions.FULL_SUBMISSIONS.COMPONENT_KERAS)

        pprint(errors)
        self.assertTrue(len(errors) == 0)
        self.assertIsNotNone(loaded_object.get_schema())

        print("Testing IBM")
        loaded_object, errors = MLObject.create_object_from_string(
            SampleSubmissions.FULL_SUBMISSIONS.COMPONENT_IBM)

        pprint(errors)
        self.assertTrue(len(errors) == 0)
        self.assertIsNotNone(loaded_object.get_schema())

        print("Testing OpenVino")
        loaded_object, errors = MLObject.create_object_from_string(
            SampleSubmissions.FULL_SUBMISSIONS.COMPONENT_OPENVINO)

        pprint(errors)
        self.assertTrue(len(errors) == 0)
        self.assertIsNotNone(loaded_object.get_schema())
예제 #2
0
    def test_load_object_from_disk(self):
        MLSchema.populate_registry()
        file_path = Path('tests/data/0/0/1/datapath.yaml')
        ml_object, _ = MLObject.create_object_from_file(file_path)

        self.assertIsNotNone(ml_object.run_date)
        self.assertIsNotNone(ml_object.connection.endpoint)
예제 #3
0
 def test_validate_workflow_no_input(self):
     MLSchema.populate_registry()
     self.generic_schema_validator(
         SampleSchema.TEST.WORKFLOW_STEP,
         SampleSubmissions.UNIT_TESTS.WORKFLOW_NO_INPUT,
         ValidationError,
         "input",
     )
예제 #4
0
 def test_validate_workflow_bad_semver(self):
     MLSchema.populate_registry()
     self.generic_schema_validator(
         SampleSchema.TEST.WORKFLOW_STEP,
         SampleSubmissions.UNIT_TESTS.WORKFLOW_BAD_SEMVER,
         ValidationError,
         "semver",
     )
예제 #5
0
    def test_load_file_from_disk(self):
        all_objects = []

        MLSchema.populate_registry()

        all_objects.append(
            IO.get_content_from_path(Path('tests/data/0/0/1/datapath.yaml')))

        self.assertTrue(len(all_objects) == 1)
예제 #6
0
 def test_validate_workflow_valid(self):
     MLSchema.populate_registry()
     this_schema = self.generic_schema_validator(
         SampleSchema.TEST.WORKFLOW_STEP,
         SampleSubmissions.UNIT_TESTS.WORKFLOW_VALID,
         None,
         None,
     )
     self.assertTrue(len(this_schema["steps"]["step_name"]) == 5)
예제 #7
0
    def test_all_schemas(self):
        MLSchema.populate_registry()
        all_001_schemas = list(Path('mlspeclib').glob('schemas/0/0/1/*.yaml'))

        self.assertTrue(len(all_001_schemas) > 1)

        for schema in all_001_schemas:
            this_text = schema.read_text()
            loaded_schema = MLSchema.create_schema(this_text)
            self.assertIsNotNone(loaded_schema.schema_name)
예제 #8
0
    def test_all_data(self):
        MLSchema.populate_registry()
        all_data_files = list(Path('tests').glob('data/0/0/1/*.yaml'))

        self.assertTrue(len(all_data_files) > 1)

        for data_file in all_data_files:
            print(data_file)
            loaded_object, errors = MLObject.create_object_from_file(data_file)
            self.assertTrue(len(errors) == 0)
            self.assertIsNotNone(loaded_object.get_schema())
예제 #9
0
    def test_load_live_schemas(self):
        MLSchema.populate_registry()
        all_schema_paths = list(Path("mlspeclib").glob("schemas/**/*.yaml"))

        for schema_path in all_schema_paths:
            this_text = schema_path.read_text(encoding="utf-8")
            this_dict = convert_yaml_to_dict(this_text)
            if this_dict["mlspec_schema_version"]["meta"] == "0.0.1":
                continue
            # print(schema_path)
            loaded_schema = MLSchema.create_schema(this_text)
            self.assertIsNotNone(loaded_schema.schema_name)
예제 #10
0
    def test_load_live_schemas(self):
        MLSchema.populate_registry()
        all_schema_paths = list(Path('mlspeclib').glob('schemas/**/*.yaml'))

        for schema_path in all_schema_paths:
            this_text = schema_path.read_text()
            this_dict = convert_yaml_to_dict(this_text)
            if this_dict['mlspec_schema_version']['meta'] == '0.0.1':
                continue
            print(schema_path)
            loaded_schema = MLSchema.create_schema(this_text)
            self.assertIsNotNone(loaded_schema.schema_name)
예제 #11
0
    def test_load_live_data(self):
        MLSchema.populate_registry()
        all_data_files = list(Path("tests").glob("data/**/*.yaml"))

        for data_file in all_data_files:
            this_text = data_file.read_text(encoding="utf-8")
            this_dict = convert_yaml_to_dict(this_text)
            if this_dict["schema_version"] == "0.0.1":
                continue
            # print(data_file)
            loaded_object, errors = MLObject.create_object_from_file(data_file)
            self.assertTrue(len(errors) == 0)
            self.assertIsNotNone(loaded_object.get_schema())
예제 #12
0
    def test_load_live_data(self):
        MLSchema.populate_registry()
        all_data_files = list(Path('tests').glob('data/**/*.yaml'))

        for data_file in all_data_files:
            this_text = data_file.read_text()
            this_dict = convert_yaml_to_dict(this_text)
            if this_dict['schema_version'] == '0.0.1':
                continue
            print(data_file)
            loaded_object, errors = MLObject.create_object_from_file(data_file)
            self.assertTrue(len(errors) == 0)
            self.assertIsNotNone(loaded_object.get_schema())
예제 #13
0
    def test_add_schemas_from_url(self):

        MLSchema.populate_registry()
        load_url = "https://github.com/mlspec/mlspeclib-action-samples-schemas"

        current_schema_total = len(marshmallow.class_registry._registry)

        mock_stdout = StringIO()
        rootLogger = logging.getLogger()
        rootLogger.addHandler(logging.StreamHandler(mock_stdout))

        GitHubSchemas.add_schemas_from_github_url(load_url)

        self.assertTrue((current_schema_total +
                         58) == len(marshmallow.class_registry._registry))
예제 #14
0
    def create_stub_object(self):
        """ Creates a stub dictionary based on the schema with all values set to None.
        We do this because our goal is to prevent (eventually) creation of new attributes
        directly by users - only allowing them to use what we already provide to them
        to keep the schema in sync with the object."""

        MLSchema.populate_registry()
        version_number = self.get_schema_version()
        self.__schema_name = MLSchema.return_schema_name(version_number,
                                                         self.get_schema_type().name)
        object_schema = marshmallow.class_registry.get_class(self.get_schema_name())
        self.__schema = object_schema()
        self.__schema_object = None
        these_fields = object_schema().fields
        this_key_dict = recursive_fromkeys(these_fields)
        self.merge_update(this_key_dict)
예제 #15
0
    def test_cascading_inheritence(self):
        MLSchema.populate_registry()

        mlobject = MLObject()
        mlobject.set_type("0.0.1", "data_version_control")
        mlobject.run_id = uuid.uuid4()
        mlobject.step_id = uuid.uuid4()
        mlobject.run_date = datetime.datetime.now()
        mlobject.data_store = "I_am_a_datastore"
        mlobject.storage_connection_type = "AWS_BLOB"
        mlobject.connection.endpoint = "con_endpoint"
        mlobject.connection.access_key_id = "AKIAIOSFODNN7EXAMPLE"
        mlobject.connection.secret_access_key = (
            "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY")

        mlobject.dvc_hash = "923caceea54b38177505632f5612cc569a49b22246e346a7"
        mlobject.validate()
예제 #16
0
    def test_add_schema_to_registry(self):

        MLSchema.populate_registry()
        load_path = Path(os.path.dirname(__file__)) / str("external_schema")

        with self.assertRaises(FileNotFoundError) as context:
            MLSchema.append_schema_to_registry(
                "./external_schema"
            )  # This is a string and not a path, so should error.
        self.assertTrue("No files ending in" in str(context.exception))

        with self.assertRaises(FileNotFoundError) as context:
            MLSchema.append_schema_to_registry(
                Path("./external_schema"
                     ))  # This is a string and not a path, so should error.

        self.assertTrue("No files ending in" in str(context.exception))

        current_schema_total = len(marshmallow.class_registry._registry)

        # mymodule.urlprint(protocol, host, domain)
        # self.assertEqual(fake_out.getvalue(), expected_url)

        with self.assertRaises(FileNotFoundError):
            MLSchema.append_schema_to_registry(load_path / str("no_files"))

        mock_stdout = StringIO()
        rootLogger = logging.getLogger()
        rootLogger.addHandler(logging.StreamHandler(mock_stdout))

        MLSchema.append_schema_to_registry(load_path / str("bad_yaml"))
        return_string = mock_stdout.getvalue()
        assert "bad.yaml" in return_string
        assert "parsed" in return_string
        mock_stdout.seek(2)

        return_string = ""
        MLSchema.append_schema_to_registry(load_path / str("missing_fields"))
        return_string = mock_stdout.getvalue()
        assert "missing_fields.yaml" in return_string
        assert "mlspec" in return_string
        mock_stdout.seek(2)

        return_string = ""
        temp_dir = tempfile.gettempdir()
        file_name = f"{str(uuid.uuid4())}.yaml"
        temp_yaml_file = (Path(temp_dir) / file_name).write_text("")
        MLSchema.append_schema_to_registry(temp_dir)
        return_string = mock_stdout.getvalue()
        assert str(temp_yaml_file) in return_string
        assert "mlspec" in return_string
        mock_stdout.seek(2)

        return_string = ""
        valid_path = load_path / str("valid")
        MLSchema.append_schema_to_registry(valid_path)

        new_schemas = os.listdir(valid_path)
        self.assertTrue((
            current_schema_total +
            2 * len(new_schemas)) == len(marshmallow.class_registry._registry))