コード例 #1
0
 def _send(self,
           transfer_type,
           name: str,
           tag: str,
           dst_party: Party,
           rubbish: Rubbish,
           table: _DTable,
           obj=None):
     tagged_key = f"{name}-{tag}"
     if transfer_type == federation_pb2.OBJECT:
         table.put(tagged_key, obj)
         rubbish.add_obj(table, tagged_key)
     else:
         rubbish.add_table(table)
     data_desc = TransferDataDesc(
         transferDataType=transfer_type,
         storageLocator=_get_storage_locator(table),
         taggedVariableName=_ser_des.serialize(tagged_key))
     job = basic_meta_pb2.Job(jobId=self._session_id, name=name)
     transfer_meta = TransferMeta(job=job,
                                  tag=tag,
                                  src=self.local_party.to_pb(),
                                  dst=dst_party.to_pb(),
                                  dataDesc=data_desc,
                                  type=federation_pb2.SEND)
     self._stub.send(transfer_meta)
コード例 #2
0
    def get(self, name, tag, idx=-1):
        algorithm, sub_name = self.__check_authorization(name, is_send=False)

        auth_dict = self.trans_conf.get(algorithm)

        src_role = auth_dict.get(sub_name).get('src')

        src_party_ids = self.__get_parties(src_role)

        if 0 <= idx < len(src_party_ids):
            # idx is specified, return the remote object
            party_ids = [src_party_ids[idx]]
        else:
            # idx is not valid, return remote object list
            party_ids = src_party_ids

        job = basic_meta_pb2.Job(jobId=self.job_id, name=name)

        LOGGER.debug(
            "[GET] {} {} getting remote object {} from {} {}".format(self.role, self.party_id, tag, src_role,
                                                                     party_ids))

        # loop = asyncio.get_event_loop()
        # tasks = []
        results = []
        for party_id in party_ids:
            src = cluster_comm_pb2.Party(partyId="{}".format(party_id), name=src_role)
            dst = cluster_comm_pb2.Party(partyId="{}".format(self.party_id), name=self.role)
            trans_meta = cluster_comm_pb2.TransferMeta(job=job, tag=tag, src=src, dst=dst,
                                                     type=cluster_comm_pb2.RECV)
            # tasks.append(_receive(self.stub, trans_meta))
            results.append(self.__pool.submit(_thread_receive, self.stub.recv, self.stub.checkStatus, trans_meta))
        # results = loop.run_until_complete(asyncio.gather(*tasks))
        # loop.close()
        results = [r.result() for r in results]
        rtn = []
        for recv_meta in results:
            desc = recv_meta.dataDesc
            _persistent = desc.storageLocator.type != storage_basic_pb2.IN_MEMORY
            dest_table = _EggRoll.get_instance().table(name=desc.storageLocator.name,
                                                       namespace=desc.storageLocator.namespace,
                                                       persistent=_persistent)
            if recv_meta.dataDesc.transferDataType == cluster_comm_pb2.OBJECT:
                __tagged_key = _serdes.deserialize(desc.taggedVariableName)
                rtn.append(dest_table.get(__tagged_key))
                LOGGER.debug("[GET] Got remote object {}".format(__tagged_key))
            else:
                rtn.append(dest_table)
                src = recv_meta.src
                dst = recv_meta.dst
                LOGGER.debug(
                    "[GET] Got remote table {} from {} {} to {} {}".format(dest_table, src.name, src.partyId, dst.name,
                                                                           dst.partyId))
        if 0 <= idx < len(src_party_ids):
            return rtn[0]
        return rtn
コード例 #3
0
    def remote(self, obj, name: str, tag: str, role=None, idx=-1):
        algorithm, sub_name = self.__check_authorization(name)

        auth_dict = self.trans_conf.get(algorithm)

        src = cluster_comm_pb2.Party(partyId="{}".format(self.party_id), name=self.role)

        if idx >= 0:
            if role is None:
                raise ValueError("{} cannot be None if idx specified".format(role))
            parties = {role: [self.__get_parties(role)[idx]]}
        elif role is not None:
            if role not in auth_dict.get(sub_name).get('dst'):
                raise ValueError("{} is not allowed to receive {}".format(role, name))
            parties = {role: self.__get_parties(role)}
        else:
            parties = {}
            for _role in auth_dict.get(sub_name).get('dst'):
                parties[_role] = self.__get_parties(_role)

        for _role, _partyIds in parties.items():
            for _partyId in _partyIds:
                _tagged_key = self.__remote__object_key(self.job_id, name, tag, self.role, self.party_id, _role,
                                                        _partyId)

                if isinstance(obj, _DTable):
                    '''
                    If it is a table, send the meta right away.
                    '''
                    desc = cluster_comm_pb2.TransferDataDesc(transferDataType=cluster_comm_pb2.DTABLE,
                                                           storageLocator=self.__get_locator(obj),
                                                           taggedVariableName=_serdes.serialize(_tagged_key))
                else:
                    '''
                    If it is a object, put the object in the table and send the table meta.
                    '''
                    _table = _EggRoll.get_instance().table(OBJECT_STORAGE_NAME, self.job_id)
                    _table.put(_tagged_key, obj)
                    storage_locator = self.__get_locator(_table)
                    desc = cluster_comm_pb2.TransferDataDesc(transferDataType=cluster_comm_pb2.OBJECT,
                                                           storageLocator=storage_locator,
                                                           taggedVariableName=_serdes.serialize(_tagged_key))

                LOGGER.debug("[REMOTE] Sending {}".format(_tagged_key))

                dst = cluster_comm_pb2.Party(partyId="{}".format(_partyId), name=_role)
                job = basic_meta_pb2.Job(jobId=self.job_id, name=name)
                self.stub.send(cluster_comm_pb2.TransferMeta(job=job, tag=tag, src=src, dst=dst, dataDesc=desc,
                                                           type=cluster_comm_pb2.SEND))
                LOGGER.debug("[REMOTE] Sent {}".format(_tagged_key))
コード例 #4
0
def _thread_receive(receive_func, check_func, name, tag, session_id, src_party,
                    dst_party):
    log_msg = f"src={src_party}, dst={dst_party}, name={name}, tag={tag}, session_id={session_id}"
    LOGGER.debug(f"[GET] start: {log_msg}")
    job = basic_meta_pb2.Job(jobId=session_id, name=name)
    transfer_meta = TransferMeta(job=job,
                                 tag=tag,
                                 src=src_party.to_pb(),
                                 dst=dst_party.to_pb(),
                                 type=federation_pb2.RECV)
    recv_meta = _await_ready(receive_func, check_func, transfer_meta)
    desc = recv_meta.dataDesc

    if desc.transferDataType == federation_pb2.DTABLE:
        LOGGER.debug(
            f"[GET] table ready: src={src_party}, dst={dst_party}, name={name}, tag={tag}, session_id={session_id}"
        )
        table = _get_table(desc)
        return table, table, None

    if desc.transferDataType == federation_pb2.OBJECT:
        obj_table = _cache_get_obj_storage_table[src_party]
        __tagged_key = _ser_des.deserialize(desc.taggedVariableName)
        obj = obj_table.get(__tagged_key)

        if not is_split_head(obj):
            LOGGER.debug(f"[GET] object ready: {log_msg}")
            return obj, (obj_table, __tagged_key), None

        num_split = obj.num_split()
        LOGGER.debug(f"[GET] num_fragments={num_split}: {log_msg}")
        fragment_keys = []
        fragment_table = obj_table
        if REMOTE_FRAGMENT_OBJECT_USE_D_TABLE:
            LOGGER.debug(f"[GET] getting fragments table: {log_msg}")
            job = basic_meta_pb2.Job(jobId=session_id, name=name)
            transfer_meta = TransferMeta(job=job,
                                         tag=f"{tag}.fragments_table",
                                         src=src_party.to_pb(),
                                         dst=dst_party.to_pb(),
                                         type=federation_pb2.RECV)
            _resp_meta = _await_ready(receive_func, check_func, transfer_meta)
            table = _get_table(_resp_meta.dataDesc)
            fragment_table = table
            fragment_keys.extend(list(range(num_split)))
        else:
            for i in range(num_split):
                LOGGER.debug(
                    f"[GET] getting fragments({i + 1}/{num_split}): {log_msg}")
                job = basic_meta_pb2.Job(jobId=session_id, name=name)
                transfer_meta = TransferMeta(job=job,
                                             tag=_fragment_tag(tag, i),
                                             src=src_party.to_pb(),
                                             dst=dst_party.to_pb(),
                                             type=federation_pb2.RECV)
                _resp_meta = _await_ready(receive_func, check_func,
                                          transfer_meta)
                LOGGER.debug(
                    f"[GET] fragments({i + 1}/{num_split}) ready: {log_msg}")
                __fragment_tagged_key = _ser_des.deserialize(
                    _resp_meta.dataDesc.taggedVariableName)
                fragment_keys.append(__fragment_tagged_key)
        LOGGER.debug(f"[GET] large object ready: {log_msg}")
        return obj, (obj_table, __tagged_key), (fragment_table, fragment_keys)
    else:
        raise IOError(
            f"unknown transfer type: {recv_meta.dataDesc.transferDataType}")