def test_create_run_with_parent_id(self): exp = self._experiment_factory('test_create_run_with_parent_id') expected = self._get_run_configs('booyya', experiment_id=exp) tags = [RunTag('3', '4'), RunTag('1', '2')] actual = self.store.create_run( expected["experiment_id"], expected["user_id"], expected["name"], SourceType.from_string(expected["source_type"]), expected["source_name"], expected["entry_point_name"], expected["start_time"], expected["source_version"], tags, "parent_uuid_5") self.assertEqual(actual.info.experiment_id, expected["experiment_id"]) self.assertEqual(actual.info.user_id, expected["user_id"]) self.assertEqual(actual.info.name, 'booyya') self.assertEqual(actual.info.source_type, SourceType.from_string(expected["source_type"])) self.assertEqual(actual.info.source_name, expected["source_name"]) self.assertEqual(actual.info.source_version, expected["source_version"]) self.assertEqual(actual.info.entry_point_name, expected["entry_point_name"]) self.assertEqual(actual.info.start_time, expected["start_time"]) self.assertEqual(len(actual.data.tags), 4) name_tag = models.SqlTag(key='mlflow.runName', value='booyya').to_mlflow_entity() parent_id_tag = models.SqlTag( key='mlflow.parentRunId', value='parent_uuid_5').to_mlflow_entity() self.assertListEqual(actual.data.tags, tags + [parent_id_tag, name_tag])
def test_create_run_with_parent_id(self): expected = self._run_factory() name = 'booyya' expected.tags.append(models.SqlTag(key='3', value='4')) expected.tags.append(models.SqlTag(key='1', value='2')) tags = [t.to_mlflow_entity() for t in expected.tags] actual = self.store.create_run( expected.experiment_id, expected.user_id, name, entities.SourceType.from_string(expected.source_type), expected.source_name, expected.entry_point_name, expected.start_time, expected.source_version, tags, "parent_uuid_5") self.assertEqual(actual.info.experiment_id, expected.experiment_id) self.assertEqual(actual.info.user_id, expected.user_id) self.assertEqual(actual.info.name, name) self.assertEqual(actual.info.source_type, expected.source_type) self.assertEqual(actual.info.source_name, expected.source_name) self.assertEqual(actual.info.source_version, expected.source_version) self.assertEqual(actual.info.entry_point_name, expected.entry_point_name) self.assertEqual(actual.info.start_time, expected.start_time) self.assertEqual(len(actual.data.tags), 4) name_tag = models.SqlTag(key='mlflow.runName', value=name).to_mlflow_entity() parent_id_tag = models.SqlTag( key='mlflow.parentRunId', value='parent_uuid_5').to_mlflow_entity() self.assertListEqual(actual.data.tags, tags + [parent_id_tag, name_tag])
def test_create_run(self): expected = self._run_factory() name = 'booyya' expected.tags.append(models.SqlTag(key='3', value='4')) expected.tags.append(models.SqlTag(key='1', value='2')) self.session.add_all([expected, expected.tags[0], expected.tags[1]]) self.session.commit() tags = [t.to_mlflow_entity() for t in expected.tags] actual = self.store.create_run( expected.experiment_id, expected.user_id, name, expected.source_type, expected.source_name, expected.entry_point_name, expected.start_time, expected.source_version, tags, -1) self.assertEqual(actual.info.experiment_id, expected.experiment_id) self.assertEqual(actual.info.user_id, expected.user_id) self.assertEqual(actual.info.name, name) self.assertEqual(actual.info.source_type, expected.source_type) self.assertEqual(actual.info.source_name, expected.source_name) self.assertEqual(actual.info.source_version, expected.source_version) self.assertEqual(actual.info.entry_point_name, expected.entry_point_name) self.assertEqual(actual.info.start_time, expected.start_time) self.assertEqual(len(actual.data.tags), 2) self.assertListEqual(actual.data.tags, tags)
def test_create_run_with_parent_id(self): run_name = "test-run-1" parent_run_id = "parent_uuid_5" experiment_id = self._experiment_factory('test_create_run') expected = self._get_run_configs( name=run_name, experiment_id=experiment_id, parent_run_id=parent_run_id) actual = self.store.create_run(**expected) self.assertEqual(actual.info.experiment_id, experiment_id) self.assertEqual(actual.info.user_id, expected["user_id"]) self.assertEqual(actual.info.name, run_name) self.assertEqual(actual.info.source_type, expected["source_type"]) self.assertEqual(actual.info.source_name, expected["source_name"]) self.assertEqual(actual.info.source_version, expected["source_version"]) self.assertEqual(actual.info.entry_point_name, expected["entry_point_name"]) self.assertEqual(actual.info.start_time, expected["start_time"]) # Run creation should add two additional tags containing the run name and parent run id. # Check for the existence of these two tags self.assertEqual(len(actual.data.tags), 2) name_tag = models.SqlTag(key=MLFLOW_RUN_NAME, value=run_name).to_mlflow_entity() parent_id_tag = models.SqlTag(key=MLFLOW_PARENT_RUN_ID, value=parent_run_id).to_mlflow_entity() self.assertListEqual(actual.data.tags, [parent_id_tag, name_tag])
def test_run_tag_model(self): run_data = models.SqlTag(run_uuid='tuuid', key='test', value='val') self.session.add(run_data) self.session.commit() tags = self.session.query(models.SqlTag).all() self.assertEqual(len(tags), 1) actual = tags[0].to_mlflow_entity() self.assertEqual(actual.value, run_data.value) self.assertEqual(actual.key, run_data.key)
def test_run_tag_model(self): # Create a run whose UUID we can reference when creating tag models. # `run_uuid` is a foreign key in the tags table; therefore, in order # to insert a tag with a given run UUID, the UUID must be present in # the runs table run = self._run_factory() with self.store.ManagedSessionMaker() as session: new_tag = models.SqlTag(run_uuid=run.info.run_uuid, key='test', value='val') session.add(new_tag) session.commit() added_tags = [ tag for tag in session.query(models.SqlTag).all() if tag.key == new_tag.key ] self.assertEqual(len(added_tags), 1) added_tag = added_tags[0].to_mlflow_entity() self.assertEqual(added_tag.value, new_tag.value)
def test_create_run_with_tags(self): run_name = "test-run-1" experiment_id = self._experiment_factory('test_create_run') tags = [RunTag('3', '4'), RunTag('1', '2')] expected = self._get_run_configs(name=run_name, experiment_id=experiment_id, tags=tags) actual = self.store.create_run(**expected) self.assertEqual(actual.info.experiment_id, experiment_id) self.assertEqual(actual.info.user_id, expected["user_id"]) self.assertEqual(actual.info.name, run_name) self.assertEqual(actual.info.source_type, expected["source_type"]) self.assertEqual(actual.info.source_name, expected["source_name"]) self.assertEqual(actual.info.source_version, expected["source_version"]) self.assertEqual(actual.info.entry_point_name, expected["entry_point_name"]) self.assertEqual(actual.info.start_time, expected["start_time"]) # Run creation should add an additional tag containing the run name. Check for # its existence self.assertEqual(len(actual.data.tags), len(tags) + 1) name_tag = models.SqlTag(key=MLFLOW_RUN_NAME, value=run_name).to_mlflow_entity() self.assertListEqual(actual.data.tags, tags + [name_tag])
def test_create_run(self): experiment_id = self._experiment_factory('test_create_run') expected = self._get_run_configs('booyya', experiment_id=experiment_id) tags = [RunTag('3', '4'), RunTag('1', '2')] actual = self.store.create_run(expected["experiment_id"], expected["user_id"], expected["name"], SourceType.from_string(expected["source_type"]), expected["source_name"], expected["entry_point_name"], expected["start_time"], expected["source_version"], tags, None) self.assertEqual(actual.info.experiment_id, expected["experiment_id"]) self.assertEqual(actual.info.user_id, expected["user_id"]) self.assertEqual(actual.info.name, 'booyya') self.assertEqual(actual.info.source_type, SourceType.from_string(expected["source_type"])) self.assertEqual(actual.info.source_name, expected["source_name"]) self.assertEqual(actual.info.source_version, expected["source_version"]) self.assertEqual(actual.info.entry_point_name, expected["entry_point_name"]) self.assertEqual(actual.info.start_time, expected["start_time"]) self.assertEqual(len(actual.data.tags), 3) name_tag = models.SqlTag(key=MLFLOW_RUN_NAME, value='booyya').to_mlflow_entity() self.assertListEqual(actual.data.tags, tags + [name_tag])