def update_workflow_state(self, request): with self._app.app_context(): project, party = self.check_auth_info(request.auth_info) logging.debug( 'received update_workflow_state from %s: %s', party.domain_name, request) name = request.workflow_name state = WorkflowState(request.state) target_state = WorkflowState(request.target_state) transaction_state = TransactionState(request.transaction_state) workflow = Workflow.query.filter_by( name=request.workflow_name, project_id=project.id).first() if workflow is None: assert state == WorkflowState.NEW assert target_state == WorkflowState.READY workflow = Workflow( name=request.workflow_name, project_id=project.id, state=state, target_state=target_state, transaction_state=TransactionState.READY) db.session.add(workflow) db.session.commit() db.session.refresh(workflow) workflow.update_state( state, target_state, transaction_state) db.session.commit() return service_pb2.UpdateWorkflowStateResponse( status=common_pb2.Status( code=common_pb2.STATUS_SUCCESS), transaction_state=workflow.transaction_state.value)
def test_patch_create_job_flags(self): wd = WorkflowDefinition() jd = wd.job_definitions.add() workflow = Workflow( name='test-workflow', project_id=123, config=wd.SerializeToString(), forkable=False, state=WorkflowState.READY, ) db.session.add(workflow) db.session.flush() job = Job(name='test_job', job_type=JobType(1), config=jd.SerializeToString(), workflow_id=workflow.id, project_id=123, state=JobState.STOPPED, is_disabled=False) db.session.add(job) db.session.flush() workflow.job_ids = str(job.id) db.session.commit() response = self.patch_helper(f'/api/v2/workflows/{workflow.id}', data={'create_job_flags': [3]}) self.assertEqual(response.status_code, HTTPStatus.OK) patched_job = Job.query.get(job.id) self.assertEqual(patched_job.is_disabled, True) response = self.patch_helper(f'/api/v2/workflows/{workflow.id}', data={'create_job_flags': [1]}) self.assertEqual(response.status_code, HTTPStatus.OK) patched_job = Job.query.get(job.id) self.assertEqual(patched_job.is_disabled, False)
def test_patch_batch_update_interval(self, mock_collect, mock_finish, mock_patch_item, mock_get_item_status): mock_get_item_status.side_effect = [None, ItemStatus.ON] workflow = Workflow( name='test-workflow-left', project_id=123, config=WorkflowDefinition(is_left=True).SerializeToString(), forkable=False, state=WorkflowState.STOPPED, ) batch_update_interval = 1 db.session.add(workflow) db.session.commit() db.session.refresh(workflow) # test create cronjob response = self.patch_helper( f'/api/v2/workflows/{workflow.id}', data={'batch_update_interval': batch_update_interval}) self.assertEqual(response.status_code, HTTPStatus.OK) mock_collect.assert_called_with( name=f'workflow_cron_job_{workflow.id}', items=[WorkflowCronJobItem(workflow.id)], metadata={}, interval=batch_update_interval * 60) # patch new interval time for cronjob batch_update_interval = 2 response = self.patch_helper( f'/api/v2/workflows/{workflow.id}', data={'batch_update_interval': batch_update_interval}) self.assertEqual(response.status_code, HTTPStatus.OK) mock_patch_item.assert_called_with( name=f'workflow_cron_job_{workflow.id}', key='interval_time', value=batch_update_interval * 60) # test stop cronjob response = self.patch_helper(f'/api/v2/workflows/{workflow.id}', data={'batch_update_interval': -1}) self.assertEqual(response.status_code, HTTPStatus.OK) mock_finish.assert_called_with(name=f'workflow_cron_job_{workflow.id}') workflow = Workflow( name='test-workflow-right', project_id=456, config=WorkflowDefinition(is_left=False).SerializeToString(), forkable=False, state=WorkflowState.STOPPED, ) db.session.add(workflow) db.session.commit() db.session.refresh(workflow) response = self.patch_helper(f'/api/v2/workflows/{workflow.id}', data={'batch_update_interval': 1}) self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST)
def setUp(self): super().setUp() workflow_0 = Workflow(id=0, name='test-workflow-0', project_id=0) workflow_1 = Workflow(id=1, name='test-workflow-1', project_id=0) db.session.add_all([workflow_0, workflow_1]) config = workflow_definition_pb2.JobDefinition( name='test-job').SerializeToString() job_0 = Job(id=0, name='raw_data_0', job_type=JobType.RAW_DATA, state=JobState.STARTED, workflow_id=0, project_id=0, config=config) job_1 = Job(id=1, name='raw_data_1', job_type=JobType.RAW_DATA, state=JobState.COMPLETED, workflow_id=0, project_id=0, config=config) job_2 = Job(id=2, name='data_join_0', job_type=JobType.DATA_JOIN, state=JobState.WAITING, workflow_id=0, project_id=0, config=config) job_3 = Job(id=3, name='data_join_1', job_type=JobType.DATA_JOIN, state=JobState.COMPLETED, workflow_id=1, project_id=0, config=config) job_4 = Job(id=4, name='train_job_0', job_type=JobType.NN_MODEL_TRANINING, state=JobState.WAITING, workflow_id=1, project_id=0, config=config) db.session.add_all([job_0, job_1, job_2, job_3, job_4]) job_dep_0 = JobDependency(src_job_id=job_0.id, dst_job_id=job_2.id, dep_index=0) job_dep_1 = JobDependency(src_job_id=job_1.id, dst_job_id=job_2.id, dep_index=1) job_dep_2 = JobDependency(src_job_id=job_3.id, dst_job_id=job_4.id, dep_index=0) db.session.add_all([job_dep_0, job_dep_1, job_dep_2]) db.session.commit()
def setUp(self): self.maxDiff = None super().setUp() # Inserts data workflow1 = Workflow(name='workflow_key_get1', project_id=1) workflow2 = Workflow(name='workflow_kay_get2', project_id=2) workflow3 = Workflow(name='workflow_key_get3', project_id=2) db.session.add(workflow1) db.session.add(workflow2) db.session.add(workflow3) db.session.commit()
def setUp(self): super().setUp() with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test.tar.gz'), 'rb') as file: self.TEST_CERTIFICATES = str(b64encode(file.read()), encoding='utf-8') self.default_project = Project() self.default_project.name = 'test-self.default_project' self.default_project.set_config(ParseDict({ 'participants': [ { 'name': 'test-participant', 'domain_name': 'fl-test.com', 'url': '127.0.0.1:32443' } ], 'variables': [ { 'name': 'test', 'value': 'test' } ] }, ProjectProto())) self.default_project.set_certificate(ParseDict({ 'domain_name_to_cert': {'fl-test.com': {'certs': parse_certificates(self.TEST_CERTIFICATES)}}, }, CertificateStorage())) self.default_project.comment = 'test comment' db.session.add(self.default_project) workflow = Workflow(name='workflow_key_get1', project_id=1) db.session.add(workflow) db.session.commit()
def test_put_successfully(self): workflow = Workflow( name='test-workflow', project_id=123, state=WorkflowState.NEW, ) db.session.add(workflow) db.session.commit() db.session.refresh(workflow) response = self.put_helper(f'/api/v2/workflows/{workflow.id}', data={ 'forkable': True, 'config': { 'group_alias': 'test-template' }, 'comment': 'test comment' }) self.assertEqual(response.status_code, HTTPStatus.OK) updated_workflow = Workflow.query.get(workflow.id) self.assertIsNotNone(updated_workflow.config) self.assertTrue(updated_workflow.forkable) self.assertEqual(updated_workflow.comment, 'test comment') self.assertEqual(updated_workflow.target_state, WorkflowState.READY)
def test_patch_invalid_target_state(self, mock_wakeup): workflow = Workflow( name='test-workflow', project_id=123, config=WorkflowDefinition().SerializeToString(), forkable=False, state=WorkflowState.READY, target_state=WorkflowState.RUNNING ) db.session.add(workflow) db.session.commit() db.session.refresh(workflow) response = self.patch_helper( f'/api/v2/workflows/{workflow.id}', data={ 'target_state': 'READY' }) self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) self.assertEqual(json.loads(response.data).get('details'), 'Another transaction is in progress [1]') # Checks DB patched_workflow = Workflow.query.get(workflow.id) self.assertEqual(patched_workflow.state, WorkflowState.READY) self.assertEqual(patched_workflow.target_state, WorkflowState.RUNNING) # Checks scheduler mock_wakeup.assert_not_called()
def test_patch_successfully(self, mock_wakeup): workflow = Workflow( name='test-workflow', project_id=123, config=WorkflowDefinition().SerializeToString(), forkable=False, state=WorkflowState.READY, ) db.session.add(workflow) db.session.commit() db.session.refresh(workflow) response = self.patch_helper( f'/api/v2/workflows/{workflow.id}', data={ 'target_state': 'RUNNING' }) self.assertEqual(response.status_code, HTTPStatus.OK) patched_data = json.loads(response.data).get('data') self.assertEqual(patched_data['id'], workflow.id) self.assertEqual(patched_data['state'], 'READY') self.assertEqual(patched_data['target_state'], 'RUNNING') # Checks DB patched_workflow = Workflow.query.get(workflow.id) self.assertEqual(patched_workflow.target_state, WorkflowState.RUNNING) # Checks scheduler mock_wakeup.assert_called_once_with(workflow.id)
def test_get_workflows(self): time.sleep(1) workflow = Workflow(name='last', project_id=1) db.session.add(workflow) db.session.flush() response = self.get_helper('/api/v2/workflows') data = self.get_response_data(response) self.assertEqual(data[0]['name'], 'last')
def test_put_successfully(self): config = { 'participants': [ { 'name': 'party_leader', 'url': '127.0.0.1:5000', 'domain_name': 'fl-leader.com' } ], 'variables': [ { 'name': 'namespace', 'value': 'leader' }, { 'name': 'basic_envs', 'value': '{}' }, { 'name': 'storage_root_dir', 'value': '/' }, { 'name': 'EGRESS_URL', 'value': '127.0.0.1:1991' } ] } project = Project(name='test', config=ParseDict(config, project_pb2.Project()).SerializeToString()) db.session.add(project) workflow = Workflow( name='test-workflow', project_id=1, state=WorkflowState.NEW, transaction_state=TransactionState.PARTICIPANT_PREPARE, target_state=WorkflowState.READY ) db.session.add(workflow) db.session.commit() db.session.refresh(workflow) response = self.put_helper( f'/api/v2/workflows/{workflow.id}', data={ 'forkable': True, 'config': {'group_alias': 'test-template'}, 'comment': 'test comment' }) self.assertEqual(response.status_code, HTTPStatus.OK) updated_workflow = Workflow.query.get(workflow.id) self.assertIsNotNone(updated_workflow.config) self.assertTrue(updated_workflow.forkable) self.assertEqual(updated_workflow.comment, 'test comment') self.assertEqual(updated_workflow.target_state, WorkflowState.READY)
def follower_test_peer_metrics(self): self.setup_project('follower', JobMetricsBuilderTest.Config.GRPC_LISTEN_PORT) workflow = Workflow(name='test-workflow', project_id=1, metric_is_public=True) workflow.set_job_ids([1]) db.session.add(workflow) job = Job(name='automl-2782410011', job_type=JobType.NN_MODEL_TRANINING, workflow_id=1, project_id=1, config=workflow_definition_pb2.JobDefinition( name='test-job').SerializeToString()) db.session.add(job) db.session.commit() while True: time.sleep(1)
def leader_test_peer_metrics(self): self.setup_project( 'leader', JobMetricsBuilderTest.FollowerConfig.GRPC_LISTEN_PORT) workflow = Workflow(name='test-workflow', project_id=1) db.session.add(workflow) db.session.commit() while True: resp = self.get_helper('/api/v2/workflows/1/peer_workflows' '/0/jobs/test-job/metrics') if resp.status_code == HTTPStatus.OK: break time.sleep(1)
def start_or_stop_cronjob(batch_update_interval: int, workflow: Workflow): """start a cronjob for workflow if batch_update_interval is valid Args: batch_update_interval (int): restart workflow interval, unit is minutes Returns: raise when workflow is_left is False """ item_name = f'workflow_cron_job_{workflow.id}' batch_update_interval = batch_update_interval * 60 if workflow.get_config().is_left and batch_update_interval > 0: status = composer.get_item_status(name=item_name) # create a cronjob if not status: composer.collect(name=item_name, items=[WorkflowCronJobItem(workflow.id)], metadata={}, interval=batch_update_interval) return if status == ItemStatus.OFF: raise InvalidArgumentException( f'cannot set item [{item_name}], since item is off') # patch a cronjob try: composer.patch_item_attr(name=item_name, key='interval_time', value=batch_update_interval) except ValueError as err: raise InvalidArgumentException(details=repr(err)) elif batch_update_interval < 0: composer.finish(name=item_name) elif not workflow.get_config().is_left: raise InvalidArgumentException('Only left can operate this') else: logging.info('skip cronjob since batch_update_interval is -1')
def add_fake_workflow(session): wd = WorkflowDefinition() jd = wd.job_definitions.add() workflow = Workflow( name='test-workflow', project_id=123, config=wd.SerializeToString(), forkable=False, state=WorkflowState.READY, ) session.add(workflow) session.flush() job = Job(name='test_job', job_type=JobType(1), config=jd.SerializeToString(), workflow_id=workflow.id, project_id=123, state=JobState.STOPPED, is_disabled=False) session.add(job) session.flush() workflow.job_ids = str(job.id) session.commit() return workflow, job
def post(self): parser = reqparse.RequestParser() parser.add_argument('name', required=True, help='name is empty') parser.add_argument('project_id', type=int, required=True, help='project_id is empty') # TODO: should verify if the config is compatible with # workflow template parser.add_argument('config', type=dict, required=True, help='config is empty') parser.add_argument('forkable', type=bool, required=True, help='forkable is empty') parser.add_argument('forked_from', type=int, required=False, help='forkable is empty') parser.add_argument('comment') data = parser.parse_args() name = data['name'] if Workflow.query.filter_by(name=name).first() is not None: raise ResourceConflictException( 'Workflow {} already exists.'.format(name)) # form to proto buffer template_proto = dict_to_workflow_definition(data['config']) workflow = Workflow(name=name, comment=data['comment'], project_id=data['project_id'], forkable=data['forkable'], forked_from=data['forked_from'], state=WorkflowState.NEW, target_state=WorkflowState.READY, transaction_state=TransactionState.READY) workflow.set_config(template_proto) db.session.add(workflow) db.session.commit() logging.info('Inserted a workflow to db') scheduler.wakeup(workflow.id) return {'data': workflow.to_dict()}, HTTPStatus.CREATED
def test_put_resetting(self): workflow = Workflow( name='test-workflow', project_id=123, config=WorkflowDefinition( group_alias='test-template').SerializeToString(), state=WorkflowState.NEW, ) db.session.add(workflow) db.session.commit() db.session.refresh(workflow) response = self.put_helper( f'/api/v2/workflows/{workflow.id}', data={ 'forkable': True, 'config': {'group_alias': 'test-template'}, }) self.assertEqual(response.status_code, HTTPStatus.CONFLICT)
def post(self): parser = reqparse.RequestParser() parser.add_argument('name', required=True, help='name is empty') parser.add_argument('project_id', type=int, required=True, help='project_id is empty') # TODO: should verify if the config is compatible with # workflow template parser.add_argument('config', type=dict, required=True, help='config is empty') parser.add_argument('forkable', type=bool, required=True, help='forkable is empty') parser.add_argument('forked_from', type=int, required=False, help='fork from base workflow') parser.add_argument('create_job_flags', type=list, required=False, location='json', help='flags in common.CreateJobFlag') parser.add_argument('peer_create_job_flags', type=list, required=False, location='json', help='peer flags in common.CreateJobFlag') parser.add_argument('fork_proposal_config', type=dict, required=False, help='fork and edit peer config') parser.add_argument('comment') data = parser.parse_args() name = data['name'] if Workflow.query.filter_by(name=name).first() is not None: raise ResourceConflictException( 'Workflow {} already exists.'.format(name)) # form to proto buffer template_proto = dict_to_workflow_definition(data['config']) workflow = Workflow( name=name, # 20 bytes # a DNS-1035 label must start with an # alphabetic character. substring uuid[:19] has # no collision in 10 million draws uuid=f'u{uuid4().hex[:19]}', comment=data['comment'], project_id=data['project_id'], forkable=data['forkable'], forked_from=data['forked_from'], state=WorkflowState.NEW, target_state=WorkflowState.READY, transaction_state=TransactionState.READY) workflow.set_create_job_flags(data['create_job_flags']) if workflow.forked_from is not None: fork_config = dict_to_workflow_definition( data['fork_proposal_config']) # TODO: more validations if len(fork_config.job_definitions) != \ len(template_proto.job_definitions): raise InvalidArgumentException( 'Forked workflow\'s template does not match base workflow') workflow.set_fork_proposal_config(fork_config) # TODO: check that federated jobs have # same reuse policy on both sides workflow.set_peer_create_job_flags(data['peer_create_job_flags']) workflow.set_config(template_proto) db.session.add(workflow) db.session.commit() logging.info('Inserted a workflow to db') scheduler.wakeup(workflow.id) return {'data': workflow.to_dict()}, HTTPStatus.CREATED
def post(self): parser = reqparse.RequestParser() parser.add_argument('name', required=True, help='name is empty') parser.add_argument('project_id', type=int, required=True, help='project_id is empty') # TODO: should verify if the config is compatible with # workflow template parser.add_argument('config', type=dict, required=True, help='config is empty') parser.add_argument('forkable', type=bool, required=True, help='forkable is empty') parser.add_argument('forked_from', type=int, required=False, help='fork from base workflow') parser.add_argument('reuse_job_names', type=list, required=False, location='json', help='fork and inherit jobs') parser.add_argument('peer_reuse_job_names', type=list, required=False, location='json', help='peer fork and inherit jobs') parser.add_argument('fork_proposal_config', type=dict, required=False, help='fork and edit peer config') parser.add_argument('comment') data = parser.parse_args() name = data['name'] if Workflow.query.filter_by(name=name).first() is not None: raise ResourceConflictException( 'Workflow {} already exists.'.format(name)) # form to proto buffer template_proto = dict_to_workflow_definition(data['config']) workflow = Workflow(name=name, comment=data['comment'], project_id=data['project_id'], forkable=data['forkable'], forked_from=data['forked_from'], state=WorkflowState.NEW, target_state=WorkflowState.READY, transaction_state=TransactionState.READY) if workflow.forked_from is not None: fork_config = dict_to_workflow_definition( data['fork_proposal_config']) # TODO: more validations if len(fork_config.job_definitions) != \ len(template_proto.job_definitions): raise InvalidArgumentException( 'Forked workflow\'s template does not match base workflow') workflow.set_fork_proposal_config(fork_config) workflow.set_reuse_job_names(data['reuse_job_names']) workflow.set_peer_reuse_job_names(data['peer_reuse_job_names']) workflow.set_config(template_proto) db.session.add(workflow) db.session.commit() logging.info('Inserted a workflow to db') scheduler.wakeup(workflow.id) return {'data': workflow.to_dict()}, HTTPStatus.CREATED
def test_workflow_commit(self): # test the committing stage for workflow creating workflow_def = make_workflow_template() workflow = Workflow( id=20, name='job_test1', comment='这是一个测试工作流', config=workflow_def.SerializeToString(), project_id=1, forkable=True, state=WorkflowState.NEW, target_state=WorkflowState.READY, transaction_state=TransactionState.PARTICIPANT_COMMITTING) db.session.add(workflow) db.session.commit() scheduler.wakeup(20) self._wait_until( lambda: Workflow.query.get(20).state == WorkflowState.READY) workflow = Workflow.query.get(20) self.assertEqual(len(workflow.get_jobs()), 2) self.assertEqual(workflow.get_jobs()[0].state, JobState.STOPPED) self.assertEqual(workflow.get_jobs()[1].state, JobState.STOPPED) # test the committing stage for workflow running workflow.target_state = WorkflowState.RUNNING workflow.transaction_state = TransactionState.PARTICIPANT_COMMITTING db.session.commit() scheduler.wakeup(20) self._wait_until( lambda: Workflow.query.get(20).state == WorkflowState.RUNNING) workflow = Workflow.query.get(20) self._wait_until( lambda: workflow.get_jobs()[0].state == JobState.STARTED) self.assertEqual(workflow.get_jobs()[1].state, JobState.WAITING) # test the committing stage for workflow stopping workflow.target_state = WorkflowState.STOPPED workflow.transaction_state = TransactionState.PARTICIPANT_COMMITTING db.session.commit() scheduler.wakeup(20) self._wait_until( lambda: Workflow.query.get(20).state == WorkflowState.STOPPED) workflow = Workflow.query.get(20) self._wait_until( lambda: workflow.get_jobs()[0].state == JobState.STOPPED) self.assertEqual(workflow.get_jobs()[1].state, JobState.STOPPED)
def test_is_peer_job_inheritance_matched(self, mock_get_workflow): peer_job_0 = JobDefinition(name='raw-data-job') peer_job_1 = JobDefinition(name='train-job', is_federated=True) peer_config = WorkflowDefinition() peer_config.job_definitions.extend([peer_job_0, peer_job_1]) resp = GetWorkflowResponse(config=peer_config) mock_get_workflow.return_value = resp job_0 = JobDefinition(name='train-job', is_federated=True) config = WorkflowDefinition(job_definitions=[job_0]) project = Project() participant = project_pb2.Participant() project.set_config(project_pb2.Project(participants=[participant])) workflow0 = Workflow(project=project) workflow0.set_config(config) db.session.add(workflow0) db.session.commit() db.session.flush() workflow1 = Workflow(project=project, forked_from=workflow0.id) workflow1.set_config(config) workflow1.set_create_job_flags([CreateJobFlag.REUSE]) workflow1.set_peer_create_job_flags( [CreateJobFlag.NEW, CreateJobFlag.REUSE]) self.assertTrue(is_peer_job_inheritance_matched(workflow1)) workflow1.set_create_job_flags([CreateJobFlag.NEW]) self.assertFalse(is_peer_job_inheritance_matched(workflow1))
def setUp(self): super(CronJobTest, self).setUp() self.test_id = 8848 workflow = Workflow(id=self.test_id, state=WorkflowState.RUNNING) db.session.add(workflow) db.session.commit()