def _run_factory(self, name='test', experiment_id=None, config=None): m1 = models.SqlMetric(key='accuracy', value=0.89) m2 = models.SqlMetric(key='recal', value=0.89) p1 = models.SqlParam(key='loss', value='test param') p2 = models.SqlParam(key='blue', value='test param') if not experiment_id: experiment = self._experiment_factory('test exp') experiment_id = experiment.experiment_id config = { 'experiment_id': experiment_id, 'name': name, 'user_id': 'Anderson', 'run_uuid': uuid.uuid4().hex, 'status': entities.RunStatus.to_string(entities.RunStatus.SCHEDULED), 'source_type': entities.SourceType.to_string(entities.SourceType.NOTEBOOK), 'source_name': 'Python application', 'entry_point_name': 'main.py', 'start_time': int(time.time()), 'end_time': int(time.time()), 'source_version': mlflow.__version__, 'lifecycle_stage': entities.RunInfo.ACTIVE_LIFECYCLE, 'artifact_uri': '//' } run = models.SqlRun(**config) run.params.append(p1) run.params.append(p2) run.metrics.append(m1) run.metrics.append(m2) self.session.add(run) return run
def test_param_model(self): run_data = models.SqlParam(run_uuid='test', key='accuracy', value='test param') self.session.add(run_data) self.session.commit() params = self.session.query(models.SqlParam).all() self.assertEqual(len(params), 1) actual = params[0].to_mlflow_entity() self.assertEqual(actual.value, run_data.value) self.assertEqual(actual.key, run_data.key)
def test_run_data_model(self): m1 = models.SqlMetric(key='accuracy', value=0.89) m2 = models.SqlMetric(key='recal', value=0.89) p1 = models.SqlParam(key='loss', value='test param') p2 = models.SqlParam(key='blue', value='test param') self.session.add_all([m1, m2, p1, p2]) run_data = models.SqlRun(run_uuid=uuid.uuid4().hex) run_data.params.append(p1) run_data.params.append(p2) run_data.metrics.append(m1) run_data.metrics.append(m2) self.session.add(run_data) self.session.commit() run_datums = self.session.query(models.SqlRun).all() actual = run_datums[0] self.assertEqual(len(run_datums), 1) self.assertEqual(len(actual.params), 2) self.assertEqual(len(actual.metrics), 2)
def test_param_model(self): # Create a run whose UUID we can reference when creating parameter models. # `run_uuid` is a foreign key in the tags table; therefore, in order # to insert a parameter with a given run UUID, the UUID must be present in # the runs table run = self._run_factory() with self.store.ManagedSessionMaker() as session: new_param = models.SqlParam( run_uuid=run.info.run_uuid, key='accuracy', value='test param') session.add(new_param) session.commit() params = session.query(models.SqlParam).all() self.assertEqual(len(params), 1) added_param = params[0].to_mlflow_entity() self.assertEqual(added_param.value, new_param.value) self.assertEqual(added_param.key, new_param.key)