예제 #1
0
파일: models.py 프로젝트: guotie/fedlearner
    def _setup_jobs(self):
        if self.forked_from is not None:
            trunk = Workflow.query.get(self.forked_from)
            assert trunk is not None, \
                'Source workflow %d not found'%self.forked_from
            trunk_job_defs = trunk.get_config().job_definitions
            trunk_name2index = {
                job.name: i for i, job in enumerate(trunk_job_defs)
            }
        else:
            assert not self.get_reuse_job_names()

        job_defs = self.get_config().job_definitions
        jobs = []
        reuse_jobs = set(self.get_reuse_job_names())
        for i, job_def in enumerate(job_defs):
            if job_def.name in reuse_jobs:
                assert job_def.name in trunk_name2index, \
                    "Job %s not found in base workflow"%job_def.name
                j = trunk.get_job_ids()[trunk_name2index[job_def.name]]
                job = Job.query.get(j)
                assert job is not None, \
                    'Job %d not found'%j
                # TODO: check forked jobs does not depend on non-forked jobs
            else:
                job = Job(name=f'{self.uuid}-{job_def.name}',
                          job_type=JobType(job_def.type),
                          config=job_def.SerializeToString(),
                          workflow_id=self.id,
                          project_id=self.project_id,
                          state=JobState.STOPPED)
                job.set_yaml_template(job_def.yaml_template)
                db.session.add(job)
            jobs.append(job)
        db.session.flush()
        name2index = {
            job.name: i for i, job in enumerate(job_defs)
        }
        for i, job in enumerate(jobs):
            if job.get_config().name in reuse_jobs:
                continue
            for j, dep_def in enumerate(job.get_config().dependencies):
                dep = JobDependency(
                    src_job_id=jobs[name2index[dep_def.source]].id,
                    dst_job_id=job.id,
                    dep_index=j)
                db.session.add(dep)

        self.set_job_ids([job.id for job in jobs])
예제 #2
0
    def _setup_jobs(self):
        if self.forked_from is not None:
            trunk = Workflow.query.get(self.forked_from)
            assert trunk is not None, \
                'Source workflow %d not found' % self.forked_from
            trunk_job_defs = trunk.get_config().job_definitions
            trunk_name2index = {
                job.name: i
                for i, job in enumerate(trunk_job_defs)
            }

        job_defs = self.get_config().job_definitions
        flags = self.get_create_job_flags()
        assert len(job_defs) == len(flags), \
            'Number of job defs does not match number of create_job_flags ' \
            '%d vs %d'%(len(job_defs), len(flags))
        jobs = []
        for i, (job_def, flag) in enumerate(zip(job_defs, flags)):
            if flag == common_pb2.CreateJobFlag.REUSE:
                assert job_def.name in trunk_name2index, \
                    f'Job {job_def.name} not found in base workflow'
                j = trunk.get_job_ids()[trunk_name2index[job_def.name]]
                job = Job.query.get(j)
                assert job is not None, \
                    'Job %d not found' % j
                # TODO: check forked jobs does not depend on non-forked jobs
            else:
                job = Job(
                    name=f'{self.uuid}-{job_def.name}',
                    job_type=JobType(job_def.job_type),
                    config=job_def.SerializeToString(),
                    workflow_id=self.id,
                    project_id=self.project_id,
                    state=JobState.NEW,
                    is_disabled=(flag == common_pb2.CreateJobFlag.DISABLED))
                db.session.add(job)
            jobs.append(job)
        db.session.flush()
        name2index = {job.name: i for i, job in enumerate(job_defs)}
        for i, (job, flag) in enumerate(zip(jobs, flags)):
            if flag == common_pb2.CreateJobFlag.REUSE:
                continue
            for j, dep_def in enumerate(job.get_config().dependencies):
                dep = JobDependency(
                    src_job_id=jobs[name2index[dep_def.source]].id,
                    dst_job_id=job.id,
                    dep_index=j)
                db.session.add(dep)

        self.set_job_ids([job.id for job in jobs])
        if Features.FEATURE_MODEL_WORKFLOW_HOOK:
            for job in jobs:
                ModelService(db.session).workflow_hook(job)