Exemple #1
0
    def test_create_ml_object(self):
        ml_object = MLObject()
        ml_object.set_type("0.0.1", MLSchemaTypes.DATAPATH)

        self.assertIsNotNone(ml_object)
        self.assertTrue(ml_object.schema_version == "0.0.1")
        self.assertTrue(ml_object.schema_type == MLSchemaTypes.DATAPATH.name.lower())
 def test_create_stub_nested_object(self):
     ml_object = MLObject()
     ml_object.set_type('0.0.1', MLSchemaTypes.DATAPATH)
     self.assertTrue(len(ml_object) == 13)
     self.assertIsNone(ml_object.run_date)
     self.assertTrue(len(ml_object.connection) == 3)
     self.assertIsNone(ml_object.connection.endpoint)
Exemple #3
0
    def test_live_interface_samples(self):
        MLSchema.populate_registry()

        print("Testing Keras")
        loaded_object, errors = MLObject.create_object_from_string(
            SampleSubmissions.FULL_SUBMISSIONS.COMPONENT_KERAS)

        pprint(errors)
        self.assertTrue(len(errors) == 0)
        self.assertIsNotNone(loaded_object.get_schema())

        print("Testing IBM")
        loaded_object, errors = MLObject.create_object_from_string(
            SampleSubmissions.FULL_SUBMISSIONS.COMPONENT_IBM)

        pprint(errors)
        self.assertTrue(len(errors) == 0)
        self.assertIsNotNone(loaded_object.get_schema())

        print("Testing OpenVino")
        loaded_object, errors = MLObject.create_object_from_string(
            SampleSubmissions.FULL_SUBMISSIONS.COMPONENT_OPENVINO)

        pprint(errors)
        self.assertTrue(len(errors) == 0)
        self.assertIsNotNone(loaded_object.get_schema())
Exemple #4
0
    def test_create_bad_semver(self):
        with self.assertRaises(ValueError):
            ml_object = MLObject()
            ml_object.set_type(schema_version="0.0.x", schema_type=MLSchemaTypes.BASE)

        with self.assertRaises(ValueError):
            ml_object = MLObject()
            ml_object.set_type(schema_version=None, schema_type=MLSchemaTypes.BASE)
    def test_load_object_from_disk(self):
        MLSchema.populate_registry()
        file_path = Path('tests/data/0/0/1/datapath.yaml')
        ml_object, _ = MLObject.create_object_from_file(file_path)

        self.assertIsNotNone(ml_object.run_date)
        self.assertIsNotNone(ml_object.connection.endpoint)
Exemple #6
0
    def test_all_data(self):
        MLSchema.populate_registry()
        all_data_files = list(Path('tests').glob('data/0/0/1/*.yaml'))

        self.assertTrue(len(all_data_files) > 1)

        for data_file in all_data_files:
            print(data_file)
            loaded_object, errors = MLObject.create_object_from_file(data_file)
            self.assertTrue(len(errors) == 0)
            self.assertIsNotNone(loaded_object.get_schema())
Exemple #7
0
    def test_code_gen_string(self):
        code_gen_string = MLObject._code_gen_string(
            "0.0.1", MLSchemaTypes.RUNCONFIG, "prefix"
        )

        fields_re = re.compile(r"^prefix.", flags=re.MULTILINE | re.DOTALL)
        self.assertTrue(len(fields_re.findall(code_gen_string)) == 54)

        hints_re = re.compile(r"^# prefix.", flags=re.MULTILINE | re.DOTALL)
        self.assertTrue(len(fields_re.findall(code_gen_string)) == 54)

        code_gen_string = MLObject._code_gen_string(
            "0.0.1", MLSchemaTypes.RUNCONFIG, "prefix", type_hints=False
        )

        fields_re = re.compile(r"^prefix.", flags=re.MULTILINE | re.DOTALL)
        self.assertTrue(len(fields_re.findall(code_gen_string)) == 54)

        hints_re = re.compile(r"^# prefix.", flags=re.MULTILINE | re.DOTALL)
        self.assertTrue(len(hints_re.findall(code_gen_string)) == 0)
    def test_load_live_data(self):
        MLSchema.populate_registry()
        all_data_files = list(Path("tests").glob("data/**/*.yaml"))

        for data_file in all_data_files:
            this_text = data_file.read_text(encoding="utf-8")
            this_dict = convert_yaml_to_dict(this_text)
            if this_dict["schema_version"] == "0.0.1":
                continue
            # print(data_file)
            loaded_object, errors = MLObject.create_object_from_file(data_file)
            self.assertTrue(len(errors) == 0)
            self.assertIsNotNone(loaded_object.get_schema())
    def test_load_live_data(self):
        MLSchema.populate_registry()
        all_data_files = list(Path('tests').glob('data/**/*.yaml'))

        for data_file in all_data_files:
            this_text = data_file.read_text()
            this_dict = convert_yaml_to_dict(this_text)
            if this_dict['schema_version'] == '0.0.1':
                continue
            print(data_file)
            loaded_object, errors = MLObject.create_object_from_file(data_file)
            self.assertTrue(len(errors) == 0)
            self.assertIsNotNone(loaded_object.get_schema())
Exemple #10
0
    def test_load_object_with_missing_field_from_variable(self):

        WRONG_DATAPATH = """
schema_version: 0.0.1
schema_type: datapath
run_id: 6a9a5931-1c1d-47cc-aaf3-ad8b03f70575
step_id: 0c98f080-4760-46be-b35f-7dbb5e2a88c2
run_date: 1970-01-01 00:00:00.00000
# data_store: I_am_a_datastore_name
storage_connection_type: AWS_BLOB
connection:
    # endpoint: S3://mybucket/puppy.jpg
    access_key_id: AKIAIOSFODNN7EXAMPLE
    secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"""

        ml_object, errors = MLObject.create_object_from_string(WRONG_DATAPATH)

        self.assertTrue(ml_object is None)
        self.assertTrue(len(errors) == 2)
        self.assertTrue(errors["data_store"][0] == "Field may not be null.")
        self.assertTrue(errors["connection"]["endpoint"][0] == "Field may not be null.")
Exemple #11
0
    def attach_step_info(
        self,
        mlobject: MLObject,
        workflow_version,
        workflow_node_id,
        step_name: str,
        step_type: str,
    ):
        if step_type not in ["input", "execution", "output", "log"]:
            raise ValueError(
                f"Error when saving '{mlobject.get_schema_name()}', the step_type must be from ['input', 'execution', 'output', 'log']."
            )
        run_info_id = build_vertex_id(
            step_name,
            step_type,
            mlobject.run_id,
            mlobject.run_date,
            workflow_version,
            self._workflow_partition_id,
        )
        mlobject_dict = mlobject.dict_without_internal_variables()
        property_string = convert_to_property_strings(mlobject_dict)
        raw_content = encode_raw_object_for_db(mlobject)
        add_run_info_query = f"""g.addV('id', '{run_info_id}'){property_string}.property('raw_content', '{raw_content}').property('workflow_node_id', '{workflow_node_id}').property('workflow_partition_id', '{self._workflow_partition_id}')"""

        self.execute_query(add_run_info_query)

        self.execute_query(
            sQuery(
                "g.V('id', '%s').out().hasId('%s').addE('results').to(g.V('%s')).executionProfile()",
                [workflow_node_id, step_name, run_info_id],
            ))
        self.execute_query(
            sQuery(
                "g.V('id', '%s').out().hasId('%s').addE('root').from(g.V('%s')).executionProfile()",
                [workflow_node_id, step_name, run_info_id],
            ))

        return run_info_id
 def test_create_stub_base_object(self):
     ml_object = MLObject()
     ml_object.set_type('0.0.1', MLSchemaTypes.BASE)
     self.assertIsNone(ml_object['run_date'])
     self.assertTrue(len(ml_object) == 10)
 def test_create_bad_schema_type(self):
     with self.assertRaises(KeyError):
         ml_object = MLObject()
         ml_object.set_type(schema_version='0.0.1', schema_type='foo')
Exemple #14
0
    def test_load_and_save_file(self):
        run_id = uuid.uuid4()

        save_path = Path(self.test_dir.name) / str(run_id)
        save_path.mkdir()

        datapath_object = MLObject()
        datapath_object.set_type('0.0.1', MLSchemaTypes.DATAPATH)

        datapath_object.run_id = run_id
        datapath_object.step_id = uuid.uuid4()
        datapath_object.run_date = datetime.datetime.now()
        datapath_object.data_store = None  # This is an intentional bug

        # This is an intentional bug (Should be AWS_BLOB)
        datapath_object.storage_connection_type = 'AWS_BLOB_OBJECT'
        datapath_object.connection.endpoint = None  # Another intentional bug
        datapath_object.connection.access_key_id = 'AKIAIOSFODNN7EXAMPLE'
        datapath_object.connection.secret_access_key = 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY'

        response, errors = datapath_object.save(save_path)

        self.assertFalse(response)
        self.assertTrue(len(errors) == 3)
        self.assertTrue(len(list(Path(save_path).glob('*'))) == 0)

        datapath_object.storage_connection_type = 'AWS_BLOB'

        response, errors = datapath_object.save(save_path)

        self.assertFalse(response)
        self.assertTrue(len(errors) == 2)
        self.assertTrue(len(list(Path(save_path).glob('*'))) == 0)

        datapath_object.connection.endpoint = 'http://s3.amazon.com/BUCKET'

        response, errors = datapath_object.save(save_path)

        self.assertFalse(response)
        self.assertTrue(len(errors) == 1)
        self.assertTrue(len(list(Path(save_path).glob('*'))) == 0)

        datapath_object.data_store = 'BUCKET NAME'

        response, errors = datapath_object.save(save_path)

        self.assertTrue(response)
        self.assertTrue(len(errors) == 0)
        path = Path(save_path)
        all_files = list(path.glob('*'))
        self.assertTrue(len(all_files) == 1)

        ml_object, errors = MLObject.create_object_from_file(all_files[0])
        self.assertTrue(len(ml_object) == 13)
        self.assertTrue(len(errors) == 0)

        self.assertTrue(datapath_object.data_store == ml_object.data_store)
        self.assertTrue(datapath_object.storage_connection_type ==
                        ml_object.storage_connection_type)
        self.assertTrue(datapath_object.connection.endpoint ==
                        ml_object.connection.endpoint)
Exemple #15
0
 def test_create_bad_schema_type(self):
     with self.assertRaises(RegistryError):
         ml_object = MLObject()
         ml_object.set_type(schema_version="0.0.1", schema_type="foo")
Exemple #16
0
    def test_cascading_inheritence(self):
        MLSchema.populate_registry()

        mlobject = MLObject()
        mlobject.set_type("0.0.1", "data_version_control")
        mlobject.run_id = uuid.uuid4()
        mlobject.step_id = uuid.uuid4()
        mlobject.run_date = datetime.datetime.now()
        mlobject.data_store = "I_am_a_datastore"
        mlobject.storage_connection_type = "AWS_BLOB"
        mlobject.connection.endpoint = "con_endpoint"
        mlobject.connection.access_key_id = "AKIAIOSFODNN7EXAMPLE"
        mlobject.connection.secret_access_key = (
            "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY")

        mlobject.dvc_hash = "923caceea54b38177505632f5612cc569a49b22246e346a7"
        mlobject.validate()