def setUp(self): super().setUp() with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test.tar.gz'), 'rb') as file: self.TEST_CERTIFICATES = str(b64encode(file.read()), encoding='utf-8') self.default_project = Project() self.default_project.name = 'test-self.default_project' self.default_project.set_config(ParseDict({ 'participants': [ { 'name': 'test-participant', 'domain_name': 'fl-test.com', 'url': '127.0.0.1:32443' } ], 'variables': [ { 'name': 'test', 'value': 'test' } ] }, ProjectProto())) self.default_project.set_certificate(ParseDict({ 'domain_name_to_cert': {'fl-test.com': {'certs': parse_certificates(self.TEST_CERTIFICATES)}}, }, CertificateStorage())) self.default_project.comment = 'test comment' db.session.add(self.default_project) workflow = Workflow(name='workflow_key_get1', project_id=1) db.session.add(workflow) db.session.commit()
def test_is_peer_job_inheritance_matched(self, mock_get_workflow): peer_job_0 = JobDefinition(name='raw-data-job') peer_job_1 = JobDefinition(name='train-job', is_federated=True) peer_config = WorkflowDefinition() peer_config.job_definitions.extend([peer_job_0, peer_job_1]) resp = GetWorkflowResponse(config=peer_config) mock_get_workflow.return_value = resp job_0 = JobDefinition(name='train-job', is_federated=True) config = WorkflowDefinition(job_definitions=[job_0]) project = Project() participant = project_pb2.Participant() project.set_config(project_pb2.Project(participants=[participant])) workflow0 = Workflow(project=project) workflow0.set_config(config) db.session.add(workflow0) db.session.commit() db.session.flush() workflow1 = Workflow(project=project, forked_from=workflow0.id) workflow1.set_config(config) workflow1.set_create_job_flags([CreateJobFlag.REUSE]) workflow1.set_peer_create_job_flags( [CreateJobFlag.NEW, CreateJobFlag.REUSE]) self.assertTrue(is_peer_job_inheritance_matched(workflow1)) workflow1.set_create_job_flags([CreateJobFlag.NEW]) self.assertFalse(is_peer_job_inheritance_matched(workflow1))
def test_get_namespace_from_variables(self): project = Project() project.set_config( project_pb2.Project(variables=[ common_pb2.Variable(name='namespace', value='haha') ])) self.assertEqual(project.get_namespace(), 'haha')
def setUp(self): super().setUp() # Inserts project config = { 'domain_name': 'fl-follower.com', 'participants': [{ 'name': 'party_leader', 'url': '127.0.0.1:5000', 'domain_name': 'fl-leader.com', 'grpc_spec': { 'peer_url': '127.0.0.1:1991', } }], 'variables': [{ 'name': 'namespace', 'value': 'leader' }, { 'name': 'basic_envs', 'value': '{}' }, { 'name': 'storage_root_dir', 'value': '/' }] } project = Project( name='test', config=ParseDict(config, project_pb2.Project()).SerializeToString()) db.session.add(project) db.session.commit()
def init_db(port, domain_name): db.create_all() user = User(username='******') user.set_password('ada') db.session.add(user) config = { 'name': 'test', 'participants': [{ 'name': f'{domain_name}', 'url': f'127.0.0.1:{port}', 'domain_name': f'{domain_name}', 'grpc_spec': { 'authority': f'{domain_name[:-4]}-client-auth.com' } }], 'variables': [{ 'name': 'namespace', 'value': 'default' }, { 'name': 'storage_root_dir', 'value': '/data' }, { 'name': 'EGRESS_URL', 'value': f'127.0.0.1:{port}' }] } project = Project(name='test', config=ParseDict( config, project_pb2.Project()).SerializeToString()) db.session.add(project) db.session.commit()
def setUp(self): super().setUp() self.variables = [ # no-default-value variable { 'name': 'TEST_NO_DEFAULT_VARIABLE', 'value': 'test' }, # default-value variable { 'name': 'VOLUMES', 'value': json.dumps([{ 'hostPath': { 'path': '/test' }, 'name': 'test' }]) }, { 'name': 'VOLUME_MOUNTS', 'value': json.dumps([{ 'mountPath': '/test', 'name': 'test' }]) } ] self.default_project = Project() self.default_project.name = 'test-self.default_project' self.default_project.set_config( ParseDict( { 'participants': [{ 'name': 'test-participant', 'domain_name': 'fl-test.com', 'url': '127.0.0.1:32443' }], 'variables': self.variables }, ProjectProto())) self.default_project.comment = 'test comment' db.session.add(self.default_project) db.session.commit() self.project_k8s_adapter = ProjectK8sAdapter(self.default_project.id)
def test_put_successfully(self): config = { 'participants': [ { 'name': 'party_leader', 'url': '127.0.0.1:5000', 'domain_name': 'fl-leader.com' } ], 'variables': [ { 'name': 'namespace', 'value': 'leader' }, { 'name': 'basic_envs', 'value': '{}' }, { 'name': 'storage_root_dir', 'value': '/' }, { 'name': 'EGRESS_URL', 'value': '127.0.0.1:1991' } ] } project = Project(name='test', config=ParseDict(config, project_pb2.Project()).SerializeToString()) db.session.add(project) workflow = Workflow( name='test-workflow', project_id=1, state=WorkflowState.NEW, transaction_state=TransactionState.PARTICIPANT_PREPARE, target_state=WorkflowState.READY ) db.session.add(workflow) db.session.commit() db.session.refresh(workflow) response = self.put_helper( f'/api/v2/workflows/{workflow.id}', data={ 'forkable': True, 'config': {'group_alias': 'test-template'}, 'comment': 'test comment' }) self.assertEqual(response.status_code, HTTPStatus.OK) updated_workflow = Workflow.query.get(workflow.id) self.assertIsNotNone(updated_workflow.config) self.assertTrue(updated_workflow.forkable) self.assertEqual(updated_workflow.comment, 'test comment') self.assertEqual(updated_workflow.target_state, WorkflowState.READY)
def test_get_namespace_fallback(self): project = Project() self.assertEqual(project.get_namespace(), 'default') project.set_config( project_pb2.Project(variables=[ common_pb2.Variable(name='test_name', value='test_value') ])) self.assertEqual(project.get_namespace(), 'default')
def post(self): parser = reqparse.RequestParser() parser.add_argument('name', required=True, type=str, help=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'name', 'Empty')) parser.add_argument('config', required=True, type=dict, help=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'config', 'Empty')) parser.add_argument('comment') data = parser.parse_args() name = data['name'] config = data['config'] comment = data['comment'] if Project.query.filter_by(name=name).first() is not None: raise InvalidArgumentException( details=ErrorMessage.NAME_CONFLICT.value.format(name)) if config.get('participants') is None: raise InvalidArgumentException( details=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'participants', 'Empty')) if len(config.get('participants')) != 1: # TODO: remove limit after operator supports multiple participants raise InvalidArgumentException( details='Currently not support multiple participants.') # exact configuration from variables # TODO: one custom host for one participant custom_host = None for variable in config.get('variables', []): if variable.get('name') == 'CUSTOM_HOST': custom_host = variable.get('value') # parse participant certificates = {} for participant in config.get('participants'): if 'name' not in participant.keys() or \ 'url' not in participant.keys() or \ 'domain_name' not in participant.keys(): raise InvalidArgumentException( details=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'participants', 'Participant must have name, ' 'domain_name and url.')) if re.match(_URL_REGEX, participant.get('url')) is None: raise InvalidArgumentException('URL pattern is wrong') domain_name = participant.get('domain_name') # Grpc spec participant['grpc_spec'] = { 'authority': '{}-client-auth.com'.format(domain_name[:-4]) } if participant.get('certificates') is not None: current_cert = parse_certificates( participant.get('certificates')) success, err = verify_certificates(current_cert) if not success: raise InvalidArgumentException(err) certificates[domain_name] = {'certs': current_cert} if 'certificates' in participant.keys(): participant.pop('certificates') new_project = Project() # generate token # If users send a token, then use it instead. # If `token` is None, generate a new one by uuid. config['name'] = name token = config.get('token', uuid4().hex) config['token'] = token # check format of config try: new_project.set_config(ParseDict(config, ProjectProto())) except Exception as e: raise InvalidArgumentException( details=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'config', e)) new_project.set_certificate( ParseDict({'domain_name_to_cert': certificates}, CertificateStorage())) new_project.name = name new_project.token = token new_project.comment = comment # create add on for participant in new_project.get_config().participants: if participant.domain_name in\ new_project.get_certificate().domain_name_to_cert.keys(): _create_add_on( participant, new_project.get_certificate().domain_name_to_cert[ participant.domain_name], custom_host) try: new_project = db.session.merge(new_project) db.session.commit() except Exception as e: raise InvalidArgumentException(details=str(e)) return {'data': new_project.to_dict()}
def post(self): parser = reqparse.RequestParser() parser.add_argument('name', required=True, type=str, help=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'name', 'Empty')) parser.add_argument('config', required=True, type=dict, help=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'config', 'Empty')) parser.add_argument('comment') data = parser.parse_args() name = data['name'] config = data['config'] comment = data['comment'] if Project.query.filter_by(name=name).first() is not None: raise InvalidArgumentException( details=ErrorMessage.NAME_CONFLICT.value.format(name)) if config.get('participants') is None: raise InvalidArgumentException( details=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'participants', 'Empty')) if len(config.get('participants')) != 1: # TODO: remove limit after operator supports multiple participants raise InvalidArgumentException( details='Currently not support multiple participants.') certificates = {} for participant in config.get('participants'): if 'name' not in participant.keys() or \ 'url' not in participant.keys() or \ 'domain_name' not in participant.keys(): raise InvalidArgumentException( details=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'participants', 'Participant must have name, ' 'domain_name and url.')) domain_name = participant.get('domain_name') if participant.get('certificates') is not None: current_cert = parse_certificates( participant.get('certificates')) # check validation for file_name in _CERTIFICATE_FILE_NAMES: if current_cert.get(file_name) is None: raise InvalidArgumentException( details=ErrorMessage.PARAM_FORMAT_ERROR.value. format('certificates', '{} not existed'.format( file_name))) certificates[domain_name] = {'certs': current_cert} participant.pop('certificates') # create add on try: k8s_client = get_client() for domain_name, certificate in certificates.items(): create_add_on(k8s_client, domain_name, participant.get('url'), current_cert) except RuntimeError as e: raise InvalidArgumentException(details=str(e)) new_project = Project() # generate token # If users send a token, then use it instead. # If `token` is None, generate a new one by uuid. config['name'] = name token = config.get('token', uuid4().hex) config['token'] = token # check format of config try: new_project.set_config(ParseDict(config, ProjectProto())) except Exception as e: raise InvalidArgumentException( details=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'config', e)) new_project.set_certificate( ParseDict({'domain_name_to_cert': certificates}, CertificateStorage())) new_project.name = name new_project.token = token new_project.comment = comment try: new_project = db.session.merge(new_project) db.session.commit() except Exception as e: raise InvalidArgumentException(details=str(e)) return {'data': new_project.to_dict()}
class ProjectApiTest(BaseTestCase): def setUp(self): super().setUp() with open( os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test.tar.gz'), 'rb') as file: self.TEST_CERTIFICATES = str(b64encode(file.read()), encoding='utf-8') self.default_project = Project() self.default_project.name = 'test-self.default_project' self.default_project.set_config( ParseDict( { 'participants': [{ 'name': 'test-participant', 'domain_name': 'fl-test.com', 'url': '127.0.0.1:32443' }], 'variables': [{ 'name': 'test', 'value': 'test' }] }, ProjectProto())) self.default_project.set_certificate( ParseDict( { 'domain_name_to_cert': { 'fl-test.com': { 'certs': parse_certificates(self.TEST_CERTIFICATES) } } }, CertificateStorage())) self.default_project.comment = 'test comment' db.session.add(self.default_project) workflow = Workflow(name='workflow_key_get1', project_id=1) db.session.add(workflow) db.session.commit() def test_get_project(self): get_response = self.get_helper('/api/v2/projects/{}'.format(1)) self.assertEqual(get_response.status_code, HTTPStatus.OK) queried_project = json.loads(get_response.data).get('data') self.assertEqual(queried_project, self.default_project.to_dict()) def test_get_not_found_project(self): get_response = self.get_helper('/api/v2/projects/{}'.format(1000)) self.assertEqual(get_response.status_code, HTTPStatus.NOT_FOUND) @patch('fedlearner_webconsole.project.apis.verify_certificates') def test_post_project(self, mock_verify_certificates): mock_verify_certificates.return_value = (True, '') name = 'test-post-project' config = { 'participants': [{ 'name': 'test-post-participant', 'domain_name': 'fl-test-post.com', 'url': '127.0.0.1:32443', 'certificates': self.TEST_CERTIFICATES }], 'variables': [{ 'name': 'test-post', 'value': 'test' }] } comment = 'test post project' create_response = self.post_helper('/api/v2/projects', data={ 'name': name, 'config': config, 'comment': comment }) self.assertEqual(create_response.status_code, HTTPStatus.OK) created_project = json.loads(create_response.data).get('data') queried_project = Project.query.filter_by(name=name).first() self.assertEqual(created_project, queried_project.to_dict()) mock_verify_certificates.assert_called_once_with( parse_certificates(self.TEST_CERTIFICATES)) def test_post_conflict_name_project(self): config = { 'participants': { 'fl-test-post.com': { 'name': 'test-post-participant', 'url': '127.0.0.1:32443', 'certificates': self.TEST_CERTIFICATES } }, 'variables': [{ 'name': 'test-post', 'value': 'test' }] } create_response = self.post_helper('/api/v2/projects', data={ 'name': self.default_project.name, 'config': config, 'comment': '' }) self.assertEqual(create_response.status_code, HTTPStatus.BAD_REQUEST) def test_list_project(self): list_response = self.get_helper('/api/v2/projects') project_list = json.loads(list_response.data).get('data') self.assertEqual(len(project_list), 1) for project in project_list: queried_project = Project.query.filter_by( name=project['name']).first() result = queried_project.to_dict() result['num_workflow'] = 1 self.assertEqual(project, result) def test_update_project(self): updated_name = 'updated name' updated_comment = 'updated comment' update_response = self.patch_helper('/api/v2/projects/{}'.format(1), data={ 'participant_name': updated_name, 'comment': updated_comment }) self.assertEqual(update_response.status_code, HTTPStatus.OK) queried_project = Project.query.filter_by(id=1).first() participant = queried_project.get_config().participants[0] self.assertEqual(participant.name, updated_name) self.assertEqual(queried_project.comment, updated_comment) def test_update_not_found_project(self): updated_comment = 'updated comment' update_response = self.patch_helper('/api/v2/projects/{}'.format(1000), data={'comment': updated_comment}) self.assertEqual(update_response.status_code, HTTPStatus.NOT_FOUND)
class ProjectK8sAdapterTest(BaseTestCase): def setUp(self): super().setUp() self.variables = [ # no-default-value variable { 'name': 'TEST_NO_DEFAULT_VARIABLE', 'value': 'test' }, # default-value variable { 'name': 'VOLUMES', 'value': json.dumps([{ 'hostPath': { 'path': '/test' }, 'name': 'test' }]) }, { 'name': 'VOLUME_MOUNTS', 'value': json.dumps([{ 'mountPath': '/test', 'name': 'test' }]) } ] self.default_project = Project() self.default_project.name = 'test-self.default_project' self.default_project.set_config( ParseDict( { 'participants': [{ 'name': 'test-participant', 'domain_name': 'fl-test.com', 'url': '127.0.0.1:32443' }], 'variables': self.variables }, ProjectProto())) self.default_project.comment = 'test comment' db.session.add(self.default_project) db.session.commit() self.project_k8s_adapter = ProjectK8sAdapter(self.default_project.id) def test_exact_variable_from_config(self): # no default value self.assertEqual( 'test', self.project_k8s_adapter._exact_variable_from_config( 'TEST_NO_DEFAULT_VARIABLE')) # using default value self.assertEqual( 'default', self.project_k8s_adapter._exact_variable_from_config('NAMESPACE')) # using new value self.assertEqual( json.dumps([{ 'hostPath': { 'path': '/test' }, 'name': 'test' }]), self.project_k8s_adapter._exact_variable_from_config('VOLUMES')) def test_get_global_replica_spec(self): self.assertEqual( { 'global_replica_spec': { 'template': { 'spec': { 'imagePullSecrets': [{ 'name': 'regcred' }], 'volumes': [{ 'hostPath': { 'path': '/test' }, 'name': 'test' }], 'containers': [{ 'env': self.variables, 'volumeMounts': [{ 'mountPath': '/test', 'name': 'test' }] }] } } } }, self.project_k8s_adapter.get_global_replica_spec()) if __name__ == '__main__': unittest.main()
class ProjectApiTest(BaseTestCase): def setUp(self): super().setUp() with open( os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test.tar.gz'), 'rb') as file: self.TEST_CERTIFICATES = str(b64encode(file.read()), encoding='utf-8') self.default_project = Project() self.default_project.name = 'test-self.default_project' self.default_project.set_config( ParseDict( { 'participants': [{ 'name': 'test-participant', 'domain_name': 'fl-test.com', 'url': '127.0.0.1:32443' }], 'variables': [{ 'name': 'test', 'value': 'test' }] }, ProjectProto())) self.default_project.set_certificate( ParseDict( { 'domain_name_to_cert': { '*.fl-test.com': { 'certs': parse_certificates(self.TEST_CERTIFICATES) } }, }, CertificateStorage())) self.default_project.comment = 'test comment' db.session.add(self.default_project) db.session.commit() def test_get_project(self): get_response = self.client.get('/api/v2/projects/{}'.format(1)) self.assertEqual(get_response.status_code, HTTPStatus.OK) queried_project = json.loads(get_response.data).get('data') self.assertEqual(queried_project, self.default_project.to_dict()) def test_get_not_found_project(self): get_response = self.client.get('/api/v2/projects/{}'.format(1000)) self.assertEqual(get_response.status_code, HTTPStatus.NOT_FOUND) def test_post_project(self): name = 'test-post-project' config = { 'participants': [{ 'name': 'test-post-participant', 'domain_name': 'fl-test-post.com', 'url': '127.0.0.1:32443', 'certificates': self.TEST_CERTIFICATES }], 'variables': [{ 'name': 'test-post', 'value': 'test' }] } comment = 'test post project' create_response = self.client.post('/api/v2/projects', data=json.dumps({ 'name': name, 'config': config, 'comment': comment }), content_type='application/json') self.assertEqual(create_response.status_code, HTTPStatus.OK) created_project = json.loads(create_response.data).get('data') queried_project = Project.query.filter_by(name=name).first() self.assertEqual(created_project, queried_project.to_dict()) def test_post_conflict_name_project(self): config = { 'participants': { 'fl-test-post.com': { 'name': 'test-post-participant', 'url': '127.0.0.1:32443', 'certificates': self.TEST_CERTIFICATES } }, 'variables': [{ 'name': 'test-post', 'value': 'test' }] } create_response = self.client.post('/api/v2/projects', data=json.dumps({ 'name': self.default_project.name, 'config': config, 'comment': '' }), content_type='application/json') self.assertEqual(create_response.status_code, HTTPStatus.BAD_REQUEST) def test_list_project(self): list_response = self.client.get('/api/v2/projects') project_list = json.loads(list_response.data).get('data') for project in project_list: queried_project = Project.query.filter_by( name=project['name']).first() self.assertEqual(project, queried_project.to_dict()) def test_update_project(self): updated_comment = 'updated comment' update_response = self.client.patch('/api/v2/projects/{}'.format(1), data=json.dumps( {'comment': updated_comment}), content_type='application/json') self.assertEqual(update_response.status_code, HTTPStatus.OK) queried_project = Project.query.filter_by(id=1).first() self.assertEqual(queried_project.comment, updated_comment) def test_update_not_found_project(self): updated_comment = 'updated comment' update_response = self.client.patch('/api/v2/projects/{}'.format(1000), data=json.dumps( {'comment': updated_comment}), content_type='application/json') self.assertEqual(update_response.status_code, HTTPStatus.NOT_FOUND)
def post(self): parser = reqparse.RequestParser() parser.add_argument('name', required=True, type=str, help=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'name', 'Empty')) parser.add_argument('config', required=True, type=dict, help=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'config', 'Empty')) parser.add_argument('comment') data = parser.parse_args() name = data['name'] config = data['config'] comment = data['comment'] if Project.query.filter_by(name=name).first() is not None: abort(HTTPStatus.BAD_REQUEST, message=ErrorMessage.NAME_CONFLICT.value.format(name)) if config.get('participants') is None: abort(HTTPStatus.BAD_REQUEST, message=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'participants', 'Empty')) elif len(config.get('participants')) != 1: # TODO: remove limit in schema after operator supports multiple participants abort(HTTPStatus.BAD_REQUEST, message='Currently not support multiple participants.') certificates = {} for domain_name, participant in config.get('participants').items(): if 'name' not in participant.keys( ) or 'url' not in participant.keys(): abort(HTTPStatus.BAD_REQUEST, message=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'participants', 'Participant must have name and url.')) if participant.get('certificates') is not None: current_cert = _parse_certificates( participant.get('certificates')) # check validation for file_name in _CERTIFICATE_FILE_NAMES: if current_cert.get(file_name) is None: abort(HTTPStatus.BAD_REQUEST, message=ErrorMessage.PARAM_FORMAT_ERROR.value. format('certificates', '{} not existed'.format(file_name))) certificates[domain_name] = participant.get('certificates') participant['domain_name'] = domain_name participant.pop('certificates') # format participant to proto structure # TODO: fill other fields participant['grpc_spec'] = {'url': participant.get('url')} participant.pop('url') new_project = Project() # generate token # If users send a token, then use it instead. # If `token` is None, generate a new one by uuid. token = config.get('token', uuid4().hex) config['token'] = token # check format of config try: new_project.set_config(ParseDict(config, ProjectProto())) except Exception as e: abort(HTTPStatus.BAD_REQUEST, message=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'config', e)) new_project.set_certificate( ParseDict({'certificate': certificates}, Certificate())) new_project.name = name new_project.token = token new_project.comment = comment # following operations will change the state of k8s and db try: # TODO: singleton k8s client k8s_client = K8sClient() for domain_name, certificate in certificates.items(): _create_add_on(k8s_client, domain_name, certificate) db.session.add(new_project) db.session.commit() except Exception as e: abort(HTTPStatus.INTERNAL_SERVER_ERROR, msg=e) return {'data': new_project.to_dict()}