Exemple #1
0
class Workflow(db.Model):
    __tablename__ = 'workflow_v2'
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(255), unique=True, index=True)
    project_id = db.Column(db.Integer, db.ForeignKey(Project.id))
    config = db.Column(db.Text())
    comment = db.Column(db.String(255))

    forkable = db.Column(db.Boolean, default=False)
    forked_from = db.Column(db.Integer, default=None)
    # index in config.job_defs instead of job's id
    reuse_job_names = db.Column(db.TEXT())
    peer_reuse_job_names = db.Column(db.TEXT())
    fork_proposal_config = db.Column(db.TEXT())

    recur_type = db.Column(db.Enum(RecurType), default=RecurType.NONE)
    recur_at = db.Column(db.Interval)
    trigger_dataset = db.Column(db.Integer)
    last_triggered_batch = db.Column(db.Integer)

    job_ids = db.Column(db.TEXT())

    state = db.Column(db.Enum(WorkflowState), default=WorkflowState.INVALID)
    target_state = db.Column(db.Enum(WorkflowState),
                             default=WorkflowState.INVALID)
    transaction_state = db.Column(db.Enum(TransactionState),
                                  default=TransactionState.READY)
    transaction_err = db.Column(db.Text())

    start_at = db.Column(db.Integer)
    stop_at = db.Column(db.Integer)

    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           onupdate=func.now(),
                           server_default=func.now())

    owned_jobs = db.relationship('Job', back_populates='workflow')
    project = db.relationship(Project)

    def set_config(self, proto):
        if proto is not None:
            self.config = proto.SerializeToString()
        else:
            self.config = None

    def get_config(self):
        if self.config is not None:
            proto = workflow_definition_pb2.WorkflowDefinition()
            proto.ParseFromString(self.config)
            return proto
        return None

    def set_fork_proposal_config(self, proto):
        if proto is not None:
            self.fork_proposal_config = proto.SerializeToString()
        else:
            self.fork_proposal_config = None

    def get_fork_proposal_config(self):
        if self.fork_proposal_config is not None:
            proto = workflow_definition_pb2.WorkflowDefinition()
            proto.ParseFromString(self.fork_proposal_config)
            return proto
        return None

    def set_job_ids(self, job_ids):
        self.job_ids = ','.join([str(i) for i in job_ids])

    def get_job_ids(self):
        if not self.job_ids:
            return []
        return [int(i) for i in self.job_ids.split(',')]

    def get_jobs(self):
        return [Job.query.get(i) for i in self.get_job_ids()]

    def set_reuse_job_names(self, reuse_job_names):
        self.reuse_job_names = ','.join(reuse_job_names)

    def get_reuse_job_names(self):
        if not self.reuse_job_names:
            return []
        return self.reuse_job_names.split(',')

    def set_peer_reuse_job_names(self, peer_reuse_job_names):
        self.peer_reuse_job_names = ','.join(peer_reuse_job_names)

    def get_peer_reuse_job_names(self):
        if not self.peer_reuse_job_names:
            return []
        return self.peer_reuse_job_names.split(',')

    def update_target_state(self, target_state):
        if self.target_state != target_state \
                and self.target_state != WorkflowState.INVALID:
            raise ValueError(f'Another transaction is in progress [{self.id}]')
        if target_state not in [
                WorkflowState.READY, WorkflowState.RUNNING,
                WorkflowState.STOPPED
        ]:
            raise ValueError(f'Invalid target_state {self.target_state}')
        if (self.state, target_state) not in VALID_TRANSITIONS:
            raise ValueError(
                f'Invalid transition from {self.state} to {target_state}')
        self.target_state = target_state

    def update_state(self, asserted_state, target_state, transaction_state):
        assert asserted_state is None or self.state == asserted_state, \
            'Cannot change current state directly'

        if transaction_state != self.transaction_state:
            if (self.transaction_state, transaction_state) in \
                    IGNORED_TRANSACTION_TRANSITIONS:
                return self.transaction_state
            assert (self.transaction_state, transaction_state) in \
                   VALID_TRANSACTION_TRANSITIONS, \
                'Invalid transaction transition from {} to {}'.format(
                    self.transaction_state, transaction_state)
            self.transaction_state = transaction_state

        # coordinator prepare & rollback
        if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
            self.prepare(target_state)
        if self.transaction_state == TransactionState.COORDINATOR_ABORTING:
            self.rollback()

        # participant prepare & rollback & commit
        if self.transaction_state == TransactionState.PARTICIPANT_PREPARE:
            self.prepare(target_state)
        if self.transaction_state == TransactionState.PARTICIPANT_ABORTING:
            self.rollback()
            self.transaction_state = TransactionState.ABORTED
        if self.transaction_state == TransactionState.PARTICIPANT_COMMITTING:
            self.commit()

        return self.transaction_state

    def prepare(self, target_state):
        assert self.transaction_state in [
            TransactionState.COORDINATOR_PREPARE,
            TransactionState.PARTICIPANT_PREPARE], \
            'Workflow not in prepare state'

        # TODO(tjulinfan): remove this
        if target_state is None:
            # No action
            return

        # Validation
        try:
            self.update_target_state(target_state)
        except ValueError as e:
            logging.warning('Error during update target state in prepare: %s',
                            str(e))
            self.transaction_state = TransactionState.ABORTED
            return

        success = True
        if self.target_state == WorkflowState.READY:
            success = self._prepare_for_ready()

        if success:
            if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
                self.transaction_state = \
                    TransactionState.COORDINATOR_COMMITTABLE
            else:
                self.transaction_state = \
                    TransactionState.PARTICIPANT_COMMITTABLE

    def rollback(self):
        self.target_state = WorkflowState.INVALID

    # TODO: separate this method to another module
    def commit(self):
        assert self.transaction_state in [
            TransactionState.COORDINATOR_COMMITTING,
            TransactionState.PARTICIPANT_COMMITTING], \
                'Workflow not in prepare state'

        if self.target_state == WorkflowState.STOPPED:
            self.stop_at = int(datetime.utcnow().timestamp())
            for job in self.owned_jobs:
                job.stop()
        elif self.target_state == WorkflowState.READY:
            self._setup_jobs()
            self.fork_proposal_config = None
        elif self.target_state == WorkflowState.RUNNING:
            self.start_at = int(datetime.utcnow().timestamp())
            for job in self.owned_jobs:
                if not job.get_config().is_manual:
                    job.schedule()

        self.state = self.target_state
        self.target_state = WorkflowState.INVALID
        self.transaction_state = TransactionState.READY

    def _setup_jobs(self):
        if self.forked_from is not None:
            trunk = Workflow.query.get(self.forked_from)
            assert trunk is not None, \
                'Source workflow %d not found'%self.forked_from
            trunk_job_defs = trunk.get_config().job_definitions
            trunk_name2index = {
                job.name: i
                for i, job in enumerate(trunk_job_defs)
            }
        else:
            assert not self.get_reuse_job_names()

        job_defs = self.get_config().job_definitions
        jobs = []
        reuse_jobs = set(self.get_reuse_job_names())
        for i, job_def in enumerate(job_defs):
            if job_def.name in reuse_jobs:
                assert job_def.name in trunk_name2index, \
                    "Job %s not found in base workflow"%job_def.name
                j = trunk.get_job_ids()[trunk_name2index[job_def.name]]
                job = Job.query.get(j)
                assert job is not None, \
                    'Job %d not found'%j
                # TODO: check forked jobs does not depend on non-forked jobs
            else:
                job = Job(name=f'{self.name}-{job_def.name}',
                          job_type=JobType(job_def.type),
                          config=job_def.SerializeToString(),
                          workflow_id=self.id,
                          project_id=self.project_id,
                          state=JobState.STOPPED)
                job.set_yaml_template(job_def.yaml_template)
                db.session.add(job)
            jobs.append(job)
        db.session.commit()

        name2index = {job.name: i for i, job in enumerate(job_defs)}
        for i, job in enumerate(jobs):
            if job.name in reuse_jobs:
                continue
            for j, dep_def in enumerate(job.get_config().dependencies):
                dep = JobDependency(
                    src_job_id=jobs[name2index[dep_def.source]].id,
                    dst_job_id=job.id,
                    dep_index=j)
                db.session.add(dep)

        self.set_job_ids([job.id for job in jobs])

        db.session.commit()

    def log_states(self):
        logging.debug(
            'workflow %d updated to state=%s, target_state=%s, '
            'transaction_state=%s', self.id, self.state.name,
            self.target_state.name, self.transaction_state.name)

    def _get_peer_workflow(self):
        project_config = self.project.get_config()
        # TODO: find coordinator for multiparty
        client = RpcClient(project_config, project_config.participants[0])
        return client.get_workflow(self.name)

    def _prepare_for_ready(self):
        # This is a hack, if config is not set then
        # no action needed
        if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
            # TODO(tjulinfan): validate if the config is legal or not
            return bool(self.config)
        peer_workflow = self._get_peer_workflow()
        if peer_workflow.forked_from:
            base_workflow = Workflow.query.filter(
                Workflow.name == peer_workflow.forked_from).first()
            if base_workflow is None or not base_workflow.forkable:
                return False
            self.forked_from = base_workflow.id
            self.forkable = base_workflow.forkable
            self.set_reuse_job_names(peer_workflow.peer_reuse_job_names)
            self.set_peer_reuse_job_names(peer_workflow.reuse_job_names)
            config = base_workflow.get_config()
            _merge_workflow_config(config, peer_workflow.fork_proposal_config,
                                   [common_pb2.Variable.PEER_WRITABLE])
            self.set_config(config)
            return True
        return bool(self.config)
Exemple #2
0
class Workflow(db.Model):
    __tablename__ = 'workflow_v2'
    __table_args__ = (UniqueConstraint('uuid', name='uniq_uuid'),
                      UniqueConstraint('name', name='uniq_name'), {
                          'comment': 'workflow_v2',
                          'mysql_engine': 'innodb',
                          'mysql_charset': 'utf8mb4',
                      })
    id = db.Column(db.Integer, primary_key=True, comment='id')
    uuid = db.Column(db.String(64), comment='uuid')
    name = db.Column(db.String(255), comment='name')
    project_id = db.Column(db.Integer, comment='project_id')
    # max store 16777215 bytes (16 MB)
    config = db.Column(db.LargeBinary(16777215), comment='config')
    comment = db.Column('cmt',
                        db.String(255),
                        key='comment',
                        comment='comment')

    metric_is_public = db.Column(db.Boolean(),
                                 default=False,
                                 nullable=False,
                                 comment='metric_is_public')
    create_job_flags = db.Column(db.TEXT(), comment='create_job_flags')

    job_ids = db.Column(db.TEXT(), comment='job_ids')

    forkable = db.Column(db.Boolean, default=False, comment='forkable')
    forked_from = db.Column(db.Integer, default=None, comment='forked_from')
    # index in config.job_defs instead of job's id
    peer_create_job_flags = db.Column(db.TEXT(),
                                      comment='peer_create_job_flags')
    # max store 16777215 bytes (16 MB)
    fork_proposal_config = db.Column(db.LargeBinary(16777215),
                                     comment='fork_proposal_config')

    recur_type = db.Column(db.Enum(RecurType, native_enum=False),
                           default=RecurType.NONE,
                           comment='recur_type')
    recur_at = db.Column(db.Interval, comment='recur_at')
    trigger_dataset = db.Column(db.Integer, comment='trigger_dataset')
    last_triggered_batch = db.Column(db.Integer,
                                     comment='last_triggered_batch')

    state = db.Column(db.Enum(WorkflowState,
                              native_enum=False,
                              name='workflow_state'),
                      default=WorkflowState.INVALID,
                      comment='state')
    target_state = db.Column(db.Enum(WorkflowState,
                                     native_enum=False,
                                     name='workflow_target_state'),
                             default=WorkflowState.INVALID,
                             comment='target_state')
    transaction_state = db.Column(db.Enum(TransactionState, native_enum=False),
                                  default=TransactionState.READY,
                                  comment='transaction_state')
    transaction_err = db.Column(db.Text(), comment='transaction_err')

    start_at = db.Column(db.Integer, comment='start_at')
    stop_at = db.Column(db.Integer, comment='stop_at')

    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           comment='created_at')
    updated_at = db.Column(db.DateTime(timezone=True),
                           onupdate=func.now(),
                           server_default=func.now(),
                           comment='update_at')

    owned_jobs = db.relationship(
        'Job', primaryjoin='foreign(Job.workflow_id) == Workflow.id')
    project = db.relationship(
        'Project', primaryjoin='Project.id == foreign(Workflow.project_id)')

    def get_state_for_frontend(self):
        if self.state == WorkflowState.RUNNING:
            is_complete = all([job.is_complete() for job in self.owned_jobs])
            if is_complete:
                return 'COMPLETED'
            is_failed = any([job.is_failed() for job in self.owned_jobs])
            if is_failed:
                return 'FAILED'
        return self.state.name

    def get_transaction_state_for_frontend(self):
        # TODO(xiangyuxuan): remove this hack by redesign 2pc
        if (self.transaction_state == TransactionState.PARTICIPANT_PREPARE
                and self.config is not None):
            return 'PARTICIPANT_COMMITTABLE'
        return self.transaction_state.name

    def set_config(self, proto):
        if proto is not None:
            self.config = proto.SerializeToString()
        else:
            self.config = None

    def get_config(self):
        if self.config is not None:
            proto = workflow_definition_pb2.WorkflowDefinition()
            proto.ParseFromString(self.config)
            return proto
        return None

    def set_fork_proposal_config(self, proto):
        if proto is not None:
            self.fork_proposal_config = proto.SerializeToString()
        else:
            self.fork_proposal_config = None

    def get_fork_proposal_config(self):
        if self.fork_proposal_config is not None:
            proto = workflow_definition_pb2.WorkflowDefinition()
            proto.ParseFromString(self.fork_proposal_config)
            return proto
        return None

    def set_job_ids(self, job_ids):
        self.job_ids = ','.join([str(i) for i in job_ids])

    def get_job_ids(self):
        if not self.job_ids:
            return []
        return [int(i) for i in self.job_ids.split(',')]

    def get_jobs(self):
        return [Job.query.get(i) for i in self.get_job_ids()]

    def set_create_job_flags(self, create_job_flags):
        if create_job_flags is None:
            self.create_job_flags = None
        else:
            self.create_job_flags = ','.join(
                [str(i) for i in create_job_flags])

    def get_create_job_flags(self):
        if self.create_job_flags is None:
            config = self.get_config()
            if config is None:
                return None
            num_jobs = len(config.job_definitions)
            return [common_pb2.CreateJobFlag.NEW] * num_jobs
        return [int(i) for i in self.create_job_flags.split(',')]

    def set_peer_create_job_flags(self, peer_create_job_flags):
        if not peer_create_job_flags:
            self.peer_create_job_flags = None
        else:
            self.peer_create_job_flags = ','.join(
                [str(i) for i in peer_create_job_flags])

    def get_peer_create_job_flags(self):
        if self.peer_create_job_flags is None:
            return None
        return [int(i) for i in self.peer_create_job_flags.split(',')]

    def update_target_state(self, target_state):
        if self.target_state != target_state \
            and self.target_state != WorkflowState.INVALID:
            raise ValueError(f'Another transaction is in progress [{self.id}]')
        if target_state not in [
                WorkflowState.READY, WorkflowState.RUNNING,
                WorkflowState.STOPPED
        ]:
            raise ValueError(f'Invalid target_state {self.target_state}')
        if (self.state, target_state) not in VALID_TRANSITIONS:
            raise ValueError(
                f'Invalid transition from {self.state} to {target_state}')

        self.target_state = target_state

    def update_state(self, asserted_state, target_state, transaction_state):
        assert asserted_state is None or self.state == asserted_state, \
            'Cannot change current state directly'

        if transaction_state != self.transaction_state:
            if (self.transaction_state, transaction_state) in \
                IGNORED_TRANSACTION_TRANSITIONS:
                return self.transaction_state
            assert (self.transaction_state, transaction_state) in \
                   VALID_TRANSACTION_TRANSITIONS, \
                'Invalid transaction transition from {} to {}'.format(
                    self.transaction_state, transaction_state)
            self.transaction_state = transaction_state

        # coordinator prepare & rollback
        if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
            self.prepare(target_state)
        if self.transaction_state == TransactionState.COORDINATOR_ABORTING:
            self.rollback()

        # participant prepare & rollback & commit
        if self.transaction_state == TransactionState.PARTICIPANT_PREPARE:
            self.prepare(target_state)
        if self.transaction_state == TransactionState.PARTICIPANT_ABORTING:
            self.rollback()
            self.transaction_state = TransactionState.ABORTED
        if self.transaction_state == TransactionState.PARTICIPANT_COMMITTING:
            self.commit()

        return self.transaction_state

    def prepare(self, target_state):
        assert self.transaction_state in [
            TransactionState.COORDINATOR_PREPARE,
            TransactionState.PARTICIPANT_PREPARE], \
            'Workflow not in prepare state'

        # TODO(tjulinfan): remove this
        if target_state is None:
            # No action
            return

        # Validation
        try:
            self.update_target_state(target_state)
        except ValueError as e:
            logging.warning('Error during update target state in prepare: %s',
                            str(e))
            self.transaction_state = TransactionState.ABORTED
            return

        success = True
        if self.target_state == WorkflowState.READY:
            success = self._prepare_for_ready()

        if success:
            if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
                self.transaction_state = \
                    TransactionState.COORDINATOR_COMMITTABLE
            else:
                self.transaction_state = \
                    TransactionState.PARTICIPANT_COMMITTABLE

    def rollback(self):
        self.target_state = WorkflowState.INVALID

    # TODO: separate this method to another module
    def commit(self):
        assert self.transaction_state in [
            TransactionState.COORDINATOR_COMMITTING,
            TransactionState.PARTICIPANT_COMMITTING], \
            'Workflow not in prepare state'

        if self.target_state == WorkflowState.STOPPED:
            self.stop_at = int(datetime.utcnow().timestamp())
            try:
                for job in self.owned_jobs:
                    job.stop()
            except RuntimeError as e:
                # errors from k8s
                logging.error('Stop workflow %d has Runtime error msg: %s',
                              self.id, e.args)
                return
        elif self.target_state == WorkflowState.READY:
            self._setup_jobs()
            self.fork_proposal_config = None
        elif self.target_state == WorkflowState.RUNNING:
            self.start_at = int(datetime.utcnow().timestamp())
            for job in self.owned_jobs:
                if not job.is_disabled:
                    job.schedule()

        self.state = self.target_state
        self.target_state = WorkflowState.INVALID
        self.transaction_state = TransactionState.READY

    def invalidate(self):
        self.state = WorkflowState.INVALID
        self.target_state = WorkflowState.INVALID
        self.transaction_state = TransactionState.READY
        for job in self.owned_jobs:
            try:
                job.stop()
            except Exception as e:  # pylint: disable=broad-except
                logging.warning(
                    'Error while stopping job %s during invalidation: %s',
                    job.name, repr(e))

    def _setup_jobs(self):
        if self.forked_from is not None:
            trunk = Workflow.query.get(self.forked_from)
            assert trunk is not None, \
                'Source workflow %d not found' % self.forked_from
            trunk_job_defs = trunk.get_config().job_definitions
            trunk_name2index = {
                job.name: i
                for i, job in enumerate(trunk_job_defs)
            }

        job_defs = self.get_config().job_definitions
        flags = self.get_create_job_flags()
        assert len(job_defs) == len(flags), \
            'Number of job defs does not match number of create_job_flags ' \
            '%d vs %d'%(len(job_defs), len(flags))
        jobs = []
        for i, (job_def, flag) in enumerate(zip(job_defs, flags)):
            if flag == common_pb2.CreateJobFlag.REUSE:
                assert job_def.name in trunk_name2index, \
                    "Job %s not found in base workflow" % job_def.name
                j = trunk.get_job_ids()[trunk_name2index[job_def.name]]
                job = Job.query.get(j)
                assert job is not None, \
                    'Job %d not found' % j
                # TODO: check forked jobs does not depend on non-forked jobs
            else:
                job = Job(
                    name=f'{self.uuid}-{job_def.name}',
                    job_type=JobType(job_def.job_type),
                    config=job_def.SerializeToString(),
                    workflow_id=self.id,
                    project_id=self.project_id,
                    state=JobState.STOPPED,
                    is_disabled=(flag == common_pb2.CreateJobFlag.DISABLED))
                job.set_yaml_template(job_def.yaml_template)
                db.session.add(job)
            jobs.append(job)
        db.session.flush()
        name2index = {job.name: i for i, job in enumerate(job_defs)}
        for i, (job, flag) in enumerate(zip(jobs, flags)):
            if flag == common_pb2.CreateJobFlag.REUSE:
                continue
            for j, dep_def in enumerate(job.get_config().dependencies):
                dep = JobDependency(
                    src_job_id=jobs[name2index[dep_def.source]].id,
                    dst_job_id=job.id,
                    dep_index=j)
                db.session.add(dep)

        self.set_job_ids([job.id for job in jobs])

    def log_states(self):
        logging.debug(
            'workflow %d updated to state=%s, target_state=%s, '
            'transaction_state=%s', self.id, self.state.name,
            self.target_state.name, self.transaction_state.name)

    def _get_peer_workflow(self):
        project_config = self.project.get_config()
        # TODO: find coordinator for multiparty
        client = RpcClient(project_config, project_config.participants[0])
        return client.get_workflow(self.name)

    def _prepare_for_ready(self):
        # This is a hack, if config is not set then
        # no action needed
        if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
            # TODO(tjulinfan): validate if the config is legal or not
            return bool(self.config)
        if self.forked_from:
            peer_workflow = self._get_peer_workflow()
            base_workflow = Workflow.query.get(self.forked_from)
            if base_workflow is None or not base_workflow.forkable:
                return False
            self.forked_from = base_workflow.id
            self.forkable = base_workflow.forkable
            self.set_create_job_flags(peer_workflow.peer_create_job_flags)
            self.set_peer_create_job_flags(peer_workflow.create_job_flags)
            config = base_workflow.get_config()
            _merge_workflow_config(config, peer_workflow.fork_proposal_config,
                                   [common_pb2.Variable.PEER_WRITABLE])
            self.set_config(config)
            return True
        return bool(self.config)