def get_predict_dsl(): request_data = request.json request_data['query_filters'] = ['inference_dsl'] retcode, retmsg, data = model_utils.query_model_info_from_file( **request_data) if data: for d in data: if d.get("f_role") in {"guest", "host"}: _data = d break else: return error_response( 210, "can not found guest or host model, please get predict dsl on guest or host." ) if request_data.get("filename"): os.makedirs(TEMP_DIRECTORY, exist_ok=True) temp_filepath = os.path.join(TEMP_DIRECTORY, request_data.get("filename")) with open(temp_filepath, "w") as fout: fout.write(json_dumps(_data['f_inference_dsl'], indent=4)) return send_file(open(temp_filepath, "rb"), as_attachment=True, attachment_filename=request_data.get("filename")) else: return get_json_result(data=_data['f_inference_dsl']) return error_response( 210, "No model found, please check if arguments are specified correctly.")
def component_output_data_download(): request_data = request.json tasks = JobSaver.query_task(only_latest=True, job_id=request_data['job_id'], component_name=request_data['component_name'], role=request_data['role'], party_id=request_data['party_id']) if not tasks: raise ValueError( f'no found task, please check if the parameters are correct:{request_data}' ) import_component_output_depend(tasks[0].f_provider_info) try: output_tables_meta = get_component_output_tables_meta( task_data=request_data) except Exception as e: stat_logger.exception(e) return error_response(210, str(e)) limit = request_data.get('limit', -1) if not output_tables_meta: return error_response(response_code=210, retmsg='no data') if limit == 0: return error_response(response_code=210, retmsg='limit is 0') tar_file_name = 'job_{}_{}_{}_{}_output_data.tar.gz'.format( request_data['job_id'], request_data['component_name'], request_data['role'], request_data['party_id']) return TableStorage.send_table(output_tables_meta, tar_file_name, limit=limit, need_head=request_data.get("head", True))
def validate_component_param(): if not request.json or not isinstance(request.json, dict): return error_response(400, 'bad request') required_keys = [ 'component_name', 'component_module_name', ] config_keys = ['role'] dsl_version = int(request.json.get('dsl_version', 0)) if dsl_version == 1: config_keys += ['role_parameters', 'algorithm_parameters'] parser_class = DSLParser elif dsl_version == 2: config_keys += ['component_parameters'] parser_class = DSLParserV2 else: return error_response(400, 'unsupported dsl_version') try: check_config(request.json, required_keys + config_keys) except Exception as e: return error_response(400, str(e)) try: parser_class.validate_component_param( get_federatedml_setting_conf_directory(), {i: request.json[i] for i in config_keys}, *[request.json[i] for i in required_keys]) except Exception as e: return error_response(400, str(e)) return get_json_result()
def get_component_summary(): request_data = request.json try: required_params = ["job_id", "component_name", "role", "party_id"] detect_utils.check_config(request_data, required_params) tracker = Tracker(job_id=request_data["job_id"], component_name=request_data["component_name"], role=request_data["role"], party_id=request_data["party_id"], task_id=request_data.get("task_id", None), task_version=request_data.get("task_version", None)) summary = tracker.read_summary_from_db() if summary: if request_data.get("filename"): temp_filepath = os.path.join(TEMP_DIRECTORY, request_data.get("filename")) with open(temp_filepath, "w") as fout: fout.write(json.dumps(summary, indent=4)) return send_file( open(temp_filepath, "rb"), as_attachment=True, attachment_filename=request_data.get("filename")) else: return get_json_result(data=summary) return error_response( 210, "No component summary found, please check if arguments are specified correctly." ) except Exception as e: stat_logger.exception(e) return error_response(210, str(e))
def get_config(): kwargs = {} job_configuration = None for i in ('job_id', 'role', 'party_id'): if request.json.get(i) is None: return error_response(400, f"'{i}' is required.") kwargs[i] = str(request.json[i]) for i in ('component_name', 'task_id', 'task_version'): if request.json.get(i) is None: break kwargs[i] = str(request.json[i]) else: try: job_configuration = job_utils.get_task_using_job_conf(**kwargs) except Exception: pass if job_configuration is None: job_configuration = job_utils.get_job_configuration( kwargs['job_id'], kwargs['role'], kwargs['party_id']) if job_configuration is None: return error_response(404, 'Job not found.') return get_json_result(data=job_configuration.to_dict())
def transfer_model(): party_model_id = request.json.get('namespace') model_version = request.json.get('name') if not party_model_id or not model_version: return error_response(400, 'namespace and name are required') model_data = publish_model.download_model(party_model_id, model_version) if model_data is None: return error_response(404, 'model not found') return get_json_result(data=model_data)
def component_output_data_download(): request_data = request.json output_data_table = get_component_output_data_table(task_data=request_data) limit = request_data.get('limit', -1) if not output_data_table: return error_response(response_code=500, retmsg='no data') if limit == 0: return error_response(response_code=500, retmsg='limit is 0') output_data_count = 0 have_data_label = False output_tmp_dir = os.path.join(os.getcwd(), 'tmp/{}'.format(fate_uuid())) output_file_path = '{}/output_%s'.format(output_tmp_dir) output_data_file_path = output_file_path % 'data.csv' os.makedirs(os.path.dirname(output_data_file_path), exist_ok=True) with open(output_data_file_path, 'w') as fw: for k, v in output_data_table.collect(): data_line, have_data_label = get_component_output_data_line( src_key=k, src_value=v) fw.write('{}\n'.format(','.join(map(lambda x: str(x), data_line)))) output_data_count += 1 if output_data_count == limit: break if output_data_count: # get meta header = get_component_output_data_meta( output_data_table=output_data_table, have_data_label=have_data_label) output_data_meta_file_path = output_file_path % 'data_meta.json' with open(output_data_meta_file_path, 'w') as fw: json.dump({'header': header}, fw, indent=4) if request_data.get('head', True): with open(output_data_file_path, 'r+') as f: content = f.read() f.seek(0, 0) f.write('{}\n'.format(','.join(header)) + content) # tar memory_file = io.BytesIO() tar = tarfile.open(fileobj=memory_file, mode='w:gz') tar.add(output_data_file_path, os.path.relpath(output_data_file_path, output_tmp_dir)) tar.add(output_data_meta_file_path, os.path.relpath(output_data_meta_file_path, output_tmp_dir)) tar.close() memory_file.seek(0) try: shutil.rmtree(os.path.dirname(output_data_file_path)) except Exception as e: # warning stat_logger.warning(e) tar_file_name = 'job_{}_{}_{}_{}_output_data.tar.gz'.format( request_data['job_id'], request_data['component_name'], request_data['role'], request_data['party_id']) return send_file(memory_file, attachment_filename=tar_file_name, as_attachment=True)
def get_mysql_info(): if IS_STANDALONE: return error_response(404, 'mysql only available on cluster mode') try: with DB.connection_context(): DB.random() except Exception as e: return error_response(503, str(e)) return error_response(200)
def list_job(): limit, offset = parse_limit_and_offset() query = { 'tag': ('!=', 'submit_failed'), } for i in ('job_id', 'description'): if request.json.get(i) is not None: query[i] = ('contains', request.json[i]) if request.json.get('party_id') is not None: try: query['party_id'] = int(request.json['party_id']) except Exception: return error_response(400, f"Invalid parameter 'party_id'.") query['party_id'] = ('contains', query['party_id']) if request.json.get('partner') is not None: query['roles'] = ('contains', query['partner']) for i in ('role', 'status'): if request.json.get(i) is None: continue if isinstance(request.json[i], str): request.json[i] = [request.json[i]] if not isinstance(request.json[i], list): return error_response(400, f"Invalid parameter '{i}'.") request.json[i] = set(request.json[i]) for j in request.json[i]: if j not in valid_query_parameters[i]: return error_response(400, f"Invalid parameter '{i}'.") query[i] = ('in_', request.json[i]) jobs, count = job_utils.list_job(limit, offset, query, parse_order_by(('create_time', 'desc'))) jobs = [job.to_human_model_dict() for job in jobs] for job in jobs: job['party_id'] = int(job['party_id']) job['partners'] = set() for i in ('guest', 'host', 'arbiter'): job['partners'].update(job['roles'].get(i, [])) job['partners'].discard(job['party_id']) job['partners'] = sorted(job['partners']) return get_json_result(data={ 'jobs': jobs, 'count': count, })
def parse_order_by(default=None): order_by = [] if request.json.get('order_by') is not None: if request.json['order_by'] not in valid_query_parameters['order_by']: abort(error_response(400, f"Invalid parameter 'order_by'.")) order_by.append(request.json['order_by']) if request.json.get('order') is not None: if request.json['order'] not in valid_query_parameters['order']: abort(error_response(400, f"Invalid parameter order 'order'.")) order_by.append(request.json['order']) return order_by or default
def get_eggroll_info(): if IS_STANDALONE: return error_response(404, 'eggroll only available on cluster mode') if PROXY != CoordinationProxyService.ROLLSITE: return error_response( 404, 'coordination communication protocol is not rollsite') conf = ServiceRegistry.FATE_ON_EGGROLL['rollsite'] with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: r = s.connect_ex((conf['host'], conf['port'])) if r != 0: return error_response(503) return error_response(200)
def dsl_generator(): data = request.json cpn_str = data.get("cpn_str", "") try: if not cpn_str: raise Exception("Component list should not be empty.") if isinstance(cpn_str, list): cpn_list = cpn_str else: if (cpn_str.find("/") and cpn_str.find("\\")) != -1: raise Exception( "Component list string should not contain '/' or '\\'.") cpn_str = cpn_str.replace(" ", "").replace("\n", "").strip(",[]") cpn_list = cpn_str.split(",") train_dsl = json_loads(data.get("train_dsl")) parser = schedule_utils.get_dsl_parser_by_version( data.get("version", "2")) predict_dsl = parser.deploy_component(cpn_list, train_dsl) if data.get("filename"): os.makedirs(TEMP_DIRECTORY, exist_ok=True) temp_filepath = os.path.join(TEMP_DIRECTORY, data.get("filename")) with open(temp_filepath, "w") as fout: fout.write(json.dumps(predict_dsl, indent=4)) return send_file(open(temp_filepath, 'rb'), as_attachment=True, attachment_filename=data.get("filename")) return get_json_result(data=predict_dsl) except Exception as e: stat_logger.exception(e) return error_response( 210, "DSL generating failed. For more details, " "please check logs/fate_flow/fate_flow_stat.log.")
def parse_limit_and_offset(): try: limit = int(request.json.get('limit', 0)) page = int(request.json.get('page', 1)) - 1 except Exception: abort(error_response(400, f"Invalid parameter 'limit' or 'page'.")) return limit, limit * page
def get_fateboard_info(): host = ServiceRegistry.FATEBOARD.get('host') port = ServiceRegistry.FATEBOARD.get('port') if not host or not port: return error_response(404, 'fateboard is not configured') return get_json_result(data={ 'host': host, 'port': port, })
def get_fateboard_info(): fateboard = get_base_config('fateboard', {}) host = fateboard.get('host') port = fateboard.get('port') if not host or not port: return error_response(404, 'fateboard is not configured') return get_json_result(data={ 'host': host, 'port': port, })
def get_checkpoint(): checkpoint_manager = load_checkpoints() if 'step_index' in request.json: try: request.json['step_index'] = int(request.json['step_index']) except Exception: return error_response(400, "Invalid 'step_index'") checkpoint = checkpoint_manager.get_checkpoint_by_index( request.json['step_index']) elif 'step_name' in request.json: checkpoint = checkpoint_manager.get_checkpoint_by_name( request.json['step_name']) else: return error_response(400, "'step_index' or 'step_name' is required") if checkpoint is None: return error_response(404, "The checkpoint was not found.") return get_json_result(data=checkpoint.to_dict(True))
def check_job_log_dir(): job_id = str(request.json['job_id']) job_log_dir = job_utils.get_job_log_directory(job_id=job_id) if not os.path.exists(job_log_dir): abort( error_response( 404, f"Log file path: '{job_log_dir}' not found. Please check if the job id is valid." )) return job_id, job_log_dir
def list_task(): limit, offset = parse_limit_and_offset() query = {} for i in ('job_id', 'role', 'party_id', 'component_name'): if request.json.get(i) is not None: query[i] = request.json[i] if query.get('role') is not None: if query['role'] not in valid_query_parameters['role']: return error_response(400, f"Invalid parameter 'role'.") if query.get('party_id') is not None: try: query['party_id'] = int(query['party_id']) except Exception: return error_response(400, f"Invalid parameter 'party_id'.") tasks, count = job_utils.list_task(limit, offset, query, parse_order_by(('create_time', 'asc'))) return get_json_result( data={ 'tasks': [task.to_human_model_dict() for task in tasks], 'count': count, })
def authentication(): if request.json and request.form: return error_response(400) if not (HTTP_APP_KEY and HTTP_SECRET_KEY): return for i in [ 'TIMESTAMP', 'NONCE', 'APP_KEY', 'SIGNATURE', ]: if not request.headers.get(i): return error_response(401) try: timestamp = int(request.headers['TIMESTAMP']) / 1000 except Exception: return error_response(400, 'Invalid TIMESTAMP') now = time() if not now - MAX_TIMESTAMP_INTERVAL < timestamp < now + MAX_TIMESTAMP_INTERVAL: return error_response( 425, f'TIMESTAMP is more than {MAX_TIMESTAMP_INTERVAL} seconds away from the server time' ) if not request.headers['NONCE']: return error_response(400, 'Invalid NONCE') if request.headers['APP_KEY'] != HTTP_APP_KEY: return error_response(401, 'Unknown APP_KEY') signature = b64encode( HMAC( HTTP_SECRET_KEY.encode('ascii'), b'\n'.join([ request.headers['TIMESTAMP'].encode('ascii'), request.headers['NONCE'].encode('ascii'), request.headers['APP_KEY'].encode('ascii'), request.full_path.rstrip('?').encode('ascii'), request.data if request.json else b'', # quote_via: `urllib.parse.quote` replaces spaces with `%20` # safe: unreserved characters from rfc3986 urlencode( sorted(request.form.items()), quote_via=quote, safe='-._~').encode('ascii') if request.form else b'', ]), 'sha1').digest()).decode('ascii') if signature != request.headers['SIGNATURE']: return error_response(403)
def load_checkpoints(): required_args = [ 'role', 'party_id', 'model_id', 'model_version', 'component_name' ] try: check_config(request.json, required_args) except Exception as e: abort(error_response(400, str(e))) checkpoint_manager = CheckpointManager(**{ i: request.json[i] for i in required_args }, mkdir=False) checkpoint_manager.load_checkpoints_from_disk() return checkpoint_manager
def table_download(): request_data = request.json from fate_flow.component_env_utils.env_utils import import_component_output_depend import_component_output_depend() data_table_meta = storage.StorageTableMeta( name=request_data.get("name"), namespace=request_data.get("namespace")) if not data_table_meta: return error_response( response_code=210, retmsg= f'no found table:{request_data.get("namespace")}, {request_data.get("name")}' ) tar_file_name = 'table_{}_{}.tar.gz'.format(request_data.get("namespace"), request_data.get("name")) return TableStorage.send_table( output_tables_meta={"table": data_table_meta}, tar_file_name=tar_file_name, need_head=request_data.get("head", True))
def get_predict_conf(): request_data = request.json required_parameters = ['model_id', 'model_version'] check_config(request_data, required_parameters) model_dir = os.path.join(get_project_base_directory(), 'model_local_cache') model_fp_list = glob.glob( model_dir + f"/guest#*#{request_data['model_id']}/{request_data['model_version']}") if model_fp_list: fp = model_fp_list[0] pipeline_model = PipelinedModel(model_id=fp.split('/')[-2], model_version=fp.split('/')[-1]) pipeline = pipeline_model.read_component_model('pipeline', 'pipeline')['Pipeline'] predict_dsl = json_loads(pipeline.inference_dsl) train_runtime_conf = json_loads(pipeline.train_runtime_conf) parser = schedule_utils.get_dsl_parser_by_version( train_runtime_conf.get('dsl_version', '1')) predict_conf = parser.generate_predict_conf_template( predict_dsl=predict_dsl, train_conf=train_runtime_conf, model_id=request_data['model_id'], model_version=request_data['model_version']) else: predict_conf = '' if predict_conf: if request_data.get("filename"): os.makedirs(TEMP_DIRECTORY, exist_ok=True) temp_filepath = os.path.join(TEMP_DIRECTORY, request_data.get("filename")) with open(temp_filepath, "w") as fout: fout.write(json_dumps(predict_conf, indent=4)) return send_file(open(temp_filepath, "rb"), as_attachment=True, attachment_filename=request_data.get("filename")) else: return get_json_result(data=predict_conf) return error_response( 210, "No model found, please check if arguments are specified correctly.")
def job_log(): job_id = request.json.get('job_id', '') job_log_dir = job_utils.get_job_log_directory(job_id=job_id) if os.path.exists(job_log_dir): memory_file = io.BytesIO() tar = tarfile.open(fileobj=memory_file, mode='w:gz') for root, dir, files in os.walk(job_log_dir): for file in files: full_path = os.path.join(root, file) rel_path = os.path.relpath(full_path, job_log_dir) tar.add(full_path, rel_path) tar.close() memory_file.seek(0) return send_file( memory_file, attachment_filename='job_{}_log.tar.gz'.format(job_id), as_attachment=True) else: return error_response( 210, "Log file path: {} not found. Please check if the job id is valid." .format(job_log_dir))
def register(): info = request.json or request.form.to_dict() if not Path(info["path"]).is_dir(): return error_response(400, "invalid path") provider = ComponentProvider( name=info["name"], version=info["version"], path=info["path"], class_path=info.get("class_path", ComponentRegistry.get_default_class_path())) code, std = WorkerManager.start_general_worker( worker_name=WorkerName.PROVIDER_REGISTRAR, provider=provider) if code == 0: ComponentRegistry.load() if ComponentRegistry.get_providers().get(provider.name, {}).get( provider.version, None) is None: return get_json_result(retcode=RetCode.OPERATING_ERROR, retmsg=f"not load into memory") else: return get_json_result() else: return get_json_result(retcode=RetCode.OPERATING_ERROR, retmsg=f"register failed:\n{std}")
def operate_model(model_operation): request_config = request.json or request.form.to_dict() job_id = job_utils.generate_job_id() if model_operation not in [ ModelOperation.STORE, ModelOperation.RESTORE, ModelOperation.EXPORT, ModelOperation.IMPORT ]: raise Exception( 'Can not support this operating now: {}'.format(model_operation)) required_arguments = ["model_id", "model_version", "role", "party_id"] check_config(request_config, required_arguments=required_arguments) request_config["model_id"] = gen_party_model_id( model_id=request_config["model_id"], role=request_config["role"], party_id=request_config["party_id"]) if model_operation in [ModelOperation.EXPORT, ModelOperation.IMPORT]: if model_operation == ModelOperation.IMPORT: try: file = request.files.get('file') file_path = os.path.join(TEMP_DIRECTORY, file.filename) # if not os.path.exists(file_path): # raise Exception('The file is obtained from the fate flow client machine, but it does not exist, ' # 'please check the path: {}'.format(file_path)) try: os.makedirs(os.path.dirname(file_path), exist_ok=True) file.save(file_path) except Exception as e: shutil.rmtree(file_path) raise e request_config['file'] = file_path model = pipelined_model.PipelinedModel( model_id=request_config["model_id"], model_version=request_config["model_version"]) model.unpack_model(file_path) pipeline = model.read_component_model('pipeline', 'pipeline')['Pipeline'] train_runtime_conf = json_loads(pipeline.train_runtime_conf) permitted_party_id = [] for key, value in train_runtime_conf.get('role', {}).items(): for v in value: permitted_party_id.extend([v, str(v)]) if request_config["party_id"] not in permitted_party_id: shutil.rmtree(model.model_path) raise Exception( "party id {} is not in model roles, please check if the party id is valid." ) try: adapter = JobRuntimeConfigAdapter(train_runtime_conf) job_parameters = adapter.get_common_parameters().to_dict() with DB.connection_context(): db_model = MLModel.get_or_none( MLModel.f_job_id == job_parameters.get( "model_version"), MLModel.f_role == request_config["role"]) if not db_model: model_info = model_utils.gather_model_info_data(model) model_info['imported'] = 1 model_info['job_id'] = model_info['f_model_version'] model_info['size'] = model.calculate_model_file_size() model_info['role'] = request_config["model_id"].split( '#')[0] model_info['party_id'] = request_config[ "model_id"].split('#')[1] if model_utils.compare_version( model_info['f_fate_version'], '1.5.1') == 'lt': model_info['roles'] = model_info.get( 'f_train_runtime_conf', {}).get('role', {}) model_info['initiator_role'] = model_info.get( 'f_train_runtime_conf', {}).get('initiator', {}).get('role') model_info['initiator_party_id'] = model_info.get( 'f_train_runtime_conf', {}).get('initiator', {}).get('party_id') model_info[ 'work_mode'] = adapter.get_job_work_mode() model_info['parent'] = False if model_info.get( 'f_inference_dsl') else True model_utils.save_model_info(model_info) else: stat_logger.info( f'job id: {job_parameters.get("model_version")}, ' f'role: {request_config["role"]} model info already existed in database.' ) except peewee.IntegrityError as e: stat_logger.exception(e) operation_record(request_config, "import", "success") return get_json_result() except Exception: operation_record(request_config, "import", "failed") raise else: try: model = pipelined_model.PipelinedModel( model_id=request_config["model_id"], model_version=request_config["model_version"]) if model.exists(): archive_file_path = model.packaging_model() operation_record(request_config, "export", "success") return send_file(archive_file_path, attachment_filename=os.path.basename( archive_file_path), as_attachment=True) else: operation_record(request_config, "export", "failed") res = error_response( response_code=210, retmsg="Model {} {} is not exist.".format( request_config.get("model_id"), request_config.get("model_version"))) return res except Exception as e: operation_record(request_config, "export", "failed") stat_logger.exception(e) return error_response(response_code=210, retmsg=str(e)) else: data = {} job_dsl, job_runtime_conf = gen_model_operation_job_config( request_config, model_operation) submit_result = DAGScheduler.submit( { 'job_dsl': job_dsl, 'job_runtime_conf': job_runtime_conf }, job_id=job_id) data.update(submit_result) operation_record(data=job_runtime_conf, oper_type=model_operation, oper_status='') return get_json_result(job_id=job_id, data=data)
def operate_model(model_operation): request_config = request.json or request.form.to_dict() job_id = job_utils.generate_job_id() if model_operation not in [ ModelOperation.STORE, ModelOperation.RESTORE, ModelOperation.EXPORT, ModelOperation.IMPORT ]: raise Exception( 'Can not support this operating now: {}'.format(model_operation)) required_arguments = ["model_id", "model_version", "role", "party_id"] check_config(request_config, required_arguments=required_arguments) request_config["model_id"] = gen_party_model_id( model_id=request_config["model_id"], role=request_config["role"], party_id=request_config["party_id"]) if model_operation in [ModelOperation.EXPORT, ModelOperation.IMPORT]: if model_operation == ModelOperation.IMPORT: try: file = request.files.get('file') file_path = os.path.join(TEMP_DIRECTORY, file.filename) # if not os.path.exists(file_path): # raise Exception('The file is obtained from the fate flow client machine, but it does not exist, ' # 'please check the path: {}'.format(file_path)) try: os.makedirs(os.path.dirname(file_path), exist_ok=True) file.save(file_path) except Exception as e: shutil.rmtree(file_path) raise e request_config['file'] = file_path model = pipelined_model.PipelinedModel( model_id=request_config["model_id"], model_version=request_config["model_version"]) model.unpack_model(file_path) pipeline = model.read_component_model('pipeline', 'pipeline')['Pipeline'] train_runtime_conf = json_loads(pipeline.train_runtime_conf) permitted_party_id = [] for key, value in train_runtime_conf.get('role', {}).items(): for v in value: permitted_party_id.extend([v, str(v)]) if request_config["party_id"] not in permitted_party_id: shutil.rmtree(model.model_path) raise Exception( "party id {} is not in model roles, please check if the party id is valid." ) try: with DB.connection_context(): model = MLModel.get_or_none( MLModel.f_job_id == train_runtime_conf[ "job_parameters"]["model_version"], MLModel.f_role == request_config["role"]) if not model: MLModel.create( f_role=request_config["role"], f_party_id=request_config["party_id"], f_roles=train_runtime_conf["role"], f_job_id=train_runtime_conf["job_parameters"] ["model_version"], f_model_id=train_runtime_conf["job_parameters"] ["model_id"], f_model_version=train_runtime_conf[ "job_parameters"]["model_version"], f_initiator_role=train_runtime_conf[ "initiator"]["role"], f_initiator_party_id=train_runtime_conf[ "initiator"]["party_id"], f_runtime_conf=train_runtime_conf, f_work_mode=train_runtime_conf[ "job_parameters"]["work_mode"], f_dsl=json_loads(pipeline.train_dsl), f_imported=1, f_job_status='complete') else: stat_logger.info( f'job id: {train_runtime_conf["job_parameters"]["model_version"]}, ' f'role: {request_config["role"]} model info already existed in database.' ) except peewee.IntegrityError as e: stat_logger.exception(e) operation_record(request_config, "import", "success") return get_json_result() except Exception: operation_record(request_config, "import", "failed") raise else: try: model = pipelined_model.PipelinedModel( model_id=request_config["model_id"], model_version=request_config["model_version"]) if model.exists(): archive_file_path = model.packaging_model() operation_record(request_config, "export", "success") return send_file(archive_file_path, attachment_filename=os.path.basename( archive_file_path), as_attachment=True) else: operation_record(request_config, "export", "failed") res = error_response( response_code=210, retmsg="Model {} {} is not exist.".format( request_config.get("model_id"), request_config.get("model_version"))) return res except Exception as e: operation_record(request_config, "export", "failed") stat_logger.exception(e) return error_response(response_code=210, retmsg=str(e)) else: data = {} job_dsl, job_runtime_conf = gen_model_operation_job_config( request_config, model_operation) job_id, job_dsl_path, job_runtime_conf_path, logs_directory, model_info, board_url = DAGScheduler.submit( { 'job_dsl': job_dsl, 'job_runtime_conf': job_runtime_conf }, job_id=job_id) data.update({ 'job_dsl_path': job_dsl_path, 'job_runtime_conf_path': job_runtime_conf_path, 'board_url': board_url, 'logs_directory': logs_directory }) operation_record(data=job_runtime_conf, oper_type=model_operation, oper_status='') return get_json_result(job_id=job_id, data=data)
def internal_server_error(e): stat_logger.exception(e) return error_response(500, str(e))
def download_model(party_model_id, model_version): party_model_id = party_model_id.replace('~', '#') model_data = publish_model.download_model(party_model_id, model_version) if model_data is None: return error_response(404, 'model not found') return get_json_result(data=model_data)
def component_output_data_download(): request_data = request.json try: output_tables_meta = get_component_output_tables_meta(task_data=request_data) except Exception as e: stat_logger.exception(e) return error_response(210, str(e)) limit = request_data.get('limit', -1) if not output_tables_meta: return error_response(response_code=210, retmsg='no data') if limit == 0: return error_response(response_code=210, retmsg='limit is 0') have_data_label = False output_data_file_list = [] output_data_meta_file_list = [] output_tmp_dir = os.path.join(os.getcwd(), 'tmp/{}'.format(fate_uuid())) for output_name, output_table_meta in output_tables_meta.items(): output_data_count = 0 is_str = False output_data_file_path = "{}/{}.csv".format(output_tmp_dir, output_name) os.makedirs(os.path.dirname(output_data_file_path), exist_ok=True) with open(output_data_file_path, 'w') as fw: with storage.Session.build(name=output_table_meta.get_name(), namespace=output_table_meta.get_namespace()) as storage_session: output_table = storage_session.get_table() for k, v in output_table.collect(): data_line, have_data_label, is_str = get_component_output_data_line(src_key=k, src_value=v) fw.write('{}\n'.format(','.join(map(lambda x: str(x), data_line)))) output_data_count += 1 if output_data_count == limit: break if output_data_count: # get meta output_data_file_list.append(output_data_file_path) header = get_component_output_data_schema(output_table_meta=output_table_meta, have_data_label=have_data_label, is_str=is_str) output_data_meta_file_path = "{}/{}.meta".format(output_tmp_dir, output_name) output_data_meta_file_list.append(output_data_meta_file_path) with open(output_data_meta_file_path, 'w') as fw: json.dump({'header': header}, fw, indent=4) if request_data.get('head', True) and header: with open(output_data_file_path, 'r+') as f: content = f.read() f.seek(0, 0) f.write('{}\n'.format(','.join(header)) + content) # tar memory_file = io.BytesIO() tar = tarfile.open(fileobj=memory_file, mode='w:gz') for index in range(0, len(output_data_file_list)): tar.add(output_data_file_list[index], os.path.relpath(output_data_file_list[index], output_tmp_dir)) tar.add(output_data_meta_file_list[index], os.path.relpath(output_data_meta_file_list[index], output_tmp_dir)) tar.close() memory_file.seek(0) output_data_file_list.extend(output_data_meta_file_list) for path in output_data_file_list: try: shutil.rmtree(os.path.dirname(path)) except Exception as e: # warning stat_logger.warning(e) tar_file_name = 'job_{}_{}_{}_{}_output_data.tar.gz'.format(request_data['job_id'], request_data['component_name'], request_data['role'], request_data['party_id']) return send_file(memory_file, attachment_filename=tar_file_name, as_attachment=True)