示例#1
0
    def test_on_job_update(self, mock_plot_metrics: MagicMock):
        mock_plot_metrics.return_value = 'plot metrics return'

        # TODO: change get_session to db.session_scope
        with get_session(db.engine) as session:
            model = session.query(Model).filter_by(
                job_name=self.train_job.name).one()
            self.assertEqual(model.state, ModelState.COMMITTED.value)

            train_job = session.query(Job).filter_by(name='train-job').one()
            train_job.state = JobState.STARTED
            session.commit()

        # TODO: change get_session to db.session_scope
        with get_session(db.engine) as session:
            train_job = session.query(Job).filter_by(name='train-job').one()
            train_job.state = JobState.STARTED
            model = session.query(Model).filter_by(
                job_name=self.train_job.name).one()
            model_service = ModelService(session)

            model_service.on_job_update(train_job)
            self.assertEqual(model.state, ModelState.RUNNING.value)
            session.commit()

        # TODO: change get_session to db.session_scope
        with get_session(db.engine) as session:
            train_job = session.query(Job).filter_by(name='train-job').one()
            train_job.state = JobState.COMPLETED
            model = session.query(Model).filter_by(
                job_name=self.train_job.name).one()
            model_service = ModelService(session)

            model_service.on_job_update(train_job)
            self.assertEqual(model.state, ModelState.SUCCEEDED.value)
            session.commit()

        # TODO: change get_session to db.session_scope
        with get_session(db.engine) as session:
            train_job = session.query(Job).filter_by(name='train-job').one()
            train_job.state = JobState.FAILED
            model = session.query(Model).filter_by(
                job_name=self.train_job.name).one()
            model_service = ModelService(session)

            model_service.on_job_update(train_job)
            self.assertEqual(model.state, ModelState.FAILED.value)
            session.commit()
示例#2
0
    def _schedule_job(self, job_id):
        job = Job.query.get(job_id)
        assert job is not None, f'Job {job_id} not found'
        if job.state != JobState.WAITING:
            return job.state

        with get_session(self._db_engine) as session:
            job_service = JobService(session)
            if not job_service.is_ready(job):
                return job.state
            config = job.get_config()
            if config.is_federated:
                if not job_service.is_peer_ready(job):
                    return job.state

        try:
            yaml = generate_job_run_yaml(job)
            k8s_client.create_flapp(yaml)
        except Exception as e:
            logging.error(f'Start job {job_id} has error msg: {e.args}')
            job.error_message = str(e)
            db.session.commit()
            return job.state
        job.error_message = None
        job.start()
        db.session.commit()

        return job.state
示例#3
0
 def _check_items(self):
     with get_session(self.db_engine) as session:
         items = session.query(SchedulerItem).filter_by(
             status=ItemStatus.ON.value).all()
         for item in items:
             if not item.need_run():
                 continue
             # NOTE: use `func.now()` to let sqlalchemy handles
             # the timezone.
             item.last_run_at = func.now()
             if item.interval_time < 0:
                 # finish run-once item automatically
                 item.status = ItemStatus.OFF.value
             pp = Pipeline(**(json.loads(item.pipeline)))
             context = Context(data=pp.meta,
                               internal={},
                               db_engine=self.db_engine)
             runner = SchedulerRunner(
                 item_id=item.id,
                 pipeline=item.pipeline,
                 context=ContextEncoder().encode(context),
             )
             session.add(runner)
             try:
                 logging.info(
                     f'[composer] insert runner, item_id: {item.id}')
                 session.commit()
             except Exception as e:  # pylint: disable=broad-except
                 logging.error(
                     f'[composer] failed to create scheduler_runner, '
                     f'item_id: {item.id}, exception: {e}')
                 session.rollback()
示例#4
0
    def patch_item_attr(self, name: str, key: str, value: str):
        """ patch item args

        Args:
            name (str): name of this item
            key (str): key you want to update
            value (str): value you wnat to set

        Returns:
            Raise if some check violates
        """
        if key not in self.__class__.MUTABLE_ITEM_KEY:
            raise ValueError(f'fail to change attribute {key}')

        with get_session(self.db_engine) as session:
            item: SchedulerItem = session.query(SchedulerItem).filter(
                SchedulerItem.name == name).first()
            if not item:
                raise ValueError(f'cannot find item {name}')
            setattr(item, key, value)
            session.add(item)
            try:
                session.commit()
            except Exception as e:  # pylint: disable=broad-except
                logging.error(f'[composer] failed to patch item attr, '
                              f'name: {name}, exception: {e}')
                session.rollback()
示例#5
0
    def get_item_status(self, name: str) -> Optional[ItemStatus]:
        """Get item status

        Args:
            name: item name
        """
        with get_session(self.db_engine) as session:
            existed = session.query(SchedulerItem).filter(
                SchedulerItem.name == name).first()
            if not existed:
                return None
            return ItemStatus(existed.status)
示例#6
0
    def check_job_ready(self, request, context):
        with self._app.app_context():
            project, _ = self.check_auth_info(request.auth_info, context)
            job = db.session.query(Job).filter_by(
                name=request.job_name, project_id=project.id).first()
            assert job is not None, \
                f'Job {request.job_name} not found'

            with get_session(db.get_engine()) as session:
                is_ready = JobService(session).is_ready(job)
            return service_pb2.CheckJobReadyResponse(
                status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS),
                is_ready=is_ready)
示例#7
0
 def _check_init_runners(self):
     with get_session(self.db_engine) as session:
         init_runners = session.query(SchedulerRunner).filter_by(
             status=RunnerStatus.INIT.value).all()
         # TODO: support priority
         for runner in init_runners:
             # if thread_reaper is full, skip this round and
             # wait next checking
             if self.thread_reaper.is_full():
                 return
             lock_name = f'check_init_runner_{runner.id}_lock'
             check_lock = OpLocker(lock_name, self.db_engine).try_lock()
             if not check_lock:
                 logging.error(f'[composer] failed to lock, '
                               f'ignore current init_runner_{runner.id}')
                 continue
             pipeline = Pipeline(**(json.loads(runner.pipeline)))
             context = decode_context(val=runner.context,
                                      db_engine=self.db_engine)
             # find the first job in pipeline
             first = pipeline.deps[0]
             # update status
             runner.start_at = func.now()
             runner.status = RunnerStatus.RUNNING.value
             output = json.loads(runner.output)
             output[first] = {'status': RunnerStatus.RUNNING.value}
             runner.output = json.dumps(output)
             # record current running job
             context.set_internal('current', first)
             runner.context = ContextEncoder().encode(context)
             # start runner
             runner_fn = self.runner_cache.find_runner(runner.id, first)
             self.thread_reaper.enqueue(name=lock_name,
                                        fn=runner_fn,
                                        context=context)
             try:
                 logging.info(
                     f'[composer] update runner, status: {runner.status}, '
                     f'pipeline: {runner.pipeline}, '
                     f'output: {output}, context: {runner.context}')
                 if check_lock.is_latest_version() and \
                         check_lock.update_version():
                     session.commit()
                 else:
                     logging.error(f'[composer] {lock_name} is outdated, '
                                   f'ignore updates to database')
             except Exception as e:  # pylint: disable=broad-except
                 logging.error(f'[composer] failed to update init runner'
                               f'status, exception: {e}')
                 session.rollback()
示例#8
0
    def update_version(self) -> bool:
        # double check
        if not self.is_latest_version():
            return False

        with get_session(self.db_engine) as session:
            try:
                lock = session.query(OptimisticLock).filter_by(
                    name=self._name).first()
                lock.version = self._version + 1
                session.commit()
                return True
            except Exception as e:  # pylint: disable=broad-except
                logging.error(f'failed to update lock version, exception: {e}')
                return False
示例#9
0
    def is_latest_version(self) -> bool:
        if not self._has_lock:
            return False

        with get_session(self.db_engine) as session:
            try:
                new_lock = session.query(OptimisticLock).filter_by(
                    name=self._name).first()
                if not new_lock:
                    return False
                logging.info(f'[op_locker] version, current: {self._version}, '
                             f'new: {new_lock.version}')
                return self._version == new_lock.version
            except Exception as e:  # pylint: disable=broad-except
                logging.error(
                    f'failed to check lock is conflict, exception: {e}')
                return False
示例#10
0
 def try_lock(self) -> 'OpLocker':
     with get_session(self.db_engine) as session:
         try:
             lock = session.query(OptimisticLock).filter_by(
                 name=self._name).first()
             if lock:
                 self._has_lock = True
                 self._version = lock.version
                 return self
             new_lock = OptimisticLock(name=self._name,
                                       version=self._version)
             session.add(new_lock)
             session.commit()
             self._has_lock = True
             return self
         except Exception as e:  # pylint: disable=broad-except
             logging.error(f'failed to require lock, exception: {e}')
             return self
示例#11
0
    def get_recent_runners(self,
                           name: str,
                           count: int = 10) -> List[SchedulerRunner]:
        """Get recent runners order by created_at in desc

        Args:
            name: item name
            count: the number of runners
        """
        with get_session(self.db_engine) as session:
            runners = session.query(SchedulerRunner).join(
                SchedulerItem,
                SchedulerItem.id == SchedulerRunner.item_id).filter(
                    SchedulerItem.name == name).order_by(
                        SchedulerRunner.created_at.desc()).limit(count)
            if not runners:
                return []
            return runners
示例#12
0
    def finish(self, name: str):
        """Finish item

        Args:
            name: item name
        """
        with get_session(self.db_engine) as session:
            existed = session.query(SchedulerItem).filter_by(
                name=name, status=ItemStatus.ON.value).first()
            if not existed:
                return
            existed.status = ItemStatus.OFF.value
            try:
                session.commit()
            except Exception as e:  # pylint: disable=broad-except
                logging.error(f'[composer] failed to finish scheduler_item, '
                              f'name: {name}, exception: {e}')
                session.rollback()
示例#13
0
    def start(self, context: Context):
        with get_session(context.db_engine) as session:
            try:
                workflow: Workflow = session.query(Workflow).filter_by(
                    id=self._workflow_id).one()
                # TODO: This is a hack!!! Templatelly use this method
                # cc @hangweiqiang: Transaction State Refactor
                state = workflow.get_state_for_frontend()
                if state in ('COMPLETED', 'FAILED', 'READY', 'STOPPED', 'NEW'):
                    if state in ('COMPLETED', 'FAILED'):
                        workflow.update_target_state(
                            target_state=WorkflowState.STOPPED)
                        session.commit()
                        # check workflow stopped
                        # TODO: use composer timeout cc @yurunyu
                        for _ in range(24):
                            # use session refresh to get the latest info
                            # otherwise it'll use the indentity map locally
                            session.refresh(workflow)
                            if workflow.state == WorkflowState.STOPPED:
                                break
                            sleep(5)
                        else:
                            self._msg = f'failed to stop \
                                        workflow[{self._workflow_id}]'
                            return
                    workflow.update_target_state(
                        target_state=WorkflowState.RUNNING)
                    session.commit()
                    self._msg = f'restarted workflow[{self._workflow_id}]'
                elif state == 'RUNNING':
                    self._msg = f'skip restarting workflow[{self._workflow_id}]'
                elif state == 'INVALID':
                    self._msg = f'current workflow[{self._workflow_id}] \
                                 is invalid'
                else:
                    self._msg = f'workflow[{self._workflow_id}] \
                                state is {state}, which is out of expection'

            except Exception as err:  # pylint: disable=broad-except
                self._msg = f'exception of workflow[{self._workflow_id}], \
示例#14
0
    def result(self, context: Context) -> Tuple[RunnerStatus, dict]:
        time.sleep(2)
        now = datetime.datetime.utcnow()
        timeout = random.randint(0, 10)
        # mock timeout
        if self._start_at is not None and self._start_at + datetime.timedelta(
                seconds=timeout) < now:
            # kill runner
            logging.info(f'[memory_runner] {self.task_id} is timeout, '
                         f'start at: {self._start_at}')
            return RunnerStatus.FAILED, {}

        # use `get_session` to query database
        with get_session(context.db_engine) as session:
            count = session.query(SchedulerRunner).count()
            # write data to context
            context.set_data(f'is_done_{self.task_id}', {
                'status': 'OK',
                'count': count
            })
        return RunnerStatus.DONE, {}
示例#15
0
    def collect(self,
                name: str,
                items: List[IItem],
                metadata: dict,
                interval: int = -1):
        """Collect scheduler item

        Args:
             name: item name, should be unique
             items: specify dependencies
             metadata: pass metadata to share with item dependencies each other
             interval: if value is -1, it's run-once job, or run
                every interval time in seconds
        """
        if len(name) == 0:
            return
        valid_interval = interval == -1 or interval >= 10
        if not valid_interval:  # seems non-sense if interval is less than 10
            raise ValueError('interval should not less than 10 if not -1')
        with get_session(self.db_engine) as session:
            # check name if exists
            existed = session.query(SchedulerItem).filter_by(name=name).first()
            if existed:
                return
            item = SchedulerItem(
                name=name,
                pipeline=PipelineEncoder().encode(
                    self._build_pipeline(name, items, metadata)),
                interval_time=interval,
            )
            session.add(item)
            try:
                session.commit()
            except Exception as e:  # pylint: disable=broad-except
                logging.error(f'[composer] failed to create scheduler_item, '
                              f'name: {name}, exception: {e}')
                session.rollback()
示例#16
0
    def _check_running_runners(self):
        with get_session(self.db_engine) as session:
            running_runners = session.query(SchedulerRunner).filter_by(
                status=RunnerStatus.RUNNING.value).all()
            for runner in running_runners:
                if self.thread_reaper.is_full():
                    return
                lock_name = f'check_running_runner_{runner.id}_lock'
                check_lock = OpLocker(lock_name, self.db_engine).try_lock()
                if not check_lock:
                    logging.error(f'[composer] failed to lock, '
                                  f'ignore current running_runner_{runner.id}')
                    continue
                # TODO: restart runner if exit unexpectedly
                pipeline = Pipeline(**(json.loads(runner.pipeline)))
                output = json.loads(runner.output)
                context = decode_context(val=runner.context,
                                         db_engine=self.db_engine)
                current = context.internal['current']
                runner_fn = self.runner_cache.find_runner(runner.id, current)
                # check status of current one
                status, current_output = runner_fn.result(context)
                if status == RunnerStatus.RUNNING:
                    continue  # ignore
                if status == RunnerStatus.DONE:
                    output[current] = {'status': RunnerStatus.DONE.value}
                    context.set_internal(f'output_{current}', current_output)
                    current_idx = pipeline.deps.index(current)
                    if current_idx == len(pipeline.deps) - 1:  # all done
                        runner.status = RunnerStatus.DONE.value
                        runner.end_at = func.now()
                    else:  # run next one
                        next_one = pipeline.deps[current_idx + 1]
                        output[next_one] = {
                            'status': RunnerStatus.RUNNING.value
                        }
                        context.set_internal('current', next_one)
                        next_runner_fn = self.runner_cache.find_runner(
                            runner.id, next_one)
                        self.thread_reaper.enqueue(name=lock_name,
                                                   fn=next_runner_fn,
                                                   context=context)
                elif status == RunnerStatus.FAILED:
                    # TODO: abort now, need retry
                    output[current] = {'status': RunnerStatus.FAILED.value}
                    context.set_internal(f'output_{current}', current_output)
                    runner.status = RunnerStatus.FAILED.value
                    runner.end_at = func.now()

                runner.pipeline = PipelineEncoder().encode(pipeline)
                runner.output = json.dumps(output)
                runner.context = ContextEncoder().encode(context)

                updated_db = False
                try:
                    logging.info(
                        f'[composer] update runner, status: {runner.status}, '
                        f'pipeline: {runner.pipeline}, '
                        f'output: {output}, context: {runner.context}')
                    if check_lock.is_latest_version():
                        if check_lock.update_version():
                            session.commit()
                            updated_db = True
                    else:
                        logging.error(f'[composer] {lock_name} is outdated, '
                                      f'ignore updates to database')
                except Exception as e:  # pylint: disable=broad-except
                    logging.error(f'[composer] failed to update running '
                                  f'runner status, exception: {e}')
                    session.rollback()

                # delete useless runner obj in runner cache
                if status in (RunnerStatus.DONE,
                              RunnerStatus.FAILED) and updated_db:
                    self.runner_cache.del_runner(runner.id, current)