def test_create_ml_object(self): ml_object = MLObject() ml_object.set_type("0.0.1", MLSchemaTypes.DATAPATH) self.assertIsNotNone(ml_object) self.assertTrue(ml_object.schema_version == "0.0.1") self.assertTrue(ml_object.schema_type == MLSchemaTypes.DATAPATH.name.lower())
def test_create_stub_nested_object(self): ml_object = MLObject() ml_object.set_type('0.0.1', MLSchemaTypes.DATAPATH) self.assertTrue(len(ml_object) == 13) self.assertIsNone(ml_object.run_date) self.assertTrue(len(ml_object.connection) == 3) self.assertIsNone(ml_object.connection.endpoint)
def test_live_interface_samples(self): MLSchema.populate_registry() print("Testing Keras") loaded_object, errors = MLObject.create_object_from_string( SampleSubmissions.FULL_SUBMISSIONS.COMPONENT_KERAS) pprint(errors) self.assertTrue(len(errors) == 0) self.assertIsNotNone(loaded_object.get_schema()) print("Testing IBM") loaded_object, errors = MLObject.create_object_from_string( SampleSubmissions.FULL_SUBMISSIONS.COMPONENT_IBM) pprint(errors) self.assertTrue(len(errors) == 0) self.assertIsNotNone(loaded_object.get_schema()) print("Testing OpenVino") loaded_object, errors = MLObject.create_object_from_string( SampleSubmissions.FULL_SUBMISSIONS.COMPONENT_OPENVINO) pprint(errors) self.assertTrue(len(errors) == 0) self.assertIsNotNone(loaded_object.get_schema())
def test_create_bad_semver(self): with self.assertRaises(ValueError): ml_object = MLObject() ml_object.set_type(schema_version="0.0.x", schema_type=MLSchemaTypes.BASE) with self.assertRaises(ValueError): ml_object = MLObject() ml_object.set_type(schema_version=None, schema_type=MLSchemaTypes.BASE)
def test_load_object_from_disk(self): MLSchema.populate_registry() file_path = Path('tests/data/0/0/1/datapath.yaml') ml_object, _ = MLObject.create_object_from_file(file_path) self.assertIsNotNone(ml_object.run_date) self.assertIsNotNone(ml_object.connection.endpoint)
def test_all_data(self): MLSchema.populate_registry() all_data_files = list(Path('tests').glob('data/0/0/1/*.yaml')) self.assertTrue(len(all_data_files) > 1) for data_file in all_data_files: print(data_file) loaded_object, errors = MLObject.create_object_from_file(data_file) self.assertTrue(len(errors) == 0) self.assertIsNotNone(loaded_object.get_schema())
def test_code_gen_string(self): code_gen_string = MLObject._code_gen_string( "0.0.1", MLSchemaTypes.RUNCONFIG, "prefix" ) fields_re = re.compile(r"^prefix.", flags=re.MULTILINE | re.DOTALL) self.assertTrue(len(fields_re.findall(code_gen_string)) == 54) hints_re = re.compile(r"^# prefix.", flags=re.MULTILINE | re.DOTALL) self.assertTrue(len(fields_re.findall(code_gen_string)) == 54) code_gen_string = MLObject._code_gen_string( "0.0.1", MLSchemaTypes.RUNCONFIG, "prefix", type_hints=False ) fields_re = re.compile(r"^prefix.", flags=re.MULTILINE | re.DOTALL) self.assertTrue(len(fields_re.findall(code_gen_string)) == 54) hints_re = re.compile(r"^# prefix.", flags=re.MULTILINE | re.DOTALL) self.assertTrue(len(hints_re.findall(code_gen_string)) == 0)
def test_load_live_data(self): MLSchema.populate_registry() all_data_files = list(Path("tests").glob("data/**/*.yaml")) for data_file in all_data_files: this_text = data_file.read_text(encoding="utf-8") this_dict = convert_yaml_to_dict(this_text) if this_dict["schema_version"] == "0.0.1": continue # print(data_file) loaded_object, errors = MLObject.create_object_from_file(data_file) self.assertTrue(len(errors) == 0) self.assertIsNotNone(loaded_object.get_schema())
def test_load_live_data(self): MLSchema.populate_registry() all_data_files = list(Path('tests').glob('data/**/*.yaml')) for data_file in all_data_files: this_text = data_file.read_text() this_dict = convert_yaml_to_dict(this_text) if this_dict['schema_version'] == '0.0.1': continue print(data_file) loaded_object, errors = MLObject.create_object_from_file(data_file) self.assertTrue(len(errors) == 0) self.assertIsNotNone(loaded_object.get_schema())
def test_load_object_with_missing_field_from_variable(self): WRONG_DATAPATH = """ schema_version: 0.0.1 schema_type: datapath run_id: 6a9a5931-1c1d-47cc-aaf3-ad8b03f70575 step_id: 0c98f080-4760-46be-b35f-7dbb5e2a88c2 run_date: 1970-01-01 00:00:00.00000 # data_store: I_am_a_datastore_name storage_connection_type: AWS_BLOB connection: # endpoint: S3://mybucket/puppy.jpg access_key_id: AKIAIOSFODNN7EXAMPLE secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY""" ml_object, errors = MLObject.create_object_from_string(WRONG_DATAPATH) self.assertTrue(ml_object is None) self.assertTrue(len(errors) == 2) self.assertTrue(errors["data_store"][0] == "Field may not be null.") self.assertTrue(errors["connection"]["endpoint"][0] == "Field may not be null.")
def attach_step_info( self, mlobject: MLObject, workflow_version, workflow_node_id, step_name: str, step_type: str, ): if step_type not in ["input", "execution", "output", "log"]: raise ValueError( f"Error when saving '{mlobject.get_schema_name()}', the step_type must be from ['input', 'execution', 'output', 'log']." ) run_info_id = build_vertex_id( step_name, step_type, mlobject.run_id, mlobject.run_date, workflow_version, self._workflow_partition_id, ) mlobject_dict = mlobject.dict_without_internal_variables() property_string = convert_to_property_strings(mlobject_dict) raw_content = encode_raw_object_for_db(mlobject) add_run_info_query = f"""g.addV('id', '{run_info_id}'){property_string}.property('raw_content', '{raw_content}').property('workflow_node_id', '{workflow_node_id}').property('workflow_partition_id', '{self._workflow_partition_id}')""" self.execute_query(add_run_info_query) self.execute_query( sQuery( "g.V('id', '%s').out().hasId('%s').addE('results').to(g.V('%s')).executionProfile()", [workflow_node_id, step_name, run_info_id], )) self.execute_query( sQuery( "g.V('id', '%s').out().hasId('%s').addE('root').from(g.V('%s')).executionProfile()", [workflow_node_id, step_name, run_info_id], )) return run_info_id
def test_create_stub_base_object(self): ml_object = MLObject() ml_object.set_type('0.0.1', MLSchemaTypes.BASE) self.assertIsNone(ml_object['run_date']) self.assertTrue(len(ml_object) == 10)
def test_create_bad_schema_type(self): with self.assertRaises(KeyError): ml_object = MLObject() ml_object.set_type(schema_version='0.0.1', schema_type='foo')
def test_load_and_save_file(self): run_id = uuid.uuid4() save_path = Path(self.test_dir.name) / str(run_id) save_path.mkdir() datapath_object = MLObject() datapath_object.set_type('0.0.1', MLSchemaTypes.DATAPATH) datapath_object.run_id = run_id datapath_object.step_id = uuid.uuid4() datapath_object.run_date = datetime.datetime.now() datapath_object.data_store = None # This is an intentional bug # This is an intentional bug (Should be AWS_BLOB) datapath_object.storage_connection_type = 'AWS_BLOB_OBJECT' datapath_object.connection.endpoint = None # Another intentional bug datapath_object.connection.access_key_id = 'AKIAIOSFODNN7EXAMPLE' datapath_object.connection.secret_access_key = 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY' response, errors = datapath_object.save(save_path) self.assertFalse(response) self.assertTrue(len(errors) == 3) self.assertTrue(len(list(Path(save_path).glob('*'))) == 0) datapath_object.storage_connection_type = 'AWS_BLOB' response, errors = datapath_object.save(save_path) self.assertFalse(response) self.assertTrue(len(errors) == 2) self.assertTrue(len(list(Path(save_path).glob('*'))) == 0) datapath_object.connection.endpoint = 'http://s3.amazon.com/BUCKET' response, errors = datapath_object.save(save_path) self.assertFalse(response) self.assertTrue(len(errors) == 1) self.assertTrue(len(list(Path(save_path).glob('*'))) == 0) datapath_object.data_store = 'BUCKET NAME' response, errors = datapath_object.save(save_path) self.assertTrue(response) self.assertTrue(len(errors) == 0) path = Path(save_path) all_files = list(path.glob('*')) self.assertTrue(len(all_files) == 1) ml_object, errors = MLObject.create_object_from_file(all_files[0]) self.assertTrue(len(ml_object) == 13) self.assertTrue(len(errors) == 0) self.assertTrue(datapath_object.data_store == ml_object.data_store) self.assertTrue(datapath_object.storage_connection_type == ml_object.storage_connection_type) self.assertTrue(datapath_object.connection.endpoint == ml_object.connection.endpoint)
def test_create_bad_schema_type(self): with self.assertRaises(RegistryError): ml_object = MLObject() ml_object.set_type(schema_version="0.0.1", schema_type="foo")
def test_cascading_inheritence(self): MLSchema.populate_registry() mlobject = MLObject() mlobject.set_type("0.0.1", "data_version_control") mlobject.run_id = uuid.uuid4() mlobject.step_id = uuid.uuid4() mlobject.run_date = datetime.datetime.now() mlobject.data_store = "I_am_a_datastore" mlobject.storage_connection_type = "AWS_BLOB" mlobject.connection.endpoint = "con_endpoint" mlobject.connection.access_key_id = "AKIAIOSFODNN7EXAMPLE" mlobject.connection.secret_access_key = ( "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY") mlobject.dvc_hash = "923caceea54b38177505632f5612cc569a49b22246e346a7" mlobject.validate()