class TransferClient(object): def __init__(self): self.__grpc_channel_factory = GrpcChannelFactory() #self.__bin_packet_len = 32 << 20 #self.__chunk_size = 100 @_exception_logger def send(self, broker, endpoint: ErEndpoint, tag): try: channel = self.__grpc_channel_factory.create_channel(endpoint) stub = transfer_pb2_grpc.TransferServiceStub(channel) import types if isinstance(broker, types.GeneratorType): requests = (transfer_pb2.TransferBatch(header=transfer_pb2.TransferHeader(id=i, tag=tag), data=d) for i, d in enumerate(broker)) else: requests = TransferService.transfer_batch_generator_from_broker(broker, tag) future = stub.send.future(requests, metadata=[(TRANSFER_BROKER_NAME, tag)]) return future except Exception as e: L.error(f'Error calling to {endpoint} in TransferClient.send') raise e @_exception_logger def recv(self, endpoint: ErEndpoint, tag, broker): try: L.debug(f'TransferClient.recv for endpoint: {endpoint}, tag: {tag}') @_exception_logger def fill_broker(iterable: Iterable, broker): try: iterator = iter(iterable) for e in iterator: broker.put(e) broker.signal_write_finish() except Exception as e: L.error(f'Fail to fill broker for tag: {tag}, endpoint: {endpoint}') raise e channel = self.__grpc_channel_factory.create_channel(endpoint) stub = transfer_pb2_grpc.TransferServiceStub(channel) request = transfer_pb2.TransferBatch( header=transfer_pb2.TransferHeader(id=1, tag=tag)) response_iter = stub.recv( request, metadata=[(TRANSFER_BROKER_NAME, tag)]) if broker is None: return response_iter else: t = Thread(target=fill_broker, args=[response_iter, broker]) t.start() return broker except Exception as e: L.error(f'Error calling to {endpoint} in TransferClient.recv') raise e
def __init__(self): self._channel_factory = GrpcChannelFactory() if CommandClient._executor_pool is None: with CommandClient._executor_pool_lock: if CommandClient._executor_pool is None: _executor_pool_type = CoreConfKeys.EGGROLL_CORE_DEFAULT_EXECUTOR_POOL.get( ) _max_workers = int( CoreConfKeys. EGGROLL_CORE_CLIENT_COMMAND_EXECUTOR_POOL_MAX_SIZE.get( )) CommandClient._executor_pool = create_executor_pool( canonical_name=_executor_pool_type, max_workers=_max_workers, thread_name_prefix="command_client")
class RollSiteContext: grpc_channel_factory = GrpcChannelFactory() def __init__(self, roll_site_session_id, rp_ctx: RollPairContext, options: dict = None): if options is None: options = {} self.roll_site_session_id = roll_site_session_id self.rp_ctx = rp_ctx self.role = options["self_role"] self.party_id = str(options["self_party_id"]) self._options = options endpoint = options["proxy_endpoint"] if isinstance(endpoint, str): splitted = endpoint.split(':') self.proxy_endpoint = ErEndpoint(host=splitted[0].strip(), port=int(splitted[1].strip())) elif isinstance(endpoint, ErEndpoint): self.proxy_endpoint = endpoint else: raise ValueError("endpoint only support str and ErEndpoint type") self.is_standalone = RollSiteConfKeys.EGGROLL_ROLLSITE_DEPLOY_MODE.get_with( options) == "standalone" if self.is_standalone: self.stub = None else: channel = self.grpc_channel_factory.create_channel( self.proxy_endpoint) self.stub = proxy_pb2_grpc.DataTransferServiceStub(channel) self.pushing_latch = CountDownLatch(0) self.rp_ctx.get_session().add_exit_task(self._wait_push_complete) self._wait_push_exit_timeout = int( RollSiteConfKeys.EGGROLL_ROLLSITE_PUSH_OVERALL_TIMEOUT_SEC. get_with(options)) L.info(f"inited RollSiteContext: {self.__dict__}") def _wait_push_complete(self): session_id = self.rp_ctx.get_session().get_session_id() L.info(f"running roll site exit func for er session={session_id}," f" roll site session id={self.roll_site_session_id}") residual_count = self.pushing_latch.await_latch( self._wait_push_exit_timeout) if residual_count != 0: L.error( f"exit session when not finish push: " f"residual_count={residual_count}, timeout={self._wait_push_exit_timeout}" ) def load(self, name: str, tag: str, options: dict = None): if options is None: options = {} final_options = self._options.copy().update(options) return RollSite(name, tag, self, options=final_options)
class RollSiteWriteBatch(PairWriteBatch): grpc_channel_factory = GrpcChannelFactory() # TODO:0: check if secure channel needed def __init__(self, adapter: RollSiteAdapter, options: dict = None): if options is None: options = {} self.adapter = adapter self.roll_site_header: ErRollSiteHeader = adapter.roll_site_header self.namespace = adapter.namespace self.name = create_store_name(self.roll_site_header) self.tagged_key = '' self.obj_type = adapter.obj_type self.proxy_endpoint = adapter.proxy_endpoint channel = self.grpc_channel_factory.create_channel(self.proxy_endpoint) self.stub = proxy_pb2_grpc.DataTransferServiceStub(channel) static_er_conf = get_static_er_conf() self.__bin_packet_len = int( options.get( RollSiteConfKeys.EGGROLL_ROLLSITE_ADAPTER_SENDBUF_SIZE.key, static_er_conf.get( RollSiteConfKeys.EGGROLL_ROLLSITE_ADAPTER_SENDBUF_SIZE.key, RollSiteConfKeys.EGGROLL_ROLLSITE_ADAPTER_SENDBUF_SIZE. default_value))) self.total_written = 0 self.ba = bytearray(self.__bin_packet_len) self.buffer = ArrayByteBuffer(self.ba) self.writer = PairBinWriter(pair_buffer=self.buffer) self.push_cnt = 0 self.topic_src = proxy_pb2.Topic( name=self.name, partyId=self.roll_site_header._src_party_id, role=self.roll_site_header._src_role, callback=None) self.topic_dst = proxy_pb2.Topic( name=self.name, partyId=self.roll_site_header._dst_party_id, role=self.roll_site_header._dst_role, callback=None) def __repr__(self): return f'<ErRollSiteWriteBatch(' \ f'adapter={self.adapter}, ' \ f'roll_site_header={self.roll_site_header}' \ f'namespace={self.namespace}, ' \ f'name={self.name}, ' \ f'obj_type={self.obj_type}, ' \ f'proxy_endpoint={self.proxy_endpoint}) ' \ f'at {hex(id(self))}>' def generate_message(self, obj, metadata): data = proxy_pb2.Data(value=obj) metadata.seq += 1 packet = proxy_pb2.Packet(header=metadata, body=data) yield packet # TODO:0: configurable def push(self, obj): L.debug( f'pushing for task: {self.name}, partition id: {self.adapter.partition_id}, push cnt: {self.get_push_count()}' ) task_info = proxy_pb2.Task( taskId=self.name, model=proxy_pb2.Model(name=self.adapter.roll_site_header_string, dataKey=self.namespace)) command_test = proxy_pb2.Command() # TODO: conf test as config and use it conf_test = proxy_pb2.Conf(overallTimeout=200000, completionWaitTimeout=200000, packetIntervalTimeout=200000, maxRetries=10) metadata = proxy_pb2.Metadata(task=task_info, src=self.topic_src, dst=self.topic_dst, command=command_test, seq=0, ack=0) max_retry_cnt = 100 exception = None for i in range(1, max_retry_cnt + 1): try: self.stub.push(self.generate_message(obj, metadata)) exception = None self.increase_push_count() break except Exception as e: exception = e L.info( f'caught exception in pushing {self.name}, partition_id: {self.adapter.partition_id}: {e}. retrying. current retry count: {i}, max_retry_cnt: {max_retry_cnt}' ) time.sleep(min(0.1 * i, 30)) if exception: raise GrpcCallError("error in push", self.proxy_endpoint, exception) def write(self): bin_data = bytes(self.ba[0:self.buffer.get_offset()]) self.push(bin_data) self.buffer = ArrayByteBuffer(self.ba) def send_end(self): L.info(f"send_end tagged_key:{self.tagged_key}") task_info = proxy_pb2.Task( taskId=self.name, model=proxy_pb2.Model(name=self.adapter.roll_site_header_string, dataKey=self.namespace)) command_test = proxy_pb2.Command(name="set_status") conf_test = proxy_pb2.Conf(overallTimeout=20000, completionWaitTimeout=20000, packetIntervalTimeout=20000, maxRetries=10) metadata = proxy_pb2.Metadata(task=task_info, src=self.topic_src, dst=self.topic_dst, command=command_test, operator="markEnd", seq=self.get_push_count(), ack=0) packet = proxy_pb2.Packet(header=metadata) try: # TODO:0: retry and sleep for all grpc call in RollSite self.stub.unaryCall(packet) except Exception as e: raise GrpcCallError('send_end', self.proxy_endpoint, e) def close(self): bin_batch = bytes(self.ba[0:self.buffer.get_offset()]) self.push(bin_batch) self.send_end() L.info(f'closing RollSiteWriteBatch for name: {self.name}, ' f'total push count: {self.push_cnt}') def put(self, k, v): if self.obj_type == 'object': L.debug(f"set tagged_key: {k}") self.tagged_key = _serdes.deserialize(k) try: self.writer.write(k, v) except IndexError as e: bin_batch = bytes(self.ba[0:self.buffer.get_offset()]) self.push(bin_batch) # TODO:0: replace 1024 with constant self.ba = bytearray( max(self.__bin_packet_len, len(k) + len(v) + 1024)) self.buffer = ArrayByteBuffer(self.ba) self.writer = PairBinWriter(pair_buffer=self.buffer) self.writer.write(k, v) except Exception as e: L.error(f"Unexpected error: {sys.exc_info()[0]}") raise e def increase_push_count(self): self.push_cnt += 1 def get_push_count(self): return self.push_cnt
def __init__(self): self._channel_factory = GrpcChannelFactory()
class CommandClient(object): executor = ThreadPoolExecutor(max_workers=int( CoreConfKeys.EGGROLL_CORE_CLIENT_COMMAND_EXECUTOR_POOL_MAX_SIZE.get()), thread_name_prefix="command_client") def __init__(self): self._channel_factory = GrpcChannelFactory() def simple_sync_send(self, input: RpcMessage, output_type, endpoint: ErEndpoint, command_uri: CommandURI, serdes_type=SerdesTypes.PROTOBUF): results = self.sync_send(inputs=[input], output_types=[output_type], endpoint=endpoint, command_uri=command_uri, serdes_type=serdes_type) if len(results): return results[0] else: return None def sync_send(self, inputs: list, output_types: list, endpoint: ErEndpoint, command_uri: CommandURI, serdes_type=SerdesTypes.PROTOBUF): request = None try: request = ErCommandRequest(id=time_now(), uri=command_uri._uri, args=_map_and_listify( _to_proto_string, inputs)) start = time.time() L.debug(f"[CC] calling: {endpoint} {command_uri} {request}") _channel = self._channel_factory.create_channel(endpoint) _command_stub = command_pb2_grpc.CommandServiceStub(_channel) response = _command_stub.call(request.to_proto()) er_response = ErCommandResponse.from_proto(response) elapsed = time.time() - start L.debug( f"[CC] called (elapsed: {elapsed}): {endpoint}, {command_uri}, {request}, {er_response}" ) byte_results = er_response._results if len(byte_results): zipped = zip(output_types, byte_results) return list( map( lambda t: t[0].from_proto_string(t[1]) if t[1] is not None else None, zipped)) else: return [] except Exception as e: L.error( f'Error calling to {endpoint}, command_uri: {command_uri}, req:{request}', exc_info=e) raise CommandCallError(command_uri, endpoint, e) def async_call(self, args, output_types: list, command_uri: CommandURI, serdes_type=SerdesTypes.PROTOBUF, callback=None): futures = list() for inputs, endpoint in args: f = self.executor.submit(self.sync_send, inputs=inputs, output_types=output_types, endpoint=endpoint, command_uri=command_uri, serdes_type=serdes_type) if callback: f.add_done_callback(callback) futures.append(f) return futures
def serve(args): prefix = 'v1/egg-pair' set_data_dir(args.data_dir) CommandRouter.get_instance().register( service_name=f"{prefix}/runTask", route_to_module_name="eggroll.roll_pair.egg_pair", route_to_class_name="EggPair", route_to_method_name="run_task") max_workers = int( RollPairConfKeys. EGGROLL_ROLLPAIR_EGGPAIR_SERVER_EXECUTOR_POOL_MAX_SIZE.get()) executor_pool_type = CoreConfKeys.EGGROLL_CORE_DEFAULT_EXECUTOR_POOL.get() command_server = grpc.server( create_executor_pool(canonical_name=executor_pool_type, max_workers=max_workers, thread_name_prefix="eggpair-command-server"), options= [("grpc.max_metadata_size", int(CoreConfKeys. EGGROLL_CORE_GRPC_SERVER_CHANNEL_MAX_INBOUND_METADATA_SIZE.get()) ), ('grpc.max_send_message_length', int(CoreConfKeys. EGGROLL_CORE_GRPC_SERVER_CHANNEL_MAX_INBOUND_MESSAGE_SIZE.get()) ), ('grpc.max_receive_message_length', int(CoreConfKeys. EGGROLL_CORE_GRPC_SERVER_CHANNEL_MAX_INBOUND_MESSAGE_SIZE.get()) ), ('grpc.keepalive_time_ms', int(CoreConfKeys.CONFKEY_CORE_GRPC_CHANNEL_KEEPALIVE_TIME_SEC.get()) * 1000), ('grpc.keepalive_timeout_ms', int(CoreConfKeys. CONFKEY_CORE_GRPC_SERVER_CHANNEL_KEEPALIVE_TIMEOUT_SEC.get()) * 1000), ('grpc.keepalive_permit_without_calls', int(CoreConfKeys. CONFKEY_CORE_GRPC_SERVER_CHANNEL_KEEPALIVE_WITHOUT_CALLS_ENABLED. get())), ('grpc.per_rpc_retry_buffer_size', int(CoreConfKeys.CONFKEY_CORE_GRPC_SERVER_CHANNEL_RETRY_BUFFER_SIZE. get())), ('grpc.so_reuseport', False)]) command_servicer = CommandServicer() command_pb2_grpc.add_CommandServiceServicer_to_server( command_servicer, command_server) transfer_servicer = GrpcTransferServicer() port = args.port transfer_port = args.transfer_port port = command_server.add_insecure_port(f'[::]:{port}') if transfer_port == "-1": transfer_server = command_server transfer_port = port transfer_pb2_grpc.add_TransferServiceServicer_to_server( transfer_servicer, transfer_server) else: transfer_server_max_workers = int( RollPairConfKeys. EGGROLL_ROLLPAIR_EGGPAIR_DATA_SERVER_EXECUTOR_POOL_MAX_SIZE.get()) transfer_server = grpc.server( create_executor_pool(canonical_name=executor_pool_type, max_workers=transfer_server_max_workers, thread_name_prefix="transfer_server"), options= [('grpc.max_metadata_size', int(CoreConfKeys. EGGROLL_CORE_GRPC_SERVER_CHANNEL_MAX_INBOUND_METADATA_SIZE. get())), ('grpc.max_send_message_length', int(CoreConfKeys. EGGROLL_CORE_GRPC_SERVER_CHANNEL_MAX_INBOUND_MESSAGE_SIZE. get())), ('grpc.max_receive_message_length', int(CoreConfKeys. EGGROLL_CORE_GRPC_SERVER_CHANNEL_MAX_INBOUND_MESSAGE_SIZE. get())), ('grpc.keepalive_time_ms', int(CoreConfKeys. CONFKEY_CORE_GRPC_SERVER_CHANNEL_KEEPALIVE_WITHOUT_CALLS_ENABLED .get()) * 1000), ('grpc.keepalive_timeout_ms', int(CoreConfKeys. CONFKEY_CORE_GRPC_SERVER_CHANNEL_KEEPALIVE_TIMEOUT_SEC.get()) * 1000), ('grpc.keepalive_permit_without_calls', int(CoreConfKeys. CONFKEY_CORE_GRPC_SERVER_CHANNEL_KEEPALIVE_WITHOUT_CALLS_ENABLED .get())), ('grpc.per_rpc_retry_buffer_size', int(CoreConfKeys. CONFKEY_CORE_GRPC_SERVER_CHANNEL_RETRY_BUFFER_SIZE.get())), ('grpc.so_reuseport', False)]) transfer_port = transfer_server.add_insecure_port( f'[::]:{transfer_port}') transfer_pb2_grpc.add_TransferServiceServicer_to_server( transfer_servicer, transfer_server) transfer_server.start() pid = os.getpid() L.info( f"starting egg_pair service, port: {port}, transfer port: {transfer_port}, pid: {pid}" ) command_server.start() cluster_manager = args.cluster_manager myself = None cluster_manager_client = None if cluster_manager: session_id = args.session_id server_node_id = int(args.server_node_id) static_er_conf = get_static_er_conf() static_er_conf['server_node_id'] = server_node_id if not session_id: raise ValueError('session id is missing') options = {SessionConfKeys.CONFKEY_SESSION_ID: args.session_id} myself = ErProcessor(id=int(args.processor_id), server_node_id=server_node_id, processor_type=ProcessorTypes.EGG_PAIR, command_endpoint=ErEndpoint(host='localhost', port=port), transfer_endpoint=ErEndpoint(host='localhost', port=transfer_port), pid=pid, options=options, status=ProcessorStatus.RUNNING) cluster_manager_host, cluster_manager_port = cluster_manager.strip( ).split(':') L.info(f'egg_pair cluster_manager: {cluster_manager}') cluster_manager_client = ClusterManagerClient( options={ ClusterManagerConfKeys.CONFKEY_CLUSTER_MANAGER_HOST: cluster_manager_host, ClusterManagerConfKeys.CONFKEY_CLUSTER_MANAGER_PORT: cluster_manager_port }) cluster_manager_client.heartbeat(myself) if platform.system() == "Windows": t1 = threading.Thread(target=stop_processor, args=[cluster_manager_client, myself]) t1.start() L.info(f'egg_pair started at port={port}, transfer_port={transfer_port}') run = True def exit_gracefully(signum, frame): nonlocal run run = False L.info( f'egg_pair {args.processor_id} at port={port}, transfer_port={transfer_port}, pid={pid} receives signum={signal.getsignal(signum)}, stopping gracefully.' ) signal.signal(signal.SIGTERM, exit_gracefully) signal.signal(signal.SIGINT, exit_gracefully) while run: time.sleep(1) L.info(f'sending exit heartbeat to cm') if cluster_manager: myself._status = ProcessorStatus.STOPPED cluster_manager_client.heartbeat(myself) GrpcChannelFactory.shutdown_all_now() L.info(f'closing RocksDB open dbs') #todo:1: move to RocksdbAdapter and provide a cleanup method from eggroll.core.pair_store.rocksdb import RocksdbAdapter for path, db in RocksdbAdapter.db_dict.items(): del db gc.collect() L.info(f'system metric at exit: {get_system_metric(1)}') L.info( f'egg_pair {args.processor_id} at port={port}, transfer_port={transfer_port}, pid={pid} stopped gracefully' )
def _push_partition(ertask): rs_header._partition_id = ertask._inputs[0]._id L.trace( f"pushing rollpair partition. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}" ) from eggroll.core.grpc.factory import GrpcChannelFactory from eggroll.core.proto import proxy_pb2_grpc grpc_channel_factory = GrpcChannelFactory() with create_adapter(ertask._inputs[0]) as db, db.iteritems() as rb: # NOTICE AGAIN: all modifications to rs_header are limited in bs_helper. # rs_header is shared by bs_helper and here. any modification in bs_helper affects this header. # Remind that python's object references are passed by value, # meaning the 'pointer' is copied, while the contents are modificable bs_helper = _BatchStreamHelper(rs_header) bin_batch_streams = bs_helper._generate_batch_streams( pair_iter=rb, batches_per_stream=batches_per_stream, body_bytes=body_bytes) channel = grpc_channel_factory.create_channel(endpoint) stub = proxy_pb2_grpc.DataTransferServiceStub(channel) for batch_stream in bin_batch_streams: batch_stream_data = list(batch_stream) cur_retry = 0 exception = None while cur_retry < max_retry_cnt: L.trace( f'pushing rollpair partition stream. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}' ) try: stub.push(bs_helper.generate_packet( batch_stream_data, cur_retry), timeout=per_stream_timeout) exception = None break except Exception as e: if cur_retry < max_retry_cnt - long_retry_cnt: retry_interval = round( min(2 * cur_retry, 20) + random.random() * 10, 3) else: retry_interval = round( 300 + random.random() * 10, 3) L.warn( f"push rp partition error. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, max_retry_cnt={max_retry_cnt}, cur_retry={cur_retry}, retry_interval={retry_interval}", exc_info=e) time.sleep(retry_interval) if isinstance(e, RpcError) and e.code( ).name == 'UNAVAILABLE': channel = grpc_channel_factory.create_channel( endpoint, refresh=True) stub = proxy_pb2_grpc.DataTransferServiceStub( channel) exception = e finally: cur_retry += 1 if exception is not None: L.exception( f"push partition failed. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}", exc_info=exception) raise exception L.trace( f'pushed rollpair partition stream. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, retry count={cur_retry - 1}' ) L.trace( f"pushed rollpair partition. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}" )
def _push_bytes(self, obj, rs_header: ErRollSiteHeader, options: dict = None): if options is None: options = {} start_time = time.time() rs_key = rs_header.get_rs_key() int_size = 4 if L.isEnabledFor(logging.DEBUG): L.debug(f"pushing object: rs_key={rs_key}, rs_header={rs_header}") def _generate_obj_bytes(py_obj, body_bytes): key_id = 0 obj_bytes = pickle.dumps(py_obj) obj_bytes_len = len(obj_bytes) cur_pos = 0 while cur_pos <= obj_bytes_len: yield key_id.to_bytes( int_size, "big"), obj_bytes[cur_pos:cur_pos + body_bytes] key_id += 1 cur_pos += body_bytes rs_header._partition_id = 0 rs_header._total_partitions = 1 serdes = options.get('serdes', None) if serdes is not None: rs_header._options['serdes'] = serdes wrapee_cls = options.get('wrapee_cls', None) if serdes is not None: rs_header._options['wrapee_cls'] = wrapee_cls # NOTICE: all modifications to rs_header are limited in bs_helper. # rs_header is shared by bs_helper and here. any modification in bs_helper affects this header. # Remind that python's object references are passed by value, # meaning the 'pointer' is copied, while the contents are modificable bs_helper = _BatchStreamHelper(rs_header=rs_header) bin_batch_streams = bs_helper._generate_batch_streams( pair_iter=_generate_obj_bytes(obj, self.batch_body_bytes), batches_per_stream=self.push_batches_per_stream, body_bytes=self.batch_body_bytes) grpc_channel_factory = GrpcChannelFactory() channel = grpc_channel_factory.create_channel(self.ctx.proxy_endpoint) stub = proxy_pb2_grpc.DataTransferServiceStub(channel) max_retry_cnt = self.push_max_retry long_retry_cnt = self.push_long_retry per_stream_timeout = self.push_per_stream_timeout # if use stub.push.future here, retry mechanism is a problem to solve for batch_stream in bin_batch_streams: cur_retry = 0 batch_stream_data = list(batch_stream) exception = None while cur_retry < max_retry_cnt: L.trace( f'pushing object stream. rs_key={rs_key}, rs_header={rs_header}, cur_retry={cur_retry}' ) try: stub.push(bs_helper.generate_packet( batch_stream_data, cur_retry), timeout=per_stream_timeout) exception = None break except Exception as e: if cur_retry <= max_retry_cnt - long_retry_cnt: retry_interval = round( min(2 * cur_retry, 20) + random.random() * 10, 3) else: retry_interval = round(300 + random.random() * 10, 3) L.warn( f"push object error. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, max_retry_cnt={max_retry_cnt}, cur_retry={cur_retry}, retry_interval={retry_interval}", exc_info=e) time.sleep(retry_interval) if isinstance(e, RpcError) and e.code().name == 'UNAVAILABLE': channel = grpc_channel_factory.create_channel( self.ctx.proxy_endpoint, refresh=True) stub = proxy_pb2_grpc.DataTransferServiceStub(channel) exception = e finally: cur_retry += 1 if exception is not None: L.exception( f"push object failed. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}", exc_info=exception) raise exception L.trace( f'pushed object stream. rs_key={rs_key}, rs_header={rs_header}, cur_retry={cur_retry - 1}' ) L.debug( f"pushed object: rs_key={rs_key}, rs_header={rs_header}, is_none={obj is None}, elapsed={time.time() - start_time}" ) self.ctx.pushing_latch.count_down()
class RollSiteContext: grpc_channel_factory = GrpcChannelFactory() def __init__(self, roll_site_session_id, rp_ctx: RollPairContext, options: dict = None): if options is None: options = {} self.roll_site_session_id = roll_site_session_id self.rp_ctx = rp_ctx self.push_session_enabled = RollSiteConfKeys.EGGROLL_ROLLSITE_PUSH_SESSION_ENABLED.get_with( options) if self.push_session_enabled: # create session for push roll_pair and object self._push_session = ErSession( session_id=roll_site_session_id + "_push", options=rp_ctx.get_session().get_all_options()) self._push_rp_ctx = RollPairContext(session=self._push_session) L.info( f"push_session={self._push_session.get_session_id()} enabled") def stop_push_session(): self._push_session.stop() else: self._push_session = None self._push_rp_ctx = None self.role = options["self_role"] self.party_id = str(options["self_party_id"]) self._options = options self._registered_comm_types = dict() self.register_comm_type('grpc', RollSiteGrpc) endpoint = options["proxy_endpoint"] if isinstance(endpoint, str): splitted = endpoint.split(':') self.proxy_endpoint = ErEndpoint(host=splitted[0].strip(), port=int(splitted[1].strip())) elif isinstance(endpoint, ErEndpoint): self.proxy_endpoint = endpoint else: raise ValueError("endpoint only support str and ErEndpoint type") self.is_standalone = RollSiteConfKeys.EGGROLL_ROLLSITE_DEPLOY_MODE.get_with( options) == "standalone" # if self.is_standalone: # self.stub = None # else: # channel = self.grpc_channel_factory.create_channel(self.proxy_endpoint) # self.stub = proxy_pb2_grpc.DataTransferServiceStub(channel) self.pushing_latch = CountDownLatch(0) self.rp_ctx.get_session().add_exit_task(self._wait_push_complete) if self.push_session_enabled: self.rp_ctx.get_session().add_exit_task(stop_push_session) self._wait_push_exit_timeout = int( RollSiteConfKeys.EGGROLL_ROLLSITE_PUSH_OVERALL_TIMEOUT_SEC. get_with(options)) L.info(f"inited RollSiteContext: {self.__dict__}") def _wait_push_complete(self): session_id = self.rp_ctx.get_session().get_session_id() L.info(f"running roll site exit func for er session={session_id}," f" roll site session id={self.roll_site_session_id}") residual_count = self.pushing_latch.await_latch( self._wait_push_exit_timeout) if residual_count != 0: L.error( f"exit session when not finish push: " f"residual_count={residual_count}, timeout={self._wait_push_exit_timeout}" ) def load(self, name: str, tag: str, options: dict = None): if options is None: options = {} final_options = self._options.copy() final_options.update(options) return RollSite(name, tag, self, options=final_options) def register_comm_type(self, name, clazz): self._registered_comm_types[name] = clazz def get_comm_impl(self, name): if name in self._registered_comm_types: return self._registered_comm_types[name] else: raise ValueError(f'comm_type={name} is not registered')
class RollSiteContext: grpc_channel_factory = GrpcChannelFactory() def __init__(self, roll_site_session_id, rp_ctx: RollPairContext, options: dict = None): if options is None: options = {} self.roll_site_session_id = roll_site_session_id self.rp_ctx = rp_ctx self.role = options["self_role"] self.party_id = str(options["self_party_id"]) self.proxy_endpoint = options["proxy_endpoint"] self.is_standalone = self.rp_ctx.get_session().get_option( SessionConfKeys.CONFKEY_SESSION_DEPLOY_MODE) == "standalone" if self.is_standalone: self.stub = None else: channel = self.grpc_channel_factory.create_channel( self.proxy_endpoint) self.stub = proxy_pb2_grpc.DataTransferServiceStub(channel) self.init_job_session_pair(self.roll_site_session_id, self.rp_ctx.session_id) self.pushing_task_count = 0 self.rp_ctx.get_session().add_exit_task(self.push_complete) L.info(f"inited RollSiteContext: {self.__dict__}") def push_complete(self): session_id = self.rp_ctx.get_session().get_session_id() L.info( f"running roll site exit func for er session: {session_id}, roll site session id: {self.roll_site_session_id}" ) try_count = 0 max_try_count = 800 while True: if try_count >= max_try_count: L.warn( f"try times reach {max_try_count} for session: {session_id}, exiting" ) return if self.pushing_task_count: L.info( f"session: {session_id} " f"waiting for all push tasks complete. " f"current try_count: {try_count}, " f"current pushing task count: {self.pushing_task_count}") try_count += 1 time.sleep(min(0.1 * try_count, 60)) else: L.info(f"session: {session_id} finishes all pushing tasks") return # todo:1: add options? def load(self, name: str, tag: str, options: dict = None): if options is None: options = {} return RollSite(name, tag, self, options=options) # todo:1: try-except as decorator def init_job_session_pair(self, roll_site_session_id, er_session_id): try: task_info = proxy_pb2.Task(model=proxy_pb2.Model( name=roll_site_session_id, dataKey=bytes(er_session_id, encoding='utf8'))) topic_src = proxy_pb2.Topic(name="init_job_session_pair", partyId=self.party_id, role=self.role, callback=None) topic_dst = proxy_pb2.Topic(name="init_job_session_pair", partyId=self.party_id, role=self.role, callback=None) command_test = proxy_pb2.Command(name="init_job_session_pair") conf_test = proxy_pb2.Conf(overallTimeout=1000, completionWaitTimeout=1000, packetIntervalTimeout=1000, maxRetries=10) metadata = proxy_pb2.Metadata(task=task_info, src=topic_src, dst=topic_dst, command=command_test, operator="init_job_session_pair", seq=0, ack=0) packet = proxy_pb2.Packet(header=metadata) self.stub.unaryCall(packet) L.info( f"send RollSiteContext init to Proxy: {to_one_line_string(packet)}" ) except Exception as e: raise GrpcCallError("init_job_session_pair", self.proxy_endpoint, e)