コード例 #1
0
 def init_config(**kwargs):
     for k, v in kwargs.items():
         if hasattr(RuntimeConfig, k):
             setattr(RuntimeConfig, k, v)
             if k == 'HTTP_PORT':
                 setattr(
                     RuntimeConfig, 'JOB_SERVER_HOST',
                     "{}:{}".format(get_lan_ip(), RuntimeConfig.HTTP_PORT))
コード例 #2
0
 def init(hosts, use_configuation_center, fate_flow_zk_path, fate_flow_port, model_transfer_path):
     if use_configuation_center:
         zk = CenterConfig.get_zk(hosts)
         zk.start()
         model_host = 'http://{}:{}{}'.format(get_lan_ip(), fate_flow_port, model_transfer_path)
         fate_flow_zk_path = '{}/{}'.format(fate_flow_zk_path, parse.quote(model_host, safe=' '))
         try:
             zk.create(fate_flow_zk_path, makepath=True)
         except:
             pass
         zk.stop()
コード例 #3
0
ファイル: job_detector.py プロジェクト: pangzx1/FATE1.1
 def run_do(self):
     try:
         running_tasks = job_utils.query_task(status='running',
                                              run_ip=get_lan_ip())
         stop_job_ids = set()
         detect_logger.info('start to detect running job..')
         for task in running_tasks:
             try:
                 process_exist = job_utils.check_job_process(
                     int(task.f_run_pid))
                 if not process_exist:
                     detect_logger.info(
                         'job {} component {} on {} {} task {} {} process does not exist'
                         .format(task.f_job_id, task.f_component_name,
                                 task.f_role, task.f_party_id,
                                 task.f_task_id, task.f_run_pid))
                     stop_job_ids.add(task.f_job_id)
             except Exception as e:
                 detect_logger.exception(e)
         if stop_job_ids:
             schedule_logger().info(
                 'start to stop jobs: {}'.format(stop_job_ids))
         for job_id in stop_job_ids:
             jobs = job_utils.query_job(job_id=job_id)
             if jobs:
                 initiator_party_id = jobs[0].f_initiator_party_id
                 job_work_mode = jobs[0].f_work_mode
                 if len(jobs) > 1:
                     # i am initiator
                     my_party_id = initiator_party_id
                 else:
                     my_party_id = jobs[0].f_party_id
                     initiator_party_id = jobs[0].f_initiator_party_id
                 api_utils.federated_api(
                     job_id=job_id,
                     method='POST',
                     endpoint='/{}/job/stop'.format(API_VERSION),
                     src_party_id=my_party_id,
                     dest_party_id=initiator_party_id,
                     src_role=None,
                     json_body={'job_id': job_id},
                     work_mode=job_work_mode)
                 schedule_logger(job_id).info(
                     'send stop job {} command'.format(job_id))
     except Exception as e:
         detect_logger.exception(e)
     finally:
         detect_logger.info('finish detect running job')
コード例 #4
0
ファイル: job_controller.py プロジェクト: pangzx1/FL1.0
    def update_job_status(job_id, role, party_id, job_info, create=False):
        job_tracker = Tracking(job_id=job_id, role=role, party_id=party_id)
        job_info['f_run_ip'] = get_lan_ip()
        if create:
            dsl = json_loads(job_info['f_dsl'])
            runtime_conf = json_loads(job_info['f_runtime_conf'])
            train_runtime_conf = json_loads(job_info['f_train_runtime_conf'])
            save_job_conf(job_id=job_id,
                          job_dsl=dsl,
                          job_runtime_conf=runtime_conf)
            roles = json_loads(job_info['f_roles'])
            partner = {}
            show_role = {}
            is_initiator = job_info.get('f_is_initiator', 0)
            for _role, _role_party in roles.items():
                if is_initiator or _role == role:
                    show_role[_role] = show_role.get(_role, [])
                    for _party_id in _role_party:
                        if is_initiator or _party_id == party_id:
                            show_role[_role].append(_party_id)

                if _role != role:
                    partner[_role] = partner.get(_role, [])
                    partner[_role].extend(_role_party)
                else:
                    for _party_id in _role_party:
                        if _party_id != party_id:
                            partner[_role] = partner.get(_role, [])
                            partner[_role].append(_party_id)

            dag = get_job_dsl_parser(dsl=dsl,
                                     runtime_conf=runtime_conf,
                                     train_runtime_conf=train_runtime_conf)
            job_args = dag.get_args_input()
            dataset = {}
            for _role, _role_party_args in job_args.items():
                if is_initiator or _role == role:
                    for _party_index in range(len(_role_party_args)):
                        _party_id = roles[_role][_party_index]
                        if is_initiator or _party_id == party_id:
                            dataset[_role] = dataset.get(_role, {})
                            dataset[_role][_party_id] = dataset[_role].get(_party_id, {})
                            for _data_type, _data_location in _role_party_args[_party_index]['args']['data'].items():
                                dataset[_role][_party_id][_data_type] = '{}.{}'.format(_data_location['namespace'],
                                                                                       _data_location['name'])
            job_tracker.log_job_view({'partner': partner, 'dataset': dataset, 'roles': show_role})
        job_tracker.save_job_info(role=role, party_id=party_id, job_info=job_info, create=create)
コード例 #5
0
ファイル: fate_flow_client.py プロジェクト: errord/FATE-1
def call_fun(func, config_data, dsl_path, config_path):
    ip = server_conf.get(SERVERS).get(ROLE).get('host')
    if ip in ['localhost', '127.0.0.1']:
        ip = get_lan_ip()
    http_port = server_conf.get(SERVERS).get(ROLE).get('http.port')
    server_url = "http://{}:{}/{}".format(ip, http_port, API_VERSION)

    if func in JOB_OPERATE_FUNC:
        if func == 'submit_job':
            if not config_path:
                raise Exception(
                    'the following arguments are required: {}'.format(
                        'runtime conf path'))
            dsl_data = {}
            if dsl_path or config_data.get('job_parameters', {}).get(
                    'job_type', '') == 'predict':
                if dsl_path:
                    dsl_path = os.path.abspath(dsl_path)
                    with open(dsl_path, 'r') as f:
                        dsl_data = json.load(f)
            else:
                raise Exception(
                    'the following arguments are required: {}'.format(
                        'dsl path'))
            post_data = {'job_dsl': dsl_data, 'job_runtime_conf': config_data}
            response = requests.post("/".join(
                [server_url, "job", func.rstrip('_job')]),
                                     json=post_data)
            try:
                if response.json()['retcode'] == 999:
                    start_cluster_standalone_job_server()
                    response = requests.post("/".join(
                        [server_url, "job",
                         func.rstrip('_job')]),
                                             json=post_data)
            except:
                pass
        elif func == 'data_view_query':
            response = requests.post("/".join(
                [server_url, "job", func.replace('_', '/')]),
                                     json=config_data)
        else:
            if func != 'query_job':
                detect_utils.check_config(config=config_data,
                                          required_arguments=['job_id'])
            post_data = config_data
            response = requests.post("/".join(
                [server_url, "job", func.rstrip('_job')]),
                                     json=post_data)
            if func == 'query_job':
                response = response.json()
                if response['retcode'] == 0:
                    for i in range(len(response['data'])):
                        del response['data'][i]['f_runtime_conf']
                        del response['data'][i]['f_dsl']
    elif func in JOB_FUNC:
        if func == 'job_config':
            detect_utils.check_config(config=config_data,
                                      required_arguments=[
                                          'job_id', 'role', 'party_id',
                                          'output_path'
                                      ])
            response = requests.post("/".join(
                [server_url, func.replace('_', '/')]),
                                     json=config_data)
            response_data = response.json()
            if response_data['retcode'] == 0:
                job_id = response_data['data']['job_id']
                download_directory = os.path.join(
                    config_data['output_path'], 'job_{}_config'.format(job_id))
                os.makedirs(download_directory, exist_ok=True)
                for k, v in response_data['data'].items():
                    if k == 'job_id':
                        continue
                    with open('{}/{}.json'.format(download_directory, k),
                              'w') as fw:
                        json.dump(v, fw, indent=4)
                del response_data['data']['dsl']
                del response_data['data']['runtime_conf']
                response_data['directory'] = download_directory
                response_data[
                    'retmsg'] = 'download successfully, please check {} directory'.format(
                        download_directory)
                response = response_data
        elif func == 'job_log':
            detect_utils.check_config(
                config=config_data,
                required_arguments=['job_id', 'output_path'])
            job_id = config_data['job_id']
            tar_file_name = 'job_{}_log.tar.gz'.format(job_id)
            extract_dir = os.path.join(config_data['output_path'],
                                       'job_{}_log'.format(job_id))
            with closing(
                    requests.get("/".join([server_url,
                                           func.replace('_', '/')]),
                                 json=config_data,
                                 stream=True)) as response:
                if response.status_code == 200:
                    download_from_request(http_response=response,
                                          tar_file_name=tar_file_name,
                                          extract_dir=extract_dir)
                    response = {
                        'retcode':
                        0,
                        'directory':
                        extract_dir,
                        'retmsg':
                        'download successfully, please check {} directory'.
                        format(extract_dir)
                    }
                else:
                    response = response.json()
    elif func in TASK_OPERATE_FUNC:
        response = requests.post("/".join(
            [server_url, "job", "task",
             func.rstrip('_task')]),
                                 json=config_data)
    elif func in TRACKING_FUNC:
        if func != 'component_metric_delete':
            detect_utils.check_config(config=config_data,
                                      required_arguments=[
                                          'job_id', 'component_name', 'role',
                                          'party_id'
                                      ])
        if func == 'component_output_data':
            detect_utils.check_config(config=config_data,
                                      required_arguments=['output_path'])
            tar_file_name = 'job_{}_{}_{}_{}_output_data.tar.gz'.format(
                config_data['job_id'], config_data['component_name'],
                config_data['role'], config_data['party_id'])
            extract_dir = os.path.join(config_data['output_path'],
                                       tar_file_name.replace('.tar.gz', ''))
            with closing(
                    requests.get("/".join([
                        server_url, "tracking",
                        func.replace('_', '/'), 'download'
                    ]),
                                 json=config_data,
                                 stream=True)) as response:
                if response.status_code == 200:
                    download_from_request(http_response=response,
                                          tar_file_name=tar_file_name,
                                          extract_dir=extract_dir)
                    response = {
                        'retcode':
                        0,
                        'directory':
                        extract_dir,
                        'retmsg':
                        'download successfully, please check {} directory'.
                        format(extract_dir)
                    }
                else:
                    response = response.json()

        else:
            response = requests.post("/".join(
                [server_url, "tracking",
                 func.replace('_', '/')]),
                                     json=config_data)
    elif func in DATA_FUNC:
        response = requests.post("/".join(
            [server_url, "data", func.replace('_', '/')]),
                                 json=config_data)
        try:
            if response.json()['retcode'] == 999:
                start_cluster_standalone_job_server()
                response = requests.post("/".join([server_url, "data", func]),
                                         json=config_data)
        except:
            pass
    elif func in TABLE_FUNC:
        if func == "table_info":
            detect_utils.check_config(
                config=config_data,
                required_arguments=['namespace', 'table_name'])
            response = requests.post("/".join([server_url, "table", func]),
                                     json=config_data)
        else:
            response = requests.post("/".join(
                [server_url, "table",
                 func.lstrip('table_')]),
                                     json=config_data)
    elif func in MODEL_FUNC:
        if func == "version":
            detect_utils.check_config(config=config_data,
                                      required_arguments=['namespace'])
        response = requests.post("/".join([server_url, "model", func]),
                                 json=config_data)
    elif func in PERMISSION_FUNC:
        detect_utils.check_config(
            config=config_data,
            required_arguments=['src_party_id', 'src_role'])
        response = requests.post("/".join(
            [server_url, "permission",
             func.replace('_', '/')]),
                                 json=config_data)
    return response.json() if isinstance(
        response, requests.models.Response) else response
コード例 #6
0
ファイル: settings.py プロジェクト: zhaobobo10/KubeFATE
    'host': 'mysql',
    'port': 3306,
    'max_connections': 100,
    'stale_timeout': 30,
}

REDIS = {
    'host': 'redis',
    'port': 6379,
    'password': '******',
    'max_connections': 500
}

REDIS_QUEUE_DB_INDEX = 0
JOB_MODULE_CONF = file_utils.load_json_conf("fate_flow/job_module_conf.json")

"""
Services
"""
server_conf = file_utils.load_json_conf("arch/conf/server_conf.json")
PROXY_HOST = server_conf.get(SERVERS).get('proxy').get('host')
PROXY_PORT = server_conf.get(SERVERS).get('proxy').get('port')
BOARD_HOST = server_conf.get(SERVERS).get('fateboard').get('host')
if BOARD_HOST == 'localhost':
    BOARD_HOST = get_lan_ip()
BOARD_PORT = server_conf.get(SERVERS).get('fateboard').get('port')
SERVINGS = server_conf.get(SERVERS).get('servings')
BOARD_DASHBOARD_URL = 'http://%s:%d/index.html#/dashboard?job_id={}&role={}&party_id={}' % (BOARD_HOST, BOARD_PORT)
RuntimeConfig.init_config(WORK_MODE=WORK_MODE)
RuntimeConfig.init_config(HTTP_PORT=HTTP_PORT)
コード例 #7
0
    def run_task():
        task = Task()
        task.f_create_time = current_timestamp()
        try:
            parser = argparse.ArgumentParser()
            parser.add_argument('-j', '--job_id', required=True, type=str, help="job id")
            parser.add_argument('-n', '--component_name', required=True, type=str,
                                help="component name")
            parser.add_argument('-t', '--task_id', required=True, type=str, help="task id")
            parser.add_argument('-r', '--role', required=True, type=str, help="role")
            parser.add_argument('-p', '--party_id', required=True, type=str, help="party id")
            parser.add_argument('-c', '--config', required=True, type=str, help="task config")
            parser.add_argument('--job_server', help="job server", type=str)
            args = parser.parse_args()
            schedule_logger(args.job_id).info('enter task process')
            schedule_logger(args.job_id).info(args)
            # init function args
            if args.job_server:
                RuntimeConfig.init_config(HTTP_PORT=args.job_server.split(':')[1])
            job_id = args.job_id
            component_name = args.component_name
            task_id = args.task_id
            role = args.role
            party_id = int(args.party_id)
            task_config = file_utils.load_json_conf(args.config)
            job_parameters = task_config['job_parameters']
            job_initiator = task_config['job_initiator']
            job_args = task_config['job_args']
            task_input_dsl = task_config['input']
            task_output_dsl = task_config['output']
            parameters = TaskExecutor.get_parameters(job_id, component_name, role, party_id)
            # parameters = task_config['parameters']
            module_name = task_config['module_name']
        except Exception as e:
            traceback.print_exc()
            schedule_logger().exception(e)
            task.f_status = TaskStatus.FAILED
            return
        try:
            # init environment, process is shared globally
            RuntimeConfig.init_config(WORK_MODE=job_parameters['work_mode'],
                                      BACKEND=job_parameters.get('backend', 0))
            session.init(job_id='{}_{}_{}'.format(task_id, role, party_id), mode=RuntimeConfig.WORK_MODE, backend=RuntimeConfig.BACKEND)
            federation.init(job_id=task_id, runtime_conf=parameters)
            job_log_dir = os.path.join(job_utils.get_job_log_directory(job_id=job_id), role, str(party_id))
            task_log_dir = os.path.join(job_log_dir, component_name)
            log_utils.LoggerFactory.set_directory(directory=task_log_dir, parent_log_dir=job_log_dir,
                                                  append_to_parent_log=True, force=True)

            task.f_job_id = job_id
            task.f_component_name = component_name
            task.f_task_id = task_id
            task.f_role = role
            task.f_party_id = party_id
            task.f_operator = 'python_operator'
            tracker = Tracking(job_id=job_id, role=role, party_id=party_id, component_name=component_name,
                               task_id=task_id,
                               model_id=job_parameters['model_id'],
                               model_version=job_parameters['model_version'],
                               module_name=module_name)
            task.f_start_time = current_timestamp()
            task.f_run_ip = get_lan_ip()
            task.f_run_pid = os.getpid()
            run_class_paths = parameters.get('CodePath').split('/')
            run_class_package = '.'.join(run_class_paths[:-2]) + '.' + run_class_paths[-2].replace('.py','')
            run_class_name = run_class_paths[-1]
            task_run_args = TaskExecutor.get_task_run_args(job_id=job_id, role=role, party_id=party_id,
                                                           job_parameters=job_parameters, job_args=job_args,
                                                           input_dsl=task_input_dsl)
            run_object = getattr(importlib.import_module(run_class_package), run_class_name)()
            run_object.set_tracker(tracker=tracker)
            run_object.set_taskid(taskid=task_id)
            task.f_status = TaskStatus.RUNNING
            TaskExecutor.sync_task_status(job_id=job_id, component_name=component_name, task_id=task_id, role=role,
                                          party_id=party_id, initiator_party_id=job_initiator.get('party_id', None),
                                          initiator_role=job_initiator.get('role', None),
                                          task_info=task.to_json())

            schedule_logger().info('run {} {} {} {} {} task'.format(job_id, component_name, task_id, role, party_id))
            schedule_logger().info(parameters)
            schedule_logger().info(task_input_dsl)
            run_object.run(parameters, task_run_args)
            output_data = run_object.save_data()
            tracker.save_output_data_table(output_data, task_output_dsl.get('data')[0] if task_output_dsl.get('data') else 'component')
            output_model = run_object.export_model()
            # There is only one model output at the current dsl version.
            tracker.save_output_model(output_model, task_output_dsl['model'][0] if task_output_dsl.get('model') else 'default')
            task.f_status = TaskStatus.SUCCESS
        except Exception as e:
            traceback.print_exc()
            schedule_logger().exception(e)
            task.f_status = TaskStatus.FAILED
        finally:
            sync_success = False
            try:
                session.stop()
                task.f_end_time = current_timestamp()
                task.f_elapsed = task.f_end_time - task.f_start_time
                task.f_update_time = current_timestamp()
                TaskExecutor.sync_task_status(job_id=job_id, component_name=component_name, task_id=task_id, role=role,
                                              party_id=party_id,
                                              initiator_party_id=job_initiator.get('party_id', None),
                                              initiator_role=job_initiator.get('role', None),
                                              task_info=task.to_json())
                sync_success = True
            except Exception as e:
                traceback.print_exc()
                schedule_logger().exception(e)
        schedule_logger().info(
            'finish {} {} {} {} {} {} task'.format(job_id, component_name, task_id, role, party_id, task.f_status if sync_success else TaskStatus.FAILED))
        print('finish {} {} {} {} {} {} task'.format(job_id, component_name, task_id, role, party_id, task.f_status if sync_success else TaskStatus.FAILED))
コード例 #8
0
    def run_component(job_id, job_runtime_conf, job_parameters, job_initiator,
                      job_args, dag, component):
        parameters = component.get_role_parameters()
        component_name = component.get_name()
        module_name = component.get_module()
        task_id = job_utils.generate_task_id(job_id=job_id,
                                             component_name=component_name)
        schedule_logger(job_id).info('job {} run component {}'.format(
            job_id, component_name))
        for role, partys_parameters in parameters.items():
            for party_index in range(len(partys_parameters)):
                party_parameters = partys_parameters[party_index]
                if role in job_args:
                    party_job_args = job_args[role][party_index]['args']
                else:
                    party_job_args = {}
                dest_party_id = party_parameters.get('local',
                                                     {}).get('party_id')

                response = federated_api(
                    job_id=job_id,
                    method='POST',
                    endpoint='/{}/schedule/{}/{}/{}/{}/{}/run'.format(
                        API_VERSION, job_id, component_name, task_id, role,
                        dest_party_id),
                    src_party_id=job_initiator['party_id'],
                    dest_party_id=dest_party_id,
                    src_role=job_initiator['role'],
                    json_body={
                        'job_parameters': job_parameters,
                        'job_initiator': job_initiator,
                        'job_args': party_job_args,
                        'parameters': party_parameters,
                        'module_name': module_name,
                        'input': component.get_input(),
                        'output': component.get_output(),
                        'job_server': {
                            'ip': get_lan_ip(),
                            'http_port': RuntimeConfig.HTTP_PORT
                        }
                    },
                    work_mode=job_parameters['work_mode'])
                if response['retcode']:
                    if 'not authorized' in response['retmsg']:
                        raise Exception(
                            'run component {} not authorized'.format(
                                component_name))
        component_task_status = TaskScheduler.check_task_status(
            job_id=job_id, component=component)
        job_status = TaskScheduler.check_job_status(job_id)
        if component_task_status and job_status:
            task_success = True
        else:
            task_success = False
        schedule_logger(job_id).info('job {} component {} run {}'.format(
            job_id, component_name, 'success' if task_success else 'failed'))
        # update progress
        TaskScheduler.sync_job_status(
            job_id=job_id,
            roles=job_runtime_conf['role'],
            work_mode=job_parameters['work_mode'],
            initiator_party_id=job_initiator['party_id'],
            initiator_role=job_initiator['role'],
            job_info=job_utils.update_job_progress(
                job_id=job_id, dag=dag, current_task_id=task_id).to_json())
        TaskScheduler.stop(job_id=job_id, component_name=component_name)
        if task_success:
            next_components = dag.get_next_components(component_name)
            schedule_logger(job_id).info(
                'job {} component {} next components is {}'.format(
                    job_id, component_name, [
                        next_component.get_name()
                        for next_component in next_components
                    ]))
            for next_component in next_components:
                try:
                    schedule_logger(job_id).info(
                        'job {} check component {} dependencies status'.format(
                            job_id, next_component.get_name()))
                    dependencies_status = TaskScheduler.check_dependencies(
                        job_id=job_id, dag=dag, component=next_component)
                    job_status = TaskScheduler.check_job_status(job_id)
                    schedule_logger(job_id).info(
                        'job {} component {} dependencies status is {}, job status is {}'
                        .format(job_id, next_component.get_name(),
                                dependencies_status, job_status))
                    if dependencies_status and job_status:
                        run_status = TaskScheduler.run_component(
                            job_id, job_runtime_conf, job_parameters,
                            job_initiator, job_args, dag, next_component)
                    else:
                        run_status = False
                except Exception as e:
                    schedule_logger(job_id).exception(e)
                    run_status = False
                if not run_status:
                    return False
            return True
        else:
            if component_task_status == None:
                end_status = JobStatus.TIMEOUT
            else:
                end_status = JobStatus.FAILED
            TaskScheduler.stop(job_id=job_id, end_status=end_status)
            return False
コード例 #9
0
ファイル: task_executor.py プロジェクト: pangzx1/FL1.0
    def run_task():
        task = Task()
        task.f_create_time = current_timestamp()
        try:
            parser = argparse.ArgumentParser()
            parser.add_argument('-j',
                                '--job_id',
                                required=True,
                                type=str,
                                help="Specify a config json file path")
            parser.add_argument('-n',
                                '--component_name',
                                required=True,
                                type=str,
                                help="Specify a config json file path")
            parser.add_argument('-t',
                                '--task_id',
                                required=True,
                                type=str,
                                help="Specify a config json file path")
            parser.add_argument('-r',
                                '--role',
                                required=True,
                                type=str,
                                help="Specify a config json file path")
            parser.add_argument('-p',
                                '--party_id',
                                required=True,
                                type=str,
                                help="Specify a config json file path")
            parser.add_argument('-c',
                                '--config',
                                required=True,
                                type=str,
                                help="Specify a config json file path")
            args = parser.parse_args()
            schedule_logger.info('enter task process')
            schedule_logger.info(args)
            # init function args
            job_id = args.job_id
            component_name = args.component_name
            task_id = args.task_id
            role = args.role
            party_id = int(args.party_id)
            task_config = file_utils.load_json_conf(args.config)
            job_parameters = task_config.get('job_parameters', None)
            job_initiator = task_config.get('job_initiator', None)
            job_args = task_config.get('job_args', {})
            task_input_dsl = task_config.get('input', {})
            task_output_dsl = task_config.get('output', {})
            parameters = task_config.get('parameters', {})
            module_name = task_config.get('module_name', '')
        except Exception as e:
            schedule_logger.exception(e)
            task.f_status = TaskStatus.FAILED
            return
        try:
            # init environment
            RuntimeConfig.init_config(WORK_MODE=job_parameters['work_mode'])
            storage.init_storage(job_id=task_id,
                                 work_mode=RuntimeConfig.WORK_MODE)
            federation.init(job_id=task_id, runtime_conf=parameters)
            job_log_dir = os.path.join(
                job_utils.get_job_log_directory(job_id=job_id), role,
                str(party_id))
            task_log_dir = os.path.join(job_log_dir, component_name)
            log_utils.LoggerFactory.set_directory(directory=task_log_dir,
                                                  parent_log_dir=job_log_dir,
                                                  append_to_parent_log=True,
                                                  force=True)

            task.f_job_id = job_id
            task.f_component_name = component_name
            task.f_task_id = task_id
            task.f_role = role
            task.f_party_id = party_id
            task.f_operator = 'python_operator'
            tracker = Tracking(job_id=job_id,
                               role=role,
                               party_id=party_id,
                               component_name=component_name,
                               task_id=task_id,
                               model_id=job_parameters['model_id'],
                               model_version=job_parameters['model_version'],
                               module_name=module_name)
            task.f_start_time = current_timestamp()
            task.f_run_ip = get_lan_ip()
            task.f_run_pid = os.getpid()
            run_class_paths = parameters.get('CodePath').split('/')
            run_class_package = '.'.join(
                run_class_paths[:-2]) + '.' + run_class_paths[-2].rstrip('.py')
            run_class_name = run_class_paths[-1]
            task_run_args = TaskExecutor.get_task_run_args(
                job_id=job_id,
                role=role,
                party_id=party_id,
                job_parameters=job_parameters,
                job_args=job_args,
                input_dsl=task_input_dsl)
            run_object = getattr(importlib.import_module(run_class_package),
                                 run_class_name)()
            run_object.set_tracker(tracker=tracker)
            run_object.set_taskid(taskid=task_id)
            task.f_status = TaskStatus.RUNNING
            TaskExecutor.sync_task_status(job_id=job_id,
                                          component_name=component_name,
                                          task_id=task_id,
                                          role=role,
                                          party_id=party_id,
                                          initiator_party_id=job_initiator.get(
                                              'party_id', None),
                                          task_info=task.to_json())

            schedule_logger.info('run {} {} {} {} {} task'.format(
                job_id, component_name, task_id, role, party_id))
            schedule_logger.info(parameters)
            schedule_logger.info(task_input_dsl)
            run_object.run(parameters, task_run_args)
            if task_output_dsl:
                if task_output_dsl.get('data', []):
                    output_data = run_object.save_data()
                    tracker.save_output_data_table(
                        output_data,
                        task_output_dsl.get('data')[0])
                if task_output_dsl.get('model', []):
                    output_model = run_object.export_model()
                    # There is only one model output at the current dsl version.
                    tracker.save_output_model(output_model,
                                              task_output_dsl['model'][0])
            task.f_status = TaskStatus.SUCCESS
        except Exception as e:
            schedule_logger.exception(e)
            task.f_status = TaskStatus.FAILED
        finally:
            try:
                task.f_end_time = current_timestamp()
                task.f_elapsed = task.f_end_time - task.f_start_time
                task.f_update_time = current_timestamp()
                TaskExecutor.sync_task_status(
                    job_id=job_id,
                    component_name=component_name,
                    task_id=task_id,
                    role=role,
                    party_id=party_id,
                    initiator_party_id=job_initiator.get('party_id', None),
                    task_info=task.to_json())
            except Exception as e:
                schedule_logger.exception(e)
        schedule_logger.info('finish {} {} {} {} {} {} task'.format(
            job_id, component_name, task_id, role, party_id, task.f_status))
        print('finish {} {} {} {} {} {} task'.format(job_id, component_name,
                                                     task_id, role, party_id,
                                                     task.f_status))
コード例 #10
0
ファイル: task_scheduler.py プロジェクト: zhilangtaosha/FATE
    def start_task(job_id, component_name, task_id, role, party_id,
                   task_config):
        schedule_logger(job_id).info(
            'job {} {} {} {} task subprocess is ready'.format(
                job_id, component_name, role, party_id, task_config))
        task_process_start_status = False
        try:
            task_dir = os.path.join(job_utils.get_job_directory(job_id=job_id),
                                    role, party_id, component_name)
            os.makedirs(task_dir, exist_ok=True)
            task_config_path = os.path.join(task_dir, 'task_config.json')
            with open(task_config_path, 'w') as fw:
                json.dump(task_config, fw)

            try:
                backend = task_config['job_parameters']['backend']
            except KeyError:
                backend = 0
                schedule_logger(job_id).warning(
                    "failed to get backend, set as 0")

            backend = Backend(backend)

            if backend.is_eggroll():
                process_cmd = [
                    'python3',
                    sys.modules[TaskExecutor.__module__].__file__,
                    '-j',
                    job_id,
                    '-n',
                    component_name,
                    '-t',
                    task_id,
                    '-r',
                    role,
                    '-p',
                    party_id,
                    '-c',
                    task_config_path,
                    '--job_server',
                    '{}:{}'.format(get_lan_ip(), HTTP_PORT),
                ]
            elif backend.is_spark():
                if "SPARK_HOME" not in os.environ:
                    raise EnvironmentError("SPARK_HOME not found")
                spark_submit_config = task_config['job_parameters'].get(
                    "spark_submit_config", dict())
                deploy_mode = spark_submit_config.get("deploy-mode", "client")
                queue = spark_submit_config.get("queue", "default")
                driver_memory = spark_submit_config.get("driver-memory", "1g")
                num_executors = spark_submit_config.get("num-executors", 2)
                executor_memory = spark_submit_config.get(
                    "executor-memory", "1g")
                executor_cores = spark_submit_config.get("executor-cores", 1)

                if deploy_mode not in ["client"]:
                    raise ValueError(
                        f"deploy mode {deploy_mode} not supported")
                spark_home = os.environ["SPARK_HOME"]
                spark_submit_cmd = os.path.join(spark_home, "bin/spark-submit")
                process_cmd = [
                    spark_submit_cmd,
                    f'--name={task_id}#{role}',
                    f'--deploy-mode={deploy_mode}',
                    f'--queue={queue}',
                    f'--driver-memory={driver_memory}',
                    f'--num-executors={num_executors}',
                    f'--executor-memory={executor_memory}',
                    f'--executor-cores={executor_cores}',
                    sys.modules[TaskExecutor.__module__].__file__,
                    '-j',
                    job_id,
                    '-n',
                    component_name,
                    '-t',
                    task_id,
                    '-r',
                    role,
                    '-p',
                    party_id,
                    '-c',
                    task_config_path,
                    '--job_server',
                    '{}:{}'.format(get_lan_ip(), HTTP_PORT),
                ]
            else:
                raise ValueError(f"${backend} supported")

            task_log_dir = os.path.join(
                job_utils.get_job_log_directory(job_id=job_id), role, party_id,
                component_name)
            schedule_logger(job_id).info(
                'job {} {} {} {} task subprocess start'.format(
                    job_id, component_name, role, party_id, task_config))
            p = job_utils.run_subprocess(config_dir=task_dir,
                                         process_cmd=process_cmd,
                                         log_dir=task_log_dir)
            if p:
                task_process_start_status = True
        except Exception as e:
            schedule_logger(job_id).exception(e)
        finally:
            schedule_logger(job_id).info(
                'job {} component {} on {} {} start task subprocess {}'.format(
                    job_id, component_name, role, party_id,
                    'success' if task_process_start_status else 'failed'))