Example #1
0
 def clean_task(self, runtime_conf):
     schedule_logger(self.job_id).info('clean task {} {} on {} {}'.format(
         self.task_id, self.task_version, self.role, self.party_id))
     try:
         sess = session.Session(
             computing_type=self.job_parameters.computing_engine,
             federation_type=self.job_parameters.federation_engine)
         # clean up temporary tables
         computing_temp_namespace = job_utils.generate_session_id(
             task_id=self.task_id,
             task_version=self.task_version,
             role=self.role,
             party_id=self.party_id)
         if self.job_parameters.computing_engine == ComputingEngine.EGGROLL:
             session_options = {"eggroll.session.processors.per.node": 1}
         else:
             session_options = {}
         sess.init_computing(
             computing_session_id=f"{computing_temp_namespace}_clean",
             options=session_options)
         sess.computing.cleanup(namespace=computing_temp_namespace,
                                name="*")
         schedule_logger(self.job_id).info(
             'clean table by namespace {} on {} {} done'.format(
                 computing_temp_namespace, self.role, self.party_id))
         # clean up the last tables of the federation
         federation_temp_namespace = job_utils.generate_task_version_id(
             self.task_id, self.task_version)
         sess.computing.cleanup(namespace=federation_temp_namespace,
                                name="*")
         schedule_logger(self.job_id).info(
             'clean table by namespace {} on {} {} done'.format(
                 federation_temp_namespace, self.role, self.party_id))
         sess.computing.stop()
         if self.job_parameters.federation_engine == FederationEngine.RABBITMQ and self.role != "local":
             schedule_logger(self.job_id).info('rabbitmq start clean up')
             parties = [
                 Party(k, p) for k, v in runtime_conf['role'].items()
                 for p in v
             ]
             federation_session_id = job_utils.generate_task_version_id(
                 self.task_id, self.task_version)
             component_parameters_on_party = copy.deepcopy(runtime_conf)
             component_parameters_on_party["local"] = {
                 "role": self.role,
                 "party_id": self.party_id
             }
             sess.init_federation(
                 federation_session_id=federation_session_id,
                 runtime_conf=component_parameters_on_party,
                 service_conf=self.job_parameters.engines_address.get(
                     EngineType.FEDERATION, {}))
             sess._federation_session.cleanup(parties)
             schedule_logger(self.job_id).info('rabbitmq clean up success')
         return True
     except Exception as e:
         schedule_logger(self.job_id).exception(e)
         return False
Example #2
0
 def run(cls):
     parser = argparse.ArgumentParser()
     parser.add_argument('--session',
                         required=True,
                         type=str,
                         help="session manager id")
     parser.add_argument('--computing', help="computing engine", type=str)
     parser.add_argument('--federation', help="federation engine", type=str)
     parser.add_argument('--storage', help="storage engine", type=str)
     parser.add_argument('-c',
                         '--command',
                         required=True,
                         type=str,
                         help="command")
     args = parser.parse_args()
     session_id = args.session
     fate_job_id = session_id.split('_')[0]
     command = args.command
     with session.Session(session_id=session_id,
                          options={"logger":
                                   schedule_logger(fate_job_id)}) as sess:
         sess.destroy_all_sessions()
Example #3
0
    def run_task(cls):
        task_info = {}
        try:
            parser = argparse.ArgumentParser()
            parser.add_argument('-j',
                                '--job_id',
                                required=True,
                                type=str,
                                help="job id")
            parser.add_argument('-n',
                                '--component_name',
                                required=True,
                                type=str,
                                help="component name")
            parser.add_argument('-t',
                                '--task_id',
                                required=True,
                                type=str,
                                help="task id")
            parser.add_argument('-v',
                                '--task_version',
                                required=True,
                                type=int,
                                help="task version")
            parser.add_argument('-r',
                                '--role',
                                required=True,
                                type=str,
                                help="role")
            parser.add_argument('-p',
                                '--party_id',
                                required=True,
                                type=int,
                                help="party id")
            parser.add_argument('-c',
                                '--config',
                                required=True,
                                type=str,
                                help="task parameters")
            parser.add_argument('--run_ip', help="run ip", type=str)
            parser.add_argument('--job_server', help="job server", type=str)
            args = parser.parse_args()
            schedule_logger(args.job_id).info('enter task process')
            schedule_logger(args.job_id).info(args)
            # init function args
            if args.job_server:
                RuntimeConfig.init_config(
                    JOB_SERVER_HOST=args.job_server.split(':')[0],
                    HTTP_PORT=args.job_server.split(':')[1])
                RuntimeConfig.set_process_role(ProcessRole.EXECUTOR)
            job_id = args.job_id
            component_name = args.component_name
            task_id = args.task_id
            task_version = args.task_version
            role = args.role
            party_id = args.party_id
            executor_pid = os.getpid()
            task_info.update({
                "job_id": job_id,
                "component_name": component_name,
                "task_id": task_id,
                "task_version": task_version,
                "role": role,
                "party_id": party_id,
                "run_ip": args.run_ip,
                "run_pid": executor_pid
            })
            start_time = current_timestamp()
            job_conf = job_utils.get_job_conf(job_id, role)
            job_dsl = job_conf["job_dsl_path"]
            job_runtime_conf = job_conf["job_runtime_conf_path"]
            dsl_parser = schedule_utils.get_job_dsl_parser(
                dsl=job_dsl,
                runtime_conf=job_runtime_conf,
                train_runtime_conf=job_conf["train_runtime_conf_path"],
                pipeline_dsl=job_conf["pipeline_dsl_path"])
            party_index = job_runtime_conf["role"][role].index(party_id)
            job_args_on_party = TaskExecutor.get_job_args_on_party(
                dsl_parser, job_runtime_conf, role, party_id)
            component = dsl_parser.get_component_info(
                component_name=component_name)
            component_parameters = component.get_role_parameters()
            component_parameters_on_party = component_parameters[role][
                party_index] if role in component_parameters else {}
            module_name = component.get_module()
            task_input_dsl = component.get_input()
            task_output_dsl = component.get_output()
            component_parameters_on_party[
                'output_data_name'] = task_output_dsl.get('data')
            task_parameters = RunParameters(
                **file_utils.load_json_conf(args.config))
            job_parameters = task_parameters
            if job_parameters.assistant_role:
                TaskExecutor.monkey_patch()
        except Exception as e:
            traceback.print_exc()
            schedule_logger().exception(e)
            task_info["party_status"] = TaskStatus.FAILED
            return
        try:
            job_log_dir = os.path.join(
                job_utils.get_job_log_directory(job_id=job_id), role,
                str(party_id))
            task_log_dir = os.path.join(job_log_dir, component_name)
            log.LoggerFactory.set_directory(directory=task_log_dir,
                                            parent_log_dir=job_log_dir,
                                            append_to_parent_log=True,
                                            force=True)

            tracker = Tracker(job_id=job_id,
                              role=role,
                              party_id=party_id,
                              component_name=component_name,
                              task_id=task_id,
                              task_version=task_version,
                              model_id=job_parameters.model_id,
                              model_version=job_parameters.model_version,
                              component_module_name=module_name,
                              job_parameters=job_parameters)
            tracker_client = TrackerClient(
                job_id=job_id,
                role=role,
                party_id=party_id,
                component_name=component_name,
                task_id=task_id,
                task_version=task_version,
                model_id=job_parameters.model_id,
                model_version=job_parameters.model_version,
                component_module_name=module_name,
                job_parameters=job_parameters)
            run_class_paths = component_parameters_on_party.get(
                'CodePath').split('/')
            run_class_package = '.'.join(
                run_class_paths[:-2]) + '.' + run_class_paths[-2].replace(
                    '.py', '')
            run_class_name = run_class_paths[-1]
            task_info["party_status"] = TaskStatus.RUNNING
            cls.report_task_update_to_driver(task_info=task_info)

            # init environment, process is shared globally
            RuntimeConfig.init_config(
                WORK_MODE=job_parameters.work_mode,
                COMPUTING_ENGINE=job_parameters.computing_engine,
                FEDERATION_ENGINE=job_parameters.federation_engine,
                FEDERATED_MODE=job_parameters.federated_mode)

            if RuntimeConfig.COMPUTING_ENGINE == ComputingEngine.EGGROLL:
                session_options = task_parameters.eggroll_run.copy()
            else:
                session_options = {}

            sess = session.Session(
                computing_type=job_parameters.computing_engine,
                federation_type=job_parameters.federation_engine)
            computing_session_id = job_utils.generate_session_id(
                task_id, task_version, role, party_id)
            sess.init_computing(computing_session_id=computing_session_id,
                                options=session_options)
            federation_session_id = job_utils.generate_task_version_id(
                task_id, task_version)
            component_parameters_on_party[
                "job_parameters"] = job_parameters.to_dict()
            sess.init_federation(
                federation_session_id=federation_session_id,
                runtime_conf=component_parameters_on_party,
                service_conf=job_parameters.engines_address.get(
                    EngineType.FEDERATION, {}))
            sess.as_default()

            schedule_logger().info('Run {} {} {} {} {} task'.format(
                job_id, component_name, task_id, role, party_id))
            schedule_logger().info("Component parameters on party {}".format(
                component_parameters_on_party))
            schedule_logger().info("Task input dsl {}".format(task_input_dsl))
            task_run_args = cls.get_task_run_args(
                job_id=job_id,
                role=role,
                party_id=party_id,
                task_id=task_id,
                task_version=task_version,
                job_args=job_args_on_party,
                job_parameters=job_parameters,
                task_parameters=task_parameters,
                input_dsl=task_input_dsl,
            )
            if module_name in {"Upload", "Download", "Reader", "Writer"}:
                task_run_args["job_parameters"] = job_parameters
            run_object = getattr(importlib.import_module(run_class_package),
                                 run_class_name)()
            run_object.set_tracker(tracker=tracker_client)
            run_object.set_task_version_id(
                task_version_id=job_utils.generate_task_version_id(
                    task_id, task_version))
            # add profile logs
            profile.profile_start()
            run_object.run(component_parameters_on_party, task_run_args)
            profile.profile_ends()
            output_data = run_object.save_data()
            if not isinstance(output_data, list):
                output_data = [output_data]
            for index in range(0, len(output_data)):
                data_name = task_output_dsl.get(
                    'data')[index] if task_output_dsl.get(
                        'data') else '{}'.format(index)
                persistent_table_namespace, persistent_table_name = tracker.save_output_data(
                    computing_table=output_data[index],
                    output_storage_engine=job_parameters.storage_engine,
                    output_storage_address=job_parameters.engines_address.get(
                        EngineType.STORAGE, {}))
                if persistent_table_namespace and persistent_table_name:
                    tracker.log_output_data_info(
                        data_name=data_name,
                        table_namespace=persistent_table_namespace,
                        table_name=persistent_table_name)
            output_model = run_object.export_model()
            # There is only one model output at the current dsl version.
            tracker.save_output_model(
                output_model, task_output_dsl['model'][0]
                if task_output_dsl.get('model') else 'default')
            task_info["party_status"] = TaskStatus.SUCCESS
        except Exception as e:
            task_info["party_status"] = TaskStatus.FAILED
            schedule_logger().exception(e)
        finally:
            try:
                task_info["end_time"] = current_timestamp()
                task_info["elapsed"] = task_info["end_time"] - start_time
                cls.report_task_update_to_driver(task_info=task_info)
            except Exception as e:
                task_info["party_status"] = TaskStatus.FAILED
                traceback.print_exc()
                schedule_logger().exception(e)
        schedule_logger().info('task {} {} {} start time: {}'.format(
            task_id, role, party_id, timestamp_to_date(start_time)))
        schedule_logger().info('task {} {} {} end time: {}'.format(
            task_id, role, party_id, timestamp_to_date(task_info["end_time"])))
        schedule_logger().info('task {} {} {} takes {}s'.format(
            task_id, role, party_id,
            int(task_info["elapsed"]) / 1000))
        schedule_logger().info('Finish {} {} {} {} {} {} task {}'.format(
            job_id, component_name, task_id, task_version, role, party_id,
            task_info["party_status"]))

        print('Finish {} {} {} {} {} {} task {}'.format(
            job_id, component_name, task_id, task_version, role, party_id,
            task_info["party_status"]))
        return task_info
Example #4
0
    def _run_(self):
        # todo: All function calls where errors should be thrown
        args = self.args
        start_time = current_timestamp()
        try:
            LOGGER.info(
                f'run {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} task'
            )
            self.report_info.update({
                "job_id": args.job_id,
                "component_name": args.component_name,
                "task_id": args.task_id,
                "task_version": args.task_version,
                "role": args.role,
                "party_id": args.party_id,
                "run_ip": args.run_ip,
                "run_pid": self.run_pid
            })
            operation_client = OperationClient()
            job_configuration = JobConfiguration(
                **operation_client.get_job_conf(
                    args.job_id, args.role, args.party_id, args.component_name,
                    args.task_id, args.task_version))
            task_parameters_conf = args.config
            dsl_parser = schedule_utils.get_job_dsl_parser(
                dsl=job_configuration.dsl,
                runtime_conf=job_configuration.runtime_conf,
                train_runtime_conf=job_configuration.train_runtime_conf,
                pipeline_dsl=None)

            job_parameters = dsl_parser.get_job_parameters(
                job_configuration.runtime_conf)
            user_name = job_parameters.get(args.role,
                                           {}).get(args.party_id,
                                                   {}).get("user", '')
            LOGGER.info(f"user name:{user_name}")
            src_user = task_parameters_conf.get("src_user")
            task_parameters = RunParameters(**task_parameters_conf)
            job_parameters = task_parameters
            if job_parameters.assistant_role:
                TaskExecutor.monkey_patch()

            job_args_on_party = TaskExecutor.get_job_args_on_party(
                dsl_parser, job_configuration.runtime_conf_on_party, args.role,
                args.party_id)
            component = dsl_parser.get_component_info(
                component_name=args.component_name)
            module_name = component.get_module()
            task_input_dsl = component.get_input()
            task_output_dsl = component.get_output()

            kwargs = {
                'job_id': args.job_id,
                'role': args.role,
                'party_id': args.party_id,
                'component_name': args.component_name,
                'task_id': args.task_id,
                'task_version': args.task_version,
                'model_id': job_parameters.model_id,
                'model_version': job_parameters.model_version,
                'component_module_name': module_name,
                'job_parameters': job_parameters,
            }
            tracker = Tracker(**kwargs)
            tracker_client = TrackerClient(**kwargs)
            checkpoint_manager = CheckpointManager(**kwargs)

            self.report_info["party_status"] = TaskStatus.RUNNING
            self.report_task_info_to_driver()

            previous_components_parameters = tracker_client.get_model_run_parameters(
            )
            LOGGER.info(
                f"previous_components_parameters:\n{json_dumps(previous_components_parameters, indent=4)}"
            )

            component_provider, component_parameters_on_party, user_specified_parameters = ProviderManager.get_component_run_info(
                dsl_parser=dsl_parser,
                component_name=args.component_name,
                role=args.role,
                party_id=args.party_id,
                previous_components_parameters=previous_components_parameters)
            RuntimeConfig.set_component_provider(component_provider)
            LOGGER.info(
                f"component parameters on party:\n{json_dumps(component_parameters_on_party, indent=4)}"
            )
            flow_feeded_parameters = {
                "output_data_name": task_output_dsl.get("data")
            }

            # init environment, process is shared globally
            RuntimeConfig.init_config(
                COMPUTING_ENGINE=job_parameters.computing_engine,
                FEDERATION_ENGINE=job_parameters.federation_engine,
                FEDERATED_MODE=job_parameters.federated_mode)

            if RuntimeConfig.COMPUTING_ENGINE == ComputingEngine.EGGROLL:
                session_options = task_parameters.eggroll_run.copy()
                session_options["python.path"] = os.getenv("PYTHONPATH")
                session_options["python.venv"] = os.getenv("VIRTUAL_ENV")
            else:
                session_options = {}

            sess = session.Session(session_id=args.session_id)
            sess.as_global()
            sess.init_computing(computing_session_id=args.session_id,
                                options=session_options)
            component_parameters_on_party[
                "job_parameters"] = job_parameters.to_dict()
            roles = job_configuration.runtime_conf["role"]
            if set(roles) == {"local"}:
                LOGGER.info(f"only local roles, pass init federation")
            else:
                sess.init_federation(
                    federation_session_id=args.federation_session_id,
                    runtime_conf=component_parameters_on_party,
                    service_conf=job_parameters.engines_address.get(
                        EngineType.FEDERATION, {}))
            LOGGER.info(
                f'run {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} task'
            )
            LOGGER.info(
                f"component parameters on party:\n{json_dumps(component_parameters_on_party, indent=4)}"
            )
            LOGGER.info(f"task input dsl {task_input_dsl}")
            task_run_args, input_table_list = self.get_task_run_args(
                job_id=args.job_id,
                role=args.role,
                party_id=args.party_id,
                task_id=args.task_id,
                task_version=args.task_version,
                job_args=job_args_on_party,
                job_parameters=job_parameters,
                task_parameters=task_parameters,
                input_dsl=task_input_dsl,
            )
            if module_name in {
                    "Upload", "Download", "Reader", "Writer", "Checkpoint"
            }:
                task_run_args["job_parameters"] = job_parameters
            LOGGER.info(f"task input args {task_run_args}")

            need_run = component_parameters_on_party.get("ComponentParam",
                                                         {}).get(
                                                             "need_run", True)
            provider_interface = provider_utils.get_provider_interface(
                provider=component_provider)
            run_object = provider_interface.get(
                module_name,
                ComponentRegistry.get_provider_components(
                    provider_name=component_provider.name,
                    provider_version=component_provider.version)).get_run_obj(
                        self.args.role)
            flow_feeded_parameters.update({"table_info": input_table_list})
            cpn_input = ComponentInput(
                tracker=tracker_client,
                checkpoint_manager=checkpoint_manager,
                task_version_id=job_utils.generate_task_version_id(
                    args.task_id, args.task_version),
                parameters=component_parameters_on_party["ComponentParam"],
                datasets=task_run_args.get("data", None),
                caches=task_run_args.get("cache", None),
                models=dict(
                    model=task_run_args.get("model"),
                    isometric_model=task_run_args.get("isometric_model"),
                ),
                job_parameters=job_parameters,
                roles=dict(
                    role=component_parameters_on_party["role"],
                    local=component_parameters_on_party["local"],
                ),
                flow_feeded_parameters=flow_feeded_parameters,
            )
            profile_log_enabled = False
            try:
                if int(os.getenv("FATE_PROFILE_LOG_ENABLED", "0")) > 0:
                    profile_log_enabled = True
            except Exception as e:
                LOGGER.warning(e)
            if profile_log_enabled:
                # add profile logs
                LOGGER.info("profile logging is enabled")
                profile.profile_start()
                cpn_output = run_object.run(cpn_input)
                sess.wait_remote_all_done()
                profile.profile_ends()
            else:
                LOGGER.info("profile logging is disabled")
                cpn_output = run_object.run(cpn_input)
                sess.wait_remote_all_done()

            output_table_list = []
            LOGGER.info(f"task output data {cpn_output.data}")
            for index, data in enumerate(cpn_output.data):
                data_name = task_output_dsl.get(
                    'data')[index] if task_output_dsl.get(
                        'data') else '{}'.format(index)
                #todo: the token depends on the engine type, maybe in job parameters
                persistent_table_namespace, persistent_table_name = tracker.save_output_data(
                    computing_table=data,
                    output_storage_engine=job_parameters.storage_engine,
                    token={"username": user_name})
                if persistent_table_namespace and persistent_table_name:
                    tracker.log_output_data_info(
                        data_name=data_name,
                        table_namespace=persistent_table_namespace,
                        table_name=persistent_table_name)
                    output_table_list.append({
                        "namespace": persistent_table_namespace,
                        "name": persistent_table_name
                    })
            self.log_output_data_table_tracker(args.job_id, input_table_list,
                                               output_table_list)

            # There is only one model output at the current dsl version.
            tracker_client.save_component_output_model(
                model_buffers=cpn_output.model,
                model_alias=task_output_dsl['model'][0]
                if task_output_dsl.get('model') else 'default',
                user_specified_run_parameters=user_specified_parameters)
            if cpn_output.cache is not None:
                for i, cache in enumerate(cpn_output.cache):
                    if cache is None:
                        continue
                    name = task_output_dsl.get(
                        "cache")[i] if "cache" in task_output_dsl else str(i)
                    if isinstance(cache, DataCache):
                        tracker.tracking_output_cache(cache, cache_name=name)
                    elif isinstance(cache, tuple):
                        tracker.save_output_cache(
                            cache_data=cache[0],
                            cache_meta=cache[1],
                            cache_name=name,
                            output_storage_engine=job_parameters.
                            storage_engine,
                            output_storage_address=job_parameters.
                            engines_address.get(EngineType.STORAGE, {}),
                            token={"username": user_name})
                    else:
                        raise RuntimeError(
                            f"can not support type {type(cache)} module run object output cache"
                        )
            if need_run:
                self.report_info["party_status"] = TaskStatus.SUCCESS
            else:
                self.report_info["party_status"] = TaskStatus.PASS
        except PassError as e:
            self.report_info["party_status"] = TaskStatus.PASS
        except Exception as e:
            traceback.print_exc()
            self.report_info["party_status"] = TaskStatus.FAILED
            LOGGER.exception(e)
        finally:
            try:
                self.report_info["end_time"] = current_timestamp()
                self.report_info[
                    "elapsed"] = self.report_info["end_time"] - start_time
                self.report_task_info_to_driver()
            except Exception as e:
                self.report_info["party_status"] = TaskStatus.FAILED
                traceback.print_exc()
                LOGGER.exception(e)
        msg = f"finish {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} with {self.report_info['party_status']}"
        LOGGER.info(msg)
        print(msg)
        return self.report_info
Example #5
0
def get_big_data(guest_data_size, host_data_size, guest_feature_num, host_feature_num, include_path, host_data_type,
                 conf: Config, encryption_type, match_rate, sparsity, force, split_host, output_path, parallelize):
    global big_data_dir

    def list_tag_value(feature_nums, head):
        # data = ''
        # for f in range(feature_nums):
        #     data += head[f] + ':' + str(round(np.random.randn(), 4)) + ";"
        # return data[:-1]
        return ";".join([head[k] + ':' + str(round(v, 4)) for k, v in enumerate(np.random.randn(feature_nums))])

    def list_tag(feature_nums, data_list):
        data = ''
        for f in range(feature_nums):
            data += random.choice(data_list) + ";"
        return data[:-1]

    def _generate_tag_value_data(data_path, start_num, end_num, feature_nums, progress):
        data_num = end_num - start_num
        section_data_size = round(data_num / 100)
        iteration = round(data_num / section_data_size)
        head = ['x' + str(i) for i in range(feature_nums)]
        for batch in range(iteration + 1):
            progress.set_time_percent(batch)
            output_data = pd.DataFrame(columns=["id"])
            if section_data_size * (batch + 1) <= data_num:
                output_data["id"] = id_encryption(encryption_type, section_data_size * batch + start_num,
                                                  section_data_size * (batch + 1) + start_num)
                slicing_data_size = section_data_size
            elif section_data_size * batch < data_num:
                output_data['id'] = id_encryption(encryption_type, section_data_size * batch + start_num, end_num)
                slicing_data_size = data_num - section_data_size * batch
            else:
                break
            feature = [list_tag_value(feature_nums, head) for i in range(slicing_data_size)]
            output_data['feature'] = feature
            output_data.to_csv(data_path, mode='a+', index=False, header=False)

    def _generate_dens_data(data_path, start_num, end_num, feature_nums, label_flag, progress):
        if label_flag:
            head_1 = ['id', 'y']
        else:
            head_1 = ['id']
        data_num = end_num - start_num
        head_2 = ['x' + str(i) for i in range(feature_nums)]
        df_data_1 = pd.DataFrame(columns=head_1)
        head_data = pd.DataFrame(columns=head_1 + head_2)
        head_data.to_csv(data_path, mode='a+', index=False)
        section_data_size = round(data_num / 100)
        iteration = round(data_num / section_data_size)
        for batch in range(iteration + 1):
            progress.set_time_percent(batch)
            if section_data_size * (batch + 1) <= data_num:
                df_data_1["id"] = id_encryption(encryption_type, section_data_size * batch + start_num,
                                                section_data_size * (batch + 1) + start_num)
                slicing_data_size = section_data_size
            elif section_data_size * batch < data_num:
                df_data_1 = pd.DataFrame(columns=head_1)
                df_data_1["id"] = id_encryption(encryption_type, section_data_size * batch + start_num, end_num)
                slicing_data_size = data_num - section_data_size * batch
            else:
                break
            if label_flag:
                df_data_1["y"] = [round(np.random.random()) for x in range(slicing_data_size)]
            feature = np.random.randint(-10000, 10000, size=[slicing_data_size, feature_nums]) / 10000
            df_data_2 = pd.DataFrame(feature, columns=head_2)
            output_data = pd.concat([df_data_1, df_data_2], axis=1)
            output_data.to_csv(data_path, mode='a+', index=False, header=False)

    def _generate_tag_data(data_path, start_num, end_num, feature_nums, sparsity, progress):
        data_num = end_num - start_num
        section_data_size = round(data_num / 100)
        iteration = round(data_num / section_data_size)
        valid_set = [x for x in range(2019120799, 2019120799 + round(feature_nums / sparsity))]
        data = list(map(str, valid_set))
        for batch in range(iteration + 1):
            progress.set_time_percent(batch)
            output_data = pd.DataFrame(columns=["id"])
            if section_data_size * (batch + 1) <= data_num:
                output_data["id"] = id_encryption(encryption_type, section_data_size * batch + start_num,
                                                  section_data_size * (batch + 1) + start_num)
                slicing_data_size = section_data_size
            elif section_data_size * batch < data_num:
                output_data["id"] = id_encryption(encryption_type, section_data_size * batch + start_num, end_num)
                slicing_data_size = data_num - section_data_size * batch
            else:
                break
            feature = [list_tag(feature_nums, data_list=data) for i in range(slicing_data_size)]
            output_data['feature'] = feature
            output_data.to_csv(data_path, mode='a+', index=False, header=False)

    def _generate_parallelize_data(start_num, end_num, feature_nums, table_name, namespace, label_flag, data_type,
                                   partition, progress):
        def expand_id_range(k, v):
            if label_flag:
                return [(id_encryption(encryption_type, ids, ids + 1)[0],
                         ",".join([str(round(np.random.random()))] + [str(round(i, 4)) for i in np.random.randn(v)]))
                        for ids in range(int(k), min(step + int(k), end_num))]
            else:
                if data_type == 'tag':
                    valid_set = [x for x in range(2019120799, 2019120799 + round(feature_nums / sparsity))]
                    data = list(map(str, valid_set))
                    return [(id_encryption(encryption_type, ids, ids + 1)[0],
                             ";".join([random.choice(data) for i in range(int(v))]))
                            for ids in range(int(k), min(step + int(k), data_num))]

                elif data_type == 'tag_value':
                    return [(id_encryption(encryption_type, ids, ids + 1)[0],
                             ";".join([f"x{i}" + ':' + str(round(i, 4)) for i in np.random.randn(v)]))
                            for ids in range(int(k), min(step + int(k), data_num))]
                elif data_type == 'dense':
                    return [(id_encryption(encryption_type, ids, ids + 1)[0],
                             ",".join([str(round(i, 4)) for i in np.random.randn(v)]))
                            for ids in range(int(k), min(step + int(k), data_num))]
        data_num = end_num - start_num
        step = 10000 if data_num > 10000 else int(data_num / 10)
        table_list = [(f"{i * step}", f"{feature_nums}") for i in range(int(data_num / step) + start_num)]
        table = sess.computing.parallelize(table_list, partition=partition, include_key=True)
        table = table.flatMap(functools.partial(expand_id_range))
        if label_flag:
            schema = {"sid": "id", "header": ",".join(["y"] + [f"x{i}" for i in range(feature_nums)])}
        else:
            schema = {"sid": "id", "header": ",".join([f"x{i}" for i in range(feature_nums)])}
        if data_type != "dense":
            schema = None

        h_table = sess.get_table(name=table_name, namespace=namespace)
        if h_table:
            h_table.destroy()

        table_meta = sess.persistent(computing_table=table, name=table_name, namespace=namespace, schema=schema)

        storage_session = sess.storage()
        s_table = storage_session.get_table(namespace=table_meta.get_namespace(), name=table_meta.get_name())
        if s_table.count() == data_num:
            progress.set_time_percent(100)
        from fate_flow.manager.data_manager import DataTableTracker
        DataTableTracker.create_table_tracker(
            table_name=table_name,
            table_namespace=namespace,
            entity_info={}
        )

    def data_save(data_info, table_names, namespaces, partition_list):
        data_count = 0
        for idx, data_name in enumerate(data_info.keys()):
            label_flag = True if 'guest' in data_info[data_name] else False
            data_type = 'dense' if 'guest' in data_info[data_name] else host_data_type
            if split_host and ('host' in data_info[data_name]):
                host_end_num = int(np.ceil(host_data_size / len(data_info))) * (data_count + 1) if np.ceil(
                    host_data_size / len(data_info)) * (data_count + 1) <= host_data_size else host_data_size
                host_start_num = int(np.ceil(host_data_size / len(data_info))) * data_count
                data_count += 1
            else:
                host_end_num = host_data_size
                host_start_num = 0
            out_path = os.path.join(str(big_data_dir), data_name)
            if os.path.exists(out_path) and os.path.isfile(out_path) and not parallelize:
                if force:
                    remove_file(out_path)
                else:
                    echo.echo('{} Already exists'.format(out_path))
                    continue
            data_i = (idx + 1) / len(data_info)
            downLoad = f'dataget  [{"#" * int(24 * data_i)}{"-" * (24 - int(24 * data_i))}]  {idx + 1}/{len(data_info)}'
            start = time.time()
            progress = data_progress(downLoad, start)
            thread = threading.Thread(target=run, args=[progress])
            thread.start()

            try:
                if 'guest' in data_info[data_name]:
                    if not parallelize:
                        _generate_dens_data(out_path, guest_start_num, guest_end_num, guest_feature_num, label_flag, progress)
                    else:
                        _generate_parallelize_data(guest_start_num, guest_end_num, guest_feature_num, table_names[idx],
                                                   namespaces[idx], label_flag, data_type, partition_list[idx], progress)
                else:
                    if data_type == 'tag' and not parallelize:
                        _generate_tag_data(out_path, host_start_num, host_end_num, host_feature_num, sparsity, progress)
                    elif data_type == 'tag_value' and not parallelize:
                        _generate_tag_value_data(out_path, host_start_num, host_end_num, host_feature_num, progress)
                    elif data_type == 'dense' and not parallelize:
                        _generate_dens_data(out_path, host_start_num, host_end_num, host_feature_num, label_flag, progress)
                    elif parallelize:
                        _generate_parallelize_data(host_start_num, host_end_num, host_feature_num, table_names[idx],
                                                   namespaces[idx], label_flag, data_type, partition_list[idx], progress)
                progress.set_switch(False)
                time.sleep(1)
            except Exception:
                exception_id = uuid.uuid1()
                echo.echo(f"exception_id={exception_id}")
                LOGGER.exception(f"exception id: {exception_id}")
            finally:
                progress.set_switch(False)
                echo.stdout_newline()

    def run(p):
        while p.get_switch():
            time.sleep(1)
            p.progress(p.get_time_percent())

    if not match_rate > 0 or not match_rate <= 1:
        raise Exception(f"The value is between (0-1), Please check match_rate:{match_rate}")
    guest_start_num = host_data_size - int(guest_data_size * match_rate)
    guest_end_num = guest_start_num + guest_data_size

    if os.path.isfile(include_path):
        with include_path.open("r") as f:
            testsuite_config = json.load(f)
    else:
        raise Exception(f'Input file error, please check{include_path}.')
    try:
        if output_path is not None:
            big_data_dir = os.path.abspath(output_path)
        else:
            big_data_dir = os.path.abspath(conf.cache_directory)
    except Exception:
        raise Exception('{}path does not exist'.format(big_data_dir))
    date_set = {}
    table_name_list = []
    table_namespace_list = []
    partition_list = []
    for upload_dict in testsuite_config.get('data'):
        date_set[os.path.basename(upload_dict.get('file'))] = upload_dict.get('role')
        table_name_list.append(upload_dict.get('table_name'))
        table_namespace_list.append(upload_dict.get('namespace'))
        partition_list.append(upload_dict.get('partition', 8))

    if parallelize:
        with session.Session() as sess:
            session_id = str(uuid.uuid1())
            sess.init_computing(session_id)
            data_save(data_info=date_set, table_names=table_name_list, namespaces=table_namespace_list, partition_list=partition_list)
    else:
        data_save(data_info=date_set, table_names=table_name_list, namespaces=table_namespace_list, partition_list=partition_list)
        echo.echo(f'Data storage address, please check{big_data_dir}')
Example #6
0
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
import uuid

import numpy as np
from fate_arch import session


sess = session.Session()
sess.init_computing()

data = []
for i in range(10):
    features = np.random.random(10)
    features = ",".join([str(x) for x in features])
    data.append((i, features))

c_table = session.get_session().computing.parallelize(data, include_key=True, partition=4)
for k, v in c_table.collect():
    print(v)
print()

table_meta = sess.persistent(computing_table=c_table, namespace="experiment", name=str(uuid.uuid1()))