コード例 #1
0
 def test1(self):
     cache = DataCache(name="test_cache",
                       data={"t1": DTable(namespace="test", name="test1")},
                       meta={"t1": {
                           "a": 1
                       }})
     a = json_loads(json_dumps(cache))
     self.assertEqual(a["data"]["t1"]["namespace"], "test")
     b = json_loads(json_dumps(cache, with_type=True),
                    object_hook=from_dict_hook)
     self.assertEqual(b.data["t1"].namespace, "test")
コード例 #2
0
ファイル: job_tracker.py プロジェクト: FederatedAI/FATE-Flow
    def save_machine_learning_model_info(self):
        try:
            record = MLModel.get_or_none(MLModel.f_model_version == self.job_id,
                                         MLModel.f_role == self.role,
                                         MLModel.f_model_id == self.model_id,
                                         MLModel.f_party_id == self.party_id)
            if not record:
                job = Job.get_or_none(Job.f_job_id == self.job_id)
                pipeline = self.pipelined_model.read_pipeline_model()
                if job:
                    job_data = job.to_dict()
                    model_info = {
                        'job_id': job_data.get("f_job_id"),
                        'role': self.role,
                        'party_id': self.party_id,
                        'roles': job_data.get("f_roles"),
                        'model_id': self.model_id,
                        'model_version': self.model_version,
                        'initiator_role': job_data.get('f_initiator_role'),
                        'initiator_party_id': job_data.get('f_initiator_party_id'),
                        'runtime_conf': job_data.get('f_runtime_conf'),
                        'work_mode': job_data.get('f_work_mode'),
                        'train_dsl': job_data.get('f_dsl'),
                        'train_runtime_conf': job_data.get('f_train_runtime_conf'),
                        'size': self.get_model_size(),
                        'job_status': job_data.get('f_status'),
                        'parent': pipeline.parent,
                        'fate_version': pipeline.fate_version,
                        'runtime_conf_on_party': json_loads(pipeline.runtime_conf_on_party),
                        'parent_info': json_loads(pipeline.parent_info),
                        'inference_dsl': json_loads(pipeline.inference_dsl)
                    }
                    model_utils.save_model_info(model_info)

                    schedule_logger(self.job_id).info(
                        'save {} model info done. model id: {}, model version: {}.'.format(self.job_id,
                                                                                           self.model_id,
                                                                                           self.model_version))
                else:
                    schedule_logger(self.job_id).info(
                        'save {} model info failed, no job found in db. '
                        'model id: {}, model version: {}.'.format(self.job_id,
                                                                  self.model_id,
                                                                  self.model_version))
            else:
                schedule_logger(self.job_id).info('model {} info has already existed in database.'.format(self.job_id))
        except Exception as e:
            schedule_logger(self.job_id).exception(e)
コード例 #3
0
    def unaryCall(self, _request, context):
        packet = _request
        header = packet.header
        _suffix = packet.body.key
        param_bytes = packet.body.value
        param = bytes.decode(param_bytes)
        job_id = header.task.taskId
        src = header.src
        dst = header.dst
        method = header.operator
        param_dict = json_loads(param)
        param_dict['src_party_id'] = str(src.partyId)
        source_routing_header = []
        for key, value in context.invocation_metadata():
            source_routing_header.append((key, value))
        stat_logger.info(
            f"grpc request routing header: {source_routing_header}")

        param = bytes.decode(bytes(json_dumps(param_dict), 'utf-8'))

        action = getattr(requests, method.lower(), None)
        if action:
            print(_suffix)
            #resp = action(url=get_url(_suffix), data=param, headers=HEADERS)
        else:
            pass
        #resp_json = resp.json()
        resp_json = {"status": "test"}
        import time
        print("sleep")
        time.sleep(60)
        return wrap_grpc_packet(resp_json, method, _suffix, dst.partyId,
                                src.partyId, job_id)
コード例 #4
0
def proxy_api(role, _job_id, request_config):
    job_id = request_config.get('header').get('job_id', _job_id)
    method = request_config.get('header').get('method', 'POST')
    endpoint = request_config.get('header').get('endpoint')
    src_party_id = request_config.get('header').get('src_party_id')
    dest_party_id = request_config.get('header').get('dest_party_id')
    json_body = request_config.get('body')
    _packet = forward_grpc_packet(
        json_body,
        method,
        endpoint,
        src_party_id,
        dest_party_id,
        job_id=job_id,
        role=role,
        overall_timeout=DEFAULT_REMOTE_REQUEST_TIMEOUT)
    _routing_metadata = gen_routing_metadata(src_party_id=src_party_id,
                                             dest_party_id=dest_party_id)
    host, port, protocol = get_federated_proxy_address(src_party_id,
                                                       dest_party_id)
    channel, stub = get_command_federation_channel(host, port)
    _return, _call = stub.unaryCall.with_call(_packet,
                                              metadata=_routing_metadata)
    channel.close()
    json_body = json_loads(_return.body.value)
    return json_body
コード例 #5
0
ファイル: job_app.py プロジェクト: zeta1999/FATE
def dsl_generator():
    data = request.json
    cpn_str = data.get("cpn_str", "")
    try:
        if not cpn_str:
            raise Exception("Component list should not be empty.")
        if isinstance(cpn_str, list):
            cpn_list = cpn_str
        else:
            if (cpn_str.find("/") and cpn_str.find("\\")) != -1:
                raise Exception(
                    "Component list string should not contain '/' or '\\'.")
            cpn_str = cpn_str.replace(" ", "").replace("\n", "").strip(",[]")
            cpn_list = cpn_str.split(",")
        train_dsl = json_loads(data.get("train_dsl"))
        parser = schedule_utils.get_dsl_parser_by_version(
            data.get("version", "2"))
        predict_dsl = parser.deploy_component(cpn_list, train_dsl)

        if data.get("filename"):
            os.makedirs(TEMP_DIRECTORY, exist_ok=True)
            temp_filepath = os.path.join(TEMP_DIRECTORY, data.get("filename"))
            with open(temp_filepath, "w") as fout:
                fout.write(json.dumps(predict_dsl, indent=4))
            return send_file(open(temp_filepath, 'rb'),
                             as_attachment=True,
                             attachment_filename=data.get("filename"))
        return get_json_result(data=predict_dsl)
    except Exception as e:
        stat_logger.exception(e)
        return error_response(
            210, "DSL generating failed. For more details, "
            "please check logs/fate_flow/fate_flow_stat.log.")
コード例 #6
0
ファイル: api_utils.py プロジェクト: zeta1999/FATE
def federated_coordination_on_grpc(job_id, method, host, port, endpoint, src_party_id, src_role, dest_party_id, json_body, api_version=API_VERSION,
                                   overall_timeout=DEFAULT_REMOTE_REQUEST_TIMEOUT, try_times=3):
    endpoint = f"/{api_version}{endpoint}"
    json_body['src_role'] = src_role
    json_body['src_party_id'] = src_party_id
    if CHECK_NODES_IDENTITY:
        get_node_identity(json_body, src_party_id)
    _packet = wrap_grpc_packet(json_body, method, endpoint, src_party_id, dest_party_id, job_id,
                               overall_timeout=overall_timeout)
    _routing_metadata = gen_routing_metadata(src_party_id=src_party_id, dest_party_id=dest_party_id)
    exception = None
    for t in range(try_times):
        try:
            channel, stub = get_command_federation_channel(host, port)
            _return, _call = stub.unaryCall.with_call(_packet, metadata=_routing_metadata, timeout=(overall_timeout/1000))
            audit_logger(job_id).info("grpc api response: {}".format(_return))
            channel.close()
            response = json_loads(_return.body.value)
            return response
        except Exception as e:
            exception = e
            schedule_logger(job_id).warning(f"remote request {endpoint} error, sleep and try again")
            time.sleep(2 * (t+1))
    else:
        tips = 'Please check rollSite and fateflow network connectivity'
        """
        if 'Error received from peer' in str(exception):
            tips = 'Please check if the fate flow server of the other party is started. '
        if 'failed to connect to all addresses' in str(exception):
            tips = 'Please check whether the rollsite service(port: 9370) is started. '
        """
        raise Exception('{}rpc request error: {}'.format(tips, exception))
コード例 #7
0
ファイル: base_model.py プロジェクト: yubo1993/FATE
 def python_value(self, value):
     if self._serialized_type == SerializedType.PICKLE:
         return deserialize_b64(value)
     elif self._serialized_type == SerializedType.JSON:
         if value is None:
             return {}
         return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
     else:
         raise ValueError(f"the serialized type {self._serialized_type} is not supported")
コード例 #8
0
 def read_model_run_parameters(self):
     if not os.path.exists(self.run_parameters_path):
         return {}
     components_run_parameters = {}
     for component_name in os.listdir(self.run_parameters_path):
         p = self.component_run_parameters_path(component_name)
         with open(p, encoding="utf8") as fr:
             components_run_parameters[
                 component_name] = base_utils.json_loads(fr.read())
     return components_run_parameters
コード例 #9
0
def deploy_homo_model(request_data):
    party_model_id = model_utils.gen_party_model_id(
        model_id=request_data["model_id"],
        role=request_data["role"],
        party_id=request_data["party_id"])
    model_version = request_data["model_version"]
    component_name = request_data['component_name']
    service_id = request_data['service_id']
    framework_name = request_data.get('framework_name')
    model = pipelined_model.PipelinedModel(model_id=party_model_id,
                                           model_version=model_version)
    if not model.exists():
        return 100, 'Model {} {} does not exist'.format(
            party_model_id, model_version), None

    # get the model alias from the dsl saved with the pipeline
    pipeline = model.read_pipeline_model()
    train_dsl = json_loads(pipeline.train_dsl)
    if component_name not in train_dsl.get('components', {}):
        return 100, 'Model {} {} does not contain component {}'.\
            format(party_model_id, model_version, component_name), None

    model_alias_list = train_dsl['components'][component_name].get(
        'output', {}).get('model')
    if not model_alias_list:
        return 100, 'Component {} in Model {} {} does not have output model'. \
            format(component_name, party_model_id, model_version), None

    # currently there is only one model output
    model_alias = model_alias_list[0]
    converted_model_dir = os.path.join(model.variables_data_path,
                                       component_name, model_alias,
                                       "converted_model")
    if not os.path.isdir(converted_model_dir):
        return 100, '''Component {} in Model {} {} isn't converted'''.\
            format(component_name, party_model_id, model_version), None

    # todo: use subprocess?
    convert_tool = model.get_homo_model_convert_tool()
    if not framework_name:
        module_name = train_dsl['components'][component_name].get('module')
        buffer_obj = model.read_component_model(component_name, model_alias)
        framework_name = convert_tool.get_default_target_framework(
            model_contents=buffer_obj, module_name=module_name)

    model_object = convert_tool.load_converted_model(
        base_dir=converted_model_dir, framework_name=framework_name)
    deployed_service = model_deploy(party_model_id, model_version,
                                    model_object, framework_name, service_id,
                                    request_data['deployment_type'],
                                    request_data['deployment_parameters'])
    return (
        0,
        f"An online serving service is started in the {request_data['deployment_type']} system.",
        deployed_service)
コード例 #10
0
ファイル: data_access_app.py プロジェクト: zark7777/FATE
def download_upload(access_module):
    job_id = job_utils.generate_job_id()
    if access_module == "upload" and UPLOAD_DATA_FROM_CLIENT and not (request.json and request.json.get("use_local_data") == 0):
        file = request.files['file']
        filename = os.path.join(job_utils.get_job_directory(job_id), 'fate_upload_tmp', file.filename)
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        try:
            file.save(filename)
        except Exception as e:
            shutil.rmtree(os.path.join(job_utils.get_job_directory(job_id), 'fate_upload_tmp'))
            raise e
        job_config = request.args.to_dict()
        if "namespace" in job_config and "table_name" in job_config:
            pass
        else:
            # higher than version 1.5.1, support eggroll run parameters
            job_config = json_loads(list(job_config.keys())[0])
        job_config['file'] = filename
    else:
        job_config = request.json
    required_arguments = ['work_mode', 'namespace', 'table_name']
    if access_module == 'upload':
        required_arguments.extend(['file', 'head', 'partition'])
    elif access_module == 'download':
        required_arguments.extend(['output_path'])
    else:
        raise Exception('can not support this operating: {}'.format(access_module))
    detect_utils.check_config(job_config, required_arguments=required_arguments)
    data = {}
    # compatibility
    if "table_name" in job_config:
        job_config["name"] = job_config["table_name"]
    if "backend" not in job_config:
        job_config["backend"] = 0
    for _ in ["work_mode", "backend", "head", "partition", "drop"]:
        if _ in job_config:
            job_config[_] = int(job_config[_])
    if access_module == "upload":
        if job_config.get('drop', 0) == 1:
            job_config["destroy"] = True
        else:
            job_config["destroy"] = False
        data['table_name'] = job_config["table_name"]
        data['namespace'] = job_config["namespace"]
        data_table_meta = storage.StorageTableMeta(name=job_config["table_name"], namespace=job_config["namespace"])
        if data_table_meta and not job_config["destroy"]:
            return get_json_result(retcode=100,
                                   retmsg='The data table already exists.'
                                          'If you still want to continue uploading, please add the parameter -drop.'
                                          ' 0 means not to delete and continue uploading, '
                                          '1 means to upload again after deleting the table')
    job_dsl, job_runtime_conf = gen_data_access_job_config(job_config, access_module)
    submit_result = DAGScheduler.submit({'job_dsl': job_dsl, 'job_runtime_conf': job_runtime_conf}, job_id=job_id)
    data.update(submit_result)
    return get_json_result(job_id=job_id, data=data)
コード例 #11
0
ファイル: model_app.py プロジェクト: zeta1999/FATE
def get_predict_conf():
    request_data = request.json
    required_parameters = ['model_id', 'model_version']
    check_config(request_data, required_parameters)
    model_dir = os.path.join(get_project_base_directory(), 'model_local_cache')
    model_fp_list = glob.glob(
        model_dir +
        f"/guest#*#{request_data['model_id']}/{request_data['model_version']}")
    if model_fp_list:
        fp = model_fp_list[0]
        pipeline_model = PipelinedModel(model_id=fp.split('/')[-2],
                                        model_version=fp.split('/')[-1])
        pipeline = pipeline_model.read_component_model('pipeline',
                                                       'pipeline')['Pipeline']
        predict_dsl = json_loads(pipeline.inference_dsl)

        train_runtime_conf = json_loads(pipeline.train_runtime_conf)
        parser = schedule_utils.get_dsl_parser_by_version(
            train_runtime_conf.get('dsl_version', '1'))
        predict_conf = parser.generate_predict_conf_template(
            predict_dsl=predict_dsl,
            train_conf=train_runtime_conf,
            model_id=request_data['model_id'],
            model_version=request_data['model_version'])
    else:
        predict_conf = ''
    if predict_conf:
        if request_data.get("filename"):
            os.makedirs(TEMP_DIRECTORY, exist_ok=True)
            temp_filepath = os.path.join(TEMP_DIRECTORY,
                                         request_data.get("filename"))
            with open(temp_filepath, "w") as fout:

                fout.write(json_dumps(predict_conf, indent=4))
            return send_file(open(temp_filepath, "rb"),
                             as_attachment=True,
                             attachment_filename=request_data.get("filename"))
        else:
            return get_json_result(data=predict_conf)
    return error_response(
        210,
        "No model found, please check if arguments are specified correctly.")
コード例 #12
0
def gather_model_info_data(model: PipelinedModel, query_filters=None):
    if model.exists():
        pipeline = model.read_pipeline_model()
        model_info = OrderedDict()
        if query_filters and isinstance(query_filters, list):
            for attr, field in pipeline.ListFields():
                if attr.name in query_filters:
                    if isinstance(field, bytes):
                        model_info["f_" + attr.name] = json_loads(
                            field, OrderedDict)
                    else:
                        model_info["f_" + attr.name] = field
        else:
            for attr, field in pipeline.ListFields():
                if isinstance(field, bytes):
                    model_info["f_" + attr.name] = json_loads(
                        field, OrderedDict)
                else:
                    model_info["f_" + attr.name] = field
        return model_info
    return []
コード例 #13
0
def federation_cleanup(job, task):
    from fate_arch.common import Backend
    from fate_arch.common import Party

    runtime_conf = json_loads(job.f_runtime_conf_on_party)
    job_parameters = runtime_conf['job_parameters']
    backend = Backend(job_parameters.get('backend', 0))
    store_engine = StoreEngine(job_parameters.get('store_engine', 0))

    if backend.is_spark() and store_engine.is_hdfs():
        runtime_conf['local'] = {'role': job.f_role, 'party_id': job.f_party_id}
        parties = [Party(k, p) for k,v in runtime_conf['role'].items() for p in v ]
        from fate_arch.session.spark import Session
        ssn = Session(session_id=task.f_task_id)
        ssn.init_federation(federation_session_id=task.f_task_id, runtime_conf=runtime_conf)
        ssn._get_federation().generate_mq_names(parties=parties)
        ssn._get_federation().cleanup()
コード例 #14
0
def check_if_deployed(role, party_id, model_id, model_version):
    party_model_id = gen_party_model_id(model_id=model_id,
                                        role=role,
                                        party_id=party_id)
    pipeline_model = PipelinedModel(model_id=party_model_id,
                                    model_version=model_version)
    if not pipeline_model.exists():
        raise Exception(
            f"Model {party_model_id} {model_version} not exists in model local cache."
        )

    pipeline = pipeline_model.read_pipeline_model()
    if compare_version(pipeline.fate_version, '1.5.0') == 'gt':
        train_runtime_conf = json_loads(pipeline.train_runtime_conf)
        if str(train_runtime_conf.get('dsl_version', '1')) != '1':
            if pipeline.parent:
                return False
    return True
コード例 #15
0
def remote_api(job_id,
               method,
               endpoint,
               src_party_id,
               dest_party_id,
               src_role,
               json_body,
               api_version=API_VERSION,
               overall_timeout=DEFAULT_GRPC_OVERALL_TIMEOUT,
               try_times=3):
    endpoint = f"/{api_version}{endpoint}"
    json_body['src_role'] = src_role
    if CHECK_NODES_IDENTITY:
        get_node_identity(json_body, src_party_id)
    _packet = wrap_grpc_packet(json_body,
                               method,
                               endpoint,
                               src_party_id,
                               dest_party_id,
                               job_id,
                               overall_timeout=overall_timeout)
    _routing_metadata = get_routing_metadata(src_party_id=src_party_id,
                                             dest_party_id=dest_party_id)
    exception = None
    for t in range(try_times):
        try:
            channel, stub = get_command_federation_channel()
            _return, _call = stub.unaryCall.with_call(
                _packet,
                metadata=_routing_metadata,
                timeout=(overall_timeout / 1000))
            audit_logger(job_id).info("grpc api response: {}".format(_return))
            channel.close()
            response = json_loads(_return.body.value)
            return response
        except Exception as e:
            exception = e
    else:
        tips = ''
        if 'Error received from peer' in str(exception):
            tips = 'Please check if the fate flow server of the other party is started. '
        if 'failed to connect to all addresses' in str(exception):
            tips = 'Please check whether the rollsite service(port: 9370) is started. '
        raise Exception('{}rpc request error: {}'.format(tips, exception))
コード例 #16
0
    def read(self, parse_models: bool = True, include_database: bool = False):
        data = self.read_database()

        with self.lock:
            for model_name, model in data['models'].items():
                model['filepath_pb'] = self.directory / f'{model_name}.pb'
                model['filepath_json'] = self.directory / f'{model_name}.json'
                if not model['filepath_pb'].exists(
                ) or not model['filepath_json'].exists():
                    raise FileNotFoundError(
                        'Checkpoint is incorrect: protobuf file or json file not found. '
                        f'protobuf filepath: {model["filepath_pb"]} json filepath: {model["filepath_json"]}'
                    )

            model_data = {
                model_name: (
                    model['filepath_pb'].read_bytes(),
                    json_loads(model['filepath_json'].read_text('utf8')),
                )
                for model_name, model in data['models'].items()
            }

        for model_name, model in data['models'].items():
            serialized_string, json_format_dict = model_data[model_name]

            sha1 = hashlib.sha1(serialized_string).hexdigest()
            if sha1 != model['sha1']:
                raise ValueError(
                    'Checkpoint may be incorrect: hash dose not match. '
                    f'filepath: {model["filepath"]} expected: {model["sha1"]} actual: {sha1}'
                )

        data['models'] = {
            model_name: (
                model['buffer_name'],
                *model_data[model_name],
            ) if parse_models else b64encode(
                model_data[model_name][0]).decode('ascii')
            for model_name, model in data['models'].items()
        }
        return data if include_database else data['models']
コード例 #17
0
ファイル: grpc_utils.py プロジェクト: zeta1999/FATE
    def unaryCall(self, _request, context):
        packet = _request
        header = packet.header
        _suffix = packet.body.key
        param_bytes = packet.body.value
        param = bytes.decode(param_bytes)
        job_id = header.task.taskId
        src = header.src
        dst = header.dst
        method = header.operator
        param_dict = json_loads(param)
        param_dict['src_party_id'] = str(src.partyId)
        source_routing_header = []
        for key, value in context.invocation_metadata():
            source_routing_header.append((key, value))

        _routing_metadata = gen_routing_metadata(src_party_id=src.partyId,
                                                 dest_party_id=dst.partyId)
        context.set_trailing_metadata(trailing_metadata=_routing_metadata)
        try:
            nodes_check(param_dict.get('src_party_id'),
                        param_dict.get('_src_role'), param_dict.get('appKey'),
                        param_dict.get('appSecret'), str(dst.partyId))
        except Exception as e:
            resp_json = {"retcode": 100, "retmsg": str(e)}
            return wrap_grpc_packet(resp_json, method, _suffix, dst.partyId,
                                    src.partyId, job_id)
        param = bytes.decode(bytes(json_dumps(param_dict), 'utf-8'))

        action = getattr(requests, method.lower(), None)
        audit_logger(job_id).info('rpc receive: {}'.format(packet))
        if action:
            audit_logger(job_id).info("rpc receive: {} {}".format(
                get_url(_suffix), param))
            resp = action(url=get_url(_suffix), data=param, headers=HEADERS)
        else:
            pass
        resp_json = resp.json()
        return wrap_grpc_packet(resp_json, method, _suffix, dst.partyId,
                                src.partyId, job_id)
コード例 #18
0
    def _read_component_model(self, component_name, model_alias):
        component_model_storage_path = os.path.join(self.variables_data_path,
                                                    component_name,
                                                    model_alias)
        model_proto_index = self.get_model_proto_index(
            component_name=component_name, model_alias=model_alias)

        model_buffers = {}
        for model_name, buffer_name in model_proto_index.items():
            storage_path = os.path.join(component_model_storage_path,
                                        model_name)

            with open(os.path.join(component_model_storage_path, model_name),
                      "rb") as f:
                buffer_object_serialized_string = f.read()

            try:
                with open(f"{storage_path}.json", encoding="utf8") as f:
                    buffer_object_json_format = base_utils.json_loads(f.read())
            except FileNotFoundError:
                buffer_object_json_format = ""
                # todo: should be running in worker
                """
                buffer_object_json_format = json_format.MessageToDict(
                    parse_proto_object(buffer_name, buffer_object_serialized_string),
                    including_default_value_fields=True
                )
                with self.lock, open(f"{storage_path}.json", "w", encoding="utf8") as f:
                    f.write(base_utils.json_dumps(buffer_object_json_format))
                """

            model_buffers[model_name] = (
                buffer_name,
                buffer_object_serialized_string,
                buffer_object_json_format,
            )

        return model_buffers
コード例 #19
0
ファイル: base_model.py プロジェクト: yubo1993/FATE
 def python_value(self, value):
     if value is None:
         value = "[]"
     return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
コード例 #20
0
ファイル: dag_scheduler.py プロジェクト: tarada/FATE
    def submit(cls, job_data, job_id=None):
        if not job_id:
            job_id = job_utils.generate_job_id()
        schedule_logger(job_id).info('submit job, job_id {}, body {}'.format(job_id, job_data))
        job_dsl = job_data.get('job_dsl', {})
        job_runtime_conf = job_data.get('job_runtime_conf', {})
        job_initiator = job_runtime_conf['initiator']
        job_parameters = RunParameters(**job_runtime_conf['job_parameters'])
        cls.backend_compatibility(job_parameters=job_parameters)

        job_utils.check_job_runtime_conf(job_runtime_conf)
        if job_parameters.job_type != 'predict':
            # generate job model info
            job_parameters.model_id = model_utils.gen_model_id(job_runtime_conf['role'])
            job_parameters.model_version = job_id
            train_runtime_conf = {}
        else:
            detect_utils.check_config(job_parameters.to_dict(), ['model_id', 'model_version'])
            # get inference dsl from pipeline model as job dsl
            tracker = Tracker(job_id=job_id, role=job_initiator['role'], party_id=job_initiator['party_id'],
                              model_id=job_parameters.model_id, model_version=job_parameters.model_version)
            pipeline_model = tracker.get_output_model('pipeline')
            if not job_dsl:
                job_dsl = json_loads(pipeline_model['Pipeline'].inference_dsl)
            train_runtime_conf = json_loads(pipeline_model['Pipeline'].train_runtime_conf)

        path_dict = job_utils.save_job_conf(job_id=job_id,
                                            job_dsl=job_dsl,
                                            job_runtime_conf=job_runtime_conf,
                                            train_runtime_conf=train_runtime_conf,
                                            pipeline_dsl=None)

        job = Job()
        job.f_job_id = job_id
        job.f_dsl = job_dsl
        job_runtime_conf["job_parameters"] = job_parameters.to_dict()
        job.f_runtime_conf = job_runtime_conf
        job.f_train_runtime_conf = train_runtime_conf
        job.f_roles = job_runtime_conf['role']
        job.f_work_mode = job_parameters.work_mode
        job.f_initiator_role = job_initiator['role']
        job.f_initiator_party_id = job_initiator['party_id']

        initiator_role = job_initiator['role']
        initiator_party_id = job_initiator['party_id']
        if initiator_party_id not in job_runtime_conf['role'][initiator_role]:
            schedule_logger(job_id).info("initiator party id error:{}".format(initiator_party_id))
            raise Exception("initiator party id error {}".format(initiator_party_id))

        dsl_parser = schedule_utils.get_job_dsl_parser(dsl=job_dsl,
                                                       runtime_conf=job_runtime_conf,
                                                       train_runtime_conf=train_runtime_conf)

        cls.adapt_job_parameters(job_parameters=job_parameters)

        # update runtime conf
        job_runtime_conf["job_parameters"] = job_parameters.to_dict()
        job.f_runtime_conf = job_runtime_conf

        status_code, response = FederatedScheduler.create_job(job=job)
        if status_code != FederatedSchedulingStatusCode.SUCCESS:
            raise Exception("create job failed: {}".format(response))

        if job_parameters.work_mode == WorkMode.CLUSTER:
            # Save the status information of all participants in the initiator for scheduling
            for role, party_ids in job_runtime_conf["role"].items():
                for party_id in party_ids:
                    if role == job_initiator['role'] and party_id == job_initiator['party_id']:
                        continue
                    JobController.initialize_tasks(job_id, role, party_id, False, job_initiator, job_parameters, dsl_parser)

        # push into queue
        try:
            JobQueue.create_event(job_id=job_id, initiator_role=initiator_role, initiator_party_id=initiator_party_id)
        except Exception as e:
            raise Exception(f'push job into queue failed:\n{e}')

        schedule_logger(job_id).info(
            'submit job successfully, job id is {}, model id is {}'.format(job.f_job_id, job_parameters.model_id))
        board_url = "http://{}:{}{}".format(
            ServiceUtils.get_item("fateboard", "host"),
            ServiceUtils.get_item("fateboard", "port"),
            FATE_BOARD_DASHBOARD_ENDPOINT).format(job_id, job_initiator['role'], job_initiator['party_id'])
        logs_directory = job_utils.get_job_log_directory(job_id)
        return job_id, path_dict['job_dsl_path'], path_dict['job_runtime_conf_path'], logs_directory, \
               {'model_id': job_parameters.model_id, 'model_version': job_parameters.model_version}, board_url
コード例 #21
0
ファイル: model_app.py プロジェクト: zeta1999/FATE
def operate_model(model_operation):
    request_config = request.json or request.form.to_dict()
    job_id = job_utils.generate_job_id()
    if model_operation not in [
            ModelOperation.STORE, ModelOperation.RESTORE,
            ModelOperation.EXPORT, ModelOperation.IMPORT
    ]:
        raise Exception(
            'Can not support this operating now: {}'.format(model_operation))
    required_arguments = ["model_id", "model_version", "role", "party_id"]
    check_config(request_config, required_arguments=required_arguments)
    request_config["model_id"] = gen_party_model_id(
        model_id=request_config["model_id"],
        role=request_config["role"],
        party_id=request_config["party_id"])
    if model_operation in [ModelOperation.EXPORT, ModelOperation.IMPORT]:
        if model_operation == ModelOperation.IMPORT:
            try:
                file = request.files.get('file')
                file_path = os.path.join(TEMP_DIRECTORY, file.filename)
                # if not os.path.exists(file_path):
                #     raise Exception('The file is obtained from the fate flow client machine, but it does not exist, '
                #                     'please check the path: {}'.format(file_path))
                try:
                    os.makedirs(os.path.dirname(file_path), exist_ok=True)
                    file.save(file_path)
                except Exception as e:
                    shutil.rmtree(file_path)
                    raise e
                request_config['file'] = file_path
                model = pipelined_model.PipelinedModel(
                    model_id=request_config["model_id"],
                    model_version=request_config["model_version"])
                model.unpack_model(file_path)

                pipeline = model.read_component_model('pipeline',
                                                      'pipeline')['Pipeline']
                train_runtime_conf = json_loads(pipeline.train_runtime_conf)
                permitted_party_id = []
                for key, value in train_runtime_conf.get('role', {}).items():
                    for v in value:
                        permitted_party_id.extend([v, str(v)])
                if request_config["party_id"] not in permitted_party_id:
                    shutil.rmtree(model.model_path)
                    raise Exception(
                        "party id {} is not in model roles, please check if the party id is valid."
                    )
                try:
                    adapter = JobRuntimeConfigAdapter(train_runtime_conf)
                    job_parameters = adapter.get_common_parameters().to_dict()
                    with DB.connection_context():
                        db_model = MLModel.get_or_none(
                            MLModel.f_job_id == job_parameters.get(
                                "model_version"),
                            MLModel.f_role == request_config["role"])
                    if not db_model:
                        model_info = model_utils.gather_model_info_data(model)
                        model_info['imported'] = 1
                        model_info['job_id'] = model_info['f_model_version']
                        model_info['size'] = model.calculate_model_file_size()
                        model_info['role'] = request_config["model_id"].split(
                            '#')[0]
                        model_info['party_id'] = request_config[
                            "model_id"].split('#')[1]
                        if model_utils.compare_version(
                                model_info['f_fate_version'], '1.5.1') == 'lt':
                            model_info['roles'] = model_info.get(
                                'f_train_runtime_conf', {}).get('role', {})
                            model_info['initiator_role'] = model_info.get(
                                'f_train_runtime_conf',
                                {}).get('initiator', {}).get('role')
                            model_info['initiator_party_id'] = model_info.get(
                                'f_train_runtime_conf',
                                {}).get('initiator', {}).get('party_id')
                            model_info[
                                'work_mode'] = adapter.get_job_work_mode()
                            model_info['parent'] = False if model_info.get(
                                'f_inference_dsl') else True
                        model_utils.save_model_info(model_info)
                    else:
                        stat_logger.info(
                            f'job id: {job_parameters.get("model_version")}, '
                            f'role: {request_config["role"]} model info already existed in database.'
                        )
                except peewee.IntegrityError as e:
                    stat_logger.exception(e)
                operation_record(request_config, "import", "success")
                return get_json_result()
            except Exception:
                operation_record(request_config, "import", "failed")
                raise
        else:
            try:
                model = pipelined_model.PipelinedModel(
                    model_id=request_config["model_id"],
                    model_version=request_config["model_version"])
                if model.exists():
                    archive_file_path = model.packaging_model()
                    operation_record(request_config, "export", "success")
                    return send_file(archive_file_path,
                                     attachment_filename=os.path.basename(
                                         archive_file_path),
                                     as_attachment=True)
                else:
                    operation_record(request_config, "export", "failed")
                    res = error_response(
                        response_code=210,
                        retmsg="Model {} {} is not exist.".format(
                            request_config.get("model_id"),
                            request_config.get("model_version")))
                    return res
            except Exception as e:
                operation_record(request_config, "export", "failed")
                stat_logger.exception(e)
                return error_response(response_code=210, retmsg=str(e))
    else:
        data = {}
        job_dsl, job_runtime_conf = gen_model_operation_job_config(
            request_config, model_operation)
        submit_result = DAGScheduler.submit(
            {
                'job_dsl': job_dsl,
                'job_runtime_conf': job_runtime_conf
            },
            job_id=job_id)
        data.update(submit_result)
        operation_record(data=job_runtime_conf,
                         oper_type=model_operation,
                         oper_status='')
        return get_json_result(job_id=job_id, data=data)
コード例 #22
0
ファイル: migrate_model.py プロジェクト: zhlj98/FATE
def migration(config_data: dict):
    try:
        party_model_id = model_utils.gen_party_model_id(
            model_id=config_data["model_id"],
            role=config_data["local"]["role"],
            party_id=config_data["local"]["party_id"])
        model = pipelined_model.PipelinedModel(
            model_id=party_model_id,
            model_version=config_data["model_version"])
        if not model.exists():
            raise Exception("Can not found {} {} model local cache".format(
                config_data["model_id"], config_data["model_version"]))
        with DB.connection_context():
            if MLModel.get_or_none(MLModel.f_model_version ==
                                   config_data["unify_model_version"]):
                raise Exception(
                    "Unify model version {} has been occupied in database. "
                    "Please choose another unify model version and try again.".
                    format(config_data["unify_model_version"]))

        model_data = model.collect_models(in_bytes=True)
        if "pipeline.pipeline:Pipeline" not in model_data:
            raise Exception("Can not found pipeline file in model.")

        migrate_model = pipelined_model.PipelinedModel(
            model_id=model_utils.gen_party_model_id(
                model_id=model_utils.gen_model_id(config_data["migrate_role"]),
                role=config_data["local"]["role"],
                party_id=config_data["local"]["migrate_party_id"]),
            model_version=config_data["unify_model_version"])

        # migrate_model.create_pipelined_model()
        shutil.copytree(src=model.model_path, dst=migrate_model.model_path)

        pipeline = migrate_model.read_component_model('pipeline',
                                                      'pipeline')['Pipeline']

        # Utilize Pipeline_model collect model data. And modify related inner information of model
        train_runtime_conf = json_loads(pipeline.train_runtime_conf)
        train_runtime_conf["role"] = config_data["migrate_role"]
        train_runtime_conf["initiator"] = config_data["migrate_initiator"]

        adapter = JobRuntimeConfigAdapter(train_runtime_conf)
        train_runtime_conf = adapter.update_model_id_version(
            model_id=model_utils.gen_model_id(train_runtime_conf["role"]),
            model_version=migrate_model.model_version)

        # update pipeline.pb file
        pipeline.train_runtime_conf = json_dumps(train_runtime_conf, byte=True)
        pipeline.model_id = bytes(
            adapter.get_common_parameters().to_dict.get("model_id"), "utf-8")
        pipeline.model_version = bytes(
            adapter.get_common_parameters().to_dict().get("model_version"),
            "utf-8")

        # save updated pipeline.pb file
        migrate_model.save_pipeline(pipeline)
        shutil.copyfile(
            os.path.join(migrate_model.model_path, "pipeline.pb"),
            os.path.join(migrate_model.model_path, "variables", "data",
                         "pipeline", "pipeline", "Pipeline"))

        # modify proto
        with open(
                os.path.join(migrate_model.model_path, 'define',
                             'define_meta.yaml'), 'r') as fin:
            define_yaml = yaml.safe_load(fin)

        for key, value in define_yaml['model_proto'].items():
            if key == 'pipeline':
                continue
            for v in value.keys():
                buffer_obj = migrate_model.read_component_model(key, v)
                module_name = define_yaml['component_define'].get(
                    key, {}).get('module_name')
                modified_buffer = model_migration(
                    model_contents=buffer_obj,
                    module_name=module_name,
                    old_guest_list=config_data['role']['guest'],
                    new_guest_list=config_data['migrate_role']['guest'],
                    old_host_list=config_data['role']['host'],
                    new_host_list=config_data['migrate_role']['host'],
                    old_arbiter_list=config_data.get('role',
                                                     {}).get('arbiter', None),
                    new_arbiter_list=config_data.get('migrate_role',
                                                     {}).get('arbiter', None))
                migrate_model.save_component_model(
                    component_name=key,
                    component_module_name=module_name,
                    model_alias=v,
                    model_buffers=modified_buffer)

        archive_path = migrate_model.packaging_model()
        shutil.rmtree(os.path.abspath(migrate_model.model_path))

        return (0, f"Migrating model successfully. " \
                  "The configuration of model has been modified automatically. " \
                  "New model id is: {}, model version is: {}. " \
                  "Model files can be found at '{}'.".format(adapter.get_common_parameters()["model_id"],
                                                             migrate_model.model_version,
                                                             os.path.abspath(archive_path)),
                {"model_id": migrate_model.model_id,
                 "model_version": migrate_model.model_version,
                 "path": os.path.abspath(archive_path)})

    except Exception as e:
        return 100, str(e), {}
コード例 #23
0
    def submit(cls, submit_job_conf: JobConfigurationBase, job_id: str = None):
        if not job_id:
            job_id = job_utils.generate_job_id()
        submit_result = {"job_id": job_id}
        schedule_logger(job_id).info(
            f"submit job, body {submit_job_conf.to_dict()}")
        try:
            dsl = submit_job_conf.dsl
            runtime_conf = deepcopy(submit_job_conf.runtime_conf)
            job_utils.check_job_runtime_conf(runtime_conf)
            authentication_utils.check_constraint(runtime_conf, dsl)
            job_initiator = runtime_conf["initiator"]
            conf_adapter = JobRuntimeConfigAdapter(runtime_conf)
            common_job_parameters = conf_adapter.get_common_parameters()

            if common_job_parameters.job_type != "predict":
                # generate job model info
                conf_version = schedule_utils.get_conf_version(runtime_conf)
                if conf_version != 2:
                    raise Exception(
                        "only the v2 version runtime conf is supported")
                common_job_parameters.model_id = model_utils.gen_model_id(
                    runtime_conf["role"])
                common_job_parameters.model_version = job_id
                train_runtime_conf = {}
            else:
                # check predict job parameters
                detect_utils.check_config(common_job_parameters.to_dict(),
                                          ["model_id", "model_version"])
                # get inference dsl from pipeline model as job dsl
                tracker = Tracker(
                    job_id=job_id,
                    role=job_initiator["role"],
                    party_id=job_initiator["party_id"],
                    model_id=common_job_parameters.model_id,
                    model_version=common_job_parameters.model_version)
                pipeline_model = tracker.get_pipeline_model()
                train_runtime_conf = json_loads(
                    pipeline_model.train_runtime_conf)
                if not model_utils.check_if_deployed(
                        role=job_initiator["role"],
                        party_id=job_initiator["party_id"],
                        model_id=common_job_parameters.model_id,
                        model_version=common_job_parameters.model_version):
                    raise Exception(
                        f"Model {common_job_parameters.model_id} {common_job_parameters.model_version} has not been deployed yet."
                    )
                dsl = json_loads(pipeline_model.inference_dsl)
            # dsl = ProviderManager.fill_fate_flow_provider(dsl)

            job = Job()
            job.f_job_id = job_id
            job.f_dsl = dsl
            job.f_train_runtime_conf = train_runtime_conf
            job.f_roles = runtime_conf["role"]
            job.f_initiator_role = job_initiator["role"]
            job.f_initiator_party_id = job_initiator["party_id"]
            job.f_role = job_initiator["role"]
            job.f_party_id = job_initiator["party_id"]

            path_dict = job_utils.save_job_conf(
                job_id=job_id,
                role=job.f_initiator_role,
                party_id=job.f_initiator_party_id,
                dsl=dsl,
                runtime_conf=runtime_conf,
                runtime_conf_on_party={},
                train_runtime_conf=train_runtime_conf,
                pipeline_dsl=None)

            if job.f_initiator_party_id not in runtime_conf["role"][
                    job.f_initiator_role]:
                msg = f"initiator party id {job.f_initiator_party_id} not in roles {runtime_conf['role']}"
                schedule_logger(job_id).info(msg)
                raise Exception(msg)

            # create common parameters on initiator
            JobController.create_common_job_parameters(
                job_id=job.f_job_id,
                initiator_role=job.f_initiator_role,
                common_job_parameters=common_job_parameters)
            job.f_runtime_conf = conf_adapter.update_common_parameters(
                common_parameters=common_job_parameters)
            dsl_parser = schedule_utils.get_job_dsl_parser(
                dsl=job.f_dsl,
                runtime_conf=job.f_runtime_conf,
                train_runtime_conf=job.f_train_runtime_conf)

            # initiator runtime conf as template
            job.f_runtime_conf_on_party = job.f_runtime_conf.copy()
            job.f_runtime_conf_on_party[
                "job_parameters"] = common_job_parameters.to_dict()

            # inherit job
            job.f_inheritance_info = common_job_parameters.inheritance_info
            job.f_inheritance_status = JobInheritanceStatus.WAITING if common_job_parameters.inheritance_info else JobInheritanceStatus.PASS
            if job.f_inheritance_info:
                inheritance_jobs = JobSaver.query_job(
                    job_id=job.f_inheritance_info.get("job_id"),
                    role=job_initiator["role"],
                    party_id=job_initiator["party_id"])
                inheritance_tasks = JobSaver.query_task(
                    job_id=job.f_inheritance_info.get("job_id"),
                    role=job_initiator["role"],
                    party_id=job_initiator["party_id"],
                    only_latest=True)
                job_utils.check_job_inheritance_parameters(
                    job, inheritance_jobs, inheritance_tasks)

            status_code, response = FederatedScheduler.create_job(job=job)
            if status_code != FederatedSchedulingStatusCode.SUCCESS:
                job.f_status = JobStatus.FAILED
                job.f_tag = "submit_failed"
                FederatedScheduler.sync_job_status(job=job)
                raise Exception("create job failed", response)
            else:
                need_run_components = {}
                for role in response:
                    need_run_components[role] = {}
                    for party, res in response[role].items():
                        need_run_components[role][party] = [
                            name for name, value in response[role][party]
                            ["data"]["components"].items()
                            if value["need_run"] is True
                        ]
                if common_job_parameters.federated_mode == FederatedMode.MULTIPLE:
                    # create the task holder in db to record information of all participants in the initiator for scheduling
                    for role, party_ids in job.f_roles.items():
                        for party_id in party_ids:
                            if role == job.f_initiator_role and party_id == job.f_initiator_party_id:
                                continue
                            if not need_run_components[role][party_id]:
                                continue
                            JobController.initialize_tasks(
                                job_id=job_id,
                                role=role,
                                party_id=party_id,
                                run_on_this_party=False,
                                initiator_role=job.f_initiator_role,
                                initiator_party_id=job.f_initiator_party_id,
                                job_parameters=common_job_parameters,
                                dsl_parser=dsl_parser,
                                components=need_run_components[role][party_id])
                job.f_status = JobStatus.WAITING
                status_code, response = FederatedScheduler.sync_job_status(
                    job=job)
                if status_code != FederatedSchedulingStatusCode.SUCCESS:
                    raise Exception("set job to waiting status failed")

            schedule_logger(job_id).info(
                f"submit job successfully, job id is {job.f_job_id}, model id is {common_job_parameters.model_id}"
            )
            logs_directory = job_utils.get_job_log_directory(job_id)
            result = {
                "code":
                RetCode.SUCCESS,
                "message":
                "success",
                "model_info": {
                    "model_id": common_job_parameters.model_id,
                    "model_version": common_job_parameters.model_version
                },
                "logs_directory":
                logs_directory,
                "board_url":
                job_utils.get_board_url(job_id, job_initiator["role"],
                                        job_initiator["party_id"])
            }
            warn_parameter = JobRuntimeConfigAdapter(
                submit_job_conf.runtime_conf).check_removed_parameter()
            if warn_parameter:
                result[
                    "message"] = f"[WARN]{warn_parameter} is removed,it does not take effect!"
            submit_result.update(result)
            submit_result.update(path_dict)
        except Exception as e:
            submit_result["code"] = RetCode.OPERATING_ERROR
            submit_result["message"] = exception_to_trace_string(e)
            schedule_logger(job_id).exception(e)
        return submit_result
コード例 #24
0
ファイル: dag_scheduler.py プロジェクト: kunchengit/FATE
    def submit(cls, job_data, job_id=None):
        if not job_id:
            job_id = job_utils.generate_job_id()
        schedule_logger(job_id).info('submit job, job_id {}, body {}'.format(
            job_id, job_data))
        job_dsl = job_data.get('job_dsl', {})
        job_runtime_conf = job_data.get('job_runtime_conf', {})
        job_utils.check_job_runtime_conf(job_runtime_conf)
        authentication_utils.check_constraint(job_runtime_conf, job_dsl)

        job_initiator = job_runtime_conf['initiator']
        conf_adapter = JobRuntimeConfigAdapter(job_runtime_conf)
        common_job_parameters = conf_adapter.get_common_parameters()

        if common_job_parameters.job_type != 'predict':
            # generate job model info
            common_job_parameters.model_id = model_utils.gen_model_id(
                job_runtime_conf['role'])
            common_job_parameters.model_version = job_id
            train_runtime_conf = {}
        else:
            # check predict job parameters
            detect_utils.check_config(common_job_parameters.to_dict(),
                                      ['model_id', 'model_version'])
            # get inference dsl from pipeline model as job dsl
            tracker = Tracker(
                job_id=job_id,
                role=job_initiator['role'],
                party_id=job_initiator['party_id'],
                model_id=common_job_parameters.model_id,
                model_version=common_job_parameters.model_version)
            pipeline_model = tracker.get_output_model('pipeline')
            train_runtime_conf = json_loads(
                pipeline_model['Pipeline'].train_runtime_conf)
            if not model_utils.check_if_deployed(
                    role=job_initiator['role'],
                    party_id=job_initiator['party_id'],
                    model_id=common_job_parameters.model_id,
                    model_version=common_job_parameters.model_version):
                raise Exception(
                    f"Model {common_job_parameters.model_id} {common_job_parameters.model_version} has not been deployed yet."
                )
            job_dsl = json_loads(pipeline_model['Pipeline'].inference_dsl)

        job = Job()
        job.f_job_id = job_id
        job.f_dsl = job_dsl
        job.f_train_runtime_conf = train_runtime_conf
        job.f_roles = job_runtime_conf['role']
        job.f_work_mode = common_job_parameters.work_mode
        job.f_initiator_role = job_initiator['role']
        job.f_initiator_party_id = job_initiator['party_id']
        job.f_role = job_initiator['role']
        job.f_party_id = job_initiator['party_id']

        path_dict = job_utils.save_job_conf(
            job_id=job_id,
            role=job.f_initiator_role,
            job_dsl=job_dsl,
            job_runtime_conf=job_runtime_conf,
            job_runtime_conf_on_party={},
            train_runtime_conf=train_runtime_conf,
            pipeline_dsl=None)

        if job.f_initiator_party_id not in job_runtime_conf['role'][
                job.f_initiator_role]:
            schedule_logger(job_id).info("initiator party id error:{}".format(
                job.f_initiator_party_id))
            raise Exception("initiator party id error {}".format(
                job.f_initiator_party_id))

        # create common parameters on initiator
        JobController.backend_compatibility(
            job_parameters=common_job_parameters)
        JobController.adapt_job_parameters(
            role=job.f_initiator_role,
            job_parameters=common_job_parameters,
            create_initiator_baseline=True)

        job.f_runtime_conf = conf_adapter.update_common_parameters(
            common_parameters=common_job_parameters)
        dsl_parser = schedule_utils.get_job_dsl_parser(
            dsl=job.f_dsl,
            runtime_conf=job.f_runtime_conf,
            train_runtime_conf=job.f_train_runtime_conf)

        # initiator runtime conf as template
        job.f_runtime_conf_on_party = job.f_runtime_conf.copy()
        job.f_runtime_conf_on_party[
            "job_parameters"] = common_job_parameters.to_dict()

        if common_job_parameters.work_mode == WorkMode.CLUSTER:
            # Save the status information of all participants in the initiator for scheduling
            for role, party_ids in job.f_roles.items():
                for party_id in party_ids:
                    if role == job.f_initiator_role and party_id == job.f_initiator_party_id:
                        continue
                    JobController.initialize_tasks(job_id, role, party_id,
                                                   False, job.f_initiator_role,
                                                   job.f_initiator_party_id,
                                                   common_job_parameters,
                                                   dsl_parser)

        status_code, response = FederatedScheduler.create_job(job=job)
        if status_code != FederatedSchedulingStatusCode.SUCCESS:
            job.f_status = JobStatus.FAILED
            job.f_tag = "submit_failed"
            FederatedScheduler.sync_job_status(job=job)
            raise Exception("create job failed", response)

        schedule_logger(job_id).info(
            'submit job successfully, job id is {}, model id is {}'.format(
                job.f_job_id, common_job_parameters.model_id))
        logs_directory = job_utils.get_job_log_directory(job_id)
        submit_result = {
            "job_id":
            job_id,
            "model_info": {
                "model_id": common_job_parameters.model_id,
                "model_version": common_job_parameters.model_version
            },
            "logs_directory":
            logs_directory,
            "board_url":
            job_utils.get_board_url(job_id, job_initiator['role'],
                                    job_initiator['party_id'])
        }
        submit_result.update(path_dict)
        return submit_result
コード例 #25
0
ファイル: deploy_model.py プロジェクト: kunchengit/FATE
def deploy(config_data):
    model_id = config_data.get('model_id')
    model_version = config_data.get('model_version')
    local_role = config_data.get('local').get('role')
    local_party_id = config_data.get('local').get('party_id')
    child_model_version = config_data.get('child_model_version')

    try:
        party_model_id = model_utils.gen_party_model_id(
            model_id=model_id, role=local_role, party_id=local_party_id)
        model = PipelinedModel(model_id=party_model_id,
                               model_version=model_version)
        model_data = model.collect_models(in_bytes=True)
        if "pipeline.pipeline:Pipeline" not in model_data:
            raise Exception("Can not found pipeline file in model.")

        # check if the model could be executed the deploy process (parent/child)
        if not check_before_deploy(model):
            raise Exception('Child model could not be deployed.')

        # copy proto content from parent model and generate a child model
        deploy_model = PipelinedModel(model_id=party_model_id,
                                      model_version=child_model_version)
        shutil.copytree(src=model.model_path, dst=deploy_model.model_path)
        pipeline = deploy_model.read_component_model('pipeline',
                                                     'pipeline')['Pipeline']

        # modify two pipeline files (model version/ train_runtime_conf)
        train_runtime_conf = json_loads(pipeline.train_runtime_conf)
        adapter = JobRuntimeConfigAdapter(train_runtime_conf)
        train_runtime_conf = adapter.update_model_id_version(
            model_version=deploy_model.model_version)
        pipeline.model_version = child_model_version
        pipeline.train_runtime_conf = json_dumps(train_runtime_conf, byte=True)

        parser = get_dsl_parser_by_version(
            train_runtime_conf.get('dsl_version', '1'))
        train_dsl = json_loads(pipeline.train_dsl)
        parent_predict_dsl = json_loads(pipeline.inference_dsl)

        if str(train_runtime_conf.get('dsl_version', '1')) == '1':
            predict_dsl = json_loads(pipeline.inference_dsl)
        else:
            if config_data.get('dsl') or config_data.get('predict_dsl'):
                predict_dsl = config_data.get('dsl') if config_data.get(
                    'dsl') else config_data.get('predict_dsl')
                if not isinstance(predict_dsl, dict):
                    predict_dsl = json_loads(predict_dsl)
            else:
                if config_data.get('cpn_list', None):
                    cpn_list = config_data.pop('cpn_list')
                else:
                    cpn_list = list(train_dsl.get('components', {}).keys())
                parser_version = train_runtime_conf.get('dsl_version', '1')
                if str(parser_version) == '1':
                    predict_dsl = parent_predict_dsl
                else:
                    parser = schedule_utils.get_dsl_parser_by_version(
                        parser_version)
                    predict_dsl = parser.deploy_component(cpn_list, train_dsl)

        #  save predict dsl into child model file
        parser.verify_dsl(predict_dsl, "predict")
        inference_dsl = parser.get_predict_dsl(
            role=local_role,
            predict_dsl=predict_dsl,
            setting_conf_prefix=file_utils.
            get_federatedml_setting_conf_directory())
        pipeline.inference_dsl = json_dumps(inference_dsl, byte=True)
        if model_utils.compare_version(pipeline.fate_version, '1.5.0') == 'gt':
            pipeline.parent_info = json_dumps(
                {
                    'parent_model_id': model_id,
                    'parent_model_version': model_version
                },
                byte=True)
            pipeline.parent = False
            runtime_conf_on_party = json_loads(pipeline.runtime_conf_on_party)
            runtime_conf_on_party['job_parameters'][
                'model_version'] = child_model_version
            pipeline.runtime_conf_on_party = json_dumps(runtime_conf_on_party,
                                                        byte=True)

        # save model file
        deploy_model.save_pipeline(pipeline)
        shutil.copyfile(
            os.path.join(deploy_model.model_path, "pipeline.pb"),
            os.path.join(deploy_model.model_path, "variables", "data",
                         "pipeline", "pipeline", "Pipeline"))

        model_info = model_utils.gather_model_info_data(deploy_model)
        model_info['job_id'] = model_info['f_model_version']
        model_info['size'] = deploy_model.calculate_model_file_size()
        model_info['role'] = local_role
        model_info['party_id'] = local_party_id
        model_info['work_mode'] = adapter.get_job_work_mode()
        model_info['parent'] = False if model_info.get(
            'f_inference_dsl') else True
        if model_utils.compare_version(model_info['f_fate_version'],
                                       '1.5.0') == 'eq':
            model_info['roles'] = model_info.get('f_train_runtime_conf',
                                                 {}).get('role', {})
            model_info['initiator_role'] = model_info.get(
                'f_train_runtime_conf', {}).get('initiator', {}).get('role')
            model_info['initiator_party_id'] = model_info.get(
                'f_train_runtime_conf', {}).get('initiator',
                                                {}).get('party_id')
        model_utils.save_model_info(model_info)

    except Exception as e:
        stat_logger.exception(e)
        return 100, f"deploy model of role {local_role} {local_party_id} failed, details: {str(e)}"
    else:
        return 0, f"deploy model of role {local_role} {local_party_id} success"
コード例 #26
0
ファイル: deploy_model.py プロジェクト: FederatedAI/FATE-Flow
def deploy(config_data):
    model_id = config_data.get('model_id')
    model_version = config_data.get('model_version')
    local_role = config_data.get('local').get('role')
    local_party_id = config_data.get('local').get('party_id')
    child_model_version = config_data.get('child_model_version')
    components_checkpoint = config_data.get('components_checkpoint', {})
    warning_msg = ""

    try:
        party_model_id = gen_party_model_id(model_id=model_id,
                                            role=local_role,
                                            party_id=local_party_id)
        model = PipelinedModel(model_id=party_model_id,
                               model_version=model_version)
        model_data = model.collect_models(in_bytes=True)
        if "pipeline.pipeline:Pipeline" not in model_data:
            raise Exception("Can not found pipeline file in model.")

        # check if the model could be executed the deploy process (parent/child)
        if not check_before_deploy(model):
            raise Exception('Child model could not be deployed.')

        # copy proto content from parent model and generate a child model
        deploy_model = PipelinedModel(model_id=party_model_id,
                                      model_version=child_model_version)
        shutil.copytree(src=model.model_path,
                        dst=deploy_model.model_path,
                        ignore=lambda src, names: {'checkpoint'}
                        if src == model.model_path else {})
        pipeline_model = deploy_model.read_pipeline_model()

        train_runtime_conf = json_loads(pipeline_model.train_runtime_conf)
        runtime_conf_on_party = json_loads(
            pipeline_model.runtime_conf_on_party)
        dsl_version = train_runtime_conf.get("dsl_version", "1")

        parser = get_dsl_parser_by_version(dsl_version)
        train_dsl = json_loads(pipeline_model.train_dsl)
        parent_predict_dsl = json_loads(pipeline_model.inference_dsl)

        if config_data.get('dsl') or config_data.get('predict_dsl'):
            inference_dsl = config_data.get('dsl') if config_data.get(
                'dsl') else config_data.get('predict_dsl')
            if not isinstance(inference_dsl, dict):
                inference_dsl = json_loads(inference_dsl)
        else:
            if config_data.get('cpn_list', None):
                cpn_list = config_data.pop('cpn_list')
            else:
                cpn_list = list(train_dsl.get('components', {}).keys())
            if int(dsl_version) == 1:
                # convert v1 dsl to v2 dsl
                inference_dsl, warning_msg = parser.convert_dsl_v1_to_v2(
                    parent_predict_dsl)
            else:
                parser = get_dsl_parser_by_version(dsl_version)
                inference_dsl = parser.deploy_component(cpn_list, train_dsl)

        # convert v1 conf to v2 conf
        if int(dsl_version) == 1:
            components = parser.get_components_light_weight(inference_dsl)

            from fate_flow.db.component_registry import ComponentRegistry
            job_providers = parser.get_job_providers(
                dsl=inference_dsl, provider_detail=ComponentRegistry.REGISTRY)
            cpn_role_parameters = dict()
            for cpn in components:
                cpn_name = cpn.get_name()
                role_params = parser.parse_component_role_parameters(
                    component=cpn_name,
                    dsl=inference_dsl,
                    runtime_conf=train_runtime_conf,
                    provider_detail=ComponentRegistry.REGISTRY,
                    provider_name=job_providers[cpn_name]["provider"]["name"],
                    provider_version=job_providers[cpn_name]["provider"]
                    ["version"])
                cpn_role_parameters[cpn_name] = role_params
            train_runtime_conf = parser.convert_conf_v1_to_v2(
                train_runtime_conf, cpn_role_parameters)

        adapter = JobRuntimeConfigAdapter(train_runtime_conf)
        train_runtime_conf = adapter.update_model_id_version(
            model_version=deploy_model.model_version)
        pipeline_model.model_version = child_model_version
        pipeline_model.train_runtime_conf = json_dumps(train_runtime_conf,
                                                       byte=True)

        #  save inference dsl into child model file
        parser = get_dsl_parser_by_version(2)
        parser.verify_dsl(inference_dsl, "predict")
        inference_dsl = JobSaver.fill_job_inference_dsl(
            job_id=model_version,
            role=local_role,
            party_id=local_party_id,
            dsl_parser=parser,
            origin_inference_dsl=inference_dsl)
        pipeline_model.inference_dsl = json_dumps(inference_dsl, byte=True)

        if compare_version(pipeline_model.fate_version, '1.5.0') == 'gt':
            pipeline_model.parent_info = json_dumps(
                {
                    'parent_model_id': model_id,
                    'parent_model_version': model_version
                },
                byte=True)
            pipeline_model.parent = False
            runtime_conf_on_party['job_parameters'][
                'model_version'] = child_model_version
            pipeline_model.runtime_conf_on_party = json_dumps(
                runtime_conf_on_party, byte=True)

        # save model file
        deploy_model.save_pipeline(pipeline_model)
        shutil.copyfile(
            os.path.join(deploy_model.model_path, "pipeline.pb"),
            os.path.join(deploy_model.model_path, "variables", "data",
                         "pipeline", "pipeline", "Pipeline"))

        model_info = gather_model_info_data(deploy_model)
        model_info['job_id'] = model_info['f_model_version']
        model_info['size'] = deploy_model.calculate_model_file_size()
        model_info['role'] = local_role
        model_info['party_id'] = local_party_id
        model_info['parent'] = False if model_info.get(
            'f_inference_dsl') else True
        if compare_version(model_info['f_fate_version'], '1.5.0') == 'eq':
            model_info['roles'] = model_info.get('f_train_runtime_conf',
                                                 {}).get('role', {})
            model_info['initiator_role'] = model_info.get(
                'f_train_runtime_conf', {}).get('initiator', {}).get('role')
            model_info['initiator_party_id'] = model_info.get(
                'f_train_runtime_conf', {}).get('initiator',
                                                {}).get('party_id')
        save_model_info(model_info)

        for component_name, component in train_dsl.get('components',
                                                       {}).items():
            step_index = components_checkpoint.get(component_name,
                                                   {}).get('step_index')
            step_name = components_checkpoint.get(component_name,
                                                  {}).get('step_name')
            if step_index is not None:
                step_index = int(step_index)
                step_name = None
            elif step_name is None:
                continue

            checkpoint_manager = CheckpointManager(
                role=local_role,
                party_id=local_party_id,
                model_id=model_id,
                model_version=model_version,
                component_name=component_name,
                mkdir=False,
            )
            checkpoint_manager.load_checkpoints_from_disk()
            if checkpoint_manager.latest_checkpoint is not None:
                checkpoint_manager.deploy(
                    child_model_version,
                    component['output']['model'][0] if component.get(
                        'output', {}).get('model') else 'default',
                    step_index,
                    step_name,
                )
    except Exception as e:
        stat_logger.exception(e)
        return 100, f"deploy model of role {local_role} {local_party_id} failed, details: {str(e)}"
    else:
        msg = f"deploy model of role {local_role} {local_party_id} success"
        if warning_msg:
            msg = msg + f", warning: {warning_msg}"
        return 0, msg
コード例 #27
0
ファイル: model_app.py プロジェクト: tarada/FATE
def operate_model(model_operation):
    request_config = request.json or request.form.to_dict()
    job_id = job_utils.generate_job_id()
    if model_operation not in [
            ModelOperation.STORE, ModelOperation.RESTORE,
            ModelOperation.EXPORT, ModelOperation.IMPORT
    ]:
        raise Exception(
            'Can not support this operating now: {}'.format(model_operation))
    required_arguments = ["model_id", "model_version", "role", "party_id"]
    check_config(request_config, required_arguments=required_arguments)
    request_config["model_id"] = gen_party_model_id(
        model_id=request_config["model_id"],
        role=request_config["role"],
        party_id=request_config["party_id"])
    if model_operation in [ModelOperation.EXPORT, ModelOperation.IMPORT]:
        if model_operation == ModelOperation.IMPORT:
            try:
                file = request.files.get('file')
                file_path = os.path.join(TEMP_DIRECTORY, file.filename)
                # if not os.path.exists(file_path):
                #     raise Exception('The file is obtained from the fate flow client machine, but it does not exist, '
                #                     'please check the path: {}'.format(file_path))
                try:
                    os.makedirs(os.path.dirname(file_path), exist_ok=True)
                    file.save(file_path)
                except Exception as e:
                    shutil.rmtree(file_path)
                    raise e
                request_config['file'] = file_path
                model = pipelined_model.PipelinedModel(
                    model_id=request_config["model_id"],
                    model_version=request_config["model_version"])
                model.unpack_model(file_path)

                pipeline = model.read_component_model('pipeline',
                                                      'pipeline')['Pipeline']
                train_runtime_conf = json_loads(pipeline.train_runtime_conf)
                permitted_party_id = []
                for key, value in train_runtime_conf.get('role', {}).items():
                    for v in value:
                        permitted_party_id.extend([v, str(v)])
                if request_config["party_id"] not in permitted_party_id:
                    shutil.rmtree(model.model_path)
                    raise Exception(
                        "party id {} is not in model roles, please check if the party id is valid."
                    )
                try:
                    with DB.connection_context():
                        model = MLModel.get_or_none(
                            MLModel.f_job_id == train_runtime_conf[
                                "job_parameters"]["model_version"],
                            MLModel.f_role == request_config["role"])
                        if not model:
                            MLModel.create(
                                f_role=request_config["role"],
                                f_party_id=request_config["party_id"],
                                f_roles=train_runtime_conf["role"],
                                f_job_id=train_runtime_conf["job_parameters"]
                                ["model_version"],
                                f_model_id=train_runtime_conf["job_parameters"]
                                ["model_id"],
                                f_model_version=train_runtime_conf[
                                    "job_parameters"]["model_version"],
                                f_initiator_role=train_runtime_conf[
                                    "initiator"]["role"],
                                f_initiator_party_id=train_runtime_conf[
                                    "initiator"]["party_id"],
                                f_runtime_conf=train_runtime_conf,
                                f_work_mode=train_runtime_conf[
                                    "job_parameters"]["work_mode"],
                                f_dsl=json_loads(pipeline.train_dsl),
                                f_imported=1,
                                f_job_status='complete')
                        else:
                            stat_logger.info(
                                f'job id: {train_runtime_conf["job_parameters"]["model_version"]}, '
                                f'role: {request_config["role"]} model info already existed in database.'
                            )
                except peewee.IntegrityError as e:
                    stat_logger.exception(e)
                operation_record(request_config, "import", "success")
                return get_json_result()
            except Exception:
                operation_record(request_config, "import", "failed")
                raise
        else:
            try:
                model = pipelined_model.PipelinedModel(
                    model_id=request_config["model_id"],
                    model_version=request_config["model_version"])
                if model.exists():
                    archive_file_path = model.packaging_model()
                    operation_record(request_config, "export", "success")
                    return send_file(archive_file_path,
                                     attachment_filename=os.path.basename(
                                         archive_file_path),
                                     as_attachment=True)
                else:
                    operation_record(request_config, "export", "failed")
                    res = error_response(
                        response_code=210,
                        retmsg="Model {} {} is not exist.".format(
                            request_config.get("model_id"),
                            request_config.get("model_version")))
                    return res
            except Exception as e:
                operation_record(request_config, "export", "failed")
                stat_logger.exception(e)
                return error_response(response_code=210, retmsg=str(e))
    else:
        data = {}
        job_dsl, job_runtime_conf = gen_model_operation_job_config(
            request_config, model_operation)
        job_id, job_dsl_path, job_runtime_conf_path, logs_directory, model_info, board_url = DAGScheduler.submit(
            {
                'job_dsl': job_dsl,
                'job_runtime_conf': job_runtime_conf
            },
            job_id=job_id)
        data.update({
            'job_dsl_path': job_dsl_path,
            'job_runtime_conf_path': job_runtime_conf_path,
            'board_url': board_url,
            'logs_directory': logs_directory
        })
        operation_record(data=job_runtime_conf,
                         oper_type=model_operation,
                         oper_status='')
        return get_json_result(job_id=job_id, data=data)