def _run_(self): # todo: All function calls where errors should be thrown args = self.args start_time = current_timestamp() try: LOGGER.info( f'run {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} task' ) self.report_info.update({ "job_id": args.job_id, "component_name": args.component_name, "task_id": args.task_id, "task_version": args.task_version, "role": args.role, "party_id": args.party_id, "run_ip": args.run_ip, "run_pid": self.run_pid }) operation_client = OperationClient() job_configuration = JobConfiguration( **operation_client.get_job_conf( args.job_id, args.role, args.party_id, args.component_name, args.task_id, args.task_version)) task_parameters_conf = args.config dsl_parser = schedule_utils.get_job_dsl_parser( dsl=job_configuration.dsl, runtime_conf=job_configuration.runtime_conf, train_runtime_conf=job_configuration.train_runtime_conf, pipeline_dsl=None) job_parameters = dsl_parser.get_job_parameters( job_configuration.runtime_conf) user_name = job_parameters.get(args.role, {}).get(args.party_id, {}).get("user", '') LOGGER.info(f"user name:{user_name}") src_user = task_parameters_conf.get("src_user") task_parameters = RunParameters(**task_parameters_conf) job_parameters = task_parameters if job_parameters.assistant_role: TaskExecutor.monkey_patch() job_args_on_party = TaskExecutor.get_job_args_on_party( dsl_parser, job_configuration.runtime_conf_on_party, args.role, args.party_id) component = dsl_parser.get_component_info( component_name=args.component_name) module_name = component.get_module() task_input_dsl = component.get_input() task_output_dsl = component.get_output() kwargs = { 'job_id': args.job_id, 'role': args.role, 'party_id': args.party_id, 'component_name': args.component_name, 'task_id': args.task_id, 'task_version': args.task_version, 'model_id': job_parameters.model_id, 'model_version': job_parameters.model_version, 'component_module_name': module_name, 'job_parameters': job_parameters, } tracker = Tracker(**kwargs) tracker_client = TrackerClient(**kwargs) checkpoint_manager = CheckpointManager(**kwargs) self.report_info["party_status"] = TaskStatus.RUNNING self.report_task_info_to_driver() previous_components_parameters = tracker_client.get_model_run_parameters( ) LOGGER.info( f"previous_components_parameters:\n{json_dumps(previous_components_parameters, indent=4)}" ) component_provider, component_parameters_on_party, user_specified_parameters = ProviderManager.get_component_run_info( dsl_parser=dsl_parser, component_name=args.component_name, role=args.role, party_id=args.party_id, previous_components_parameters=previous_components_parameters) RuntimeConfig.set_component_provider(component_provider) LOGGER.info( f"component parameters on party:\n{json_dumps(component_parameters_on_party, indent=4)}" ) flow_feeded_parameters = { "output_data_name": task_output_dsl.get("data") } # init environment, process is shared globally RuntimeConfig.init_config( COMPUTING_ENGINE=job_parameters.computing_engine, FEDERATION_ENGINE=job_parameters.federation_engine, FEDERATED_MODE=job_parameters.federated_mode) if RuntimeConfig.COMPUTING_ENGINE == ComputingEngine.EGGROLL: session_options = task_parameters.eggroll_run.copy() session_options["python.path"] = os.getenv("PYTHONPATH") session_options["python.venv"] = os.getenv("VIRTUAL_ENV") else: session_options = {} sess = session.Session(session_id=args.session_id) sess.as_global() sess.init_computing(computing_session_id=args.session_id, options=session_options) component_parameters_on_party[ "job_parameters"] = job_parameters.to_dict() roles = job_configuration.runtime_conf["role"] if set(roles) == {"local"}: LOGGER.info(f"only local roles, pass init federation") else: sess.init_federation( federation_session_id=args.federation_session_id, runtime_conf=component_parameters_on_party, service_conf=job_parameters.engines_address.get( EngineType.FEDERATION, {})) LOGGER.info( f'run {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} task' ) LOGGER.info( f"component parameters on party:\n{json_dumps(component_parameters_on_party, indent=4)}" ) LOGGER.info(f"task input dsl {task_input_dsl}") task_run_args, input_table_list = self.get_task_run_args( job_id=args.job_id, role=args.role, party_id=args.party_id, task_id=args.task_id, task_version=args.task_version, job_args=job_args_on_party, job_parameters=job_parameters, task_parameters=task_parameters, input_dsl=task_input_dsl, ) if module_name in { "Upload", "Download", "Reader", "Writer", "Checkpoint" }: task_run_args["job_parameters"] = job_parameters LOGGER.info(f"task input args {task_run_args}") need_run = component_parameters_on_party.get("ComponentParam", {}).get( "need_run", True) provider_interface = provider_utils.get_provider_interface( provider=component_provider) run_object = provider_interface.get( module_name, ComponentRegistry.get_provider_components( provider_name=component_provider.name, provider_version=component_provider.version)).get_run_obj( self.args.role) flow_feeded_parameters.update({"table_info": input_table_list}) cpn_input = ComponentInput( tracker=tracker_client, checkpoint_manager=checkpoint_manager, task_version_id=job_utils.generate_task_version_id( args.task_id, args.task_version), parameters=component_parameters_on_party["ComponentParam"], datasets=task_run_args.get("data", None), caches=task_run_args.get("cache", None), models=dict( model=task_run_args.get("model"), isometric_model=task_run_args.get("isometric_model"), ), job_parameters=job_parameters, roles=dict( role=component_parameters_on_party["role"], local=component_parameters_on_party["local"], ), flow_feeded_parameters=flow_feeded_parameters, ) profile_log_enabled = False try: if int(os.getenv("FATE_PROFILE_LOG_ENABLED", "0")) > 0: profile_log_enabled = True except Exception as e: LOGGER.warning(e) if profile_log_enabled: # add profile logs LOGGER.info("profile logging is enabled") profile.profile_start() cpn_output = run_object.run(cpn_input) sess.wait_remote_all_done() profile.profile_ends() else: LOGGER.info("profile logging is disabled") cpn_output = run_object.run(cpn_input) sess.wait_remote_all_done() output_table_list = [] LOGGER.info(f"task output data {cpn_output.data}") for index, data in enumerate(cpn_output.data): data_name = task_output_dsl.get( 'data')[index] if task_output_dsl.get( 'data') else '{}'.format(index) #todo: the token depends on the engine type, maybe in job parameters persistent_table_namespace, persistent_table_name = tracker.save_output_data( computing_table=data, output_storage_engine=job_parameters.storage_engine, token={"username": user_name}) if persistent_table_namespace and persistent_table_name: tracker.log_output_data_info( data_name=data_name, table_namespace=persistent_table_namespace, table_name=persistent_table_name) output_table_list.append({ "namespace": persistent_table_namespace, "name": persistent_table_name }) self.log_output_data_table_tracker(args.job_id, input_table_list, output_table_list) # There is only one model output at the current dsl version. tracker_client.save_component_output_model( model_buffers=cpn_output.model, model_alias=task_output_dsl['model'][0] if task_output_dsl.get('model') else 'default', user_specified_run_parameters=user_specified_parameters) if cpn_output.cache is not None: for i, cache in enumerate(cpn_output.cache): if cache is None: continue name = task_output_dsl.get( "cache")[i] if "cache" in task_output_dsl else str(i) if isinstance(cache, DataCache): tracker.tracking_output_cache(cache, cache_name=name) elif isinstance(cache, tuple): tracker.save_output_cache( cache_data=cache[0], cache_meta=cache[1], cache_name=name, output_storage_engine=job_parameters. storage_engine, output_storage_address=job_parameters. engines_address.get(EngineType.STORAGE, {}), token={"username": user_name}) else: raise RuntimeError( f"can not support type {type(cache)} module run object output cache" ) if need_run: self.report_info["party_status"] = TaskStatus.SUCCESS else: self.report_info["party_status"] = TaskStatus.PASS except PassError as e: self.report_info["party_status"] = TaskStatus.PASS except Exception as e: traceback.print_exc() self.report_info["party_status"] = TaskStatus.FAILED LOGGER.exception(e) finally: try: self.report_info["end_time"] = current_timestamp() self.report_info[ "elapsed"] = self.report_info["end_time"] - start_time self.report_task_info_to_driver() except Exception as e: self.report_info["party_status"] = TaskStatus.FAILED traceback.print_exc() LOGGER.exception(e) msg = f"finish {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} with {self.report_info['party_status']}" LOGGER.info(msg) print(msg) return self.report_info
def run_task(cls, **kwargs): task_info = {} try: job_id, component_name, task_id, task_version, role, party_id, run_ip, config, job_server = cls.get_run_task_args(kwargs) if job_server: RuntimeConfig.init_config(JOB_SERVER_HOST=job_server.split(':')[0], HTTP_PORT=job_server.split(':')[1]) RuntimeConfig.set_process_role(ProcessRole.EXECUTOR) executor_pid = os.getpid() task_info.update({ "job_id": job_id, "component_name": component_name, "task_id": task_id, "task_version": task_version, "role": role, "party_id": party_id, "run_ip": run_ip, "run_pid": executor_pid }) start_time = current_timestamp() operation_client = OperationClient() job_conf = operation_client.get_job_conf(job_id, role) job_dsl = job_conf["job_dsl_path"] job_runtime_conf = job_conf["job_runtime_conf_path"] dsl_parser = schedule_utils.get_job_dsl_parser(dsl=job_dsl, runtime_conf=job_runtime_conf, train_runtime_conf=job_conf["train_runtime_conf_path"], pipeline_dsl=job_conf["pipeline_dsl_path"] ) party_index = job_runtime_conf["role"][role].index(party_id) job_args_on_party = TaskExecutor.get_job_args_on_party(dsl_parser, job_runtime_conf, role, party_id) component = dsl_parser.get_component_info(component_name=component_name) component_parameters = component.get_role_parameters() component_parameters_on_party = component_parameters[role][ party_index] if role in component_parameters else {} module_name = component.get_module() task_input_dsl = component.get_input() task_output_dsl = component.get_output() component_parameters_on_party['output_data_name'] = task_output_dsl.get('data') json_conf = operation_client.load_json_conf(job_id, config) user_name = dsl_parser.get_job_parameters().get(role, {}).get(party_id, {}).get("user", '') schedule_logger(job_id).info(f"user name:{user_name}") src_user = json_conf.get("src_user") task_parameters = RunParameters(**json_conf) job_parameters = task_parameters if job_parameters.assistant_role: TaskExecutor.monkey_patch() job_log_dir = os.path.join(job_utils.get_job_log_directory(job_id=job_id), role, str(party_id)) task_log_dir = os.path.join(job_log_dir, component_name) log.LoggerFactory.set_directory(directory=task_log_dir, parent_log_dir=job_log_dir, append_to_parent_log=True, force=True) tracker = Tracker(job_id=job_id, role=role, party_id=party_id, component_name=component_name, task_id=task_id, task_version=task_version, model_id=job_parameters.model_id, model_version=job_parameters.model_version, component_module_name=module_name, job_parameters=job_parameters) tracker_client = TrackerClient(job_id=job_id, role=role, party_id=party_id, component_name=component_name, task_id=task_id, task_version=task_version, model_id=job_parameters.model_id, model_version=job_parameters.model_version, component_module_name=module_name, job_parameters=job_parameters) run_class_paths = component_parameters_on_party.get('CodePath').split('/') run_class_package = '.'.join(run_class_paths[:-2]) + '.' + run_class_paths[-2].replace('.py', '') run_class_name = run_class_paths[-1] task_info["party_status"] = TaskStatus.RUNNING cls.report_task_update_to_driver(task_info=task_info) # init environment, process is shared globally RuntimeConfig.init_config(WORK_MODE=job_parameters.work_mode, COMPUTING_ENGINE=job_parameters.computing_engine, FEDERATION_ENGINE=job_parameters.federation_engine, FEDERATED_MODE=job_parameters.federated_mode) if RuntimeConfig.COMPUTING_ENGINE == ComputingEngine.EGGROLL: session_options = task_parameters.eggroll_run.copy() else: session_options = {} sess = session.Session(computing_type=job_parameters.computing_engine, federation_type=job_parameters.federation_engine) computing_session_id = job_utils.generate_session_id(task_id, task_version, role, party_id) sess.init_computing(computing_session_id=computing_session_id, options=session_options) federation_session_id = job_utils.generate_task_version_id(task_id, task_version) component_parameters_on_party["job_parameters"] = job_parameters.to_dict() sess.init_federation(federation_session_id=federation_session_id, runtime_conf=component_parameters_on_party, service_conf=job_parameters.engines_address.get(EngineType.FEDERATION, {})) sess.as_default() schedule_logger().info('Run {} {} {} {} {} task'.format(job_id, component_name, task_id, role, party_id)) schedule_logger().info("Component parameters on party {}".format(component_parameters_on_party)) schedule_logger().info("Task input dsl {}".format(task_input_dsl)) task_run_args = cls.get_task_run_args(job_id=job_id, role=role, party_id=party_id, task_id=task_id, task_version=task_version, job_args=job_args_on_party, job_parameters=job_parameters, task_parameters=task_parameters, input_dsl=task_input_dsl, ) if module_name in {"Upload", "Download", "Reader", "Writer"}: task_run_args["job_parameters"] = job_parameters run_object = getattr(importlib.import_module(run_class_package), run_class_name)() run_object.set_tracker(tracker=tracker_client) run_object.set_task_version_id(task_version_id=job_utils.generate_task_version_id(task_id, task_version)) # add profile logs profile.profile_start() run_object.run(component_parameters_on_party, task_run_args) # profile.profile_ends() output_data = run_object.save_data() if not isinstance(output_data, list): output_data = [output_data] for index in range(0, len(output_data)): data_name = task_output_dsl.get('data')[index] if task_output_dsl.get('data') else '{}'.format(index) persistent_table_namespace, persistent_table_name = tracker.save_output_data( computing_table=output_data[index], output_storage_engine=job_parameters.storage_engine, output_storage_address=job_parameters.engines_address.get(EngineType.STORAGE, {}), user_name=user_name) if persistent_table_namespace and persistent_table_name: tracker.log_output_data_info(data_name=data_name, table_namespace=persistent_table_namespace, table_name=persistent_table_name) output_model = run_object.export_model() # There is only one model output at the current dsl version. tracker.save_output_model(output_model, task_output_dsl['model'][0] if task_output_dsl.get('model') else 'default', tracker_client=tracker_client) task_info["party_status"] = TaskStatus.SUCCESS except Exception as e: traceback.print_exc() task_info["party_status"] = TaskStatus.FAILED schedule_logger().exception(e) finally: try: task_info["end_time"] = current_timestamp() task_info["elapsed"] = task_info["end_time"] - start_time cls.report_task_update_to_driver(task_info=task_info) except Exception as e: task_info["party_status"] = TaskStatus.FAILED traceback.print_exc() schedule_logger().exception(e) schedule_logger().info( 'task {} {} {} start time: {}'.format(task_id, role, party_id, timestamp_to_date(start_time))) schedule_logger().info( 'task {} {} {} end time: {}'.format(task_id, role, party_id, timestamp_to_date(task_info["end_time"]))) schedule_logger().info( 'task {} {} {} takes {}s'.format(task_id, role, party_id, int(task_info["elapsed"]) / 1000)) schedule_logger().info( 'Finish {} {} {} {} {} {} task {}'.format(job_id, component_name, task_id, task_version, role, party_id, task_info["party_status"])) print('Finish {} {} {} {} {} {} task {}'.format(job_id, component_name, task_id, task_version, role, party_id, task_info["party_status"])) return task_info