Beispiel #1
0
def _get_table(transfer_data_desc):
    name = transfer_data_desc.storageLocator.name
    namespace = transfer_data_desc.storageLocator.namespace
    persistent = transfer_data_desc.storageLocator.type != storage_basic_pb2.IN_MEMORY
    return _EggRoll.get_instance().table(name=name,
                                         namespace=namespace,
                                         persistent=persistent)
Beispiel #2
0
def init(job_id=None,
         mode: WorkMode = WorkMode.STANDALONE,
         naming_policy: NamingPolicy = NamingPolicy.DEFAULT):
    if RuntimeInstance.EGGROLL:
        return
    if job_id is None:
        job_id = str(uuid.uuid1())
        LoggerFactory.setDirectory()
    else:
        LoggerFactory.setDirectory(
            os.path.join(file_utils.get_project_base_directory(), 'logs',
                         job_id))
    RuntimeInstance.MODE = mode

    eggroll_context = EggRollContext(naming_policy=naming_policy)
    if mode == WorkMode.STANDALONE:
        from eggroll.api.standalone.eggroll import Standalone
        RuntimeInstance.EGGROLL = Standalone(job_id=job_id,
                                             eggroll_context=eggroll_context)
    elif mode == WorkMode.CLUSTER:
        from eggroll.api.cluster.eggroll import _EggRoll
        from eggroll.api.cluster.eggroll import init as c_init
        c_init(job_id, eggroll_context=eggroll_context)
        RuntimeInstance.EGGROLL = _EggRoll.get_instance()
    else:
        from eggroll.api.cluster import simple_roll
        simple_roll.init(job_id)
        RuntimeInstance.EGGROLL = simple_roll.EggRoll.get_instance()
    RuntimeInstance.EGGROLL.table("__clustercomm__", job_id, partition=10)
Beispiel #3
0
    def get(self, name, tag, idx=-1):
        algorithm, sub_name = self.__check_authorization(name, is_send=False)

        auth_dict = self.trans_conf.get(algorithm)

        src_role = auth_dict.get(sub_name).get('src')

        src_party_ids = self.__get_parties(src_role)

        if 0 <= idx < len(src_party_ids):
            # idx is specified, return the remote object
            party_ids = [src_party_ids[idx]]
        else:
            # idx is not valid, return remote object list
            party_ids = src_party_ids

        job = basic_meta_pb2.Job(jobId=self.job_id, name=name)

        LOGGER.debug(
            "[GET] {} {} getting remote object {} from {} {}".format(self.role, self.party_id, tag, src_role,
                                                                     party_ids))

        # loop = asyncio.get_event_loop()
        # tasks = []
        results = []
        for party_id in party_ids:
            src = cluster_comm_pb2.Party(partyId="{}".format(party_id), name=src_role)
            dst = cluster_comm_pb2.Party(partyId="{}".format(self.party_id), name=self.role)
            trans_meta = cluster_comm_pb2.TransferMeta(job=job, tag=tag, src=src, dst=dst,
                                                     type=cluster_comm_pb2.RECV)
            # tasks.append(_receive(self.stub, trans_meta))
            results.append(self.__pool.submit(_thread_receive, self.stub.recv, self.stub.checkStatus, trans_meta))
        # results = loop.run_until_complete(asyncio.gather(*tasks))
        # loop.close()
        results = [r.result() for r in results]
        rtn = []
        for recv_meta in results:
            desc = recv_meta.dataDesc
            _persistent = desc.storageLocator.type != storage_basic_pb2.IN_MEMORY
            dest_table = _EggRoll.get_instance().table(name=desc.storageLocator.name,
                                                       namespace=desc.storageLocator.namespace,
                                                       persistent=_persistent)
            if recv_meta.dataDesc.transferDataType == cluster_comm_pb2.OBJECT:
                __tagged_key = _serdes.deserialize(desc.taggedVariableName)
                rtn.append(dest_table.get(__tagged_key))
                LOGGER.debug("[GET] Got remote object {}".format(__tagged_key))
            else:
                rtn.append(dest_table)
                src = recv_meta.src
                dst = recv_meta.dst
                LOGGER.debug(
                    "[GET] Got remote table {} from {} {} to {} {}".format(dest_table, src.name, src.partyId, dst.name,
                                                                           dst.partyId))
        if 0 <= idx < len(src_party_ids):
            return rtn[0]
        return rtn
Beispiel #4
0
    def remote(self, obj, name: str, tag: str, role=None, idx=-1):
        algorithm, sub_name = self.__check_authorization(name)

        auth_dict = self.trans_conf.get(algorithm)

        src = cluster_comm_pb2.Party(partyId="{}".format(self.party_id), name=self.role)

        if idx >= 0:
            if role is None:
                raise ValueError("{} cannot be None if idx specified".format(role))
            parties = {role: [self.__get_parties(role)[idx]]}
        elif role is not None:
            if role not in auth_dict.get(sub_name).get('dst'):
                raise ValueError("{} is not allowed to receive {}".format(role, name))
            parties = {role: self.__get_parties(role)}
        else:
            parties = {}
            for _role in auth_dict.get(sub_name).get('dst'):
                parties[_role] = self.__get_parties(_role)

        for _role, _partyIds in parties.items():
            for _partyId in _partyIds:
                _tagged_key = self.__remote__object_key(self.job_id, name, tag, self.role, self.party_id, _role,
                                                        _partyId)

                if isinstance(obj, _DTable):
                    '''
                    If it is a table, send the meta right away.
                    '''
                    desc = cluster_comm_pb2.TransferDataDesc(transferDataType=cluster_comm_pb2.DTABLE,
                                                           storageLocator=self.__get_locator(obj),
                                                           taggedVariableName=_serdes.serialize(_tagged_key))
                else:
                    '''
                    If it is a object, put the object in the table and send the table meta.
                    '''
                    _table = _EggRoll.get_instance().table(OBJECT_STORAGE_NAME, self.job_id)
                    _table.put(_tagged_key, obj)
                    storage_locator = self.__get_locator(_table)
                    desc = cluster_comm_pb2.TransferDataDesc(transferDataType=cluster_comm_pb2.OBJECT,
                                                           storageLocator=storage_locator,
                                                           taggedVariableName=_serdes.serialize(_tagged_key))

                LOGGER.debug("[REMOTE] Sending {}".format(_tagged_key))

                dst = cluster_comm_pb2.Party(partyId="{}".format(_partyId), name=_role)
                job = basic_meta_pb2.Job(jobId=self.job_id, name=name)
                self.stub.send(cluster_comm_pb2.TransferMeta(job=job, tag=tag, src=src, dst=dst, dataDesc=desc,
                                                           type=cluster_comm_pb2.SEND))
                LOGGER.debug("[REMOTE] Sent {}".format(_tagged_key))
Beispiel #5
0
def init(session_id=None,
         mode: WorkMode = WorkMode.STANDALONE,
         server_conf_path="eggroll/conf/server_conf.json",
         eggroll_session: EggrollSession = None,
         computing_engine_conf=None,
         naming_policy=NamingPolicy.DEFAULT,
         tag=None,
         job_id=None,
         chunk_size=100000):
    if RuntimeInstance.EGGROLL:
        return
    if not session_id:
        session_id = str(uuid.uuid1())
    LoggerFactory.setDirectory(
        os.path.join(file_utils.get_project_base_directory(), 'logs',
                     session_id))

    if not job_id:
        job_id = session_id
    RuntimeInstance.MODE = mode

    eggroll_session = EggrollSession(session_id=session_id,
                                     naming_policy=naming_policy)
    if mode == WorkMode.STANDALONE:
        from eggroll.api.standalone.eggroll import Standalone
        RuntimeInstance.EGGROLL = Standalone(eggroll_session=eggroll_session)
    elif mode == WorkMode.CLUSTER:
        from eggroll.api.cluster.eggroll import _EggRoll
        from eggroll.api.cluster.eggroll import init as c_init
        c_init(session_id=session_id,
               server_conf_path=server_conf_path,
               computing_engine_conf=computing_engine_conf,
               naming_policy=naming_policy,
               tag=tag,
               job_id=job_id)
        RuntimeInstance.EGGROLL = _EggRoll.get_instance()
    else:
        from eggroll.api.cluster import simple_roll
        simple_roll.init(job_id)
        RuntimeInstance.EGGROLL = simple_roll.EggRoll.get_instance()
    RuntimeInstance.EGGROLL.table("__clustercomm__", job_id, partition=10)
Beispiel #6
0
def _create_fragment_obj_table(namespace, persistent=True):
    eggroll = _EggRoll.get_instance()
    name = eggroll.generateUniqueId()
    return eggroll.table(name=name, namespace=namespace, persistent=persistent)
Beispiel #7
0
def _create_table(name, namespace, persistent=True):
    return _EggRoll.get_instance().table(name=name,
                                         namespace=namespace,
                                         persistent=persistent)