def _get_remote_obj_store_table(parties, rollsite): from eggroll.roll_site.utils.roll_site_utils import create_store_name from eggroll.core.transfer_model import ErRollSiteHeader tables = [] for role_party_id in parties: _role = role_party_id[0] _party_id = str(role_party_id[1]) _options = {} obj_type = 'object' roll_site_header = ErRollSiteHeader( roll_site_session_id=rollsite.roll_site_session_id, name=rollsite.name, tag=rollsite.tag, src_role=rollsite.local_role, src_party_id=rollsite.party_id, dst_role=_role, dst_party_id=_party_id, data_type=obj_type, options=_options) _tagged_key = create_store_name(roll_site_header) namespace = rollsite.roll_site_session_id tables.append(rollsite.ctx.rp_ctx.load(namespace, _tagged_key)) return tables
def __init__(self, options: dict = None): if options is None: options = {} super().__init__(options) er_partition = options['er_partition'] self.partition = er_partition self.store_locator = er_partition._store_locator self.partition_id = er_partition._id self.namespace = self.store_locator._namespace #_store_type = StoreTypes.ROLLPAIR_ROLLSITE # self._store_locator = meta_pb2.StoreLocator(storeType=_store_type, # namespace=self.namespace, # name=self.store_locator._name, # partitioner=self.store_locator._partitioner, # serdes=self.store_locator._serdes, # totalPartitions=self.store_locator._total_partitions) self.roll_site_header_string = options.get('roll_site_header', None) self.is_writable = False if self.roll_site_header_string: self.roll_site_header = ErRollSiteHeader.from_proto_string( self.roll_site_header_string.encode(stringify_charset)) self.proxy_endpoint = ErEndpoint.from_proto_string( options['proxy_endpoint'].encode(stringify_charset)) self.obj_type = options['obj_type'] self.is_writable = True L.info( f"writable RollSiteAdapter: {self.namespace}, {self.partition_id}. proxy_endpoint: {self.proxy_endpoint}, partition: {self.partition}" )
def push(self, obj, parties: list = None, options: dict = None): if options is None: options = {} futures = [] for role_party_id in parties: self.ctx.pushing_latch.count_up() dst_role = role_party_id[0] dst_party_id = str(role_party_id[1]) data_type = 'rollpair' if isinstance(obj, RollPair) else 'object' rs_header = ErRollSiteHeader( roll_site_session_id=self.roll_site_session_id, name=self.name, tag=self.tag, src_role=self.local_role, src_party_id=self.party_id, dst_role=dst_role, dst_party_id=dst_party_id, data_type=data_type) if isinstance(obj, RollPair): future = self._run_thread(self._impl_instance._push_rollpair, obj, rs_header, options) else: future = self._run_thread(self._impl_instance._push_bytes, obj, rs_header, options) futures.append(future) return futures
def count_batch(self, rs_header: ErRollSiteHeader, batch_pairs): L.trace(f'count batch. rs_key={rs_header.get_rs_key()}, rs_header={rs_header}, batch_pairs={batch_pairs}') batch_seq_id = rs_header._batch_seq stream_seq_id = rs_header._stream_seq if self._rs_header is None: self._rs_header = rs_header self._rs_key = rs_header.get_rs_key() L.debug(f"header arrived. rs_key={rs_header.get_rs_key()}, rs_header={rs_header}") self._header_arrive_event.set() self._batch_seq_to_pair_counter[batch_seq_id] = batch_pairs self._stream_seq_to_pair_counter[stream_seq_id] += batch_pairs self._stream_seq_to_batch_seq[stream_seq_id] = batch_seq_id
def pull(self, parties: list = None): self._pull_start_wall_time = time.time() self._pull_start_cpu_time = time.perf_counter() futures = [] for src_role, src_party_id in parties: src_party_id = str(src_party_id) roll_site_header = ErRollSiteHeader( roll_site_session_id=self.roll_site_session_id, name=self.name, tag=self.tag, src_role=src_role, src_party_id=src_party_id, dst_role=self.local_role, dst_party_id=self.party_id) _tagged_key = create_store_name(roll_site_header) name = _tagged_key model = proxy_pb2.Model(name=_stringify(roll_site_header)) task_info = proxy_pb2.Task(taskId=name, model=model) topic_src = proxy_pb2.Topic(name="get_status", partyId=src_party_id, role=src_role, callback=None) topic_dst = proxy_pb2.Topic(name="get_status", partyId=self.party_id, role=self.local_role, callback=None) get_status_command = proxy_pb2.Command(name="get_status") conf_test = proxy_pb2.Conf(overallTimeout=1000, completionWaitTimeout=1000, packetIntervalTimeout=1000, maxRetries=10) metadata = proxy_pb2.Metadata(task=task_info, src=topic_src, dst=topic_dst, command=get_status_command, operator="getStatus", seq=0, ack=0) packet = proxy_pb2.Packet(header=metadata) namespace = self.roll_site_session_id L.info( f"pulling prepared tagged_key: {_tagged_key}, packet:{to_one_line_string(packet)}" ) futures.append( RollSite.receive_exeutor_pool.submit(RollSite._thread_receive, self, packet, namespace, roll_site_header)) return futures
def run(self): # batch stream must be executed serially, and reinit. # TODO:0: remove lock to bss rs_header = None with PutBatchTask._partition_lock[self.tag]: # tag includes partition info in tag generation L.trace(f"do_store start for tag={self.tag}, partition_id={self.partition._id}") bss = _BatchStreamStatus.get_or_create(self.tag) try: broker = TransferService.get_or_create_broker(self.tag, write_signals=1) iter_wait = 0 iter_timeout = int(CoreConfKeys.EGGROLL_CORE_FIFOBROKER_ITER_TIMEOUT_SEC.get()) batch = None batch_get_interval = 0.1 with create_adapter(self.partition) as db, db.new_batch() as wb: #for batch in broker: while not broker.is_closable(): try: batch = broker.get(block=True, timeout=batch_get_interval) except queue.Empty as e: iter_wait += batch_get_interval if iter_wait > iter_timeout: raise TimeoutError(f'timeout in PutBatchTask.run. tag={self.tag}, iter_timeout={iter_timeout}, iter_wait={iter_wait}') else: continue except BrokerClosed as e: continue iter_wait = 0 rs_header = ErRollSiteHeader.from_proto_string(batch.header.ext) batch_pairs = 0 if batch.data: bin_data = ArrayByteBuffer(batch.data) reader = PairBinReader(pair_buffer=bin_data, data=batch.data) for k_bytes, v_bytes in reader.read_all(): wb.put(k_bytes, v_bytes) batch_pairs += 1 bss.count_batch(rs_header, batch_pairs) # TODO:0 bss._data_type = rs_header._data_type if rs_header._stage == FINISH_STATUS: bss.set_done(rs_header) # starting from 0 bss.check_finish() # TransferService.remove_broker(tag) will be called in get_status phrase finished or exception got except Exception as e: L.exception(f'_run_put_batch error, tag={self.tag}, ' f'rs_key={rs_header.get_rs_key() if rs_header is not None else None}, rs_header={rs_header}') raise e finally: TransferService.remove_broker(self.tag)
def __init__(self, options: dict = None): if options is None: options = {} super().__init__(options) er_partition = options['er_partition'] self.partition = er_partition self.store_locator = er_partition._store_locator self.partition_id = er_partition._id self.namespace = self.store_locator._namespace #_store_type = StoreTypes.ROLLPAIR_ROLLSITE # self._store_locator = meta_pb2.StoreLocator(storeType=_store_type, # namespace=self.namespace, # name=self.store_locator._name, # partitioner=self.store_locator._partitioner, # serdes=self.store_locator._serdes, # totalPartitions=self.store_locator._total_partitions) self.roll_site_header_string = options.get('roll_site_header', None) self.is_writable = False if self.roll_site_header_string: self.roll_site_header = ErRollSiteHeader.from_proto_string( self.roll_site_header_string.encode(stringify_charset)) self.roll_site_header._options['partition_id'] = self.partition_id self.proxy_endpoint = ErEndpoint.from_proto_string( options['proxy_endpoint'].encode(stringify_charset)) self.obj_type = options['obj_type'] self.is_writable = True L.trace( f"writable RollSiteAdapter: {self.namespace}, partition_id={self.partition_id}. proxy_endpoint={self.proxy_endpoint}, partition={self.partition}" ) self.unarycall_max_retry_cnt = int( RollSiteConfKeys.EGGROLL_ROLLSITE_UNARYCALL_CLIENT_MAX_RETRY. get_with(options)) self.push_max_retry_cnt = int( RollSiteConfKeys.EGGROLL_ROLLSITE_PUSH_CLIENT_MAX_RETRY.get_with( options)) self.push_overall_timeout = int( RollSiteConfKeys.EGGROLL_ROLLSITE_OVERALL_TIMEOUT_SEC.get_with( options)) self.push_completion_wait_timeout = int( RollSiteConfKeys.EGGROLL_ROLLSITE_COMPLETION_WAIT_TIMEOUT_SEC. get_with(options)) self.push_packet_interval_timeout = int( RollSiteConfKeys.EGGROLL_ROLLSITE_PACKET_INTERVAL_TIMEOUT_SEC. get_with(options))
def pull(self, parties: list = None): futures = [] for src_role, src_party_id in parties: src_party_id = str(src_party_id) rs_header = ErRollSiteHeader( roll_site_session_id=self.roll_site_session_id, name=self.name, tag=self.tag, src_role=src_role, src_party_id=src_party_id, dst_role=self.local_role, dst_party_id=self.party_id) futures.append( self._receive_executor_pool.submit(self._pull_one, rs_header)) return futures
def _pull_one(self, rs_header: ErRollSiteHeader, options: dict = None): if options is None: options = {} start_time = time.time() rs_key = rp_name = rs_header.get_rs_key() rp_namespace = self.roll_site_session_id transfer_tag_prefix = "putBatch-" + rs_header.get_rs_key() + "-" last_total_batches = None last_cur_pairs = -1 pull_attempts = 0 data_type = None L.debug(f'pulling rs_key={rs_key}') try: # make sure rollpair already created pull_header_interval = self.pull_header_interval pull_header_timeout = self.pull_header_timeout # skips pickling self pull_interval = self.pull_interval # skips pickling self def get_partition_status(task): put_batch_task = PutBatchTask( transfer_tag_prefix + str(task._inputs[0]._id), None) return put_batch_task.get_status(pull_interval) def get_status(roll_site): pull_status = {} total_pairs = 0 total_batches = 0 all_finished = True final_options = options.copy() final_options['create_if_missing'] = False store = roll_site.ctx.rp_ctx.load(name=rp_name, namespace=rp_namespace, options=final_options) if store is None: raise ValueError( f'illegal state for rp_name={rp_name}, rp_namespace={rp_namespace}' ) all_status = store.with_stores( get_partition_status, options={"__op": "get_partition_status"}) for part_id, part_status in all_status: if not part_status.is_finished: all_finished = False pull_status[part_id] = part_status total_batches += part_status.total_batches total_pairs += part_status.total_pairs return pull_status, all_finished, total_batches, total_pairs def clear_status(task): return PutBatchTask(transfer_tag_prefix + str(task._inputs[0]._id)).clear_status() wait_time = 0 header_response = None while wait_time < pull_header_timeout and \ (header_response is None or not isinstance(header_response[0][1], ErRollSiteHeader)): final_options = options.copy() final_options['create_if_missing'] = True final_options['total_partitions'] = 1 header_response = self.ctx.rp_ctx.load(name=STATUS_TABLE_NAME, namespace=rp_namespace, options=final_options) \ .with_stores(lambda x: PutBatchTask(transfer_tag_prefix + "0").get_header(pull_header_interval), options={"__op": "pull_header"}) wait_time += pull_header_interval #pull_status, all_finished, total_batches, total_pairs = stat_all_status(self) L.debug( f"roll site get header_response: rs_key={rs_key}, rs_header={rs_header}, wait_time={wait_time}" ) if header_response is None or not isinstance( header_response[0][1], ErRollSiteHeader): raise IOError( f"roll site pull header failed: rs_key={rs_key}, rs_header={rs_header}, timeout={self.pull_header_timeout}" ) else: header: ErRollSiteHeader = header_response[0][1] # TODO:0: push bytes has only one partition, that means it has finished, need not get_status data_type = header._data_type L.debug( f"roll site pull header successful: rs_key={rs_key}, rs_header={header}" ) pull_status = {} for cur_retry in range(self.pull_max_retry): pull_attempts = cur_retry pull_status, all_finished, total_batches, cur_pairs = get_status( self) if not all_finished: L.debug( f'getting status NOT finished for rs_key={rs_key}, ' f'rs_header={rs_header}, ' f'cur_status={pull_status}, ' f'attempts={pull_attempts}, ' f'cur_pairs={cur_pairs}, ' f'last_cur_pairs={last_cur_pairs}, ' f'total_batches={total_batches}, ' f'last_total_batches={last_total_batches}, ' f'elapsed={time.time() - start_time}') if last_cur_pairs == cur_pairs and cur_pairs > 0: raise IOError( f"roll site pull waiting failed because there is no updated progress: rs_key={rs_key}, " f"rs_header={rs_header}, pull_status={pull_status}, last_cur_pairs={last_cur_pairs}, cur_pairs={cur_pairs}" ) else: L.debug( f"getting status DO finished for rs_key={rs_key}, rs_header={rs_header}, pull_status={pull_status}, cur_pairs={cur_pairs}, total_batches={total_batches}" ) rp = self.ctx.rp_ctx.load(name=rp_name, namespace=rp_namespace) clear_future = self._receive_executor_pool.submit( rp.with_stores, clear_status, options={"__op": "clear_status"}) if data_type == "object": result = pickle.loads(b''.join( map( lambda t: t[1], sorted(rp.get_all(), key=lambda x: int.from_bytes( x[0], "big"))))) rp.destroy() L.debug( f"pulled object: rs_key={rs_key}, rs_header={rs_header}, is_none={result is None}, " f"elapsed={time.time() - start_time}") else: result = rp if L.isEnabledFor(logging.DEBUG): L.debug( f"pulled roll_pair: rs_key={rs_key}, rs_header={rs_header}, rp.count={rp.count()}, " f"elapsed={time.time() - start_time}") clear_future.result() return result last_total_batches = total_batches last_cur_pairs = cur_pairs raise IOError( f"roll site pull failed. max try exceeded: {self.pull_max_retry}, rs_key={rs_key}, " f"rs_header={rs_header}, pull_status={pull_status}") except Exception as e: L.exception( f"fatal error: when pulling rs_key={rs_key}, rs_header={rs_header}, attempts={pull_attempts}" ) raise e
def _push_rollpair(self, rp: RollPair, rs_header: ErRollSiteHeader, options: dict = None): if options is None: options = {} rs_key = rs_header.get_rs_key() if L.isEnabledFor(logging.DEBUG): L.debug( f"pushing rollpair: rs_key={rs_key}, rs_header={rs_header}, rp.count={rp.count()}" ) start_time = time.time() rs_header._total_partitions = rp.get_partitions() serdes = options.get('serdes', None) if serdes is not None: rs_header._options['serdes'] = serdes wrapee_cls = options.get('wrapee_cls', None) if serdes is not None: rs_header._options['wrapee_cls'] = wrapee_cls batches_per_stream = self.push_batches_per_stream body_bytes = self.batch_body_bytes endpoint = self.ctx.proxy_endpoint max_retry_cnt = self.push_max_retry long_retry_cnt = self.push_long_retry per_stream_timeout = self.push_per_stream_timeout def _push_partition(ertask): rs_header._partition_id = ertask._inputs[0]._id L.trace( f"pushing rollpair partition. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}" ) from eggroll.core.grpc.factory import GrpcChannelFactory from eggroll.core.proto import proxy_pb2_grpc grpc_channel_factory = GrpcChannelFactory() with create_adapter(ertask._inputs[0]) as db, db.iteritems() as rb: # NOTICE AGAIN: all modifications to rs_header are limited in bs_helper. # rs_header is shared by bs_helper and here. any modification in bs_helper affects this header. # Remind that python's object references are passed by value, # meaning the 'pointer' is copied, while the contents are modificable bs_helper = _BatchStreamHelper(rs_header) bin_batch_streams = bs_helper._generate_batch_streams( pair_iter=rb, batches_per_stream=batches_per_stream, body_bytes=body_bytes) channel = grpc_channel_factory.create_channel(endpoint) stub = proxy_pb2_grpc.DataTransferServiceStub(channel) for batch_stream in bin_batch_streams: batch_stream_data = list(batch_stream) cur_retry = 0 exception = None while cur_retry < max_retry_cnt: L.trace( f'pushing rollpair partition stream. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}' ) try: stub.push(bs_helper.generate_packet( batch_stream_data, cur_retry), timeout=per_stream_timeout) exception = None break except Exception as e: if cur_retry < max_retry_cnt - long_retry_cnt: retry_interval = round( min(2 * cur_retry, 20) + random.random() * 10, 3) else: retry_interval = round( 300 + random.random() * 10, 3) L.warn( f"push rp partition error. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, max_retry_cnt={max_retry_cnt}, cur_retry={cur_retry}, retry_interval={retry_interval}", exc_info=e) time.sleep(retry_interval) if isinstance(e, RpcError) and e.code( ).name == 'UNAVAILABLE': channel = grpc_channel_factory.create_channel( endpoint, refresh=True) stub = proxy_pb2_grpc.DataTransferServiceStub( channel) exception = e finally: cur_retry += 1 if exception is not None: L.exception( f"push partition failed. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}", exc_info=exception) raise exception L.trace( f'pushed rollpair partition stream. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, retry count={cur_retry - 1}' ) L.trace( f"pushed rollpair partition. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}" ) rp.with_stores(_push_partition, options={"__op": "push_partition"}) if L.isEnabledFor(logging.DEBUG): L.debug( f"pushed rollpair: rs_key={rs_key}, rs_header={rs_header}, count={rp.count()}, elapsed={time.time() - start_time}" ) self.ctx.pushing_latch.count_down()
def _push_bytes(self, obj, rs_header: ErRollSiteHeader, options: dict = None): if options is None: options = {} start_time = time.time() rs_key = rs_header.get_rs_key() int_size = 4 if L.isEnabledFor(logging.DEBUG): L.debug(f"pushing object: rs_key={rs_key}, rs_header={rs_header}") def _generate_obj_bytes(py_obj, body_bytes): key_id = 0 obj_bytes = pickle.dumps(py_obj) obj_bytes_len = len(obj_bytes) cur_pos = 0 while cur_pos <= obj_bytes_len: yield key_id.to_bytes( int_size, "big"), obj_bytes[cur_pos:cur_pos + body_bytes] key_id += 1 cur_pos += body_bytes rs_header._partition_id = 0 rs_header._total_partitions = 1 serdes = options.get('serdes', None) if serdes is not None: rs_header._options['serdes'] = serdes wrapee_cls = options.get('wrapee_cls', None) if serdes is not None: rs_header._options['wrapee_cls'] = wrapee_cls # NOTICE: all modifications to rs_header are limited in bs_helper. # rs_header is shared by bs_helper and here. any modification in bs_helper affects this header. # Remind that python's object references are passed by value, # meaning the 'pointer' is copied, while the contents are modificable bs_helper = _BatchStreamHelper(rs_header=rs_header) bin_batch_streams = bs_helper._generate_batch_streams( pair_iter=_generate_obj_bytes(obj, self.batch_body_bytes), batches_per_stream=self.push_batches_per_stream, body_bytes=self.batch_body_bytes) grpc_channel_factory = GrpcChannelFactory() channel = grpc_channel_factory.create_channel(self.ctx.proxy_endpoint) stub = proxy_pb2_grpc.DataTransferServiceStub(channel) max_retry_cnt = self.push_max_retry long_retry_cnt = self.push_long_retry per_stream_timeout = self.push_per_stream_timeout # if use stub.push.future here, retry mechanism is a problem to solve for batch_stream in bin_batch_streams: cur_retry = 0 batch_stream_data = list(batch_stream) exception = None while cur_retry < max_retry_cnt: L.trace( f'pushing object stream. rs_key={rs_key}, rs_header={rs_header}, cur_retry={cur_retry}' ) try: stub.push(bs_helper.generate_packet( batch_stream_data, cur_retry), timeout=per_stream_timeout) exception = None break except Exception as e: if cur_retry <= max_retry_cnt - long_retry_cnt: retry_interval = round( min(2 * cur_retry, 20) + random.random() * 10, 3) else: retry_interval = round(300 + random.random() * 10, 3) L.warn( f"push object error. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, max_retry_cnt={max_retry_cnt}, cur_retry={cur_retry}, retry_interval={retry_interval}", exc_info=e) time.sleep(retry_interval) if isinstance(e, RpcError) and e.code().name == 'UNAVAILABLE': channel = grpc_channel_factory.create_channel( self.ctx.proxy_endpoint, refresh=True) stub = proxy_pb2_grpc.DataTransferServiceStub(channel) exception = e finally: cur_retry += 1 if exception is not None: L.exception( f"push object failed. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}", exc_info=exception) raise exception L.trace( f'pushed object stream. rs_key={rs_key}, rs_header={rs_header}, cur_retry={cur_retry - 1}' ) L.debug( f"pushed object: rs_key={rs_key}, rs_header={rs_header}, is_none={obj is None}, elapsed={time.time() - start_time}" ) self.ctx.pushing_latch.count_down()
def push(self, obj, parties: list = None): L.info( f"pushing: self:{self.__dict__}, obj_type:{type(obj)}, parties:{parties}" ) self._push_start_wall_time = time.time() self._push_start_cpu_time = time.perf_counter() futures = [] for role_party_id in parties: self.ctx.pushing_task_count += 1 _role = role_party_id[0] _party_id = str(role_party_id[1]) _options = {} obj_type = 'rollpair' if isinstance(obj, RollPair) else 'object' roll_site_header = ErRollSiteHeader( roll_site_session_id=self.roll_site_session_id, name=self.name, tag=self.tag, src_role=self.local_role, src_party_id=self.party_id, dst_role=_role, dst_party_id=_party_id, data_type=obj_type, options=_options) _tagged_key = create_store_name(roll_site_header) L.debug(f"pushing start party:{type(obj)}, {_tagged_key}") namespace = self.roll_site_session_id if isinstance(obj, RollPair): rp = obj else: rp = self.ctx.rp_ctx.load(namespace, _tagged_key) rp.put(_tagged_key, obj) rp.disable_gc() L.info(f"pushing prepared: {type(obj)}, tag_key:{_tagged_key}") def map_values(_tagged_key, is_standalone, roll_site_header): if is_standalone: dst_name = _tagged_key store_type = rp.get_store_type() else: dst_name = DELIM.join([ _tagged_key, self.dst_host, str(self.dst_port), obj_type ]) store_type = StoreTypes.ROLLPAIR_ROLLSITE if is_standalone: status_rp = self.ctx.rp_ctx.load( namespace, STATUS_TABLE_NAME + DELIM + self.roll_site_session_id, options=_options) status_rp.disable_gc() if isinstance(obj, RollPair): status_rp.put(_tagged_key, (obj_type.encode("utf-8"), rp.get_name(), rp.get_namespace())) else: status_rp.put( _tagged_key, (obj_type.encode("utf-8"), dst_name, namespace)) else: store = rp.get_store() store_locator = store._store_locator new_store_locator = ErStoreLocator( store_type=store_type, namespace=namespace, name=dst_name, total_partitions=store_locator._total_partitions, partitioner=store_locator._partitioner, serdes=store_locator._serdes) # TODO:0: move options from job to store when database modification finished options = { "roll_site_header": roll_site_header, "proxy_endpoint": self.ctx.proxy_endpoint, "obj_type": obj_type } if isinstance(obj, RollPair): roll_site_header._options[ 'total_partitions'] = obj.get_store( )._store_locator._total_partitions L.info( f"RollSite.push: pushing {roll_site_header}, type: RollPair, count: {obj.count()}" ) else: L.info( f"RollSite.push: pushing {roll_site_header}, type: object" ) rp.map_values( lambda v: v, output=ErStore(store_locator=new_store_locator), options=options) L.info( f"RollSite.push: push {roll_site_header} done. type:{type(obj)}" ) return _tagged_key future = RollSite.receive_exeutor_pool.submit( map_values, _tagged_key, self._is_standalone, roll_site_header) if not self._is_standalone and (obj_type == 'object' or obj_type == b'object'): tmp_rp = rp else: tmp_rp = None future.add_done_callback( functools.partial(self._push_callback, tmp_rp=tmp_rp)) futures.append(future) return futures
def push(self, obj, parties: list = None): L.info(f"pushing: self:{self.__dict__}, obj_type:{type(obj)}, parties:{parties}") self._push_start_wall_time = time.time() self._push_start_cpu_time = time.perf_counter() self.ctx.pushing_task_count += 1 futures = [] for role_party_id in parties: _role = role_party_id[0] _party_id = str(role_party_id[1]) _options = {} obj_type = 'rollpair' if isinstance(obj, RollPair) else 'object' roll_site_header = ErRollSiteHeader( roll_site_session_id=self.roll_site_session_id, name=self.name, tag=self.tag, src_role=self.local_role, src_party_id=self.party_id, dst_role=_role, dst_party_id=_party_id, data_type=obj_type, options=_options) _tagged_key = create_store_name(roll_site_header) L.debug(f"pushing start party:{type(obj)}, {_tagged_key}") namespace = self.roll_site_session_id if isinstance(obj, RollPair): rp = obj else: rp = self.ctx.rp_ctx.load(namespace, _tagged_key) rp.put(_tagged_key, obj) rp.disable_gc() L.info(f"pushing prepared: {type(obj)}, tag_key:{_tagged_key}") def map_values(_tagged_key): is_standalone = self.ctx.rp_ctx.get_session().get_option( SessionConfKeys.CONFKEY_SESSION_DEPLOY_MODE) == DeployModes.STANDALONE if is_standalone: dst_name = _tagged_key store_type = rp.get_store_type() else: dst_name = DELIM.join([_tagged_key, self.dst_host, str(self.dst_port), obj_type]) store_type = StoreTypes.ROLLPAIR_ROLLSITE if is_standalone: status_rp = self.ctx.rp_ctx.load(namespace, STATUS_TABLE_NAME + DELIM + self.roll_site_session_id, options=_options) status_rp.disable_gc() if isinstance(obj, RollPair): status_rp.put(_tagged_key, (obj_type.encode("utf-8"), rp.get_name(), rp.get_namespace())) else: status_rp.put(_tagged_key, (obj_type.encode("utf-8"), dst_name, namespace)) else: store = rp.get_store() store_locator = store._store_locator new_store_locator = ErStoreLocator(store_type=store_type, namespace=namespace, name=dst_name, total_partitions=store_locator._total_partitions, partitioner=store_locator._partitioner, serdes=store_locator._serdes) # TODO:0: move options from job to store when database modification finished options = {"roll_site_header": roll_site_header, "proxy_endpoint": self.ctx.proxy_endpoint, "obj_type": obj_type} if isinstance(obj, RollPair): roll_site_header._options['total_partitions'] = obj.get_store()._store_locator._total_partitions L.debug(f"pushing map_values: {dst_name}, count: {obj.count()}, tag_key:{_tagged_key}") rp.map_values(lambda v: v, output=ErStore(store_locator=new_store_locator), options=options) L.info(f"pushing map_values done:{type(obj)}, tag_key:{_tagged_key}") return _tagged_key future = self.process_pool.submit(map_values, _tagged_key) future.add_done_callback(self._push_callback) futures.append(future) self.process_pool.shutdown(wait=False) return futures