示例#1
0
    # init
    signal.signal(signal.SIGTERM, job_utils.cleaning)
    signal.signal(signal.SIGCHLD, job_utils.wait_child_process)
    # init db
    init_flow_db()
    init_arch_db()
    # init runtime config
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--standalone_node',
                        default=False,
                        help="if standalone node mode or not ",
                        action='store_true')
    args = parser.parse_args()
    RuntimeConfig.init_env()
    RuntimeConfig.set_process_role(ProcessRole.DRIVER)
    PrivilegeAuth.init()
    ServiceUtils.register()
    ResourceManager.initialize()
    Detector(interval=5 * 1000).start()
    DAGScheduler(interval=2 * 1000).start()
    server = grpc.server(
        futures.ThreadPoolExecutor(max_workers=10),
        options=[(cygrpc.ChannelArgKey.max_send_message_length, -1),
                 (cygrpc.ChannelArgKey.max_receive_message_length, -1)])

    proxy_pb2_grpc.add_DataTransferServiceServicer_to_server(
        UnaryService(), server)
    server.add_insecure_port("{}:{}".format(IP, GRPC_PORT))
    server.start()
    # start http server
示例#2
0
    def run_task(cls):
        task_info = {}
        try:
            parser = argparse.ArgumentParser()
            parser.add_argument('-j',
                                '--job_id',
                                required=True,
                                type=str,
                                help="job id")
            parser.add_argument('-n',
                                '--component_name',
                                required=True,
                                type=str,
                                help="component name")
            parser.add_argument('-t',
                                '--task_id',
                                required=True,
                                type=str,
                                help="task id")
            parser.add_argument('-v',
                                '--task_version',
                                required=True,
                                type=int,
                                help="task version")
            parser.add_argument('-r',
                                '--role',
                                required=True,
                                type=str,
                                help="role")
            parser.add_argument('-p',
                                '--party_id',
                                required=True,
                                type=int,
                                help="party id")
            parser.add_argument('-c',
                                '--config',
                                required=True,
                                type=str,
                                help="task parameters")
            parser.add_argument('--run_ip', help="run ip", type=str)
            parser.add_argument('--job_server', help="job server", type=str)
            args = parser.parse_args()
            schedule_logger(args.job_id).info('enter task process')
            schedule_logger(args.job_id).info(args)
            # init function args
            if args.job_server:
                RuntimeConfig.init_config(
                    JOB_SERVER_HOST=args.job_server.split(':')[0],
                    HTTP_PORT=args.job_server.split(':')[1])
                RuntimeConfig.set_process_role(ProcessRole.EXECUTOR)
            job_id = args.job_id
            component_name = args.component_name
            task_id = args.task_id
            task_version = args.task_version
            role = args.role
            party_id = args.party_id
            executor_pid = os.getpid()
            task_info.update({
                "job_id": job_id,
                "component_name": component_name,
                "task_id": task_id,
                "task_version": task_version,
                "role": role,
                "party_id": party_id,
                "run_ip": args.run_ip,
                "run_pid": executor_pid
            })
            start_time = current_timestamp()
            job_conf = job_utils.get_job_conf(job_id, role)
            job_dsl = job_conf["job_dsl_path"]
            job_runtime_conf = job_conf["job_runtime_conf_path"]
            dsl_parser = schedule_utils.get_job_dsl_parser(
                dsl=job_dsl,
                runtime_conf=job_runtime_conf,
                train_runtime_conf=job_conf["train_runtime_conf_path"],
                pipeline_dsl=job_conf["pipeline_dsl_path"])
            party_index = job_runtime_conf["role"][role].index(party_id)
            job_args_on_party = TaskExecutor.get_job_args_on_party(
                dsl_parser, job_runtime_conf, role, party_id)
            component = dsl_parser.get_component_info(
                component_name=component_name)
            component_parameters = component.get_role_parameters()
            component_parameters_on_party = component_parameters[role][
                party_index] if role in component_parameters else {}
            module_name = component.get_module()
            task_input_dsl = component.get_input()
            task_output_dsl = component.get_output()
            component_parameters_on_party[
                'output_data_name'] = task_output_dsl.get('data')
            task_parameters = RunParameters(
                **file_utils.load_json_conf(args.config))
            job_parameters = task_parameters
            if job_parameters.assistant_role:
                TaskExecutor.monkey_patch()
        except Exception as e:
            traceback.print_exc()
            schedule_logger().exception(e)
            task_info["party_status"] = TaskStatus.FAILED
            return
        try:
            job_log_dir = os.path.join(
                job_utils.get_job_log_directory(job_id=job_id), role,
                str(party_id))
            task_log_dir = os.path.join(job_log_dir, component_name)
            log.LoggerFactory.set_directory(directory=task_log_dir,
                                            parent_log_dir=job_log_dir,
                                            append_to_parent_log=True,
                                            force=True)

            tracker = Tracker(job_id=job_id,
                              role=role,
                              party_id=party_id,
                              component_name=component_name,
                              task_id=task_id,
                              task_version=task_version,
                              model_id=job_parameters.model_id,
                              model_version=job_parameters.model_version,
                              component_module_name=module_name,
                              job_parameters=job_parameters)
            tracker_client = TrackerClient(
                job_id=job_id,
                role=role,
                party_id=party_id,
                component_name=component_name,
                task_id=task_id,
                task_version=task_version,
                model_id=job_parameters.model_id,
                model_version=job_parameters.model_version,
                component_module_name=module_name,
                job_parameters=job_parameters)
            run_class_paths = component_parameters_on_party.get(
                'CodePath').split('/')
            run_class_package = '.'.join(
                run_class_paths[:-2]) + '.' + run_class_paths[-2].replace(
                    '.py', '')
            run_class_name = run_class_paths[-1]
            task_info["party_status"] = TaskStatus.RUNNING
            cls.report_task_update_to_driver(task_info=task_info)

            # init environment, process is shared globally
            RuntimeConfig.init_config(
                WORK_MODE=job_parameters.work_mode,
                COMPUTING_ENGINE=job_parameters.computing_engine,
                FEDERATION_ENGINE=job_parameters.federation_engine,
                FEDERATED_MODE=job_parameters.federated_mode)

            if RuntimeConfig.COMPUTING_ENGINE == ComputingEngine.EGGROLL:
                session_options = task_parameters.eggroll_run.copy()
            else:
                session_options = {}

            sess = session.Session(
                computing_type=job_parameters.computing_engine,
                federation_type=job_parameters.federation_engine)
            computing_session_id = job_utils.generate_session_id(
                task_id, task_version, role, party_id)
            sess.init_computing(computing_session_id=computing_session_id,
                                options=session_options)
            federation_session_id = job_utils.generate_task_version_id(
                task_id, task_version)
            component_parameters_on_party[
                "job_parameters"] = job_parameters.to_dict()
            sess.init_federation(
                federation_session_id=federation_session_id,
                runtime_conf=component_parameters_on_party,
                service_conf=job_parameters.engines_address.get(
                    EngineType.FEDERATION, {}))
            sess.as_default()

            schedule_logger().info('Run {} {} {} {} {} task'.format(
                job_id, component_name, task_id, role, party_id))
            schedule_logger().info("Component parameters on party {}".format(
                component_parameters_on_party))
            schedule_logger().info("Task input dsl {}".format(task_input_dsl))
            task_run_args = cls.get_task_run_args(
                job_id=job_id,
                role=role,
                party_id=party_id,
                task_id=task_id,
                task_version=task_version,
                job_args=job_args_on_party,
                job_parameters=job_parameters,
                task_parameters=task_parameters,
                input_dsl=task_input_dsl,
            )
            if module_name in {"Upload", "Download", "Reader", "Writer"}:
                task_run_args["job_parameters"] = job_parameters
            run_object = getattr(importlib.import_module(run_class_package),
                                 run_class_name)()
            run_object.set_tracker(tracker=tracker_client)
            run_object.set_task_version_id(
                task_version_id=job_utils.generate_task_version_id(
                    task_id, task_version))
            # add profile logs
            profile.profile_start()
            run_object.run(component_parameters_on_party, task_run_args)
            profile.profile_ends()
            output_data = run_object.save_data()
            if not isinstance(output_data, list):
                output_data = [output_data]
            for index in range(0, len(output_data)):
                data_name = task_output_dsl.get(
                    'data')[index] if task_output_dsl.get(
                        'data') else '{}'.format(index)
                persistent_table_namespace, persistent_table_name = tracker.save_output_data(
                    computing_table=output_data[index],
                    output_storage_engine=job_parameters.storage_engine,
                    output_storage_address=job_parameters.engines_address.get(
                        EngineType.STORAGE, {}))
                if persistent_table_namespace and persistent_table_name:
                    tracker.log_output_data_info(
                        data_name=data_name,
                        table_namespace=persistent_table_namespace,
                        table_name=persistent_table_name)
            output_model = run_object.export_model()
            # There is only one model output at the current dsl version.
            tracker.save_output_model(
                output_model, task_output_dsl['model'][0]
                if task_output_dsl.get('model') else 'default')
            task_info["party_status"] = TaskStatus.SUCCESS
        except Exception as e:
            task_info["party_status"] = TaskStatus.FAILED
            schedule_logger().exception(e)
        finally:
            try:
                task_info["end_time"] = current_timestamp()
                task_info["elapsed"] = task_info["end_time"] - start_time
                cls.report_task_update_to_driver(task_info=task_info)
            except Exception as e:
                task_info["party_status"] = TaskStatus.FAILED
                traceback.print_exc()
                schedule_logger().exception(e)
        schedule_logger().info('task {} {} {} start time: {}'.format(
            task_id, role, party_id, timestamp_to_date(start_time)))
        schedule_logger().info('task {} {} {} end time: {}'.format(
            task_id, role, party_id, timestamp_to_date(task_info["end_time"])))
        schedule_logger().info('task {} {} {} takes {}s'.format(
            task_id, role, party_id,
            int(task_info["elapsed"]) / 1000))
        schedule_logger().info('Finish {} {} {} {} {} {} task {}'.format(
            job_id, component_name, task_id, task_version, role, party_id,
            task_info["party_status"]))

        print('Finish {} {} {} {} {} {} task {}'.format(
            job_id, component_name, task_id, task_version, role, party_id,
            task_info["party_status"]))
        return task_info
示例#3
0
    def run_task():
        task = Task()
        task.f_create_time = current_timestamp()
        try:
            parser = argparse.ArgumentParser()
            parser.add_argument('-j', '--job_id', required=True, type=str, help="job id")
            parser.add_argument('-n', '--component_name', required=True, type=str,
                                help="component name")
            parser.add_argument('-t', '--task_id', required=True, type=str, help="task id")
            parser.add_argument('-r', '--role', required=True, type=str, help="role")
            parser.add_argument('-p', '--party_id', required=True, type=str, help="party id")
            parser.add_argument('-c', '--config', required=True, type=str, help="task config")
            parser.add_argument('--processors_per_node', help="processors_per_node", type=int)
            parser.add_argument('--job_server', help="job server", type=str)
            args = parser.parse_args()
            schedule_logger(args.job_id).info('enter task process')
            schedule_logger(args.job_id).info(args)
            # init function args
            if args.job_server:
                RuntimeConfig.init_config(HTTP_PORT=args.job_server.split(':')[1])
                RuntimeConfig.set_process_role(ProcessRole.EXECUTOR)
            job_id = args.job_id
            component_name = args.component_name
            task_id = args.task_id
            role = args.role
            party_id = int(args.party_id)
            executor_pid = os.getpid()
            task_config = file_utils.load_json_conf(args.config)
            job_parameters = task_config['job_parameters']
            job_initiator = task_config['job_initiator']
            job_args = task_config['job_args']
            task_input_dsl = task_config['input']
            task_output_dsl = task_config['output']
            component_parameters = TaskExecutor.get_parameters(job_id, component_name, role, party_id)
            task_parameters = task_config['task_parameters']
            module_name = task_config['module_name']
            TaskExecutor.monkey_patch()
        except Exception as e:
            traceback.print_exc()
            schedule_logger().exception(e)
            task.f_status = TaskStatus.FAILED
            return
        try:
            job_log_dir = os.path.join(job_utils.get_job_log_directory(job_id=job_id), role, str(party_id))
            task_log_dir = os.path.join(job_log_dir, component_name)
            log_utils.LoggerFactory.set_directory(directory=task_log_dir, parent_log_dir=job_log_dir,
                                                  append_to_parent_log=True, force=True)

            task.f_job_id = job_id
            task.f_component_name = component_name
            task.f_task_id = task_id
            task.f_role = role
            task.f_party_id = party_id
            task.f_operator = 'python_operator'
            tracker = Tracking(job_id=job_id, role=role, party_id=party_id, component_name=component_name,
                               task_id=task_id,
                               model_id=job_parameters['model_id'],
                               model_version=job_parameters['model_version'],
                               component_module_name=module_name)
            task.f_start_time = current_timestamp()
            task.f_run_ip = get_lan_ip()
            task.f_run_pid = executor_pid
            run_class_paths = component_parameters.get('CodePath').split('/')
            run_class_package = '.'.join(run_class_paths[:-2]) + '.' + run_class_paths[-2].replace('.py', '')
            run_class_name = run_class_paths[-1]
            task.f_status = TaskStatus.RUNNING
            TaskExecutor.sync_task_status(job_id=job_id, component_name=component_name, task_id=task_id, role=role,
                                          party_id=party_id, initiator_party_id=job_initiator.get('party_id', None),
                                          initiator_role=job_initiator.get('role', None),
                                          task_info=task.to_json())

            # init environment, process is shared globally
            RuntimeConfig.init_config(WORK_MODE=job_parameters['work_mode'],
                                      BACKEND=job_parameters.get('backend', 0))
            if args.processors_per_node and args.processors_per_node > 0 and RuntimeConfig.BACKEND == Backend.EGGROLL:
                session_options = {"eggroll.session.processors.per.node": args.processors_per_node}
            else:
                session_options = {}
            session.init(job_id=job_utils.generate_session_id(task_id, role, party_id),
                         mode=RuntimeConfig.WORK_MODE,
                         backend=RuntimeConfig.BACKEND,
                         options=session_options)
            federation.init(job_id=task_id, runtime_conf=component_parameters)

            schedule_logger().info('run {} {} {} {} {} task'.format(job_id, component_name, task_id, role, party_id))
            schedule_logger().info(component_parameters)
            schedule_logger().info(task_input_dsl)
            task_run_args = TaskExecutor.get_task_run_args(job_id=job_id, role=role, party_id=party_id,
                                                           task_id=task_id,
                                                           job_args=job_args,
                                                           job_parameters=job_parameters,
                                                           task_parameters=task_parameters,
                                                           input_dsl=task_input_dsl,
                                                           if_save_as_task_input_data=job_parameters.get("save_as_task_input_data", SAVE_AS_TASK_INPUT_DATA_SWITCH)
                                                           )
            run_object = getattr(importlib.import_module(run_class_package), run_class_name)()
            run_object.set_tracker(tracker=tracker)
            run_object.set_taskid(taskid=task_id)
            run_object.run(component_parameters, task_run_args)
            output_data = run_object.save_data()
            tracker.save_output_data_table(output_data, task_output_dsl.get('data')[0] if task_output_dsl.get('data') else 'component')
            output_model = run_object.export_model()
            # There is only one model output at the current dsl version.
            tracker.save_output_model(output_model, task_output_dsl['model'][0] if task_output_dsl.get('model') else 'default')
            task.f_status = TaskStatus.COMPLETE
        except Exception as e:
            task.f_status = TaskStatus.FAILED
            schedule_logger().exception(e)
        finally:
            sync_success = False
            try:
                task.f_end_time = current_timestamp()
                task.f_elapsed = task.f_end_time - task.f_start_time
                task.f_update_time = current_timestamp()
                TaskExecutor.sync_task_status(job_id=job_id, component_name=component_name, task_id=task_id, role=role,
                                              party_id=party_id,
                                              initiator_party_id=job_initiator.get('party_id', None),
                                              initiator_role=job_initiator.get('role', None),
                                              task_info=task.to_json())
                sync_success = True
            except Exception as e:
                traceback.print_exc()
                schedule_logger().exception(e)
        schedule_logger().info('task {} {} {} start time: {}'.format(task_id, role, party_id, timestamp_to_date(task.f_start_time)))
        schedule_logger().info('task {} {} {} end time: {}'.format(task_id, role, party_id, timestamp_to_date(task.f_end_time)))
        schedule_logger().info('task {} {} {} takes {}s'.format(task_id, role, party_id, int(task.f_elapsed)/1000))
        schedule_logger().info(
            'finish {} {} {} {} {} {} task'.format(job_id, component_name, task_id, role, party_id, task.f_status if sync_success else TaskStatus.FAILED))

        print('finish {} {} {} {} {} {} task'.format(job_id, component_name, task_id, role, party_id, task.f_status if sync_success else TaskStatus.FAILED))