class TransferClient(object): def __init__(self): self.__grpc_channel_factory = GrpcChannelFactory() #self.__bin_packet_len = 32 << 20 #self.__chunk_size = 100 @_exception_logger def send(self, broker, endpoint: ErEndpoint, tag): try: channel = self.__grpc_channel_factory.create_channel(endpoint) stub = transfer_pb2_grpc.TransferServiceStub(channel) import types if isinstance(broker, types.GeneratorType): requests = (transfer_pb2.TransferBatch(header=transfer_pb2.TransferHeader(id=i, tag=tag), data=d) for i, d in enumerate(broker)) else: requests = TransferService.transfer_batch_generator_from_broker(broker, tag) future = stub.send.future(requests, metadata=[(TRANSFER_BROKER_NAME, tag)]) return future except Exception as e: L.error(f'Error calling to {endpoint} in TransferClient.send') raise e @_exception_logger def recv(self, endpoint: ErEndpoint, tag, broker): try: L.debug(f'TransferClient.recv for endpoint: {endpoint}, tag: {tag}') @_exception_logger def fill_broker(iterable: Iterable, broker): try: iterator = iter(iterable) for e in iterator: broker.put(e) broker.signal_write_finish() except Exception as e: L.error(f'Fail to fill broker for tag: {tag}, endpoint: {endpoint}') raise e channel = self.__grpc_channel_factory.create_channel(endpoint) stub = transfer_pb2_grpc.TransferServiceStub(channel) request = transfer_pb2.TransferBatch( header=transfer_pb2.TransferHeader(id=1, tag=tag)) response_iter = stub.recv( request, metadata=[(TRANSFER_BROKER_NAME, tag)]) if broker is None: return response_iter else: t = Thread(target=fill_broker, args=[response_iter, broker]) t.start() return broker except Exception as e: L.error(f'Error calling to {endpoint} in TransferClient.recv') raise e
class CommandClient(object): executor = ThreadPoolExecutor(max_workers=int( CoreConfKeys.EGGROLL_CORE_CLIENT_COMMAND_EXECUTOR_POOL_MAX_SIZE.get()), thread_name_prefix="command_client") def __init__(self): self._channel_factory = GrpcChannelFactory() def simple_sync_send(self, input: RpcMessage, output_type, endpoint: ErEndpoint, command_uri: CommandURI, serdes_type=SerdesTypes.PROTOBUF): results = self.sync_send(inputs=[input], output_types=[output_type], endpoint=endpoint, command_uri=command_uri, serdes_type=serdes_type) if len(results): return results[0] else: return None def sync_send(self, inputs: list, output_types: list, endpoint: ErEndpoint, command_uri: CommandURI, serdes_type=SerdesTypes.PROTOBUF): request = None try: request = ErCommandRequest(id=time_now(), uri=command_uri._uri, args=_map_and_listify( _to_proto_string, inputs)) start = time.time() L.debug(f"[CC] calling: {endpoint} {command_uri} {request}") _channel = self._channel_factory.create_channel(endpoint) _command_stub = command_pb2_grpc.CommandServiceStub(_channel) response = _command_stub.call(request.to_proto()) er_response = ErCommandResponse.from_proto(response) elapsed = time.time() - start L.debug( f"[CC] called (elapsed: {elapsed}): {endpoint}, {command_uri}, {request}, {er_response}" ) byte_results = er_response._results if len(byte_results): zipped = zip(output_types, byte_results) return list( map( lambda t: t[0].from_proto_string(t[1]) if t[1] is not None else None, zipped)) else: return [] except Exception as e: L.error( f'Error calling to {endpoint}, command_uri: {command_uri}, req:{request}', exc_info=e) raise CommandCallError(command_uri, endpoint, e) def async_call(self, args, output_types: list, command_uri: CommandURI, serdes_type=SerdesTypes.PROTOBUF, callback=None): futures = list() for inputs, endpoint in args: f = self.executor.submit(self.sync_send, inputs=inputs, output_types=output_types, endpoint=endpoint, command_uri=command_uri, serdes_type=serdes_type) if callback: f.add_done_callback(callback) futures.append(f) return futures
def _push_partition(ertask): rs_header._partition_id = ertask._inputs[0]._id L.trace( f"pushing rollpair partition. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}" ) from eggroll.core.grpc.factory import GrpcChannelFactory from eggroll.core.proto import proxy_pb2_grpc grpc_channel_factory = GrpcChannelFactory() with create_adapter(ertask._inputs[0]) as db, db.iteritems() as rb: # NOTICE AGAIN: all modifications to rs_header are limited in bs_helper. # rs_header is shared by bs_helper and here. any modification in bs_helper affects this header. # Remind that python's object references are passed by value, # meaning the 'pointer' is copied, while the contents are modificable bs_helper = _BatchStreamHelper(rs_header) bin_batch_streams = bs_helper._generate_batch_streams( pair_iter=rb, batches_per_stream=batches_per_stream, body_bytes=body_bytes) channel = grpc_channel_factory.create_channel(endpoint) stub = proxy_pb2_grpc.DataTransferServiceStub(channel) for batch_stream in bin_batch_streams: batch_stream_data = list(batch_stream) cur_retry = 0 exception = None while cur_retry < max_retry_cnt: L.trace( f'pushing rollpair partition stream. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}' ) try: stub.push(bs_helper.generate_packet( batch_stream_data, cur_retry), timeout=per_stream_timeout) exception = None break except Exception as e: if cur_retry < max_retry_cnt - long_retry_cnt: retry_interval = round( min(2 * cur_retry, 20) + random.random() * 10, 3) else: retry_interval = round( 300 + random.random() * 10, 3) L.warn( f"push rp partition error. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, max_retry_cnt={max_retry_cnt}, cur_retry={cur_retry}, retry_interval={retry_interval}", exc_info=e) time.sleep(retry_interval) if isinstance(e, RpcError) and e.code( ).name == 'UNAVAILABLE': channel = grpc_channel_factory.create_channel( endpoint, refresh=True) stub = proxy_pb2_grpc.DataTransferServiceStub( channel) exception = e finally: cur_retry += 1 if exception is not None: L.exception( f"push partition failed. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}", exc_info=exception) raise exception L.trace( f'pushed rollpair partition stream. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, retry count={cur_retry - 1}' ) L.trace( f"pushed rollpair partition. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}" )
def _push_bytes(self, obj, rs_header: ErRollSiteHeader, options: dict = None): if options is None: options = {} start_time = time.time() rs_key = rs_header.get_rs_key() int_size = 4 if L.isEnabledFor(logging.DEBUG): L.debug(f"pushing object: rs_key={rs_key}, rs_header={rs_header}") def _generate_obj_bytes(py_obj, body_bytes): key_id = 0 obj_bytes = pickle.dumps(py_obj) obj_bytes_len = len(obj_bytes) cur_pos = 0 while cur_pos <= obj_bytes_len: yield key_id.to_bytes( int_size, "big"), obj_bytes[cur_pos:cur_pos + body_bytes] key_id += 1 cur_pos += body_bytes rs_header._partition_id = 0 rs_header._total_partitions = 1 serdes = options.get('serdes', None) if serdes is not None: rs_header._options['serdes'] = serdes wrapee_cls = options.get('wrapee_cls', None) if serdes is not None: rs_header._options['wrapee_cls'] = wrapee_cls # NOTICE: all modifications to rs_header are limited in bs_helper. # rs_header is shared by bs_helper and here. any modification in bs_helper affects this header. # Remind that python's object references are passed by value, # meaning the 'pointer' is copied, while the contents are modificable bs_helper = _BatchStreamHelper(rs_header=rs_header) bin_batch_streams = bs_helper._generate_batch_streams( pair_iter=_generate_obj_bytes(obj, self.batch_body_bytes), batches_per_stream=self.push_batches_per_stream, body_bytes=self.batch_body_bytes) grpc_channel_factory = GrpcChannelFactory() channel = grpc_channel_factory.create_channel(self.ctx.proxy_endpoint) stub = proxy_pb2_grpc.DataTransferServiceStub(channel) max_retry_cnt = self.push_max_retry long_retry_cnt = self.push_long_retry per_stream_timeout = self.push_per_stream_timeout # if use stub.push.future here, retry mechanism is a problem to solve for batch_stream in bin_batch_streams: cur_retry = 0 batch_stream_data = list(batch_stream) exception = None while cur_retry < max_retry_cnt: L.trace( f'pushing object stream. rs_key={rs_key}, rs_header={rs_header}, cur_retry={cur_retry}' ) try: stub.push(bs_helper.generate_packet( batch_stream_data, cur_retry), timeout=per_stream_timeout) exception = None break except Exception as e: if cur_retry <= max_retry_cnt - long_retry_cnt: retry_interval = round( min(2 * cur_retry, 20) + random.random() * 10, 3) else: retry_interval = round(300 + random.random() * 10, 3) L.warn( f"push object error. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, max_retry_cnt={max_retry_cnt}, cur_retry={cur_retry}, retry_interval={retry_interval}", exc_info=e) time.sleep(retry_interval) if isinstance(e, RpcError) and e.code().name == 'UNAVAILABLE': channel = grpc_channel_factory.create_channel( self.ctx.proxy_endpoint, refresh=True) stub = proxy_pb2_grpc.DataTransferServiceStub(channel) exception = e finally: cur_retry += 1 if exception is not None: L.exception( f"push object failed. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}", exc_info=exception) raise exception L.trace( f'pushed object stream. rs_key={rs_key}, rs_header={rs_header}, cur_retry={cur_retry - 1}' ) L.debug( f"pushed object: rs_key={rs_key}, rs_header={rs_header}, is_none={obj is None}, elapsed={time.time() - start_time}" ) self.ctx.pushing_latch.count_down()