示例#1
0
文件: apis.py 项目: guotie/fedlearner
    def get(self, participant_id, job_id):
        parser = reqparse.RequestParser()
        parser.add_argument('start_time',
                            type=int,
                            location='args',
                            required=False,
                            help='project_id must be timestamp')
        parser.add_argument('max_lines',
                            type=int,
                            location='args',
                            required=True,
                            help='max_lines is required')
        data = parser.parse_args()
        start_time = data['start_time']
        max_lines = data['max_lines']
        job = _get_job(job_id)
        if start_time is None:
            start_time = job.workflow.start_at

        workflow = job.workflow
        project_config = workflow.project.get_config()
        party = project_config.participants[participant_id]
        client = RpcClient(project_config, party)
        resp = client.get_job_events(job_name=job.name,
                                     start_time=start_time,
                                     max_lines=max_lines)
        if resp.status.code != common_pb2.STATUS_SUCCESS:
            raise InternalException(resp.status.msg)
        peer_events = MessageToDict(resp.logs,
                                    preserving_proto_field_name=True,
                                    including_default_value_fields=True)
        return {'data': peer_events}
示例#2
0
 def get(self, workflow_uuid, participant_id, job_name):
     parser = reqparse.RequestParser()
     parser.add_argument('start_time', type=int, location='args',
                         required=False,
                         help='project_id must be timestamp')
     parser.add_argument('max_lines', type=int, location='args',
                         required=True,
                         help='max_lines is required')
     data = parser.parse_args()
     start_time = data['start_time']
     max_lines = data['max_lines']
     workflow = Workflow.query.filter_by(uuid=workflow_uuid).first()
     if workflow is None:
         raise NotFoundException(
             f'Failed to find workflow: {workflow_uuid}')
     if start_time is None:
         start_time = workflow.start_at
     project_config = workflow.project.get_config()
     party = project_config.participants[participant_id]
     client = RpcClient(project_config, party)
     resp = client.get_job_events(job_name=job_name,
                                  start_time=start_time,
                                  max_lines=max_lines)
     if resp.status.code != common_pb2.STATUS_SUCCESS:
         raise InternalException(resp.status.msg)
     peer_events = MessageToDict(
         resp,
         preserving_proto_field_name=True,
         including_default_value_fields=True)['logs']
     return {'data': peer_events}
示例#3
0
def is_peer_job_inheritance_matched(workflow):
    # TODO: Move it to workflow service
    if workflow.forked_from is None:
        return True
    job_flags = workflow.get_create_job_flags()
    peer_job_flags = workflow.get_peer_create_job_flags()
    job_defs = workflow.get_config().job_definitions
    project = workflow.project
    if project is None:
        return True
    project_config = project.get_config()
    # TODO: Fix for multi-peer
    client = RpcClient(project_config, project_config.participants[0])
    parent_workflow = db.session.query(Workflow).get(workflow.forked_from)
    resp = client.get_workflow(parent_workflow.name)
    if resp.status.code != common_pb2.STATUS_SUCCESS:
        emit_counter('get_workflow_failed', 1)
        raise InternalException(resp.status.msg)
    peer_job_defs = resp.config.job_definitions
    for i, job_def in enumerate(job_defs):
        if job_def.is_federated:
            for j, peer_job_def in enumerate(peer_job_defs):
                if job_def.name == peer_job_def.name:
                    if job_flags[i] != peer_job_flags[j]:
                        return False
    return True
示例#4
0
 def get(self, workflow_uuid, participant_id, job_name):
     parser = reqparse.RequestParser()
     parser.add_argument('type', type=str, location='args',
                         required=True,
                         choices=('Ratio', 'Numeric'),
                         help='Visualization type is required. Choices: '
                              'Rate, Ratio, Numeric, Time, Timer')
     parser.add_argument('interval', type=str, location='args',
                         default='',
                         help='Time bucket interval length, '
                              'defaults to be automated by Kibana.')
     parser.add_argument('x_axis_field', type=str, location='args',
                         default='tags.event_time',
                         help='Time field (X axis) is required.')
     parser.add_argument('query', type=str, location='args',
                         help='Additional query string to the graph.')
     parser.add_argument('start_time', type=int, location='args',
                         default=-1,
                         help='Earliest <x_axis_field> time of data.'
                              'Unix timestamp in secs.')
     parser.add_argument('end_time', type=int, location='args',
                         default=-1,
                         help='Latest <x_axis_field> time of data.'
                              'Unix timestamp in secs.')
     # Ratio visualization
     parser.add_argument('numerator', type=str, location='args',
                         help='Numerator is required in Ratio '
                              'visualization. '
                              'A query string similar to args::query.')
     parser.add_argument('denominator', type=str, location='args',
                         help='Denominator is required in Ratio '
                              'visualization. '
                              'A query string similar to args::query.')
     # Numeric visualization
     parser.add_argument('aggregator', type=str, location='args',
                         default='Average',
                         choices=('Average', 'Sum', 'Max', 'Min', 'Variance',
                                  'Std. Deviation', 'Sum of Squares'),
                         help='Aggregator type is required in Numeric and '
                              'Timer visualization.')
     parser.add_argument('value_field', type=str, location='args',
                         help='The field to be aggregated on is required '
                              'in Numeric visualization.')
     args = parser.parse_args()
     workflow = Workflow.query.filter_by(uuid=workflow_uuid).first()
     if workflow is None:
         raise NotFoundException(
             f'Failed to find workflow: {workflow_uuid}')
     project_config = workflow.project.get_config()
     party = project_config.participants[participant_id]
     client = RpcClient(project_config, party)
     resp = client.get_job_kibana(job_name, json.dumps(args))
     if resp.status.code != common_pb2.STATUS_SUCCESS:
         raise InternalException(resp.status.msg)
     metrics = json.loads(resp.metrics)
     # metrics is a list of 2-element lists,
     #   each 2-element list is a [x, y] pair.
     return {'data': metrics}
示例#5
0
 def is_peer_ready(job: Job) -> bool:
     project_config = job.project.get_config()
     for party in project_config.participants:
         client = RpcClient(project_config, party)
         resp = client.check_job_ready(job.name)
         if resp.status.code != common_pb2.STATUS_SUCCESS:
             emit_counter('check_peer_ready_failed', 1)
             return True
         if not resp.is_ready:
             return False
     return True
示例#6
0
 def _broadcast_state(self, state, target_state, transaction_state):
     project_config = self._project.get_config()
     states = []
     for party in project_config.participants:
         client = RpcClient(project_config, party)
         resp = client.update_workflow_state(self._workflow.name, state,
                                             target_state,
                                             transaction_state)
         if resp.status.code == common_pb2.STATUS_SUCCESS:
             states.append(TransactionState(resp.transaction_state))
         else:
             states.append(None)
     return states
示例#7
0
 def get(self, workflow_id):
     workflow = _get_workflow(workflow_id)
     project_config = workflow.project.get_config()
     peer_workflows = {}
     for party in project_config.participants:
         client = RpcClient(project_config, party)
         resp = client.get_workflow(workflow.name)
         if resp.status.code != common_pb2.STATUS_SUCCESS:
             raise InternalException()
         peer_workflows[party.name] = MessageToDict(
             resp,
             preserving_proto_field_name=True,
             including_default_value_fields=True)
     return {'data': peer_workflows}, HTTPStatus.OK
示例#8
0
文件: apis.py 项目: guotie/fedlearner
    def get(self, participant_id, job_id):
        job = _get_job(job_id)
        workflow = job.workflow
        project_config = workflow.project.get_config()
        party = project_config.participants[participant_id]
        client = RpcClient(project_config, party)
        resp = client.get_job_metrics(job.name)
        if resp.status.code != common_pb2.STATUS_SUCCESS:
            raise InternalException(resp.status.msg)

        metrics = json.loads(resp.metrics)

        # Metrics is a list of dict. Each dict can be rendered by frontend with
        #   mpld3.draw_figure('figure1', json)
        return {'data': metrics}
示例#9
0
    def get(self, workflow_id, participant_id, job_name):
        workflow = Workflow.query.filter_by(id=workflow_id).first()
        if workflow is None:
            raise NotFoundException()
        project_config = workflow.project.get_config()
        party = project_config.participants[participant_id]
        client = RpcClient(project_config, party)
        resp = client.get_job_metrics(workflow.name, job_name)
        if resp.status.code != common_pb2.STATUS_SUCCESS:
            raise InternalException(resp.status.msg)

        metrics = json.loads(resp.metrics)

        # Metrics is a list of dict. Each dict can be rendered by frontend with
        #   mpld3.draw_figure('figure1', json)
        return {'data': metrics}
示例#10
0
 def get(self, workflow_id):
     workflow = _get_workflow(workflow_id)
     project_config = workflow.project.get_config()
     peer_workflows = {}
     for party in project_config.participants:
         client = RpcClient(project_config, party)
         # TODO(xiangyxuan): use uuid to identify the workflow
         resp = client.get_workflow(workflow.name)
         if resp.status.code != common_pb2.STATUS_SUCCESS:
             raise InternalException(resp.status.msg)
         peer_workflow = MessageToDict(resp,
                                       preserving_proto_field_name=True,
                                       including_default_value_fields=True)
         for job in peer_workflow['jobs']:
             if 'pods' in job:
                 job['pods'] = json.loads(job['pods'])
         peer_workflows[party.name] = peer_workflow
     return {'data': peer_workflows}, HTTPStatus.OK
示例#11
0
 def _broadcast_state(self, state, target_state, transaction_state):
     project_config = self._project.get_config()
     states = []
     for party in project_config.participants:
         client = RpcClient(project_config, party)
         forked_from_uuid = Workflow.query.filter_by(
             id=self._workflow.forked_from).first(
             ).uuid if self._workflow.forked_from else None
         resp = client.update_workflow_state(
             self._workflow.name, state, target_state, transaction_state,
             self._workflow.uuid, forked_from_uuid, self._workflow.extra)
         if resp.status.code == common_pb2.STATUS_SUCCESS:
             if resp.state == WorkflowState.INVALID:
                 self._workflow.invalidate()
                 self._reload()
                 raise RuntimeError('Peer workflow invalidated. Abort.')
             states.append(TransactionState(resp.transaction_state))
         else:
             states.append(None)
     return states
示例#12
0
    def patch(self, workflow_id):
        parser = reqparse.RequestParser()
        parser.add_argument('config', type=dict, required=True,
                            help='new config for peer')
        data = parser.parse_args()
        config_proto = dict_to_workflow_definition(data['config'])

        workflow = _get_workflow(workflow_id)
        project_config = workflow.project.get_config()
        peer_workflows = {}
        for party in project_config.participants:
            client = RpcClient(project_config, party)
            resp = client.update_workflow(
                workflow.name, config_proto)
            if resp.status.code != common_pb2.STATUS_SUCCESS:
                raise InternalException(resp.status.msg)
            peer_workflows[party.name] = MessageToDict(
                resp,
                preserving_proto_field_name=True,
                including_default_value_fields=True)
        return {'data': peer_workflows}, HTTPStatus.OK
示例#13
0
    def setUp(self):
        self._client_execution_thread_pool = logging_pool.pool(1)

        # Builds a testing channel
        self._fake_channel = grpc_testing.channel(
            DESCRIPTOR.services_by_name.values(),
            grpc_testing.strict_real_time())
        self._build_channel_patcher = patch(
            'fedlearner_webconsole.rpc.client._build_channel')
        self._mock_build_channel = self._build_channel_patcher.start()
        self._mock_build_channel.return_value = self._fake_channel
        self._client = RpcClient(self._project_config, self._participant)

        self._mock_build_channel.assert_called_once_with(
            self._TEST_URL, self._TEST_AUTHORITY)
示例#14
0
 def check_connection(self, project_config: ProjectProto,
                      participant_proto: ParticipantProto):
     client = RpcClient(project_config, participant_proto)
     return client.check_connection()
示例#15
0
 def _get_peer_workflow(self):
     project_config = self.project.get_config()
     # TODO: find coordinator for multiparty
     client = RpcClient(project_config, project_config.participants[0])
     return client.get_workflow(self.name)