def setUp(self): super().setUp() with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test.tar.gz'), 'rb') as file: self.TEST_CERTIFICATES = str(b64encode(file.read()), encoding='utf-8') self.default_project = Project() self.default_project.name = 'test-self.default_project' self.default_project.set_config(ParseDict({ 'participants': [ { 'name': 'test-participant', 'domain_name': 'fl-test.com', 'url': '127.0.0.1:32443' } ], 'variables': [ { 'name': 'test', 'value': 'test' } ] }, ProjectProto())) self.default_project.set_certificate(ParseDict({ 'domain_name_to_cert': {'fl-test.com': {'certs': parse_certificates(self.TEST_CERTIFICATES)}}, }, CertificateStorage())) self.default_project.comment = 'test comment' db.session.add(self.default_project) workflow = Workflow(name='workflow_key_get1', project_id=1) db.session.add(workflow) db.session.commit()
def setUp(self): super().setUp() # Inserts project config = { 'domain_name': 'fl-follower.com', 'participants': [{ 'name': 'party_leader', 'url': '127.0.0.1:5000', 'domain_name': 'fl-leader.com', 'grpc_spec': { 'peer_url': '127.0.0.1:1991', } }], 'variables': [{ 'name': 'namespace', 'value': 'leader' }, { 'name': 'basic_envs', 'value': '{}' }, { 'name': 'storage_root_dir', 'value': '/' }] } project = Project( name='test', config=ParseDict(config, project_pb2.Project()).SerializeToString()) db.session.add(project) db.session.commit()
def init_db(port, domain_name): db.create_all() user = User(username='******') user.set_password('ada') db.session.add(user) config = { 'name': 'test', 'participants': [{ 'name': f'{domain_name}', 'url': f'127.0.0.1:{port}', 'domain_name': f'{domain_name}', 'grpc_spec': { 'authority': f'{domain_name[:-4]}-client-auth.com' } }], 'variables': [{ 'name': 'namespace', 'value': 'default' }, { 'name': 'storage_root_dir', 'value': '/data' }, { 'name': 'EGRESS_URL', 'value': f'127.0.0.1:{port}' }] } project = Project(name='test', config=ParseDict( config, project_pb2.Project()).SerializeToString()) db.session.add(project) db.session.commit()
def test_is_peer_job_inheritance_matched(self, mock_get_workflow): peer_job_0 = JobDefinition(name='raw-data-job') peer_job_1 = JobDefinition(name='train-job', is_federated=True) peer_config = WorkflowDefinition() peer_config.job_definitions.extend([peer_job_0, peer_job_1]) resp = GetWorkflowResponse(config=peer_config) mock_get_workflow.return_value = resp job_0 = JobDefinition(name='train-job', is_federated=True) config = WorkflowDefinition(job_definitions=[job_0]) project = Project() participant = project_pb2.Participant() project.set_config(project_pb2.Project(participants=[participant])) workflow0 = Workflow(project=project) workflow0.set_config(config) db.session.add(workflow0) db.session.commit() db.session.flush() workflow1 = Workflow(project=project, forked_from=workflow0.id) workflow1.set_config(config) workflow1.set_create_job_flags([CreateJobFlag.REUSE]) workflow1.set_peer_create_job_flags( [CreateJobFlag.NEW, CreateJobFlag.REUSE]) self.assertTrue(is_peer_job_inheritance_matched(workflow1)) workflow1.set_create_job_flags([CreateJobFlag.NEW]) self.assertFalse(is_peer_job_inheritance_matched(workflow1))
def test_get_namespace_from_variables(self): project = Project() project.set_config( project_pb2.Project(variables=[ common_pb2.Variable(name='namespace', value='haha') ])) self.assertEqual(project.get_namespace(), 'haha')
def test_get_namespace_fallback(self): project = Project() self.assertEqual(project.get_namespace(), 'default') project.set_config( project_pb2.Project(variables=[ common_pb2.Variable(name='test_name', value='test_value') ])) self.assertEqual(project.get_namespace(), 'default')
def test_put_successfully(self): config = { 'participants': [ { 'name': 'party_leader', 'url': '127.0.0.1:5000', 'domain_name': 'fl-leader.com' } ], 'variables': [ { 'name': 'namespace', 'value': 'leader' }, { 'name': 'basic_envs', 'value': '{}' }, { 'name': 'storage_root_dir', 'value': '/' }, { 'name': 'EGRESS_URL', 'value': '127.0.0.1:1991' } ] } project = Project(name='test', config=ParseDict(config, project_pb2.Project()).SerializeToString()) db.session.add(project) workflow = Workflow( name='test-workflow', project_id=1, state=WorkflowState.NEW, transaction_state=TransactionState.PARTICIPANT_PREPARE, target_state=WorkflowState.READY ) db.session.add(workflow) db.session.commit() db.session.refresh(workflow) response = self.put_helper( f'/api/v2/workflows/{workflow.id}', data={ 'forkable': True, 'config': {'group_alias': 'test-template'}, 'comment': 'test comment' }) self.assertEqual(response.status_code, HTTPStatus.OK) updated_workflow = Workflow.query.get(workflow.id) self.assertIsNotNone(updated_workflow.config) self.assertTrue(updated_workflow.forkable) self.assertEqual(updated_workflow.comment, 'test comment') self.assertEqual(updated_workflow.target_state, WorkflowState.READY)
def setUp(self): super().setUp() self.variables = [ # no-default-value variable { 'name': 'TEST_NO_DEFAULT_VARIABLE', 'value': 'test' }, # default-value variable { 'name': 'VOLUMES', 'value': json.dumps([{ 'hostPath': { 'path': '/test' }, 'name': 'test' }]) }, { 'name': 'VOLUME_MOUNTS', 'value': json.dumps([{ 'mountPath': '/test', 'name': 'test' }]) } ] self.default_project = Project() self.default_project.name = 'test-self.default_project' self.default_project.set_config( ParseDict( { 'participants': [{ 'name': 'test-participant', 'domain_name': 'fl-test.com', 'url': '127.0.0.1:32443' }], 'variables': self.variables }, ProjectProto())) self.default_project.comment = 'test comment' db.session.add(self.default_project) db.session.commit() self.project_k8s_adapter = ProjectK8sAdapter(self.default_project.id)
def post(self): parser = reqparse.RequestParser() parser.add_argument('name', required=True, type=str, help=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'name', 'Empty')) parser.add_argument('config', required=True, type=dict, help=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'config', 'Empty')) parser.add_argument('comment') data = parser.parse_args() name = data['name'] config = data['config'] comment = data['comment'] if Project.query.filter_by(name=name).first() is not None: raise InvalidArgumentException( details=ErrorMessage.NAME_CONFLICT.value.format(name)) if config.get('participants') is None: raise InvalidArgumentException( details=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'participants', 'Empty')) if len(config.get('participants')) != 1: # TODO: remove limit after operator supports multiple participants raise InvalidArgumentException( details='Currently not support multiple participants.') # exact configuration from variables # TODO: one custom host for one participant custom_host = None for variable in config.get('variables', []): if variable.get('name') == 'CUSTOM_HOST': custom_host = variable.get('value') # parse participant certificates = {} for participant in config.get('participants'): if 'name' not in participant.keys() or \ 'url' not in participant.keys() or \ 'domain_name' not in participant.keys(): raise InvalidArgumentException( details=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'participants', 'Participant must have name, ' 'domain_name and url.')) if re.match(_URL_REGEX, participant.get('url')) is None: raise InvalidArgumentException('URL pattern is wrong') domain_name = participant.get('domain_name') # Grpc spec participant['grpc_spec'] = { 'authority': '{}-client-auth.com'.format(domain_name[:-4]) } if participant.get('certificates') is not None: current_cert = parse_certificates( participant.get('certificates')) success, err = verify_certificates(current_cert) if not success: raise InvalidArgumentException(err) certificates[domain_name] = {'certs': current_cert} if 'certificates' in participant.keys(): participant.pop('certificates') new_project = Project() # generate token # If users send a token, then use it instead. # If `token` is None, generate a new one by uuid. config['name'] = name token = config.get('token', uuid4().hex) config['token'] = token # check format of config try: new_project.set_config(ParseDict(config, ProjectProto())) except Exception as e: raise InvalidArgumentException( details=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'config', e)) new_project.set_certificate( ParseDict({'domain_name_to_cert': certificates}, CertificateStorage())) new_project.name = name new_project.token = token new_project.comment = comment # create add on for participant in new_project.get_config().participants: if participant.domain_name in\ new_project.get_certificate().domain_name_to_cert.keys(): _create_add_on( participant, new_project.get_certificate().domain_name_to_cert[ participant.domain_name], custom_host) try: new_project = db.session.merge(new_project) db.session.commit() except Exception as e: raise InvalidArgumentException(details=str(e)) return {'data': new_project.to_dict()}
def post(self): parser = reqparse.RequestParser() parser.add_argument('name', required=True, type=str, help=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'name', 'Empty')) parser.add_argument('config', required=True, type=dict, help=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'config', 'Empty')) parser.add_argument('comment') data = parser.parse_args() name = data['name'] config = data['config'] comment = data['comment'] if Project.query.filter_by(name=name).first() is not None: raise InvalidArgumentException( details=ErrorMessage.NAME_CONFLICT.value.format(name)) if config.get('participants') is None: raise InvalidArgumentException( details=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'participants', 'Empty')) if len(config.get('participants')) != 1: # TODO: remove limit after operator supports multiple participants raise InvalidArgumentException( details='Currently not support multiple participants.') certificates = {} for participant in config.get('participants'): if 'name' not in participant.keys() or \ 'url' not in participant.keys() or \ 'domain_name' not in participant.keys(): raise InvalidArgumentException( details=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'participants', 'Participant must have name, ' 'domain_name and url.')) domain_name = participant.get('domain_name') if participant.get('certificates') is not None: current_cert = parse_certificates( participant.get('certificates')) # check validation for file_name in _CERTIFICATE_FILE_NAMES: if current_cert.get(file_name) is None: raise InvalidArgumentException( details=ErrorMessage.PARAM_FORMAT_ERROR.value. format('certificates', '{} not existed'.format( file_name))) certificates[domain_name] = {'certs': current_cert} participant.pop('certificates') # create add on try: k8s_client = get_client() for domain_name, certificate in certificates.items(): create_add_on(k8s_client, domain_name, participant.get('url'), current_cert) except RuntimeError as e: raise InvalidArgumentException(details=str(e)) new_project = Project() # generate token # If users send a token, then use it instead. # If `token` is None, generate a new one by uuid. config['name'] = name token = config.get('token', uuid4().hex) config['token'] = token # check format of config try: new_project.set_config(ParseDict(config, ProjectProto())) except Exception as e: raise InvalidArgumentException( details=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'config', e)) new_project.set_certificate( ParseDict({'domain_name_to_cert': certificates}, CertificateStorage())) new_project.name = name new_project.token = token new_project.comment = comment try: new_project = db.session.merge(new_project) db.session.commit() except Exception as e: raise InvalidArgumentException(details=str(e)) return {'data': new_project.to_dict()}
def post(self): parser = reqparse.RequestParser() parser.add_argument('name', required=True, type=str, help=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'name', 'Empty')) parser.add_argument('config', required=True, type=dict, help=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'config', 'Empty')) parser.add_argument('comment') data = parser.parse_args() name = data['name'] config = data['config'] comment = data['comment'] if Project.query.filter_by(name=name).first() is not None: abort(HTTPStatus.BAD_REQUEST, message=ErrorMessage.NAME_CONFLICT.value.format(name)) if config.get('participants') is None: abort(HTTPStatus.BAD_REQUEST, message=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'participants', 'Empty')) elif len(config.get('participants')) != 1: # TODO: remove limit in schema after operator supports multiple participants abort(HTTPStatus.BAD_REQUEST, message='Currently not support multiple participants.') certificates = {} for domain_name, participant in config.get('participants').items(): if 'name' not in participant.keys( ) or 'url' not in participant.keys(): abort(HTTPStatus.BAD_REQUEST, message=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'participants', 'Participant must have name and url.')) if participant.get('certificates') is not None: current_cert = _parse_certificates( participant.get('certificates')) # check validation for file_name in _CERTIFICATE_FILE_NAMES: if current_cert.get(file_name) is None: abort(HTTPStatus.BAD_REQUEST, message=ErrorMessage.PARAM_FORMAT_ERROR.value. format('certificates', '{} not existed'.format(file_name))) certificates[domain_name] = participant.get('certificates') participant['domain_name'] = domain_name participant.pop('certificates') # format participant to proto structure # TODO: fill other fields participant['grpc_spec'] = {'url': participant.get('url')} participant.pop('url') new_project = Project() # generate token # If users send a token, then use it instead. # If `token` is None, generate a new one by uuid. token = config.get('token', uuid4().hex) config['token'] = token # check format of config try: new_project.set_config(ParseDict(config, ProjectProto())) except Exception as e: abort(HTTPStatus.BAD_REQUEST, message=ErrorMessage.PARAM_FORMAT_ERROR.value.format( 'config', e)) new_project.set_certificate( ParseDict({'certificate': certificates}, Certificate())) new_project.name = name new_project.token = token new_project.comment = comment # following operations will change the state of k8s and db try: # TODO: singleton k8s client k8s_client = K8sClient() for domain_name, certificate in certificates.items(): _create_add_on(k8s_client, domain_name, certificate) db.session.add(new_project) db.session.commit() except Exception as e: abort(HTTPStatus.INTERNAL_SERVER_ERROR, msg=e) return {'data': new_project.to_dict()}