示例#1
0
 def save_pipelined_model(cls, job_id, role, party_id):
     schedule_logger(job_id).info(
         'job {} on {} {} start to save pipeline'.format(
             job_id, role, party_id))
     job_dsl, job_runtime_conf, train_runtime_conf = job_utils.get_job_configuration(
         job_id=job_id, role=role, party_id=party_id)
     job_parameters = job_runtime_conf.get('job_parameters', {})
     model_id = job_parameters['model_id']
     model_version = job_parameters['model_version']
     job_type = job_parameters.get('job_type', '')
     if job_type == 'predict':
         return
     dag = schedule_utils.get_job_dsl_parser(
         dsl=job_dsl,
         runtime_conf=job_runtime_conf,
         train_runtime_conf=train_runtime_conf)
     predict_dsl = dag.get_predict_dsl(role=role)
     pipeline = pipeline_pb2.Pipeline()
     pipeline.inference_dsl = json_dumps(predict_dsl, byte=True)
     pipeline.train_dsl = json_dumps(job_dsl, byte=True)
     pipeline.train_runtime_conf = json_dumps(job_runtime_conf, byte=True)
     pipeline.fate_version = RuntimeConfig.get_env("FATE")
     pipeline.model_id = model_id
     pipeline.model_version = model_version
     tracker = Tracker(job_id=job_id,
                       role=role,
                       party_id=party_id,
                       model_id=model_id,
                       model_version=model_version)
     tracker.save_pipelined_model(pipelined_buffer_object=pipeline)
     if role != 'local':
         tracker.save_machine_learning_model_info()
     schedule_logger(job_id).info(
         'job {} on {} {} save pipeline successfully'.format(
             job_id, role, party_id))
示例#2
0
 def test1(self):
     cache = DataCache(name="test_cache",
                       data={"t1": DTable(namespace="test", name="test1")},
                       meta={"t1": {
                           "a": 1
                       }})
     a = json_loads(json_dumps(cache))
     self.assertEqual(a["data"]["t1"]["namespace"], "test")
     b = json_loads(json_dumps(cache, with_type=True),
                    object_hook=from_dict_hook)
     self.assertEqual(b.data["t1"].namespace, "test")
示例#3
0
    def unaryCall(self, _request, context):
        packet = _request
        header = packet.header
        _suffix = packet.body.key
        param_bytes = packet.body.value
        param = bytes.decode(param_bytes)
        job_id = header.task.taskId
        src = header.src
        dst = header.dst
        method = header.operator
        param_dict = json_loads(param)
        param_dict['src_party_id'] = str(src.partyId)
        source_routing_header = []
        for key, value in context.invocation_metadata():
            source_routing_header.append((key, value))
        stat_logger.info(
            f"grpc request routing header: {source_routing_header}")

        param = bytes.decode(bytes(json_dumps(param_dict), 'utf-8'))

        action = getattr(requests, method.lower(), None)
        if action:
            print(_suffix)
            #resp = action(url=get_url(_suffix), data=param, headers=HEADERS)
        else:
            pass
        #resp_json = resp.json()
        resp_json = {"status": "test"}
        import time
        print("sleep")
        time.sleep(60)
        return wrap_grpc_packet(resp_json, method, _suffix, dst.partyId,
                                src.partyId, job_id)
示例#4
0
def get_predict_dsl():
    request_data = request.json
    request_data['query_filters'] = ['inference_dsl']
    retcode, retmsg, data = model_utils.query_model_info_from_file(
        **request_data)
    if data:
        for d in data:
            if d.get("f_role") in {"guest", "host"}:
                _data = d
                break
        else:
            return error_response(
                210,
                "can not found guest or host model, please get predict dsl on guest or host."
            )
        if request_data.get("filename"):
            os.makedirs(TEMP_DIRECTORY, exist_ok=True)
            temp_filepath = os.path.join(TEMP_DIRECTORY,
                                         request_data.get("filename"))
            with open(temp_filepath, "w") as fout:
                fout.write(json_dumps(_data['f_inference_dsl'], indent=4))
            return send_file(open(temp_filepath, "rb"),
                             as_attachment=True,
                             attachment_filename=request_data.get("filename"))
        else:
            return get_json_result(data=_data['f_inference_dsl'])
    return error_response(
        210,
        "No model found, please check if arguments are specified correctly.")
示例#5
0
 def stop_job(cls, job_id, role, party_id, stop_status):
     schedule_logger(job_id=job_id).info(
         f"request stop job {job_id} with {stop_status}")
     jobs = JobSaver.query_job(job_id=job_id,
                               role=role,
                               party_id=party_id,
                               is_initiator=True)
     if len(jobs) > 0:
         if stop_status == JobStatus.CANCELED:
             schedule_logger(job_id=job_id).info(f"cancel job {job_id}")
             set_cancel_status = cls.cancel_signal(job_id=job_id,
                                                   set_or_reset=True)
             schedule_logger(job_id=job_id).info(
                 f"set job {job_id} cancel signal {set_cancel_status}")
         job = jobs[0]
         job.f_status = stop_status
         schedule_logger(job_id=job_id).info(
             f"request stop job {job_id} with {stop_status} to all party")
         status_code, response = FederatedScheduler.stop_job(
             job=jobs[0], stop_status=stop_status)
         if status_code == FederatedSchedulingStatusCode.SUCCESS:
             schedule_logger(job_id=job_id).info(
                 f"stop job {job_id} with {stop_status} successfully")
             return RetCode.SUCCESS, "success"
         else:
             schedule_logger(job_id=job_id).info(
                 f"stop job {job_id} with {stop_status} failed, {response}")
             return RetCode.FEDERATED_ERROR, json_dumps(response)
     else:
         return RetCode.SUCCESS, "can not found job"
示例#6
0
def save_job_conf(job_id,
                  role,
                  job_dsl,
                  job_runtime_conf,
                  job_runtime_conf_on_party,
                  train_runtime_conf,
                  pipeline_dsl=None):
    path_dict = get_job_conf_path(job_id=job_id, role=role)
    os.makedirs(os.path.dirname(path_dict.get('job_dsl_path')), exist_ok=True)
    os.makedirs(os.path.dirname(
        path_dict.get('job_runtime_conf_on_party_path')),
                exist_ok=True)
    for data, conf_path in [
        (job_dsl, path_dict['job_dsl_path']),
        (job_runtime_conf, path_dict['job_runtime_conf_path']),
        (job_runtime_conf_on_party,
         path_dict['job_runtime_conf_on_party_path']),
        (train_runtime_conf, path_dict['train_runtime_conf_path']),
        (pipeline_dsl, path_dict['pipeline_dsl_path'])
    ]:
        with open(conf_path, 'w+') as f:
            f.truncate()
            if not data:
                data = {}
            f.write(json_dumps(data, indent=4))
            f.flush()
    return path_dict
示例#7
0
    def save(self, model_buffers: Dict[str, Tuple[str, bytes, dict]]):
        if not model_buffers:
            raise ValueError('model_buffers is empty.')

        self.create_time = datetime.utcnow()
        data = {
            'step_index': self.step_index,
            'step_name': self.step_name,
            'create_time': self.create_time.isoformat(),
            'models': {},
        }

        model_data = {}
        for model_name, (pb_name, serialized_string,
                         json_format_dict) in model_buffers.items():
            model_data[model_name] = (serialized_string, json_format_dict)

            data['models'][model_name] = {
                'sha1': hashlib.sha1(serialized_string).hexdigest(),
                'buffer_name': pb_name,
            }

        with self.lock:
            for model_name, model in data['models'].items():
                serialized_string, json_format_dict = model_data[model_name]
                (self.directory /
                 f'{model_name}.pb').write_bytes(serialized_string)
                (self.directory / f'{model_name}.json').write_text(
                    json_dumps(json_format_dict), 'utf8')

            self.database.write_text(
                yaml.dump(data, Dumper=yaml.RoundTripDumper), 'utf8')

        stat_logger.info(f'Checkpoint saved. path: {self.directory}')
        return self.directory
示例#8
0
文件: job_app.py 项目: zeta1999/FATE
def stop_job():
    job_id = request.json.get('job_id')
    stop_status = request.json.get("stop_status", "canceled")
    jobs = JobSaver.query_job(job_id=job_id)
    if jobs:
        schedule_logger(job_id).info(f"stop job on this party")
        kill_status, kill_details = JobController.stop_jobs(
            job_id=job_id, stop_status=stop_status)
        schedule_logger(job_id).info(
            f"stop job on this party status {kill_status}")
        schedule_logger(job_id).info(
            f"request stop job {jobs[0]} to {stop_status}")
        status_code, response = FederatedScheduler.request_stop_job(
            job=jobs[0],
            stop_status=stop_status,
            command_body=jobs[0].to_json())
        if status_code == FederatedSchedulingStatusCode.SUCCESS:
            return get_json_result(
                retcode=RetCode.SUCCESS,
                retmsg=f"stop job on this party {kill_status};\n"
                f"stop job on all party success")
        else:
            return get_json_result(retcode=RetCode.OPERATING_ERROR,
                                   retmsg="stop job on this party {};\n"
                                   "stop job failed:\n{}".format(
                                       kill_status,
                                       json_dumps(response, indent=4)))
    else:
        schedule_logger(job_id).info(f"can not found job {job_id} to stop")
        return get_json_result(retcode=RetCode.DATA_ERROR,
                               retmsg="can not found job")
示例#9
0
def wrap_grpc_packet(json_body,
                     http_method,
                     url,
                     src_party_id,
                     dst_party_id,
                     job_id=None,
                     overall_timeout=DEFAULT_REMOTE_REQUEST_TIMEOUT):
    _src_end_point = basic_meta_pb2.Endpoint(ip=IP, port=GRPC_PORT)
    _src = proxy_pb2.Topic(name=job_id,
                           partyId="{}".format(src_party_id),
                           role=FATEFLOW_SERVICE_NAME,
                           callback=_src_end_point)
    _dst = proxy_pb2.Topic(name=job_id,
                           partyId="{}".format(dst_party_id),
                           role=FATEFLOW_SERVICE_NAME,
                           callback=None)
    _task = proxy_pb2.Task(taskId=job_id)
    _command = proxy_pb2.Command(name=FATEFLOW_SERVICE_NAME)
    _conf = proxy_pb2.Conf(overallTimeout=overall_timeout)
    _meta = proxy_pb2.Metadata(src=_src,
                               dst=_dst,
                               task=_task,
                               command=_command,
                               operator=http_method,
                               conf=_conf)
    _data = proxy_pb2.Data(key=url,
                           value=bytes(json_dumps(json_body), 'utf-8'))
    return proxy_pb2.Packet(header=_meta, body=_data)
示例#10
0
def wrap_grpc_packet(json_body,
                     http_method,
                     url,
                     src_party_id,
                     dst_party_id,
                     job_id=None,
                     overall_timeout=None):
    overall_timeout = JobDefaultConfig.remote_request_timeout if overall_timeout is None else overall_timeout
    _src_end_point = basic_meta_pb2.Endpoint(ip=HOST, port=GRPC_PORT)
    _src = proxy_pb2.Topic(name=job_id,
                           partyId="{}".format(src_party_id),
                           role=FATE_FLOW_SERVICE_NAME,
                           callback=_src_end_point)
    _dst = proxy_pb2.Topic(name=job_id,
                           partyId="{}".format(dst_party_id),
                           role=FATE_FLOW_SERVICE_NAME,
                           callback=None)
    _task = proxy_pb2.Task(taskId=job_id)
    _command = proxy_pb2.Command(name=FATE_FLOW_SERVICE_NAME)
    _conf = proxy_pb2.Conf(overallTimeout=overall_timeout)
    _meta = proxy_pb2.Metadata(src=_src,
                               dst=_dst,
                               task=_task,
                               command=_command,
                               operator=http_method,
                               conf=_conf)
    _data = proxy_pb2.Data(key=url,
                           value=bytes(json_dumps(json_body), 'utf-8'))
    return proxy_pb2.Packet(header=_meta, body=_data)
示例#11
0
def local_api(job_id,
              method,
              endpoint,
              json_body,
              api_version=API_VERSION,
              try_times=3):
    endpoint = f"/{api_version}{endpoint}"
    exception = None
    for t in range(try_times):
        try:
            url = "http://{}:{}{}".format(RuntimeConfig.JOB_SERVER_HOST,
                                          RuntimeConfig.HTTP_PORT, endpoint)
            audit_logger(job_id).info('local api request: {}'.format(url))
            action = getattr(requests, method.lower(), None)
            http_response = action(url=url,
                                   data=json_dumps(json_body),
                                   headers=HEADERS)
            audit_logger(job_id).info(http_response.text)
            response = http_response.json()
            audit_logger(job_id).info('local api response: {} {}'.format(
                endpoint, response))
            return response
        except Exception as e:
            schedule_logger(job_id).exception(e)
            exception = e
    else:
        raise Exception('local request error: {}'.format(exception))
示例#12
0
 def db_value(self, value):
     if self._serialized_type == SerializedType.PICKLE:
         return serialize_b64(value, to_str=True)
     elif self._serialized_type == SerializedType.JSON:
         if value is None:
             return None
         return json_dumps(value, with_type=True)
     else:
         raise ValueError(f"the serialized type {self._serialized_type} is not supported")
示例#13
0
 def write_component_model(self, component_model):
     for storage_path, (
             object_serialized_encoded,
             object_json) in component_model.get("buffer").items():
         storage_path = get_fate_flow_directory() + storage_path
         os.makedirs(os.path.dirname(storage_path), exist_ok=True)
         with self.lock, open(storage_path, "wb") as fw:
             fw.write(base64.b64decode(object_serialized_encoded.encode()))
         with self.lock, open(f"{storage_path}.json", "w",
                              encoding="utf8") as fw:
             fw.write(base_utils.json_dumps(object_json))
     run_parameters = component_model.get("run_parameters", {}) or {}
     p = self.component_run_parameters_path(
         component_model["component_name"])
     os.makedirs(os.path.dirname(p), exist_ok=True)
     with self.lock, open(p, "w", encoding="utf8") as fw:
         fw.write(base_utils.json_dumps(run_parameters))
     self.update_component_meta(
         component_name=component_model["component_name"],
         component_module_name=component_model["component_module_name"],
         model_alias=component_model["model_alias"],
         model_proto_index=component_model["model_proto_index"])
     stat_logger.info("save {} {} successfully".format(
         component_model["component_name"], component_model["model_alias"]))
示例#14
0
文件: job_app.py 项目: zeta1999/FATE
def rerun_job():
    job_id = request.json.get("job_id")
    jobs = JobSaver.query_job(job_id=job_id)
    if jobs:
        status_code, response = FederatedScheduler.request_rerun_job(
            job=jobs[0], command_body=request.json)
        if status_code == FederatedSchedulingStatusCode.SUCCESS:
            return get_json_result(retcode=RetCode.SUCCESS,
                                   retmsg="rerun job success")
        else:
            return get_json_result(retcode=RetCode.OPERATING_ERROR,
                                   retmsg="rerun job failed:\n{}".format(
                                       json_dumps(response)))
    else:
        return get_json_result(retcode=RetCode.DATA_ERROR,
                               retmsg="can not found job")
示例#15
0
    def save_pipelined_model(cls, job_id, role, party_id):
        schedule_logger(job_id).info(
            'job {} on {} {} start to save pipeline'.format(
                job_id, role, party_id))
        job_dsl, job_runtime_conf, runtime_conf_on_party, train_runtime_conf = job_utils.get_job_configuration(
            job_id=job_id, role=role, party_id=party_id)
        job_parameters = runtime_conf_on_party.get('job_parameters', {})
        if role in job_parameters.get("assistant_role", []):
            return
        model_id = job_parameters['model_id']
        model_version = job_parameters['model_version']
        job_type = job_parameters.get('job_type', '')
        work_mode = job_parameters['work_mode']
        roles = runtime_conf_on_party['role']
        initiator_role = runtime_conf_on_party['initiator']['role']
        initiator_party_id = runtime_conf_on_party['initiator']['party_id']
        if job_type == 'predict':
            return
        dag = schedule_utils.get_job_dsl_parser(
            dsl=job_dsl,
            runtime_conf=job_runtime_conf,
            train_runtime_conf=train_runtime_conf)
        predict_dsl = dag.get_predict_dsl(role=role)
        pipeline = pipeline_pb2.Pipeline()
        pipeline.inference_dsl = json_dumps(predict_dsl, byte=True)
        pipeline.train_dsl = json_dumps(job_dsl, byte=True)
        pipeline.train_runtime_conf = json_dumps(job_runtime_conf, byte=True)
        pipeline.fate_version = RuntimeConfig.get_env("FATE")
        pipeline.model_id = model_id
        pipeline.model_version = model_version

        pipeline.parent = True
        pipeline.loaded_times = 0
        pipeline.roles = json_dumps(roles, byte=True)
        pipeline.work_mode = work_mode
        pipeline.initiator_role = initiator_role
        pipeline.initiator_party_id = initiator_party_id
        pipeline.runtime_conf_on_party = json_dumps(runtime_conf_on_party,
                                                    byte=True)
        pipeline.parent_info = json_dumps({}, byte=True)

        tracker = Tracker(job_id=job_id,
                          role=role,
                          party_id=party_id,
                          model_id=model_id,
                          model_version=model_version)
        tracker.save_pipelined_model(pipelined_buffer_object=pipeline)
        if role != 'local':
            tracker.save_machine_learning_model_info()
        schedule_logger(job_id).info(
            'job {} on {} {} save pipeline successfully'.format(
                job_id, role, party_id))
示例#16
0
def get_predict_conf():
    request_data = request.json
    required_parameters = ['model_id', 'model_version']
    check_config(request_data, required_parameters)
    model_dir = os.path.join(get_project_base_directory(), 'model_local_cache')
    model_fp_list = glob.glob(
        model_dir +
        f"/guest#*#{request_data['model_id']}/{request_data['model_version']}")
    if model_fp_list:
        fp = model_fp_list[0]
        pipeline_model = PipelinedModel(model_id=fp.split('/')[-2],
                                        model_version=fp.split('/')[-1])
        pipeline = pipeline_model.read_component_model('pipeline',
                                                       'pipeline')['Pipeline']
        predict_dsl = json_loads(pipeline.inference_dsl)

        train_runtime_conf = json_loads(pipeline.train_runtime_conf)
        parser = schedule_utils.get_dsl_parser_by_version(
            train_runtime_conf.get('dsl_version', '1'))
        predict_conf = parser.generate_predict_conf_template(
            predict_dsl=predict_dsl,
            train_conf=train_runtime_conf,
            model_id=request_data['model_id'],
            model_version=request_data['model_version'])
    else:
        predict_conf = ''
    if predict_conf:
        if request_data.get("filename"):
            os.makedirs(TEMP_DIRECTORY, exist_ok=True)
            temp_filepath = os.path.join(TEMP_DIRECTORY,
                                         request_data.get("filename"))
            with open(temp_filepath, "w") as fout:

                fout.write(json_dumps(predict_conf, indent=4))
            return send_file(open(temp_filepath, "rb"),
                             as_attachment=True,
                             attachment_filename=request_data.get("filename"))
        else:
            return get_json_result(data=predict_conf)
    return error_response(
        210,
        "No model found, please check if arguments are specified correctly.")
示例#17
0
def federated_coordination_on_http(
        job_id,
        method,
        host,
        port,
        endpoint,
        src_party_id,
        src_role,
        dest_party_id,
        json_body,
        api_version=API_VERSION,
        overall_timeout=DEFAULT_REMOTE_REQUEST_TIMEOUT,
        try_times=3):
    endpoint = f"/{api_version}{endpoint}"
    exception = None
    json_body['src_role'] = src_role
    json_body['src_party_id'] = src_party_id
    for t in range(try_times):
        try:
            url = "http://{}:{}{}".format(host, port, endpoint)
            audit_logger(job_id).info(
                'remote http api request: {}'.format(url))
            action = getattr(requests, method.lower(), None)
            headers = HEADERS.copy()
            headers["dest-party-id"] = str(dest_party_id)
            headers["src-party-id"] = str(src_party_id)
            headers["src-role"] = str(src_role)
            http_response = action(url=url,
                                   data=json_dumps(json_body),
                                   headers=headers)
            audit_logger(job_id).info(http_response.text)
            response = http_response.json()
            audit_logger(job_id).info('remote http api response: {} {}'.format(
                endpoint, response))
            return response
        except Exception as e:
            exception = e
            schedule_logger(job_id).warning(
                f"remote http request {endpoint} error, sleep and try again")
            time.sleep(2 * (t + 1))
    else:
        raise Exception('remote http request error: {}'.format(exception))
示例#18
0
 def stop_job(cls, job_id, role, party_id, stop_status):
     schedule_logger(job_id=job_id).info(f"request stop job {job_id}")
     jobs = JobSaver.query_job(job_id=job_id, role=role, party_id=party_id, is_initiator=True)
     if len(jobs) > 0:
         schedule_logger(job_id=job_id).info(f"initiator cancel job {job_id}")
         JobController.cancel_job(job_id=job_id, role=role, party_id=party_id)
         job = jobs[0]
         job.f_status = stop_status
         schedule_logger(job_id=job_id).info(f"request cancel job {job_id} to all party")
         status_code, response = FederatedScheduler.stop_job(job=jobs[0], stop_status=stop_status)
         if status_code == FederatedSchedulingStatusCode.SUCCESS:
             schedule_logger(job_id=job_id).info(f"cancel job {job_id} successfully")
             return RetCode.SUCCESS, "success"
         else:
             schedule_logger(job_id=job_id).info(f"cancel job {job_id} failed, {response}")
             return RetCode.FEDERATED_ERROR, json_dumps(response)
     else:
         schedule_logger(job_id=job_id).info(f"can not found job {job_id} to stop, delete event on {role} {party_id}")
         JobQueue.delete_event(job_id=job_id)
         return RetCode.SUCCESS, "can not found job, delete job waiting event"
示例#19
0
文件: job_app.py 项目: zpskt/FATE
def stop_job():
    job_id = request.json.get('job_id')
    stop_status = request.json.get("stop_status", "canceled")
    jobs = JobSaver.query_job(job_id=job_id)
    if jobs:
        stat_logger.info(f"request stop job {jobs[0]} to {stop_status}")
        status_code, response = FederatedScheduler.request_stop_job(
            job=jobs[0],
            stop_status=stop_status,
            command_body=jobs[0].to_json())
        if status_code == FederatedSchedulingStatusCode.SUCCESS:
            return get_json_result(retcode=RetCode.SUCCESS,
                                   retmsg="stop job success")
        else:
            return get_json_result(retcode=RetCode.OPERATING_ERROR,
                                   retmsg="stop job failed:\n{}".format(
                                       json_dumps(response, indent=4)))
    else:
        stat_logger.info(f"can not found job {jobs[0]} to stop")
        return get_json_result(retcode=RetCode.DATA_ERROR,
                               retmsg="can not found job")
示例#20
0
    def unaryCall(self, _request, context):
        packet = _request
        header = packet.header
        _suffix = packet.body.key
        param_bytes = packet.body.value
        param = bytes.decode(param_bytes)
        job_id = header.task.taskId
        src = header.src
        dst = header.dst
        method = header.operator
        param_dict = json_loads(param)
        param_dict['src_party_id'] = str(src.partyId)
        source_routing_header = []
        for key, value in context.invocation_metadata():
            source_routing_header.append((key, value))

        _routing_metadata = gen_routing_metadata(src_party_id=src.partyId,
                                                 dest_party_id=dst.partyId)
        context.set_trailing_metadata(trailing_metadata=_routing_metadata)
        try:
            nodes_check(param_dict.get('src_party_id'),
                        param_dict.get('_src_role'), param_dict.get('appKey'),
                        param_dict.get('appSecret'), str(dst.partyId))
        except Exception as e:
            resp_json = {"retcode": 100, "retmsg": str(e)}
            return wrap_grpc_packet(resp_json, method, _suffix, dst.partyId,
                                    src.partyId, job_id)
        param = bytes.decode(bytes(json_dumps(param_dict), 'utf-8'))

        action = getattr(requests, method.lower(), None)
        audit_logger(job_id).info('rpc receive: {}'.format(packet))
        if action:
            audit_logger(job_id).info("rpc receive: {} {}".format(
                get_url(_suffix), param))
            resp = action(url=get_url(_suffix), data=param, headers=HEADERS)
        else:
            pass
        resp_json = resp.json()
        return wrap_grpc_packet(resp_json, method, _suffix, dst.partyId,
                                src.partyId, job_id)
示例#21
0
def dump_job_conf(path_dict,
                  dsl,
                  runtime_conf,
                  runtime_conf_on_party,
                  train_runtime_conf,
                  pipeline_dsl=None):
    os.makedirs(os.path.dirname(path_dict.get('dsl_path')), exist_ok=True)
    os.makedirs(os.path.dirname(path_dict.get('runtime_conf_on_party_path')),
                exist_ok=True)
    for data, conf_path in [
        (dsl, path_dict['dsl_path']),
        (runtime_conf, path_dict['runtime_conf_path']),
        (runtime_conf_on_party, path_dict['runtime_conf_on_party_path']),
        (train_runtime_conf, path_dict['train_runtime_conf_path']),
        (pipeline_dsl, path_dict['pipeline_dsl_path'])
    ]:
        with open(conf_path, 'w+') as f:
            f.truncate()
            if not data:
                data = {}
            f.write(json_dumps(data, indent=4))
            f.flush()
    return path_dict
示例#22
0
    def save_pipelined_model(cls, job_id, role, party_id):
        schedule_logger(job_id).info(
            f"start to save pipeline model on {role} {party_id}")
        job_configuration = job_utils.get_job_configuration(job_id=job_id,
                                                            role=role,
                                                            party_id=party_id)
        runtime_conf_on_party = job_configuration.runtime_conf_on_party
        job_parameters = runtime_conf_on_party.get('job_parameters', {})
        if role in job_parameters.get("assistant_role", []):
            return
        model_id = job_parameters['model_id']
        model_version = job_parameters['model_version']
        job_type = job_parameters.get('job_type', '')
        roles = runtime_conf_on_party['role']
        initiator_role = runtime_conf_on_party['initiator']['role']
        initiator_party_id = runtime_conf_on_party['initiator']['party_id']
        if job_type == 'predict':
            return
        dsl_parser = schedule_utils.get_job_dsl_parser(
            dsl=job_configuration.dsl,
            runtime_conf=job_configuration.runtime_conf,
            train_runtime_conf=job_configuration.train_runtime_conf)

        components_parameters = {}
        tasks = JobSaver.query_task(job_id=job_id,
                                    role=role,
                                    party_id=party_id,
                                    only_latest=True)
        for task in tasks:
            components_parameters[
                task.f_component_name] = task.f_component_parameters
        predict_dsl = schedule_utils.fill_inference_dsl(
            dsl_parser,
            origin_inference_dsl=job_configuration.dsl,
            components_parameters=components_parameters)

        pipeline = pipeline_pb2.Pipeline()
        pipeline.inference_dsl = json_dumps(predict_dsl, byte=True)
        pipeline.train_dsl = json_dumps(job_configuration.dsl, byte=True)
        pipeline.train_runtime_conf = json_dumps(
            job_configuration.runtime_conf, byte=True)
        pipeline.fate_version = RuntimeConfig.get_env("FATE")
        pipeline.model_id = model_id
        pipeline.model_version = model_version

        pipeline.parent = True
        pipeline.loaded_times = 0
        pipeline.roles = json_dumps(roles, byte=True)
        pipeline.initiator_role = initiator_role
        pipeline.initiator_party_id = initiator_party_id
        pipeline.runtime_conf_on_party = json_dumps(runtime_conf_on_party,
                                                    byte=True)
        pipeline.parent_info = json_dumps({}, byte=True)

        tracker = Tracker(job_id=job_id,
                          role=role,
                          party_id=party_id,
                          model_id=model_id,
                          model_version=model_version,
                          job_parameters=RunParameters(**job_parameters))
        tracker.save_pipeline_model(pipeline_buffer_object=pipeline)
        if role != 'local':
            tracker.save_machine_learning_model_info()
        schedule_logger(job_id).info(
            f"save pipeline on {role} {party_id} successfully")
示例#23
0
def migration(config_data: dict):
    try:
        party_model_id = model_utils.gen_party_model_id(
            model_id=config_data["model_id"],
            role=config_data["local"]["role"],
            party_id=config_data["local"]["party_id"])
        model = pipelined_model.PipelinedModel(
            model_id=party_model_id,
            model_version=config_data["model_version"])
        if not model.exists():
            raise Exception("Can not found {} {} model local cache".format(
                config_data["model_id"], config_data["model_version"]))
        with DB.connection_context():
            if MLModel.get_or_none(MLModel.f_model_version ==
                                   config_data["unify_model_version"]):
                raise Exception(
                    "Unify model version {} has been occupied in database. "
                    "Please choose another unify model version and try again.".
                    format(config_data["unify_model_version"]))

        model_data = model.collect_models(in_bytes=True)
        if "pipeline.pipeline:Pipeline" not in model_data:
            raise Exception("Can not found pipeline file in model.")

        migrate_model = pipelined_model.PipelinedModel(
            model_id=model_utils.gen_party_model_id(
                model_id=model_utils.gen_model_id(config_data["migrate_role"]),
                role=config_data["local"]["role"],
                party_id=config_data["local"]["migrate_party_id"]),
            model_version=config_data["unify_model_version"])

        # migrate_model.create_pipelined_model()
        shutil.copytree(src=model.model_path, dst=migrate_model.model_path)

        pipeline = migrate_model.read_component_model('pipeline',
                                                      'pipeline')['Pipeline']

        # Utilize Pipeline_model collect model data. And modify related inner information of model
        train_runtime_conf = json_loads(pipeline.train_runtime_conf)
        train_runtime_conf["role"] = config_data["migrate_role"]
        train_runtime_conf["initiator"] = config_data["migrate_initiator"]

        adapter = JobRuntimeConfigAdapter(train_runtime_conf)
        train_runtime_conf = adapter.update_model_id_version(
            model_id=model_utils.gen_model_id(train_runtime_conf["role"]),
            model_version=migrate_model.model_version)

        # update pipeline.pb file
        pipeline.train_runtime_conf = json_dumps(train_runtime_conf, byte=True)
        pipeline.model_id = bytes(
            adapter.get_common_parameters().to_dict.get("model_id"), "utf-8")
        pipeline.model_version = bytes(
            adapter.get_common_parameters().to_dict().get("model_version"),
            "utf-8")

        # save updated pipeline.pb file
        migrate_model.save_pipeline(pipeline)
        shutil.copyfile(
            os.path.join(migrate_model.model_path, "pipeline.pb"),
            os.path.join(migrate_model.model_path, "variables", "data",
                         "pipeline", "pipeline", "Pipeline"))

        # modify proto
        with open(
                os.path.join(migrate_model.model_path, 'define',
                             'define_meta.yaml'), 'r') as fin:
            define_yaml = yaml.safe_load(fin)

        for key, value in define_yaml['model_proto'].items():
            if key == 'pipeline':
                continue
            for v in value.keys():
                buffer_obj = migrate_model.read_component_model(key, v)
                module_name = define_yaml['component_define'].get(
                    key, {}).get('module_name')
                modified_buffer = model_migration(
                    model_contents=buffer_obj,
                    module_name=module_name,
                    old_guest_list=config_data['role']['guest'],
                    new_guest_list=config_data['migrate_role']['guest'],
                    old_host_list=config_data['role']['host'],
                    new_host_list=config_data['migrate_role']['host'],
                    old_arbiter_list=config_data.get('role',
                                                     {}).get('arbiter', None),
                    new_arbiter_list=config_data.get('migrate_role',
                                                     {}).get('arbiter', None))
                migrate_model.save_component_model(
                    component_name=key,
                    component_module_name=module_name,
                    model_alias=v,
                    model_buffers=modified_buffer)

        archive_path = migrate_model.packaging_model()
        shutil.rmtree(os.path.abspath(migrate_model.model_path))

        return (0, f"Migrating model successfully. " \
                  "The configuration of model has been modified automatically. " \
                  "New model id is: {}, model version is: {}. " \
                  "Model files can be found at '{}'.".format(adapter.get_common_parameters()["model_id"],
                                                             migrate_model.model_version,
                                                             os.path.abspath(archive_path)),
                {"model_id": migrate_model.model_id,
                 "model_version": migrate_model.model_version,
                 "path": os.path.abspath(archive_path)})

    except Exception as e:
        return 100, str(e), {}
示例#24
0
    def start_task(cls, job_id, component_name, task_id, task_version, role,
                   party_id, **kwargs):
        """
        Start task, update status and party status
        :param job_id:
        :param component_name:
        :param task_id:
        :param task_version:
        :param role:
        :param party_id:
        :return:
        """
        job_dsl = job_utils.get_job_dsl(job_id, role, party_id)
        PrivilegeAuth.authentication_component(
            job_dsl,
            src_party_id=kwargs.get('src_party_id'),
            src_role=kwargs.get('src_role'),
            party_id=party_id,
            component_name=component_name)

        schedule_logger(job_id).info(
            f"try to start task {task_id} {task_version} on {role} {party_id} executor subprocess"
        )
        task_executor_process_start_status = False
        task_info = {
            "job_id": job_id,
            "task_id": task_id,
            "task_version": task_version,
            "role": role,
            "party_id": party_id,
        }
        is_failed = False
        try:
            task = JobSaver.query_task(task_id=task_id,
                                       task_version=task_version,
                                       role=role,
                                       party_id=party_id)[0]
            run_parameters_dict = job_utils.get_job_parameters(
                job_id, role, party_id)
            run_parameters_dict["src_user"] = kwargs.get("src_user")
            run_parameters = RunParameters(**run_parameters_dict)

            config_dir = job_utils.get_task_directory(job_id, role, party_id,
                                                      component_name, task_id,
                                                      task_version)
            os.makedirs(config_dir, exist_ok=True)

            run_parameters_path = os.path.join(config_dir,
                                               'task_parameters.json')
            with open(run_parameters_path, 'w') as fw:
                fw.write(json_dumps(run_parameters_dict))

            schedule_logger(job_id).info(
                f"use computing engine {run_parameters.computing_engine}")
            task_info["engine_conf"] = {
                "computing_engine": run_parameters.computing_engine
            }
            backend_engine = build_engine(run_parameters.computing_engine)
            run_info = backend_engine.run(
                task=task,
                run_parameters=run_parameters,
                run_parameters_path=run_parameters_path,
                config_dir=config_dir,
                log_dir=job_utils.get_job_log_directory(
                    job_id, role, party_id, component_name),
                cwd_dir=job_utils.get_job_directory(job_id, role, party_id,
                                                    component_name),
                user_name=kwargs.get("user_id"))
            task_info.update(run_info)
            task_info["start_time"] = current_timestamp()
            task_executor_process_start_status = True
        except Exception as e:
            schedule_logger(job_id).exception(e)
            is_failed = True
        finally:
            try:
                cls.update_task(task_info=task_info)
                task_info["party_status"] = TaskStatus.RUNNING
                cls.update_task_status(task_info=task_info)
                if is_failed:
                    task_info["party_status"] = TaskStatus.FAILED
                    cls.update_task_status(task_info=task_info)
            except Exception as e:
                schedule_logger(job_id).exception(e)
            schedule_logger(job_id).info(
                "task {} {} on {} {} executor subprocess start {}".format(
                    task_id, task_version, role, party_id, "success"
                    if task_executor_process_start_status else "failed"))
示例#25
0
def jprint(src: dict, indent: int = 4):
    print(json_dumps(src, indent=indent))
示例#26
0
    def start_task(cls, job_id, component_name, task_id, task_version, role,
                   party_id):
        """
        Start task, update status and party status
        :param job_id:
        :param component_name:
        :param task_id:
        :param task_version:
        :param role:
        :param party_id:
        :return:
        """
        schedule_logger(job_id).info(
            'try to start job {} task {} {} on {} {} executor subprocess'.
            format(job_id, task_id, task_version, role, party_id))
        task_executor_process_start_status = False
        task_info = {
            "job_id": job_id,
            "task_id": task_id,
            "task_version": task_version,
            "role": role,
            "party_id": party_id,
        }
        try:
            task_dir = os.path.join(job_utils.get_job_directory(job_id=job_id),
                                    role, party_id, component_name, task_id,
                                    task_version)
            os.makedirs(task_dir, exist_ok=True)
            task_parameters_path = os.path.join(task_dir,
                                                'task_parameters.json')
            run_parameters_dict = job_utils.get_job_parameters(
                job_id, role, party_id)
            with open(task_parameters_path, 'w') as fw:
                fw.write(json_dumps(run_parameters_dict))

            run_parameters = RunParameters(**run_parameters_dict)

            schedule_logger(job_id=job_id).info(
                f"use computing engine {run_parameters.computing_engine}")

            if run_parameters.computing_engine in {
                    ComputingEngine.EGGROLL, ComputingEngine.STANDALONE
            }:
                process_cmd = [
                    sys.executable,
                    sys.modules[TaskExecutor.__module__].__file__,
                    '-j',
                    job_id,
                    '-n',
                    component_name,
                    '-t',
                    task_id,
                    '-v',
                    task_version,
                    '-r',
                    role,
                    '-p',
                    party_id,
                    '-c',
                    task_parameters_path,
                    '--run_ip',
                    RuntimeConfig.JOB_SERVER_HOST,
                    '--job_server',
                    '{}:{}'.format(RuntimeConfig.JOB_SERVER_HOST,
                                   RuntimeConfig.HTTP_PORT),
                ]
            elif run_parameters.computing_engine == ComputingEngine.SPARK:
                if "SPARK_HOME" not in os.environ:
                    raise EnvironmentError("SPARK_HOME not found")
                spark_home = os.environ["SPARK_HOME"]

                # additional configs
                spark_submit_config = run_parameters.spark_run

                deploy_mode = spark_submit_config.get("deploy-mode", "client")
                if deploy_mode not in ["client"]:
                    raise ValueError(
                        f"deploy mode {deploy_mode} not supported")

                spark_submit_cmd = os.path.join(spark_home, "bin/spark-submit")
                process_cmd = [spark_submit_cmd, f'--name={task_id}#{role}']
                for k, v in spark_submit_config.items():
                    if k != "conf":
                        process_cmd.append(f'--{k}={v}')
                if "conf" in spark_submit_config:
                    for ck, cv in spark_submit_config["conf"].items():
                        process_cmd.append(f'--conf')
                        process_cmd.append(f'{ck}={cv}')
                process_cmd.extend([
                    sys.modules[TaskExecutor.__module__].__file__,
                    '-j',
                    job_id,
                    '-n',
                    component_name,
                    '-t',
                    task_id,
                    '-v',
                    task_version,
                    '-r',
                    role,
                    '-p',
                    party_id,
                    '-c',
                    task_parameters_path,
                    '--run_ip',
                    RuntimeConfig.JOB_SERVER_HOST,
                    '--job_server',
                    '{}:{}'.format(RuntimeConfig.JOB_SERVER_HOST,
                                   RuntimeConfig.HTTP_PORT),
                ])
            else:
                raise ValueError(
                    f"${run_parameters.computing_engine} is not supported")

            task_log_dir = os.path.join(
                job_utils.get_job_log_directory(job_id=job_id), role, party_id,
                component_name)
            schedule_logger(job_id).info(
                'job {} task {} {} on {} {} executor subprocess is ready'.
                format(job_id, task_id, task_version, role, party_id))
            p = job_utils.run_subprocess(job_id=job_id,
                                         config_dir=task_dir,
                                         process_cmd=process_cmd,
                                         log_dir=task_log_dir)
            if p:
                task_info["party_status"] = TaskStatus.RUNNING
                #task_info["run_pid"] = p.pid
                task_info["start_time"] = current_timestamp()
                task_executor_process_start_status = True
            else:
                task_info["party_status"] = TaskStatus.FAILED
        except Exception as e:
            schedule_logger(job_id).exception(e)
            task_info["party_status"] = TaskStatus.FAILED
        finally:
            try:
                cls.update_task(task_info=task_info)
                cls.update_task_status(task_info=task_info)
            except Exception as e:
                schedule_logger(job_id).exception(e)
            schedule_logger(job_id).info(
                'job {} task {} {} on {} {} executor subprocess start {}'.
                format(
                    job_id, task_id, task_version, role, party_id, "success"
                    if task_executor_process_start_status else "failed"))
示例#27
0
def deploy(config_data):
    model_id = config_data.get('model_id')
    model_version = config_data.get('model_version')
    local_role = config_data.get('local').get('role')
    local_party_id = config_data.get('local').get('party_id')
    child_model_version = config_data.get('child_model_version')
    components_checkpoint = config_data.get('components_checkpoint', {})
    warning_msg = ""

    try:
        party_model_id = gen_party_model_id(model_id=model_id,
                                            role=local_role,
                                            party_id=local_party_id)
        model = PipelinedModel(model_id=party_model_id,
                               model_version=model_version)
        model_data = model.collect_models(in_bytes=True)
        if "pipeline.pipeline:Pipeline" not in model_data:
            raise Exception("Can not found pipeline file in model.")

        # check if the model could be executed the deploy process (parent/child)
        if not check_before_deploy(model):
            raise Exception('Child model could not be deployed.')

        # copy proto content from parent model and generate a child model
        deploy_model = PipelinedModel(model_id=party_model_id,
                                      model_version=child_model_version)
        shutil.copytree(src=model.model_path,
                        dst=deploy_model.model_path,
                        ignore=lambda src, names: {'checkpoint'}
                        if src == model.model_path else {})
        pipeline_model = deploy_model.read_pipeline_model()

        train_runtime_conf = json_loads(pipeline_model.train_runtime_conf)
        runtime_conf_on_party = json_loads(
            pipeline_model.runtime_conf_on_party)
        dsl_version = train_runtime_conf.get("dsl_version", "1")

        parser = get_dsl_parser_by_version(dsl_version)
        train_dsl = json_loads(pipeline_model.train_dsl)
        parent_predict_dsl = json_loads(pipeline_model.inference_dsl)

        if config_data.get('dsl') or config_data.get('predict_dsl'):
            inference_dsl = config_data.get('dsl') if config_data.get(
                'dsl') else config_data.get('predict_dsl')
            if not isinstance(inference_dsl, dict):
                inference_dsl = json_loads(inference_dsl)
        else:
            if config_data.get('cpn_list', None):
                cpn_list = config_data.pop('cpn_list')
            else:
                cpn_list = list(train_dsl.get('components', {}).keys())
            if int(dsl_version) == 1:
                # convert v1 dsl to v2 dsl
                inference_dsl, warning_msg = parser.convert_dsl_v1_to_v2(
                    parent_predict_dsl)
            else:
                parser = get_dsl_parser_by_version(dsl_version)
                inference_dsl = parser.deploy_component(cpn_list, train_dsl)

        # convert v1 conf to v2 conf
        if int(dsl_version) == 1:
            components = parser.get_components_light_weight(inference_dsl)

            from fate_flow.db.component_registry import ComponentRegistry
            job_providers = parser.get_job_providers(
                dsl=inference_dsl, provider_detail=ComponentRegistry.REGISTRY)
            cpn_role_parameters = dict()
            for cpn in components:
                cpn_name = cpn.get_name()
                role_params = parser.parse_component_role_parameters(
                    component=cpn_name,
                    dsl=inference_dsl,
                    runtime_conf=train_runtime_conf,
                    provider_detail=ComponentRegistry.REGISTRY,
                    provider_name=job_providers[cpn_name]["provider"]["name"],
                    provider_version=job_providers[cpn_name]["provider"]
                    ["version"])
                cpn_role_parameters[cpn_name] = role_params
            train_runtime_conf = parser.convert_conf_v1_to_v2(
                train_runtime_conf, cpn_role_parameters)

        adapter = JobRuntimeConfigAdapter(train_runtime_conf)
        train_runtime_conf = adapter.update_model_id_version(
            model_version=deploy_model.model_version)
        pipeline_model.model_version = child_model_version
        pipeline_model.train_runtime_conf = json_dumps(train_runtime_conf,
                                                       byte=True)

        #  save inference dsl into child model file
        parser = get_dsl_parser_by_version(2)
        parser.verify_dsl(inference_dsl, "predict")
        inference_dsl = JobSaver.fill_job_inference_dsl(
            job_id=model_version,
            role=local_role,
            party_id=local_party_id,
            dsl_parser=parser,
            origin_inference_dsl=inference_dsl)
        pipeline_model.inference_dsl = json_dumps(inference_dsl, byte=True)

        if compare_version(pipeline_model.fate_version, '1.5.0') == 'gt':
            pipeline_model.parent_info = json_dumps(
                {
                    'parent_model_id': model_id,
                    'parent_model_version': model_version
                },
                byte=True)
            pipeline_model.parent = False
            runtime_conf_on_party['job_parameters'][
                'model_version'] = child_model_version
            pipeline_model.runtime_conf_on_party = json_dumps(
                runtime_conf_on_party, byte=True)

        # save model file
        deploy_model.save_pipeline(pipeline_model)
        shutil.copyfile(
            os.path.join(deploy_model.model_path, "pipeline.pb"),
            os.path.join(deploy_model.model_path, "variables", "data",
                         "pipeline", "pipeline", "Pipeline"))

        model_info = gather_model_info_data(deploy_model)
        model_info['job_id'] = model_info['f_model_version']
        model_info['size'] = deploy_model.calculate_model_file_size()
        model_info['role'] = local_role
        model_info['party_id'] = local_party_id
        model_info['parent'] = False if model_info.get(
            'f_inference_dsl') else True
        if compare_version(model_info['f_fate_version'], '1.5.0') == 'eq':
            model_info['roles'] = model_info.get('f_train_runtime_conf',
                                                 {}).get('role', {})
            model_info['initiator_role'] = model_info.get(
                'f_train_runtime_conf', {}).get('initiator', {}).get('role')
            model_info['initiator_party_id'] = model_info.get(
                'f_train_runtime_conf', {}).get('initiator',
                                                {}).get('party_id')
        save_model_info(model_info)

        for component_name, component in train_dsl.get('components',
                                                       {}).items():
            step_index = components_checkpoint.get(component_name,
                                                   {}).get('step_index')
            step_name = components_checkpoint.get(component_name,
                                                  {}).get('step_name')
            if step_index is not None:
                step_index = int(step_index)
                step_name = None
            elif step_name is None:
                continue

            checkpoint_manager = CheckpointManager(
                role=local_role,
                party_id=local_party_id,
                model_id=model_id,
                model_version=model_version,
                component_name=component_name,
                mkdir=False,
            )
            checkpoint_manager.load_checkpoints_from_disk()
            if checkpoint_manager.latest_checkpoint is not None:
                checkpoint_manager.deploy(
                    child_model_version,
                    component['output']['model'][0] if component.get(
                        'output', {}).get('model') else 'default',
                    step_index,
                    step_name,
                )
    except Exception as e:
        stat_logger.exception(e)
        return 100, f"deploy model of role {local_role} {local_party_id} failed, details: {str(e)}"
    else:
        msg = f"deploy model of role {local_role} {local_party_id} success"
        if warning_msg:
            msg = msg + f", warning: {warning_msg}"
        return 0, msg
示例#28
0
 def db_value(self, value):
     if value is None:
         value = []
     return json_dumps(value)
示例#29
0
def deploy(config_data):
    model_id = config_data.get('model_id')
    model_version = config_data.get('model_version')
    local_role = config_data.get('local').get('role')
    local_party_id = config_data.get('local').get('party_id')
    child_model_version = config_data.get('child_model_version')

    try:
        party_model_id = model_utils.gen_party_model_id(
            model_id=model_id, role=local_role, party_id=local_party_id)
        model = PipelinedModel(model_id=party_model_id,
                               model_version=model_version)
        model_data = model.collect_models(in_bytes=True)
        if "pipeline.pipeline:Pipeline" not in model_data:
            raise Exception("Can not found pipeline file in model.")

        # check if the model could be executed the deploy process (parent/child)
        if not check_before_deploy(model):
            raise Exception('Child model could not be deployed.')

        # copy proto content from parent model and generate a child model
        deploy_model = PipelinedModel(model_id=party_model_id,
                                      model_version=child_model_version)
        shutil.copytree(src=model.model_path, dst=deploy_model.model_path)
        pipeline = deploy_model.read_component_model('pipeline',
                                                     'pipeline')['Pipeline']

        # modify two pipeline files (model version/ train_runtime_conf)
        train_runtime_conf = json_loads(pipeline.train_runtime_conf)
        adapter = JobRuntimeConfigAdapter(train_runtime_conf)
        train_runtime_conf = adapter.update_model_id_version(
            model_version=deploy_model.model_version)
        pipeline.model_version = child_model_version
        pipeline.train_runtime_conf = json_dumps(train_runtime_conf, byte=True)

        parser = get_dsl_parser_by_version(
            train_runtime_conf.get('dsl_version', '1'))
        train_dsl = json_loads(pipeline.train_dsl)
        parent_predict_dsl = json_loads(pipeline.inference_dsl)

        if str(train_runtime_conf.get('dsl_version', '1')) == '1':
            predict_dsl = json_loads(pipeline.inference_dsl)
        else:
            if config_data.get('dsl') or config_data.get('predict_dsl'):
                predict_dsl = config_data.get('dsl') if config_data.get(
                    'dsl') else config_data.get('predict_dsl')
                if not isinstance(predict_dsl, dict):
                    predict_dsl = json_loads(predict_dsl)
            else:
                if config_data.get('cpn_list', None):
                    cpn_list = config_data.pop('cpn_list')
                else:
                    cpn_list = list(train_dsl.get('components', {}).keys())
                parser_version = train_runtime_conf.get('dsl_version', '1')
                if str(parser_version) == '1':
                    predict_dsl = parent_predict_dsl
                else:
                    parser = schedule_utils.get_dsl_parser_by_version(
                        parser_version)
                    predict_dsl = parser.deploy_component(cpn_list, train_dsl)

        #  save predict dsl into child model file
        parser.verify_dsl(predict_dsl, "predict")
        inference_dsl = parser.get_predict_dsl(
            role=local_role,
            predict_dsl=predict_dsl,
            setting_conf_prefix=file_utils.
            get_federatedml_setting_conf_directory())
        pipeline.inference_dsl = json_dumps(inference_dsl, byte=True)
        if model_utils.compare_version(pipeline.fate_version, '1.5.0') == 'gt':
            pipeline.parent_info = json_dumps(
                {
                    'parent_model_id': model_id,
                    'parent_model_version': model_version
                },
                byte=True)
            pipeline.parent = False
            runtime_conf_on_party = json_loads(pipeline.runtime_conf_on_party)
            runtime_conf_on_party['job_parameters'][
                'model_version'] = child_model_version
            pipeline.runtime_conf_on_party = json_dumps(runtime_conf_on_party,
                                                        byte=True)

        # save model file
        deploy_model.save_pipeline(pipeline)
        shutil.copyfile(
            os.path.join(deploy_model.model_path, "pipeline.pb"),
            os.path.join(deploy_model.model_path, "variables", "data",
                         "pipeline", "pipeline", "Pipeline"))

        model_info = model_utils.gather_model_info_data(deploy_model)
        model_info['job_id'] = model_info['f_model_version']
        model_info['size'] = deploy_model.calculate_model_file_size()
        model_info['role'] = local_role
        model_info['party_id'] = local_party_id
        model_info['work_mode'] = adapter.get_job_work_mode()
        model_info['parent'] = False if model_info.get(
            'f_inference_dsl') else True
        if model_utils.compare_version(model_info['f_fate_version'],
                                       '1.5.0') == 'eq':
            model_info['roles'] = model_info.get('f_train_runtime_conf',
                                                 {}).get('role', {})
            model_info['initiator_role'] = model_info.get(
                'f_train_runtime_conf', {}).get('initiator', {}).get('role')
            model_info['initiator_party_id'] = model_info.get(
                'f_train_runtime_conf', {}).get('initiator',
                                                {}).get('party_id')
        model_utils.save_model_info(model_info)

    except Exception as e:
        stat_logger.exception(e)
        return 100, f"deploy model of role {local_role} {local_party_id} failed, details: {str(e)}"
    else:
        return 0, f"deploy model of role {local_role} {local_party_id} success"
示例#30
0
 def get_config(cls, config_dir, config, log_dir):
     config_path = os.path.join(config_dir, "config.json")
     with open(config_path, 'w') as fw:
         fw.write(json_dumps(config))
     result_path = os.path.join(config_dir, "result.json")
     return config_path, result_path