예제 #1
0
def do_load_model():
    request_data = request.json
    adapter_servings_config(request_data)
    retcode, retmsg = publish_model.load_model(config_data=request_data)
    try:
        if not retcode:
            with DB.connection_context():
                model = MLModel.get_or_none(MLModel.f_role == request_data.get("local").get("role"),
                                            MLModel.f_party_id == request_data.get("local").get("party_id"),
                                            MLModel.f_model_id == request_data.get("job_parameters").get("model_id"),
                                            MLModel.f_model_version == request_data.get("job_parameters").get("model_version"))
                if model:
                    count = model.f_loaded_times
                    model.f_loaded_times = count + 1
                    model.save()
    except Exception as modify_err:
        stat_logger.exception(modify_err)

    try:
        party_model_id = gen_party_model_id(role=request_data.get("local").get("role"),
                                            party_id=request_data.get("local").get("party_id"),
                                            model_id=request_data.get("job_parameters").get("model_id"))
        src_model_path = os.path.join(file_utils.get_project_base_directory(), 'model_local_cache', party_model_id,
                                      request_data.get("job_parameters").get("model_version"))
        dst_model_path = os.path.join(file_utils.get_project_base_directory(), 'loaded_model_backup',
                                      party_model_id, request_data.get("job_parameters").get("model_version"))
        if not os.path.exists(dst_model_path):
            shutil.copytree(src=src_model_path, dst=dst_model_path)
    except Exception as copy_err:
        stat_logger.exception(copy_err)
    operation_record(request_data, "load", "success" if not retcode else "failed")
    return get_json_result(retcode=retcode, retmsg=retmsg)
예제 #2
0
def do_load_model():
    request_data = request.json
    request_data['servings'] = RuntimeConfig.SERVICE_DB.get_urls('servings')

    role = request_data['local']['role']
    party_id = request_data['local']['party_id']
    model_id = request_data['job_parameters']['model_id']
    model_version = request_data['job_parameters']['model_version']
    party_model_id = model_utils.gen_party_model_id(model_id, role, party_id)

    if get_base_config('enable_model_store', False):
        pipeline_model = pipelined_model.PipelinedModel(
            party_model_id, model_version)

        component_parameters = {
            'model_id': party_model_id,
            'model_version': model_version,
            'store_address': ServiceRegistry.MODEL_STORE_ADDRESS,
        }
        model_storage = get_model_storage(component_parameters)

        if pipeline_model.exists() and not model_storage.exists(
                **component_parameters):
            stat_logger.info(
                f'Uploading {pipeline_model.model_path} to model storage.')
            model_storage.store(**component_parameters)
        elif not pipeline_model.exists() and model_storage.exists(
                **component_parameters):
            stat_logger.info(
                f'Downloading {pipeline_model.model_path} from model storage.')
            model_storage.restore(**component_parameters)

    if not model_utils.check_if_deployed(role, party_id, model_id,
                                         model_version):
        return get_json_result(
            retcode=100,
            retmsg=
            "Only deployed models could be used to execute process of loading. "
            "Please deploy model before loading.")

    retcode, retmsg = publish_model.load_model(request_data)
    try:
        if not retcode:
            with DB.connection_context():
                model = MLModel.get_or_none(
                    MLModel.f_role == request_data["local"]["role"],
                    MLModel.f_party_id == request_data["local"]["party_id"],
                    MLModel.f_model_id == request_data["job_parameters"]
                    ["model_id"], MLModel.f_model_version ==
                    request_data["job_parameters"]["model_version"])
                if model:
                    model.f_loaded_times += 1
                    model.save()
    except Exception as modify_err:
        stat_logger.exception(modify_err)

    operation_record(request_data, "load",
                     "success" if not retcode else "failed")
    return get_json_result(retcode=retcode, retmsg=retmsg)