Beispiel #1
0
    def _run_(self):
        # todo: All function calls where errors should be thrown
        args = self.args
        start_time = current_timestamp()
        try:
            LOGGER.info(
                f'run {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} task'
            )
            self.report_info.update({
                "job_id": args.job_id,
                "component_name": args.component_name,
                "task_id": args.task_id,
                "task_version": args.task_version,
                "role": args.role,
                "party_id": args.party_id,
                "run_ip": args.run_ip,
                "run_pid": self.run_pid
            })
            operation_client = OperationClient()
            job_configuration = JobConfiguration(
                **operation_client.get_job_conf(
                    args.job_id, args.role, args.party_id, args.component_name,
                    args.task_id, args.task_version))
            task_parameters_conf = args.config
            dsl_parser = schedule_utils.get_job_dsl_parser(
                dsl=job_configuration.dsl,
                runtime_conf=job_configuration.runtime_conf,
                train_runtime_conf=job_configuration.train_runtime_conf,
                pipeline_dsl=None)

            job_parameters = dsl_parser.get_job_parameters(
                job_configuration.runtime_conf)
            user_name = job_parameters.get(args.role,
                                           {}).get(args.party_id,
                                                   {}).get("user", '')
            LOGGER.info(f"user name:{user_name}")
            src_user = task_parameters_conf.get("src_user")
            task_parameters = RunParameters(**task_parameters_conf)
            job_parameters = task_parameters
            if job_parameters.assistant_role:
                TaskExecutor.monkey_patch()

            job_args_on_party = TaskExecutor.get_job_args_on_party(
                dsl_parser, job_configuration.runtime_conf_on_party, args.role,
                args.party_id)
            component = dsl_parser.get_component_info(
                component_name=args.component_name)
            module_name = component.get_module()
            task_input_dsl = component.get_input()
            task_output_dsl = component.get_output()

            kwargs = {
                'job_id': args.job_id,
                'role': args.role,
                'party_id': args.party_id,
                'component_name': args.component_name,
                'task_id': args.task_id,
                'task_version': args.task_version,
                'model_id': job_parameters.model_id,
                'model_version': job_parameters.model_version,
                'component_module_name': module_name,
                'job_parameters': job_parameters,
            }
            tracker = Tracker(**kwargs)
            tracker_client = TrackerClient(**kwargs)
            checkpoint_manager = CheckpointManager(**kwargs)

            self.report_info["party_status"] = TaskStatus.RUNNING
            self.report_task_info_to_driver()

            previous_components_parameters = tracker_client.get_model_run_parameters(
            )
            LOGGER.info(
                f"previous_components_parameters:\n{json_dumps(previous_components_parameters, indent=4)}"
            )

            component_provider, component_parameters_on_party, user_specified_parameters = ProviderManager.get_component_run_info(
                dsl_parser=dsl_parser,
                component_name=args.component_name,
                role=args.role,
                party_id=args.party_id,
                previous_components_parameters=previous_components_parameters)
            RuntimeConfig.set_component_provider(component_provider)
            LOGGER.info(
                f"component parameters on party:\n{json_dumps(component_parameters_on_party, indent=4)}"
            )
            flow_feeded_parameters = {
                "output_data_name": task_output_dsl.get("data")
            }

            # init environment, process is shared globally
            RuntimeConfig.init_config(
                COMPUTING_ENGINE=job_parameters.computing_engine,
                FEDERATION_ENGINE=job_parameters.federation_engine,
                FEDERATED_MODE=job_parameters.federated_mode)

            if RuntimeConfig.COMPUTING_ENGINE == ComputingEngine.EGGROLL:
                session_options = task_parameters.eggroll_run.copy()
                session_options["python.path"] = os.getenv("PYTHONPATH")
                session_options["python.venv"] = os.getenv("VIRTUAL_ENV")
            else:
                session_options = {}

            sess = session.Session(session_id=args.session_id)
            sess.as_global()
            sess.init_computing(computing_session_id=args.session_id,
                                options=session_options)
            component_parameters_on_party[
                "job_parameters"] = job_parameters.to_dict()
            roles = job_configuration.runtime_conf["role"]
            if set(roles) == {"local"}:
                LOGGER.info(f"only local roles, pass init federation")
            else:
                sess.init_federation(
                    federation_session_id=args.federation_session_id,
                    runtime_conf=component_parameters_on_party,
                    service_conf=job_parameters.engines_address.get(
                        EngineType.FEDERATION, {}))
            LOGGER.info(
                f'run {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} task'
            )
            LOGGER.info(
                f"component parameters on party:\n{json_dumps(component_parameters_on_party, indent=4)}"
            )
            LOGGER.info(f"task input dsl {task_input_dsl}")
            task_run_args, input_table_list = self.get_task_run_args(
                job_id=args.job_id,
                role=args.role,
                party_id=args.party_id,
                task_id=args.task_id,
                task_version=args.task_version,
                job_args=job_args_on_party,
                job_parameters=job_parameters,
                task_parameters=task_parameters,
                input_dsl=task_input_dsl,
            )
            if module_name in {
                    "Upload", "Download", "Reader", "Writer", "Checkpoint"
            }:
                task_run_args["job_parameters"] = job_parameters
            LOGGER.info(f"task input args {task_run_args}")

            need_run = component_parameters_on_party.get("ComponentParam",
                                                         {}).get(
                                                             "need_run", True)
            provider_interface = provider_utils.get_provider_interface(
                provider=component_provider)
            run_object = provider_interface.get(
                module_name,
                ComponentRegistry.get_provider_components(
                    provider_name=component_provider.name,
                    provider_version=component_provider.version)).get_run_obj(
                        self.args.role)
            flow_feeded_parameters.update({"table_info": input_table_list})
            cpn_input = ComponentInput(
                tracker=tracker_client,
                checkpoint_manager=checkpoint_manager,
                task_version_id=job_utils.generate_task_version_id(
                    args.task_id, args.task_version),
                parameters=component_parameters_on_party["ComponentParam"],
                datasets=task_run_args.get("data", None),
                caches=task_run_args.get("cache", None),
                models=dict(
                    model=task_run_args.get("model"),
                    isometric_model=task_run_args.get("isometric_model"),
                ),
                job_parameters=job_parameters,
                roles=dict(
                    role=component_parameters_on_party["role"],
                    local=component_parameters_on_party["local"],
                ),
                flow_feeded_parameters=flow_feeded_parameters,
            )
            profile_log_enabled = False
            try:
                if int(os.getenv("FATE_PROFILE_LOG_ENABLED", "0")) > 0:
                    profile_log_enabled = True
            except Exception as e:
                LOGGER.warning(e)
            if profile_log_enabled:
                # add profile logs
                LOGGER.info("profile logging is enabled")
                profile.profile_start()
                cpn_output = run_object.run(cpn_input)
                sess.wait_remote_all_done()
                profile.profile_ends()
            else:
                LOGGER.info("profile logging is disabled")
                cpn_output = run_object.run(cpn_input)
                sess.wait_remote_all_done()

            output_table_list = []
            LOGGER.info(f"task output data {cpn_output.data}")
            for index, data in enumerate(cpn_output.data):
                data_name = task_output_dsl.get(
                    'data')[index] if task_output_dsl.get(
                        'data') else '{}'.format(index)
                #todo: the token depends on the engine type, maybe in job parameters
                persistent_table_namespace, persistent_table_name = tracker.save_output_data(
                    computing_table=data,
                    output_storage_engine=job_parameters.storage_engine,
                    token={"username": user_name})
                if persistent_table_namespace and persistent_table_name:
                    tracker.log_output_data_info(
                        data_name=data_name,
                        table_namespace=persistent_table_namespace,
                        table_name=persistent_table_name)
                    output_table_list.append({
                        "namespace": persistent_table_namespace,
                        "name": persistent_table_name
                    })
            self.log_output_data_table_tracker(args.job_id, input_table_list,
                                               output_table_list)

            # There is only one model output at the current dsl version.
            tracker_client.save_component_output_model(
                model_buffers=cpn_output.model,
                model_alias=task_output_dsl['model'][0]
                if task_output_dsl.get('model') else 'default',
                user_specified_run_parameters=user_specified_parameters)
            if cpn_output.cache is not None:
                for i, cache in enumerate(cpn_output.cache):
                    if cache is None:
                        continue
                    name = task_output_dsl.get(
                        "cache")[i] if "cache" in task_output_dsl else str(i)
                    if isinstance(cache, DataCache):
                        tracker.tracking_output_cache(cache, cache_name=name)
                    elif isinstance(cache, tuple):
                        tracker.save_output_cache(
                            cache_data=cache[0],
                            cache_meta=cache[1],
                            cache_name=name,
                            output_storage_engine=job_parameters.
                            storage_engine,
                            output_storage_address=job_parameters.
                            engines_address.get(EngineType.STORAGE, {}),
                            token={"username": user_name})
                    else:
                        raise RuntimeError(
                            f"can not support type {type(cache)} module run object output cache"
                        )
            if need_run:
                self.report_info["party_status"] = TaskStatus.SUCCESS
            else:
                self.report_info["party_status"] = TaskStatus.PASS
        except PassError as e:
            self.report_info["party_status"] = TaskStatus.PASS
        except Exception as e:
            traceback.print_exc()
            self.report_info["party_status"] = TaskStatus.FAILED
            LOGGER.exception(e)
        finally:
            try:
                self.report_info["end_time"] = current_timestamp()
                self.report_info[
                    "elapsed"] = self.report_info["end_time"] - start_time
                self.report_task_info_to_driver()
            except Exception as e:
                self.report_info["party_status"] = TaskStatus.FAILED
                traceback.print_exc()
                LOGGER.exception(e)
        msg = f"finish {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} with {self.report_info['party_status']}"
        LOGGER.info(msg)
        print(msg)
        return self.report_info
Beispiel #2
0
    def run_task(cls, **kwargs):
        task_info = {}
        try:
            job_id, component_name, task_id, task_version, role, party_id, run_ip, config, job_server = cls.get_run_task_args(kwargs)
            if job_server:
                RuntimeConfig.init_config(JOB_SERVER_HOST=job_server.split(':')[0],
                                          HTTP_PORT=job_server.split(':')[1])
                RuntimeConfig.set_process_role(ProcessRole.EXECUTOR)
            executor_pid = os.getpid()
            task_info.update({
                "job_id": job_id,
                "component_name": component_name,
                "task_id": task_id,
                "task_version": task_version,
                "role": role,
                "party_id": party_id,
                "run_ip": run_ip,
                "run_pid": executor_pid
            })
            start_time = current_timestamp()
            operation_client = OperationClient()
            job_conf = operation_client.get_job_conf(job_id, role)
            job_dsl = job_conf["job_dsl_path"]
            job_runtime_conf = job_conf["job_runtime_conf_path"]
            dsl_parser = schedule_utils.get_job_dsl_parser(dsl=job_dsl,
                                                           runtime_conf=job_runtime_conf,
                                                           train_runtime_conf=job_conf["train_runtime_conf_path"],
                                                           pipeline_dsl=job_conf["pipeline_dsl_path"]
                                                           )
            party_index = job_runtime_conf["role"][role].index(party_id)
            job_args_on_party = TaskExecutor.get_job_args_on_party(dsl_parser, job_runtime_conf, role, party_id)
            component = dsl_parser.get_component_info(component_name=component_name)
            component_parameters = component.get_role_parameters()
            component_parameters_on_party = component_parameters[role][
                party_index] if role in component_parameters else {}
            module_name = component.get_module()
            task_input_dsl = component.get_input()
            task_output_dsl = component.get_output()
            component_parameters_on_party['output_data_name'] = task_output_dsl.get('data')
            json_conf = operation_client.load_json_conf(job_id, config)
            user_name = dsl_parser.get_job_parameters().get(role, {}).get(party_id, {}).get("user", '')
            schedule_logger(job_id).info(f"user name:{user_name}")
            src_user = json_conf.get("src_user")
            task_parameters = RunParameters(**json_conf)
            job_parameters = task_parameters
            if job_parameters.assistant_role:
                TaskExecutor.monkey_patch()
            job_log_dir = os.path.join(job_utils.get_job_log_directory(job_id=job_id), role, str(party_id))
            task_log_dir = os.path.join(job_log_dir, component_name)
            log.LoggerFactory.set_directory(directory=task_log_dir, parent_log_dir=job_log_dir,
                                            append_to_parent_log=True, force=True)

            tracker = Tracker(job_id=job_id, role=role, party_id=party_id, component_name=component_name,
                              task_id=task_id,
                              task_version=task_version,
                              model_id=job_parameters.model_id,
                              model_version=job_parameters.model_version,
                              component_module_name=module_name,
                              job_parameters=job_parameters)
            tracker_client = TrackerClient(job_id=job_id, role=role, party_id=party_id,
                                           component_name=component_name,
                                           task_id=task_id,
                                           task_version=task_version,
                                           model_id=job_parameters.model_id,
                                           model_version=job_parameters.model_version,
                                           component_module_name=module_name,
                                           job_parameters=job_parameters)
            run_class_paths = component_parameters_on_party.get('CodePath').split('/')
            run_class_package = '.'.join(run_class_paths[:-2]) + '.' + run_class_paths[-2].replace('.py', '')
            run_class_name = run_class_paths[-1]
            task_info["party_status"] = TaskStatus.RUNNING
            cls.report_task_update_to_driver(task_info=task_info)

            # init environment, process is shared globally
            RuntimeConfig.init_config(WORK_MODE=job_parameters.work_mode,
                                      COMPUTING_ENGINE=job_parameters.computing_engine,
                                      FEDERATION_ENGINE=job_parameters.federation_engine,
                                      FEDERATED_MODE=job_parameters.federated_mode)

            if RuntimeConfig.COMPUTING_ENGINE == ComputingEngine.EGGROLL:
                session_options = task_parameters.eggroll_run.copy()
            else:
                session_options = {}

            sess = session.Session(computing_type=job_parameters.computing_engine, federation_type=job_parameters.federation_engine)
            computing_session_id = job_utils.generate_session_id(task_id, task_version, role, party_id)
            sess.init_computing(computing_session_id=computing_session_id, options=session_options)
            federation_session_id = job_utils.generate_task_version_id(task_id, task_version)
            component_parameters_on_party["job_parameters"] = job_parameters.to_dict()
            sess.init_federation(federation_session_id=federation_session_id,
                                 runtime_conf=component_parameters_on_party,
                                 service_conf=job_parameters.engines_address.get(EngineType.FEDERATION, {}))
            sess.as_default()

            schedule_logger().info('Run {} {} {} {} {} task'.format(job_id, component_name, task_id, role, party_id))
            schedule_logger().info("Component parameters on party {}".format(component_parameters_on_party))
            schedule_logger().info("Task input dsl {}".format(task_input_dsl))
            task_run_args = cls.get_task_run_args(job_id=job_id, role=role, party_id=party_id,
                                                  task_id=task_id,
                                                  task_version=task_version,
                                                  job_args=job_args_on_party,
                                                  job_parameters=job_parameters,
                                                  task_parameters=task_parameters,
                                                  input_dsl=task_input_dsl,
                                                  )
            if module_name in {"Upload", "Download", "Reader", "Writer"}:
                task_run_args["job_parameters"] = job_parameters
            run_object = getattr(importlib.import_module(run_class_package), run_class_name)()
            run_object.set_tracker(tracker=tracker_client)
            run_object.set_task_version_id(task_version_id=job_utils.generate_task_version_id(task_id, task_version))
            # add profile logs
            profile.profile_start()
            run_object.run(component_parameters_on_party, task_run_args)
            # profile.profile_ends()
            output_data = run_object.save_data()
            if not isinstance(output_data, list):
                output_data = [output_data]
            for index in range(0, len(output_data)):
                data_name = task_output_dsl.get('data')[index] if task_output_dsl.get('data') else '{}'.format(index)
                persistent_table_namespace, persistent_table_name = tracker.save_output_data(
                    computing_table=output_data[index],
                    output_storage_engine=job_parameters.storage_engine,
                    output_storage_address=job_parameters.engines_address.get(EngineType.STORAGE, {}),
                    user_name=user_name)
                if persistent_table_namespace and persistent_table_name:
                    tracker.log_output_data_info(data_name=data_name,
                                                 table_namespace=persistent_table_namespace,
                                                 table_name=persistent_table_name)
            output_model = run_object.export_model()
            # There is only one model output at the current dsl version.
            tracker.save_output_model(output_model,
                                      task_output_dsl['model'][0] if task_output_dsl.get('model') else 'default',
                                      tracker_client=tracker_client)
            task_info["party_status"] = TaskStatus.SUCCESS
        except Exception as e:
            traceback.print_exc()
            task_info["party_status"] = TaskStatus.FAILED
            schedule_logger().exception(e)
        finally:
            try:
                task_info["end_time"] = current_timestamp()
                task_info["elapsed"] = task_info["end_time"] - start_time
                cls.report_task_update_to_driver(task_info=task_info)
            except Exception as e:
                task_info["party_status"] = TaskStatus.FAILED
                traceback.print_exc()
                schedule_logger().exception(e)
        schedule_logger().info(
            'task {} {} {} start time: {}'.format(task_id, role, party_id, timestamp_to_date(start_time)))
        schedule_logger().info(
            'task {} {} {} end time: {}'.format(task_id, role, party_id, timestamp_to_date(task_info["end_time"])))
        schedule_logger().info(
            'task {} {} {} takes {}s'.format(task_id, role, party_id, int(task_info["elapsed"]) / 1000))
        schedule_logger().info(
            'Finish {} {} {} {} {} {} task {}'.format(job_id, component_name, task_id, task_version, role, party_id,
                                                      task_info["party_status"]))

        print('Finish {} {} {} {} {} {} task {}'.format(job_id, component_name, task_id, task_version, role, party_id,
                                                        task_info["party_status"]))
        return task_info