コード例 #1
0
class Project(db.Model):
    __tablename__ = 'projects_v2'
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    name = db.Column(db.String(255), index=True, unique=True)
    token = db.Column(db.String(64), index=True)
    config = db.Column(db.Text())
    certificate = db.Column(db.Text())
    comment = db.Column(db.Text())
    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_onupdate=func.now(),
                           server_default=func.now())
    deleted_at = db.Column(db.DateTime(timezone=True))

    def set_config(self, proto):
        self.config = proto.SerializeToString()

    def get_config(self):
        proto = project_pb2.Project()
        proto.ParseFromString(self.config)
        return proto

    def set_certificate(self, proto):
        self.certificate = proto.SerializeToString()

    def get_certificate(self):
        proto = project_pb2.CertificateStorage()
        proto.ParseFromString(self.certificate)
        return proto
コード例 #2
0
ファイル: models.py プロジェクト: nolanliou/fedlearner
class DataBatch(db.Model):
    __tablename__ = 'data_batches_v2'
    __table_args__ = (PrimaryKeyConstraint('event_time', 'dataset_id'), )
    event_time = db.Column(db.TIMESTAMP(timezone=True))
    dataset_id = db.Column(db.Integer, db.ForeignKey(Dataset.id))
    state = db.Column(db.Enum(BatchState))
    source = db.Column(db.Text(),
                       default=dataset_pb2.DatasetSource().SerializeToString())
    failed_source = db.Column(db.Text())
    file_size = db.Column(db.Integer, default=0)
    imported_file_num = db.Column(db.Integer, default=0)
    num_file = db.Column(db.Integer, default=0)
    comment = db.Column(db.Text())
    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           server_onupdate=func.now())
    deleted_at = db.Column(db.DateTime(timezone=True))

    dataset = db.relationship('Dataset', back_populates='data_batches')

    def set_source(self, proto):
        self.source = proto.SerializeToString()

    def get_source(self):
        if self.source is None:
            return None
        proto = dataset_pb2.DatasetSource()
        proto.ParseFromString(self.source)
        return proto
コード例 #3
0
class DataBatch(db.Model):
    __tablename__ = 'data_batches_v2'
    __table_args__ = (UniqueConstraint('event_time', 'dataset_id'), )
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    event_time = db.Column(db.TIMESTAMP(timezone=True))
    dataset_id = db.Column(db.Integer, db.ForeignKey(Dataset.id))
    state = db.Column(db.Enum(BatchState), default=BatchState.NEW)
    move = db.Column(db.Boolean, default=False)
    # Serialized proto of DatasetBatch
    details = db.Column(db.Text())
    file_size = db.Column(db.Integer, default=0)
    num_imported_file = db.Column(db.Integer, default=0)
    num_file = db.Column(db.Integer, default=0)
    comment = db.Column(db.Text())
    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           server_onupdate=func.now())
    deleted_at = db.Column(db.DateTime(timezone=True))

    dataset = db.relationship('Dataset', back_populates='data_batches')

    def set_details(self, proto):
        self.num_file = len(proto.files)
        num_imported_file = 0
        num_failed_file = 0
        file_size = 0
        # Aggregates stats
        for file in proto.files:
            if file.state == dataset_pb2.File.State.COMPLETED:
                num_imported_file += 1
                file_size += file.size
            elif file.state == dataset_pb2.File.State.FAILED:
                num_failed_file += 1
        if num_imported_file + num_failed_file == self.num_file:
            if num_failed_file > 0:
                self.state = BatchState.FAILED
            else:
                self.state = BatchState.SUCCESS
        self.num_imported_file = num_imported_file
        self.file_size = file_size
        self.details = proto.SerializeToString()

    def get_details(self):
        if self.details is None:
            return None
        proto = dataset_pb2.DataBatch()
        proto.ParseFromString(self.details)
        return proto
コード例 #4
0
ファイル: models.py プロジェクト: nolanliou/fedlearner
class Dataset(db.Model):
    __tablename__ = 'datasets_v2'
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    name = db.Column(db.String(255), unique=True)
    type = db.Column(db.Enum(DatasetType))
    external_storage_path = db.Column(db.Text())
    comment = db.Column(db.Text())
    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           onupdate=func.now())
    deleted_at = db.Column(db.DateTime(timezone=True))

    data_batches = db.relationship('DataBatch', back_populates='dataset')
コード例 #5
0
class Dataset(db.Model):
    __tablename__ = 'datasets_v2'
    __table_args__ = ({
        'comment': 'This is webconsole dataset table',
        'mysql_engine': 'innodb',
        'mysql_charset': 'utf8mb4',
    })

    id = db.Column(db.Integer,
                   primary_key=True,
                   autoincrement=True,
                   comment='id')
    name = db.Column(db.String(255), nullable=False, comment='dataset name')
    dataset_type = db.Column(db.Enum(DatasetType, native_enum=False),
                             nullable=False,
                             comment='data type')
    path = db.Column(db.String(512), comment='dataset path')
    comment = db.Column('cmt',
                        db.Text(),
                        key='comment',
                        comment='comment of dataset')
    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           comment='created time')
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           onupdate=func.now(),
                           comment='updated time')
    deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted time')

    data_batches = db.relationship(
        'DataBatch', primaryjoin='foreign(DataBatch.dataset_id) == Dataset.id')
コード例 #6
0
class Workflow(db.Model):
    __tablename__ = 'workflow_v2'
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(255), unique=True, index=True)
    project_token = db.Column(db.String(255), nullable=False)
    status = db.Column(db.Enum(WorkflowStatus), nullable=False)
    uid = db.Column(db.String(255), unique=True, nullable=False, index=True)
    forkable = db.Column(db.Boolean, default=True)
    group_alias = db.Column(db.String(255), index=True)
    config = db.Column(db.Text())
    peer_config = db.Column(db.Text())
    comment = db.Column(db.String(255))

    created_at = db.Column(db.DateTime,
                           nullable=False,
                           default=datetime.utcnow)
    updated_at = db.Column(db.DateTime, onupdate=datetime.utcnow)
    deleted_at = db.Column(db.DateTime)

    def set_config(self, proto):
        self.config = proto.SerializeToString()

    def get_config(self):
        proto = workflow_definition_pb2.WorkflowDefinition()
        proto.ParseFromString(self.config)
        return proto

    def set_peer_config(self, proto):
        self.peer_config = proto.SerializeToString()

    def get_peer_config(self):
        proto = workflow_definition_pb2.WorkflowDefinition()
        proto.ParseFromString(self.peer_config)
        return proto

    def to_dict(self):
        dic = {
            col.name: getattr(self, col.name)
            for col in self.__table__.columns
        }
        dic['config'] = json_format.MessageToDict(
            self.get_config(), preserving_proto_field_name=True)
        dic['peer_config'] = json_format.MessageToDict(
            self.get_peer_config(), preserving_proto_field_name=True)
        return dic
コード例 #7
0
class Project(db.Model):
    __tablename__ = 'projects_v2'
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    name = db.Column(db.String(255), index=True, unique=True)
    token = db.Column(db.String(64), index=True)
    config = db.Column(db.Text())
    certificate = db.Column(db.Text())
    comment = db.Column(db.Text())
    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           onupdate=func.now(),
                           server_default=func.now())
    deleted_at = db.Column(db.DateTime(timezone=True))

    def set_config(self, proto):
        self.config = proto.SerializeToString()

    def get_config(self):
        if self.config is None:
            return None
        proto = project_pb2.Project()
        proto.ParseFromString(self.config)
        return proto

    def set_certificate(self, proto):
        self.certificate = proto.SerializeToString()

    def get_certificate(self):
        if self.certificate is None:
            return None
        proto = project_pb2.CertificateStorage()
        proto.ParseFromString(self.certificate)
        return proto

    def get_namespace(self):
        config = self.get_config()
        if config is not None:
            variables = self.get_config().variables
            for variable in variables:
                if variable.name == 'NAMESPACE':
                    return variable.value
        return 'default'
コード例 #8
0
class SchedulerItem(db.Model):
    __tablename__ = 'scheduler_item_v2'
    __table_args__ = (UniqueConstraint('name', name='uniq_name'),
                      default_table_args('scheduler items'))
    id = db.Column(db.Integer,
                   comment='id',
                   primary_key=True,
                   autoincrement=True)
    name = db.Column(db.String(255), comment='item name', nullable=False)
    pipeline = db.Column(db.Text,
                         comment='pipeline',
                         nullable=False,
                         default='{}')
    status = db.Column(db.Integer,
                       comment='item status',
                       nullable=False,
                       default=ItemStatus.ON.value)
    interval_time = db.Column(db.Integer,
                              comment='item run interval in second',
                              nullable=False,
                              default=-1)
    last_run_at = db.Column(db.DateTime(timezone=True),
                            comment='last runner time')
    retry_cnt = db.Column(db.Integer,
                          comment='retry count when item is failed',
                          nullable=False,
                          default=0)
    extra = db.Column(db.Text(), comment='extra info')
    created_at = db.Column(db.DateTime(timezone=True),
                           comment='created at',
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           comment='updated at',
                           server_default=func.now(),
                           onupdate=func.now())
    deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted at')

    def need_run(self) -> bool:
        # job runs one time
        if self.interval_time == -1 and self.last_run_at is None:
            return True
        if self.interval_time > 0:  # cronjob
            if self.last_run_at is None:  # never run
                return True
            # compare datetime in utc
            next_run_at = self.last_run_at.replace(
                tzinfo=datetime.timezone.utc) + datetime.timedelta(
                    seconds=self.interval_time)
            utc_now = datetime.datetime.now(datetime.timezone.utc)
            logging.info(f'[composer] item id: {self.id}, '
                         f'next_run_at: {next_run_at.timestamp()}, '
                         f'utc_now: {utc_now.timestamp()}')
            if next_run_at.timestamp() < utc_now.timestamp():
                return True
        return False
コード例 #9
0
ファイル: models.py プロジェクト: Milkve/fedlearner
class Project(db.Model):
    __tablename__ = 'projects_v2'
    __table_args__ = (UniqueConstraint('name', name='idx_name'),
                      Index('idx_token', 'token'), {
                          'comment': 'webconsole projects',
                          'mysql_engine': 'innodb',
                          'mysql_charset': 'utf8mb4',
                      })
    id = db.Column(db.Integer,
                   primary_key=True,
                   autoincrement=True,
                   comment='id')
    name = db.Column(db.String(255), comment='name')
    token = db.Column(db.String(64), comment='token')
    config = db.Column(db.LargeBinary(), comment='config')
    certificate = db.Column(db.LargeBinary(), comment='certificate')
    comment = db.Column('cmt', db.Text(), key='comment', comment='comment')
    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           comment='created at')
    updated_at = db.Column(db.DateTime(timezone=True),
                           onupdate=func.now(),
                           server_default=func.now(),
                           comment='updated at')
    deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted at')

    def set_config(self, proto):
        self.config = proto.SerializeToString()

    def get_config(self):
        if self.config is None:
            return None
        proto = project_pb2.Project()
        proto.ParseFromString(self.config)
        return proto

    def set_certificate(self, proto):
        self.certificate = proto.SerializeToString()

    def get_certificate(self):
        if self.certificate is None:
            return None
        proto = project_pb2.CertificateStorage()
        proto.ParseFromString(self.certificate)
        return proto

    def get_namespace(self):
        config = self.get_config()
        if config is not None:
            variables = self.get_config().variables
            for variable in variables:
                if variable.name == 'NAMESPACE':
                    return variable.value
        return 'default'
コード例 #10
0
ファイル: models.py プロジェクト: Joejiong/fedlearner
class Project(db.Model):
    __tablename__ = 'projects_v2'
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    name = db.Column(db.String(255), index=True)
    token = db.Column(db.String(64), index=True)
    config = db.Column(db.Text())
    certificate = db.Column(db.Text())
    comment = db.Column(db.Text())
    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_onupdate=func.now(),
                           server_default=func.now())
    deleted_at = db.Column(db.DateTime(timezone=True))

    def set_config(self, proto):
        self.config = proto.SerializeToString()

    def get_config(self):
        proto = project_pb2.Project()
        proto.ParseFromString(self.config)
        return proto

    def set_certificate(self, proto):
        self.certificate = proto.SerializeToString()

    def get_certificate(self):
        proto = project_pb2.Certificate()
        proto.ParseFromString(self.certificate)
        return proto

    def to_dict(self):
        return {
            'id': self.id,
            'name': self.name,
            'token': self.token,
            'config': MessageToDict(self.get_config()),
            'comment': self.comment,
            'created_at': self.created_at.strftime("%Y-%m-%d %H:%M:%S"),
            'updated_at': self.updated_at.strftime("%Y-%m-%d %H:%M:%S"),
        }
コード例 #11
0
class Model(db.Model):
    __tablename__ = 'models_v2'
    __table_args__ = (default_table_args('model'))

    id = db.Column(db.Integer, primary_key=True, comment='id')
    name = db.Column(db.String(255), comment='model_name')
    version = db.Column(db.Integer, comment='model_version')
    parent_id = db.Column(db.String(255), comment='parent_id')
    job_name = db.Column(db.String(255), comment='job_name')
    type = db.Column(db.Enum(ModelType, native_enum=False), comment='type')
    state = db.Column(db.Enum(ModelState, native_enum=False), comment='state')
    create_time = db.Column(db.DateTime(timezone=True), comment='create_time')
    params = db.Column(db.Text(), comment='params')
    metrics = db.Column(db.Text(), comment='metrics')
    output_base_dir = db.Column(db.String(255),
                                comment='model checkpoint/export path')
    parent = db.relationship('Model',
                             primaryjoin=remote(id) == foreign(parent_id),
                             backref='children')
    job = db.relationship('Job', primaryjoin=Job.name == foreign(job_name))

    def __init__(self):
        self.create_time = datetime.now()
        self.version = 0

    def commit(self):
        db.session.add(self)
        db.session.commit()

    def get_eval_model(self):
        """
        Get the evaluation model inherited model

        Returns:
             a list of evaluation model
        """
        eval_model = [
            child for child in self.children if child.type in
            [ModelType.NN_EVALUATION, ModelType.TREE_EVALUATION]
        ]
        return eval_model
コード例 #12
0
class Model(db.Model):
    __tablename__ = 'models_v2'
    __table_args__ = (Index('idx_job_name', 'job_name'),
                      UniqueConstraint('job_name', name='uniq_job_name'),
                      default_table_args('model'))

    id = db.Column(db.Integer, primary_key=True, comment='id')
    name = db.Column(db.String(255),
                     comment='name')  # can be modified by end-user
    version = db.Column(db.Integer, default=0, comment='version')
    type = db.Column(db.Integer, comment='type')
    state = db.Column(db.Integer, comment='state')
    job_name = db.Column(db.String(255), comment='job_name')
    parent_id = db.Column(db.Integer, comment='parent_id')
    params = db.Column(db.Text(), comment='params')
    metrics = db.Column(db.Text(), comment='metrics')
    created_at = db.Column(db.DateTime(timezone=True),
                           comment='created_at',
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           comment='updated_at',
                           server_default=func.now(),
                           onupdate=func.now())
    deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted_at')

    group_id = db.Column(db.Integer, default=0, comment='group_id')
    # TODO https://code.byted.org/data/fedlearner_web_console_v2/issues/289
    extra = db.Column(db.Text(), comment='extra')  # json string

    parent = db.relationship('Model',
                             primaryjoin=remote(id) == foreign(parent_id),
                             backref='children')
    job = db.relationship('Job', primaryjoin=Job.name == foreign(job_name))

    def get_eval_model(self):
        return [
            child for child in self.children if child.type in
            [ModelType.NN_EVALUATION.value, ModelType.TREE_EVALUATION.value]
        ]
コード例 #13
0
class Project(db.Model):
    __tablename__ = 'projects_v2'
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(255), index=True)
    config = db.Column(db.Text())

    def set_config(self, proto):
        self.config = proto.SerializeToString()

    def get_config(self):
        proto = project_pb2.Project()
        proto.ParseFromString(self.config)
        return proto
コード例 #14
0
class Dataset(db.Model):
    __tablename__ = 'datasets_v2'
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    name = db.Column(db.String(255), unique=True, nullable=False)
    dataset_type = db.Column(db.Enum(DatasetType, native_enum=False),
                             nullable=False)
    comment = db.Column('cmt', db.Text(), key='comment')
    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           onupdate=func.now())
    deleted_at = db.Column(db.DateTime(timezone=True))

    data_batches = db.relationship('DataBatch', back_populates='dataset')
コード例 #15
0
class _TestModel(db.Model):
    __tablename__ = 'test_table'
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    name = db.Column(db.String(255))
    token = db.Column(db.String(64), index=True)
    created_at = db.Column(db.DateTime(timezone=True))
    grpc_spec = db.Column(db.Text())

    def set_grpc_spec(self, proto):
        self.grpc_spec = proto.SerializeToString()

    def get_grpc_spec(self):
        proto = common_pb2.GrpcSpec()
        proto.ParseFromString(self.grpc_spec)
        return proto
コード例 #16
0
class SchedulerRunner(db.Model):
    __tablename__ = 'scheduler_runner_v2'
    __table_args__ = (default_table_args('scheduler runners'))
    id = db.Column(db.Integer,
                   comment='id',
                   primary_key=True,
                   autoincrement=True)
    item_id = db.Column(db.Integer, comment='item id', nullable=False)
    status = db.Column(db.Integer,
                       comment='runner status',
                       nullable=False,
                       default=RunnerStatus.INIT.value)
    start_at = db.Column(db.DateTime(timezone=True),
                         comment='runner start time')
    end_at = db.Column(db.DateTime(timezone=True), comment='runner end time')
    pipeline = db.Column(db.Text(),
                         comment='pipeline from scheduler item',
                         nullable=False,
                         default='{}')
    output = db.Column(db.Text(),
                       comment='output',
                       nullable=False,
                       default='{}')
    context = db.Column(db.Text(),
                        comment='context',
                        nullable=False,
                        default='{}')
    extra = db.Column(db.Text(), comment='extra info')
    created_at = db.Column(db.DateTime(timezone=True),
                           comment='created at',
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           comment='updated at',
                           server_default=func.now(),
                           onupdate=func.now())
    deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted at')
コード例 #17
0
ファイル: models.py プロジェクト: flyfoxCI/fedlearner
class WorkflowTemplate(db.Model):
    __tablename__ = 'template_v2'
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(255), unique=True, index=True)
    comment = db.Column(db.String(255))
    group_alias = db.Column(db.String(255), nullable=False, index=True)
    config = db.Column(db.Text(), nullable=False)

    def set_config(self, proto):
        self.config = proto.SerializeToString()

    def get_config(self):
        proto = workflow_definition_pb2.WorkflowDefinition()
        proto.ParseFromString(self.config)
        return proto
コード例 #18
0
class ModelGroup(db.Model):
    __tablename__ = 'model_groups_v2'
    __table_args__ = (default_table_args('model_groups_v2'))

    id = db.Column(db.Integer, primary_key=True, comment='id')
    name = db.Column(db.String(255),
                     comment='name')  # can be modified by end-user

    created_at = db.Column(db.DateTime(timezone=True),
                           comment='created_at',
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           comment='updated_at',
                           server_default=func.now(),
                           onupdate=func.now())
    deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted_at')

    # TODO https://code.byted.org/data/fedlearner_web_console_v2/issues/289
    extra = db.Column(db.Text(), comment='extra')  # json string
コード例 #19
0
ファイル: models.py プロジェクト: zhenv5/fedlearner
class Template(db.Model):
    __tablename__ = 'template_v2'
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(255), unique=True, index=True)
    comment = db.Column(db.String(255))
    group_alias = db.Column(db.String(255), nullable=False, index=True)
    config = db.Column(db.Text(), nullable=False)

    def set_config(self, proto):
        self.config = proto.SerializeToString()

    def get_config(self):
        proto = workflow_definition_pb2.WorkflowDefinition()
        proto.ParseFromString(self.config)
        return proto

    def to_dict(self):
        dic = {
            col.name: getattr(self, col.name)
            for col in self.__table__.columns
        }
        dic['config'] = json_format.MessageToDict(
            self.get_config(), preserving_proto_field_name=True)
        return dic
コード例 #20
0
ファイル: models.py プロジェクト: piiswrong/fedlearner
class Job(db.Model):
    __tablename__ = 'job_v2'
    __table_args__ = (Index('idx_workflow_id', 'workflow_id'), {
        'comment': 'webconsole job',
        'mysql_engine': 'innodb',
        'mysql_charset': 'utf8mb4',
    })
    id = db.Column(db.Integer,
                   primary_key=True,
                   autoincrement=True,
                   comment='id')
    name = db.Column(db.String(255), unique=True, comment='name')
    job_type = db.Column(db.Enum(JobType, native_enum=False),
                         nullable=False,
                         comment='job type')
    state = db.Column(db.Enum(JobState, native_enum=False),
                      nullable=False,
                      default=JobState.INVALID,
                      comment='state')
    yaml_template = db.Column(db.Text(), comment='yaml_template')
    config = db.Column(db.LargeBinary(), comment='config')

    is_disabled = db.Column(db.Boolean(), default=False, comment='is_disabled')

    workflow_id = db.Column(db.Integer, nullable=False, comment='workflow id')
    project_id = db.Column(db.Integer, nullable=False, comment='project id')
    flapp_snapshot = db.Column(db.Text(), comment='flapp snapshot')
    pods_snapshot = db.Column(db.Text(), comment='pods snapshot')

    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           comment='created at')
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           onupdate=func.now(),
                           comment='updated at')
    deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted at')

    project = db.relationship('Project',
                              primaryjoin='Project.id == '
                              'foreign(Job.project_id)')
    workflow = db.relationship('Workflow',
                               primaryjoin='Workflow.id == '
                               'foreign(Job.workflow_id)')
    _k8s_client = get_client()

    def get_config(self):
        if self.config is not None:
            proto = JobDefinition()
            proto.ParseFromString(self.config)
            return proto
        return None

    def _set_snapshot_flapp(self):
        flapp = self._k8s_client.get_custom_object(
            CrdKind.FLAPP, self.name, self.project.get_namespace())
        self.flapp_snapshot = json.dumps(flapp)

    def _set_snapshot_pods(self):
        pods = self._k8s_client.list_resource_of_custom_object(
            CrdKind.FLAPP, self.name, 'pods', self.project.get_namespace())
        self.pods_snapshot = json.dumps(pods)

    def get_pods(self):
        if self.state == JobState.STARTED:
            try:
                pods = self._k8s_client.list_resource_of_custom_object(
                    CrdKind.FLAPP, self.name, 'pods',
                    self.project.get_namespace())
                return pods['pods']
            except RuntimeError as e:
                logging.error('Get %d pods error msg: %s', self.id, e.args)
                return None
        if self.pods_snapshot is not None:
            return json.loads(self.pods_snapshot)['pods']
        return None

    def get_flapp(self):
        if self.state == JobState.STARTED:
            try:
                flapp = self._k8s_client.get_custom_object(
                    CrdKind.FLAPP, self.name, self.project.get_namespace())
                return flapp['flapp']
            except RuntimeError as e:
                logging.error('Get %d flapp error msg: %s', self.id, str(e))
                return None
        if self.flapp_snapshot is not None:
            return json.loads(self.flapp_snapshot)['flapp']
        return None

    def get_pods_for_frontend(self, filter_private_info=False):
        result = []
        flapp = self.get_flapp()
        if flapp is None:
            return result
        if 'status' in flapp \
            and 'flReplicaStatus' in flapp['status']:
            replicas = flapp['status']['flReplicaStatus']
            if replicas:
                for pod_type in replicas:
                    for state in ['failed', 'succeeded']:
                        for pod in replicas[pod_type][state]:
                            result.append({
                                'name': pod,
                                'pod_type': pod_type,
                                'status': 'Flapp_{}'.format(state),
                                'message': '',
                            })

        # msg from pods
        pods = self.get_pods()
        if pods is None:
            return result
        pods = pods['items']
        for pod in pods:
            status = pod['status']['phase'].lower()
            msgs = []
            if 'containerStatuses' in pod['status']:
                state = pod['status']['containerStatuses'][0]['state']
                for key, detail in state.items():
                    if filter_private_info:
                        if 'reason' in detail:
                            msgs.append(key + ':' + detail['reason'])
                    elif 'message' in detail:
                        msgs.append(key + ':' + detail['message'])

            for cond in pod['status']['conditions']:
                if filter_private_info:
                    if 'reason' in cond:
                        msgs.append(cond['type'] + ':' + cond['reason'])
                elif 'message' in cond:
                    msgs.append(cond['type'] + ':' + cond['message'])

            result.append({
                'name':
                pod['metadata']['name'],
                'pod_type':
                pod['metadata']['labels']['fl-replica-type'],
                'status':
                status,
                'message':
                ', '.join(msgs)
            })

        # deduplication pods both in pods and flapp
        result = list({pod['name']: pod for pod in result}.values())
        return result

    def get_state_for_frontend(self):
        if self.state == JobState.STARTED:
            if self.is_complete():
                return 'COMPLETED'
            if self.is_failed():
                return 'FAILED'
            return 'RUNNING'
        if self.state == JobState.STOPPED:
            if self.get_flapp() is None:
                return 'NEW'
        return self.state.name

    def is_failed(self):
        flapp = self.get_flapp()
        if flapp is None \
                or 'status' not in flapp \
                or 'appState' not in flapp['status']:
            return False
        return flapp['status']['appState'] in [
            'FLStateFailed', 'FLStateShutDown'
        ]

    def is_complete(self):
        flapp = self.get_flapp()
        if flapp is None \
                or 'status' not in flapp \
                or 'appState' not in flapp['status']:
            return False
        return flapp['status']['appState'] == 'FLStateComplete'

    def get_complete_at(self):
        flapp = self.get_flapp()
        if flapp is None \
                or 'status' not in flapp \
                or 'complete_at' not in flapp['status']:
            return None
        return flapp['status']['complete_at']

    def stop(self):
        if self.state == JobState.STARTED:
            self._set_snapshot_flapp()
            self._set_snapshot_pods()
            self._k8s_client.delete_custom_object(CrdKind.FLAPP, self.name,
                                                  self.project.get_namespace())
        self.state = JobState.STOPPED

    def schedule(self):
        assert self.state == JobState.STOPPED
        self.pods_snapshot = None
        self.flapp_snapshot = None
        self.state = JobState.WAITING

    def start(self):
        self.state = JobState.STARTED

    def set_yaml_template(self, yaml_template):
        self.yaml_template = yaml_template
コード例 #21
0
class Job(db.Model):
    __tablename__ = 'job_v2'
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    name = db.Column(db.String(255), unique=True)
    job_type = db.Column(db.String(16), nullable=False)
    status = db.Column(db.Enum(JobStatus), nullable=False)
    yaml = db.Column(db.Text(), nullable=False)

    workflow_id = db.Column(db.Integer, db.ForeignKey(Workflow.id),
                            nullable=False, index=True)
    project_id = db.Column(db.Integer, db.ForeignKey(Project.id),
                           nullable=False)

    flapp_snapshot = db.Column(db.Text())
    pods_snapshot = db.Column(db.Text())
    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           server_onupdate=func.now())
    deleted_at = db.Column(db.DateTime(timezone=True))
    _project_adapter = ProjectK8sAdapter(project_id)
    _k8s_client = get_client()

    def _set_snapshot_flapp(self):
        flapp = json.dumps(self._k8s_client.get_flapp(self.
                           _project_adapter.get_namespace(), self.name))
        self.flapp_snapshot = json.dumps(flapp)

    def _set_snapshot_pods(self):
        flapp = json.dumps(self._k8s_client.get_pods(self.
                           _project_adapter.get_namespace(), self.name))
        self.flapp_snapshot = json.dumps(flapp)

    def get_flapp(self):
        if self.status == JobStatus.STARTED:
            self._set_snapshot_flapp()
        return json.loads(self.flapp_snapshot)

    def get_pods(self):
        if self.status == JobStatus.STARTED:
            self._set_snapshot_pods()
        return json.loads(self.pods_snapshot)

    def run(self):
        if self.status == JobStatus.STARTED:
            raise ResourceConflictException('Job has been started')
        self.status = JobStatus.STARTED
        self._k8s_client.create_flapp(self._project_adapter.
                                      get_namespace(), self.yaml)

    def stop(self):
        if self.status == JobStatus.STOPPED:
            raise ResourceConflictException('Job has stopped')
        self.status = JobStatus.STOPPED
        self._set_snapshot_flapp()
        self._set_snapshot_pods()
        self._k8s_client.delete_flapp(self._project_adapter.
                                      get_namespace(), self.name)

    def set_yaml(self, yaml_template, job_config):
        yaml = merge(yaml_template,
                     self._project_adapter.get_global_job_spec())
コード例 #22
0
class Workflow(db.Model):
    __tablename__ = 'workflow_v2'
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(255), unique=True, index=True)
    project_id = db.Column(db.Integer, db.ForeignKey(Project.id))
    config = db.Column(db.Text())
    forkable = db.Column(db.Boolean, default=False)
    forked_from = db.Column(db.Integer, default=None)
    comment = db.Column(db.String(255))

    state = db.Column(db.Enum(WorkflowState), default=WorkflowState.INVALID)
    target_state = db.Column(db.Enum(WorkflowState),
                             default=WorkflowState.INVALID)
    transaction_state = db.Column(db.Enum(TransactionState),
                                  default=TransactionState.READY)
    transaction_err = db.Column(db.Text())

    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_onupdate=func.now(),
                           server_default=func.now())

    project = db.relationship(Project)

    def set_config(self, proto):
        if proto is not None:
            self.config = proto.SerializeToString()
        else:
            self.config = None

    def get_config(self):
        if self.config is not None:
            proto = workflow_definition_pb2.WorkflowDefinition()
            proto.ParseFromString(self.config)
            return proto
        return None

    def update_state(self, state, target_state, transaction_state):
        assert state is None or self.state == state, \
            'Cannot change current state directly'

        if target_state and self.target_state != target_state:
            assert self.target_state == WorkflowState.INVALID, \
                'Another transaction is in progress'
            assert self.transaction_state == TransactionState.READY, \
                'Another transaction is in progress'
            assert (self.state, target_state) in VALID_TRANSITIONS, \
                'Invalid transition from %s to %s'%(self.state, target_state)
            self.target_state = target_state

        if transaction_state is None or \
                transaction_state == self.transaction_state:
            return self.transaction_state

        if (self.transaction_state, transaction_state) in \
                IGNORED_TRANSACTION_TRANSITIONS:
            return self.transaction_state

        assert (self.transaction_state, transaction_state) in \
            VALID_TRANSACTION_TRANSITIONS, \
                'Invalid transaction transition from %s to %s'%(
                    self.transaction_state, transaction_state)
        self.transaction_state = transaction_state

        # coordinator prepare & rollback
        if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
            try:
                self.prepare()
            except Exception as e:
                self.transaction_state = \
                    TransactionState.COORDINATOR_ABORTING

        if self.transaction_state == TransactionState.COORDINATOR_ABORTING:
            try:
                self.rollback()
            except Exception as e:
                pass

        # participant prepare & rollback & commit
        if self.transaction_state == TransactionState.PARTICIPANT_PREPARE:
            try:
                self.prepare()
            except Exception as e:
                self.transaction_state = \
                    TransactionState.PARTICIPANT_ABORTING

        if self.transaction_state == TransactionState.PARTICIPANT_ABORTING:
            try:
                self.rollback()
            except Exception as e:
                pass
            self.target_state = WorkflowState.INVALID
            self.transaction_state = \
                TransactionState.ABORTED

        if self.transaction_state == TransactionState.PARTICIPANT_COMMITTING:
            self.commit()

        return self.transaction_state

    def prepare(self):
        assert self.transaction_state in [
            TransactionState.COORDINATOR_PREPARE,
            TransactionState.PARTICIPANT_PREPARE], \
                "Workflow not in prepare state"

        success = False
        if self.target_state == WorkflowState.READY:
            success = bool(self.config)
        elif self.target_state == WorkflowState.RUNNING:
            success = True
        elif self.target_state == WorkflowState.STOPPED:
            success = True
        else:
            raise RuntimeError("Invalid target_state %s" % self.target_state)
        if success:
            if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
                self.transaction_state = \
                    TransactionState.COORDINATOR_COMMITTABLE
            else:
                self.transaction_state = \
                    TransactionState.PARTICIPANT_COMMITTABLE

    def rollback(self):
        pass

    def commit(self):
        assert self.transaction_state in [
            TransactionState.COORDINATOR_COMMITTING,
            TransactionState.PARTICIPANT_COMMITTING], \
                "Workflow not in prepare state"

        if self.target_state == WorkflowState.STOPPED:
            # TODO: delete jobs from k8s
            pass
        elif self.target_state == WorkflowState.READY:
            # TODO: create workflow jobs in database according to config
            pass

        self.state = self.target_state
        self.target_state = WorkflowState.INVALID
        self.transaction_state = TransactionState.READY

    def log_states(self):
        logging.debug(
            'workflow %d updated to state=%s, target_state=%s, '
            'transaction_state=%s', self.id, self.state.name,
            self.target_state.name, self.transaction_state.name)
コード例 #23
0
class DataBatch(db.Model):
    __tablename__ = 'data_batches_v2'
    __table_args__ = (
        UniqueConstraint('event_time',
                         'dataset_id',
                         name='uniq_event_time_dataset_id'),
        {
            'comment': 'This is webconsole dataset table',
            'mysql_engine': 'innodb',
            'mysql_charset': 'utf8mb4',
        },
    )
    id = db.Column(db.Integer,
                   primary_key=True,
                   autoincrement=True,
                   comment='id')
    event_time = db.Column(db.TIMESTAMP(timezone=True),
                           nullable=False,
                           comment='event_time')
    dataset_id = db.Column(db.Integer, nullable=False, comment='dataset_id')
    path = db.Column(db.String(512), comment='path')
    state = db.Column(db.Enum(BatchState, native_enum=False),
                      default=BatchState.NEW,
                      comment='state')
    move = db.Column(db.Boolean, default=False, comment='move')
    # Serialized proto of DatasetBatch
    details = db.Column(db.LargeBinary(), comment='details')
    file_size = db.Column(db.Integer, default=0, comment='file_size')
    num_imported_file = db.Column(db.Integer,
                                  default=0,
                                  comment='num_imported_file')
    num_file = db.Column(db.Integer, default=0, comment='num_file')
    comment = db.Column('cmt', db.Text(), key='comment', comment='comment')
    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           comment='created_at')
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           server_onupdate=func.now(),
                           comment='updated_at')
    deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted_at')

    dataset = db.relationship('Dataset',
                              primaryjoin='Dataset.id == '
                              'foreign(DataBatch.dataset_id)',
                              back_populates='data_batches')

    def set_details(self, proto):
        self.num_file = len(proto.files)
        num_imported_file = 0
        num_failed_file = 0
        file_size = 0
        # Aggregates stats
        for file in proto.files:
            if file.state == dataset_pb2.File.State.COMPLETED:
                num_imported_file += 1
                file_size += file.size
            elif file.state == dataset_pb2.File.State.FAILED:
                num_failed_file += 1
        if num_imported_file + num_failed_file == self.num_file:
            if num_failed_file > 0:
                self.state = BatchState.FAILED
            else:
                self.state = BatchState.SUCCESS
        self.num_imported_file = num_imported_file
        self.file_size = file_size
        self.details = proto.SerializeToString()

    def get_details(self):
        if self.details is None:
            return None
        proto = dataset_pb2.DataBatch()
        proto.ParseFromString(self.details)
        return proto
コード例 #24
0
class Job(db.Model):
    __tablename__ = 'job_v2'
    __table_args__ = (Index('idx_workflow_id', 'workflow_id'), {
        'comment': 'webconsole job',
        'mysql_engine': 'innodb',
        'mysql_charset': 'utf8mb4',
    })
    id = db.Column(db.Integer,
                   primary_key=True,
                   autoincrement=True,
                   comment='id')
    name = db.Column(db.String(255), unique=True, comment='name')
    job_type = db.Column(db.Enum(JobType, native_enum=False),
                         nullable=False,
                         comment='job type')
    state = db.Column(db.Enum(JobState, native_enum=False),
                      nullable=False,
                      default=JobState.INVALID,
                      comment='state')
    config = db.Column(db.LargeBinary(16777215), comment='config')

    is_disabled = db.Column(db.Boolean(), default=False, comment='is_disabled')

    workflow_id = db.Column(db.Integer, nullable=False, comment='workflow id')
    project_id = db.Column(db.Integer, nullable=False, comment='project id')
    flapp_snapshot = db.Column(db.Text(16777215), comment='flapp snapshot')
    pods_snapshot = db.Column(db.Text(16777215), comment='pods snapshot')
    error_message = db.Column(db.Text(), comment='error message')

    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           comment='created at')
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           onupdate=func.now(),
                           comment='updated at')
    deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted at')

    project = db.relationship('Project',
                              primaryjoin='Project.id == '
                              'foreign(Job.project_id)')
    workflow = db.relationship('Workflow',
                               primaryjoin='Workflow.id == '
                               'foreign(Job.workflow_id)')

    def get_config(self):
        if self.config is not None:
            proto = JobDefinition()
            proto.ParseFromString(self.config)
            return proto
        return None

    def set_config(self, proto):
        if proto is not None:
            self.config = proto.SerializeToString()
        else:
            self.config = None

    def _set_snapshot_flapp(self):
        def default(o):
            if isinstance(o, (datetime.date, datetime.datetime)):
                return o.isoformat()
            return str(o)

        flapp = k8s_client.get_flapp(self.name)
        if flapp:
            self.flapp_snapshot = json.dumps(flapp, default=default)
        else:
            self.flapp_snapshot = None

    def get_flapp_details(self):
        if self.state == JobState.STARTED:
            flapp = k8s_client.get_flapp(self.name)
        elif self.flapp_snapshot is not None:
            flapp = json.loads(self.flapp_snapshot)
            # aims to support old job
            if 'flapp' not in flapp:
                flapp['flapp'] = None
            if 'pods' not in flapp and self.pods_snapshot:
                flapp['pods'] = json.loads(self.pods_snapshot)['pods']
        else:
            flapp = {'flapp': None, 'pods': {'items': []}}
        return flapp

    def get_pods_for_frontend(self, include_private_info=True):
        flapp_details = self.get_flapp_details()
        flapp = FlApp.from_json(flapp_details.get('flapp', None))
        pods_json = None
        if 'pods' in flapp_details:
            pods_json = flapp_details['pods'].get('items', None)
        pods = []
        if pods_json is not None:
            pods = [Pod.from_json(p) for p in pods_json]

        # deduplication pods both in pods and flapp
        result = {}
        for pod in flapp.pods:
            result[pod.name] = pod
        for pod in pods:
            result[pod.name] = pod
        return [pod.to_dict(include_private_info) for pod in result.values()]

    def get_state_for_frontend(self):
        return self.state.name

    def is_flapp_failed(self):
        # TODO: make the getter more efficient
        flapp = FlApp.from_json(self.get_flapp_details()['flapp'])
        return flapp.state in [FlAppState.FAILED, FlAppState.SHUTDOWN]

    def is_flapp_complete(self):
        # TODO: make the getter more efficient
        flapp = FlApp.from_json(self.get_flapp_details()['flapp'])
        return flapp.state == FlAppState.COMPLETED

    def get_complete_at(self):
        # TODO: make the getter more efficient
        flapp = FlApp.from_json(self.get_flapp_details()['flapp'])
        return flapp.completed_at

    def stop(self):
        if self.state not in [JobState.WAITING, JobState.STARTED,
                              JobState.COMPLETED, JobState.FAILED]:
            logging.warning('illegal job state, name: %s, state: %s',
                            self.name, self.state)
            return
        if self.state == JobState.STARTED:
            self._set_snapshot_flapp()
            k8s_client.delete_flapp(self.name)
        # state change:
        # WAITING -> NEW
        # STARTED -> STOPPED
        # COMPLETED/FAILED unchanged
        if self.state == JobState.STARTED:
            self.state = JobState.STOPPED
        if self.state == JobState.WAITING:
            self.state = JobState.NEW

    def schedule(self):
        # COMPLETED/FAILED Job State can be scheduled since stop action
        # will not change the state of completed or failed job
        assert self.state in [JobState.NEW, JobState.STOPPED,
                              JobState.COMPLETED, JobState.FAILED]
        self.pods_snapshot = None
        self.flapp_snapshot = None
        self.state = JobState.WAITING

    def start(self):
        assert self.state == JobState.WAITING
        self.state = JobState.STARTED

    def complete(self):
        assert self.state == JobState.STARTED, 'Job State is not STARTED'
        self._set_snapshot_flapp()
        k8s_client.delete_flapp(self.name)
        self.state = JobState.COMPLETED

    def fail(self):
        assert self.state == JobState.STARTED, 'Job State is not STARTED'
        self._set_snapshot_flapp()
        k8s_client.delete_flapp(self.name)
        self.state = JobState.FAILED
コード例 #25
0
ファイル: models.py プロジェクト: GodCedric/fedlearner
class Job(db.Model):
    __tablename__ = 'job_v2'
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    name = db.Column(db.String(255), unique=True)
    job_type = db.Column(db.Enum(JobType, native_enum=False), nullable=False)
    state = db.Column(db.Enum(JobState, native_enum=False),
                      nullable=False,
                      default=JobState.INVALID)
    yaml_template = db.Column(db.Text())
    config = db.Column(db.LargeBinary())

    workflow_id = db.Column(db.Integer,
                            db.ForeignKey('workflow_v2.id'),
                            nullable=False,
                            index=True)
    project_id = db.Column(db.Integer,
                           db.ForeignKey(Project.id),
                           nullable=False)
    flapp_snapshot = db.Column(db.Text())
    pods_snapshot = db.Column(db.Text())

    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           onupdate=func.now())
    deleted_at = db.Column(db.DateTime(timezone=True))

    project = db.relationship(Project)
    workflow = db.relationship('Workflow')
    _k8s_client = get_client()

    def get_config(self):
        if self.config is not None:
            proto = JobDefinition()
            proto.ParseFromString(self.config)
            return proto
        return None

    def _set_snapshot_flapp(self):
        flapp = self._k8s_client.get_custom_object(
            CrdKind.FLAPP, self.name, self.project.get_namespace())
        self.flapp_snapshot = json.dumps(flapp)

    def _set_snapshot_pods(self):
        pods = self._k8s_client.list_resource_of_custom_object(
            CrdKind.FLAPP, self.name, 'pods', self.project.get_namespace())
        self.pods_snapshot = json.dumps(pods)

    def get_pods(self):
        if self.state == JobState.STARTED:
            try:
                pods = self._k8s_client.list_resource_of_custom_object(
                    CrdKind.FLAPP, self.name, 'pods',
                    self.project.get_namespace())
                return pods['pods']
            except RuntimeError as e:
                logging.error('Get %d pods error msg: %s', self.id, e.args)
                return None
        if self.pods_snapshot is not None:
            return json.loads(self.pods_snapshot)['pods']
        return None

    def get_flapp(self):
        if self.state == JobState.STARTED:
            try:
                flapp = self._k8s_client.get_custom_object(
                    CrdKind.FLAPP, self.name, self.project.get_namespace())
                return flapp['flapp']
            except RuntimeError as e:
                logging.error('Get %d flapp error msg: %s', self.id, str(e))
                return None
        if self.flapp_snapshot is not None:
            return json.loads(self.flapp_snapshot)['flapp']
        return None

    def get_pods_for_frontend(self):
        result = []
        flapp = self.get_flapp()
        if flapp is None:
            return result
        if 'status' in flapp \
            and 'flReplicaStatus' in flapp['status']:
            replicas = flapp['status']['flReplicaStatus']
            if replicas is None:
                return result
            for pod_type in replicas:
                for state in ['failed', 'succeeded']:
                    for pod in replicas[pod_type][state]:
                        result.append({
                            'name': pod,
                            'status': 'Flapp_{}'.format(state),
                            'pod_type': pod_type
                        })
        # msg from pods
        pods = self.get_pods()
        if pods is None:
            return result
        pods = pods['items']
        for pod in pods:
            # TODO: make this more readable for frontend
            pod_for_front = {
                'name': pod['metadata']['name'],
                'pod_type': pod['metadata']['labels']['fl-replica-type'],
                'status': pod['status']['phase'],
                'conditions': pod['status']['conditions']
            }
            if 'containerStatuses' in pod['status']:
                pod_for_front['containers_status'] = \
                    pod['status']['containerStatuses']
            result.append(pod_for_front)
        # deduplication pods both in pods and flapp
        result = list({pod['name']: pod for pod in result}.values())
        return result

    def get_state_for_frontend(self):
        if self.state == JobState.STARTED:
            if self.is_complete():
                return 'COMPLETED'
            if self.is_failed():
                return 'FAILED'
            return 'RUNNING'
        if self.state == JobState.STOPPED:
            if self.get_flapp() is None:
                return 'NEW'
        return self.state.name

    def is_failed(self):
        flapp = self.get_flapp()
        if flapp is None \
                or 'status' not in flapp \
                or 'appState' not in flapp['status']:
            return False
        return flapp['status']['appState'] in [
            'FLStateFailed', 'FLStateShutDown'
        ]

    def is_complete(self):
        flapp = self.get_flapp()
        if flapp is None \
                or 'status' not in flapp \
                or 'appState' not in flapp['status']:
            return False
        return flapp['status']['appState'] == 'FLStateComplete'

    def get_complete_at(self):
        flapp = self.get_flapp()
        if flapp is None \
                or 'status' not in flapp \
                or 'complete_at' not in flapp['status']:
            return None
        return flapp['status']['complete_at']

    def stop(self):
        if self.state == JobState.STARTED:
            self._set_snapshot_flapp()
            self._set_snapshot_pods()
            self._k8s_client.delete_custom_object(CrdKind.FLAPP, self.name,
                                                  self.project.get_namespace())
        self.state = JobState.STOPPED

    def schedule(self):
        assert self.state == JobState.STOPPED
        self.pods_snapshot = None
        self.flapp_snapshot = None
        self.state = JobState.WAITING

    def start(self):
        self.state = JobState.STARTED

    def set_yaml_template(self, yaml_template):
        self.yaml_template = yaml_template
コード例 #26
0
class Workflow(db.Model):
    __tablename__ = 'workflow_v2'
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(255), unique=True, index=True)
    project_id = db.Column(db.Integer, db.ForeignKey(Project.id))
    config = db.Column(db.Text())
    comment = db.Column(db.String(255))

    forkable = db.Column(db.Boolean, default=False)
    forked_from = db.Column(db.Integer, default=None)
    # index in config.job_defs instead of job's id
    reuse_job_names = db.Column(db.TEXT())
    peer_reuse_job_names = db.Column(db.TEXT())
    fork_proposal_config = db.Column(db.TEXT())

    recur_type = db.Column(db.Enum(RecurType), default=RecurType.NONE)
    recur_at = db.Column(db.Interval)
    trigger_dataset = db.Column(db.Integer)
    last_triggered_batch = db.Column(db.Integer)

    job_ids = db.Column(db.TEXT())

    state = db.Column(db.Enum(WorkflowState), default=WorkflowState.INVALID)
    target_state = db.Column(db.Enum(WorkflowState),
                             default=WorkflowState.INVALID)
    transaction_state = db.Column(db.Enum(TransactionState),
                                  default=TransactionState.READY)
    transaction_err = db.Column(db.Text())

    start_at = db.Column(db.Integer)
    stop_at = db.Column(db.Integer)

    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           onupdate=func.now(),
                           server_default=func.now())

    owned_jobs = db.relationship('Job', back_populates='workflow')
    project = db.relationship(Project)

    def set_config(self, proto):
        if proto is not None:
            self.config = proto.SerializeToString()
        else:
            self.config = None

    def get_config(self):
        if self.config is not None:
            proto = workflow_definition_pb2.WorkflowDefinition()
            proto.ParseFromString(self.config)
            return proto
        return None

    def set_fork_proposal_config(self, proto):
        if proto is not None:
            self.fork_proposal_config = proto.SerializeToString()
        else:
            self.fork_proposal_config = None

    def get_fork_proposal_config(self):
        if self.fork_proposal_config is not None:
            proto = workflow_definition_pb2.WorkflowDefinition()
            proto.ParseFromString(self.fork_proposal_config)
            return proto
        return None

    def set_job_ids(self, job_ids):
        self.job_ids = ','.join([str(i) for i in job_ids])

    def get_job_ids(self):
        if not self.job_ids:
            return []
        return [int(i) for i in self.job_ids.split(',')]

    def get_jobs(self):
        return [Job.query.get(i) for i in self.get_job_ids()]

    def set_reuse_job_names(self, reuse_job_names):
        self.reuse_job_names = ','.join(reuse_job_names)

    def get_reuse_job_names(self):
        if not self.reuse_job_names:
            return []
        return self.reuse_job_names.split(',')

    def set_peer_reuse_job_names(self, peer_reuse_job_names):
        self.peer_reuse_job_names = ','.join(peer_reuse_job_names)

    def get_peer_reuse_job_names(self):
        if not self.peer_reuse_job_names:
            return []
        return self.peer_reuse_job_names.split(',')

    def update_target_state(self, target_state):
        if self.target_state != target_state \
                and self.target_state != WorkflowState.INVALID:
            raise ValueError(f'Another transaction is in progress [{self.id}]')
        if target_state not in [
                WorkflowState.READY, WorkflowState.RUNNING,
                WorkflowState.STOPPED
        ]:
            raise ValueError(f'Invalid target_state {self.target_state}')
        if (self.state, target_state) not in VALID_TRANSITIONS:
            raise ValueError(
                f'Invalid transition from {self.state} to {target_state}')
        self.target_state = target_state

    def update_state(self, asserted_state, target_state, transaction_state):
        assert asserted_state is None or self.state == asserted_state, \
            'Cannot change current state directly'

        if transaction_state != self.transaction_state:
            if (self.transaction_state, transaction_state) in \
                    IGNORED_TRANSACTION_TRANSITIONS:
                return self.transaction_state
            assert (self.transaction_state, transaction_state) in \
                   VALID_TRANSACTION_TRANSITIONS, \
                'Invalid transaction transition from {} to {}'.format(
                    self.transaction_state, transaction_state)
            self.transaction_state = transaction_state

        # coordinator prepare & rollback
        if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
            self.prepare(target_state)
        if self.transaction_state == TransactionState.COORDINATOR_ABORTING:
            self.rollback()

        # participant prepare & rollback & commit
        if self.transaction_state == TransactionState.PARTICIPANT_PREPARE:
            self.prepare(target_state)
        if self.transaction_state == TransactionState.PARTICIPANT_ABORTING:
            self.rollback()
            self.transaction_state = TransactionState.ABORTED
        if self.transaction_state == TransactionState.PARTICIPANT_COMMITTING:
            self.commit()

        return self.transaction_state

    def prepare(self, target_state):
        assert self.transaction_state in [
            TransactionState.COORDINATOR_PREPARE,
            TransactionState.PARTICIPANT_PREPARE], \
            'Workflow not in prepare state'

        # TODO(tjulinfan): remove this
        if target_state is None:
            # No action
            return

        # Validation
        try:
            self.update_target_state(target_state)
        except ValueError as e:
            logging.warning('Error during update target state in prepare: %s',
                            str(e))
            self.transaction_state = TransactionState.ABORTED
            return

        success = True
        if self.target_state == WorkflowState.READY:
            success = self._prepare_for_ready()

        if success:
            if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
                self.transaction_state = \
                    TransactionState.COORDINATOR_COMMITTABLE
            else:
                self.transaction_state = \
                    TransactionState.PARTICIPANT_COMMITTABLE

    def rollback(self):
        self.target_state = WorkflowState.INVALID

    # TODO: separate this method to another module
    def commit(self):
        assert self.transaction_state in [
            TransactionState.COORDINATOR_COMMITTING,
            TransactionState.PARTICIPANT_COMMITTING], \
                'Workflow not in prepare state'

        if self.target_state == WorkflowState.STOPPED:
            self.stop_at = int(datetime.utcnow().timestamp())
            for job in self.owned_jobs:
                job.stop()
        elif self.target_state == WorkflowState.READY:
            self._setup_jobs()
            self.fork_proposal_config = None
        elif self.target_state == WorkflowState.RUNNING:
            self.start_at = int(datetime.utcnow().timestamp())
            for job in self.owned_jobs:
                if not job.get_config().is_manual:
                    job.schedule()

        self.state = self.target_state
        self.target_state = WorkflowState.INVALID
        self.transaction_state = TransactionState.READY

    def _setup_jobs(self):
        if self.forked_from is not None:
            trunk = Workflow.query.get(self.forked_from)
            assert trunk is not None, \
                'Source workflow %d not found'%self.forked_from
            trunk_job_defs = trunk.get_config().job_definitions
            trunk_name2index = {
                job.name: i
                for i, job in enumerate(trunk_job_defs)
            }
        else:
            assert not self.get_reuse_job_names()

        job_defs = self.get_config().job_definitions
        jobs = []
        reuse_jobs = set(self.get_reuse_job_names())
        for i, job_def in enumerate(job_defs):
            if job_def.name in reuse_jobs:
                assert job_def.name in trunk_name2index, \
                    "Job %s not found in base workflow"%job_def.name
                j = trunk.get_job_ids()[trunk_name2index[job_def.name]]
                job = Job.query.get(j)
                assert job is not None, \
                    'Job %d not found'%j
                # TODO: check forked jobs does not depend on non-forked jobs
            else:
                job = Job(name=f'{self.name}-{job_def.name}',
                          job_type=JobType(job_def.type),
                          config=job_def.SerializeToString(),
                          workflow_id=self.id,
                          project_id=self.project_id,
                          state=JobState.STOPPED)
                job.set_yaml_template(job_def.yaml_template)
                db.session.add(job)
            jobs.append(job)
        db.session.commit()

        name2index = {job.name: i for i, job in enumerate(job_defs)}
        for i, job in enumerate(jobs):
            if job.name in reuse_jobs:
                continue
            for j, dep_def in enumerate(job.get_config().dependencies):
                dep = JobDependency(
                    src_job_id=jobs[name2index[dep_def.source]].id,
                    dst_job_id=job.id,
                    dep_index=j)
                db.session.add(dep)

        self.set_job_ids([job.id for job in jobs])

        db.session.commit()

    def log_states(self):
        logging.debug(
            'workflow %d updated to state=%s, target_state=%s, '
            'transaction_state=%s', self.id, self.state.name,
            self.target_state.name, self.transaction_state.name)

    def _get_peer_workflow(self):
        project_config = self.project.get_config()
        # TODO: find coordinator for multiparty
        client = RpcClient(project_config, project_config.participants[0])
        return client.get_workflow(self.name)

    def _prepare_for_ready(self):
        # This is a hack, if config is not set then
        # no action needed
        if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
            # TODO(tjulinfan): validate if the config is legal or not
            return bool(self.config)
        peer_workflow = self._get_peer_workflow()
        if peer_workflow.forked_from:
            base_workflow = Workflow.query.filter(
                Workflow.name == peer_workflow.forked_from).first()
            if base_workflow is None or not base_workflow.forkable:
                return False
            self.forked_from = base_workflow.id
            self.forkable = base_workflow.forkable
            self.set_reuse_job_names(peer_workflow.peer_reuse_job_names)
            self.set_peer_reuse_job_names(peer_workflow.reuse_job_names)
            config = base_workflow.get_config()
            _merge_workflow_config(config, peer_workflow.fork_proposal_config,
                                   [common_pb2.Variable.PEER_WRITABLE])
            self.set_config(config)
            return True
        return bool(self.config)
コード例 #27
0
class Job(db.Model):
    __tablename__ = 'job_v2'
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    name = db.Column(db.String(255), unique=True)
    job_type = db.Column(db.Enum(JobType), nullable=False)
    state = db.Column(db.Enum(JobState),
                      nullable=False,
                      default=JobState.INVALID)
    yaml_template = db.Column(db.Text(), nullable=False)
    config = db.Column(db.Text(), nullable=False)
    workflow_id = db.Column(db.Integer,
                            db.ForeignKey('workflow_v2.id'),
                            nullable=False,
                            index=True)
    project_id = db.Column(db.Integer,
                           db.ForeignKey(Project.id),
                           nullable=False)
    flapp_snapshot = db.Column(db.Text())
    pods_snapshot = db.Column(db.Text())
    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           onupdate=func.now())
    deleted_at = db.Column(db.DateTime(timezone=True))

    project = db.relationship(Project)
    workflow = db.relationship('Workflow')
    _k8s_client = get_client()

    def get_config(self):
        if self.config is not None:
            proto = JobDefinition()
            proto.ParseFromString(self.config)
            return proto
        return None

    def _set_snapshot_flapp(self):
        flapp = self._k8s_client.get_custom_object(
            CrdKind.FLAPP, self.name, self.project.get_namespace())
        self.flapp_snapshot = json.dumps(flapp)

    def _set_snapshot_pods(self):
        pods = self._k8s_client.list_resource_of_custom_object(
            CrdKind.FLAPP, self.name, 'pods', self.project.get_namespace())
        self.pods_snapshot = json.dumps(pods)

    def get_flapp(self):
        if self.state == JobState.STARTED:
            return self._k8s_client.list_resource_of_custom_object(
                CrdKind.FLAPP, self.name, 'pods', self.project.get_namespace())
        if self.flapp_snapshot is not None:
            return json.loads(self.flapp_snapshot)
        return None

    def get_pods(self):
        if self.state == JobState.STARTED:
            return self._k8s_client.list_resource_of_custom_object(
                CrdKind.FLAPP, self.name, 'pods', self.project.get_namespace())
        if self.pods_snapshot is not None:
            return json.loads(self.pods_snapshot)
        return None

    def get_pods_for_front(self):
        result = []
        flapp = self.get_flapp()
        if flapp is not None \
                and 'status' in flapp \
                and 'flReplicaStatus' in flapp['status']:
            replicas = flapp['status']['flReplicaStatus']
            for pod_type in replicas:
                for state in replicas[pod_type]:
                    for pod in replicas[pod_type][state]:
                        result.append({
                            'name': pod,
                            'state': state,
                            'pod_type': pod_type
                        })
        return result

    def get_state_for_front(self):
        if self.state == JobState.STARTED:
            if self.is_complete():
                return 'COMPLETE'
            if self.is_failed():
                return 'FAILED'
            return 'RUNNING'
        if self.state == JobState.STOPPED:
            if self.get_flapp() is None:
                return 'NEW'
        return self.state.name

    def is_failed(self):
        flapp = self.get_flapp()
        if flapp is None \
                or 'status' not in flapp \
                or 'appState' not in flapp['status']:
            return False
        return flapp['status']['appState'] in [
            'FLStateFailed', 'FLStateShutDown'
        ]

    def is_complete(self):
        flapp = self.get_flapp()
        if flapp is None \
                or 'status' not in flapp \
                or 'appState' not in flapp['status']:
            return False
        return flapp['status']['appState'] == 'FLStateComplete'

    def get_complete_at(self):
        flapp = self.get_flapp()
        if flapp is None \
                or 'status' not in flapp \
                or 'complete_at' not in flapp['status']:
            return None
        return flapp['status']['complete_at']

    def stop(self):
        if self.state == JobState.STARTED:
            self._set_snapshot_flapp()
            self._set_snapshot_pods()
            self._k8s_client.delete_custom_object(CrdKind.FLAPP, self.name,
                                                  self.project.get_namespace())
        self.state = JobState.STOPPED

    def schedule(self):
        assert self.state == JobState.STOPPED
        self.pods_snapshot = None
        self.flapp_snapshot = None
        self.state = JobState.WAITING

    def start(self):
        self.state = JobState.STARTED

    def set_yaml_template(self, yaml_template):
        self.yaml_template = yaml_template
コード例 #28
0
ファイル: models.py プロジェクト: piiswrong/fedlearner
class Workflow(db.Model):
    __tablename__ = 'workflow_v2'
    __table_args__ = (UniqueConstraint('uuid', name='uniq_uuid'),
                      UniqueConstraint('name', name='uniq_name'), {
                          'comment': 'workflow_v2',
                          'mysql_engine': 'innodb',
                          'mysql_charset': 'utf8mb4',
                      })
    id = db.Column(db.Integer, primary_key=True, comment='id')
    uuid = db.Column(db.String(64), comment='uuid')
    name = db.Column(db.String(255), comment='name')
    project_id = db.Column(db.Integer, comment='project_id')
    # max store 16777215 bytes (16 MB)
    config = db.Column(db.LargeBinary(16777215), comment='config')
    comment = db.Column('cmt',
                        db.String(255),
                        key='comment',
                        comment='comment')

    metric_is_public = db.Column(db.Boolean(),
                                 default=False,
                                 nullable=False,
                                 comment='metric_is_public')
    create_job_flags = db.Column(db.TEXT(), comment='create_job_flags')

    job_ids = db.Column(db.TEXT(), comment='job_ids')

    forkable = db.Column(db.Boolean, default=False, comment='forkable')
    forked_from = db.Column(db.Integer, default=None, comment='forked_from')
    # index in config.job_defs instead of job's id
    peer_create_job_flags = db.Column(db.TEXT(),
                                      comment='peer_create_job_flags')
    # max store 16777215 bytes (16 MB)
    fork_proposal_config = db.Column(db.LargeBinary(16777215),
                                     comment='fork_proposal_config')

    recur_type = db.Column(db.Enum(RecurType, native_enum=False),
                           default=RecurType.NONE,
                           comment='recur_type')
    recur_at = db.Column(db.Interval, comment='recur_at')
    trigger_dataset = db.Column(db.Integer, comment='trigger_dataset')
    last_triggered_batch = db.Column(db.Integer,
                                     comment='last_triggered_batch')

    state = db.Column(db.Enum(WorkflowState,
                              native_enum=False,
                              name='workflow_state'),
                      default=WorkflowState.INVALID,
                      comment='state')
    target_state = db.Column(db.Enum(WorkflowState,
                                     native_enum=False,
                                     name='workflow_target_state'),
                             default=WorkflowState.INVALID,
                             comment='target_state')
    transaction_state = db.Column(db.Enum(TransactionState, native_enum=False),
                                  default=TransactionState.READY,
                                  comment='transaction_state')
    transaction_err = db.Column(db.Text(), comment='transaction_err')

    start_at = db.Column(db.Integer, comment='start_at')
    stop_at = db.Column(db.Integer, comment='stop_at')

    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           comment='created_at')
    updated_at = db.Column(db.DateTime(timezone=True),
                           onupdate=func.now(),
                           server_default=func.now(),
                           comment='update_at')

    owned_jobs = db.relationship(
        'Job', primaryjoin='foreign(Job.workflow_id) == Workflow.id')
    project = db.relationship(
        'Project', primaryjoin='Project.id == foreign(Workflow.project_id)')

    def get_state_for_frontend(self):
        if self.state == WorkflowState.RUNNING:
            is_complete = all([job.is_complete() for job in self.owned_jobs])
            if is_complete:
                return 'COMPLETED'
            is_failed = any([job.is_failed() for job in self.owned_jobs])
            if is_failed:
                return 'FAILED'
        return self.state.name

    def get_transaction_state_for_frontend(self):
        # TODO(xiangyuxuan): remove this hack by redesign 2pc
        if (self.transaction_state == TransactionState.PARTICIPANT_PREPARE
                and self.config is not None):
            return 'PARTICIPANT_COMMITTABLE'
        return self.transaction_state.name

    def set_config(self, proto):
        if proto is not None:
            self.config = proto.SerializeToString()
        else:
            self.config = None

    def get_config(self):
        if self.config is not None:
            proto = workflow_definition_pb2.WorkflowDefinition()
            proto.ParseFromString(self.config)
            return proto
        return None

    def set_fork_proposal_config(self, proto):
        if proto is not None:
            self.fork_proposal_config = proto.SerializeToString()
        else:
            self.fork_proposal_config = None

    def get_fork_proposal_config(self):
        if self.fork_proposal_config is not None:
            proto = workflow_definition_pb2.WorkflowDefinition()
            proto.ParseFromString(self.fork_proposal_config)
            return proto
        return None

    def set_job_ids(self, job_ids):
        self.job_ids = ','.join([str(i) for i in job_ids])

    def get_job_ids(self):
        if not self.job_ids:
            return []
        return [int(i) for i in self.job_ids.split(',')]

    def get_jobs(self):
        return [Job.query.get(i) for i in self.get_job_ids()]

    def set_create_job_flags(self, create_job_flags):
        if create_job_flags is None:
            self.create_job_flags = None
        else:
            self.create_job_flags = ','.join(
                [str(i) for i in create_job_flags])

    def get_create_job_flags(self):
        if self.create_job_flags is None:
            config = self.get_config()
            if config is None:
                return None
            num_jobs = len(config.job_definitions)
            return [common_pb2.CreateJobFlag.NEW] * num_jobs
        return [int(i) for i in self.create_job_flags.split(',')]

    def set_peer_create_job_flags(self, peer_create_job_flags):
        if not peer_create_job_flags:
            self.peer_create_job_flags = None
        else:
            self.peer_create_job_flags = ','.join(
                [str(i) for i in peer_create_job_flags])

    def get_peer_create_job_flags(self):
        if self.peer_create_job_flags is None:
            return None
        return [int(i) for i in self.peer_create_job_flags.split(',')]

    def update_target_state(self, target_state):
        if self.target_state != target_state \
            and self.target_state != WorkflowState.INVALID:
            raise ValueError(f'Another transaction is in progress [{self.id}]')
        if target_state not in [
                WorkflowState.READY, WorkflowState.RUNNING,
                WorkflowState.STOPPED
        ]:
            raise ValueError(f'Invalid target_state {self.target_state}')
        if (self.state, target_state) not in VALID_TRANSITIONS:
            raise ValueError(
                f'Invalid transition from {self.state} to {target_state}')

        self.target_state = target_state

    def update_state(self, asserted_state, target_state, transaction_state):
        assert asserted_state is None or self.state == asserted_state, \
            'Cannot change current state directly'

        if transaction_state != self.transaction_state:
            if (self.transaction_state, transaction_state) in \
                IGNORED_TRANSACTION_TRANSITIONS:
                return self.transaction_state
            assert (self.transaction_state, transaction_state) in \
                   VALID_TRANSACTION_TRANSITIONS, \
                'Invalid transaction transition from {} to {}'.format(
                    self.transaction_state, transaction_state)
            self.transaction_state = transaction_state

        # coordinator prepare & rollback
        if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
            self.prepare(target_state)
        if self.transaction_state == TransactionState.COORDINATOR_ABORTING:
            self.rollback()

        # participant prepare & rollback & commit
        if self.transaction_state == TransactionState.PARTICIPANT_PREPARE:
            self.prepare(target_state)
        if self.transaction_state == TransactionState.PARTICIPANT_ABORTING:
            self.rollback()
            self.transaction_state = TransactionState.ABORTED
        if self.transaction_state == TransactionState.PARTICIPANT_COMMITTING:
            self.commit()

        return self.transaction_state

    def prepare(self, target_state):
        assert self.transaction_state in [
            TransactionState.COORDINATOR_PREPARE,
            TransactionState.PARTICIPANT_PREPARE], \
            'Workflow not in prepare state'

        # TODO(tjulinfan): remove this
        if target_state is None:
            # No action
            return

        # Validation
        try:
            self.update_target_state(target_state)
        except ValueError as e:
            logging.warning('Error during update target state in prepare: %s',
                            str(e))
            self.transaction_state = TransactionState.ABORTED
            return

        success = True
        if self.target_state == WorkflowState.READY:
            success = self._prepare_for_ready()

        if success:
            if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
                self.transaction_state = \
                    TransactionState.COORDINATOR_COMMITTABLE
            else:
                self.transaction_state = \
                    TransactionState.PARTICIPANT_COMMITTABLE

    def rollback(self):
        self.target_state = WorkflowState.INVALID

    # TODO: separate this method to another module
    def commit(self):
        assert self.transaction_state in [
            TransactionState.COORDINATOR_COMMITTING,
            TransactionState.PARTICIPANT_COMMITTING], \
            'Workflow not in prepare state'

        if self.target_state == WorkflowState.STOPPED:
            self.stop_at = int(datetime.utcnow().timestamp())
            try:
                for job in self.owned_jobs:
                    job.stop()
            except RuntimeError as e:
                # errors from k8s
                logging.error('Stop workflow %d has Runtime error msg: %s',
                              self.id, e.args)
                return
        elif self.target_state == WorkflowState.READY:
            self._setup_jobs()
            self.fork_proposal_config = None
        elif self.target_state == WorkflowState.RUNNING:
            self.start_at = int(datetime.utcnow().timestamp())
            for job in self.owned_jobs:
                if not job.is_disabled:
                    job.schedule()

        self.state = self.target_state
        self.target_state = WorkflowState.INVALID
        self.transaction_state = TransactionState.READY

    def invalidate(self):
        self.state = WorkflowState.INVALID
        self.target_state = WorkflowState.INVALID
        self.transaction_state = TransactionState.READY
        for job in self.owned_jobs:
            try:
                job.stop()
            except Exception as e:  # pylint: disable=broad-except
                logging.warning(
                    'Error while stopping job %s during invalidation: %s',
                    job.name, repr(e))

    def _setup_jobs(self):
        if self.forked_from is not None:
            trunk = Workflow.query.get(self.forked_from)
            assert trunk is not None, \
                'Source workflow %d not found' % self.forked_from
            trunk_job_defs = trunk.get_config().job_definitions
            trunk_name2index = {
                job.name: i
                for i, job in enumerate(trunk_job_defs)
            }

        job_defs = self.get_config().job_definitions
        flags = self.get_create_job_flags()
        assert len(job_defs) == len(flags), \
            'Number of job defs does not match number of create_job_flags ' \
            '%d vs %d'%(len(job_defs), len(flags))
        jobs = []
        for i, (job_def, flag) in enumerate(zip(job_defs, flags)):
            if flag == common_pb2.CreateJobFlag.REUSE:
                assert job_def.name in trunk_name2index, \
                    "Job %s not found in base workflow" % job_def.name
                j = trunk.get_job_ids()[trunk_name2index[job_def.name]]
                job = Job.query.get(j)
                assert job is not None, \
                    'Job %d not found' % j
                # TODO: check forked jobs does not depend on non-forked jobs
            else:
                job = Job(
                    name=f'{self.uuid}-{job_def.name}',
                    job_type=JobType(job_def.job_type),
                    config=job_def.SerializeToString(),
                    workflow_id=self.id,
                    project_id=self.project_id,
                    state=JobState.STOPPED,
                    is_disabled=(flag == common_pb2.CreateJobFlag.DISABLED))
                job.set_yaml_template(job_def.yaml_template)
                db.session.add(job)
            jobs.append(job)
        db.session.flush()
        name2index = {job.name: i for i, job in enumerate(job_defs)}
        for i, (job, flag) in enumerate(zip(jobs, flags)):
            if flag == common_pb2.CreateJobFlag.REUSE:
                continue
            for j, dep_def in enumerate(job.get_config().dependencies):
                dep = JobDependency(
                    src_job_id=jobs[name2index[dep_def.source]].id,
                    dst_job_id=job.id,
                    dep_index=j)
                db.session.add(dep)

        self.set_job_ids([job.id for job in jobs])

    def log_states(self):
        logging.debug(
            'workflow %d updated to state=%s, target_state=%s, '
            'transaction_state=%s', self.id, self.state.name,
            self.target_state.name, self.transaction_state.name)

    def _get_peer_workflow(self):
        project_config = self.project.get_config()
        # TODO: find coordinator for multiparty
        client = RpcClient(project_config, project_config.participants[0])
        return client.get_workflow(self.name)

    def _prepare_for_ready(self):
        # This is a hack, if config is not set then
        # no action needed
        if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
            # TODO(tjulinfan): validate if the config is legal or not
            return bool(self.config)
        if self.forked_from:
            peer_workflow = self._get_peer_workflow()
            base_workflow = Workflow.query.get(self.forked_from)
            if base_workflow is None or not base_workflow.forkable:
                return False
            self.forked_from = base_workflow.id
            self.forkable = base_workflow.forkable
            self.set_create_job_flags(peer_workflow.peer_create_job_flags)
            self.set_peer_create_job_flags(peer_workflow.create_job_flags)
            config = base_workflow.get_config()
            _merge_workflow_config(config, peer_workflow.fork_proposal_config,
                                   [common_pb2.Variable.PEER_WRITABLE])
            self.set_config(config)
            return True
        return bool(self.config)