Ejemplo n.º 1
0
class DataBatch(db.Model):
    __tablename__ = 'data_batches_v2'
    __table_args__ = (PrimaryKeyConstraint('event_time', 'dataset_id'), )
    event_time = db.Column(db.TIMESTAMP(timezone=True))
    dataset_id = db.Column(db.Integer, db.ForeignKey(Dataset.id))
    state = db.Column(db.Enum(BatchState))
    source = db.Column(db.Text(),
                       default=dataset_pb2.DatasetSource().SerializeToString())
    failed_source = db.Column(db.Text())
    file_size = db.Column(db.Integer, default=0)
    imported_file_num = db.Column(db.Integer, default=0)
    num_file = db.Column(db.Integer, default=0)
    comment = db.Column(db.Text())
    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           server_onupdate=func.now())
    deleted_at = db.Column(db.DateTime(timezone=True))

    dataset = db.relationship('Dataset', back_populates='data_batches')

    def set_source(self, proto):
        self.source = proto.SerializeToString()

    def get_source(self):
        if self.source is None:
            return None
        proto = dataset_pb2.DatasetSource()
        proto.ParseFromString(self.source)
        return proto
Ejemplo n.º 2
0
class DataBatch(db.Model):
    __tablename__ = 'data_batches_v2'
    __table_args__ = (UniqueConstraint('event_time', 'dataset_id'), )
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    event_time = db.Column(db.TIMESTAMP(timezone=True), nullable=False)
    dataset_id = db.Column(db.Integer, db.ForeignKey(Dataset.id))
    path = db.Column(db.String(512))
    state = db.Column(db.Enum(BatchState, native_enum=False),
                      default=BatchState.NEW)
    move = db.Column(db.Boolean, default=False)
    # Serialized proto of DatasetBatch
    details = db.Column(db.LargeBinary())
    file_size = db.Column(db.Integer, default=0)
    num_imported_file = db.Column(db.Integer, default=0)
    num_file = db.Column(db.Integer, default=0)
    comment = db.Column('cmt', db.Text(), key='comment')
    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           server_onupdate=func.now())
    deleted_at = db.Column(db.DateTime(timezone=True))

    dataset = db.relationship('Dataset', back_populates='data_batches')

    def set_details(self, proto):
        self.num_file = len(proto.files)
        num_imported_file = 0
        num_failed_file = 0
        file_size = 0
        # Aggregates stats
        for file in proto.files:
            if file.state == dataset_pb2.File.State.COMPLETED:
                num_imported_file += 1
                file_size += file.size
            elif file.state == dataset_pb2.File.State.FAILED:
                num_failed_file += 1
        if num_imported_file + num_failed_file == self.num_file:
            if num_failed_file > 0:
                self.state = BatchState.FAILED
            else:
                self.state = BatchState.SUCCESS
        self.num_imported_file = num_imported_file
        self.file_size = file_size
        self.details = proto.SerializeToString()

    def get_details(self):
        if self.details is None:
            return None
        proto = dataset_pb2.DataBatch()
        proto.ParseFromString(self.details)
        return proto
Ejemplo n.º 3
0
class Job(db.Model):
    __tablename__ = 'job_v2'
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    name = db.Column(db.String(255), unique=True)
    job_type = db.Column(db.Enum(JobType, native_enum=False), nullable=False)
    state = db.Column(db.Enum(JobState, native_enum=False),
                      nullable=False,
                      default=JobState.INVALID)
    yaml_template = db.Column(db.Text())
    config = db.Column(db.LargeBinary())

    workflow_id = db.Column(db.Integer,
                            db.ForeignKey('workflow_v2.id'),
                            nullable=False,
                            index=True)
    project_id = db.Column(db.Integer,
                           db.ForeignKey(Project.id),
                           nullable=False)
    flapp_snapshot = db.Column(db.Text())
    pods_snapshot = db.Column(db.Text())

    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           onupdate=func.now())
    deleted_at = db.Column(db.DateTime(timezone=True))

    project = db.relationship(Project)
    workflow = db.relationship('Workflow')
    _k8s_client = get_client()

    def get_config(self):
        if self.config is not None:
            proto = JobDefinition()
            proto.ParseFromString(self.config)
            return proto
        return None

    def _set_snapshot_flapp(self):
        flapp = self._k8s_client.get_custom_object(
            CrdKind.FLAPP, self.name, self.project.get_namespace())
        self.flapp_snapshot = json.dumps(flapp)

    def _set_snapshot_pods(self):
        pods = self._k8s_client.list_resource_of_custom_object(
            CrdKind.FLAPP, self.name, 'pods', self.project.get_namespace())
        self.pods_snapshot = json.dumps(pods)

    def get_pods(self):
        if self.state == JobState.STARTED:
            try:
                pods = self._k8s_client.list_resource_of_custom_object(
                    CrdKind.FLAPP, self.name, 'pods',
                    self.project.get_namespace())
                return pods['pods']
            except RuntimeError as e:
                logging.error('Get %d pods error msg: %s', self.id, e.args)
                return None
        if self.pods_snapshot is not None:
            return json.loads(self.pods_snapshot)['pods']
        return None

    def get_flapp(self):
        if self.state == JobState.STARTED:
            try:
                flapp = self._k8s_client.get_custom_object(
                    CrdKind.FLAPP, self.name, self.project.get_namespace())
                return flapp['flapp']
            except RuntimeError as e:
                logging.error('Get %d flapp error msg: %s', self.id, str(e))
                return None
        if self.flapp_snapshot is not None:
            return json.loads(self.flapp_snapshot)['flapp']
        return None

    def get_pods_for_frontend(self):
        result = []
        flapp = self.get_flapp()
        if flapp is None:
            return result
        if 'status' in flapp \
            and 'flReplicaStatus' in flapp['status']:
            replicas = flapp['status']['flReplicaStatus']
            if replicas is None:
                return result
            for pod_type in replicas:
                for state in ['failed', 'succeeded']:
                    for pod in replicas[pod_type][state]:
                        result.append({
                            'name': pod,
                            'status': 'Flapp_{}'.format(state),
                            'pod_type': pod_type
                        })
        # msg from pods
        pods = self.get_pods()
        if pods is None:
            return result
        pods = pods['items']
        for pod in pods:
            # TODO: make this more readable for frontend
            pod_for_front = {
                'name': pod['metadata']['name'],
                'pod_type': pod['metadata']['labels']['fl-replica-type'],
                'status': pod['status']['phase'],
                'conditions': pod['status']['conditions']
            }
            if 'containerStatuses' in pod['status']:
                pod_for_front['containers_status'] = \
                    pod['status']['containerStatuses']
            result.append(pod_for_front)
        # deduplication pods both in pods and flapp
        result = list({pod['name']: pod for pod in result}.values())
        return result

    def get_state_for_frontend(self):
        if self.state == JobState.STARTED:
            if self.is_complete():
                return 'COMPLETED'
            if self.is_failed():
                return 'FAILED'
            return 'RUNNING'
        if self.state == JobState.STOPPED:
            if self.get_flapp() is None:
                return 'NEW'
        return self.state.name

    def is_failed(self):
        flapp = self.get_flapp()
        if flapp is None \
                or 'status' not in flapp \
                or 'appState' not in flapp['status']:
            return False
        return flapp['status']['appState'] in [
            'FLStateFailed', 'FLStateShutDown'
        ]

    def is_complete(self):
        flapp = self.get_flapp()
        if flapp is None \
                or 'status' not in flapp \
                or 'appState' not in flapp['status']:
            return False
        return flapp['status']['appState'] == 'FLStateComplete'

    def get_complete_at(self):
        flapp = self.get_flapp()
        if flapp is None \
                or 'status' not in flapp \
                or 'complete_at' not in flapp['status']:
            return None
        return flapp['status']['complete_at']

    def stop(self):
        if self.state == JobState.STARTED:
            self._set_snapshot_flapp()
            self._set_snapshot_pods()
            self._k8s_client.delete_custom_object(CrdKind.FLAPP, self.name,
                                                  self.project.get_namespace())
        self.state = JobState.STOPPED

    def schedule(self):
        assert self.state == JobState.STOPPED
        self.pods_snapshot = None
        self.flapp_snapshot = None
        self.state = JobState.WAITING

    def start(self):
        self.state = JobState.STARTED

    def set_yaml_template(self, yaml_template):
        self.yaml_template = yaml_template
Ejemplo n.º 4
0
class Workflow(db.Model):
    __tablename__ = 'workflow_v2'
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(255), unique=True, index=True)
    project_id = db.Column(db.Integer, db.ForeignKey(Project.id))
    config = db.Column(db.Text())
    comment = db.Column(db.String(255))

    forkable = db.Column(db.Boolean, default=False)
    forked_from = db.Column(db.Integer, default=None)
    # index in config.job_defs instead of job's id
    reuse_job_names = db.Column(db.TEXT())
    peer_reuse_job_names = db.Column(db.TEXT())
    fork_proposal_config = db.Column(db.TEXT())

    recur_type = db.Column(db.Enum(RecurType), default=RecurType.NONE)
    recur_at = db.Column(db.Interval)
    trigger_dataset = db.Column(db.Integer)
    last_triggered_batch = db.Column(db.Integer)

    job_ids = db.Column(db.TEXT())

    state = db.Column(db.Enum(WorkflowState), default=WorkflowState.INVALID)
    target_state = db.Column(db.Enum(WorkflowState),
                             default=WorkflowState.INVALID)
    transaction_state = db.Column(db.Enum(TransactionState),
                                  default=TransactionState.READY)
    transaction_err = db.Column(db.Text())

    start_at = db.Column(db.Integer)
    stop_at = db.Column(db.Integer)

    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           onupdate=func.now(),
                           server_default=func.now())

    owned_jobs = db.relationship('Job', back_populates='workflow')
    project = db.relationship(Project)

    def set_config(self, proto):
        if proto is not None:
            self.config = proto.SerializeToString()
        else:
            self.config = None

    def get_config(self):
        if self.config is not None:
            proto = workflow_definition_pb2.WorkflowDefinition()
            proto.ParseFromString(self.config)
            return proto
        return None

    def set_fork_proposal_config(self, proto):
        if proto is not None:
            self.fork_proposal_config = proto.SerializeToString()
        else:
            self.fork_proposal_config = None

    def get_fork_proposal_config(self):
        if self.fork_proposal_config is not None:
            proto = workflow_definition_pb2.WorkflowDefinition()
            proto.ParseFromString(self.fork_proposal_config)
            return proto
        return None

    def set_job_ids(self, job_ids):
        self.job_ids = ','.join([str(i) for i in job_ids])

    def get_job_ids(self):
        if not self.job_ids:
            return []
        return [int(i) for i in self.job_ids.split(',')]

    def get_jobs(self):
        return [Job.query.get(i) for i in self.get_job_ids()]

    def set_reuse_job_names(self, reuse_job_names):
        self.reuse_job_names = ','.join(reuse_job_names)

    def get_reuse_job_names(self):
        if not self.reuse_job_names:
            return []
        return self.reuse_job_names.split(',')

    def set_peer_reuse_job_names(self, peer_reuse_job_names):
        self.peer_reuse_job_names = ','.join(peer_reuse_job_names)

    def get_peer_reuse_job_names(self):
        if not self.peer_reuse_job_names:
            return []
        return self.peer_reuse_job_names.split(',')

    def update_target_state(self, target_state):
        if self.target_state != target_state \
                and self.target_state != WorkflowState.INVALID:
            raise ValueError(f'Another transaction is in progress [{self.id}]')
        if target_state not in [
                WorkflowState.READY, WorkflowState.RUNNING,
                WorkflowState.STOPPED
        ]:
            raise ValueError(f'Invalid target_state {self.target_state}')
        if (self.state, target_state) not in VALID_TRANSITIONS:
            raise ValueError(
                f'Invalid transition from {self.state} to {target_state}')
        self.target_state = target_state

    def update_state(self, asserted_state, target_state, transaction_state):
        assert asserted_state is None or self.state == asserted_state, \
            'Cannot change current state directly'

        if transaction_state != self.transaction_state:
            if (self.transaction_state, transaction_state) in \
                    IGNORED_TRANSACTION_TRANSITIONS:
                return self.transaction_state
            assert (self.transaction_state, transaction_state) in \
                   VALID_TRANSACTION_TRANSITIONS, \
                'Invalid transaction transition from {} to {}'.format(
                    self.transaction_state, transaction_state)
            self.transaction_state = transaction_state

        # coordinator prepare & rollback
        if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
            self.prepare(target_state)
        if self.transaction_state == TransactionState.COORDINATOR_ABORTING:
            self.rollback()

        # participant prepare & rollback & commit
        if self.transaction_state == TransactionState.PARTICIPANT_PREPARE:
            self.prepare(target_state)
        if self.transaction_state == TransactionState.PARTICIPANT_ABORTING:
            self.rollback()
            self.transaction_state = TransactionState.ABORTED
        if self.transaction_state == TransactionState.PARTICIPANT_COMMITTING:
            self.commit()

        return self.transaction_state

    def prepare(self, target_state):
        assert self.transaction_state in [
            TransactionState.COORDINATOR_PREPARE,
            TransactionState.PARTICIPANT_PREPARE], \
            'Workflow not in prepare state'

        # TODO(tjulinfan): remove this
        if target_state is None:
            # No action
            return

        # Validation
        try:
            self.update_target_state(target_state)
        except ValueError as e:
            logging.warning('Error during update target state in prepare: %s',
                            str(e))
            self.transaction_state = TransactionState.ABORTED
            return

        success = True
        if self.target_state == WorkflowState.READY:
            success = self._prepare_for_ready()

        if success:
            if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
                self.transaction_state = \
                    TransactionState.COORDINATOR_COMMITTABLE
            else:
                self.transaction_state = \
                    TransactionState.PARTICIPANT_COMMITTABLE

    def rollback(self):
        self.target_state = WorkflowState.INVALID

    # TODO: separate this method to another module
    def commit(self):
        assert self.transaction_state in [
            TransactionState.COORDINATOR_COMMITTING,
            TransactionState.PARTICIPANT_COMMITTING], \
                'Workflow not in prepare state'

        if self.target_state == WorkflowState.STOPPED:
            self.stop_at = int(datetime.utcnow().timestamp())
            for job in self.owned_jobs:
                job.stop()
        elif self.target_state == WorkflowState.READY:
            self._setup_jobs()
            self.fork_proposal_config = None
        elif self.target_state == WorkflowState.RUNNING:
            self.start_at = int(datetime.utcnow().timestamp())
            for job in self.owned_jobs:
                if not job.get_config().is_manual:
                    job.schedule()

        self.state = self.target_state
        self.target_state = WorkflowState.INVALID
        self.transaction_state = TransactionState.READY

    def _setup_jobs(self):
        if self.forked_from is not None:
            trunk = Workflow.query.get(self.forked_from)
            assert trunk is not None, \
                'Source workflow %d not found'%self.forked_from
            trunk_job_defs = trunk.get_config().job_definitions
            trunk_name2index = {
                job.name: i
                for i, job in enumerate(trunk_job_defs)
            }
        else:
            assert not self.get_reuse_job_names()

        job_defs = self.get_config().job_definitions
        jobs = []
        reuse_jobs = set(self.get_reuse_job_names())
        for i, job_def in enumerate(job_defs):
            if job_def.name in reuse_jobs:
                assert job_def.name in trunk_name2index, \
                    "Job %s not found in base workflow"%job_def.name
                j = trunk.get_job_ids()[trunk_name2index[job_def.name]]
                job = Job.query.get(j)
                assert job is not None, \
                    'Job %d not found'%j
                # TODO: check forked jobs does not depend on non-forked jobs
            else:
                job = Job(name=f'{self.name}-{job_def.name}',
                          job_type=JobType(job_def.type),
                          config=job_def.SerializeToString(),
                          workflow_id=self.id,
                          project_id=self.project_id,
                          state=JobState.STOPPED)
                job.set_yaml_template(job_def.yaml_template)
                db.session.add(job)
            jobs.append(job)
        db.session.commit()

        name2index = {job.name: i for i, job in enumerate(job_defs)}
        for i, job in enumerate(jobs):
            if job.name in reuse_jobs:
                continue
            for j, dep_def in enumerate(job.get_config().dependencies):
                dep = JobDependency(
                    src_job_id=jobs[name2index[dep_def.source]].id,
                    dst_job_id=job.id,
                    dep_index=j)
                db.session.add(dep)

        self.set_job_ids([job.id for job in jobs])

        db.session.commit()

    def log_states(self):
        logging.debug(
            'workflow %d updated to state=%s, target_state=%s, '
            'transaction_state=%s', self.id, self.state.name,
            self.target_state.name, self.transaction_state.name)

    def _get_peer_workflow(self):
        project_config = self.project.get_config()
        # TODO: find coordinator for multiparty
        client = RpcClient(project_config, project_config.participants[0])
        return client.get_workflow(self.name)

    def _prepare_for_ready(self):
        # This is a hack, if config is not set then
        # no action needed
        if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
            # TODO(tjulinfan): validate if the config is legal or not
            return bool(self.config)
        peer_workflow = self._get_peer_workflow()
        if peer_workflow.forked_from:
            base_workflow = Workflow.query.filter(
                Workflow.name == peer_workflow.forked_from).first()
            if base_workflow is None or not base_workflow.forkable:
                return False
            self.forked_from = base_workflow.id
            self.forkable = base_workflow.forkable
            self.set_reuse_job_names(peer_workflow.peer_reuse_job_names)
            self.set_peer_reuse_job_names(peer_workflow.reuse_job_names)
            config = base_workflow.get_config()
            _merge_workflow_config(config, peer_workflow.fork_proposal_config,
                                   [common_pb2.Variable.PEER_WRITABLE])
            self.set_config(config)
            return True
        return bool(self.config)
Ejemplo n.º 5
0
class Job(db.Model):
    __tablename__ = 'job_v2'
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    name = db.Column(db.String(255), unique=True)
    job_type = db.Column(db.Enum(JobType), nullable=False)
    state = db.Column(db.Enum(JobState),
                      nullable=False,
                      default=JobState.INVALID)
    yaml_template = db.Column(db.Text(), nullable=False)
    config = db.Column(db.Text(), nullable=False)
    workflow_id = db.Column(db.Integer,
                            db.ForeignKey('workflow_v2.id'),
                            nullable=False,
                            index=True)
    project_id = db.Column(db.Integer,
                           db.ForeignKey(Project.id),
                           nullable=False)
    flapp_snapshot = db.Column(db.Text())
    pods_snapshot = db.Column(db.Text())
    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           onupdate=func.now())
    deleted_at = db.Column(db.DateTime(timezone=True))

    project = db.relationship(Project)
    workflow = db.relationship('Workflow')
    _k8s_client = get_client()

    def get_config(self):
        if self.config is not None:
            proto = JobDefinition()
            proto.ParseFromString(self.config)
            return proto
        return None

    def _set_snapshot_flapp(self):
        flapp = self._k8s_client.get_custom_object(
            CrdKind.FLAPP, self.name, self.project.get_namespace())
        self.flapp_snapshot = json.dumps(flapp)

    def _set_snapshot_pods(self):
        pods = self._k8s_client.list_resource_of_custom_object(
            CrdKind.FLAPP, self.name, 'pods', self.project.get_namespace())
        self.pods_snapshot = json.dumps(pods)

    def get_flapp(self):
        if self.state == JobState.STARTED:
            return self._k8s_client.list_resource_of_custom_object(
                CrdKind.FLAPP, self.name, 'pods', self.project.get_namespace())
        if self.flapp_snapshot is not None:
            return json.loads(self.flapp_snapshot)
        return None

    def get_pods(self):
        if self.state == JobState.STARTED:
            return self._k8s_client.list_resource_of_custom_object(
                CrdKind.FLAPP, self.name, 'pods', self.project.get_namespace())
        if self.pods_snapshot is not None:
            return json.loads(self.pods_snapshot)
        return None

    def get_pods_for_front(self):
        result = []
        flapp = self.get_flapp()
        if flapp is not None \
                and 'status' in flapp \
                and 'flReplicaStatus' in flapp['status']:
            replicas = flapp['status']['flReplicaStatus']
            for pod_type in replicas:
                for state in replicas[pod_type]:
                    for pod in replicas[pod_type][state]:
                        result.append({
                            'name': pod,
                            'state': state,
                            'pod_type': pod_type
                        })
        return result

    def get_state_for_front(self):
        if self.state == JobState.STARTED:
            if self.is_complete():
                return 'COMPLETE'
            if self.is_failed():
                return 'FAILED'
            return 'RUNNING'
        if self.state == JobState.STOPPED:
            if self.get_flapp() is None:
                return 'NEW'
        return self.state.name

    def is_failed(self):
        flapp = self.get_flapp()
        if flapp is None \
                or 'status' not in flapp \
                or 'appState' not in flapp['status']:
            return False
        return flapp['status']['appState'] in [
            'FLStateFailed', 'FLStateShutDown'
        ]

    def is_complete(self):
        flapp = self.get_flapp()
        if flapp is None \
                or 'status' not in flapp \
                or 'appState' not in flapp['status']:
            return False
        return flapp['status']['appState'] == 'FLStateComplete'

    def get_complete_at(self):
        flapp = self.get_flapp()
        if flapp is None \
                or 'status' not in flapp \
                or 'complete_at' not in flapp['status']:
            return None
        return flapp['status']['complete_at']

    def stop(self):
        if self.state == JobState.STARTED:
            self._set_snapshot_flapp()
            self._set_snapshot_pods()
            self._k8s_client.delete_custom_object(CrdKind.FLAPP, self.name,
                                                  self.project.get_namespace())
        self.state = JobState.STOPPED

    def schedule(self):
        assert self.state == JobState.STOPPED
        self.pods_snapshot = None
        self.flapp_snapshot = None
        self.state = JobState.WAITING

    def start(self):
        self.state = JobState.STARTED

    def set_yaml_template(self, yaml_template):
        self.yaml_template = yaml_template
Ejemplo n.º 6
0
class Workflow(db.Model):
    __tablename__ = 'workflow_v2'
    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(255), unique=True, index=True)
    project_id = db.Column(db.Integer, db.ForeignKey(Project.id))
    config = db.Column(db.Text())
    forkable = db.Column(db.Boolean, default=False)
    forked_from = db.Column(db.Integer, default=None)
    comment = db.Column(db.String(255))

    state = db.Column(db.Enum(WorkflowState), default=WorkflowState.INVALID)
    target_state = db.Column(db.Enum(WorkflowState),
                             default=WorkflowState.INVALID)
    transaction_state = db.Column(db.Enum(TransactionState),
                                  default=TransactionState.READY)
    transaction_err = db.Column(db.Text())

    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_onupdate=func.now(),
                           server_default=func.now())

    project = db.relationship(Project)

    def set_config(self, proto):
        if proto is not None:
            self.config = proto.SerializeToString()
        else:
            self.config = None

    def get_config(self):
        if self.config is not None:
            proto = workflow_definition_pb2.WorkflowDefinition()
            proto.ParseFromString(self.config)
            return proto
        return None

    def update_state(self, state, target_state, transaction_state):
        assert state is None or self.state == state, \
            'Cannot change current state directly'

        if target_state and self.target_state != target_state:
            assert self.target_state == WorkflowState.INVALID, \
                'Another transaction is in progress'
            assert self.transaction_state == TransactionState.READY, \
                'Another transaction is in progress'
            assert (self.state, target_state) in VALID_TRANSITIONS, \
                'Invalid transition from %s to %s'%(self.state, target_state)
            self.target_state = target_state

        if transaction_state is None or \
                transaction_state == self.transaction_state:
            return self.transaction_state

        if (self.transaction_state, transaction_state) in \
                IGNORED_TRANSACTION_TRANSITIONS:
            return self.transaction_state

        assert (self.transaction_state, transaction_state) in \
            VALID_TRANSACTION_TRANSITIONS, \
                'Invalid transaction transition from %s to %s'%(
                    self.transaction_state, transaction_state)
        self.transaction_state = transaction_state

        # coordinator prepare & rollback
        if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
            try:
                self.prepare()
            except Exception as e:
                self.transaction_state = \
                    TransactionState.COORDINATOR_ABORTING

        if self.transaction_state == TransactionState.COORDINATOR_ABORTING:
            try:
                self.rollback()
            except Exception as e:
                pass

        # participant prepare & rollback & commit
        if self.transaction_state == TransactionState.PARTICIPANT_PREPARE:
            try:
                self.prepare()
            except Exception as e:
                self.transaction_state = \
                    TransactionState.PARTICIPANT_ABORTING

        if self.transaction_state == TransactionState.PARTICIPANT_ABORTING:
            try:
                self.rollback()
            except Exception as e:
                pass
            self.target_state = WorkflowState.INVALID
            self.transaction_state = \
                TransactionState.ABORTED

        if self.transaction_state == TransactionState.PARTICIPANT_COMMITTING:
            self.commit()

        return self.transaction_state

    def prepare(self):
        assert self.transaction_state in [
            TransactionState.COORDINATOR_PREPARE,
            TransactionState.PARTICIPANT_PREPARE], \
                "Workflow not in prepare state"

        success = False
        if self.target_state == WorkflowState.READY:
            success = bool(self.config)
        elif self.target_state == WorkflowState.RUNNING:
            success = True
        elif self.target_state == WorkflowState.STOPPED:
            success = True
        else:
            raise RuntimeError("Invalid target_state %s" % self.target_state)
        if success:
            if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
                self.transaction_state = \
                    TransactionState.COORDINATOR_COMMITTABLE
            else:
                self.transaction_state = \
                    TransactionState.PARTICIPANT_COMMITTABLE

    def rollback(self):
        pass

    def commit(self):
        assert self.transaction_state in [
            TransactionState.COORDINATOR_COMMITTING,
            TransactionState.PARTICIPANT_COMMITTING], \
                "Workflow not in prepare state"

        if self.target_state == WorkflowState.STOPPED:
            # TODO: delete jobs from k8s
            pass
        elif self.target_state == WorkflowState.READY:
            # TODO: create workflow jobs in database according to config
            pass

        self.state = self.target_state
        self.target_state = WorkflowState.INVALID
        self.transaction_state = TransactionState.READY

    def log_states(self):
        logging.debug(
            'workflow %d updated to state=%s, target_state=%s, '
            'transaction_state=%s', self.id, self.state.name,
            self.target_state.name, self.transaction_state.name)
Ejemplo n.º 7
0
class Job(db.Model):
    __tablename__ = 'job_v2'
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    name = db.Column(db.String(255), unique=True)
    job_type = db.Column(db.String(16), nullable=False)
    status = db.Column(db.Enum(JobStatus), nullable=False)
    yaml = db.Column(db.Text(), nullable=False)

    workflow_id = db.Column(db.Integer, db.ForeignKey(Workflow.id),
                            nullable=False, index=True)
    project_id = db.Column(db.Integer, db.ForeignKey(Project.id),
                           nullable=False)

    flapp_snapshot = db.Column(db.Text())
    pods_snapshot = db.Column(db.Text())
    created_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now())
    updated_at = db.Column(db.DateTime(timezone=True),
                           server_default=func.now(),
                           server_onupdate=func.now())
    deleted_at = db.Column(db.DateTime(timezone=True))
    _project_adapter = ProjectK8sAdapter(project_id)
    _k8s_client = get_client()

    def _set_snapshot_flapp(self):
        flapp = json.dumps(self._k8s_client.get_flapp(self.
                           _project_adapter.get_namespace(), self.name))
        self.flapp_snapshot = json.dumps(flapp)

    def _set_snapshot_pods(self):
        flapp = json.dumps(self._k8s_client.get_pods(self.
                           _project_adapter.get_namespace(), self.name))
        self.flapp_snapshot = json.dumps(flapp)

    def get_flapp(self):
        if self.status == JobStatus.STARTED:
            self._set_snapshot_flapp()
        return json.loads(self.flapp_snapshot)

    def get_pods(self):
        if self.status == JobStatus.STARTED:
            self._set_snapshot_pods()
        return json.loads(self.pods_snapshot)

    def run(self):
        if self.status == JobStatus.STARTED:
            raise ResourceConflictException('Job has been started')
        self.status = JobStatus.STARTED
        self._k8s_client.create_flapp(self._project_adapter.
                                      get_namespace(), self.yaml)

    def stop(self):
        if self.status == JobStatus.STOPPED:
            raise ResourceConflictException('Job has stopped')
        self.status = JobStatus.STOPPED
        self._set_snapshot_flapp()
        self._set_snapshot_pods()
        self._k8s_client.delete_flapp(self._project_adapter.
                                      get_namespace(), self.name)

    def set_yaml(self, yaml_template, job_config):
        yaml = merge(yaml_template,
                     self._project_adapter.get_global_job_spec())