def delete(self, model_id): with db_handler.session_scope() as session: model = ModelService(session).drop(model_id) if not model: raise NotFoundException( f'Failed to find model: {model_id}') return {'data': model.to_dict()}, HTTPStatus.OK
def get(self, model_id): detail_level = request.args.get('detail_level', '') with db_handler.session_scope() as session: model_json = ModelService(session).query(model_id, detail_level) if not model_json: raise NotFoundException( f'Failed to find model: {model_id}') return {'data': model_json}, HTTPStatus.OK
def setUp(self, mock_get_checkpoint_path): super().setUp() self.model_service = ModelService(db.session) self.train_job = Job(name='train-job', job_type=JobType.NN_MODEL_TRANINING, workflow_id=1, project_id=1) self.eval_job = Job(name='eval-job', job_type=JobType.NN_MODEL_EVALUATION, workflow_id=1, project_id=1) mock_get_checkpoint_path.return_value = 'output' self.model_service.create(job=self.train_job, parent_job_name=None) model = db.session.query(Model).filter_by( job_name=self.train_job.name).one() self.model_service.create(job=self.eval_job, parent_job_name=model.job_name) db.session.add(self.train_job) db.session.add(self.eval_job) db.session.commit()
def _setup_jobs(self): if self.forked_from is not None: trunk = Workflow.query.get(self.forked_from) assert trunk is not None, \ 'Source workflow %d not found' % self.forked_from trunk_job_defs = trunk.get_config().job_definitions trunk_name2index = { job.name: i for i, job in enumerate(trunk_job_defs) } job_defs = self.get_config().job_definitions flags = self.get_create_job_flags() assert len(job_defs) == len(flags), \ 'Number of job defs does not match number of create_job_flags ' \ '%d vs %d'%(len(job_defs), len(flags)) jobs = [] for i, (job_def, flag) in enumerate(zip(job_defs, flags)): if flag == common_pb2.CreateJobFlag.REUSE: assert job_def.name in trunk_name2index, \ f'Job {job_def.name} not found in base workflow' j = trunk.get_job_ids()[trunk_name2index[job_def.name]] job = Job.query.get(j) assert job is not None, \ 'Job %d not found' % j # TODO: check forked jobs does not depend on non-forked jobs else: job = Job( name=f'{self.uuid}-{job_def.name}', job_type=JobType(job_def.job_type), config=job_def.SerializeToString(), workflow_id=self.id, project_id=self.project_id, state=JobState.NEW, is_disabled=(flag == common_pb2.CreateJobFlag.DISABLED)) db.session.add(job) jobs.append(job) db.session.flush() name2index = {job.name: i for i, job in enumerate(job_defs)} for i, (job, flag) in enumerate(zip(jobs, flags)): if flag == common_pb2.CreateJobFlag.REUSE: continue for j, dep_def in enumerate(job.get_config().dependencies): dep = JobDependency( src_job_id=jobs[name2index[dep_def.source]].id, dst_job_id=job.id, dep_index=j) db.session.add(dep) self.set_job_ids([job.id for job in jobs]) if Features.FEATURE_MODEL_WORKFLOW_HOOK: for job in jobs: ModelService(db.session).workflow_hook(job)
def get(self): detail_level = request.args.get('detail_level', '') # TODO serialized query may incur performance penalty with db_handler.session_scope() as session: model_list = [ ModelService(session).query(m.id, detail_level) for m in Model.query.filter( Model.type.in_([ ModelType.NN_MODEL.value, ModelType.TREE_MODEL.value ])).all() ] return {'data': model_list}, HTTPStatus.OK
def _event_consumer(self): # TODO(xiangyuxuan): do more business level operations while True: try: event = self._queue.get() k8s_cache.update_cache(event) # job state must be updated before model service self._update_hook(event) if Features.FEATURE_MODEL_K8S_HOOK: with session_context() as session: ModelService(session).k8s_watcher_hook(event) session.commit() except Exception as e: # pylint: disable=broad-except logging.error(f'K8s event_consumer : {str(e)}. ' f'traceback:{traceback.format_exc()}')
def test_on_job_update(self, mock_plot_metrics: MagicMock): mock_plot_metrics.return_value = 'plot metrics return' # TODO: change get_session to db.session_scope with get_session(db.engine) as session: model = session.query(Model).filter_by( job_name=self.train_job.name).one() self.assertEqual(model.state, ModelState.COMMITTED.value) train_job = session.query(Job).filter_by(name='train-job').one() train_job.state = JobState.STARTED session.commit() # TODO: change get_session to db.session_scope with get_session(db.engine) as session: train_job = session.query(Job).filter_by(name='train-job').one() train_job.state = JobState.STARTED model = session.query(Model).filter_by( job_name=self.train_job.name).one() model_service = ModelService(session) model_service.on_job_update(train_job) self.assertEqual(model.state, ModelState.RUNNING.value) session.commit() # TODO: change get_session to db.session_scope with get_session(db.engine) as session: train_job = session.query(Job).filter_by(name='train-job').one() train_job.state = JobState.COMPLETED model = session.query(Model).filter_by( job_name=self.train_job.name).one() model_service = ModelService(session) model_service.on_job_update(train_job) self.assertEqual(model.state, ModelState.SUCCEEDED.value) session.commit() # TODO: change get_session to db.session_scope with get_session(db.engine) as session: train_job = session.query(Job).filter_by(name='train-job').one() train_job.state = JobState.FAILED model = session.query(Model).filter_by( job_name=self.train_job.name).one() model_service = ModelService(session) model_service.on_job_update(train_job) self.assertEqual(model.state, ModelState.FAILED.value) session.commit()
class ModelTest(BaseTestCase): @patch( 'fedlearner_webconsole.mmgr.service.ModelService.get_checkpoint_path') def setUp(self, mock_get_checkpoint_path): super().setUp() self.model_service = ModelService(db.session) self.train_job = Job(name='train-job', job_type=JobType.NN_MODEL_TRANINING, workflow_id=1, project_id=1) self.eval_job = Job(name='eval-job', job_type=JobType.NN_MODEL_EVALUATION, workflow_id=1, project_id=1) mock_get_checkpoint_path.return_value = 'output' self.model_service.create(job=self.train_job, parent_job_name=None) model = db.session.query(Model).filter_by( job_name=self.train_job.name).one() self.model_service.create(job=self.eval_job, parent_job_name=model.job_name) db.session.add(self.train_job) db.session.add(self.eval_job) db.session.commit() @patch('fedlearner_webconsole.mmgr.service.ModelService.plot_metrics') def test_on_job_update(self, mock_plot_metrics: MagicMock): mock_plot_metrics.return_value = 'plot metrics return' # TODO: change get_session to db.session_scope with get_session(db.engine) as session: model = session.query(Model).filter_by( job_name=self.train_job.name).one() self.assertEqual(model.state, ModelState.COMMITTED.value) train_job = session.query(Job).filter_by(name='train-job').one() train_job.state = JobState.STARTED session.commit() # TODO: change get_session to db.session_scope with get_session(db.engine) as session: train_job = session.query(Job).filter_by(name='train-job').one() train_job.state = JobState.STARTED model = session.query(Model).filter_by( job_name=self.train_job.name).one() model_service = ModelService(session) model_service.on_job_update(train_job) self.assertEqual(model.state, ModelState.RUNNING.value) session.commit() # TODO: change get_session to db.session_scope with get_session(db.engine) as session: train_job = session.query(Job).filter_by(name='train-job').one() train_job.state = JobState.COMPLETED model = session.query(Model).filter_by( job_name=self.train_job.name).one() model_service = ModelService(session) model_service.on_job_update(train_job) self.assertEqual(model.state, ModelState.SUCCEEDED.value) session.commit() # TODO: change get_session to db.session_scope with get_session(db.engine) as session: train_job = session.query(Job).filter_by(name='train-job').one() train_job.state = JobState.FAILED model = session.query(Model).filter_by( job_name=self.train_job.name).one() model_service = ModelService(session) model_service.on_job_update(train_job) self.assertEqual(model.state, ModelState.FAILED.value) session.commit() def test_hook(self): train_job = Job(id=0, state=JobState.STARTED, name='nn-train', job_type=JobType.NN_MODEL_TRANINING, workflow_id=0, project_id=0) db.session.add(train_job) db.session.commit() event = Event(flapp_name='nn-train', event_type=EventType.ADDED, obj_type=ObjectType.FLAPP, obj_dict={}) self.model_service.workflow_hook(train_job) model = Model.query.filter_by(job_name='nn-train').one() self.assertEqual(model.state, ModelState.COMMITTED.value) event.event_type = EventType.MODIFIED train_job.state = JobState.STARTED self.model_service.k8s_watcher_hook(event) self.assertEqual(model.state, ModelState.RUNNING.value) train_job.state = JobState.COMPLETED self.model_service.k8s_watcher_hook(event) self.assertEqual(model.state, ModelState.SUCCEEDED.value) train_job.state = JobState.STARTED self.model_service.k8s_watcher_hook(event) self.assertEqual(model.state, ModelState.RUNNING.value) self.assertEqual(model.version, 2) train_job.state = JobState.STOPPED self.model_service.k8s_watcher_hook(event) self.assertEqual(model.state, ModelState.PAUSED.value) db.session.rollback() def test_api(self): resp = self.get_helper('/api/v2/models/1') data = self.get_response_data(resp) self.assertEqual(data.get('id'), 1) resp = self.get_helper('/api/v2/models') model_list = self.get_response_data(resp) self.assertEqual(len(model_list), 1) model = Model.query.first() model.state = ModelState.FAILED.value db.session.add(model) db.session.commit() self.delete_helper('/api/v2/models/1') resp = self.get_helper('/api/v2/models/1') data = self.get_response_data(resp) self.assertEqual(data.get('state'), ModelState.DROPPED.value) def test_get_eval(self): model = Model.query.filter_by(job_name=self.train_job.name).one() self.assertEqual(len(model.get_eval_model()), 1)