예제 #1
0
 def delete(self, model_id):
     with db_handler.session_scope() as session:
         model = ModelService(session).drop(model_id)
         if not model:
             raise NotFoundException(
                 f'Failed to find model: {model_id}')
         return {'data': model.to_dict()}, HTTPStatus.OK
예제 #2
0
 def get(self, model_id):
     detail_level = request.args.get('detail_level', '')
     with db_handler.session_scope() as session:
         model_json = ModelService(session).query(model_id, detail_level)
     if not model_json:
         raise NotFoundException(
             f'Failed to find model: {model_id}')
     return {'data': model_json}, HTTPStatus.OK
예제 #3
0
 def setUp(self, mock_get_checkpoint_path):
     super().setUp()
     self.model_service = ModelService(db.session)
     self.train_job = Job(name='train-job',
                          job_type=JobType.NN_MODEL_TRANINING,
                          workflow_id=1,
                          project_id=1)
     self.eval_job = Job(name='eval-job',
                         job_type=JobType.NN_MODEL_EVALUATION,
                         workflow_id=1,
                         project_id=1)
     mock_get_checkpoint_path.return_value = 'output'
     self.model_service.create(job=self.train_job, parent_job_name=None)
     model = db.session.query(Model).filter_by(
         job_name=self.train_job.name).one()
     self.model_service.create(job=self.eval_job,
                               parent_job_name=model.job_name)
     db.session.add(self.train_job)
     db.session.add(self.eval_job)
     db.session.commit()
예제 #4
0
    def _setup_jobs(self):
        if self.forked_from is not None:
            trunk = Workflow.query.get(self.forked_from)
            assert trunk is not None, \
                'Source workflow %d not found' % self.forked_from
            trunk_job_defs = trunk.get_config().job_definitions
            trunk_name2index = {
                job.name: i
                for i, job in enumerate(trunk_job_defs)
            }

        job_defs = self.get_config().job_definitions
        flags = self.get_create_job_flags()
        assert len(job_defs) == len(flags), \
            'Number of job defs does not match number of create_job_flags ' \
            '%d vs %d'%(len(job_defs), len(flags))
        jobs = []
        for i, (job_def, flag) in enumerate(zip(job_defs, flags)):
            if flag == common_pb2.CreateJobFlag.REUSE:
                assert job_def.name in trunk_name2index, \
                    f'Job {job_def.name} not found in base workflow'
                j = trunk.get_job_ids()[trunk_name2index[job_def.name]]
                job = Job.query.get(j)
                assert job is not None, \
                    'Job %d not found' % j
                # TODO: check forked jobs does not depend on non-forked jobs
            else:
                job = Job(
                    name=f'{self.uuid}-{job_def.name}',
                    job_type=JobType(job_def.job_type),
                    config=job_def.SerializeToString(),
                    workflow_id=self.id,
                    project_id=self.project_id,
                    state=JobState.NEW,
                    is_disabled=(flag == common_pb2.CreateJobFlag.DISABLED))
                db.session.add(job)
            jobs.append(job)
        db.session.flush()
        name2index = {job.name: i for i, job in enumerate(job_defs)}
        for i, (job, flag) in enumerate(zip(jobs, flags)):
            if flag == common_pb2.CreateJobFlag.REUSE:
                continue
            for j, dep_def in enumerate(job.get_config().dependencies):
                dep = JobDependency(
                    src_job_id=jobs[name2index[dep_def.source]].id,
                    dst_job_id=job.id,
                    dep_index=j)
                db.session.add(dep)

        self.set_job_ids([job.id for job in jobs])
        if Features.FEATURE_MODEL_WORKFLOW_HOOK:
            for job in jobs:
                ModelService(db.session).workflow_hook(job)
예제 #5
0
 def get(self):
     detail_level = request.args.get('detail_level', '')
     # TODO serialized query may incur performance penalty
     with db_handler.session_scope() as session:
         model_list = [
             ModelService(session).query(m.id, detail_level)
             for m in Model.query.filter(
                 Model.type.in_([
                     ModelType.NN_MODEL.value, ModelType.TREE_MODEL.value
                 ])).all()
         ]
     return {'data': model_list}, HTTPStatus.OK
예제 #6
0
 def _event_consumer(self):
     # TODO(xiangyuxuan): do more business level operations
     while True:
         try:
             event = self._queue.get()
             k8s_cache.update_cache(event)
             # job state must be updated before model service
             self._update_hook(event)
             if Features.FEATURE_MODEL_K8S_HOOK:
                 with session_context() as session:
                     ModelService(session).k8s_watcher_hook(event)
                     session.commit()
         except Exception as e:  # pylint: disable=broad-except
             logging.error(f'K8s event_consumer : {str(e)}. '
                           f'traceback:{traceback.format_exc()}')
예제 #7
0
    def test_on_job_update(self, mock_plot_metrics: MagicMock):
        mock_plot_metrics.return_value = 'plot metrics return'

        # TODO: change get_session to db.session_scope
        with get_session(db.engine) as session:
            model = session.query(Model).filter_by(
                job_name=self.train_job.name).one()
            self.assertEqual(model.state, ModelState.COMMITTED.value)

            train_job = session.query(Job).filter_by(name='train-job').one()
            train_job.state = JobState.STARTED
            session.commit()

        # TODO: change get_session to db.session_scope
        with get_session(db.engine) as session:
            train_job = session.query(Job).filter_by(name='train-job').one()
            train_job.state = JobState.STARTED
            model = session.query(Model).filter_by(
                job_name=self.train_job.name).one()
            model_service = ModelService(session)

            model_service.on_job_update(train_job)
            self.assertEqual(model.state, ModelState.RUNNING.value)
            session.commit()

        # TODO: change get_session to db.session_scope
        with get_session(db.engine) as session:
            train_job = session.query(Job).filter_by(name='train-job').one()
            train_job.state = JobState.COMPLETED
            model = session.query(Model).filter_by(
                job_name=self.train_job.name).one()
            model_service = ModelService(session)

            model_service.on_job_update(train_job)
            self.assertEqual(model.state, ModelState.SUCCEEDED.value)
            session.commit()

        # TODO: change get_session to db.session_scope
        with get_session(db.engine) as session:
            train_job = session.query(Job).filter_by(name='train-job').one()
            train_job.state = JobState.FAILED
            model = session.query(Model).filter_by(
                job_name=self.train_job.name).one()
            model_service = ModelService(session)

            model_service.on_job_update(train_job)
            self.assertEqual(model.state, ModelState.FAILED.value)
            session.commit()
예제 #8
0
class ModelTest(BaseTestCase):
    @patch(
        'fedlearner_webconsole.mmgr.service.ModelService.get_checkpoint_path')
    def setUp(self, mock_get_checkpoint_path):
        super().setUp()
        self.model_service = ModelService(db.session)
        self.train_job = Job(name='train-job',
                             job_type=JobType.NN_MODEL_TRANINING,
                             workflow_id=1,
                             project_id=1)
        self.eval_job = Job(name='eval-job',
                            job_type=JobType.NN_MODEL_EVALUATION,
                            workflow_id=1,
                            project_id=1)
        mock_get_checkpoint_path.return_value = 'output'
        self.model_service.create(job=self.train_job, parent_job_name=None)
        model = db.session.query(Model).filter_by(
            job_name=self.train_job.name).one()
        self.model_service.create(job=self.eval_job,
                                  parent_job_name=model.job_name)
        db.session.add(self.train_job)
        db.session.add(self.eval_job)
        db.session.commit()

    @patch('fedlearner_webconsole.mmgr.service.ModelService.plot_metrics')
    def test_on_job_update(self, mock_plot_metrics: MagicMock):
        mock_plot_metrics.return_value = 'plot metrics return'

        # TODO: change get_session to db.session_scope
        with get_session(db.engine) as session:
            model = session.query(Model).filter_by(
                job_name=self.train_job.name).one()
            self.assertEqual(model.state, ModelState.COMMITTED.value)

            train_job = session.query(Job).filter_by(name='train-job').one()
            train_job.state = JobState.STARTED
            session.commit()

        # TODO: change get_session to db.session_scope
        with get_session(db.engine) as session:
            train_job = session.query(Job).filter_by(name='train-job').one()
            train_job.state = JobState.STARTED
            model = session.query(Model).filter_by(
                job_name=self.train_job.name).one()
            model_service = ModelService(session)

            model_service.on_job_update(train_job)
            self.assertEqual(model.state, ModelState.RUNNING.value)
            session.commit()

        # TODO: change get_session to db.session_scope
        with get_session(db.engine) as session:
            train_job = session.query(Job).filter_by(name='train-job').one()
            train_job.state = JobState.COMPLETED
            model = session.query(Model).filter_by(
                job_name=self.train_job.name).one()
            model_service = ModelService(session)

            model_service.on_job_update(train_job)
            self.assertEqual(model.state, ModelState.SUCCEEDED.value)
            session.commit()

        # TODO: change get_session to db.session_scope
        with get_session(db.engine) as session:
            train_job = session.query(Job).filter_by(name='train-job').one()
            train_job.state = JobState.FAILED
            model = session.query(Model).filter_by(
                job_name=self.train_job.name).one()
            model_service = ModelService(session)

            model_service.on_job_update(train_job)
            self.assertEqual(model.state, ModelState.FAILED.value)
            session.commit()

    def test_hook(self):
        train_job = Job(id=0,
                        state=JobState.STARTED,
                        name='nn-train',
                        job_type=JobType.NN_MODEL_TRANINING,
                        workflow_id=0,
                        project_id=0)
        db.session.add(train_job)
        db.session.commit()
        event = Event(flapp_name='nn-train',
                      event_type=EventType.ADDED,
                      obj_type=ObjectType.FLAPP,
                      obj_dict={})
        self.model_service.workflow_hook(train_job)
        model = Model.query.filter_by(job_name='nn-train').one()
        self.assertEqual(model.state, ModelState.COMMITTED.value)

        event.event_type = EventType.MODIFIED
        train_job.state = JobState.STARTED
        self.model_service.k8s_watcher_hook(event)
        self.assertEqual(model.state, ModelState.RUNNING.value)

        train_job.state = JobState.COMPLETED
        self.model_service.k8s_watcher_hook(event)
        self.assertEqual(model.state, ModelState.SUCCEEDED.value)

        train_job.state = JobState.STARTED
        self.model_service.k8s_watcher_hook(event)
        self.assertEqual(model.state, ModelState.RUNNING.value)
        self.assertEqual(model.version, 2)

        train_job.state = JobState.STOPPED
        self.model_service.k8s_watcher_hook(event)
        self.assertEqual(model.state, ModelState.PAUSED.value)
        db.session.rollback()

    def test_api(self):
        resp = self.get_helper('/api/v2/models/1')
        data = self.get_response_data(resp)
        self.assertEqual(data.get('id'), 1)

        resp = self.get_helper('/api/v2/models')
        model_list = self.get_response_data(resp)
        self.assertEqual(len(model_list), 1)

        model = Model.query.first()
        model.state = ModelState.FAILED.value
        db.session.add(model)
        db.session.commit()
        self.delete_helper('/api/v2/models/1')
        resp = self.get_helper('/api/v2/models/1')
        data = self.get_response_data(resp)
        self.assertEqual(data.get('state'), ModelState.DROPPED.value)

    def test_get_eval(self):
        model = Model.query.filter_by(job_name=self.train_job.name).one()
        self.assertEqual(len(model.get_eval_model()), 1)