Beispiel #1
0
 def clean_task(self, roles, party_ids):
     schedule_logger(self.job_id).info('clean task {} on {} {}'.format(self.task_id,
                                                                       self.role,
                                                                       self.party_id))
     try:
         for role in roles.split(','):
             for party_id in party_ids.split(','):
                 # clean up temporary tables
                 namespace_clean = job_utils.generate_session_id(task_id=self.task_id,
                                                                 role=role,
                                                                 party_id=party_id)
                 session.clean_tables(namespace=namespace_clean, regex_string='*')
                 schedule_logger(self.job_id).info('clean table by namespace {} on {} {} done'.format(namespace_clean,
                                                                                                      self.role,
                                                                                                      self.party_id))
                 # clean up the last tables of the federation
                 namespace_clean = self.task_id
                 session.clean_tables(namespace=namespace_clean, regex_string='*')
                 schedule_logger(self.job_id).info('clean table by namespace {} on {} {} done'.format(namespace_clean,
                                                                                                      self.role,
                                                                                                      self.party_id))
                 
     except Exception as e:
         schedule_logger(self.job_id).exception(e)
     schedule_logger(self.job_id).info('clean task {} on {} {} done'.format(self.task_id,
                                                                            self.role,
                                                                            self.party_id))
Beispiel #2
0
 def clean_task(self, runtime_conf):
     schedule_logger(self.job_id).info('clean task {} {} on {} {}'.format(
         self.task_id, self.task_version, self.role, self.party_id))
     try:
         sess = session.Session(
             computing_type=self.job_parameters.computing_engine,
             federation_type=self.job_parameters.federation_engine)
         # clean up temporary tables
         computing_temp_namespace = job_utils.generate_session_id(
             task_id=self.task_id,
             task_version=self.task_version,
             role=self.role,
             party_id=self.party_id)
         if self.job_parameters.computing_engine == ComputingEngine.EGGROLL:
             session_options = {"eggroll.session.processors.per.node": 1}
         else:
             session_options = {}
         sess.init_computing(
             computing_session_id=f"{computing_temp_namespace}_clean",
             options=session_options)
         sess.computing.cleanup(namespace=computing_temp_namespace,
                                name="*")
         schedule_logger(self.job_id).info(
             'clean table by namespace {} on {} {} done'.format(
                 computing_temp_namespace, self.role, self.party_id))
         # clean up the last tables of the federation
         federation_temp_namespace = job_utils.generate_task_version_id(
             self.task_id, self.task_version)
         sess.computing.cleanup(namespace=federation_temp_namespace,
                                name="*")
         schedule_logger(self.job_id).info(
             'clean table by namespace {} on {} {} done'.format(
                 federation_temp_namespace, self.role, self.party_id))
         sess.computing.stop()
         if self.job_parameters.federation_engine == FederationEngine.RABBITMQ and self.role != "local":
             schedule_logger(self.job_id).info('rabbitmq start clean up')
             parties = [
                 Party(k, p) for k, v in runtime_conf['role'].items()
                 for p in v
             ]
             federation_session_id = job_utils.generate_task_version_id(
                 self.task_id, self.task_version)
             component_parameters_on_party = copy.deepcopy(runtime_conf)
             component_parameters_on_party["local"] = {
                 "role": self.role,
                 "party_id": self.party_id
             }
             sess.init_federation(
                 federation_session_id=federation_session_id,
                 runtime_conf=component_parameters_on_party,
                 service_conf=self.job_parameters.engines_address.get(
                     EngineType.FEDERATION, {}))
             sess._federation_session.cleanup(parties)
             schedule_logger(self.job_id).info('rabbitmq clean up success')
         return True
     except Exception as e:
         schedule_logger(self.job_id).exception(e)
         return False
Beispiel #3
0
 def run(self, component_parameters=None, args=None):
     self.parameters = component_parameters["DownloadParam"]
     self.parameters["role"] = component_parameters["role"]
     self.parameters["local"] = component_parameters["local"]
     name, namespace = self.parameters.get("name"), self.parameters.get(
         "namespace")
     with open(os.path.abspath(self.parameters["output_path"]),
               "w") as fout:
         with storage.Session.build(
                 session_id=job_utils.generate_session_id(
                     self.tracker.task_id,
                     self.tracker.task_version,
                     self.tracker.role,
                     self.tracker.party_id,
                     suffix="storage",
                     random_end=True),
                 name=name,
                 namespace=namespace) as storage_session:
             data_table = storage_session.get_table()
             count = data_table.count()
             LOGGER.info('===== begin to export data =====')
             lines = 0
             job_info = {}
             job_info["job_id"] = self.tracker.job_id
             job_info["role"] = self.tracker.role
             job_info["party_id"] = self.tracker.party_id
             for key, value in data_table.collect():
                 if not value:
                     fout.write(key + "\n")
                 else:
                     fout.write(key +
                                self.parameters.get("delimiter", ",") +
                                str(value) + "\n")
                 lines += 1
                 if lines % 2000 == 0:
                     LOGGER.info(
                         "===== export {} lines =====".format(lines))
                 if lines % 10000 == 0:
                     job_info["progress"] = lines / count * 100 // 1
                     ControllerClient.update_job(job_info=job_info)
             job_info["progress"] = 100
             ControllerClient.update_job(job_info=job_info)
             self.callback_metric(
                 metric_name='data_access',
                 metric_namespace='download',
                 metric_data=[Metric("count", data_table.count())])
         LOGGER.info("===== export {} lines totally =====".format(lines))
         LOGGER.info('===== export data finish =====')
         LOGGER.info('===== export data file path:{} ====='.format(
             os.path.abspath(self.parameters["output_path"])))
Beispiel #4
0
 def _run(self, cpn_input: ComponentInputProtocol):
     self.parameters = cpn_input.parameters
     self.parameters["role"] = cpn_input.roles["role"]
     self.parameters["local"] = cpn_input.roles["local"]
     name, namespace = self.parameters.get("name"), self.parameters.get(
         "namespace")
     with open(os.path.abspath(self.parameters["output_path"]), "w") as fw:
         session = Session(
             job_utils.generate_session_id(
                 self.tracker.task_id,
                 self.tracker.task_version,
                 self.tracker.role,
                 self.tracker.party_id,
             ))
         data_table = session.get_table(name=name, namespace=namespace)
         if not data_table:
             raise Exception(f"no found table {name} {namespace}")
         count = data_table.count()
         LOGGER.info("===== begin to export data =====")
         lines = 0
         job_info = {}
         job_info["job_id"] = self.tracker.job_id
         job_info["role"] = self.tracker.role
         job_info["party_id"] = self.tracker.party_id
         for key, value in data_table.collect():
             if not value:
                 fw.write(key + "\n")
             else:
                 fw.write(key + self.parameters.get("delimiter", ",") +
                          str(value) + "\n")
             lines += 1
             if lines % 2000 == 0:
                 LOGGER.info("===== export {} lines =====".format(lines))
             if lines % 10000 == 0:
                 job_info["progress"] = lines / count * 100 // 1
                 ControllerClient.update_job(job_info=job_info)
         job_info["progress"] = 100
         ControllerClient.update_job(job_info=job_info)
         self.callback_metric(
             metric_name="data_access",
             metric_namespace="download",
             metric_data=[Metric("count", data_table.count())],
         )
         LOGGER.info("===== export {} lines totally =====".format(lines))
         LOGGER.info("===== export data finish =====")
         LOGGER.info("===== export data file path:{} =====".format(
             os.path.abspath(self.parameters["output_path"])))
Beispiel #5
0
    def run(self, component_parameters=None, args=None):
        self.parameters = component_parameters["ReaderParam"]
        output_storage_address = args["job_parameters"].engines_address[
            EngineType.STORAGE]
        table_key = [key for key in self.parameters.keys()][0]
        computing_engine = args["job_parameters"].computing_engine
        output_table_namespace, output_table_name = data_utils.default_output_table_info(
            task_id=self.tracker.task_id,
            task_version=self.tracker.task_version)
        input_table_meta, output_table_address, output_table_engine = self.convert_check(
            input_name=self.parameters[table_key]['name'],
            input_namespace=self.parameters[table_key]['namespace'],
            output_name=output_table_name,
            output_namespace=output_table_namespace,
            computing_engine=computing_engine,
            output_storage_address=output_storage_address)
        with storage.Session.build(session_id=job_utils.generate_session_id(
                self.tracker.task_id,
                self.tracker.task_version,
                self.tracker.role,
                self.tracker.party_id,
                suffix="storage",
                random_end=True),
                                   storage_engine=input_table_meta.get_engine(
                                   )) as input_table_session:
            input_table = input_table_session.get_table(
                name=input_table_meta.get_name(),
                namespace=input_table_meta.get_namespace())
            # update real count to meta info
            input_table.count()
            # Table replication is required
            if input_table_meta.get_engine() != output_table_engine:
                LOGGER.info(
                    f"the {input_table_meta.get_engine()} engine input table needs to be converted to {output_table_engine} engine to support computing engine {computing_engine}"
                )
            else:
                LOGGER.info(
                    f"the {input_table_meta.get_engine()} input table needs to be transform format"
                )
            with storage.Session.build(
                    session_id=job_utils.generate_session_id(
                        self.tracker.task_id,
                        self.tracker.task_version,
                        self.tracker.role,
                        self.tracker.party_id,
                        suffix="storage",
                        random_end=True),
                    storage_engine=output_table_engine
            ) as output_table_session:
                output_table = output_table_session.create_table(
                    address=output_table_address,
                    name=output_table_name,
                    namespace=output_table_namespace,
                    partitions=input_table_meta.partitions)
                self.copy_table(src_table=input_table, dest_table=output_table)
                # update real count to meta info
                output_table.count()
                output_table_meta = StorageTableMeta(
                    name=output_table.get_name(),
                    namespace=output_table.get_namespace())
        self.tracker.log_output_data_info(
            data_name=component_parameters.get('output_data_name')[0]
            if component_parameters.get('output_data_name') else table_key,
            table_namespace=output_table_meta.get_namespace(),
            table_name=output_table_meta.get_name())
        headers_str = output_table_meta.get_schema().get('header')
        table_info = {}
        if output_table_meta.get_schema() and headers_str:
            if isinstance(headers_str, str):
                data_list = [headers_str.split(',')]
                is_display = True
            else:
                data_list = [headers_str]
                is_display = False
            if is_display:
                for data in output_table_meta.get_part_of_data():
                    data_list.append(data[1].split(','))
                data = np.array(data_list)
                Tdata = data.transpose()
                for data in Tdata:
                    table_info[data[0]] = ','.join(list(set(data[1:]))[:5])
        data_info = {
            "table_name": self.parameters[table_key]['name'],
            "namespace": self.parameters[table_key]['namespace'],
            "table_info": table_info,
            "partitions": output_table_meta.get_partitions(),
            "storage_engine": output_table_meta.get_engine()
        }
        if input_table_meta.get_engine() in [StorageEngine.PATH]:
            data_info["file_count"] = output_table_meta.get_count()
            data_info["file_path"] = input_table_meta.get_address().path
        else:
            data_info["count"] = output_table_meta.get_count()

        self.tracker.set_metric_meta(metric_namespace="reader_namespace",
                                     metric_name="reader_name",
                                     metric_meta=MetricMeta(
                                         name='reader',
                                         metric_type='data_info',
                                         extra_metas=data_info))
Beispiel #6
0
    def run_task(cls):
        task_info = {}
        try:
            parser = argparse.ArgumentParser()
            parser.add_argument('-j',
                                '--job_id',
                                required=True,
                                type=str,
                                help="job id")
            parser.add_argument('-n',
                                '--component_name',
                                required=True,
                                type=str,
                                help="component name")
            parser.add_argument('-t',
                                '--task_id',
                                required=True,
                                type=str,
                                help="task id")
            parser.add_argument('-v',
                                '--task_version',
                                required=True,
                                type=int,
                                help="task version")
            parser.add_argument('-r',
                                '--role',
                                required=True,
                                type=str,
                                help="role")
            parser.add_argument('-p',
                                '--party_id',
                                required=True,
                                type=int,
                                help="party id")
            parser.add_argument('-c',
                                '--config',
                                required=True,
                                type=str,
                                help="task parameters")
            parser.add_argument('--run_ip', help="run ip", type=str)
            parser.add_argument('--job_server', help="job server", type=str)
            args = parser.parse_args()
            schedule_logger(args.job_id).info('enter task process')
            schedule_logger(args.job_id).info(args)
            # init function args
            if args.job_server:
                RuntimeConfig.init_config(
                    JOB_SERVER_HOST=args.job_server.split(':')[0],
                    HTTP_PORT=args.job_server.split(':')[1])
                RuntimeConfig.set_process_role(ProcessRole.EXECUTOR)
            job_id = args.job_id
            component_name = args.component_name
            task_id = args.task_id
            task_version = args.task_version
            role = args.role
            party_id = args.party_id
            executor_pid = os.getpid()
            task_info.update({
                "job_id": job_id,
                "component_name": component_name,
                "task_id": task_id,
                "task_version": task_version,
                "role": role,
                "party_id": party_id,
                "run_ip": args.run_ip,
                "run_pid": executor_pid
            })
            start_time = current_timestamp()
            job_conf = job_utils.get_job_conf(job_id, role)
            job_dsl = job_conf["job_dsl_path"]
            job_runtime_conf = job_conf["job_runtime_conf_path"]
            dsl_parser = schedule_utils.get_job_dsl_parser(
                dsl=job_dsl,
                runtime_conf=job_runtime_conf,
                train_runtime_conf=job_conf["train_runtime_conf_path"],
                pipeline_dsl=job_conf["pipeline_dsl_path"])
            party_index = job_runtime_conf["role"][role].index(party_id)
            job_args_on_party = TaskExecutor.get_job_args_on_party(
                dsl_parser, job_runtime_conf, role, party_id)
            component = dsl_parser.get_component_info(
                component_name=component_name)
            component_parameters = component.get_role_parameters()
            component_parameters_on_party = component_parameters[role][
                party_index] if role in component_parameters else {}
            module_name = component.get_module()
            task_input_dsl = component.get_input()
            task_output_dsl = component.get_output()
            component_parameters_on_party[
                'output_data_name'] = task_output_dsl.get('data')
            task_parameters = RunParameters(
                **file_utils.load_json_conf(args.config))
            job_parameters = task_parameters
            if job_parameters.assistant_role:
                TaskExecutor.monkey_patch()
        except Exception as e:
            traceback.print_exc()
            schedule_logger().exception(e)
            task_info["party_status"] = TaskStatus.FAILED
            return
        try:
            job_log_dir = os.path.join(
                job_utils.get_job_log_directory(job_id=job_id), role,
                str(party_id))
            task_log_dir = os.path.join(job_log_dir, component_name)
            log.LoggerFactory.set_directory(directory=task_log_dir,
                                            parent_log_dir=job_log_dir,
                                            append_to_parent_log=True,
                                            force=True)

            tracker = Tracker(job_id=job_id,
                              role=role,
                              party_id=party_id,
                              component_name=component_name,
                              task_id=task_id,
                              task_version=task_version,
                              model_id=job_parameters.model_id,
                              model_version=job_parameters.model_version,
                              component_module_name=module_name,
                              job_parameters=job_parameters)
            tracker_client = TrackerClient(
                job_id=job_id,
                role=role,
                party_id=party_id,
                component_name=component_name,
                task_id=task_id,
                task_version=task_version,
                model_id=job_parameters.model_id,
                model_version=job_parameters.model_version,
                component_module_name=module_name,
                job_parameters=job_parameters)
            run_class_paths = component_parameters_on_party.get(
                'CodePath').split('/')
            run_class_package = '.'.join(
                run_class_paths[:-2]) + '.' + run_class_paths[-2].replace(
                    '.py', '')
            run_class_name = run_class_paths[-1]
            task_info["party_status"] = TaskStatus.RUNNING
            cls.report_task_update_to_driver(task_info=task_info)

            # init environment, process is shared globally
            RuntimeConfig.init_config(
                WORK_MODE=job_parameters.work_mode,
                COMPUTING_ENGINE=job_parameters.computing_engine,
                FEDERATION_ENGINE=job_parameters.federation_engine,
                FEDERATED_MODE=job_parameters.federated_mode)

            if RuntimeConfig.COMPUTING_ENGINE == ComputingEngine.EGGROLL:
                session_options = task_parameters.eggroll_run.copy()
            else:
                session_options = {}

            sess = session.Session(
                computing_type=job_parameters.computing_engine,
                federation_type=job_parameters.federation_engine)
            computing_session_id = job_utils.generate_session_id(
                task_id, task_version, role, party_id)
            sess.init_computing(computing_session_id=computing_session_id,
                                options=session_options)
            federation_session_id = job_utils.generate_task_version_id(
                task_id, task_version)
            component_parameters_on_party[
                "job_parameters"] = job_parameters.to_dict()
            sess.init_federation(
                federation_session_id=federation_session_id,
                runtime_conf=component_parameters_on_party,
                service_conf=job_parameters.engines_address.get(
                    EngineType.FEDERATION, {}))
            sess.as_default()

            schedule_logger().info('Run {} {} {} {} {} task'.format(
                job_id, component_name, task_id, role, party_id))
            schedule_logger().info("Component parameters on party {}".format(
                component_parameters_on_party))
            schedule_logger().info("Task input dsl {}".format(task_input_dsl))
            task_run_args = cls.get_task_run_args(
                job_id=job_id,
                role=role,
                party_id=party_id,
                task_id=task_id,
                task_version=task_version,
                job_args=job_args_on_party,
                job_parameters=job_parameters,
                task_parameters=task_parameters,
                input_dsl=task_input_dsl,
            )
            if module_name in {"Upload", "Download", "Reader", "Writer"}:
                task_run_args["job_parameters"] = job_parameters
            run_object = getattr(importlib.import_module(run_class_package),
                                 run_class_name)()
            run_object.set_tracker(tracker=tracker_client)
            run_object.set_task_version_id(
                task_version_id=job_utils.generate_task_version_id(
                    task_id, task_version))
            # add profile logs
            profile.profile_start()
            run_object.run(component_parameters_on_party, task_run_args)
            profile.profile_ends()
            output_data = run_object.save_data()
            if not isinstance(output_data, list):
                output_data = [output_data]
            for index in range(0, len(output_data)):
                data_name = task_output_dsl.get(
                    'data')[index] if task_output_dsl.get(
                        'data') else '{}'.format(index)
                persistent_table_namespace, persistent_table_name = tracker.save_output_data(
                    computing_table=output_data[index],
                    output_storage_engine=job_parameters.storage_engine,
                    output_storage_address=job_parameters.engines_address.get(
                        EngineType.STORAGE, {}))
                if persistent_table_namespace and persistent_table_name:
                    tracker.log_output_data_info(
                        data_name=data_name,
                        table_namespace=persistent_table_namespace,
                        table_name=persistent_table_name)
            output_model = run_object.export_model()
            # There is only one model output at the current dsl version.
            tracker.save_output_model(
                output_model, task_output_dsl['model'][0]
                if task_output_dsl.get('model') else 'default')
            task_info["party_status"] = TaskStatus.SUCCESS
        except Exception as e:
            task_info["party_status"] = TaskStatus.FAILED
            schedule_logger().exception(e)
        finally:
            try:
                task_info["end_time"] = current_timestamp()
                task_info["elapsed"] = task_info["end_time"] - start_time
                cls.report_task_update_to_driver(task_info=task_info)
            except Exception as e:
                task_info["party_status"] = TaskStatus.FAILED
                traceback.print_exc()
                schedule_logger().exception(e)
        schedule_logger().info('task {} {} {} start time: {}'.format(
            task_id, role, party_id, timestamp_to_date(start_time)))
        schedule_logger().info('task {} {} {} end time: {}'.format(
            task_id, role, party_id, timestamp_to_date(task_info["end_time"])))
        schedule_logger().info('task {} {} {} takes {}s'.format(
            task_id, role, party_id,
            int(task_info["elapsed"]) / 1000))
        schedule_logger().info('Finish {} {} {} {} {} {} task {}'.format(
            job_id, component_name, task_id, task_version, role, party_id,
            task_info["party_status"]))

        print('Finish {} {} {} {} {} {} task {}'.format(
            job_id, component_name, task_id, task_version, role, party_id,
            task_info["party_status"]))
        return task_info
Beispiel #7
0
    def run_task():
        task = Task()
        task.f_create_time = current_timestamp()
        try:
            parser = argparse.ArgumentParser()
            parser.add_argument('-j', '--job_id', required=True, type=str, help="job id")
            parser.add_argument('-n', '--component_name', required=True, type=str,
                                help="component name")
            parser.add_argument('-t', '--task_id', required=True, type=str, help="task id")
            parser.add_argument('-r', '--role', required=True, type=str, help="role")
            parser.add_argument('-p', '--party_id', required=True, type=str, help="party id")
            parser.add_argument('-c', '--config', required=True, type=str, help="task config")
            parser.add_argument('--processors_per_node', help="processors_per_node", type=int)
            parser.add_argument('--job_server', help="job server", type=str)
            args = parser.parse_args()
            schedule_logger(args.job_id).info('enter task process')
            schedule_logger(args.job_id).info(args)
            # init function args
            if args.job_server:
                RuntimeConfig.init_config(HTTP_PORT=args.job_server.split(':')[1])
                RuntimeConfig.set_process_role(ProcessRole.EXECUTOR)
            job_id = args.job_id
            component_name = args.component_name
            task_id = args.task_id
            role = args.role
            party_id = int(args.party_id)
            executor_pid = os.getpid()
            task_config = file_utils.load_json_conf(args.config)
            job_parameters = task_config['job_parameters']
            job_initiator = task_config['job_initiator']
            job_args = task_config['job_args']
            task_input_dsl = task_config['input']
            task_output_dsl = task_config['output']
            component_parameters = TaskExecutor.get_parameters(job_id, component_name, role, party_id)
            task_parameters = task_config['task_parameters']
            module_name = task_config['module_name']
            TaskExecutor.monkey_patch()
        except Exception as e:
            traceback.print_exc()
            schedule_logger().exception(e)
            task.f_status = TaskStatus.FAILED
            return
        try:
            job_log_dir = os.path.join(job_utils.get_job_log_directory(job_id=job_id), role, str(party_id))
            task_log_dir = os.path.join(job_log_dir, component_name)
            log_utils.LoggerFactory.set_directory(directory=task_log_dir, parent_log_dir=job_log_dir,
                                                  append_to_parent_log=True, force=True)

            task.f_job_id = job_id
            task.f_component_name = component_name
            task.f_task_id = task_id
            task.f_role = role
            task.f_party_id = party_id
            task.f_operator = 'python_operator'
            tracker = Tracking(job_id=job_id, role=role, party_id=party_id, component_name=component_name,
                               task_id=task_id,
                               model_id=job_parameters['model_id'],
                               model_version=job_parameters['model_version'],
                               component_module_name=module_name)
            task.f_start_time = current_timestamp()
            task.f_run_ip = get_lan_ip()
            task.f_run_pid = executor_pid
            run_class_paths = component_parameters.get('CodePath').split('/')
            run_class_package = '.'.join(run_class_paths[:-2]) + '.' + run_class_paths[-2].replace('.py', '')
            run_class_name = run_class_paths[-1]
            task.f_status = TaskStatus.RUNNING
            TaskExecutor.sync_task_status(job_id=job_id, component_name=component_name, task_id=task_id, role=role,
                                          party_id=party_id, initiator_party_id=job_initiator.get('party_id', None),
                                          initiator_role=job_initiator.get('role', None),
                                          task_info=task.to_json())

            # init environment, process is shared globally
            RuntimeConfig.init_config(WORK_MODE=job_parameters['work_mode'],
                                      BACKEND=job_parameters.get('backend', 0))
            if args.processors_per_node and args.processors_per_node > 0 and RuntimeConfig.BACKEND == Backend.EGGROLL:
                session_options = {"eggroll.session.processors.per.node": args.processors_per_node}
            else:
                session_options = {}
            session.init(job_id=job_utils.generate_session_id(task_id, role, party_id),
                         mode=RuntimeConfig.WORK_MODE,
                         backend=RuntimeConfig.BACKEND,
                         options=session_options)
            federation.init(job_id=task_id, runtime_conf=component_parameters)

            schedule_logger().info('run {} {} {} {} {} task'.format(job_id, component_name, task_id, role, party_id))
            schedule_logger().info(component_parameters)
            schedule_logger().info(task_input_dsl)
            task_run_args = TaskExecutor.get_task_run_args(job_id=job_id, role=role, party_id=party_id,
                                                           task_id=task_id,
                                                           job_args=job_args,
                                                           job_parameters=job_parameters,
                                                           task_parameters=task_parameters,
                                                           input_dsl=task_input_dsl,
                                                           if_save_as_task_input_data=job_parameters.get("save_as_task_input_data", SAVE_AS_TASK_INPUT_DATA_SWITCH)
                                                           )
            run_object = getattr(importlib.import_module(run_class_package), run_class_name)()
            run_object.set_tracker(tracker=tracker)
            run_object.set_taskid(taskid=task_id)
            run_object.run(component_parameters, task_run_args)
            output_data = run_object.save_data()
            tracker.save_output_data_table(output_data, task_output_dsl.get('data')[0] if task_output_dsl.get('data') else 'component')
            output_model = run_object.export_model()
            # There is only one model output at the current dsl version.
            tracker.save_output_model(output_model, task_output_dsl['model'][0] if task_output_dsl.get('model') else 'default')
            task.f_status = TaskStatus.COMPLETE
        except Exception as e:
            task.f_status = TaskStatus.FAILED
            schedule_logger().exception(e)
        finally:
            sync_success = False
            try:
                task.f_end_time = current_timestamp()
                task.f_elapsed = task.f_end_time - task.f_start_time
                task.f_update_time = current_timestamp()
                TaskExecutor.sync_task_status(job_id=job_id, component_name=component_name, task_id=task_id, role=role,
                                              party_id=party_id,
                                              initiator_party_id=job_initiator.get('party_id', None),
                                              initiator_role=job_initiator.get('role', None),
                                              task_info=task.to_json())
                sync_success = True
            except Exception as e:
                traceback.print_exc()
                schedule_logger().exception(e)
        schedule_logger().info('task {} {} {} start time: {}'.format(task_id, role, party_id, timestamp_to_date(task.f_start_time)))
        schedule_logger().info('task {} {} {} end time: {}'.format(task_id, role, party_id, timestamp_to_date(task.f_end_time)))
        schedule_logger().info('task {} {} {} takes {}s'.format(task_id, role, party_id, int(task.f_elapsed)/1000))
        schedule_logger().info(
            'finish {} {} {} {} {} {} task'.format(job_id, component_name, task_id, role, party_id, task.f_status if sync_success else TaskStatus.FAILED))

        print('finish {} {} {} {} {} {} task'.format(job_id, component_name, task_id, role, party_id, task.f_status if sync_success else TaskStatus.FAILED))
Beispiel #8
0
    def get_task_run_args(job_id, role, party_id, task_id, job_args, job_parameters, task_parameters, input_dsl,
                          if_save_as_task_input_data, filter_type=None, filter_attr=None):
        task_run_args = {}
        for input_type, input_detail in input_dsl.items():
            if filter_type and input_type not in filter_type:
                continue
            if input_type == 'data':
                this_type_args = task_run_args[input_type] = task_run_args.get(input_type, {})
                for data_type, data_list in input_detail.items():
                    for data_key in data_list:
                        data_key_item = data_key.split('.')
                        search_component_name, search_data_name = data_key_item[0], data_key_item[1]
                        if search_component_name == 'args':
                            if job_args.get('data', {}).get(search_data_name).get('namespace', '') and job_args.get(
                                    'data', {}).get(search_data_name).get('name', ''):

                                data_table = session.table(
                                    namespace=job_args['data'][search_data_name]['namespace'],
                                    name=job_args['data'][search_data_name]['name'])
                            else:
                                data_table = None
                        else:
                            data_table = Tracking(job_id=job_id, role=role, party_id=party_id,
                                                  component_name=search_component_name).get_output_data_table(
                                data_name=search_data_name)
                        args_from_component = this_type_args[search_component_name] = this_type_args.get(
                            search_component_name, {})
                        # todo: If the same component has more than one identical input, save as is repeated
                        if if_save_as_task_input_data:
                            if data_table:
                                schedule_logger().info("start save as task {} input data table {} {}".format(
                                    task_id,
                                    data_table.get_namespace(),
                                    data_table.get_name()))
                                origin_table_metas = data_table.get_metas()
                                origin_table_schema = data_table.schema
                                save_as_options = {"store_type": StoreTypes.ROLLPAIR_IN_MEMORY} if SAVE_AS_TASK_INPUT_DATA_IN_MEMORY else {}
                                data_table = data_table.save_as(
                                    namespace=job_utils.generate_session_id(task_id=task_id,
                                                                            role=role,
                                                                            party_id=party_id),
                                    name=data_table.get_name(),
                                    partition=task_parameters['input_data_partition'] if task_parameters.get('input_data_partition', 0) > 0 else data_table.get_partitions(),
                                    options=save_as_options)
                                data_table.save_metas(origin_table_metas)
                                data_table.schema = origin_table_schema
                                schedule_logger().info("save as task {} input data table to {} {} done".format(
                                    task_id,
                                    data_table.get_namespace(),
                                    data_table.get_name()))
                            else:
                                schedule_logger().info("pass save as task {} input data table, because the table is none".format(task_id))
                        else:
                            schedule_logger().info("pass save as task {} input data table, because the switch is off".format(task_id))
                        if not data_table or not filter_attr or not filter_attr.get("data", None):
                            args_from_component[data_type] = data_table
                        else:
                            args_from_component[data_type] = dict([(a, getattr(data_table, "get_{}".format(a))()) for a in filter_attr["data"]])
            elif input_type in ['model', 'isometric_model']:
                this_type_args = task_run_args[input_type] = task_run_args.get(input_type, {})
                for dsl_model_key in input_detail:
                    dsl_model_key_items = dsl_model_key.split('.')
                    if len(dsl_model_key_items) == 2:
                        search_component_name, search_model_alias = dsl_model_key_items[0], dsl_model_key_items[1]
                    elif len(dsl_model_key_items) == 3 and dsl_model_key_items[0] == 'pipeline':
                        search_component_name, search_model_alias = dsl_model_key_items[1], dsl_model_key_items[2]
                    else:
                        raise Exception('get input {} failed'.format(input_type))
                    models = Tracking(job_id=job_id, role=role, party_id=party_id, component_name=search_component_name,
                                      model_id=job_parameters['model_id'],
                                      model_version=job_parameters['model_version']).get_output_model(
                        model_alias=search_model_alias)
                    this_type_args[search_component_name] = models
        return task_run_args
Beispiel #9
0
 def run(self, component_parameters=None, args=None):
     self.parameters = component_parameters["UploadParam"]
     LOGGER.info(self.parameters)
     LOGGER.info(args)
     self.parameters["role"] = component_parameters["role"]
     self.parameters["local"] = component_parameters["local"]
     storage_engine = self.parameters["storage_engine"]
     storage_address = self.parameters["storage_address"]
     # if not set storage, use job storage as default
     if not storage_engine:
         storage_engine = args["job_parameters"].storage_engine
     if not storage_address:
         storage_address = args["job_parameters"].engines_address[
             EngineType.STORAGE]
     job_id = self.task_version_id.split("_")[0]
     if not os.path.isabs(self.parameters.get("file", "")):
         self.parameters["file"] = os.path.join(
             file_utils.get_project_base_directory(),
             self.parameters["file"])
     if not os.path.exists(self.parameters["file"]):
         raise Exception("%s is not exist, please check the configure" %
                         (self.parameters["file"]))
     if not os.path.getsize(self.parameters["file"]):
         raise Exception("%s is an empty file" % (self.parameters["file"]))
     name, namespace = self.parameters.get("name"), self.parameters.get(
         "namespace")
     _namespace, _table_name = self.generate_table_name(
         self.parameters["file"])
     if namespace is None:
         namespace = _namespace
     if name is None:
         name = _table_name
     read_head = self.parameters['head']
     if read_head == 0:
         head = False
     elif read_head == 1:
         head = True
     else:
         raise Exception("'head' in conf.json should be 0 or 1")
     partitions = self.parameters["partition"]
     if partitions <= 0 or partitions >= self.MAX_PARTITIONS:
         raise Exception(
             "Error number of partition, it should between %d and %d" %
             (0, self.MAX_PARTITIONS))
     with storage.Session.build(session_id=job_utils.generate_session_id(
             self.tracker.task_id,
             self.tracker.task_version,
             self.tracker.role,
             self.tracker.party_id,
             suffix="storage",
             random_end=True),
                                namespace=namespace,
                                name=name) as storage_session:
         if self.parameters.get("destroy", False):
             table = storage_session.get_table()
             if table:
                 LOGGER.info(
                     f"destroy table name: {name} namespace: {namespace} engine: {table.get_engine()}"
                 )
                 table.destroy()
             else:
                 LOGGER.info(
                     f"can not found table name: {name} namespace: {namespace}, pass destroy"
                 )
     address_dict = storage_address.copy()
     with storage.Session.build(
             session_id=job_utils.generate_session_id(
                 self.tracker.task_id,
                 self.tracker.task_version,
                 self.tracker.role,
                 self.tracker.party_id,
                 suffix="storage",
                 random_end=True),
             storage_engine=storage_engine,
             options=self.parameters.get("options")) as storage_session:
         if storage_engine in {
                 StorageEngine.EGGROLL, StorageEngine.STANDALONE
         }:
             upload_address = {
                 "name": name,
                 "namespace": namespace,
                 "storage_type": EggRollStorageType.ROLLPAIR_LMDB
             }
         elif storage_engine in {StorageEngine.MYSQL}:
             upload_address = {"db": namespace, "name": name}
         elif storage_engine in {StorageEngine.HDFS}:
             upload_address = {
                 "path":
                 data_utils.default_input_fs_path(
                     name=name,
                     namespace=namespace,
                     prefix=address_dict.get("path_prefix"))
             }
         else:
             raise RuntimeError(
                 f"can not support this storage engine: {storage_engine}")
         address_dict.update(upload_address)
         LOGGER.info(
             f"upload to {storage_engine} storage, address: {address_dict}")
         address = storage.StorageTableMeta.create_address(
             storage_engine=storage_engine, address_dict=address_dict)
         self.parameters["partitions"] = partitions
         self.parameters["name"] = name
         self.table = storage_session.create_table(address=address,
                                                   **self.parameters)
         data_table_count = self.save_data_table(job_id, name, namespace,
                                                 head)
         self.table.get_meta().update_metas(in_serialized=True)
     LOGGER.info("------------load data finish!-----------------")
     # rm tmp file
     try:
         if '{}/fate_upload_tmp'.format(job_id) in self.parameters['file']:
             LOGGER.info("remove tmp upload file")
             LOGGER.info(os.path.dirname(self.parameters["file"]))
             shutil.rmtree(os.path.dirname(self.parameters["file"]))
     except:
         LOGGER.info("remove tmp file failed")
     LOGGER.info("file: {}".format(self.parameters["file"]))
     LOGGER.info("total data_count: {}".format(data_table_count))
     LOGGER.info("table name: {}, table namespace: {}".format(
         name, namespace))
Beispiel #10
0
    def start_task_worker(cls,
                          worker_name,
                          task: Task,
                          task_parameters: RunParameters = None,
                          executable: list = None,
                          extra_env: dict = None,
                          **kwargs):
        worker_id, config_dir, log_dir = cls.get_process_dirs(
            worker_name=worker_name,
            job_id=task.f_job_id,
            role=task.f_role,
            party_id=task.f_party_id,
            task=task)

        session_id = job_utils.generate_session_id(task.f_task_id,
                                                   task.f_task_version,
                                                   task.f_role,
                                                   task.f_party_id)
        federation_session_id = job_utils.generate_task_version_id(
            task.f_task_id, task.f_task_version)

        info_kwargs = {}
        specific_cmd = []
        if worker_name is WorkerName.TASK_EXECUTOR:
            from fate_flow.worker.task_executor import TaskExecutor
            module_file_path = sys.modules[TaskExecutor.__module__].__file__
        else:
            raise Exception(f"not support {worker_name} worker")

        if task_parameters is None:
            task_parameters = RunParameters(**job_utils.get_job_parameters(
                task.f_job_id, task.f_role, task.f_party_id))

        config = task_parameters.to_dict()
        config["src_user"] = kwargs.get("src_user")
        config_path, result_path = cls.get_config(config_dir=config_dir,
                                                  config=config,
                                                  log_dir=log_dir)

        if executable:
            process_cmd = executable
        else:
            process_cmd = [sys.executable or "python3"]

        common_cmd = [
            module_file_path,
            "--job_id",
            task.f_job_id,
            "--component_name",
            task.f_component_name,
            "--task_id",
            task.f_task_id,
            "--task_version",
            task.f_task_version,
            "--role",
            task.f_role,
            "--party_id",
            task.f_party_id,
            "--config",
            config_path,
            '--result',
            result_path,
            "--log_dir",
            log_dir,
            "--parent_log_dir",
            os.path.dirname(log_dir),
            "--worker_id",
            worker_id,
            "--run_ip",
            RuntimeConfig.JOB_SERVER_HOST,
            "--job_server",
            f"{RuntimeConfig.JOB_SERVER_HOST}:{RuntimeConfig.HTTP_PORT}",
            "--session_id",
            session_id,
            "--federation_session_id",
            federation_session_id,
        ]
        process_cmd.extend(common_cmd)
        process_cmd.extend(specific_cmd)
        env = cls.get_env(task.f_job_id, task.f_provider_info)
        if extra_env:
            env.update(extra_env)
        schedule_logger(task.f_job_id).info(
            f"task {task.f_task_id} {task.f_task_version} on {task.f_role} {task.f_party_id} {worker_name} worker subprocess is ready"
        )
        p = process_utils.run_subprocess(job_id=task.f_job_id,
                                         config_dir=config_dir,
                                         process_cmd=process_cmd,
                                         added_env=env,
                                         log_dir=log_dir,
                                         cwd_dir=config_dir,
                                         process_name=worker_name.value,
                                         process_id=worker_id)
        cls.save_worker_info(task=task,
                             worker_name=worker_name,
                             worker_id=worker_id,
                             run_ip=RuntimeConfig.JOB_SERVER_HOST,
                             run_pid=p.pid,
                             config=config,
                             cmd=process_cmd,
                             **info_kwargs)
        return {"run_pid": p.pid, "worker_id": worker_id, "cmd": process_cmd}
Beispiel #11
0
 def _run(self, cpn_input: ComponentInputProtocol):
     self.parameters = cpn_input.parameters
     LOGGER.info(self.parameters)
     self.parameters["role"] = cpn_input.roles["role"]
     self.parameters["local"] = cpn_input.roles["local"]
     storage_engine = self.parameters["storage_engine"].upper()
     storage_address = self.parameters["storage_address"]
     # if not set storage, use job storage as default
     if not storage_engine:
         storage_engine = cpn_input.job_parameters.storage_engine
     self.storage_engine = storage_engine
     if not storage_address:
         storage_address = cpn_input.job_parameters.engines_address[
             EngineType.STORAGE]
     job_id = self.task_version_id.split("_")[0]
     if not os.path.isabs(self.parameters.get("file", "")):
         self.parameters["file"] = os.path.join(get_fate_flow_directory(),
                                                self.parameters["file"])
     if not os.path.exists(self.parameters["file"]):
         raise Exception("%s is not exist, please check the configure" %
                         (self.parameters["file"]))
     if not os.path.getsize(self.parameters["file"]):
         raise Exception("%s is an empty file" % (self.parameters["file"]))
     name, namespace = self.parameters.get("name"), self.parameters.get(
         "namespace")
     _namespace, _table_name = self.generate_table_name(
         self.parameters["file"])
     if namespace is None:
         namespace = _namespace
     if name is None:
         name = _table_name
     read_head = self.parameters["head"]
     if read_head == 0:
         head = False
     elif read_head == 1:
         head = True
     else:
         raise Exception("'head' in conf.json should be 0 or 1")
     partitions = self.parameters["partition"]
     if partitions <= 0 or partitions >= self.MAX_PARTITIONS:
         raise Exception(
             "Error number of partition, it should between %d and %d" %
             (0, self.MAX_PARTITIONS))
     self.session_id = job_utils.generate_session_id(
         self.tracker.task_id,
         self.tracker.task_version,
         self.tracker.role,
         self.tracker.party_id,
     )
     sess = Session.get_global()
     self.session = sess
     if self.parameters.get("destroy", False):
         table = sess.get_table(namespace=namespace, name=name)
         if table:
             LOGGER.info(
                 f"destroy table name: {name} namespace: {namespace} engine: {table.engine}"
             )
             try:
                 table.destroy()
             except Exception as e:
                 LOGGER.error(e)
         else:
             LOGGER.info(
                 f"can not found table name: {name} namespace: {namespace}, pass destroy"
             )
     address_dict = storage_address.copy()
     storage_session = sess.storage(storage_engine=storage_engine,
                                    options=self.parameters.get("options"))
     upload_address = {}
     if storage_engine in {StorageEngine.EGGROLL, StorageEngine.STANDALONE}:
         upload_address = {
             "name": name,
             "namespace": namespace,
             "storage_type": EggRollStoreType.ROLLPAIR_LMDB,
         }
     elif storage_engine in {StorageEngine.MYSQL, StorageEngine.HIVE}:
         if not address_dict.get("db") or not address_dict.get("name"):
             upload_address = {"db": namespace, "name": name}
     elif storage_engine in {StorageEngine.PATH}:
         upload_address = {"path": self.parameters["file"]}
     elif storage_engine in {StorageEngine.HDFS}:
         upload_address = {
             "path":
             default_input_fs_path(
                 name=name,
                 namespace=namespace,
                 prefix=address_dict.get("path_prefix"),
             )
         }
     elif storage_engine in {StorageEngine.LOCALFS}:
         upload_address = {
             "path":
             default_input_fs_path(name=name,
                                   namespace=namespace,
                                   storage_engine=storage_engine)
         }
     else:
         raise RuntimeError(
             f"can not support this storage engine: {storage_engine}")
     address_dict.update(upload_address)
     LOGGER.info(
         f"upload to {storage_engine} storage, address: {address_dict}")
     address = storage.StorageTableMeta.create_address(
         storage_engine=storage_engine, address_dict=address_dict)
     self.parameters["partitions"] = partitions
     self.parameters["name"] = name
     self.table = storage_session.create_table(
         address=address,
         origin=StorageTableOrigin.UPLOAD,
         **self.parameters)
     if storage_engine not in [StorageEngine.PATH]:
         data_table_count = self.save_data_table(job_id, name, namespace,
                                                 storage_engine, head)
     else:
         data_table_count = self.get_data_table_count(
             self.parameters["file"], name, namespace)
     self.table.meta.update_metas(in_serialized=True)
     DataTableTracker.create_table_tracker(
         table_name=name,
         table_namespace=namespace,
         entity_info={
             "job_id": job_id,
             "have_parent": False
         },
     )
     LOGGER.info("------------load data finish!-----------------")
     # rm tmp file
     try:
         if "{}/fate_upload_tmp".format(job_id) in self.parameters["file"]:
             LOGGER.info("remove tmp upload file")
             LOGGER.info(os.path.dirname(self.parameters["file"]))
             shutil.rmtree(os.path.dirname(self.parameters["file"]))
     except:
         LOGGER.info("remove tmp file failed")
     LOGGER.info("file: {}".format(self.parameters["file"]))
     LOGGER.info("total data_count: {}".format(data_table_count))
     LOGGER.info("table name: {}, table namespace: {}".format(
         name, namespace))