Exemple #1
0
 def setUp(self):
     super().setUp()
     # Inserts project
     config = {
         'domain_name':
         'fl-follower.com',
         'participants': [{
             'name': 'party_leader',
             'url': '127.0.0.1:5000',
             'domain_name': 'fl-leader.com',
             'grpc_spec': {
                 'peer_url': '127.0.0.1:1991',
             }
         }],
         'variables': [{
             'name': 'namespace',
             'value': 'leader'
         }, {
             'name': 'basic_envs',
             'value': '{}'
         }, {
             'name': 'storage_root_dir',
             'value': '/'
         }]
     }
     project = Project(
         name='test',
         config=ParseDict(config,
                          project_pb2.Project()).SerializeToString())
     db.session.add(project)
     db.session.commit()
Exemple #2
0
    def test_is_peer_job_inheritance_matched(self, mock_get_workflow):
        peer_job_0 = JobDefinition(name='raw-data-job')
        peer_job_1 = JobDefinition(name='train-job', is_federated=True)
        peer_config = WorkflowDefinition()
        peer_config.job_definitions.extend([peer_job_0, peer_job_1])
        resp = GetWorkflowResponse(config=peer_config)
        mock_get_workflow.return_value = resp

        job_0 = JobDefinition(name='train-job', is_federated=True)
        config = WorkflowDefinition(job_definitions=[job_0])

        project = Project()
        participant = project_pb2.Participant()
        project.set_config(project_pb2.Project(participants=[participant]))
        workflow0 = Workflow(project=project)
        workflow0.set_config(config)
        db.session.add(workflow0)
        db.session.commit()
        db.session.flush()
        workflow1 = Workflow(project=project, forked_from=workflow0.id)
        workflow1.set_config(config)
        workflow1.set_create_job_flags([CreateJobFlag.REUSE])
        workflow1.set_peer_create_job_flags(
            [CreateJobFlag.NEW, CreateJobFlag.REUSE])

        self.assertTrue(is_peer_job_inheritance_matched(workflow1))

        workflow1.set_create_job_flags([CreateJobFlag.NEW])
        self.assertFalse(is_peer_job_inheritance_matched(workflow1))
Exemple #3
0
def init_db(port, domain_name):
    db.create_all()
    user = User(username='******')
    user.set_password('ada')
    db.session.add(user)
    config = {
        'name':
        'test',
        'participants': [{
            'name': f'{domain_name}',
            'url': f'127.0.0.1:{port}',
            'domain_name': f'{domain_name}',
            'grpc_spec': {
                'authority': f'{domain_name[:-4]}-client-auth.com'
            }
        }],
        'variables': [{
            'name': 'namespace',
            'value': 'default'
        }, {
            'name': 'storage_root_dir',
            'value': '/data'
        }, {
            'name': 'EGRESS_URL',
            'value': f'127.0.0.1:{port}'
        }]
    }
    project = Project(name='test',
                      config=ParseDict(
                          config, project_pb2.Project()).SerializeToString())
    db.session.add(project)
    db.session.commit()
Exemple #4
0
    def test_get_namespace_from_variables(self):
        project = Project()
        project.set_config(
            project_pb2.Project(variables=[
                common_pb2.Variable(name='namespace', value='haha')
            ]))

        self.assertEqual(project.get_namespace(), 'haha')
Exemple #5
0
    def test_get_namespace_fallback(self):
        project = Project()
        self.assertEqual(project.get_namespace(), 'default')

        project.set_config(
            project_pb2.Project(variables=[
                common_pb2.Variable(name='test_name', value='test_value')
            ]))
        self.assertEqual(project.get_namespace(), 'default')
Exemple #6
0
    def test_put_successfully(self):
        config = {
            'participants': [
                {
                    'name': 'party_leader',
                    'url': '127.0.0.1:5000',
                    'domain_name': 'fl-leader.com'
                }
            ],
            'variables': [
                {
                    'name': 'namespace',
                    'value': 'leader'
                },
                {
                    'name': 'basic_envs',
                    'value': '{}'
                },
                {
                    'name': 'storage_root_dir',
                    'value': '/'
                },
                {
                    'name': 'EGRESS_URL',
                    'value': '127.0.0.1:1991'
                }
            ]
        }
        project = Project(name='test',
                          config=ParseDict(config,
                                           project_pb2.Project()).SerializeToString())
        db.session.add(project)
        workflow = Workflow(
            name='test-workflow',
            project_id=1,
            state=WorkflowState.NEW,
            transaction_state=TransactionState.PARTICIPANT_PREPARE,
            target_state=WorkflowState.READY
        )
        db.session.add(workflow)
        db.session.commit()
        db.session.refresh(workflow)

        response = self.put_helper(
            f'/api/v2/workflows/{workflow.id}',
            data={
                'forkable': True,
                'config': {'group_alias': 'test-template'},
                'comment': 'test comment'
            })
        self.assertEqual(response.status_code, HTTPStatus.OK)

        updated_workflow = Workflow.query.get(workflow.id)
        self.assertIsNotNone(updated_workflow.config)
        self.assertTrue(updated_workflow.forkable)
        self.assertEqual(updated_workflow.comment, 'test comment')
        self.assertEqual(updated_workflow.target_state, WorkflowState.READY)
Exemple #7
0
 def get_config(self):
     proto = project_pb2.Project()
     proto.ParseFromString(self.config)
     return proto
Exemple #8
0
 def get_config(self):
     if self.config is None:
         return None
     proto = project_pb2.Project()
     proto.ParseFromString(self.config)
     return proto