示例#1
0
文件: eggroll.py 项目: zzzcq/eggroll
def init(job_id=None,
         mode: WorkMode = WorkMode.STANDALONE,
         naming_policy: NamingPolicy = NamingPolicy.DEFAULT):
    if RuntimeInstance.EGGROLL:
        return
    if job_id is None:
        job_id = str(uuid.uuid1())
        LoggerFactory.setDirectory()
    else:
        LoggerFactory.setDirectory(
            os.path.join(file_utils.get_project_base_directory(), 'logs',
                         job_id))
    RuntimeInstance.MODE = mode

    eggroll_context = EggRollContext(naming_policy=naming_policy)
    if mode == WorkMode.STANDALONE:
        from eggroll.api.standalone.eggroll import Standalone
        RuntimeInstance.EGGROLL = Standalone(job_id=job_id,
                                             eggroll_context=eggroll_context)
    elif mode == WorkMode.CLUSTER:
        from eggroll.api.cluster.eggroll import _EggRoll
        from eggroll.api.cluster.eggroll import init as c_init
        c_init(job_id, eggroll_context=eggroll_context)
        RuntimeInstance.EGGROLL = _EggRoll.get_instance()
    else:
        from eggroll.api.cluster import simple_roll
        simple_roll.init(job_id)
        RuntimeInstance.EGGROLL = simple_roll.EggRoll.get_instance()
    RuntimeInstance.EGGROLL.table("__clustercomm__", job_id, partition=10)
示例#2
0
def build_eggroll_runtime(work_mode: WorkMode,
                          eggroll_session: EggrollSession):
    if work_mode.is_standalone():
        from eggroll.api.standalone.eggroll import Standalone
        return Standalone(eggroll_session)

    elif work_mode.is_cluster():
        from eggroll.api.cluster.eggroll import eggroll_init, _EggRoll
        if _EggRoll.instance is None:
            return eggroll_init(eggroll_session)
    raise ValueError(f"work_mode: {work_mode} not supported!")
示例#3
0
    def get(self, name, tag, idx=-1):
        algorithm, sub_name = self.__check_authorization(name, is_send=False)

        auth_dict = self.trans_conf.get(algorithm)

        src_role = auth_dict.get(sub_name).get('src')

        src_party_ids = self.__get_parties(src_role)

        if 0 <= idx < len(src_party_ids):
            # idx is specified, return the remote object
            party_ids = [src_party_ids[idx]]
        else:
            # idx is not valid, return remote object list
            party_ids = src_party_ids

        _status_table = _get_meta_table(STATUS_TABLE_NAME, self.job_id)

        LOGGER.debug("[GET] {} {} getting remote object {} from {} {}".format(
            self.role, self.party_id, tag, src_role, party_ids))
        tasks = []

        for party_id in party_ids:
            _tagged_key = self.__remote__object_key(self.job_id, name, tag,
                                                    src_role, party_id,
                                                    self.role, self.party_id)
            tasks.append(check_status_and_get_value(_status_table,
                                                    _tagged_key))
        results = self._loop.run_until_complete(asyncio.gather(*tasks))

        rtn = []

        _object_table = _get_meta_table(OBJECT_STORAGE_NAME, self.job_id)
        for r in results:
            if isinstance(r, tuple):
                _persistent = r[0] == StoreType.LMDB
                rtn.append(Standalone.get_instance().table(
                    name=r[1],
                    namespace=r[2],
                    persistent=_persistent,
                    partition=r[3]))
            else:
                rtn.append(_object_table.get(r))

        if 0 <= idx < len(src_party_ids):
            return rtn[0]
        return rtn
示例#4
0
def maybe_create_eggroll_client():
    """
    a tricky way to set eggroll client which may be used by spark tasks.
    WARM: This may be removed or adjusted in future!
    """
    import pickle
    from pyspark.taskcontext import TaskContext
    mode, eggroll_session = pickle.loads(bytes.fromhex(TaskContext.get().getLocalProperty(_EGGROLL_CLIENT)))
    if mode == 1:
        from eggroll.api.cluster.eggroll import _EggRoll
        if _EggRoll.instance is None:
            from eggroll.api import ComputingEngine
            from eggroll.api.cluster.eggroll import _EggRoll
            eggroll_runtime = _EggRoll(eggroll_session=eggroll_session)
            eggroll_session.set_runtime(ComputingEngine.EGGROLL_DTABLE, eggroll_runtime)
    else:
        from eggroll.api.standalone.eggroll import Standalone
        Standalone(eggroll_session)
示例#5
0
文件: eggroll.py 项目: zazd/eggroll
def init(session_id=None,
         mode: WorkMode = WorkMode.STANDALONE,
         server_conf_path="eggroll/conf/server_conf.json",
         eggroll_session: EggrollSession = None,
         computing_engine_conf=None,
         naming_policy=NamingPolicy.DEFAULT,
         tag=None,
         job_id=None,
         chunk_size=100000):
    if RuntimeInstance.EGGROLL:
        return
    if not session_id:
        session_id = str(uuid.uuid1())
    LoggerFactory.setDirectory(
        os.path.join(file_utils.get_project_base_directory(), 'logs',
                     session_id))

    if not job_id:
        job_id = session_id
    RuntimeInstance.MODE = mode

    eggroll_session = EggrollSession(session_id=session_id,
                                     naming_policy=naming_policy)
    if mode == WorkMode.STANDALONE:
        from eggroll.api.standalone.eggroll import Standalone
        RuntimeInstance.EGGROLL = Standalone(eggroll_session=eggroll_session)
    elif mode == WorkMode.CLUSTER:
        from eggroll.api.cluster.eggroll import _EggRoll
        from eggroll.api.cluster.eggroll import init as c_init
        c_init(session_id=session_id,
               server_conf_path=server_conf_path,
               computing_engine_conf=computing_engine_conf,
               naming_policy=naming_policy,
               tag=tag,
               job_id=job_id)
        RuntimeInstance.EGGROLL = _EggRoll.get_instance()
    else:
        from eggroll.api.cluster import simple_roll
        simple_roll.init(job_id)
        RuntimeInstance.EGGROLL = simple_roll.EggRoll.get_instance()
    RuntimeInstance.EGGROLL.table("__clustercomm__", job_id, partition=10)
示例#6
0
    def get(self, name: str, tag: str,
            parties: Union[Party, list]) -> Tuple[list, Rubbish]:
        if isinstance(parties, Party):
            parties = [parties]
        self._get_side_auth(name=name, parties=parties)

        _status_table = _get_meta_table(STATUS_TABLE_NAME, self._session_id)
        LOGGER.debug(
            f"[GET] {self._local_party} getting {name}.{tag} from {parties}")
        tasks = []

        for party in parties:
            _tagged_key = self.__remote__object_key(self._session_id, name,
                                                    tag, party.role,
                                                    party.party_id, self._role,
                                                    self._party_id)
            tasks.append(check_status_and_get_value(_status_table,
                                                    _tagged_key))
        results = self._loop.run_until_complete(asyncio.gather(*tasks))
        rtn = []
        rubbish = Rubbish(name, tag)
        _object_table = _get_meta_table(OBJECT_STORAGE_NAME, self._session_id)
        for r in results:
            LOGGER.debug(
                f"[GET] {self._local_party} getting {r} from {parties}")
            if isinstance(r, tuple):
                _persistent = r[0] == StoreType.LMDB
                table = Standalone.get_instance().table(name=r[1],
                                                        namespace=r[2],
                                                        persistent=_persistent,
                                                        partition=r[3])
                rtn.append(table)
                rubbish.add_table(table)

            else:  # todo: should standalone mode split large object?
                obj = _object_table.get(r)
                rtn.append(obj)
                rubbish.add_obj(_object_table, r)
                rubbish.add_obj(_status_table, r)
        return rtn, rubbish
示例#7
0
def _get_meta_table(_name, _job_id):
    return Standalone.get_instance().table(_name, _job_id, partition=10)