def clean_task(self, roles, party_ids): schedule_logger(self.job_id).info('clean task {} on {} {}'.format(self.task_id, self.role, self.party_id)) try: for role in roles.split(','): for party_id in party_ids.split(','): # clean up temporary tables namespace_clean = job_utils.generate_session_id(task_id=self.task_id, role=role, party_id=party_id) session.clean_tables(namespace=namespace_clean, regex_string='*') schedule_logger(self.job_id).info('clean table by namespace {} on {} {} done'.format(namespace_clean, self.role, self.party_id)) # clean up the last tables of the federation namespace_clean = self.task_id session.clean_tables(namespace=namespace_clean, regex_string='*') schedule_logger(self.job_id).info('clean table by namespace {} on {} {} done'.format(namespace_clean, self.role, self.party_id)) except Exception as e: schedule_logger(self.job_id).exception(e) schedule_logger(self.job_id).info('clean task {} on {} {} done'.format(self.task_id, self.role, self.party_id))
def clean_task(self, runtime_conf): schedule_logger(self.job_id).info('clean task {} {} on {} {}'.format( self.task_id, self.task_version, self.role, self.party_id)) try: sess = session.Session( computing_type=self.job_parameters.computing_engine, federation_type=self.job_parameters.federation_engine) # clean up temporary tables computing_temp_namespace = job_utils.generate_session_id( task_id=self.task_id, task_version=self.task_version, role=self.role, party_id=self.party_id) if self.job_parameters.computing_engine == ComputingEngine.EGGROLL: session_options = {"eggroll.session.processors.per.node": 1} else: session_options = {} sess.init_computing( computing_session_id=f"{computing_temp_namespace}_clean", options=session_options) sess.computing.cleanup(namespace=computing_temp_namespace, name="*") schedule_logger(self.job_id).info( 'clean table by namespace {} on {} {} done'.format( computing_temp_namespace, self.role, self.party_id)) # clean up the last tables of the federation federation_temp_namespace = job_utils.generate_task_version_id( self.task_id, self.task_version) sess.computing.cleanup(namespace=federation_temp_namespace, name="*") schedule_logger(self.job_id).info( 'clean table by namespace {} on {} {} done'.format( federation_temp_namespace, self.role, self.party_id)) sess.computing.stop() if self.job_parameters.federation_engine == FederationEngine.RABBITMQ and self.role != "local": schedule_logger(self.job_id).info('rabbitmq start clean up') parties = [ Party(k, p) for k, v in runtime_conf['role'].items() for p in v ] federation_session_id = job_utils.generate_task_version_id( self.task_id, self.task_version) component_parameters_on_party = copy.deepcopy(runtime_conf) component_parameters_on_party["local"] = { "role": self.role, "party_id": self.party_id } sess.init_federation( federation_session_id=federation_session_id, runtime_conf=component_parameters_on_party, service_conf=self.job_parameters.engines_address.get( EngineType.FEDERATION, {})) sess._federation_session.cleanup(parties) schedule_logger(self.job_id).info('rabbitmq clean up success') return True except Exception as e: schedule_logger(self.job_id).exception(e) return False
def run(self, component_parameters=None, args=None): self.parameters = component_parameters["DownloadParam"] self.parameters["role"] = component_parameters["role"] self.parameters["local"] = component_parameters["local"] name, namespace = self.parameters.get("name"), self.parameters.get( "namespace") with open(os.path.abspath(self.parameters["output_path"]), "w") as fout: with storage.Session.build( session_id=job_utils.generate_session_id( self.tracker.task_id, self.tracker.task_version, self.tracker.role, self.tracker.party_id, suffix="storage", random_end=True), name=name, namespace=namespace) as storage_session: data_table = storage_session.get_table() count = data_table.count() LOGGER.info('===== begin to export data =====') lines = 0 job_info = {} job_info["job_id"] = self.tracker.job_id job_info["role"] = self.tracker.role job_info["party_id"] = self.tracker.party_id for key, value in data_table.collect(): if not value: fout.write(key + "\n") else: fout.write(key + self.parameters.get("delimiter", ",") + str(value) + "\n") lines += 1 if lines % 2000 == 0: LOGGER.info( "===== export {} lines =====".format(lines)) if lines % 10000 == 0: job_info["progress"] = lines / count * 100 // 1 ControllerClient.update_job(job_info=job_info) job_info["progress"] = 100 ControllerClient.update_job(job_info=job_info) self.callback_metric( metric_name='data_access', metric_namespace='download', metric_data=[Metric("count", data_table.count())]) LOGGER.info("===== export {} lines totally =====".format(lines)) LOGGER.info('===== export data finish =====') LOGGER.info('===== export data file path:{} ====='.format( os.path.abspath(self.parameters["output_path"])))
def _run(self, cpn_input: ComponentInputProtocol): self.parameters = cpn_input.parameters self.parameters["role"] = cpn_input.roles["role"] self.parameters["local"] = cpn_input.roles["local"] name, namespace = self.parameters.get("name"), self.parameters.get( "namespace") with open(os.path.abspath(self.parameters["output_path"]), "w") as fw: session = Session( job_utils.generate_session_id( self.tracker.task_id, self.tracker.task_version, self.tracker.role, self.tracker.party_id, )) data_table = session.get_table(name=name, namespace=namespace) if not data_table: raise Exception(f"no found table {name} {namespace}") count = data_table.count() LOGGER.info("===== begin to export data =====") lines = 0 job_info = {} job_info["job_id"] = self.tracker.job_id job_info["role"] = self.tracker.role job_info["party_id"] = self.tracker.party_id for key, value in data_table.collect(): if not value: fw.write(key + "\n") else: fw.write(key + self.parameters.get("delimiter", ",") + str(value) + "\n") lines += 1 if lines % 2000 == 0: LOGGER.info("===== export {} lines =====".format(lines)) if lines % 10000 == 0: job_info["progress"] = lines / count * 100 // 1 ControllerClient.update_job(job_info=job_info) job_info["progress"] = 100 ControllerClient.update_job(job_info=job_info) self.callback_metric( metric_name="data_access", metric_namespace="download", metric_data=[Metric("count", data_table.count())], ) LOGGER.info("===== export {} lines totally =====".format(lines)) LOGGER.info("===== export data finish =====") LOGGER.info("===== export data file path:{} =====".format( os.path.abspath(self.parameters["output_path"])))
def run(self, component_parameters=None, args=None): self.parameters = component_parameters["ReaderParam"] output_storage_address = args["job_parameters"].engines_address[ EngineType.STORAGE] table_key = [key for key in self.parameters.keys()][0] computing_engine = args["job_parameters"].computing_engine output_table_namespace, output_table_name = data_utils.default_output_table_info( task_id=self.tracker.task_id, task_version=self.tracker.task_version) input_table_meta, output_table_address, output_table_engine = self.convert_check( input_name=self.parameters[table_key]['name'], input_namespace=self.parameters[table_key]['namespace'], output_name=output_table_name, output_namespace=output_table_namespace, computing_engine=computing_engine, output_storage_address=output_storage_address) with storage.Session.build(session_id=job_utils.generate_session_id( self.tracker.task_id, self.tracker.task_version, self.tracker.role, self.tracker.party_id, suffix="storage", random_end=True), storage_engine=input_table_meta.get_engine( )) as input_table_session: input_table = input_table_session.get_table( name=input_table_meta.get_name(), namespace=input_table_meta.get_namespace()) # update real count to meta info input_table.count() # Table replication is required if input_table_meta.get_engine() != output_table_engine: LOGGER.info( f"the {input_table_meta.get_engine()} engine input table needs to be converted to {output_table_engine} engine to support computing engine {computing_engine}" ) else: LOGGER.info( f"the {input_table_meta.get_engine()} input table needs to be transform format" ) with storage.Session.build( session_id=job_utils.generate_session_id( self.tracker.task_id, self.tracker.task_version, self.tracker.role, self.tracker.party_id, suffix="storage", random_end=True), storage_engine=output_table_engine ) as output_table_session: output_table = output_table_session.create_table( address=output_table_address, name=output_table_name, namespace=output_table_namespace, partitions=input_table_meta.partitions) self.copy_table(src_table=input_table, dest_table=output_table) # update real count to meta info output_table.count() output_table_meta = StorageTableMeta( name=output_table.get_name(), namespace=output_table.get_namespace()) self.tracker.log_output_data_info( data_name=component_parameters.get('output_data_name')[0] if component_parameters.get('output_data_name') else table_key, table_namespace=output_table_meta.get_namespace(), table_name=output_table_meta.get_name()) headers_str = output_table_meta.get_schema().get('header') table_info = {} if output_table_meta.get_schema() and headers_str: if isinstance(headers_str, str): data_list = [headers_str.split(',')] is_display = True else: data_list = [headers_str] is_display = False if is_display: for data in output_table_meta.get_part_of_data(): data_list.append(data[1].split(',')) data = np.array(data_list) Tdata = data.transpose() for data in Tdata: table_info[data[0]] = ','.join(list(set(data[1:]))[:5]) data_info = { "table_name": self.parameters[table_key]['name'], "namespace": self.parameters[table_key]['namespace'], "table_info": table_info, "partitions": output_table_meta.get_partitions(), "storage_engine": output_table_meta.get_engine() } if input_table_meta.get_engine() in [StorageEngine.PATH]: data_info["file_count"] = output_table_meta.get_count() data_info["file_path"] = input_table_meta.get_address().path else: data_info["count"] = output_table_meta.get_count() self.tracker.set_metric_meta(metric_namespace="reader_namespace", metric_name="reader_name", metric_meta=MetricMeta( name='reader', metric_type='data_info', extra_metas=data_info))
def run_task(cls): task_info = {} try: parser = argparse.ArgumentParser() parser.add_argument('-j', '--job_id', required=True, type=str, help="job id") parser.add_argument('-n', '--component_name', required=True, type=str, help="component name") parser.add_argument('-t', '--task_id', required=True, type=str, help="task id") parser.add_argument('-v', '--task_version', required=True, type=int, help="task version") parser.add_argument('-r', '--role', required=True, type=str, help="role") parser.add_argument('-p', '--party_id', required=True, type=int, help="party id") parser.add_argument('-c', '--config', required=True, type=str, help="task parameters") parser.add_argument('--run_ip', help="run ip", type=str) parser.add_argument('--job_server', help="job server", type=str) args = parser.parse_args() schedule_logger(args.job_id).info('enter task process') schedule_logger(args.job_id).info(args) # init function args if args.job_server: RuntimeConfig.init_config( JOB_SERVER_HOST=args.job_server.split(':')[0], HTTP_PORT=args.job_server.split(':')[1]) RuntimeConfig.set_process_role(ProcessRole.EXECUTOR) job_id = args.job_id component_name = args.component_name task_id = args.task_id task_version = args.task_version role = args.role party_id = args.party_id executor_pid = os.getpid() task_info.update({ "job_id": job_id, "component_name": component_name, "task_id": task_id, "task_version": task_version, "role": role, "party_id": party_id, "run_ip": args.run_ip, "run_pid": executor_pid }) start_time = current_timestamp() job_conf = job_utils.get_job_conf(job_id, role) job_dsl = job_conf["job_dsl_path"] job_runtime_conf = job_conf["job_runtime_conf_path"] dsl_parser = schedule_utils.get_job_dsl_parser( dsl=job_dsl, runtime_conf=job_runtime_conf, train_runtime_conf=job_conf["train_runtime_conf_path"], pipeline_dsl=job_conf["pipeline_dsl_path"]) party_index = job_runtime_conf["role"][role].index(party_id) job_args_on_party = TaskExecutor.get_job_args_on_party( dsl_parser, job_runtime_conf, role, party_id) component = dsl_parser.get_component_info( component_name=component_name) component_parameters = component.get_role_parameters() component_parameters_on_party = component_parameters[role][ party_index] if role in component_parameters else {} module_name = component.get_module() task_input_dsl = component.get_input() task_output_dsl = component.get_output() component_parameters_on_party[ 'output_data_name'] = task_output_dsl.get('data') task_parameters = RunParameters( **file_utils.load_json_conf(args.config)) job_parameters = task_parameters if job_parameters.assistant_role: TaskExecutor.monkey_patch() except Exception as e: traceback.print_exc() schedule_logger().exception(e) task_info["party_status"] = TaskStatus.FAILED return try: job_log_dir = os.path.join( job_utils.get_job_log_directory(job_id=job_id), role, str(party_id)) task_log_dir = os.path.join(job_log_dir, component_name) log.LoggerFactory.set_directory(directory=task_log_dir, parent_log_dir=job_log_dir, append_to_parent_log=True, force=True) tracker = Tracker(job_id=job_id, role=role, party_id=party_id, component_name=component_name, task_id=task_id, task_version=task_version, model_id=job_parameters.model_id, model_version=job_parameters.model_version, component_module_name=module_name, job_parameters=job_parameters) tracker_client = TrackerClient( job_id=job_id, role=role, party_id=party_id, component_name=component_name, task_id=task_id, task_version=task_version, model_id=job_parameters.model_id, model_version=job_parameters.model_version, component_module_name=module_name, job_parameters=job_parameters) run_class_paths = component_parameters_on_party.get( 'CodePath').split('/') run_class_package = '.'.join( run_class_paths[:-2]) + '.' + run_class_paths[-2].replace( '.py', '') run_class_name = run_class_paths[-1] task_info["party_status"] = TaskStatus.RUNNING cls.report_task_update_to_driver(task_info=task_info) # init environment, process is shared globally RuntimeConfig.init_config( WORK_MODE=job_parameters.work_mode, COMPUTING_ENGINE=job_parameters.computing_engine, FEDERATION_ENGINE=job_parameters.federation_engine, FEDERATED_MODE=job_parameters.federated_mode) if RuntimeConfig.COMPUTING_ENGINE == ComputingEngine.EGGROLL: session_options = task_parameters.eggroll_run.copy() else: session_options = {} sess = session.Session( computing_type=job_parameters.computing_engine, federation_type=job_parameters.federation_engine) computing_session_id = job_utils.generate_session_id( task_id, task_version, role, party_id) sess.init_computing(computing_session_id=computing_session_id, options=session_options) federation_session_id = job_utils.generate_task_version_id( task_id, task_version) component_parameters_on_party[ "job_parameters"] = job_parameters.to_dict() sess.init_federation( federation_session_id=federation_session_id, runtime_conf=component_parameters_on_party, service_conf=job_parameters.engines_address.get( EngineType.FEDERATION, {})) sess.as_default() schedule_logger().info('Run {} {} {} {} {} task'.format( job_id, component_name, task_id, role, party_id)) schedule_logger().info("Component parameters on party {}".format( component_parameters_on_party)) schedule_logger().info("Task input dsl {}".format(task_input_dsl)) task_run_args = cls.get_task_run_args( job_id=job_id, role=role, party_id=party_id, task_id=task_id, task_version=task_version, job_args=job_args_on_party, job_parameters=job_parameters, task_parameters=task_parameters, input_dsl=task_input_dsl, ) if module_name in {"Upload", "Download", "Reader", "Writer"}: task_run_args["job_parameters"] = job_parameters run_object = getattr(importlib.import_module(run_class_package), run_class_name)() run_object.set_tracker(tracker=tracker_client) run_object.set_task_version_id( task_version_id=job_utils.generate_task_version_id( task_id, task_version)) # add profile logs profile.profile_start() run_object.run(component_parameters_on_party, task_run_args) profile.profile_ends() output_data = run_object.save_data() if not isinstance(output_data, list): output_data = [output_data] for index in range(0, len(output_data)): data_name = task_output_dsl.get( 'data')[index] if task_output_dsl.get( 'data') else '{}'.format(index) persistent_table_namespace, persistent_table_name = tracker.save_output_data( computing_table=output_data[index], output_storage_engine=job_parameters.storage_engine, output_storage_address=job_parameters.engines_address.get( EngineType.STORAGE, {})) if persistent_table_namespace and persistent_table_name: tracker.log_output_data_info( data_name=data_name, table_namespace=persistent_table_namespace, table_name=persistent_table_name) output_model = run_object.export_model() # There is only one model output at the current dsl version. tracker.save_output_model( output_model, task_output_dsl['model'][0] if task_output_dsl.get('model') else 'default') task_info["party_status"] = TaskStatus.SUCCESS except Exception as e: task_info["party_status"] = TaskStatus.FAILED schedule_logger().exception(e) finally: try: task_info["end_time"] = current_timestamp() task_info["elapsed"] = task_info["end_time"] - start_time cls.report_task_update_to_driver(task_info=task_info) except Exception as e: task_info["party_status"] = TaskStatus.FAILED traceback.print_exc() schedule_logger().exception(e) schedule_logger().info('task {} {} {} start time: {}'.format( task_id, role, party_id, timestamp_to_date(start_time))) schedule_logger().info('task {} {} {} end time: {}'.format( task_id, role, party_id, timestamp_to_date(task_info["end_time"]))) schedule_logger().info('task {} {} {} takes {}s'.format( task_id, role, party_id, int(task_info["elapsed"]) / 1000)) schedule_logger().info('Finish {} {} {} {} {} {} task {}'.format( job_id, component_name, task_id, task_version, role, party_id, task_info["party_status"])) print('Finish {} {} {} {} {} {} task {}'.format( job_id, component_name, task_id, task_version, role, party_id, task_info["party_status"])) return task_info
def run_task(): task = Task() task.f_create_time = current_timestamp() try: parser = argparse.ArgumentParser() parser.add_argument('-j', '--job_id', required=True, type=str, help="job id") parser.add_argument('-n', '--component_name', required=True, type=str, help="component name") parser.add_argument('-t', '--task_id', required=True, type=str, help="task id") parser.add_argument('-r', '--role', required=True, type=str, help="role") parser.add_argument('-p', '--party_id', required=True, type=str, help="party id") parser.add_argument('-c', '--config', required=True, type=str, help="task config") parser.add_argument('--processors_per_node', help="processors_per_node", type=int) parser.add_argument('--job_server', help="job server", type=str) args = parser.parse_args() schedule_logger(args.job_id).info('enter task process') schedule_logger(args.job_id).info(args) # init function args if args.job_server: RuntimeConfig.init_config(HTTP_PORT=args.job_server.split(':')[1]) RuntimeConfig.set_process_role(ProcessRole.EXECUTOR) job_id = args.job_id component_name = args.component_name task_id = args.task_id role = args.role party_id = int(args.party_id) executor_pid = os.getpid() task_config = file_utils.load_json_conf(args.config) job_parameters = task_config['job_parameters'] job_initiator = task_config['job_initiator'] job_args = task_config['job_args'] task_input_dsl = task_config['input'] task_output_dsl = task_config['output'] component_parameters = TaskExecutor.get_parameters(job_id, component_name, role, party_id) task_parameters = task_config['task_parameters'] module_name = task_config['module_name'] TaskExecutor.monkey_patch() except Exception as e: traceback.print_exc() schedule_logger().exception(e) task.f_status = TaskStatus.FAILED return try: job_log_dir = os.path.join(job_utils.get_job_log_directory(job_id=job_id), role, str(party_id)) task_log_dir = os.path.join(job_log_dir, component_name) log_utils.LoggerFactory.set_directory(directory=task_log_dir, parent_log_dir=job_log_dir, append_to_parent_log=True, force=True) task.f_job_id = job_id task.f_component_name = component_name task.f_task_id = task_id task.f_role = role task.f_party_id = party_id task.f_operator = 'python_operator' tracker = Tracking(job_id=job_id, role=role, party_id=party_id, component_name=component_name, task_id=task_id, model_id=job_parameters['model_id'], model_version=job_parameters['model_version'], component_module_name=module_name) task.f_start_time = current_timestamp() task.f_run_ip = get_lan_ip() task.f_run_pid = executor_pid run_class_paths = component_parameters.get('CodePath').split('/') run_class_package = '.'.join(run_class_paths[:-2]) + '.' + run_class_paths[-2].replace('.py', '') run_class_name = run_class_paths[-1] task.f_status = TaskStatus.RUNNING TaskExecutor.sync_task_status(job_id=job_id, component_name=component_name, task_id=task_id, role=role, party_id=party_id, initiator_party_id=job_initiator.get('party_id', None), initiator_role=job_initiator.get('role', None), task_info=task.to_json()) # init environment, process is shared globally RuntimeConfig.init_config(WORK_MODE=job_parameters['work_mode'], BACKEND=job_parameters.get('backend', 0)) if args.processors_per_node and args.processors_per_node > 0 and RuntimeConfig.BACKEND == Backend.EGGROLL: session_options = {"eggroll.session.processors.per.node": args.processors_per_node} else: session_options = {} session.init(job_id=job_utils.generate_session_id(task_id, role, party_id), mode=RuntimeConfig.WORK_MODE, backend=RuntimeConfig.BACKEND, options=session_options) federation.init(job_id=task_id, runtime_conf=component_parameters) schedule_logger().info('run {} {} {} {} {} task'.format(job_id, component_name, task_id, role, party_id)) schedule_logger().info(component_parameters) schedule_logger().info(task_input_dsl) task_run_args = TaskExecutor.get_task_run_args(job_id=job_id, role=role, party_id=party_id, task_id=task_id, job_args=job_args, job_parameters=job_parameters, task_parameters=task_parameters, input_dsl=task_input_dsl, if_save_as_task_input_data=job_parameters.get("save_as_task_input_data", SAVE_AS_TASK_INPUT_DATA_SWITCH) ) run_object = getattr(importlib.import_module(run_class_package), run_class_name)() run_object.set_tracker(tracker=tracker) run_object.set_taskid(taskid=task_id) run_object.run(component_parameters, task_run_args) output_data = run_object.save_data() tracker.save_output_data_table(output_data, task_output_dsl.get('data')[0] if task_output_dsl.get('data') else 'component') output_model = run_object.export_model() # There is only one model output at the current dsl version. tracker.save_output_model(output_model, task_output_dsl['model'][0] if task_output_dsl.get('model') else 'default') task.f_status = TaskStatus.COMPLETE except Exception as e: task.f_status = TaskStatus.FAILED schedule_logger().exception(e) finally: sync_success = False try: task.f_end_time = current_timestamp() task.f_elapsed = task.f_end_time - task.f_start_time task.f_update_time = current_timestamp() TaskExecutor.sync_task_status(job_id=job_id, component_name=component_name, task_id=task_id, role=role, party_id=party_id, initiator_party_id=job_initiator.get('party_id', None), initiator_role=job_initiator.get('role', None), task_info=task.to_json()) sync_success = True except Exception as e: traceback.print_exc() schedule_logger().exception(e) schedule_logger().info('task {} {} {} start time: {}'.format(task_id, role, party_id, timestamp_to_date(task.f_start_time))) schedule_logger().info('task {} {} {} end time: {}'.format(task_id, role, party_id, timestamp_to_date(task.f_end_time))) schedule_logger().info('task {} {} {} takes {}s'.format(task_id, role, party_id, int(task.f_elapsed)/1000)) schedule_logger().info( 'finish {} {} {} {} {} {} task'.format(job_id, component_name, task_id, role, party_id, task.f_status if sync_success else TaskStatus.FAILED)) print('finish {} {} {} {} {} {} task'.format(job_id, component_name, task_id, role, party_id, task.f_status if sync_success else TaskStatus.FAILED))
def get_task_run_args(job_id, role, party_id, task_id, job_args, job_parameters, task_parameters, input_dsl, if_save_as_task_input_data, filter_type=None, filter_attr=None): task_run_args = {} for input_type, input_detail in input_dsl.items(): if filter_type and input_type not in filter_type: continue if input_type == 'data': this_type_args = task_run_args[input_type] = task_run_args.get(input_type, {}) for data_type, data_list in input_detail.items(): for data_key in data_list: data_key_item = data_key.split('.') search_component_name, search_data_name = data_key_item[0], data_key_item[1] if search_component_name == 'args': if job_args.get('data', {}).get(search_data_name).get('namespace', '') and job_args.get( 'data', {}).get(search_data_name).get('name', ''): data_table = session.table( namespace=job_args['data'][search_data_name]['namespace'], name=job_args['data'][search_data_name]['name']) else: data_table = None else: data_table = Tracking(job_id=job_id, role=role, party_id=party_id, component_name=search_component_name).get_output_data_table( data_name=search_data_name) args_from_component = this_type_args[search_component_name] = this_type_args.get( search_component_name, {}) # todo: If the same component has more than one identical input, save as is repeated if if_save_as_task_input_data: if data_table: schedule_logger().info("start save as task {} input data table {} {}".format( task_id, data_table.get_namespace(), data_table.get_name())) origin_table_metas = data_table.get_metas() origin_table_schema = data_table.schema save_as_options = {"store_type": StoreTypes.ROLLPAIR_IN_MEMORY} if SAVE_AS_TASK_INPUT_DATA_IN_MEMORY else {} data_table = data_table.save_as( namespace=job_utils.generate_session_id(task_id=task_id, role=role, party_id=party_id), name=data_table.get_name(), partition=task_parameters['input_data_partition'] if task_parameters.get('input_data_partition', 0) > 0 else data_table.get_partitions(), options=save_as_options) data_table.save_metas(origin_table_metas) data_table.schema = origin_table_schema schedule_logger().info("save as task {} input data table to {} {} done".format( task_id, data_table.get_namespace(), data_table.get_name())) else: schedule_logger().info("pass save as task {} input data table, because the table is none".format(task_id)) else: schedule_logger().info("pass save as task {} input data table, because the switch is off".format(task_id)) if not data_table or not filter_attr or not filter_attr.get("data", None): args_from_component[data_type] = data_table else: args_from_component[data_type] = dict([(a, getattr(data_table, "get_{}".format(a))()) for a in filter_attr["data"]]) elif input_type in ['model', 'isometric_model']: this_type_args = task_run_args[input_type] = task_run_args.get(input_type, {}) for dsl_model_key in input_detail: dsl_model_key_items = dsl_model_key.split('.') if len(dsl_model_key_items) == 2: search_component_name, search_model_alias = dsl_model_key_items[0], dsl_model_key_items[1] elif len(dsl_model_key_items) == 3 and dsl_model_key_items[0] == 'pipeline': search_component_name, search_model_alias = dsl_model_key_items[1], dsl_model_key_items[2] else: raise Exception('get input {} failed'.format(input_type)) models = Tracking(job_id=job_id, role=role, party_id=party_id, component_name=search_component_name, model_id=job_parameters['model_id'], model_version=job_parameters['model_version']).get_output_model( model_alias=search_model_alias) this_type_args[search_component_name] = models return task_run_args
def run(self, component_parameters=None, args=None): self.parameters = component_parameters["UploadParam"] LOGGER.info(self.parameters) LOGGER.info(args) self.parameters["role"] = component_parameters["role"] self.parameters["local"] = component_parameters["local"] storage_engine = self.parameters["storage_engine"] storage_address = self.parameters["storage_address"] # if not set storage, use job storage as default if not storage_engine: storage_engine = args["job_parameters"].storage_engine if not storage_address: storage_address = args["job_parameters"].engines_address[ EngineType.STORAGE] job_id = self.task_version_id.split("_")[0] if not os.path.isabs(self.parameters.get("file", "")): self.parameters["file"] = os.path.join( file_utils.get_project_base_directory(), self.parameters["file"]) if not os.path.exists(self.parameters["file"]): raise Exception("%s is not exist, please check the configure" % (self.parameters["file"])) if not os.path.getsize(self.parameters["file"]): raise Exception("%s is an empty file" % (self.parameters["file"])) name, namespace = self.parameters.get("name"), self.parameters.get( "namespace") _namespace, _table_name = self.generate_table_name( self.parameters["file"]) if namespace is None: namespace = _namespace if name is None: name = _table_name read_head = self.parameters['head'] if read_head == 0: head = False elif read_head == 1: head = True else: raise Exception("'head' in conf.json should be 0 or 1") partitions = self.parameters["partition"] if partitions <= 0 or partitions >= self.MAX_PARTITIONS: raise Exception( "Error number of partition, it should between %d and %d" % (0, self.MAX_PARTITIONS)) with storage.Session.build(session_id=job_utils.generate_session_id( self.tracker.task_id, self.tracker.task_version, self.tracker.role, self.tracker.party_id, suffix="storage", random_end=True), namespace=namespace, name=name) as storage_session: if self.parameters.get("destroy", False): table = storage_session.get_table() if table: LOGGER.info( f"destroy table name: {name} namespace: {namespace} engine: {table.get_engine()}" ) table.destroy() else: LOGGER.info( f"can not found table name: {name} namespace: {namespace}, pass destroy" ) address_dict = storage_address.copy() with storage.Session.build( session_id=job_utils.generate_session_id( self.tracker.task_id, self.tracker.task_version, self.tracker.role, self.tracker.party_id, suffix="storage", random_end=True), storage_engine=storage_engine, options=self.parameters.get("options")) as storage_session: if storage_engine in { StorageEngine.EGGROLL, StorageEngine.STANDALONE }: upload_address = { "name": name, "namespace": namespace, "storage_type": EggRollStorageType.ROLLPAIR_LMDB } elif storage_engine in {StorageEngine.MYSQL}: upload_address = {"db": namespace, "name": name} elif storage_engine in {StorageEngine.HDFS}: upload_address = { "path": data_utils.default_input_fs_path( name=name, namespace=namespace, prefix=address_dict.get("path_prefix")) } else: raise RuntimeError( f"can not support this storage engine: {storage_engine}") address_dict.update(upload_address) LOGGER.info( f"upload to {storage_engine} storage, address: {address_dict}") address = storage.StorageTableMeta.create_address( storage_engine=storage_engine, address_dict=address_dict) self.parameters["partitions"] = partitions self.parameters["name"] = name self.table = storage_session.create_table(address=address, **self.parameters) data_table_count = self.save_data_table(job_id, name, namespace, head) self.table.get_meta().update_metas(in_serialized=True) LOGGER.info("------------load data finish!-----------------") # rm tmp file try: if '{}/fate_upload_tmp'.format(job_id) in self.parameters['file']: LOGGER.info("remove tmp upload file") LOGGER.info(os.path.dirname(self.parameters["file"])) shutil.rmtree(os.path.dirname(self.parameters["file"])) except: LOGGER.info("remove tmp file failed") LOGGER.info("file: {}".format(self.parameters["file"])) LOGGER.info("total data_count: {}".format(data_table_count)) LOGGER.info("table name: {}, table namespace: {}".format( name, namespace))
def start_task_worker(cls, worker_name, task: Task, task_parameters: RunParameters = None, executable: list = None, extra_env: dict = None, **kwargs): worker_id, config_dir, log_dir = cls.get_process_dirs( worker_name=worker_name, job_id=task.f_job_id, role=task.f_role, party_id=task.f_party_id, task=task) session_id = job_utils.generate_session_id(task.f_task_id, task.f_task_version, task.f_role, task.f_party_id) federation_session_id = job_utils.generate_task_version_id( task.f_task_id, task.f_task_version) info_kwargs = {} specific_cmd = [] if worker_name is WorkerName.TASK_EXECUTOR: from fate_flow.worker.task_executor import TaskExecutor module_file_path = sys.modules[TaskExecutor.__module__].__file__ else: raise Exception(f"not support {worker_name} worker") if task_parameters is None: task_parameters = RunParameters(**job_utils.get_job_parameters( task.f_job_id, task.f_role, task.f_party_id)) config = task_parameters.to_dict() config["src_user"] = kwargs.get("src_user") config_path, result_path = cls.get_config(config_dir=config_dir, config=config, log_dir=log_dir) if executable: process_cmd = executable else: process_cmd = [sys.executable or "python3"] common_cmd = [ module_file_path, "--job_id", task.f_job_id, "--component_name", task.f_component_name, "--task_id", task.f_task_id, "--task_version", task.f_task_version, "--role", task.f_role, "--party_id", task.f_party_id, "--config", config_path, '--result', result_path, "--log_dir", log_dir, "--parent_log_dir", os.path.dirname(log_dir), "--worker_id", worker_id, "--run_ip", RuntimeConfig.JOB_SERVER_HOST, "--job_server", f"{RuntimeConfig.JOB_SERVER_HOST}:{RuntimeConfig.HTTP_PORT}", "--session_id", session_id, "--federation_session_id", federation_session_id, ] process_cmd.extend(common_cmd) process_cmd.extend(specific_cmd) env = cls.get_env(task.f_job_id, task.f_provider_info) if extra_env: env.update(extra_env) schedule_logger(task.f_job_id).info( f"task {task.f_task_id} {task.f_task_version} on {task.f_role} {task.f_party_id} {worker_name} worker subprocess is ready" ) p = process_utils.run_subprocess(job_id=task.f_job_id, config_dir=config_dir, process_cmd=process_cmd, added_env=env, log_dir=log_dir, cwd_dir=config_dir, process_name=worker_name.value, process_id=worker_id) cls.save_worker_info(task=task, worker_name=worker_name, worker_id=worker_id, run_ip=RuntimeConfig.JOB_SERVER_HOST, run_pid=p.pid, config=config, cmd=process_cmd, **info_kwargs) return {"run_pid": p.pid, "worker_id": worker_id, "cmd": process_cmd}
def _run(self, cpn_input: ComponentInputProtocol): self.parameters = cpn_input.parameters LOGGER.info(self.parameters) self.parameters["role"] = cpn_input.roles["role"] self.parameters["local"] = cpn_input.roles["local"] storage_engine = self.parameters["storage_engine"].upper() storage_address = self.parameters["storage_address"] # if not set storage, use job storage as default if not storage_engine: storage_engine = cpn_input.job_parameters.storage_engine self.storage_engine = storage_engine if not storage_address: storage_address = cpn_input.job_parameters.engines_address[ EngineType.STORAGE] job_id = self.task_version_id.split("_")[0] if not os.path.isabs(self.parameters.get("file", "")): self.parameters["file"] = os.path.join(get_fate_flow_directory(), self.parameters["file"]) if not os.path.exists(self.parameters["file"]): raise Exception("%s is not exist, please check the configure" % (self.parameters["file"])) if not os.path.getsize(self.parameters["file"]): raise Exception("%s is an empty file" % (self.parameters["file"])) name, namespace = self.parameters.get("name"), self.parameters.get( "namespace") _namespace, _table_name = self.generate_table_name( self.parameters["file"]) if namespace is None: namespace = _namespace if name is None: name = _table_name read_head = self.parameters["head"] if read_head == 0: head = False elif read_head == 1: head = True else: raise Exception("'head' in conf.json should be 0 or 1") partitions = self.parameters["partition"] if partitions <= 0 or partitions >= self.MAX_PARTITIONS: raise Exception( "Error number of partition, it should between %d and %d" % (0, self.MAX_PARTITIONS)) self.session_id = job_utils.generate_session_id( self.tracker.task_id, self.tracker.task_version, self.tracker.role, self.tracker.party_id, ) sess = Session.get_global() self.session = sess if self.parameters.get("destroy", False): table = sess.get_table(namespace=namespace, name=name) if table: LOGGER.info( f"destroy table name: {name} namespace: {namespace} engine: {table.engine}" ) try: table.destroy() except Exception as e: LOGGER.error(e) else: LOGGER.info( f"can not found table name: {name} namespace: {namespace}, pass destroy" ) address_dict = storage_address.copy() storage_session = sess.storage(storage_engine=storage_engine, options=self.parameters.get("options")) upload_address = {} if storage_engine in {StorageEngine.EGGROLL, StorageEngine.STANDALONE}: upload_address = { "name": name, "namespace": namespace, "storage_type": EggRollStoreType.ROLLPAIR_LMDB, } elif storage_engine in {StorageEngine.MYSQL, StorageEngine.HIVE}: if not address_dict.get("db") or not address_dict.get("name"): upload_address = {"db": namespace, "name": name} elif storage_engine in {StorageEngine.PATH}: upload_address = {"path": self.parameters["file"]} elif storage_engine in {StorageEngine.HDFS}: upload_address = { "path": default_input_fs_path( name=name, namespace=namespace, prefix=address_dict.get("path_prefix"), ) } elif storage_engine in {StorageEngine.LOCALFS}: upload_address = { "path": default_input_fs_path(name=name, namespace=namespace, storage_engine=storage_engine) } else: raise RuntimeError( f"can not support this storage engine: {storage_engine}") address_dict.update(upload_address) LOGGER.info( f"upload to {storage_engine} storage, address: {address_dict}") address = storage.StorageTableMeta.create_address( storage_engine=storage_engine, address_dict=address_dict) self.parameters["partitions"] = partitions self.parameters["name"] = name self.table = storage_session.create_table( address=address, origin=StorageTableOrigin.UPLOAD, **self.parameters) if storage_engine not in [StorageEngine.PATH]: data_table_count = self.save_data_table(job_id, name, namespace, storage_engine, head) else: data_table_count = self.get_data_table_count( self.parameters["file"], name, namespace) self.table.meta.update_metas(in_serialized=True) DataTableTracker.create_table_tracker( table_name=name, table_namespace=namespace, entity_info={ "job_id": job_id, "have_parent": False }, ) LOGGER.info("------------load data finish!-----------------") # rm tmp file try: if "{}/fate_upload_tmp".format(job_id) in self.parameters["file"]: LOGGER.info("remove tmp upload file") LOGGER.info(os.path.dirname(self.parameters["file"])) shutil.rmtree(os.path.dirname(self.parameters["file"])) except: LOGGER.info("remove tmp file failed") LOGGER.info("file: {}".format(self.parameters["file"])) LOGGER.info("total data_count: {}".format(data_table_count)) LOGGER.info("table name: {}, table namespace: {}".format( name, namespace))