def job_view(): request_data = request.json check_request_parameters(request_data) job_tracker = Tracker(job_id=request_data['job_id'], role=request_data['role'], party_id=request_data['party_id']) job_view_data = job_tracker.get_job_view() if job_view_data: job_metric_list = job_tracker.get_metric_list(job_level=True) job_view_data['model_summary'] = {} for metric_namespace, namespace_metrics in job_metric_list.items(): job_view_data['model_summary'][metric_namespace] = job_view_data[ 'model_summary'].get(metric_namespace, {}) for metric_name in namespace_metrics: job_view_data['model_summary'][metric_namespace][ metric_name] = job_view_data['model_summary'][ metric_namespace].get(metric_name, {}) for metric_data in job_tracker.get_job_metric_data( metric_namespace=metric_namespace, metric_name=metric_name): job_view_data['model_summary'][metric_namespace][ metric_name][metric_data.key] = metric_data.value return get_json_result(retcode=0, retmsg='success', data=job_view_data) else: return get_json_result(retcode=101, retmsg='error')
def _run(self, cpn_input: ComponentInputProtocol): self.parameters = cpn_input.parameters LOGGER.info(self.parameters) for k, v in self.parameters.items(): if hasattr(self, k): setattr(self, k, v) tracker = Tracker(job_id=self.job_id, role=self.tracker.role, party_id=self.tracker.party_id, component_name=self.component_name) LOGGER.info(f"query cache by cache key: {self.cache_key} cache name: {self.cache_name}") # todo: use tracker client but not tracker caches = tracker.query_output_cache(cache_key=self.cache_key, cache_name=self.cache_name) if not caches: raise Exception("can not found this cache") elif len(caches) > 1: raise Exception(f"found {len(caches)} caches, only support one, please check parameters") else: cache = caches[0] self.cache_output = cache tracker.job_id = self.tracker.job_id tracker.component_name = self.tracker.component_name metric_meta = cache.to_dict() metric_meta.pop("data") metric_meta["component_name"] = self.component_name self.tracker.set_metric_meta(metric_namespace="cache_loader", metric_name=cache.name, metric_meta=MetricMeta(name="cache", metric_type="cache_info", extra_metas=metric_meta))
def component_metric_all(): request_data = request.json check_request_parameters(request_data) tracker = Tracker(job_id=request_data['job_id'], component_name=request_data['component_name'], role=request_data['role'], party_id=request_data['party_id']) metrics = tracker.get_metric_list() all_metric_data = {} if metrics: for metric_namespace, metric_names in metrics.items(): all_metric_data[metric_namespace] = all_metric_data.get( metric_namespace, {}) for metric_name in metric_names: all_metric_data[metric_namespace][ metric_name] = all_metric_data[metric_namespace].get( metric_name, {}) metric_data, metric_meta = get_metric_all_data( tracker=tracker, metric_namespace=metric_namespace, metric_name=metric_name) all_metric_data[metric_namespace][metric_name][ 'data'] = metric_data all_metric_data[metric_namespace][metric_name][ 'meta'] = metric_meta return get_json_result(retcode=0, retmsg='success', data=all_metric_data) else: return get_json_result(retcode=0, retmsg='no data', data={})
def get_component_summary(): request_data = request.json try: tracker = Tracker(job_id=request_data["job_id"], component_name=request_data["component_name"], role=request_data["role"], party_id=request_data["party_id"], task_id=request_data.get("task_id", None), task_version=request_data.get("task_version", None)) summary = tracker.read_summary_from_db() if summary: if request_data.get("filename"): temp_filepath = os.path.join(TEMP_DIRECTORY, request_data.get("filename")) with open(temp_filepath, "w") as fout: fout.write(json.dumps(summary, indent=4)) return send_file( open(temp_filepath, "rb"), as_attachment=True, attachment_filename=request_data.get("filename")) else: return get_json_result(data=summary) return error_response( 210, "No component summary found, please check if arguments are specified correctly." ) except Exception as e: stat_logger.exception(e) return error_response(210, str(e))
def get_component_output_tables_meta(task_data): check_request_parameters(task_data) tracker = Tracker(job_id=task_data['job_id'], component_name=task_data['component_name'], role=task_data['role'], party_id=task_data['party_id']) output_data_table_infos = tracker.get_output_data_info() output_tables_meta = tracker.get_output_data_table( output_data_infos=output_data_table_infos) return output_tables_meta
def component_metrics(): request_data = request.json check_request_parameters(request_data) tracker = Tracker(job_id=request_data['job_id'], component_name=request_data['component_name'], role=request_data['role'], party_id=request_data['party_id']) metrics = tracker.get_metric_list() if metrics: return get_json_result(retcode=0, retmsg='success', data=metrics) else: return get_json_result(retcode=0, retmsg='no data', data={})
def get_table_meta(job_id, component_name, task_version, task_id, role, party_id): request_data = request.json tracker = Tracker(job_id=job_id, component_name=component_name, task_id=task_id, task_version=task_version, role=role, party_id=party_id) table_meta_dict = tracker.get_table_meta(request_data) return get_json_result(data=table_meta_dict)
def save_component_summary(job_id: str, component_name: str, task_version: int, task_id: str, role: str, party_id: int): request_data = request.json tracker = Tracker(job_id=job_id, component_name=component_name, task_id=task_id, task_version=task_version, role=role, party_id=party_id) summary_data = request_data['summary'] tracker.insert_summary_into_db(summary_data) return get_json_result()
def get_component_output_tables_meta(task_data): check_request_parameters(task_data) tracker = Tracker(job_id=task_data['job_id'], component_name=task_data['component_name'], role=task_data['role'], party_id=task_data['party_id']) job_dsl_parser = schedule_utils.get_job_dsl_parser_by_job_id(job_id=task_data['job_id']) if not job_dsl_parser: raise Exception('can not get dag parser, please check if the parameters are correct') component = job_dsl_parser.get_component_info(task_data['component_name']) if not component: raise Exception('can not found component, please check if the parameters are correct') output_data_table_infos = tracker.get_output_data_info() output_tables_meta = tracker.get_output_data_table(output_data_infos=output_data_table_infos) return output_tables_meta
def save_output_data_info(job_id, component_name, task_version, task_id, role, party_id): request_data = request.json tracker = Tracker(job_id=job_id, component_name=component_name, task_id=task_id, task_version=task_version, role=role, party_id=party_id) tracker.insert_output_data_info_into_db( data_name=request_data["data_name"], table_namespace=request_data["table_namespace"], table_name=request_data["table_name"]) return get_json_result()
def clean_table(cls, job_id, role, party_id, component_name): # clean data table stat_logger.info('start delete {} {} {} {} data table'.format( job_id, role, party_id, component_name)) tracker = Tracker(job_id=job_id, role=role, party_id=party_id, component_name=component_name) output_data_table_infos = tracker.get_output_data_info() if output_data_table_infos: delete_tables_by_table_infos(output_data_table_infos) stat_logger.info('delete {} {} {} {} data table success'.format( job_id, role, party_id, component_name))
def read_output_data_info(job_id, component_name, task_version, task_id, role, party_id): request_data = request.json tracker = Tracker(job_id=job_id, component_name=component_name, task_id=task_id, task_version=task_version, role=role, party_id=party_id) output_data_infos = tracker.read_output_data_info_from_db( data_name=request_data["data_name"]) response_data = [] for output_data_info in output_data_infos: response_data.append(output_data_info.to_human_model_dict()) return get_json_result(data=response_data)
def save_metric_meta(job_id, component_name, task_version, task_id, role, party_id): request_data = request.json tracker = Tracker(job_id=job_id, component_name=component_name, task_id=task_id, task_version=task_version, role=role, party_id=party_id) metric_meta = deserialize_b64(request_data['metric_meta']) tracker.save_metric_meta(metric_namespace=request_data['metric_namespace'], metric_name=request_data['metric_name'], metric_meta=metric_meta, job_level=request_data['job_level']) return get_json_result()
def save_component_model(job_id, component_name, task_version, task_id, role, party_id): request_data = request.json model_id = request_data.get("model_id") model_version = request_data.get("model_version") tracker = Tracker(job_id=job_id, component_name=component_name, task_id=task_id, task_version=task_version, role=role, party_id=party_id, model_id=model_id, model_version=model_version) tracker.write_output_model(request_data.get("component_model")) return get_json_result()
def get_input_table_info(parameters, role, party_id): search_type = data_utils.get_input_search_type(parameters) if search_type is InputSearchType.TABLE_INFO: return parameters["namespace"], parameters["name"] elif search_type is InputSearchType.JOB_COMPONENT_OUTPUT: output_data_infos = Tracker.query_output_data_infos( job_id=parameters["job_id"], component_name=parameters["component_name"], data_name=parameters["data_name"], role=role, party_id=party_id, ) if not output_data_infos: raise Exception( f"can not found input table, please check parameters") else: namespace, name = ( output_data_infos[0].f_table_namespace, output_data_infos[0].f_table_name, ) LOGGER.info( f"found input table {namespace} {name} by {parameters}") return namespace, name else: raise ParameterError( f"can not found input table info by parameters {parameters}")
def get_component_model(job_id, component_name, task_version, task_id, role, party_id): request_data = request.json model_id = request_data.get("model_id") model_version = request_data.get("model_version") tracker = Tracker(job_id=job_id, component_name=component_name, task_id=task_id, task_version=task_version, role=role, party_id=party_id, model_id=model_id, model_version=model_version) data = tracker.read_output_model( model_alias=request_data.get("search_model_alias"), parse=False) return get_json_result(data=data)
def __init__(self, job_id: str, role: str, party_id: int, model_id: str = None, model_version: str = None, component_name: str = None, component_module_name: str = None, task_id: str = None, task_version: int = None, job_parameters: RunParameters = None ): self.job_id = job_id self.role = role self.party_id = party_id self.model_id = model_id self.model_version = model_version self.component_name = component_name if component_name else 'pipeline' self.module_name = component_module_name if component_module_name else 'Pipeline' self.task_id = task_id self.task_version = task_version self.job_parameters = job_parameters self.job_tracker = Tracker(job_id=job_id, role=role, party_id=party_id, component_name=component_name, task_id=task_id, task_version=task_version, model_id=model_id, model_version=model_version, job_parameters=job_parameters)
def component_output_model(): request_data = request.json check_request_parameters(request_data) job_dsl, job_runtime_conf, runtime_conf_on_party, train_runtime_conf = job_utils.get_job_configuration(job_id=request_data['job_id'], role=request_data['role'], party_id=request_data['party_id']) try: model_id = runtime_conf_on_party['job_parameters']['model_id'] model_version = runtime_conf_on_party['job_parameters']['model_version'] except Exception as e: job_dsl, job_runtime_conf, train_runtime_conf = job_utils.get_model_configuration(job_id=request_data['job_id'], role=request_data['role'], party_id=request_data['party_id']) if any([job_dsl, job_runtime_conf, train_runtime_conf]): adapter = JobRuntimeConfigAdapter(job_runtime_conf) model_id = adapter.get_common_parameters().to_dict().get('model_id') model_version = adapter.get_common_parameters().to_dict.get('model_version') else: stat_logger.exception(e) stat_logger.error(f"Can not find model info by filters: job id: {request_data.get('job_id')}, " f"role: {request_data.get('role')}, party id: {request_data.get('party_id')}") raise Exception(f"Can not find model info by filters: job id: {request_data.get('job_id')}, " f"role: {request_data.get('role')}, party id: {request_data.get('party_id')}") tracker = Tracker(job_id=request_data['job_id'], component_name=request_data['component_name'], role=request_data['role'], party_id=request_data['party_id'], model_id=model_id, model_version=model_version) dag = schedule_utils.get_job_dsl_parser(dsl=job_dsl, runtime_conf=job_runtime_conf, train_runtime_conf=train_runtime_conf) component = dag.get_component_info(request_data['component_name']) output_model_json = {} # There is only one model output at the current dsl version. output_model = tracker.get_output_model(component.get_output()['model'][0] if component.get_output().get('model') else 'default') for buffer_name, buffer_object in output_model.items(): if buffer_name.endswith('Param'): output_model_json = json_format.MessageToDict(buffer_object, including_default_value_fields=True) if output_model_json: component_define = tracker.get_component_define() this_component_model_meta = {} for buffer_name, buffer_object in output_model.items(): if buffer_name.endswith('Meta'): this_component_model_meta['meta_data'] = json_format.MessageToDict(buffer_object, including_default_value_fields=True) this_component_model_meta.update(component_define) return get_json_result(retcode=0, retmsg='success', data=output_model_json, meta=this_component_model_meta) else: return get_json_result(retcode=0, retmsg='no data', data={})
def query_component_output_data_info(): output_data_infos = Tracker.query_output_data_infos(**request.json) if not output_data_infos: return get_json_result(retcode=101, retmsg='find data view failed') return get_json_result(retcode=0, retmsg='success', data=[ output_data_info.to_json() for output_data_info in output_data_infos ])
def component_metric_data(): request_data = request.json check_request_parameters(request_data) tracker = Tracker(job_id=request_data['job_id'], component_name=request_data['component_name'], role=request_data['role'], party_id=request_data['party_id']) metric_data, metric_meta = get_metric_all_data(tracker=tracker, metric_namespace=request_data['metric_namespace'], metric_name=request_data['metric_name']) if metric_data or metric_meta: return get_json_result(retcode=0, retmsg='success', data=metric_data, meta=metric_meta) else: return get_json_result(retcode=0, retmsg='no data', data=[], meta={})
def clean_task(cls, job_id, task_id, task_version, role, party_id, content_type: TaskCleanResourceType): status = set() if content_type == TaskCleanResourceType.METRICS: tracker = Tracker(job_id=job_id, role=role, party_id=party_id, task_id=task_id, task_version=task_version) status.add(tracker.clean_metrics()) elif content_type == TaskCleanResourceType.TABLE: jobs = JobSaver.query_job(job_id=job_id, role=role, party_id=party_id) if jobs: job = jobs[0] job_parameters = RunParameters( **job.f_runtime_conf_on_party["job_parameters"]) tracker = Tracker(job_id=job_id, role=role, party_id=party_id, task_id=task_id, task_version=task_version, job_parameters=job_parameters) status.add(tracker.clean_task(job.f_runtime_conf_on_party)) if len(status) == 1 and True in status: return True else: return False
def initialize_job_tracker(cls, job_id, role, party_id, job_parameters, roles, is_initiator, dsl_parser): tracker = Tracker(job_id=job_id, role=role, party_id=party_id, model_id=job_parameters["model_id"], model_version=job_parameters["model_version"]) if job_parameters.get("job_type", "") != "predict": tracker.init_pipelined_model() partner = {} show_role = {} for _role, _role_party in roles.items(): if is_initiator or _role == role: show_role[_role] = show_role.get(_role, []) for _party_id in _role_party: if is_initiator or _party_id == party_id: show_role[_role].append(_party_id) if _role != role: partner[_role] = partner.get(_role, []) partner[_role].extend(_role_party) else: for _party_id in _role_party: if _party_id != party_id: partner[_role] = partner.get(_role, []) partner[_role].append(_party_id) job_args = dsl_parser.get_args_input() dataset = cls.get_dataset(is_initiator, role, party_id, roles, job_args) tracker.log_job_view({ 'partner': partner, 'dataset': dataset, 'roles': show_role })
def load_task_tracker(cls, tasks: dict): tracker_dict = {} for key, task in tasks.items(): schedule_logger(task.f_job_id).info( f"task:{task.f_job_id}, {task.f_role}, {task.f_party_id},{task.f_component_name},{task.f_task_version}" ) tracker = Tracker(job_id=task.f_job_id, role=task.f_role, party_id=task.f_party_id, component_name=task.f_component_name, task_id=task.f_task_id, task_version=task.f_task_version) tracker_dict[key] = tracker return tracker_dict
def get_component_model_run_parameters(job_id, component_name, task_version, task_id, role, party_id): request_data = request.json model_id = request_data.get("model_id") model_version = request_data.get("model_version") tracker = Tracker(job_id=job_id, component_name=component_name, task_id=task_id, task_version=task_version, role=role, party_id=party_id, model_id=model_id, model_version=model_version) data = tracker.pipelined_model.read_model_run_parameters() return get_json_result(data=data)
def save_pipelined_model(cls, job_id, role, party_id): schedule_logger(job_id).info( 'job {} on {} {} start to save pipeline'.format( job_id, role, party_id)) job_dsl, job_runtime_conf, runtime_conf_on_party, train_runtime_conf = job_utils.get_job_configuration( job_id=job_id, role=role, party_id=party_id) job_parameters = runtime_conf_on_party.get('job_parameters', {}) if role in job_parameters.get("assistant_role", []): return model_id = job_parameters['model_id'] model_version = job_parameters['model_version'] job_type = job_parameters.get('job_type', '') work_mode = job_parameters['work_mode'] roles = runtime_conf_on_party['role'] initiator_role = runtime_conf_on_party['initiator']['role'] initiator_party_id = runtime_conf_on_party['initiator']['party_id'] if job_type == 'predict': return dag = schedule_utils.get_job_dsl_parser( dsl=job_dsl, runtime_conf=job_runtime_conf, train_runtime_conf=train_runtime_conf) predict_dsl = dag.get_predict_dsl(role=role) pipeline = pipeline_pb2.Pipeline() pipeline.inference_dsl = json_dumps(predict_dsl, byte=True) pipeline.train_dsl = json_dumps(job_dsl, byte=True) pipeline.train_runtime_conf = json_dumps(job_runtime_conf, byte=True) pipeline.fate_version = RuntimeConfig.get_env("FATE") pipeline.model_id = model_id pipeline.model_version = model_version pipeline.parent = True pipeline.loaded_times = 0 pipeline.roles = json_dumps(roles, byte=True) pipeline.work_mode = work_mode pipeline.initiator_role = initiator_role pipeline.initiator_party_id = initiator_party_id pipeline.runtime_conf_on_party = json_dumps(runtime_conf_on_party, byte=True) pipeline.parent_info = json_dumps({}, byte=True) tracker = Tracker(job_id=job_id, role=role, party_id=party_id, model_id=model_id, model_version=model_version) tracker.save_pipelined_model(pipelined_buffer_object=pipeline) if role != 'local': tracker.save_machine_learning_model_info() schedule_logger(job_id).info( 'job {} on {} {} save pipeline successfully'.format( job_id, role, party_id))
def component_output_data_table(job_id, component_name, role, party_id): output_data_infos = Tracker.query_output_data_infos( job_id=job_id, component_name=component_name, role=role, party_id=party_id) if output_data_infos: return get_json_result(retcode=0, retmsg='success', data=[{ 'table_name': output_data_info.f_table_name, 'table_namespace': output_data_info.f_table_namespace, "data_name": output_data_info.f_data_name } for output_data_info in output_data_infos]) else: return get_json_result( retcode=100, retmsg='No found table, please check if the parameters are correct' )
def get_job_all_table(job): dsl_parser = schedule_utils.get_job_dsl_parser(dsl=job.f_dsl, runtime_conf=job.f_runtime_conf, train_runtime_conf=job.f_train_runtime_conf ) _, hierarchical_structure = dsl_parser.get_dsl_hierarchical_structure() component_table = {} component_output_tables = Tracker.query_output_data_infos(job_id=job.f_job_id, role=job.f_role, party_id=job.f_party_id) for component_name_list in hierarchical_structure: for component_name in component_name_list: component_table[component_name] = {} component_input_table = get_component_input_table(dsl_parser, job, component_name) component_table[component_name]['input'] = component_input_table component_table[component_name]['output'] = {} for output_table in component_output_tables: if output_table.f_component_name == component_name: component_table[component_name]['output'][output_table.f_data_name] = \ {'name': output_table.f_table_name, 'namespace': output_table.f_table_namespace} return component_table
def save_pipelined_model(cls, job_id, role, party_id): schedule_logger(job_id).info( f"start to save pipeline model on {role} {party_id}") job_configuration = job_utils.get_job_configuration(job_id=job_id, role=role, party_id=party_id) runtime_conf_on_party = job_configuration.runtime_conf_on_party job_parameters = runtime_conf_on_party.get('job_parameters', {}) if role in job_parameters.get("assistant_role", []): return model_id = job_parameters['model_id'] model_version = job_parameters['model_version'] job_type = job_parameters.get('job_type', '') roles = runtime_conf_on_party['role'] initiator_role = runtime_conf_on_party['initiator']['role'] initiator_party_id = runtime_conf_on_party['initiator']['party_id'] if job_type == 'predict': return dsl_parser = schedule_utils.get_job_dsl_parser( dsl=job_configuration.dsl, runtime_conf=job_configuration.runtime_conf, train_runtime_conf=job_configuration.train_runtime_conf) components_parameters = {} tasks = JobSaver.query_task(job_id=job_id, role=role, party_id=party_id, only_latest=True) for task in tasks: components_parameters[ task.f_component_name] = task.f_component_parameters predict_dsl = schedule_utils.fill_inference_dsl( dsl_parser, origin_inference_dsl=job_configuration.dsl, components_parameters=components_parameters) pipeline = pipeline_pb2.Pipeline() pipeline.inference_dsl = json_dumps(predict_dsl, byte=True) pipeline.train_dsl = json_dumps(job_configuration.dsl, byte=True) pipeline.train_runtime_conf = json_dumps( job_configuration.runtime_conf, byte=True) pipeline.fate_version = RuntimeConfig.get_env("FATE") pipeline.model_id = model_id pipeline.model_version = model_version pipeline.parent = True pipeline.loaded_times = 0 pipeline.roles = json_dumps(roles, byte=True) pipeline.initiator_role = initiator_role pipeline.initiator_party_id = initiator_party_id pipeline.runtime_conf_on_party = json_dumps(runtime_conf_on_party, byte=True) pipeline.parent_info = json_dumps({}, byte=True) tracker = Tracker(job_id=job_id, role=role, party_id=party_id, model_id=model_id, model_version=model_version, job_parameters=RunParameters(**job_parameters)) tracker.save_pipeline_model(pipeline_buffer_object=pipeline) if role != 'local': tracker.save_machine_learning_model_info() schedule_logger(job_id).info( f"save pipeline on {role} {party_id} successfully")
def _run_(self): # todo: All function calls where errors should be thrown args = self.args start_time = current_timestamp() try: LOGGER.info( f'run {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} task' ) self.report_info.update({ "job_id": args.job_id, "component_name": args.component_name, "task_id": args.task_id, "task_version": args.task_version, "role": args.role, "party_id": args.party_id, "run_ip": args.run_ip, "run_pid": self.run_pid }) operation_client = OperationClient() job_configuration = JobConfiguration( **operation_client.get_job_conf( args.job_id, args.role, args.party_id, args.component_name, args.task_id, args.task_version)) task_parameters_conf = args.config dsl_parser = schedule_utils.get_job_dsl_parser( dsl=job_configuration.dsl, runtime_conf=job_configuration.runtime_conf, train_runtime_conf=job_configuration.train_runtime_conf, pipeline_dsl=None) job_parameters = dsl_parser.get_job_parameters( job_configuration.runtime_conf) user_name = job_parameters.get(args.role, {}).get(args.party_id, {}).get("user", '') LOGGER.info(f"user name:{user_name}") src_user = task_parameters_conf.get("src_user") task_parameters = RunParameters(**task_parameters_conf) job_parameters = task_parameters if job_parameters.assistant_role: TaskExecutor.monkey_patch() job_args_on_party = TaskExecutor.get_job_args_on_party( dsl_parser, job_configuration.runtime_conf_on_party, args.role, args.party_id) component = dsl_parser.get_component_info( component_name=args.component_name) module_name = component.get_module() task_input_dsl = component.get_input() task_output_dsl = component.get_output() kwargs = { 'job_id': args.job_id, 'role': args.role, 'party_id': args.party_id, 'component_name': args.component_name, 'task_id': args.task_id, 'task_version': args.task_version, 'model_id': job_parameters.model_id, 'model_version': job_parameters.model_version, 'component_module_name': module_name, 'job_parameters': job_parameters, } tracker = Tracker(**kwargs) tracker_client = TrackerClient(**kwargs) checkpoint_manager = CheckpointManager(**kwargs) self.report_info["party_status"] = TaskStatus.RUNNING self.report_task_info_to_driver() previous_components_parameters = tracker_client.get_model_run_parameters( ) LOGGER.info( f"previous_components_parameters:\n{json_dumps(previous_components_parameters, indent=4)}" ) component_provider, component_parameters_on_party, user_specified_parameters = ProviderManager.get_component_run_info( dsl_parser=dsl_parser, component_name=args.component_name, role=args.role, party_id=args.party_id, previous_components_parameters=previous_components_parameters) RuntimeConfig.set_component_provider(component_provider) LOGGER.info( f"component parameters on party:\n{json_dumps(component_parameters_on_party, indent=4)}" ) flow_feeded_parameters = { "output_data_name": task_output_dsl.get("data") } # init environment, process is shared globally RuntimeConfig.init_config( COMPUTING_ENGINE=job_parameters.computing_engine, FEDERATION_ENGINE=job_parameters.federation_engine, FEDERATED_MODE=job_parameters.federated_mode) if RuntimeConfig.COMPUTING_ENGINE == ComputingEngine.EGGROLL: session_options = task_parameters.eggroll_run.copy() session_options["python.path"] = os.getenv("PYTHONPATH") session_options["python.venv"] = os.getenv("VIRTUAL_ENV") else: session_options = {} sess = session.Session(session_id=args.session_id) sess.as_global() sess.init_computing(computing_session_id=args.session_id, options=session_options) component_parameters_on_party[ "job_parameters"] = job_parameters.to_dict() roles = job_configuration.runtime_conf["role"] if set(roles) == {"local"}: LOGGER.info(f"only local roles, pass init federation") else: sess.init_federation( federation_session_id=args.federation_session_id, runtime_conf=component_parameters_on_party, service_conf=job_parameters.engines_address.get( EngineType.FEDERATION, {})) LOGGER.info( f'run {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} task' ) LOGGER.info( f"component parameters on party:\n{json_dumps(component_parameters_on_party, indent=4)}" ) LOGGER.info(f"task input dsl {task_input_dsl}") task_run_args, input_table_list = self.get_task_run_args( job_id=args.job_id, role=args.role, party_id=args.party_id, task_id=args.task_id, task_version=args.task_version, job_args=job_args_on_party, job_parameters=job_parameters, task_parameters=task_parameters, input_dsl=task_input_dsl, ) if module_name in { "Upload", "Download", "Reader", "Writer", "Checkpoint" }: task_run_args["job_parameters"] = job_parameters LOGGER.info(f"task input args {task_run_args}") need_run = component_parameters_on_party.get("ComponentParam", {}).get( "need_run", True) provider_interface = provider_utils.get_provider_interface( provider=component_provider) run_object = provider_interface.get( module_name, ComponentRegistry.get_provider_components( provider_name=component_provider.name, provider_version=component_provider.version)).get_run_obj( self.args.role) flow_feeded_parameters.update({"table_info": input_table_list}) cpn_input = ComponentInput( tracker=tracker_client, checkpoint_manager=checkpoint_manager, task_version_id=job_utils.generate_task_version_id( args.task_id, args.task_version), parameters=component_parameters_on_party["ComponentParam"], datasets=task_run_args.get("data", None), caches=task_run_args.get("cache", None), models=dict( model=task_run_args.get("model"), isometric_model=task_run_args.get("isometric_model"), ), job_parameters=job_parameters, roles=dict( role=component_parameters_on_party["role"], local=component_parameters_on_party["local"], ), flow_feeded_parameters=flow_feeded_parameters, ) profile_log_enabled = False try: if int(os.getenv("FATE_PROFILE_LOG_ENABLED", "0")) > 0: profile_log_enabled = True except Exception as e: LOGGER.warning(e) if profile_log_enabled: # add profile logs LOGGER.info("profile logging is enabled") profile.profile_start() cpn_output = run_object.run(cpn_input) sess.wait_remote_all_done() profile.profile_ends() else: LOGGER.info("profile logging is disabled") cpn_output = run_object.run(cpn_input) sess.wait_remote_all_done() output_table_list = [] LOGGER.info(f"task output data {cpn_output.data}") for index, data in enumerate(cpn_output.data): data_name = task_output_dsl.get( 'data')[index] if task_output_dsl.get( 'data') else '{}'.format(index) #todo: the token depends on the engine type, maybe in job parameters persistent_table_namespace, persistent_table_name = tracker.save_output_data( computing_table=data, output_storage_engine=job_parameters.storage_engine, token={"username": user_name}) if persistent_table_namespace and persistent_table_name: tracker.log_output_data_info( data_name=data_name, table_namespace=persistent_table_namespace, table_name=persistent_table_name) output_table_list.append({ "namespace": persistent_table_namespace, "name": persistent_table_name }) self.log_output_data_table_tracker(args.job_id, input_table_list, output_table_list) # There is only one model output at the current dsl version. tracker_client.save_component_output_model( model_buffers=cpn_output.model, model_alias=task_output_dsl['model'][0] if task_output_dsl.get('model') else 'default', user_specified_run_parameters=user_specified_parameters) if cpn_output.cache is not None: for i, cache in enumerate(cpn_output.cache): if cache is None: continue name = task_output_dsl.get( "cache")[i] if "cache" in task_output_dsl else str(i) if isinstance(cache, DataCache): tracker.tracking_output_cache(cache, cache_name=name) elif isinstance(cache, tuple): tracker.save_output_cache( cache_data=cache[0], cache_meta=cache[1], cache_name=name, output_storage_engine=job_parameters. storage_engine, output_storage_address=job_parameters. engines_address.get(EngineType.STORAGE, {}), token={"username": user_name}) else: raise RuntimeError( f"can not support type {type(cache)} module run object output cache" ) if need_run: self.report_info["party_status"] = TaskStatus.SUCCESS else: self.report_info["party_status"] = TaskStatus.PASS except PassError as e: self.report_info["party_status"] = TaskStatus.PASS except Exception as e: traceback.print_exc() self.report_info["party_status"] = TaskStatus.FAILED LOGGER.exception(e) finally: try: self.report_info["end_time"] = current_timestamp() self.report_info[ "elapsed"] = self.report_info["end_time"] - start_time self.report_task_info_to_driver() except Exception as e: self.report_info["party_status"] = TaskStatus.FAILED traceback.print_exc() LOGGER.exception(e) msg = f"finish {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} with {self.report_info['party_status']}" LOGGER.info(msg) print(msg) return self.report_info
def get_task_run_args(cls, job_id, role, party_id, task_id, task_version, job_args, job_parameters: RunParameters, task_parameters: RunParameters, input_dsl, filter_type=None, filter_attr=None, get_input_table=False): task_run_args = {} input_table = {} input_table_info_list = [] if 'idmapping' in role: return {} for input_type, input_detail in input_dsl.items(): if filter_type and input_type not in filter_type: continue if input_type == 'data': this_type_args = task_run_args[input_type] = task_run_args.get( input_type, {}) for data_type, data_list in input_detail.items(): data_dict = {} for data_key in data_list: data_key_item = data_key.split('.') data_dict[data_key_item[0]] = {data_type: []} for data_key in data_list: data_key_item = data_key.split('.') search_component_name, search_data_name = data_key_item[ 0], data_key_item[1] storage_table_meta = None tracker_client = TrackerClient( job_id=job_id, role=role, party_id=party_id, component_name=search_component_name, task_id=task_id, task_version=task_version) if search_component_name == 'args': if job_args.get( 'data', {}).get(search_data_name).get( 'namespace', '') and job_args.get( 'data', {}).get(search_data_name).get( 'name', ''): storage_table_meta = storage.StorageTableMeta( name=job_args['data'][search_data_name] ['name'], namespace=job_args['data'] [search_data_name]['namespace']) else: upstream_output_table_infos_json = tracker_client.get_output_data_info( data_name=search_data_name) if upstream_output_table_infos_json: tracker = Tracker( job_id=job_id, role=role, party_id=party_id, component_name=search_component_name, task_id=task_id, task_version=task_version) upstream_output_table_infos = [] for _ in upstream_output_table_infos_json: upstream_output_table_infos.append( fill_db_model_object( Tracker.get_dynamic_db_model( TrackingOutputDataInfo, job_id)(), _)) output_tables_meta = tracker.get_output_data_table( upstream_output_table_infos) if output_tables_meta: storage_table_meta = output_tables_meta.get( search_data_name, None) args_from_component = this_type_args[ search_component_name] = this_type_args.get( search_component_name, {}) if get_input_table and storage_table_meta: input_table[data_key] = { 'namespace': storage_table_meta.get_namespace(), 'name': storage_table_meta.get_name() } computing_table = None elif storage_table_meta: LOGGER.info( f"load computing table use {task_parameters.computing_partitions}" ) computing_table = session.get_computing_session( ).load(storage_table_meta.get_address(), schema=storage_table_meta.get_schema(), partitions=task_parameters. computing_partitions) input_table_info_list.append({ 'namespace': storage_table_meta.get_namespace(), 'name': storage_table_meta.get_name() }) else: computing_table = None if not computing_table or not filter_attr or not filter_attr.get( "data", None): data_dict[search_component_name][data_type].append( computing_table) args_from_component[data_type] = data_dict[ search_component_name][data_type] else: args_from_component[data_type] = dict([ (a, getattr(computing_table, "get_{}".format(a))()) for a in filter_attr["data"] ]) elif input_type == "cache": this_type_args = task_run_args[input_type] = task_run_args.get( input_type, {}) for search_key in input_detail: search_component_name, cache_name = search_key.split(".") tracker = Tracker(job_id=job_id, role=role, party_id=party_id, component_name=search_component_name) this_type_args[ search_component_name] = tracker.get_output_cache( cache_name=cache_name) elif input_type in {'model', 'isometric_model'}: this_type_args = task_run_args[input_type] = task_run_args.get( input_type, {}) for dsl_model_key in input_detail: dsl_model_key_items = dsl_model_key.split('.') if len(dsl_model_key_items) == 2: search_component_name, search_model_alias = dsl_model_key_items[ 0], dsl_model_key_items[1] elif len(dsl_model_key_items ) == 3 and dsl_model_key_items[0] == 'pipeline': search_component_name, search_model_alias = dsl_model_key_items[ 1], dsl_model_key_items[2] else: raise Exception( 'get input {} failed'.format(input_type)) tracker_client = TrackerClient( job_id=job_id, role=role, party_id=party_id, component_name=search_component_name, model_id=job_parameters.model_id, model_version=job_parameters.model_version) models = tracker_client.read_component_output_model( search_model_alias) this_type_args[search_component_name] = models else: raise Exception(f"not support {input_type} input type") if get_input_table: return input_table return task_run_args, input_table_info_list