예제 #1
0
 def get_env(cls, job_id, provider_info):
     provider = ComponentProvider(**provider_info)
     env = provider.env.copy()
     env["PYTHONPATH"] = os.path.dirname(provider.path)
     if job_id:
         env["FATE_JOB_ID"] = job_id
     return env
예제 #2
0
    def _run(self):
        result = {}
        dsl_parser = schedule_utils.get_job_dsl_parser(
            dsl=self.args.dsl,
            runtime_conf=self.args.runtime_conf,
            train_runtime_conf=self.args.train_runtime_conf,
            pipeline_dsl=self.args.pipeline_dsl)

        provider = ComponentProvider(**self.args.config["provider"])
        common_task_info = self.args.config["common_task_info"]
        log_msg = f"initialize the components: {self.args.config['components']}"
        LOGGER.info(
            start_log(log_msg,
                      role=self.args.role,
                      party_id=self.args.party_id))
        for component_name in self.args.config["components"]:
            result[component_name] = {}
            task_info = {}
            task_info.update(common_task_info)

            parameters, user_specified_parameters = ProviderManager.get_component_parameters(
                dsl_parser=dsl_parser,
                component_name=component_name,
                role=self.args.role,
                party_id=self.args.party_id,
                provider=provider)
            if parameters:
                task_info = {}
                task_info.update(common_task_info)
                task_info["component_name"] = component_name
                task_info["component_module"] = parameters["module"]
                task_info["provider_info"] = provider.to_dict()
                task_info["component_parameters"] = parameters
                TaskController.create_task(
                    role=self.args.role,
                    party_id=self.args.party_id,
                    run_on_this_party=common_task_info["run_on_this_party"],
                    task_info=task_info)
                result[component_name]["need_run"] = True
            else:
                # The party does not need to run, pass
                result[component_name]["need_run"] = False
        LOGGER.info(
            successful_log(log_msg,
                           role=self.args.role,
                           party_id=self.args.party_id))
        return result
예제 #3
0
 def get_fate_flow_provider(cls):
     path = get_fate_flow_python_directory("fate_flow")
     provider = ComponentProvider(
         name="fate_flow",
         version=get_versions()["FATEFlow"],
         path=path,
         class_path=ComponentRegistry.get_default_class_path())
     return provider
예제 #4
0
    def run(self, task: Task, run_parameters, run_parameters_path, config_dir,
            log_dir, cwd_dir, **kwargs):
        spark_home = ServiceRegistry.FATE_ON_SPARK.get("spark", {}).get("home")
        if not spark_home:
            try:
                import pyspark
                spark_home = pyspark.__path__[0]
            except ImportError as e:
                raise RuntimeError("can not import pyspark")
            except Exception as e:
                raise RuntimeError("can not import pyspark")
        # else:
        #     raise ValueError(f"spark home must be configured in conf/service_conf.yaml when run on cluster mode")

        # additional configs
        spark_submit_config = run_parameters.spark_run

        deploy_mode = spark_submit_config.get("deploy-mode", "client")
        if deploy_mode not in ["client"]:
            raise ValueError(f"deploy mode {deploy_mode} not supported")

        spark_submit_cmd = os.path.join(spark_home, "bin/spark-submit")
        executable = [
            spark_submit_cmd, f"--name={task.f_task_id}#{task.f_role}"
        ]
        for k, v in spark_submit_config.items():
            if k != "conf":
                executable.append(f"--{k}={v}")
        if "conf" in spark_submit_config:
            for ck, cv in spark_submit_config["conf"].items():
                executable.append(f"--conf")
                executable.append(f"{ck}={cv}")
        extra_env = {}
        extra_env["SPARK_HOME"] = spark_home
        if DEPENDENT_DISTRIBUTION:
            dependence = Dependence()
            dependence.init(provider=ComponentProvider(**task.f_provider_info))
            executor_env_pythonpath, executor_python_env, driver_python_env, archives = dependence.get_task_dependence_info(
            )
            schedule_logger(task.f_job_id).info(
                f"executor_env_python {executor_python_env},"
                f"driver_env_python {driver_python_env}, archives {archives}")
            executable.append(f'--archives')
            executable.append(archives)
            executable.append(f'--conf')
            executable.append(f'spark.pyspark.python={executor_python_env}')
            executable.append(f'--conf')
            executable.append(
                f'spark.executorEnv.PYTHONPATH={executor_env_pythonpath}')
            executable.append(f'--conf')
            executable.append(
                f'spark.pyspark.driver.python={driver_python_env}')
        return WorkerManager.start_task_worker(
            worker_name=WorkerName.TASK_EXECUTOR,
            task=task,
            task_parameters=run_parameters,
            executable=executable,
            extra_env=extra_env)
예제 #5
0
 def get_default_fate_provider(cls):
     path = JobDefaultConfig.default_component_provider_path.split("/")
     path = file_utils.get_fate_python_directory(*path)
     if not os.path.exists(path):
         raise Exception(f"default fate provider not exists: {path}")
     provider = ComponentProvider(
         name="fate",
         version=get_versions()["FATE"],
         path=path,
         class_path=ComponentRegistry.get_default_class_path())
     return provider
예제 #6
0
 def get_provider_object(cls, provider_info, check_registration=True):
     name, version = provider_info["name"], provider_info["version"]
     if check_registration and ComponentRegistry.get_providers().get(
             name, {}).get(version, None) is None:
         raise Exception(f"{name} {version} provider is not registered")
     path = ComponentRegistry.get_providers().get(name, {}).get(
         version, {}).get("path", [])
     class_path = ComponentRegistry.get_providers().get(name, {}).get(
         version, {}).get("class_path", None)
     if class_path is None:
         class_path = ComponentRegistry.REGISTRY["default_settings"][
             "class_path"]
     return ComponentProvider(name=name,
                              version=version,
                              path=path,
                              class_path=class_path)
    def instantiate_component_provider(provider_detail,
                                       alias=None,
                                       module=None,
                                       provider_name=None,
                                       provider_version=None,
                                       local_role=None,
                                       local_party_id=None,
                                       detect=True,
                                       provider_cache=None,
                                       job_parameters=None):
        if provider_name and provider_version:
            provider_path = provider_detail["providers"][provider_name][
                provider_version]["path"]
            provider = provider_utils.get_provider_interface(
                ComponentProvider(
                    name=provider_name,
                    version=provider_version,
                    path=provider_path,
                    class_path=ComponentRegistry.get_default_class_path()))
            if provider_cache is not None:
                if provider_name not in provider_cache:
                    provider_cache[provider_name] = {}

                provider_cache[provider_name][provider_version] = provider

            return provider

        provider_name, provider_version = RuntimeConfParserUtil.get_component_provider(
            alias=alias,
            module=module,
            provider_detail=provider_detail,
            local_role=local_role,
            local_party_id=local_party_id,
            job_parameters=job_parameters,
            provider_cache=provider_cache,
            detect=detect)

        return RuntimeConfParserUtil.instantiate_component_provider(
            provider_detail,
            provider_name=provider_name,
            provider_version=provider_version)
예제 #8
0
def register():
    info = request.json or request.form.to_dict()
    if not Path(info["path"]).is_dir():
        return error_response(400, "invalid path")

    provider = ComponentProvider(
        name=info["name"],
        version=info["version"],
        path=info["path"],
        class_path=info.get("class_path",
                            ComponentRegistry.get_default_class_path()))
    code, std = WorkerManager.start_general_worker(
        worker_name=WorkerName.PROVIDER_REGISTRAR, provider=provider)
    if code == 0:
        ComponentRegistry.load()
        if ComponentRegistry.get_providers().get(provider.name, {}).get(
                provider.version, None) is None:
            return get_json_result(retcode=RetCode.OPERATING_ERROR,
                                   retmsg=f"not load into memory")
        else:
            return get_json_result()
    else:
        return get_json_result(retcode=RetCode.OPERATING_ERROR,
                               retmsg=f"register failed:\n{std}")
예제 #9
0
    def start_general_worker(cls,
                             worker_name: WorkerName,
                             job_id="",
                             role="",
                             party_id=0,
                             provider: ComponentProvider = None,
                             initialized_config: dict = None,
                             run_in_subprocess=True,
                             **kwargs):
        if RuntimeConfig.DEBUG:
            run_in_subprocess = True
        participate = locals()
        worker_id, config_dir, log_dir = cls.get_process_dirs(
            worker_name=worker_name,
            job_id=job_id,
            role=role,
            party_id=party_id)
        if worker_name in [
                WorkerName.PROVIDER_REGISTRAR, WorkerName.DEPENDENCE_UPLOAD
        ]:
            if not provider:
                raise ValueError("no provider argument")
            config = {"provider": provider.to_dict()}
            if worker_name == WorkerName.PROVIDER_REGISTRAR:
                from fate_flow.worker.provider_registrar import ProviderRegistrar
                module = ProviderRegistrar
                module_file_path = sys.modules[
                    ProviderRegistrar.__module__].__file__
                specific_cmd = []
            elif worker_name == WorkerName.DEPENDENCE_UPLOAD:
                from fate_flow.worker.dependence_upload import DependenceUpload
                module = DependenceUpload
                module_file_path = sys.modules[
                    DependenceUpload.__module__].__file__
                specific_cmd = [
                    '--dependence_type',
                    kwargs.get("dependence_type")
                ]
            provider_info = provider.to_dict()
        elif worker_name is WorkerName.TASK_INITIALIZER:
            if not initialized_config:
                raise ValueError("no initialized_config argument")
            config = initialized_config
            job_conf = job_utils.save_using_job_conf(job_id=job_id,
                                                     role=role,
                                                     party_id=party_id,
                                                     config_dir=config_dir)

            from fate_flow.worker.task_initializer import TaskInitializer
            module = TaskInitializer
            module_file_path = sys.modules[TaskInitializer.__module__].__file__
            specific_cmd = [
                '--dsl',
                job_conf["dsl_path"],
                '--runtime_conf',
                job_conf["runtime_conf_path"],
                '--train_runtime_conf',
                job_conf["train_runtime_conf_path"],
                '--pipeline_dsl',
                job_conf["pipeline_dsl_path"],
            ]
            provider_info = initialized_config["provider"]
        else:
            raise Exception(f"not support {worker_name} worker")
        config_path, result_path = cls.get_config(config_dir=config_dir,
                                                  config=config,
                                                  log_dir=log_dir)

        process_cmd = [
            sys.executable or "python3",
            module_file_path,
            "--config",
            config_path,
            '--result',
            result_path,
            "--log_dir",
            log_dir,
            "--parent_log_dir",
            os.path.dirname(log_dir),
            "--worker_id",
            worker_id,
            "--run_ip",
            RuntimeConfig.JOB_SERVER_HOST,
            "--job_server",
            f"{RuntimeConfig.JOB_SERVER_HOST}:{RuntimeConfig.HTTP_PORT}",
        ]

        if job_id:
            process_cmd.extend([
                "--job_id",
                job_id,
                "--role",
                role,
                "--party_id",
                party_id,
            ])

        process_cmd.extend(specific_cmd)
        if run_in_subprocess:
            p = process_utils.run_subprocess(job_id=job_id,
                                             config_dir=config_dir,
                                             process_cmd=process_cmd,
                                             added_env=cls.get_env(
                                                 job_id, provider_info),
                                             log_dir=log_dir,
                                             cwd_dir=config_dir,
                                             process_name=worker_name.value,
                                             process_id=worker_id)
            participate["pid"] = p.pid
            if job_id and role and party_id:
                logger = schedule_logger(job_id)
                msg = f"{worker_name} worker {worker_id} subprocess {p.pid}"
            else:
                logger = stat_logger
                msg = f"{worker_name} worker {worker_id} subprocess {p.pid}"
            logger.info(ready_log(msg=msg, role=role, party_id=party_id))

            # asynchronous
            if worker_name in [WorkerName.DEPENDENCE_UPLOAD]:
                if kwargs.get("callback") and kwargs.get("callback_param"):
                    callback_param = {}
                    participate.update(participate.get("kwargs", {}))
                    for k, v in participate.items():
                        if k in kwargs.get("callback_param"):
                            callback_param[k] = v
                    kwargs.get("callback")(**callback_param)
            else:
                try:
                    p.wait(timeout=120)
                    if p.returncode == 0:
                        logger.info(
                            successful_log(msg=msg,
                                           role=role,
                                           party_id=party_id))
                    else:
                        logger.info(
                            failed_log(msg=msg, role=role, party_id=party_id))
                    if p.returncode == 0:
                        return p.returncode, load_json_conf(result_path)
                    else:
                        std_path = process_utils.get_std_path(
                            log_dir=log_dir,
                            process_name=worker_name.value,
                            process_id=worker_id)
                        raise Exception(
                            f"run error, please check logs: {std_path}, {log_dir}/INFO.log"
                        )
                except subprocess.TimeoutExpired as e:
                    err = failed_log(msg=f"{msg} run timeout",
                                     role=role,
                                     party_id=party_id)
                    logger.exception(err)
                    raise Exception(err)
                finally:
                    try:
                        p.kill()
                        p.poll()
                    except Exception as e:
                        logger.exception(e)
        else:
            kwargs = cls.cmd_to_func_kwargs(process_cmd)
            code, message, result = module().run(**kwargs)
            if code == 0:
                return code, result
            else:
                raise Exception(message)
예제 #10
0
 def _run(self):
     provider = ComponentProvider(**self.args.config.get("provider"))
     dependence_type = self.args.dependence_type
     self.upload_dependencies_to_hadoop(provider=provider, dependence_type=dependence_type)
예제 #11
0
    def check_upload(cls,
                     job_id,
                     provider_group,
                     fate_flow_version_provider_info,
                     storage_engine=FateDependenceStorageEngine.HDFS.value):
        schedule_logger(job_id).info(
            "start Check if need to upload dependencies")
        schedule_logger(job_id).info(f"{provider_group}")
        upload_details = {}
        check_tag = True
        upload_total = 0
        for version, provider_info in provider_group.items():
            upload_details[version] = {}
            provider = ComponentProvider(**provider_info)
            for dependence_type in [
                    FateDependenceName.Fate_Source_Code.value,
                    FateDependenceName.Python_Env.value
            ]:
                schedule_logger(job_id).info(f"{dependence_type}")
                dependencies_storage_info = DependenceRegistry.get_dependencies_storage_meta(
                    storage_engine=storage_engine,
                    version=provider.version,
                    type=dependence_type,
                    get_or_one=True)
                need_upload = False
                if dependencies_storage_info:
                    if dependencies_storage_info.f_upload_status:
                        # version dependence uploading
                        check_tag = False
                        continue
                    elif not dependencies_storage_info.f_storage_path:
                        need_upload = True
                        upload_total += 1

                    elif dependence_type == FateDependenceName.Fate_Source_Code.value:
                        if provider.name == ComponentProviderName.FATE.value:
                            check_fate_flow_provider_status = False
                            if fate_flow_version_provider_info.values():
                                flow_provider = ComponentProvider(
                                    **list(fate_flow_version_provider_info.
                                           values())[0])
                                check_fate_flow_provider_status = DependenceRegistry.get_modify_time(flow_provider.path) \
                                                                  != dependencies_storage_info.f_fate_flow_snapshot_time
                            if FATE_FLOW_UPDATE_CHECK and check_fate_flow_provider_status:
                                need_upload = True
                                upload_total += 1
                            elif DependenceRegistry.get_modify_time(provider.path) != \
                                    dependencies_storage_info.f_snapshot_time:
                                need_upload = True
                                upload_total += 1
                        elif provider.name == ComponentProviderName.FATE_FLOW.value and FATE_FLOW_UPDATE_CHECK:
                            if DependenceRegistry.get_modify_time(provider.path) != \
                                    dependencies_storage_info.f_fate_flow_snapshot_time:
                                need_upload = True
                                upload_total += 1
                else:
                    need_upload = True
                    upload_total += 1
                if need_upload:
                    upload_details[version][dependence_type] = provider
        if upload_total > 0:
            check_tag = False
        schedule_logger(job_id).info(
            f"check dependencies result: {check_tag}, {upload_details}")
        return check_tag, upload_total > 0, upload_details
예제 #12
0
 def _run(self):
     provider = ComponentProvider(**self.args.config.get("provider"))
     support_components = ComponentRegistry.register_provider(provider)
     ComponentRegistry.register_components(provider, support_components)
     ComponentRegistry.dump()
     stat_logger.info(json_dumps(ComponentRegistry.REGISTRY, indent=4))