예제 #1
0
파일: eggroll.py 프로젝트: pangzx1/FATE1.0
def init(job_id=None,
         mode: WorkMode = WorkMode.STANDALONE,
         naming_policy: NamingPolicy = NamingPolicy.DEFAULT):
    if RuntimeInstance.EGGROLL:
        return
    if job_id is None:
        job_id = str(uuid.uuid1())
        LoggerFactory.set_directory()
    else:
        LoggerFactory.set_directory(
            os.path.join(file_utils.get_project_base_directory(), 'logs',
                         job_id))
    RuntimeInstance.MODE = mode

    eggroll_context = EggRollContext(naming_policy=naming_policy)
    if mode == WorkMode.STANDALONE:
        from arch.api.standalone.eggroll import Standalone
        RuntimeInstance.EGGROLL = Standalone(job_id=job_id,
                                             eggroll_context=eggroll_context)
    elif mode == WorkMode.CLUSTER:
        from arch.api.cluster.eggroll import _EggRoll
        from arch.api.cluster.eggroll import init as c_init
        c_init(job_id, eggroll_context=eggroll_context)
        RuntimeInstance.EGGROLL = _EggRoll.get_instance()
    else:
        from arch.api.cluster import simple_roll
        simple_roll.init(job_id)
        RuntimeInstance.EGGROLL = simple_roll.EggRoll.get_instance()
    RuntimeInstance.EGGROLL.table("__federation__", job_id, partition=10)
예제 #2
0
def init(job_id=None,
         mode: typing.Union[int, WorkMode] = WorkMode.STANDALONE,
         backend: typing.Union[int, Backend] = Backend.EGGROLL,
         persistent_engine: StoreType = StoreType.LMDB,
         set_log_dir=True):
    if isinstance(mode, int):
        mode = WorkMode(mode)
    if isinstance(backend, int):
        backend = Backend(backend)
    if RuntimeInstance.SESSION:
        return
    if job_id is None:
        job_id = str(uuid.uuid1())
        if set_log_dir:
            LoggerFactory.set_directory()
    else:
        if set_log_dir:
            LoggerFactory.set_directory(
                os.path.join(file_utils.get_project_base_directory(), 'logs',
                             job_id))

    RuntimeInstance.MODE = mode
    RuntimeInstance.Backend = backend

    from arch.api.table.session import build_session
    session = build_session(job_id=job_id, work_mode=mode, backend=backend)
    RuntimeInstance.SESSION = session
예제 #3
0
def init(job_id=None, mode: WorkMode = WorkMode.STANDALONE):
    if job_id is None:
        job_id = str(uuid.uuid1())
        LoggerFactory.setDirectory()
    else:
        LoggerFactory.setDirectory(
            os.path.join(file_utils.get_project_base_directory(), 'logs',
                         job_id))
    RuntimeInstance.MODE = mode
    if mode == WorkMode.STANDALONE:
        from arch.api.standalone.eggroll import Standalone
        RuntimeInstance.EGGROLL = Standalone(job_id=job_id)
    elif mode == WorkMode.CLUSTER:
        from arch.api.cluster.eggroll import _EggRoll
        from arch.api.cluster.eggroll import init as c_init
        c_init(job_id)
        RuntimeInstance.EGGROLL = _EggRoll.get_instance()
    else:
        from arch.api.cluster import simple_roll
        simple_roll.init(job_id)
        RuntimeInstance.EGGROLL = simple_roll.EggRoll.get_instance()
    RuntimeInstance.EGGROLL.table("__federation__", job_id, partition=10)
예제 #4
0
import functools
import typing

from arch.api import session
from arch.api.utils.log_utils import LoggerFactory
from fate_flow.entity.metric import MetricType, MetricMeta, Metric
from federatedml.framework.h**o.procedure import aggregator
from federatedml.model_base import ModelBase
from federatedml.nn.homo_nn import nn_model
from federatedml.nn.homo_nn.nn_model import restore_nn_model
from federatedml.optim.convergence import converge_func_factory
from federatedml.param.homo_nn_param import HomoNNParam
from federatedml.transfer_variable.transfer_class.homo_transfer_variable import HomoTransferVariable
from federatedml.util import consts

Logger = LoggerFactory.get_logger()
MODEL_META_NAME = "HomoNNModelMeta"
MODEL_PARAM_NAME = "HomoNNModelParam"


def _build_model_dict(meta, param):
    return {MODEL_META_NAME: meta, MODEL_PARAM_NAME: param}


def _extract_param(model_dict: dict):
    return model_dict.get(MODEL_PARAM_NAME, None)


def _extract_meta(model_dict: dict):
    return model_dict.get(MODEL_META_NAME, None)
예제 #5
0
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
# -*- coding: utf-8 -*-
from arch.api.utils import file_utils
from arch.api.utils.log_utils import LoggerFactory
LoggerFactory.setDirectory()
logger = LoggerFactory.getLogger("task_manager")

'''
Constants
'''

API_VERSION = "v1"
ROLE = 'manager'
SERVERS = 'servers'
MAX_CONCURRENT_JOB_RUN = 5
DEFAULT_WORKFLOW_DATA_TYPE = ['train_input', 'data_input', 'id_library_input', 'model', 'predict_input', 'predict_output', 'evaluation_output', 'intersect_data_output']
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
DEFAULT_GRPC_OVERALL_TIMEOUT = 60 * 1000  # ms
HEADERS = {
    'Content-Type': 'application/json',
예제 #6
0
def init(job_id=None,
         mode: typing.Union[int, WorkMode] = WorkMode.STANDALONE,
         backend: typing.Union[int, Backend] = Backend.EGGROLL,
         persistent_engine: str = StoreTypes.ROLLPAIR_LMDB,
         eggroll_version=None,
         set_log_dir=True,
         options: dict = None):
    """
    Initializes session, should be called before all.

    Parameters
    ---------
    job_id : string
      job id and default table namespace of this runtime.
    mode : WorkMode
      set work mode,

        - standalone: `WorkMode.STANDALONE` or 0
        - cluster: `WorkMode.CLUSTER` or 1
    backend : Backend
      set computing backend,
        
        - eggroll: `Backend.EGGROLL` or 0
        - spark: `Backend.SAPRK` or 1
    options : None or dict
      additional options

    Returns
    -------
    None
      nothing returns

    Examples
    --------
    >>> from arch.api import session, WorkMode, Backend
    >>> session.init("a_job_id", WorkMode.Standalone, Backend.EGGROLL)
    """
    if RuntimeInstance.SESSION:
        return

    if isinstance(mode, int):
        mode = WorkMode(mode)
    if isinstance(backend, int):
        backend = Backend(backend)
    if job_id is None:
        job_id = str(uuid.uuid1())
        if True:
            LoggerFactory.set_directory()
    else:
        if set_log_dir:
            LoggerFactory.set_directory(
                os.path.join(file_utils.get_project_base_directory(), 'logs',
                             job_id))
    if eggroll_version is None:
        eggroll_version = _EGGROLL_VERSION

    if backend.is_eggroll():
        if eggroll_version < 2:
            from arch.api.impl.based_1x import build
            builder = build.Builder(session_id=job_id,
                                    work_mode=mode,
                                    persistent_engine=persistent_engine)

        else:
            from arch.api.impl.based_2x import build
            builder = build.Builder(session_id=job_id,
                                    work_mode=mode,
                                    persistent_engine=persistent_engine,
                                    options=options)

    elif backend.is_spark():
        if eggroll_version < 2:
            from arch.api.impl.based_spark.based_1x import build
            builder = build.Builder(session_id=job_id,
                                    work_mode=mode,
                                    persistent_engine=persistent_engine)
        else:
            from arch.api.impl.based_spark.based_2x import build
            builder = build.Builder(session_id=job_id,
                                    work_mode=mode,
                                    persistent_engine=persistent_engine,
                                    options=options)

    else:
        raise ValueError(f"backend: ${backend} unknown")

    RuntimeInstance.MODE = mode
    RuntimeInstance.BACKEND = backend
    RuntimeInstance.BUILDER = builder
    RuntimeInstance.SESSION = builder.build_session()