def clean_task(self, runtime_conf): schedule_logger(self.job_id).info('clean task {} {} on {} {}'.format( self.task_id, self.task_version, self.role, self.party_id)) try: sess = session.Session( computing_type=self.job_parameters.computing_engine, federation_type=self.job_parameters.federation_engine) # clean up temporary tables computing_temp_namespace = job_utils.generate_session_id( task_id=self.task_id, task_version=self.task_version, role=self.role, party_id=self.party_id) if self.job_parameters.computing_engine == ComputingEngine.EGGROLL: session_options = {"eggroll.session.processors.per.node": 1} else: session_options = {} sess.init_computing( computing_session_id=f"{computing_temp_namespace}_clean", options=session_options) sess.computing.cleanup(namespace=computing_temp_namespace, name="*") schedule_logger(self.job_id).info( 'clean table by namespace {} on {} {} done'.format( computing_temp_namespace, self.role, self.party_id)) # clean up the last tables of the federation federation_temp_namespace = job_utils.generate_task_version_id( self.task_id, self.task_version) sess.computing.cleanup(namespace=federation_temp_namespace, name="*") schedule_logger(self.job_id).info( 'clean table by namespace {} on {} {} done'.format( federation_temp_namespace, self.role, self.party_id)) sess.computing.stop() if self.job_parameters.federation_engine == FederationEngine.RABBITMQ and self.role != "local": schedule_logger(self.job_id).info('rabbitmq start clean up') parties = [ Party(k, p) for k, v in runtime_conf['role'].items() for p in v ] federation_session_id = job_utils.generate_task_version_id( self.task_id, self.task_version) component_parameters_on_party = copy.deepcopy(runtime_conf) component_parameters_on_party["local"] = { "role": self.role, "party_id": self.party_id } sess.init_federation( federation_session_id=federation_session_id, runtime_conf=component_parameters_on_party, service_conf=self.job_parameters.engines_address.get( EngineType.FEDERATION, {})) sess._federation_session.cleanup(parties) schedule_logger(self.job_id).info('rabbitmq clean up success') return True except Exception as e: schedule_logger(self.job_id).exception(e) return False
def run(cls): parser = argparse.ArgumentParser() parser.add_argument('--session', required=True, type=str, help="session manager id") parser.add_argument('--computing', help="computing engine", type=str) parser.add_argument('--federation', help="federation engine", type=str) parser.add_argument('--storage', help="storage engine", type=str) parser.add_argument('-c', '--command', required=True, type=str, help="command") args = parser.parse_args() session_id = args.session fate_job_id = session_id.split('_')[0] command = args.command with session.Session(session_id=session_id, options={"logger": schedule_logger(fate_job_id)}) as sess: sess.destroy_all_sessions()
def run_task(cls): task_info = {} try: parser = argparse.ArgumentParser() parser.add_argument('-j', '--job_id', required=True, type=str, help="job id") parser.add_argument('-n', '--component_name', required=True, type=str, help="component name") parser.add_argument('-t', '--task_id', required=True, type=str, help="task id") parser.add_argument('-v', '--task_version', required=True, type=int, help="task version") parser.add_argument('-r', '--role', required=True, type=str, help="role") parser.add_argument('-p', '--party_id', required=True, type=int, help="party id") parser.add_argument('-c', '--config', required=True, type=str, help="task parameters") parser.add_argument('--run_ip', help="run ip", type=str) parser.add_argument('--job_server', help="job server", type=str) args = parser.parse_args() schedule_logger(args.job_id).info('enter task process') schedule_logger(args.job_id).info(args) # init function args if args.job_server: RuntimeConfig.init_config( JOB_SERVER_HOST=args.job_server.split(':')[0], HTTP_PORT=args.job_server.split(':')[1]) RuntimeConfig.set_process_role(ProcessRole.EXECUTOR) job_id = args.job_id component_name = args.component_name task_id = args.task_id task_version = args.task_version role = args.role party_id = args.party_id executor_pid = os.getpid() task_info.update({ "job_id": job_id, "component_name": component_name, "task_id": task_id, "task_version": task_version, "role": role, "party_id": party_id, "run_ip": args.run_ip, "run_pid": executor_pid }) start_time = current_timestamp() job_conf = job_utils.get_job_conf(job_id, role) job_dsl = job_conf["job_dsl_path"] job_runtime_conf = job_conf["job_runtime_conf_path"] dsl_parser = schedule_utils.get_job_dsl_parser( dsl=job_dsl, runtime_conf=job_runtime_conf, train_runtime_conf=job_conf["train_runtime_conf_path"], pipeline_dsl=job_conf["pipeline_dsl_path"]) party_index = job_runtime_conf["role"][role].index(party_id) job_args_on_party = TaskExecutor.get_job_args_on_party( dsl_parser, job_runtime_conf, role, party_id) component = dsl_parser.get_component_info( component_name=component_name) component_parameters = component.get_role_parameters() component_parameters_on_party = component_parameters[role][ party_index] if role in component_parameters else {} module_name = component.get_module() task_input_dsl = component.get_input() task_output_dsl = component.get_output() component_parameters_on_party[ 'output_data_name'] = task_output_dsl.get('data') task_parameters = RunParameters( **file_utils.load_json_conf(args.config)) job_parameters = task_parameters if job_parameters.assistant_role: TaskExecutor.monkey_patch() except Exception as e: traceback.print_exc() schedule_logger().exception(e) task_info["party_status"] = TaskStatus.FAILED return try: job_log_dir = os.path.join( job_utils.get_job_log_directory(job_id=job_id), role, str(party_id)) task_log_dir = os.path.join(job_log_dir, component_name) log.LoggerFactory.set_directory(directory=task_log_dir, parent_log_dir=job_log_dir, append_to_parent_log=True, force=True) tracker = Tracker(job_id=job_id, role=role, party_id=party_id, component_name=component_name, task_id=task_id, task_version=task_version, model_id=job_parameters.model_id, model_version=job_parameters.model_version, component_module_name=module_name, job_parameters=job_parameters) tracker_client = TrackerClient( job_id=job_id, role=role, party_id=party_id, component_name=component_name, task_id=task_id, task_version=task_version, model_id=job_parameters.model_id, model_version=job_parameters.model_version, component_module_name=module_name, job_parameters=job_parameters) run_class_paths = component_parameters_on_party.get( 'CodePath').split('/') run_class_package = '.'.join( run_class_paths[:-2]) + '.' + run_class_paths[-2].replace( '.py', '') run_class_name = run_class_paths[-1] task_info["party_status"] = TaskStatus.RUNNING cls.report_task_update_to_driver(task_info=task_info) # init environment, process is shared globally RuntimeConfig.init_config( WORK_MODE=job_parameters.work_mode, COMPUTING_ENGINE=job_parameters.computing_engine, FEDERATION_ENGINE=job_parameters.federation_engine, FEDERATED_MODE=job_parameters.federated_mode) if RuntimeConfig.COMPUTING_ENGINE == ComputingEngine.EGGROLL: session_options = task_parameters.eggroll_run.copy() else: session_options = {} sess = session.Session( computing_type=job_parameters.computing_engine, federation_type=job_parameters.federation_engine) computing_session_id = job_utils.generate_session_id( task_id, task_version, role, party_id) sess.init_computing(computing_session_id=computing_session_id, options=session_options) federation_session_id = job_utils.generate_task_version_id( task_id, task_version) component_parameters_on_party[ "job_parameters"] = job_parameters.to_dict() sess.init_federation( federation_session_id=federation_session_id, runtime_conf=component_parameters_on_party, service_conf=job_parameters.engines_address.get( EngineType.FEDERATION, {})) sess.as_default() schedule_logger().info('Run {} {} {} {} {} task'.format( job_id, component_name, task_id, role, party_id)) schedule_logger().info("Component parameters on party {}".format( component_parameters_on_party)) schedule_logger().info("Task input dsl {}".format(task_input_dsl)) task_run_args = cls.get_task_run_args( job_id=job_id, role=role, party_id=party_id, task_id=task_id, task_version=task_version, job_args=job_args_on_party, job_parameters=job_parameters, task_parameters=task_parameters, input_dsl=task_input_dsl, ) if module_name in {"Upload", "Download", "Reader", "Writer"}: task_run_args["job_parameters"] = job_parameters run_object = getattr(importlib.import_module(run_class_package), run_class_name)() run_object.set_tracker(tracker=tracker_client) run_object.set_task_version_id( task_version_id=job_utils.generate_task_version_id( task_id, task_version)) # add profile logs profile.profile_start() run_object.run(component_parameters_on_party, task_run_args) profile.profile_ends() output_data = run_object.save_data() if not isinstance(output_data, list): output_data = [output_data] for index in range(0, len(output_data)): data_name = task_output_dsl.get( 'data')[index] if task_output_dsl.get( 'data') else '{}'.format(index) persistent_table_namespace, persistent_table_name = tracker.save_output_data( computing_table=output_data[index], output_storage_engine=job_parameters.storage_engine, output_storage_address=job_parameters.engines_address.get( EngineType.STORAGE, {})) if persistent_table_namespace and persistent_table_name: tracker.log_output_data_info( data_name=data_name, table_namespace=persistent_table_namespace, table_name=persistent_table_name) output_model = run_object.export_model() # There is only one model output at the current dsl version. tracker.save_output_model( output_model, task_output_dsl['model'][0] if task_output_dsl.get('model') else 'default') task_info["party_status"] = TaskStatus.SUCCESS except Exception as e: task_info["party_status"] = TaskStatus.FAILED schedule_logger().exception(e) finally: try: task_info["end_time"] = current_timestamp() task_info["elapsed"] = task_info["end_time"] - start_time cls.report_task_update_to_driver(task_info=task_info) except Exception as e: task_info["party_status"] = TaskStatus.FAILED traceback.print_exc() schedule_logger().exception(e) schedule_logger().info('task {} {} {} start time: {}'.format( task_id, role, party_id, timestamp_to_date(start_time))) schedule_logger().info('task {} {} {} end time: {}'.format( task_id, role, party_id, timestamp_to_date(task_info["end_time"]))) schedule_logger().info('task {} {} {} takes {}s'.format( task_id, role, party_id, int(task_info["elapsed"]) / 1000)) schedule_logger().info('Finish {} {} {} {} {} {} task {}'.format( job_id, component_name, task_id, task_version, role, party_id, task_info["party_status"])) print('Finish {} {} {} {} {} {} task {}'.format( job_id, component_name, task_id, task_version, role, party_id, task_info["party_status"])) return task_info
def _run_(self): # todo: All function calls where errors should be thrown args = self.args start_time = current_timestamp() try: LOGGER.info( f'run {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} task' ) self.report_info.update({ "job_id": args.job_id, "component_name": args.component_name, "task_id": args.task_id, "task_version": args.task_version, "role": args.role, "party_id": args.party_id, "run_ip": args.run_ip, "run_pid": self.run_pid }) operation_client = OperationClient() job_configuration = JobConfiguration( **operation_client.get_job_conf( args.job_id, args.role, args.party_id, args.component_name, args.task_id, args.task_version)) task_parameters_conf = args.config dsl_parser = schedule_utils.get_job_dsl_parser( dsl=job_configuration.dsl, runtime_conf=job_configuration.runtime_conf, train_runtime_conf=job_configuration.train_runtime_conf, pipeline_dsl=None) job_parameters = dsl_parser.get_job_parameters( job_configuration.runtime_conf) user_name = job_parameters.get(args.role, {}).get(args.party_id, {}).get("user", '') LOGGER.info(f"user name:{user_name}") src_user = task_parameters_conf.get("src_user") task_parameters = RunParameters(**task_parameters_conf) job_parameters = task_parameters if job_parameters.assistant_role: TaskExecutor.monkey_patch() job_args_on_party = TaskExecutor.get_job_args_on_party( dsl_parser, job_configuration.runtime_conf_on_party, args.role, args.party_id) component = dsl_parser.get_component_info( component_name=args.component_name) module_name = component.get_module() task_input_dsl = component.get_input() task_output_dsl = component.get_output() kwargs = { 'job_id': args.job_id, 'role': args.role, 'party_id': args.party_id, 'component_name': args.component_name, 'task_id': args.task_id, 'task_version': args.task_version, 'model_id': job_parameters.model_id, 'model_version': job_parameters.model_version, 'component_module_name': module_name, 'job_parameters': job_parameters, } tracker = Tracker(**kwargs) tracker_client = TrackerClient(**kwargs) checkpoint_manager = CheckpointManager(**kwargs) self.report_info["party_status"] = TaskStatus.RUNNING self.report_task_info_to_driver() previous_components_parameters = tracker_client.get_model_run_parameters( ) LOGGER.info( f"previous_components_parameters:\n{json_dumps(previous_components_parameters, indent=4)}" ) component_provider, component_parameters_on_party, user_specified_parameters = ProviderManager.get_component_run_info( dsl_parser=dsl_parser, component_name=args.component_name, role=args.role, party_id=args.party_id, previous_components_parameters=previous_components_parameters) RuntimeConfig.set_component_provider(component_provider) LOGGER.info( f"component parameters on party:\n{json_dumps(component_parameters_on_party, indent=4)}" ) flow_feeded_parameters = { "output_data_name": task_output_dsl.get("data") } # init environment, process is shared globally RuntimeConfig.init_config( COMPUTING_ENGINE=job_parameters.computing_engine, FEDERATION_ENGINE=job_parameters.federation_engine, FEDERATED_MODE=job_parameters.federated_mode) if RuntimeConfig.COMPUTING_ENGINE == ComputingEngine.EGGROLL: session_options = task_parameters.eggroll_run.copy() session_options["python.path"] = os.getenv("PYTHONPATH") session_options["python.venv"] = os.getenv("VIRTUAL_ENV") else: session_options = {} sess = session.Session(session_id=args.session_id) sess.as_global() sess.init_computing(computing_session_id=args.session_id, options=session_options) component_parameters_on_party[ "job_parameters"] = job_parameters.to_dict() roles = job_configuration.runtime_conf["role"] if set(roles) == {"local"}: LOGGER.info(f"only local roles, pass init federation") else: sess.init_federation( federation_session_id=args.federation_session_id, runtime_conf=component_parameters_on_party, service_conf=job_parameters.engines_address.get( EngineType.FEDERATION, {})) LOGGER.info( f'run {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} task' ) LOGGER.info( f"component parameters on party:\n{json_dumps(component_parameters_on_party, indent=4)}" ) LOGGER.info(f"task input dsl {task_input_dsl}") task_run_args, input_table_list = self.get_task_run_args( job_id=args.job_id, role=args.role, party_id=args.party_id, task_id=args.task_id, task_version=args.task_version, job_args=job_args_on_party, job_parameters=job_parameters, task_parameters=task_parameters, input_dsl=task_input_dsl, ) if module_name in { "Upload", "Download", "Reader", "Writer", "Checkpoint" }: task_run_args["job_parameters"] = job_parameters LOGGER.info(f"task input args {task_run_args}") need_run = component_parameters_on_party.get("ComponentParam", {}).get( "need_run", True) provider_interface = provider_utils.get_provider_interface( provider=component_provider) run_object = provider_interface.get( module_name, ComponentRegistry.get_provider_components( provider_name=component_provider.name, provider_version=component_provider.version)).get_run_obj( self.args.role) flow_feeded_parameters.update({"table_info": input_table_list}) cpn_input = ComponentInput( tracker=tracker_client, checkpoint_manager=checkpoint_manager, task_version_id=job_utils.generate_task_version_id( args.task_id, args.task_version), parameters=component_parameters_on_party["ComponentParam"], datasets=task_run_args.get("data", None), caches=task_run_args.get("cache", None), models=dict( model=task_run_args.get("model"), isometric_model=task_run_args.get("isometric_model"), ), job_parameters=job_parameters, roles=dict( role=component_parameters_on_party["role"], local=component_parameters_on_party["local"], ), flow_feeded_parameters=flow_feeded_parameters, ) profile_log_enabled = False try: if int(os.getenv("FATE_PROFILE_LOG_ENABLED", "0")) > 0: profile_log_enabled = True except Exception as e: LOGGER.warning(e) if profile_log_enabled: # add profile logs LOGGER.info("profile logging is enabled") profile.profile_start() cpn_output = run_object.run(cpn_input) sess.wait_remote_all_done() profile.profile_ends() else: LOGGER.info("profile logging is disabled") cpn_output = run_object.run(cpn_input) sess.wait_remote_all_done() output_table_list = [] LOGGER.info(f"task output data {cpn_output.data}") for index, data in enumerate(cpn_output.data): data_name = task_output_dsl.get( 'data')[index] if task_output_dsl.get( 'data') else '{}'.format(index) #todo: the token depends on the engine type, maybe in job parameters persistent_table_namespace, persistent_table_name = tracker.save_output_data( computing_table=data, output_storage_engine=job_parameters.storage_engine, token={"username": user_name}) if persistent_table_namespace and persistent_table_name: tracker.log_output_data_info( data_name=data_name, table_namespace=persistent_table_namespace, table_name=persistent_table_name) output_table_list.append({ "namespace": persistent_table_namespace, "name": persistent_table_name }) self.log_output_data_table_tracker(args.job_id, input_table_list, output_table_list) # There is only one model output at the current dsl version. tracker_client.save_component_output_model( model_buffers=cpn_output.model, model_alias=task_output_dsl['model'][0] if task_output_dsl.get('model') else 'default', user_specified_run_parameters=user_specified_parameters) if cpn_output.cache is not None: for i, cache in enumerate(cpn_output.cache): if cache is None: continue name = task_output_dsl.get( "cache")[i] if "cache" in task_output_dsl else str(i) if isinstance(cache, DataCache): tracker.tracking_output_cache(cache, cache_name=name) elif isinstance(cache, tuple): tracker.save_output_cache( cache_data=cache[0], cache_meta=cache[1], cache_name=name, output_storage_engine=job_parameters. storage_engine, output_storage_address=job_parameters. engines_address.get(EngineType.STORAGE, {}), token={"username": user_name}) else: raise RuntimeError( f"can not support type {type(cache)} module run object output cache" ) if need_run: self.report_info["party_status"] = TaskStatus.SUCCESS else: self.report_info["party_status"] = TaskStatus.PASS except PassError as e: self.report_info["party_status"] = TaskStatus.PASS except Exception as e: traceback.print_exc() self.report_info["party_status"] = TaskStatus.FAILED LOGGER.exception(e) finally: try: self.report_info["end_time"] = current_timestamp() self.report_info[ "elapsed"] = self.report_info["end_time"] - start_time self.report_task_info_to_driver() except Exception as e: self.report_info["party_status"] = TaskStatus.FAILED traceback.print_exc() LOGGER.exception(e) msg = f"finish {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} with {self.report_info['party_status']}" LOGGER.info(msg) print(msg) return self.report_info
def get_big_data(guest_data_size, host_data_size, guest_feature_num, host_feature_num, include_path, host_data_type, conf: Config, encryption_type, match_rate, sparsity, force, split_host, output_path, parallelize): global big_data_dir def list_tag_value(feature_nums, head): # data = '' # for f in range(feature_nums): # data += head[f] + ':' + str(round(np.random.randn(), 4)) + ";" # return data[:-1] return ";".join([head[k] + ':' + str(round(v, 4)) for k, v in enumerate(np.random.randn(feature_nums))]) def list_tag(feature_nums, data_list): data = '' for f in range(feature_nums): data += random.choice(data_list) + ";" return data[:-1] def _generate_tag_value_data(data_path, start_num, end_num, feature_nums, progress): data_num = end_num - start_num section_data_size = round(data_num / 100) iteration = round(data_num / section_data_size) head = ['x' + str(i) for i in range(feature_nums)] for batch in range(iteration + 1): progress.set_time_percent(batch) output_data = pd.DataFrame(columns=["id"]) if section_data_size * (batch + 1) <= data_num: output_data["id"] = id_encryption(encryption_type, section_data_size * batch + start_num, section_data_size * (batch + 1) + start_num) slicing_data_size = section_data_size elif section_data_size * batch < data_num: output_data['id'] = id_encryption(encryption_type, section_data_size * batch + start_num, end_num) slicing_data_size = data_num - section_data_size * batch else: break feature = [list_tag_value(feature_nums, head) for i in range(slicing_data_size)] output_data['feature'] = feature output_data.to_csv(data_path, mode='a+', index=False, header=False) def _generate_dens_data(data_path, start_num, end_num, feature_nums, label_flag, progress): if label_flag: head_1 = ['id', 'y'] else: head_1 = ['id'] data_num = end_num - start_num head_2 = ['x' + str(i) for i in range(feature_nums)] df_data_1 = pd.DataFrame(columns=head_1) head_data = pd.DataFrame(columns=head_1 + head_2) head_data.to_csv(data_path, mode='a+', index=False) section_data_size = round(data_num / 100) iteration = round(data_num / section_data_size) for batch in range(iteration + 1): progress.set_time_percent(batch) if section_data_size * (batch + 1) <= data_num: df_data_1["id"] = id_encryption(encryption_type, section_data_size * batch + start_num, section_data_size * (batch + 1) + start_num) slicing_data_size = section_data_size elif section_data_size * batch < data_num: df_data_1 = pd.DataFrame(columns=head_1) df_data_1["id"] = id_encryption(encryption_type, section_data_size * batch + start_num, end_num) slicing_data_size = data_num - section_data_size * batch else: break if label_flag: df_data_1["y"] = [round(np.random.random()) for x in range(slicing_data_size)] feature = np.random.randint(-10000, 10000, size=[slicing_data_size, feature_nums]) / 10000 df_data_2 = pd.DataFrame(feature, columns=head_2) output_data = pd.concat([df_data_1, df_data_2], axis=1) output_data.to_csv(data_path, mode='a+', index=False, header=False) def _generate_tag_data(data_path, start_num, end_num, feature_nums, sparsity, progress): data_num = end_num - start_num section_data_size = round(data_num / 100) iteration = round(data_num / section_data_size) valid_set = [x for x in range(2019120799, 2019120799 + round(feature_nums / sparsity))] data = list(map(str, valid_set)) for batch in range(iteration + 1): progress.set_time_percent(batch) output_data = pd.DataFrame(columns=["id"]) if section_data_size * (batch + 1) <= data_num: output_data["id"] = id_encryption(encryption_type, section_data_size * batch + start_num, section_data_size * (batch + 1) + start_num) slicing_data_size = section_data_size elif section_data_size * batch < data_num: output_data["id"] = id_encryption(encryption_type, section_data_size * batch + start_num, end_num) slicing_data_size = data_num - section_data_size * batch else: break feature = [list_tag(feature_nums, data_list=data) for i in range(slicing_data_size)] output_data['feature'] = feature output_data.to_csv(data_path, mode='a+', index=False, header=False) def _generate_parallelize_data(start_num, end_num, feature_nums, table_name, namespace, label_flag, data_type, partition, progress): def expand_id_range(k, v): if label_flag: return [(id_encryption(encryption_type, ids, ids + 1)[0], ",".join([str(round(np.random.random()))] + [str(round(i, 4)) for i in np.random.randn(v)])) for ids in range(int(k), min(step + int(k), end_num))] else: if data_type == 'tag': valid_set = [x for x in range(2019120799, 2019120799 + round(feature_nums / sparsity))] data = list(map(str, valid_set)) return [(id_encryption(encryption_type, ids, ids + 1)[0], ";".join([random.choice(data) for i in range(int(v))])) for ids in range(int(k), min(step + int(k), data_num))] elif data_type == 'tag_value': return [(id_encryption(encryption_type, ids, ids + 1)[0], ";".join([f"x{i}" + ':' + str(round(i, 4)) for i in np.random.randn(v)])) for ids in range(int(k), min(step + int(k), data_num))] elif data_type == 'dense': return [(id_encryption(encryption_type, ids, ids + 1)[0], ",".join([str(round(i, 4)) for i in np.random.randn(v)])) for ids in range(int(k), min(step + int(k), data_num))] data_num = end_num - start_num step = 10000 if data_num > 10000 else int(data_num / 10) table_list = [(f"{i * step}", f"{feature_nums}") for i in range(int(data_num / step) + start_num)] table = sess.computing.parallelize(table_list, partition=partition, include_key=True) table = table.flatMap(functools.partial(expand_id_range)) if label_flag: schema = {"sid": "id", "header": ",".join(["y"] + [f"x{i}" for i in range(feature_nums)])} else: schema = {"sid": "id", "header": ",".join([f"x{i}" for i in range(feature_nums)])} if data_type != "dense": schema = None h_table = sess.get_table(name=table_name, namespace=namespace) if h_table: h_table.destroy() table_meta = sess.persistent(computing_table=table, name=table_name, namespace=namespace, schema=schema) storage_session = sess.storage() s_table = storage_session.get_table(namespace=table_meta.get_namespace(), name=table_meta.get_name()) if s_table.count() == data_num: progress.set_time_percent(100) from fate_flow.manager.data_manager import DataTableTracker DataTableTracker.create_table_tracker( table_name=table_name, table_namespace=namespace, entity_info={} ) def data_save(data_info, table_names, namespaces, partition_list): data_count = 0 for idx, data_name in enumerate(data_info.keys()): label_flag = True if 'guest' in data_info[data_name] else False data_type = 'dense' if 'guest' in data_info[data_name] else host_data_type if split_host and ('host' in data_info[data_name]): host_end_num = int(np.ceil(host_data_size / len(data_info))) * (data_count + 1) if np.ceil( host_data_size / len(data_info)) * (data_count + 1) <= host_data_size else host_data_size host_start_num = int(np.ceil(host_data_size / len(data_info))) * data_count data_count += 1 else: host_end_num = host_data_size host_start_num = 0 out_path = os.path.join(str(big_data_dir), data_name) if os.path.exists(out_path) and os.path.isfile(out_path) and not parallelize: if force: remove_file(out_path) else: echo.echo('{} Already exists'.format(out_path)) continue data_i = (idx + 1) / len(data_info) downLoad = f'dataget [{"#" * int(24 * data_i)}{"-" * (24 - int(24 * data_i))}] {idx + 1}/{len(data_info)}' start = time.time() progress = data_progress(downLoad, start) thread = threading.Thread(target=run, args=[progress]) thread.start() try: if 'guest' in data_info[data_name]: if not parallelize: _generate_dens_data(out_path, guest_start_num, guest_end_num, guest_feature_num, label_flag, progress) else: _generate_parallelize_data(guest_start_num, guest_end_num, guest_feature_num, table_names[idx], namespaces[idx], label_flag, data_type, partition_list[idx], progress) else: if data_type == 'tag' and not parallelize: _generate_tag_data(out_path, host_start_num, host_end_num, host_feature_num, sparsity, progress) elif data_type == 'tag_value' and not parallelize: _generate_tag_value_data(out_path, host_start_num, host_end_num, host_feature_num, progress) elif data_type == 'dense' and not parallelize: _generate_dens_data(out_path, host_start_num, host_end_num, host_feature_num, label_flag, progress) elif parallelize: _generate_parallelize_data(host_start_num, host_end_num, host_feature_num, table_names[idx], namespaces[idx], label_flag, data_type, partition_list[idx], progress) progress.set_switch(False) time.sleep(1) except Exception: exception_id = uuid.uuid1() echo.echo(f"exception_id={exception_id}") LOGGER.exception(f"exception id: {exception_id}") finally: progress.set_switch(False) echo.stdout_newline() def run(p): while p.get_switch(): time.sleep(1) p.progress(p.get_time_percent()) if not match_rate > 0 or not match_rate <= 1: raise Exception(f"The value is between (0-1), Please check match_rate:{match_rate}") guest_start_num = host_data_size - int(guest_data_size * match_rate) guest_end_num = guest_start_num + guest_data_size if os.path.isfile(include_path): with include_path.open("r") as f: testsuite_config = json.load(f) else: raise Exception(f'Input file error, please check{include_path}.') try: if output_path is not None: big_data_dir = os.path.abspath(output_path) else: big_data_dir = os.path.abspath(conf.cache_directory) except Exception: raise Exception('{}path does not exist'.format(big_data_dir)) date_set = {} table_name_list = [] table_namespace_list = [] partition_list = [] for upload_dict in testsuite_config.get('data'): date_set[os.path.basename(upload_dict.get('file'))] = upload_dict.get('role') table_name_list.append(upload_dict.get('table_name')) table_namespace_list.append(upload_dict.get('namespace')) partition_list.append(upload_dict.get('partition', 8)) if parallelize: with session.Session() as sess: session_id = str(uuid.uuid1()) sess.init_computing(session_id) data_save(data_info=date_set, table_names=table_name_list, namespaces=table_namespace_list, partition_list=partition_list) else: data_save(data_info=date_set, table_names=table_name_list, namespaces=table_namespace_list, partition_list=partition_list) echo.echo(f'Data storage address, please check{big_data_dir}')
# # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import uuid import numpy as np from fate_arch import session sess = session.Session() sess.init_computing() data = [] for i in range(10): features = np.random.random(10) features = ",".join([str(x) for x in features]) data.append((i, features)) c_table = session.get_session().computing.parallelize(data, include_key=True, partition=4) for k, v in c_table.collect(): print(v) print() table_meta = sess.persistent(computing_table=c_table, namespace="experiment", name=str(uuid.uuid1()))