def test_live_interface_samples(self): MLSchema.populate_registry() print("Testing Keras") loaded_object, errors = MLObject.create_object_from_string( SampleSubmissions.FULL_SUBMISSIONS.COMPONENT_KERAS) pprint(errors) self.assertTrue(len(errors) == 0) self.assertIsNotNone(loaded_object.get_schema()) print("Testing IBM") loaded_object, errors = MLObject.create_object_from_string( SampleSubmissions.FULL_SUBMISSIONS.COMPONENT_IBM) pprint(errors) self.assertTrue(len(errors) == 0) self.assertIsNotNone(loaded_object.get_schema()) print("Testing OpenVino") loaded_object, errors = MLObject.create_object_from_string( SampleSubmissions.FULL_SUBMISSIONS.COMPONENT_OPENVINO) pprint(errors) self.assertTrue(len(errors) == 0) self.assertIsNotNone(loaded_object.get_schema())
def test_load_object_from_disk(self): MLSchema.populate_registry() file_path = Path('tests/data/0/0/1/datapath.yaml') ml_object, _ = MLObject.create_object_from_file(file_path) self.assertIsNotNone(ml_object.run_date) self.assertIsNotNone(ml_object.connection.endpoint)
def test_validate_workflow_no_input(self): MLSchema.populate_registry() self.generic_schema_validator( SampleSchema.TEST.WORKFLOW_STEP, SampleSubmissions.UNIT_TESTS.WORKFLOW_NO_INPUT, ValidationError, "input", )
def test_validate_workflow_bad_semver(self): MLSchema.populate_registry() self.generic_schema_validator( SampleSchema.TEST.WORKFLOW_STEP, SampleSubmissions.UNIT_TESTS.WORKFLOW_BAD_SEMVER, ValidationError, "semver", )
def test_load_file_from_disk(self): all_objects = [] MLSchema.populate_registry() all_objects.append( IO.get_content_from_path(Path('tests/data/0/0/1/datapath.yaml'))) self.assertTrue(len(all_objects) == 1)
def test_validate_workflow_valid(self): MLSchema.populate_registry() this_schema = self.generic_schema_validator( SampleSchema.TEST.WORKFLOW_STEP, SampleSubmissions.UNIT_TESTS.WORKFLOW_VALID, None, None, ) self.assertTrue(len(this_schema["steps"]["step_name"]) == 5)
def test_all_schemas(self): MLSchema.populate_registry() all_001_schemas = list(Path('mlspeclib').glob('schemas/0/0/1/*.yaml')) self.assertTrue(len(all_001_schemas) > 1) for schema in all_001_schemas: this_text = schema.read_text() loaded_schema = MLSchema.create_schema(this_text) self.assertIsNotNone(loaded_schema.schema_name)
def test_all_data(self): MLSchema.populate_registry() all_data_files = list(Path('tests').glob('data/0/0/1/*.yaml')) self.assertTrue(len(all_data_files) > 1) for data_file in all_data_files: print(data_file) loaded_object, errors = MLObject.create_object_from_file(data_file) self.assertTrue(len(errors) == 0) self.assertIsNotNone(loaded_object.get_schema())
def test_load_live_schemas(self): MLSchema.populate_registry() all_schema_paths = list(Path("mlspeclib").glob("schemas/**/*.yaml")) for schema_path in all_schema_paths: this_text = schema_path.read_text(encoding="utf-8") this_dict = convert_yaml_to_dict(this_text) if this_dict["mlspec_schema_version"]["meta"] == "0.0.1": continue # print(schema_path) loaded_schema = MLSchema.create_schema(this_text) self.assertIsNotNone(loaded_schema.schema_name)
def test_load_live_schemas(self): MLSchema.populate_registry() all_schema_paths = list(Path('mlspeclib').glob('schemas/**/*.yaml')) for schema_path in all_schema_paths: this_text = schema_path.read_text() this_dict = convert_yaml_to_dict(this_text) if this_dict['mlspec_schema_version']['meta'] == '0.0.1': continue print(schema_path) loaded_schema = MLSchema.create_schema(this_text) self.assertIsNotNone(loaded_schema.schema_name)
def test_load_live_data(self): MLSchema.populate_registry() all_data_files = list(Path("tests").glob("data/**/*.yaml")) for data_file in all_data_files: this_text = data_file.read_text(encoding="utf-8") this_dict = convert_yaml_to_dict(this_text) if this_dict["schema_version"] == "0.0.1": continue # print(data_file) loaded_object, errors = MLObject.create_object_from_file(data_file) self.assertTrue(len(errors) == 0) self.assertIsNotNone(loaded_object.get_schema())
def test_load_live_data(self): MLSchema.populate_registry() all_data_files = list(Path('tests').glob('data/**/*.yaml')) for data_file in all_data_files: this_text = data_file.read_text() this_dict = convert_yaml_to_dict(this_text) if this_dict['schema_version'] == '0.0.1': continue print(data_file) loaded_object, errors = MLObject.create_object_from_file(data_file) self.assertTrue(len(errors) == 0) self.assertIsNotNone(loaded_object.get_schema())
def test_add_schemas_from_url(self): MLSchema.populate_registry() load_url = "https://github.com/mlspec/mlspeclib-action-samples-schemas" current_schema_total = len(marshmallow.class_registry._registry) mock_stdout = StringIO() rootLogger = logging.getLogger() rootLogger.addHandler(logging.StreamHandler(mock_stdout)) GitHubSchemas.add_schemas_from_github_url(load_url) self.assertTrue((current_schema_total + 58) == len(marshmallow.class_registry._registry))
def create_stub_object(self): """ Creates a stub dictionary based on the schema with all values set to None. We do this because our goal is to prevent (eventually) creation of new attributes directly by users - only allowing them to use what we already provide to them to keep the schema in sync with the object.""" MLSchema.populate_registry() version_number = self.get_schema_version() self.__schema_name = MLSchema.return_schema_name(version_number, self.get_schema_type().name) object_schema = marshmallow.class_registry.get_class(self.get_schema_name()) self.__schema = object_schema() self.__schema_object = None these_fields = object_schema().fields this_key_dict = recursive_fromkeys(these_fields) self.merge_update(this_key_dict)
def test_cascading_inheritence(self): MLSchema.populate_registry() mlobject = MLObject() mlobject.set_type("0.0.1", "data_version_control") mlobject.run_id = uuid.uuid4() mlobject.step_id = uuid.uuid4() mlobject.run_date = datetime.datetime.now() mlobject.data_store = "I_am_a_datastore" mlobject.storage_connection_type = "AWS_BLOB" mlobject.connection.endpoint = "con_endpoint" mlobject.connection.access_key_id = "AKIAIOSFODNN7EXAMPLE" mlobject.connection.secret_access_key = ( "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY") mlobject.dvc_hash = "923caceea54b38177505632f5612cc569a49b22246e346a7" mlobject.validate()
def test_add_schema_to_registry(self): MLSchema.populate_registry() load_path = Path(os.path.dirname(__file__)) / str("external_schema") with self.assertRaises(FileNotFoundError) as context: MLSchema.append_schema_to_registry( "./external_schema" ) # This is a string and not a path, so should error. self.assertTrue("No files ending in" in str(context.exception)) with self.assertRaises(FileNotFoundError) as context: MLSchema.append_schema_to_registry( Path("./external_schema" )) # This is a string and not a path, so should error. self.assertTrue("No files ending in" in str(context.exception)) current_schema_total = len(marshmallow.class_registry._registry) # mymodule.urlprint(protocol, host, domain) # self.assertEqual(fake_out.getvalue(), expected_url) with self.assertRaises(FileNotFoundError): MLSchema.append_schema_to_registry(load_path / str("no_files")) mock_stdout = StringIO() rootLogger = logging.getLogger() rootLogger.addHandler(logging.StreamHandler(mock_stdout)) MLSchema.append_schema_to_registry(load_path / str("bad_yaml")) return_string = mock_stdout.getvalue() assert "bad.yaml" in return_string assert "parsed" in return_string mock_stdout.seek(2) return_string = "" MLSchema.append_schema_to_registry(load_path / str("missing_fields")) return_string = mock_stdout.getvalue() assert "missing_fields.yaml" in return_string assert "mlspec" in return_string mock_stdout.seek(2) return_string = "" temp_dir = tempfile.gettempdir() file_name = f"{str(uuid.uuid4())}.yaml" temp_yaml_file = (Path(temp_dir) / file_name).write_text("") MLSchema.append_schema_to_registry(temp_dir) return_string = mock_stdout.getvalue() assert str(temp_yaml_file) in return_string assert "mlspec" in return_string mock_stdout.seek(2) return_string = "" valid_path = load_path / str("valid") MLSchema.append_schema_to_registry(valid_path) new_schemas = os.listdir(valid_path) self.assertTrue(( current_schema_total + 2 * len(new_schemas)) == len(marshmallow.class_registry._registry))