def _send(self, transfer_type, name: str, tag: str, dst_party: Party, rubbish: Rubbish, table: _DTable, obj=None): tagged_key = f"{name}-{tag}" if transfer_type == federation_pb2.OBJECT: table.put(tagged_key, obj) rubbish.add_obj(table, tagged_key) else: rubbish.add_table(table) data_desc = TransferDataDesc( transferDataType=transfer_type, storageLocator=_get_storage_locator(table), taggedVariableName=_ser_des.serialize(tagged_key)) job = basic_meta_pb2.Job(jobId=self._session_id, name=name) transfer_meta = TransferMeta(job=job, tag=tag, src=self.local_party.to_pb(), dst=dst_party.to_pb(), dataDesc=data_desc, type=federation_pb2.SEND) self._stub.send(transfer_meta)
def get(self, name, tag, idx=-1): algorithm, sub_name = self.__check_authorization(name, is_send=False) auth_dict = self.trans_conf.get(algorithm) src_role = auth_dict.get(sub_name).get('src') src_party_ids = self.__get_parties(src_role) if 0 <= idx < len(src_party_ids): # idx is specified, return the remote object party_ids = [src_party_ids[idx]] else: # idx is not valid, return remote object list party_ids = src_party_ids job = basic_meta_pb2.Job(jobId=self.job_id, name=name) LOGGER.debug( "[GET] {} {} getting remote object {} from {} {}".format(self.role, self.party_id, tag, src_role, party_ids)) # loop = asyncio.get_event_loop() # tasks = [] results = [] for party_id in party_ids: src = cluster_comm_pb2.Party(partyId="{}".format(party_id), name=src_role) dst = cluster_comm_pb2.Party(partyId="{}".format(self.party_id), name=self.role) trans_meta = cluster_comm_pb2.TransferMeta(job=job, tag=tag, src=src, dst=dst, type=cluster_comm_pb2.RECV) # tasks.append(_receive(self.stub, trans_meta)) results.append(self.__pool.submit(_thread_receive, self.stub.recv, self.stub.checkStatus, trans_meta)) # results = loop.run_until_complete(asyncio.gather(*tasks)) # loop.close() results = [r.result() for r in results] rtn = [] for recv_meta in results: desc = recv_meta.dataDesc _persistent = desc.storageLocator.type != storage_basic_pb2.IN_MEMORY dest_table = _EggRoll.get_instance().table(name=desc.storageLocator.name, namespace=desc.storageLocator.namespace, persistent=_persistent) if recv_meta.dataDesc.transferDataType == cluster_comm_pb2.OBJECT: __tagged_key = _serdes.deserialize(desc.taggedVariableName) rtn.append(dest_table.get(__tagged_key)) LOGGER.debug("[GET] Got remote object {}".format(__tagged_key)) else: rtn.append(dest_table) src = recv_meta.src dst = recv_meta.dst LOGGER.debug( "[GET] Got remote table {} from {} {} to {} {}".format(dest_table, src.name, src.partyId, dst.name, dst.partyId)) if 0 <= idx < len(src_party_ids): return rtn[0] return rtn
def remote(self, obj, name: str, tag: str, role=None, idx=-1): algorithm, sub_name = self.__check_authorization(name) auth_dict = self.trans_conf.get(algorithm) src = cluster_comm_pb2.Party(partyId="{}".format(self.party_id), name=self.role) if idx >= 0: if role is None: raise ValueError("{} cannot be None if idx specified".format(role)) parties = {role: [self.__get_parties(role)[idx]]} elif role is not None: if role not in auth_dict.get(sub_name).get('dst'): raise ValueError("{} is not allowed to receive {}".format(role, name)) parties = {role: self.__get_parties(role)} else: parties = {} for _role in auth_dict.get(sub_name).get('dst'): parties[_role] = self.__get_parties(_role) for _role, _partyIds in parties.items(): for _partyId in _partyIds: _tagged_key = self.__remote__object_key(self.job_id, name, tag, self.role, self.party_id, _role, _partyId) if isinstance(obj, _DTable): ''' If it is a table, send the meta right away. ''' desc = cluster_comm_pb2.TransferDataDesc(transferDataType=cluster_comm_pb2.DTABLE, storageLocator=self.__get_locator(obj), taggedVariableName=_serdes.serialize(_tagged_key)) else: ''' If it is a object, put the object in the table and send the table meta. ''' _table = _EggRoll.get_instance().table(OBJECT_STORAGE_NAME, self.job_id) _table.put(_tagged_key, obj) storage_locator = self.__get_locator(_table) desc = cluster_comm_pb2.TransferDataDesc(transferDataType=cluster_comm_pb2.OBJECT, storageLocator=storage_locator, taggedVariableName=_serdes.serialize(_tagged_key)) LOGGER.debug("[REMOTE] Sending {}".format(_tagged_key)) dst = cluster_comm_pb2.Party(partyId="{}".format(_partyId), name=_role) job = basic_meta_pb2.Job(jobId=self.job_id, name=name) self.stub.send(cluster_comm_pb2.TransferMeta(job=job, tag=tag, src=src, dst=dst, dataDesc=desc, type=cluster_comm_pb2.SEND)) LOGGER.debug("[REMOTE] Sent {}".format(_tagged_key))
def _thread_receive(receive_func, check_func, name, tag, session_id, src_party, dst_party): log_msg = f"src={src_party}, dst={dst_party}, name={name}, tag={tag}, session_id={session_id}" LOGGER.debug(f"[GET] start: {log_msg}") job = basic_meta_pb2.Job(jobId=session_id, name=name) transfer_meta = TransferMeta(job=job, tag=tag, src=src_party.to_pb(), dst=dst_party.to_pb(), type=federation_pb2.RECV) recv_meta = _await_ready(receive_func, check_func, transfer_meta) desc = recv_meta.dataDesc if desc.transferDataType == federation_pb2.DTABLE: LOGGER.debug( f"[GET] table ready: src={src_party}, dst={dst_party}, name={name}, tag={tag}, session_id={session_id}" ) table = _get_table(desc) return table, table, None if desc.transferDataType == federation_pb2.OBJECT: obj_table = _cache_get_obj_storage_table[src_party] __tagged_key = _ser_des.deserialize(desc.taggedVariableName) obj = obj_table.get(__tagged_key) if not is_split_head(obj): LOGGER.debug(f"[GET] object ready: {log_msg}") return obj, (obj_table, __tagged_key), None num_split = obj.num_split() LOGGER.debug(f"[GET] num_fragments={num_split}: {log_msg}") fragment_keys = [] fragment_table = obj_table if REMOTE_FRAGMENT_OBJECT_USE_D_TABLE: LOGGER.debug(f"[GET] getting fragments table: {log_msg}") job = basic_meta_pb2.Job(jobId=session_id, name=name) transfer_meta = TransferMeta(job=job, tag=f"{tag}.fragments_table", src=src_party.to_pb(), dst=dst_party.to_pb(), type=federation_pb2.RECV) _resp_meta = _await_ready(receive_func, check_func, transfer_meta) table = _get_table(_resp_meta.dataDesc) fragment_table = table fragment_keys.extend(list(range(num_split))) else: for i in range(num_split): LOGGER.debug( f"[GET] getting fragments({i + 1}/{num_split}): {log_msg}") job = basic_meta_pb2.Job(jobId=session_id, name=name) transfer_meta = TransferMeta(job=job, tag=_fragment_tag(tag, i), src=src_party.to_pb(), dst=dst_party.to_pb(), type=federation_pb2.RECV) _resp_meta = _await_ready(receive_func, check_func, transfer_meta) LOGGER.debug( f"[GET] fragments({i + 1}/{num_split}) ready: {log_msg}") __fragment_tagged_key = _ser_des.deserialize( _resp_meta.dataDesc.taggedVariableName) fragment_keys.append(__fragment_tagged_key) LOGGER.debug(f"[GET] large object ready: {log_msg}") return obj, (obj_table, __tagged_key), (fragment_table, fragment_keys) else: raise IOError( f"unknown transfer type: {recv_meta.dataDesc.transferDataType}")