Beispiel #1
0
    def test_is_peer_job_inheritance_matched(self, mock_get_workflow):
        peer_job_0 = JobDefinition(name='raw-data-job')
        peer_job_1 = JobDefinition(name='train-job', is_federated=True)
        peer_config = WorkflowDefinition()
        peer_config.job_definitions.extend([peer_job_0, peer_job_1])
        resp = GetWorkflowResponse(config=peer_config)
        mock_get_workflow.return_value = resp

        job_0 = JobDefinition(name='train-job', is_federated=True)
        config = WorkflowDefinition(job_definitions=[job_0])

        project = Project()
        participant = project_pb2.Participant()
        project.set_config(project_pb2.Project(participants=[participant]))
        workflow0 = Workflow(project=project)
        workflow0.set_config(config)
        db.session.add(workflow0)
        db.session.commit()
        db.session.flush()
        workflow1 = Workflow(project=project, forked_from=workflow0.id)
        workflow1.set_config(config)
        workflow1.set_create_job_flags([CreateJobFlag.REUSE])
        workflow1.set_peer_create_job_flags(
            [CreateJobFlag.NEW, CreateJobFlag.REUSE])

        self.assertTrue(is_peer_job_inheritance_matched(workflow1))

        workflow1.set_create_job_flags([CreateJobFlag.NEW])
        self.assertFalse(is_peer_job_inheritance_matched(workflow1))
Beispiel #2
0
    def test_get_namespace_from_variables(self):
        project = Project()
        project.set_config(
            project_pb2.Project(variables=[
                common_pb2.Variable(name='namespace', value='haha')
            ]))

        self.assertEqual(project.get_namespace(), 'haha')
Beispiel #3
0
    def test_get_namespace_fallback(self):
        project = Project()
        self.assertEqual(project.get_namespace(), 'default')

        project.set_config(
            project_pb2.Project(variables=[
                common_pb2.Variable(name='test_name', value='test_value')
            ]))
        self.assertEqual(project.get_namespace(), 'default')
Beispiel #4
0
    def post(self):
        parser = reqparse.RequestParser()
        parser.add_argument('name',
                            required=True,
                            type=str,
                            help=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
                                'name', 'Empty'))
        parser.add_argument('config',
                            required=True,
                            type=dict,
                            help=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
                                'config', 'Empty'))
        parser.add_argument('comment')
        data = parser.parse_args()
        name = data['name']
        config = data['config']
        comment = data['comment']

        if Project.query.filter_by(name=name).first() is not None:
            raise InvalidArgumentException(
                details=ErrorMessage.NAME_CONFLICT.value.format(name))

        if config.get('participants') is None:
            raise InvalidArgumentException(
                details=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
                    'participants', 'Empty'))
        if len(config.get('participants')) != 1:
            # TODO: remove limit after operator supports multiple participants
            raise InvalidArgumentException(
                details='Currently not support multiple participants.')

        # exact configuration from variables
        # TODO: one custom host for one participant
        custom_host = None
        for variable in config.get('variables', []):
            if variable.get('name') == 'CUSTOM_HOST':
                custom_host = variable.get('value')

        # parse participant
        certificates = {}
        for participant in config.get('participants'):
            if 'name' not in participant.keys() or \
                'url' not in participant.keys() or \
                'domain_name' not in participant.keys():
                raise InvalidArgumentException(
                    details=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
                        'participants', 'Participant must have name, '
                        'domain_name and url.'))
            if re.match(_URL_REGEX, participant.get('url')) is None:
                raise InvalidArgumentException('URL pattern is wrong')
            domain_name = participant.get('domain_name')
            # Grpc spec
            participant['grpc_spec'] = {
                'authority': '{}-client-auth.com'.format(domain_name[:-4])
            }
            if participant.get('certificates') is not None:
                current_cert = parse_certificates(
                    participant.get('certificates'))
                success, err = verify_certificates(current_cert)
                if not success:
                    raise InvalidArgumentException(err)
                certificates[domain_name] = {'certs': current_cert}
            if 'certificates' in participant.keys():
                participant.pop('certificates')

        new_project = Project()
        # generate token
        # If users send a token, then use it instead.
        # If `token` is None, generate a new one by uuid.
        config['name'] = name
        token = config.get('token', uuid4().hex)
        config['token'] = token

        # check format of config
        try:
            new_project.set_config(ParseDict(config, ProjectProto()))
        except Exception as e:
            raise InvalidArgumentException(
                details=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
                    'config', e))
        new_project.set_certificate(
            ParseDict({'domain_name_to_cert': certificates},
                      CertificateStorage()))
        new_project.name = name
        new_project.token = token
        new_project.comment = comment

        # create add on
        for participant in new_project.get_config().participants:
            if participant.domain_name in\
                new_project.get_certificate().domain_name_to_cert.keys():
                _create_add_on(
                    participant,
                    new_project.get_certificate().domain_name_to_cert[
                        participant.domain_name], custom_host)
        try:
            new_project = db.session.merge(new_project)
            db.session.commit()
        except Exception as e:
            raise InvalidArgumentException(details=str(e))

        return {'data': new_project.to_dict()}
Beispiel #5
0
    def post(self):
        parser = reqparse.RequestParser()
        parser.add_argument('name',
                            required=True,
                            type=str,
                            help=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
                                'name', 'Empty'))
        parser.add_argument('config',
                            required=True,
                            type=dict,
                            help=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
                                'config', 'Empty'))
        parser.add_argument('comment')
        data = parser.parse_args()
        name = data['name']
        config = data['config']
        comment = data['comment']

        if Project.query.filter_by(name=name).first() is not None:
            raise InvalidArgumentException(
                details=ErrorMessage.NAME_CONFLICT.value.format(name))

        if config.get('participants') is None:
            raise InvalidArgumentException(
                details=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
                    'participants', 'Empty'))
        if len(config.get('participants')) != 1:
            # TODO: remove limit after operator supports multiple participants
            raise InvalidArgumentException(
                details='Currently not support multiple participants.')

        certificates = {}
        for participant in config.get('participants'):
            if 'name' not in participant.keys() or \
                'url' not in participant.keys() or \
                'domain_name' not in participant.keys():
                raise InvalidArgumentException(
                    details=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
                        'participants', 'Participant must have name, '
                        'domain_name and url.'))
            domain_name = participant.get('domain_name')
            if participant.get('certificates') is not None:
                current_cert = parse_certificates(
                    participant.get('certificates'))
                # check validation
                for file_name in _CERTIFICATE_FILE_NAMES:
                    if current_cert.get(file_name) is None:
                        raise InvalidArgumentException(
                            details=ErrorMessage.PARAM_FORMAT_ERROR.value.
                            format('certificates', '{} not existed'.format(
                                file_name)))
                certificates[domain_name] = {'certs': current_cert}
                participant.pop('certificates')

                # create add on
                try:
                    k8s_client = get_client()
                    for domain_name, certificate in certificates.items():
                        create_add_on(k8s_client, domain_name,
                                      participant.get('url'), current_cert)
                except RuntimeError as e:
                    raise InvalidArgumentException(details=str(e))

        new_project = Project()
        # generate token
        # If users send a token, then use it instead.
        # If `token` is None, generate a new one by uuid.
        config['name'] = name
        token = config.get('token', uuid4().hex)
        config['token'] = token

        # check format of config
        try:
            new_project.set_config(ParseDict(config, ProjectProto()))
        except Exception as e:
            raise InvalidArgumentException(
                details=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
                    'config', e))
        new_project.set_certificate(
            ParseDict({'domain_name_to_cert': certificates},
                      CertificateStorage()))
        new_project.name = name
        new_project.token = token
        new_project.comment = comment

        try:
            new_project = db.session.merge(new_project)
            db.session.commit()
        except Exception as e:
            raise InvalidArgumentException(details=str(e))

        return {'data': new_project.to_dict()}
Beispiel #6
0
class ProjectApiTest(BaseTestCase):
    def setUp(self):
        super().setUp()
        with open(
                os.path.join(os.path.dirname(os.path.realpath(__file__)),
                             'test.tar.gz'), 'rb') as file:
            self.TEST_CERTIFICATES = str(b64encode(file.read()),
                                         encoding='utf-8')
        self.default_project = Project()
        self.default_project.name = 'test-self.default_project'
        self.default_project.set_config(
            ParseDict(
                {
                    'participants': [{
                        'name': 'test-participant',
                        'domain_name': 'fl-test.com',
                        'url': '127.0.0.1:32443'
                    }],
                    'variables': [{
                        'name': 'test',
                        'value': 'test'
                    }]
                }, ProjectProto()))
        self.default_project.set_certificate(
            ParseDict(
                {
                    'domain_name_to_cert': {
                        'fl-test.com': {
                            'certs': parse_certificates(self.TEST_CERTIFICATES)
                        }
                    }
                }, CertificateStorage()))
        self.default_project.comment = 'test comment'
        db.session.add(self.default_project)
        workflow = Workflow(name='workflow_key_get1', project_id=1)
        db.session.add(workflow)
        db.session.commit()

    def test_get_project(self):
        get_response = self.get_helper('/api/v2/projects/{}'.format(1))
        self.assertEqual(get_response.status_code, HTTPStatus.OK)
        queried_project = json.loads(get_response.data).get('data')
        self.assertEqual(queried_project, self.default_project.to_dict())

    def test_get_not_found_project(self):
        get_response = self.get_helper('/api/v2/projects/{}'.format(1000))
        self.assertEqual(get_response.status_code, HTTPStatus.NOT_FOUND)

    @patch('fedlearner_webconsole.project.apis.verify_certificates')
    def test_post_project(self, mock_verify_certificates):
        mock_verify_certificates.return_value = (True, '')
        name = 'test-post-project'
        config = {
            'participants': [{
                'name': 'test-post-participant',
                'domain_name': 'fl-test-post.com',
                'url': '127.0.0.1:32443',
                'certificates': self.TEST_CERTIFICATES
            }],
            'variables': [{
                'name': 'test-post',
                'value': 'test'
            }]
        }
        comment = 'test post project'
        create_response = self.post_helper('/api/v2/projects',
                                           data={
                                               'name': name,
                                               'config': config,
                                               'comment': comment
                                           })
        self.assertEqual(create_response.status_code, HTTPStatus.OK)
        created_project = json.loads(create_response.data).get('data')

        queried_project = Project.query.filter_by(name=name).first()
        self.assertEqual(created_project, queried_project.to_dict())

        mock_verify_certificates.assert_called_once_with(
            parse_certificates(self.TEST_CERTIFICATES))

    def test_post_conflict_name_project(self):
        config = {
            'participants': {
                'fl-test-post.com': {
                    'name': 'test-post-participant',
                    'url': '127.0.0.1:32443',
                    'certificates': self.TEST_CERTIFICATES
                }
            },
            'variables': [{
                'name': 'test-post',
                'value': 'test'
            }]
        }
        create_response = self.post_helper('/api/v2/projects',
                                           data={
                                               'name':
                                               self.default_project.name,
                                               'config': config,
                                               'comment': ''
                                           })
        self.assertEqual(create_response.status_code, HTTPStatus.BAD_REQUEST)

    def test_list_project(self):
        list_response = self.get_helper('/api/v2/projects')
        project_list = json.loads(list_response.data).get('data')
        self.assertEqual(len(project_list), 1)
        for project in project_list:
            queried_project = Project.query.filter_by(
                name=project['name']).first()
            result = queried_project.to_dict()
            result['num_workflow'] = 1
            self.assertEqual(project, result)

    def test_update_project(self):
        updated_name = 'updated name'
        updated_comment = 'updated comment'
        update_response = self.patch_helper('/api/v2/projects/{}'.format(1),
                                            data={
                                                'participant_name':
                                                updated_name,
                                                'comment': updated_comment
                                            })
        self.assertEqual(update_response.status_code, HTTPStatus.OK)
        queried_project = Project.query.filter_by(id=1).first()
        participant = queried_project.get_config().participants[0]
        self.assertEqual(participant.name, updated_name)
        self.assertEqual(queried_project.comment, updated_comment)

    def test_update_not_found_project(self):
        updated_comment = 'updated comment'
        update_response = self.patch_helper('/api/v2/projects/{}'.format(1000),
                                            data={'comment': updated_comment})
        self.assertEqual(update_response.status_code, HTTPStatus.NOT_FOUND)
Beispiel #7
0
class ProjectK8sAdapterTest(BaseTestCase):
    def setUp(self):
        super().setUp()
        self.variables = [
            # no-default-value variable
            {
                'name': 'TEST_NO_DEFAULT_VARIABLE',
                'value': 'test'
            },
            # default-value variable
            {
                'name':
                'VOLUMES',
                'value':
                json.dumps([{
                    'hostPath': {
                        'path': '/test'
                    },
                    'name': 'test'
                }])
            },
            {
                'name': 'VOLUME_MOUNTS',
                'value': json.dumps([{
                    'mountPath': '/test',
                    'name': 'test'
                }])
            }
        ]
        self.default_project = Project()
        self.default_project.name = 'test-self.default_project'
        self.default_project.set_config(
            ParseDict(
                {
                    'participants': [{
                        'name': 'test-participant',
                        'domain_name': 'fl-test.com',
                        'url': '127.0.0.1:32443'
                    }],
                    'variables':
                    self.variables
                }, ProjectProto()))
        self.default_project.comment = 'test comment'
        db.session.add(self.default_project)
        db.session.commit()
        self.project_k8s_adapter = ProjectK8sAdapter(self.default_project.id)

    def test_exact_variable_from_config(self):
        # no default value
        self.assertEqual(
            'test',
            self.project_k8s_adapter._exact_variable_from_config(
                'TEST_NO_DEFAULT_VARIABLE'))
        # using default value
        self.assertEqual(
            'default',
            self.project_k8s_adapter._exact_variable_from_config('NAMESPACE'))
        # using new value
        self.assertEqual(
            json.dumps([{
                'hostPath': {
                    'path': '/test'
                },
                'name': 'test'
            }]),
            self.project_k8s_adapter._exact_variable_from_config('VOLUMES'))

    def test_get_global_replica_spec(self):
        self.assertEqual(
            {
                'global_replica_spec': {
                    'template': {
                        'spec': {
                            'imagePullSecrets': [{
                                'name': 'regcred'
                            }],
                            'volumes': [{
                                'hostPath': {
                                    'path': '/test'
                                },
                                'name': 'test'
                            }],
                            'containers': [{
                                'env':
                                self.variables,
                                'volumeMounts': [{
                                    'mountPath': '/test',
                                    'name': 'test'
                                }]
                            }]
                        }
                    }
                }
            }, self.project_k8s_adapter.get_global_replica_spec())

        if __name__ == '__main__':
            unittest.main()
Beispiel #8
0
class ProjectApiTest(BaseTestCase):
    def setUp(self):
        super().setUp()
        with open(
                os.path.join(os.path.dirname(os.path.realpath(__file__)),
                             'test.tar.gz'), 'rb') as file:
            self.TEST_CERTIFICATES = str(b64encode(file.read()),
                                         encoding='utf-8')
        self.default_project = Project()
        self.default_project.name = 'test-self.default_project'
        self.default_project.set_config(
            ParseDict(
                {
                    'participants': [{
                        'name': 'test-participant',
                        'domain_name': 'fl-test.com',
                        'url': '127.0.0.1:32443'
                    }],
                    'variables': [{
                        'name': 'test',
                        'value': 'test'
                    }]
                }, ProjectProto()))
        self.default_project.set_certificate(
            ParseDict(
                {
                    'domain_name_to_cert': {
                        '*.fl-test.com': {
                            'certs': parse_certificates(self.TEST_CERTIFICATES)
                        }
                    },
                }, CertificateStorage()))
        self.default_project.comment = 'test comment'
        db.session.add(self.default_project)
        db.session.commit()

    def test_get_project(self):
        get_response = self.client.get('/api/v2/projects/{}'.format(1))
        self.assertEqual(get_response.status_code, HTTPStatus.OK)
        queried_project = json.loads(get_response.data).get('data')
        self.assertEqual(queried_project, self.default_project.to_dict())

    def test_get_not_found_project(self):
        get_response = self.client.get('/api/v2/projects/{}'.format(1000))
        self.assertEqual(get_response.status_code, HTTPStatus.NOT_FOUND)

    def test_post_project(self):
        name = 'test-post-project'
        config = {
            'participants': [{
                'name': 'test-post-participant',
                'domain_name': 'fl-test-post.com',
                'url': '127.0.0.1:32443',
                'certificates': self.TEST_CERTIFICATES
            }],
            'variables': [{
                'name': 'test-post',
                'value': 'test'
            }]
        }
        comment = 'test post project'
        create_response = self.client.post('/api/v2/projects',
                                           data=json.dumps({
                                               'name': name,
                                               'config': config,
                                               'comment': comment
                                           }),
                                           content_type='application/json')
        self.assertEqual(create_response.status_code, HTTPStatus.OK)
        created_project = json.loads(create_response.data).get('data')

        queried_project = Project.query.filter_by(name=name).first()
        self.assertEqual(created_project, queried_project.to_dict())

    def test_post_conflict_name_project(self):
        config = {
            'participants': {
                'fl-test-post.com': {
                    'name': 'test-post-participant',
                    'url': '127.0.0.1:32443',
                    'certificates': self.TEST_CERTIFICATES
                }
            },
            'variables': [{
                'name': 'test-post',
                'value': 'test'
            }]
        }
        create_response = self.client.post('/api/v2/projects',
                                           data=json.dumps({
                                               'name':
                                               self.default_project.name,
                                               'config':
                                               config,
                                               'comment':
                                               ''
                                           }),
                                           content_type='application/json')
        self.assertEqual(create_response.status_code, HTTPStatus.BAD_REQUEST)

    def test_list_project(self):
        list_response = self.client.get('/api/v2/projects')
        project_list = json.loads(list_response.data).get('data')
        for project in project_list:
            queried_project = Project.query.filter_by(
                name=project['name']).first()
            self.assertEqual(project, queried_project.to_dict())

    def test_update_project(self):
        updated_comment = 'updated comment'
        update_response = self.client.patch('/api/v2/projects/{}'.format(1),
                                            data=json.dumps(
                                                {'comment': updated_comment}),
                                            content_type='application/json')
        self.assertEqual(update_response.status_code, HTTPStatus.OK)
        queried_project = Project.query.filter_by(id=1).first()
        self.assertEqual(queried_project.comment, updated_comment)

    def test_update_not_found_project(self):
        updated_comment = 'updated comment'
        update_response = self.client.patch('/api/v2/projects/{}'.format(1000),
                                            data=json.dumps(
                                                {'comment': updated_comment}),
                                            content_type='application/json')
        self.assertEqual(update_response.status_code, HTTPStatus.NOT_FOUND)
Beispiel #9
0
    def post(self):
        parser = reqparse.RequestParser()
        parser.add_argument('name',
                            required=True,
                            type=str,
                            help=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
                                'name', 'Empty'))
        parser.add_argument('config',
                            required=True,
                            type=dict,
                            help=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
                                'config', 'Empty'))
        parser.add_argument('comment')

        data = parser.parse_args()
        name = data['name']
        config = data['config']
        comment = data['comment']

        if Project.query.filter_by(name=name).first() is not None:
            abort(HTTPStatus.BAD_REQUEST,
                  message=ErrorMessage.NAME_CONFLICT.value.format(name))

        if config.get('participants') is None:
            abort(HTTPStatus.BAD_REQUEST,
                  message=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
                      'participants', 'Empty'))
        elif len(config.get('participants')) != 1:
            # TODO: remove limit in schema after operator supports multiple participants
            abort(HTTPStatus.BAD_REQUEST,
                  message='Currently not support multiple participants.')

        certificates = {}
        for domain_name, participant in config.get('participants').items():
            if 'name' not in participant.keys(
            ) or 'url' not in participant.keys():
                abort(HTTPStatus.BAD_REQUEST,
                      message=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
                          'participants',
                          'Participant must have name and url.'))
            if participant.get('certificates') is not None:
                current_cert = _parse_certificates(
                    participant.get('certificates'))
                # check validation
                for file_name in _CERTIFICATE_FILE_NAMES:
                    if current_cert.get(file_name) is None:
                        abort(HTTPStatus.BAD_REQUEST,
                              message=ErrorMessage.PARAM_FORMAT_ERROR.value.
                              format('certificates',
                                     '{} not existed'.format(file_name)))
                certificates[domain_name] = participant.get('certificates')
                participant['domain_name'] = domain_name
                participant.pop('certificates')
            # format participant to proto structure
            # TODO: fill other fields
            participant['grpc_spec'] = {'url': participant.get('url')}
            participant.pop('url')

        new_project = Project()
        # generate token
        # If users send a token, then use it instead.
        # If `token` is None, generate a new one by uuid.
        token = config.get('token', uuid4().hex)
        config['token'] = token

        # check format of config
        try:
            new_project.set_config(ParseDict(config, ProjectProto()))
        except Exception as e:
            abort(HTTPStatus.BAD_REQUEST,
                  message=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
                      'config', e))
        new_project.set_certificate(
            ParseDict({'certificate': certificates}, Certificate()))
        new_project.name = name
        new_project.token = token
        new_project.comment = comment

        # following operations will change the state of k8s and db
        try:
            # TODO: singleton k8s client
            k8s_client = K8sClient()
            for domain_name, certificate in certificates.items():
                _create_add_on(k8s_client, domain_name, certificate)

            db.session.add(new_project)
            db.session.commit()
        except Exception as e:
            abort(HTTPStatus.INTERNAL_SERVER_ERROR, msg=e)

        return {'data': new_project.to_dict()}