Ejemplo n.º 1
0
def init(job_id=None,
         mode: WorkMode = WorkMode.STANDALONE,
         naming_policy: NamingPolicy = NamingPolicy.DEFAULT):
    if RuntimeInstance.EGGROLL:
        return
    if job_id is None:
        job_id = str(uuid.uuid1())
        LoggerFactory.set_directory()
    else:
        LoggerFactory.set_directory(
            os.path.join(file_utils.get_project_base_directory(), 'logs',
                         job_id))
    RuntimeInstance.MODE = mode

    eggroll_context = EggRollContext(naming_policy=naming_policy)
    if mode == WorkMode.STANDALONE:
        from arch.api.standalone.eggroll import Standalone
        RuntimeInstance.EGGROLL = Standalone(job_id=job_id,
                                             eggroll_context=eggroll_context)
    elif mode == WorkMode.CLUSTER:
        from arch.api.cluster.eggroll import _EggRoll
        from arch.api.cluster.eggroll import init as c_init
        c_init(job_id, eggroll_context=eggroll_context)
        RuntimeInstance.EGGROLL = _EggRoll.get_instance()
    else:
        from arch.api.cluster import simple_roll
        simple_roll.init(job_id)
        RuntimeInstance.EGGROLL = simple_roll.EggRoll.get_instance()
    RuntimeInstance.EGGROLL.table("__federation__", job_id, partition=10)
Ejemplo n.º 2
0
    def get(self, name, tag, idx=-1):
        algorithm, sub_name = self.__check_authorization(name, is_send=False)

        auth_dict = self.trans_conf.get(algorithm)

        src_role = auth_dict.get(sub_name).get('src')

        src_party_ids = self.__get_parties(src_role)

        if 0 <= idx < len(src_party_ids):
            # idx is specified, return the remote object
            party_ids = [src_party_ids[idx]]
        else:
            # idx is not valid, return remote object list
            party_ids = src_party_ids

        _status_table = _get_meta_table(STATUS_TABLE_NAME, self.job_id)

        LOGGER.debug("[GET] {} {} getting remote object {} from {} {}".format(self.role, self.party_id, tag, src_role,
                                                                              party_ids))
        tasks = []

        for party_id in party_ids:
            _tagged_key = self.__remote__object_key(self.job_id, name, tag, src_role, party_id, self.role,
                                                    self.party_id)
            tasks.append(check_status_and_get_value(_status_table, _tagged_key))
        results = self._loop.run_until_complete(asyncio.gather(*tasks))

        rtn = []

        _object_table = _get_meta_table(OBJECT_STORAGE_NAME, self.job_id)
        for r in results:
            if isinstance(r, tuple):
                _persistent = r[0] == StoreType.LMDB
                rtn.append(
                    Standalone.get_instance().table(name=r[1], namespace=r[2], persistent=_persistent, partition=r[3]))
            else:
                rtn.append(_object_table.get(r))

        if 0 <= idx < len(src_party_ids):
            return rtn[0]
        return rtn
Ejemplo n.º 3
0
def init(job_id=None, mode: WorkMode = WorkMode.STANDALONE):
    if job_id is None:
        job_id = str(uuid.uuid1())
        LoggerFactory.setDirectory()
    else:
        LoggerFactory.setDirectory(
            os.path.join(file_utils.get_project_base_directory(), 'logs',
                         job_id))
    RuntimeInstance.MODE = mode
    if mode == WorkMode.STANDALONE:
        from arch.api.standalone.eggroll import Standalone
        RuntimeInstance.EGGROLL = Standalone(job_id=job_id)
    elif mode == WorkMode.CLUSTER:
        from arch.api.cluster.eggroll import _EggRoll
        from arch.api.cluster.eggroll import init as c_init
        c_init(job_id)
        RuntimeInstance.EGGROLL = _EggRoll.get_instance()
    else:
        from arch.api.cluster import simple_roll
        simple_roll.init(job_id)
        RuntimeInstance.EGGROLL = simple_roll.EggRoll.get_instance()
    RuntimeInstance.EGGROLL.table("__federation__", job_id, partition=10)
Ejemplo n.º 4
0
def _get_meta_table(_name, _job_id):
    return Standalone.get_instance().table(_name, _job_id, partition=10)