コード例 #1
0
class TransferClient(object):
    def __init__(self):
        self.__grpc_channel_factory = GrpcChannelFactory()
        #self.__bin_packet_len = 32 << 20
        #self.__chunk_size = 100

    @_exception_logger
    def send(self, broker, endpoint: ErEndpoint, tag):
        try:
            channel = self.__grpc_channel_factory.create_channel(endpoint)

            stub = transfer_pb2_grpc.TransferServiceStub(channel)
            import types
            if isinstance(broker, types.GeneratorType):
                requests = (transfer_pb2.TransferBatch(header=transfer_pb2.TransferHeader(id=i, tag=tag), data=d)
                            for i, d in enumerate(broker))
            else:
                requests = TransferService.transfer_batch_generator_from_broker(broker, tag)
            future = stub.send.future(requests, metadata=[(TRANSFER_BROKER_NAME, tag)])

            return future
        except Exception as e:
            L.error(f'Error calling to {endpoint} in TransferClient.send')
            raise e

    @_exception_logger
    def recv(self, endpoint: ErEndpoint, tag, broker):
        try:
            L.debug(f'TransferClient.recv for endpoint: {endpoint}, tag: {tag}')
            @_exception_logger
            def fill_broker(iterable: Iterable, broker):
                try:
                    iterator = iter(iterable)
                    for e in iterator:
                        broker.put(e)
                    broker.signal_write_finish()
                except Exception as e:
                    L.error(f'Fail to fill broker for tag: {tag}, endpoint: {endpoint}')
                    raise e

            channel = self.__grpc_channel_factory.create_channel(endpoint)

            stub = transfer_pb2_grpc.TransferServiceStub(channel)
            request = transfer_pb2.TransferBatch(
                    header=transfer_pb2.TransferHeader(id=1, tag=tag))

            response_iter = stub.recv(
                    request, metadata=[(TRANSFER_BROKER_NAME, tag)])
            if broker is None:
                return response_iter
            else:
                t = Thread(target=fill_broker, args=[response_iter, broker])
                t.start()
            return broker
        except Exception as e:
            L.error(f'Error calling to {endpoint} in TransferClient.recv')
            raise e
コード例 #2
0
class CommandClient(object):
    executor = ThreadPoolExecutor(max_workers=int(
        CoreConfKeys.EGGROLL_CORE_CLIENT_COMMAND_EXECUTOR_POOL_MAX_SIZE.get()),
                                  thread_name_prefix="command_client")

    def __init__(self):
        self._channel_factory = GrpcChannelFactory()

    def simple_sync_send(self,
                         input: RpcMessage,
                         output_type,
                         endpoint: ErEndpoint,
                         command_uri: CommandURI,
                         serdes_type=SerdesTypes.PROTOBUF):
        results = self.sync_send(inputs=[input],
                                 output_types=[output_type],
                                 endpoint=endpoint,
                                 command_uri=command_uri,
                                 serdes_type=serdes_type)

        if len(results):
            return results[0]
        else:
            return None

    def sync_send(self,
                  inputs: list,
                  output_types: list,
                  endpoint: ErEndpoint,
                  command_uri: CommandURI,
                  serdes_type=SerdesTypes.PROTOBUF):
        request = None
        try:
            request = ErCommandRequest(id=time_now(),
                                       uri=command_uri._uri,
                                       args=_map_and_listify(
                                           _to_proto_string, inputs))
            start = time.time()
            L.debug(f"[CC] calling: {endpoint} {command_uri} {request}")
            _channel = self._channel_factory.create_channel(endpoint)
            _command_stub = command_pb2_grpc.CommandServiceStub(_channel)
            response = _command_stub.call(request.to_proto())
            er_response = ErCommandResponse.from_proto(response)
            elapsed = time.time() - start
            L.debug(
                f"[CC] called (elapsed: {elapsed}): {endpoint}, {command_uri}, {request}, {er_response}"
            )
            byte_results = er_response._results

            if len(byte_results):
                zipped = zip(output_types, byte_results)
                return list(
                    map(
                        lambda t: t[0].from_proto_string(t[1])
                        if t[1] is not None else None, zipped))
            else:
                return []
        except Exception as e:
            L.error(
                f'Error calling to {endpoint}, command_uri: {command_uri}, req:{request}',
                exc_info=e)
            raise CommandCallError(command_uri, endpoint, e)

    def async_call(self,
                   args,
                   output_types: list,
                   command_uri: CommandURI,
                   serdes_type=SerdesTypes.PROTOBUF,
                   callback=None):
        futures = list()
        for inputs, endpoint in args:
            f = self.executor.submit(self.sync_send,
                                     inputs=inputs,
                                     output_types=output_types,
                                     endpoint=endpoint,
                                     command_uri=command_uri,
                                     serdes_type=serdes_type)
            if callback:
                f.add_done_callback(callback)
            futures.append(f)

        return futures
コード例 #3
0
        def _push_partition(ertask):
            rs_header._partition_id = ertask._inputs[0]._id
            L.trace(
                f"pushing rollpair partition. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}"
            )

            from eggroll.core.grpc.factory import GrpcChannelFactory
            from eggroll.core.proto import proxy_pb2_grpc
            grpc_channel_factory = GrpcChannelFactory()

            with create_adapter(ertask._inputs[0]) as db, db.iteritems() as rb:
                # NOTICE AGAIN: all modifications to rs_header are limited in bs_helper.
                # rs_header is shared by bs_helper and here. any modification in bs_helper affects this header.
                # Remind that python's object references are passed by value,
                # meaning the 'pointer' is copied, while the contents are modificable
                bs_helper = _BatchStreamHelper(rs_header)
                bin_batch_streams = bs_helper._generate_batch_streams(
                    pair_iter=rb,
                    batches_per_stream=batches_per_stream,
                    body_bytes=body_bytes)

                channel = grpc_channel_factory.create_channel(endpoint)
                stub = proxy_pb2_grpc.DataTransferServiceStub(channel)
                for batch_stream in bin_batch_streams:
                    batch_stream_data = list(batch_stream)
                    cur_retry = 0
                    exception = None
                    while cur_retry < max_retry_cnt:
                        L.trace(
                            f'pushing rollpair partition stream. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}'
                        )
                        try:
                            stub.push(bs_helper.generate_packet(
                                batch_stream_data, cur_retry),
                                      timeout=per_stream_timeout)
                            exception = None
                            break
                        except Exception as e:
                            if cur_retry < max_retry_cnt - long_retry_cnt:
                                retry_interval = round(
                                    min(2 * cur_retry, 20) +
                                    random.random() * 10, 3)
                            else:
                                retry_interval = round(
                                    300 + random.random() * 10, 3)
                            L.warn(
                                f"push rp partition error. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, max_retry_cnt={max_retry_cnt}, cur_retry={cur_retry}, retry_interval={retry_interval}",
                                exc_info=e)
                            time.sleep(retry_interval)
                            if isinstance(e, RpcError) and e.code(
                            ).name == 'UNAVAILABLE':
                                channel = grpc_channel_factory.create_channel(
                                    endpoint, refresh=True)
                                stub = proxy_pb2_grpc.DataTransferServiceStub(
                                    channel)

                            exception = e
                        finally:
                            cur_retry += 1
                    if exception is not None:
                        L.exception(
                            f"push partition failed. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}",
                            exc_info=exception)
                        raise exception
                    L.trace(
                        f'pushed rollpair partition stream. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, retry count={cur_retry - 1}'
                    )

            L.trace(
                f"pushed rollpair partition. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}"
            )
コード例 #4
0
    def _push_bytes(self,
                    obj,
                    rs_header: ErRollSiteHeader,
                    options: dict = None):
        if options is None:
            options = {}
        start_time = time.time()

        rs_key = rs_header.get_rs_key()
        int_size = 4

        if L.isEnabledFor(logging.DEBUG):
            L.debug(f"pushing object: rs_key={rs_key}, rs_header={rs_header}")

        def _generate_obj_bytes(py_obj, body_bytes):
            key_id = 0
            obj_bytes = pickle.dumps(py_obj)
            obj_bytes_len = len(obj_bytes)
            cur_pos = 0

            while cur_pos <= obj_bytes_len:
                yield key_id.to_bytes(
                    int_size, "big"), obj_bytes[cur_pos:cur_pos + body_bytes]
                key_id += 1
                cur_pos += body_bytes

        rs_header._partition_id = 0
        rs_header._total_partitions = 1

        serdes = options.get('serdes', None)
        if serdes is not None:
            rs_header._options['serdes'] = serdes

        wrapee_cls = options.get('wrapee_cls', None)
        if serdes is not None:
            rs_header._options['wrapee_cls'] = wrapee_cls
        # NOTICE: all modifications to rs_header are limited in bs_helper.
        # rs_header is shared by bs_helper and here. any modification in bs_helper affects this header.
        # Remind that python's object references are passed by value,
        # meaning the 'pointer' is copied, while the contents are modificable
        bs_helper = _BatchStreamHelper(rs_header=rs_header)
        bin_batch_streams = bs_helper._generate_batch_streams(
            pair_iter=_generate_obj_bytes(obj, self.batch_body_bytes),
            batches_per_stream=self.push_batches_per_stream,
            body_bytes=self.batch_body_bytes)

        grpc_channel_factory = GrpcChannelFactory()
        channel = grpc_channel_factory.create_channel(self.ctx.proxy_endpoint)
        stub = proxy_pb2_grpc.DataTransferServiceStub(channel)
        max_retry_cnt = self.push_max_retry
        long_retry_cnt = self.push_long_retry
        per_stream_timeout = self.push_per_stream_timeout

        # if use stub.push.future here, retry mechanism is a problem to solve
        for batch_stream in bin_batch_streams:
            cur_retry = 0

            batch_stream_data = list(batch_stream)
            exception = None
            while cur_retry < max_retry_cnt:
                L.trace(
                    f'pushing object stream. rs_key={rs_key}, rs_header={rs_header}, cur_retry={cur_retry}'
                )
                try:
                    stub.push(bs_helper.generate_packet(
                        batch_stream_data, cur_retry),
                              timeout=per_stream_timeout)
                    exception = None
                    break
                except Exception as e:
                    if cur_retry <= max_retry_cnt - long_retry_cnt:
                        retry_interval = round(
                            min(2 * cur_retry, 20) + random.random() * 10, 3)
                    else:
                        retry_interval = round(300 + random.random() * 10, 3)
                    L.warn(
                        f"push object error. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, max_retry_cnt={max_retry_cnt}, cur_retry={cur_retry}, retry_interval={retry_interval}",
                        exc_info=e)
                    time.sleep(retry_interval)
                    if isinstance(e,
                                  RpcError) and e.code().name == 'UNAVAILABLE':
                        channel = grpc_channel_factory.create_channel(
                            self.ctx.proxy_endpoint, refresh=True)
                        stub = proxy_pb2_grpc.DataTransferServiceStub(channel)
                    exception = e
                finally:
                    cur_retry += 1
            if exception is not None:
                L.exception(
                    f"push object failed. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}",
                    exc_info=exception)
                raise exception
            L.trace(
                f'pushed object stream. rs_key={rs_key}, rs_header={rs_header}, cur_retry={cur_retry - 1}'
            )

        L.debug(
            f"pushed object: rs_key={rs_key}, rs_header={rs_header}, is_none={obj is None}, elapsed={time.time() - start_time}"
        )
        self.ctx.pushing_latch.count_down()