def test_run_info(self): experiment_id = self._experiment_factory('test exp') config = { 'experiment_id': experiment_id, 'name': 'test run', 'user_id': 'Anderson', 'run_uuid': 'test', 'status': RunStatus.to_string(RunStatus.SCHEDULED), 'source_type': SourceType.to_string(SourceType.LOCAL), 'source_name': 'Python application', 'entry_point_name': 'main.py', 'start_time': int(time.time()), 'end_time': int(time.time()), 'source_version': mlflow.__version__, 'lifecycle_stage': entities.LifecycleStage.ACTIVE, 'artifact_uri': '//' } run = models.SqlRun(**config).to_mlflow_entity() for k, v in config.items(): v2 = getattr(run.info, k) if k == 'source_type': self.assertEqual(v, SourceType.to_string(v2)) elif k == 'status': self.assertEqual(v, RunStatus.to_string(v2)) else: self.assertEqual(v, v2)
def test_run_needs_uuid(self): run = models.SqlRun() self.session.add(run) with self.assertRaises(sqlalchemy.exc.IntegrityError): warnings.simplefilter("ignore") with warnings.catch_warnings(): self.session.commit() warnings.resetwarnings()
def test_run_needs_uuid(self): # Depending on the implementation, a NULL identity key may result in different # exceptions, including IntegrityError (sqlite) and FlushError (MysQL). # Therefore, we check for the more generic 'SQLAlchemyError' with self.assertRaises(MlflowException) as exception_context: warnings.simplefilter("ignore") with self.store.ManagedSessionMaker() as session, warnings.catch_warnings(): run = models.SqlRun() session.add(run) warnings.resetwarnings() assert exception_context.exception.error_code == ErrorCode.Name(INTERNAL_ERROR)
def _run_factory(self, config=None): if not config: config = self._get_run_configs() experiment_id = config.get("experiment_id", None) if not experiment_id: experiment_id = self._experiment_factory('test exp') config["experiment_id"] = experiment_id run = models.SqlRun(**config) self.session.add(run) return run
def _run_factory(self, name='test', experiment_id=None, config=None): m1 = models.SqlMetric(key='accuracy', value=0.89) m2 = models.SqlMetric(key='recal', value=0.89) p1 = models.SqlParam(key='loss', value='test param') p2 = models.SqlParam(key='blue', value='test param') if not experiment_id: experiment = self._experiment_factory('test exp') experiment_id = experiment.experiment_id config = { 'experiment_id': experiment_id, 'name': name, 'user_id': 'Anderson', 'run_uuid': uuid.uuid4().hex, 'status': entities.RunStatus.to_string(entities.RunStatus.SCHEDULED), 'source_type': entities.SourceType.to_string(entities.SourceType.NOTEBOOK), 'source_name': 'Python application', 'entry_point_name': 'main.py', 'start_time': int(time.time()), 'end_time': int(time.time()), 'source_version': mlflow.__version__, 'lifecycle_stage': entities.RunInfo.ACTIVE_LIFECYCLE, 'artifact_uri': '//' } run = models.SqlRun(**config) run.params.append(p1) run.params.append(p2) run.metrics.append(m1) run.metrics.append(m2) self.session.add(run) return run
def test_run_info(self): experiment = self._experiment_factory('test exp') config = { 'experiment_id': experiment.experiment_id, 'name': 'test run', 'user_id': 'Anderson', 'run_uuid': 'test', 'status': entities.RunInfo.ACTIVE_LIFECYCLE, 'source_type': entities.SourceType.LOCAL, 'source_name': 'Python application', 'entry_point_name': 'main.py', 'start_time': int(time.time()), 'end_time': int(time.time()), 'source_version': mlflow.__version__, 'lifecycle_stage': entities.RunInfo.ACTIVE_LIFECYCLE, 'artifact_uri': '//' } run = models.SqlRun(**config).to_mlflow_entity() for k, v in config.items(): self.assertEqual(v, getattr(run.info, k))
def test_run_data_model(self): m1 = models.SqlMetric(key='accuracy', value=0.89) m2 = models.SqlMetric(key='recal', value=0.89) p1 = models.SqlParam(key='loss', value='test param') p2 = models.SqlParam(key='blue', value='test param') self.session.add_all([m1, m2, p1, p2]) run_data = models.SqlRun(run_uuid=uuid.uuid4().hex) run_data.params.append(p1) run_data.params.append(p2) run_data.metrics.append(m1) run_data.metrics.append(m2) self.session.add(run_data) self.session.commit() run_datums = self.session.query(models.SqlRun).all() actual = run_datums[0] self.assertEqual(len(run_datums), 1) self.assertEqual(len(actual.params), 2) self.assertEqual(len(actual.metrics), 2)