class TransferClient(object):
    def __init__(self):
        self.__grpc_channel_factory = GrpcChannelFactory()
        #self.__bin_packet_len = 32 << 20
        #self.__chunk_size = 100

    @_exception_logger
    def send(self, broker, endpoint: ErEndpoint, tag):
        try:
            channel = self.__grpc_channel_factory.create_channel(endpoint)

            stub = transfer_pb2_grpc.TransferServiceStub(channel)
            import types
            if isinstance(broker, types.GeneratorType):
                requests = (transfer_pb2.TransferBatch(header=transfer_pb2.TransferHeader(id=i, tag=tag), data=d)
                            for i, d in enumerate(broker))
            else:
                requests = TransferService.transfer_batch_generator_from_broker(broker, tag)
            future = stub.send.future(requests, metadata=[(TRANSFER_BROKER_NAME, tag)])

            return future
        except Exception as e:
            L.error(f'Error calling to {endpoint} in TransferClient.send')
            raise e

    @_exception_logger
    def recv(self, endpoint: ErEndpoint, tag, broker):
        try:
            L.debug(f'TransferClient.recv for endpoint: {endpoint}, tag: {tag}')
            @_exception_logger
            def fill_broker(iterable: Iterable, broker):
                try:
                    iterator = iter(iterable)
                    for e in iterator:
                        broker.put(e)
                    broker.signal_write_finish()
                except Exception as e:
                    L.error(f'Fail to fill broker for tag: {tag}, endpoint: {endpoint}')
                    raise e

            channel = self.__grpc_channel_factory.create_channel(endpoint)

            stub = transfer_pb2_grpc.TransferServiceStub(channel)
            request = transfer_pb2.TransferBatch(
                    header=transfer_pb2.TransferHeader(id=1, tag=tag))

            response_iter = stub.recv(
                    request, metadata=[(TRANSFER_BROKER_NAME, tag)])
            if broker is None:
                return response_iter
            else:
                t = Thread(target=fill_broker, args=[response_iter, broker])
                t.start()
            return broker
        except Exception as e:
            L.error(f'Error calling to {endpoint} in TransferClient.recv')
            raise e
Exemple #2
0
 def __init__(self):
     self._channel_factory = GrpcChannelFactory()
     if CommandClient._executor_pool is None:
         with CommandClient._executor_pool_lock:
             if CommandClient._executor_pool is None:
                 _executor_pool_type = CoreConfKeys.EGGROLL_CORE_DEFAULT_EXECUTOR_POOL.get(
                 )
                 _max_workers = int(
                     CoreConfKeys.
                     EGGROLL_CORE_CLIENT_COMMAND_EXECUTOR_POOL_MAX_SIZE.get(
                     ))
                 CommandClient._executor_pool = create_executor_pool(
                     canonical_name=_executor_pool_type,
                     max_workers=_max_workers,
                     thread_name_prefix="command_client")
Exemple #3
0
class RollSiteContext:
    grpc_channel_factory = GrpcChannelFactory()

    def __init__(self,
                 roll_site_session_id,
                 rp_ctx: RollPairContext,
                 options: dict = None):
        if options is None:
            options = {}
        self.roll_site_session_id = roll_site_session_id
        self.rp_ctx = rp_ctx

        self.role = options["self_role"]
        self.party_id = str(options["self_party_id"])
        self._options = options
        endpoint = options["proxy_endpoint"]
        if isinstance(endpoint, str):
            splitted = endpoint.split(':')
            self.proxy_endpoint = ErEndpoint(host=splitted[0].strip(),
                                             port=int(splitted[1].strip()))
        elif isinstance(endpoint, ErEndpoint):
            self.proxy_endpoint = endpoint
        else:
            raise ValueError("endpoint only support str and ErEndpoint type")

        self.is_standalone = RollSiteConfKeys.EGGROLL_ROLLSITE_DEPLOY_MODE.get_with(
            options) == "standalone"
        if self.is_standalone:
            self.stub = None
        else:
            channel = self.grpc_channel_factory.create_channel(
                self.proxy_endpoint)
            self.stub = proxy_pb2_grpc.DataTransferServiceStub(channel)

        self.pushing_latch = CountDownLatch(0)
        self.rp_ctx.get_session().add_exit_task(self._wait_push_complete)
        self._wait_push_exit_timeout = int(
            RollSiteConfKeys.EGGROLL_ROLLSITE_PUSH_OVERALL_TIMEOUT_SEC.
            get_with(options))

        L.info(f"inited RollSiteContext: {self.__dict__}")

    def _wait_push_complete(self):
        session_id = self.rp_ctx.get_session().get_session_id()
        L.info(f"running roll site exit func for er session={session_id},"
               f" roll site session id={self.roll_site_session_id}")
        residual_count = self.pushing_latch.await_latch(
            self._wait_push_exit_timeout)
        if residual_count != 0:
            L.error(
                f"exit session when not finish push: "
                f"residual_count={residual_count}, timeout={self._wait_push_exit_timeout}"
            )

    def load(self, name: str, tag: str, options: dict = None):
        if options is None:
            options = {}
        final_options = self._options.copy().update(options)
        return RollSite(name, tag, self, options=final_options)
Exemple #4
0
class RollSiteWriteBatch(PairWriteBatch):
    grpc_channel_factory = GrpcChannelFactory()

    # TODO:0: check if secure channel needed
    def __init__(self, adapter: RollSiteAdapter, options: dict = None):
        if options is None:
            options = {}
        self.adapter = adapter

        self.roll_site_header: ErRollSiteHeader = adapter.roll_site_header
        self.namespace = adapter.namespace
        self.name = create_store_name(self.roll_site_header)

        self.tagged_key = ''
        self.obj_type = adapter.obj_type

        self.proxy_endpoint = adapter.proxy_endpoint
        channel = self.grpc_channel_factory.create_channel(self.proxy_endpoint)
        self.stub = proxy_pb2_grpc.DataTransferServiceStub(channel)

        static_er_conf = get_static_er_conf()
        self.__bin_packet_len = int(
            options.get(
                RollSiteConfKeys.EGGROLL_ROLLSITE_ADAPTER_SENDBUF_SIZE.key,
                static_er_conf.get(
                    RollSiteConfKeys.EGGROLL_ROLLSITE_ADAPTER_SENDBUF_SIZE.key,
                    RollSiteConfKeys.EGGROLL_ROLLSITE_ADAPTER_SENDBUF_SIZE.
                    default_value)))
        self.total_written = 0

        self.ba = bytearray(self.__bin_packet_len)
        self.buffer = ArrayByteBuffer(self.ba)
        self.writer = PairBinWriter(pair_buffer=self.buffer)

        self.push_cnt = 0

        self.topic_src = proxy_pb2.Topic(
            name=self.name,
            partyId=self.roll_site_header._src_party_id,
            role=self.roll_site_header._src_role,
            callback=None)
        self.topic_dst = proxy_pb2.Topic(
            name=self.name,
            partyId=self.roll_site_header._dst_party_id,
            role=self.roll_site_header._dst_role,
            callback=None)

    def __repr__(self):
        return f'<ErRollSiteWriteBatch(' \
               f'adapter={self.adapter}, ' \
               f'roll_site_header={self.roll_site_header}' \
               f'namespace={self.namespace}, ' \
               f'name={self.name}, ' \
               f'obj_type={self.obj_type}, ' \
               f'proxy_endpoint={self.proxy_endpoint}) ' \
               f'at {hex(id(self))}>'

    def generate_message(self, obj, metadata):
        data = proxy_pb2.Data(value=obj)
        metadata.seq += 1
        packet = proxy_pb2.Packet(header=metadata, body=data)
        yield packet

    # TODO:0: configurable
    def push(self, obj):
        L.debug(
            f'pushing for task: {self.name}, partition id: {self.adapter.partition_id}, push cnt: {self.get_push_count()}'
        )
        task_info = proxy_pb2.Task(
            taskId=self.name,
            model=proxy_pb2.Model(name=self.adapter.roll_site_header_string,
                                  dataKey=self.namespace))

        command_test = proxy_pb2.Command()

        # TODO: conf test as config and use it
        conf_test = proxy_pb2.Conf(overallTimeout=200000,
                                   completionWaitTimeout=200000,
                                   packetIntervalTimeout=200000,
                                   maxRetries=10)

        metadata = proxy_pb2.Metadata(task=task_info,
                                      src=self.topic_src,
                                      dst=self.topic_dst,
                                      command=command_test,
                                      seq=0,
                                      ack=0)

        max_retry_cnt = 100
        exception = None
        for i in range(1, max_retry_cnt + 1):
            try:
                self.stub.push(self.generate_message(obj, metadata))
                exception = None
                self.increase_push_count()
                break
            except Exception as e:
                exception = e
                L.info(
                    f'caught exception in pushing {self.name}, partition_id: {self.adapter.partition_id}: {e}. retrying. current retry count: {i}, max_retry_cnt: {max_retry_cnt}'
                )
                time.sleep(min(0.1 * i, 30))

        if exception:
            raise GrpcCallError("error in push", self.proxy_endpoint,
                                exception)

    def write(self):
        bin_data = bytes(self.ba[0:self.buffer.get_offset()])
        self.push(bin_data)
        self.buffer = ArrayByteBuffer(self.ba)

    def send_end(self):
        L.info(f"send_end tagged_key:{self.tagged_key}")
        task_info = proxy_pb2.Task(
            taskId=self.name,
            model=proxy_pb2.Model(name=self.adapter.roll_site_header_string,
                                  dataKey=self.namespace))

        command_test = proxy_pb2.Command(name="set_status")
        conf_test = proxy_pb2.Conf(overallTimeout=20000,
                                   completionWaitTimeout=20000,
                                   packetIntervalTimeout=20000,
                                   maxRetries=10)

        metadata = proxy_pb2.Metadata(task=task_info,
                                      src=self.topic_src,
                                      dst=self.topic_dst,
                                      command=command_test,
                                      operator="markEnd",
                                      seq=self.get_push_count(),
                                      ack=0)

        packet = proxy_pb2.Packet(header=metadata)

        try:
            # TODO:0: retry and sleep for all grpc call in RollSite
            self.stub.unaryCall(packet)
        except Exception as e:
            raise GrpcCallError('send_end', self.proxy_endpoint, e)

    def close(self):
        bin_batch = bytes(self.ba[0:self.buffer.get_offset()])
        self.push(bin_batch)
        self.send_end()
        L.info(f'closing RollSiteWriteBatch for name: {self.name}, '
               f'total push count: {self.push_cnt}')

    def put(self, k, v):
        if self.obj_type == 'object':
            L.debug(f"set tagged_key: {k}")
            self.tagged_key = _serdes.deserialize(k)
        try:
            self.writer.write(k, v)
        except IndexError as e:
            bin_batch = bytes(self.ba[0:self.buffer.get_offset()])
            self.push(bin_batch)
            # TODO:0: replace 1024 with constant
            self.ba = bytearray(
                max(self.__bin_packet_len,
                    len(k) + len(v) + 1024))
            self.buffer = ArrayByteBuffer(self.ba)
            self.writer = PairBinWriter(pair_buffer=self.buffer)
            self.writer.write(k, v)
        except Exception as e:
            L.error(f"Unexpected error: {sys.exc_info()[0]}")
            raise e

    def increase_push_count(self):
        self.push_cnt += 1

    def get_push_count(self):
        return self.push_cnt
Exemple #5
0
 def __init__(self):
     self._channel_factory = GrpcChannelFactory()
Exemple #6
0
class CommandClient(object):
    executor = ThreadPoolExecutor(max_workers=int(
        CoreConfKeys.EGGROLL_CORE_CLIENT_COMMAND_EXECUTOR_POOL_MAX_SIZE.get()),
                                  thread_name_prefix="command_client")

    def __init__(self):
        self._channel_factory = GrpcChannelFactory()

    def simple_sync_send(self,
                         input: RpcMessage,
                         output_type,
                         endpoint: ErEndpoint,
                         command_uri: CommandURI,
                         serdes_type=SerdesTypes.PROTOBUF):
        results = self.sync_send(inputs=[input],
                                 output_types=[output_type],
                                 endpoint=endpoint,
                                 command_uri=command_uri,
                                 serdes_type=serdes_type)

        if len(results):
            return results[0]
        else:
            return None

    def sync_send(self,
                  inputs: list,
                  output_types: list,
                  endpoint: ErEndpoint,
                  command_uri: CommandURI,
                  serdes_type=SerdesTypes.PROTOBUF):
        request = None
        try:
            request = ErCommandRequest(id=time_now(),
                                       uri=command_uri._uri,
                                       args=_map_and_listify(
                                           _to_proto_string, inputs))
            start = time.time()
            L.debug(f"[CC] calling: {endpoint} {command_uri} {request}")
            _channel = self._channel_factory.create_channel(endpoint)
            _command_stub = command_pb2_grpc.CommandServiceStub(_channel)
            response = _command_stub.call(request.to_proto())
            er_response = ErCommandResponse.from_proto(response)
            elapsed = time.time() - start
            L.debug(
                f"[CC] called (elapsed: {elapsed}): {endpoint}, {command_uri}, {request}, {er_response}"
            )
            byte_results = er_response._results

            if len(byte_results):
                zipped = zip(output_types, byte_results)
                return list(
                    map(
                        lambda t: t[0].from_proto_string(t[1])
                        if t[1] is not None else None, zipped))
            else:
                return []
        except Exception as e:
            L.error(
                f'Error calling to {endpoint}, command_uri: {command_uri}, req:{request}',
                exc_info=e)
            raise CommandCallError(command_uri, endpoint, e)

    def async_call(self,
                   args,
                   output_types: list,
                   command_uri: CommandURI,
                   serdes_type=SerdesTypes.PROTOBUF,
                   callback=None):
        futures = list()
        for inputs, endpoint in args:
            f = self.executor.submit(self.sync_send,
                                     inputs=inputs,
                                     output_types=output_types,
                                     endpoint=endpoint,
                                     command_uri=command_uri,
                                     serdes_type=serdes_type)
            if callback:
                f.add_done_callback(callback)
            futures.append(f)

        return futures
Exemple #7
0
def serve(args):
    prefix = 'v1/egg-pair'

    set_data_dir(args.data_dir)

    CommandRouter.get_instance().register(
        service_name=f"{prefix}/runTask",
        route_to_module_name="eggroll.roll_pair.egg_pair",
        route_to_class_name="EggPair",
        route_to_method_name="run_task")

    max_workers = int(
        RollPairConfKeys.
        EGGROLL_ROLLPAIR_EGGPAIR_SERVER_EXECUTOR_POOL_MAX_SIZE.get())
    executor_pool_type = CoreConfKeys.EGGROLL_CORE_DEFAULT_EXECUTOR_POOL.get()
    command_server = grpc.server(
        create_executor_pool(canonical_name=executor_pool_type,
                             max_workers=max_workers,
                             thread_name_prefix="eggpair-command-server"),
        options=
        [("grpc.max_metadata_size",
          int(CoreConfKeys.
              EGGROLL_CORE_GRPC_SERVER_CHANNEL_MAX_INBOUND_METADATA_SIZE.get())
          ),
         ('grpc.max_send_message_length',
          int(CoreConfKeys.
              EGGROLL_CORE_GRPC_SERVER_CHANNEL_MAX_INBOUND_MESSAGE_SIZE.get())
          ),
         ('grpc.max_receive_message_length',
          int(CoreConfKeys.
              EGGROLL_CORE_GRPC_SERVER_CHANNEL_MAX_INBOUND_MESSAGE_SIZE.get())
          ),
         ('grpc.keepalive_time_ms',
          int(CoreConfKeys.CONFKEY_CORE_GRPC_CHANNEL_KEEPALIVE_TIME_SEC.get())
          * 1000),
         ('grpc.keepalive_timeout_ms',
          int(CoreConfKeys.
              CONFKEY_CORE_GRPC_SERVER_CHANNEL_KEEPALIVE_TIMEOUT_SEC.get()) *
          1000),
         ('grpc.keepalive_permit_without_calls',
          int(CoreConfKeys.
              CONFKEY_CORE_GRPC_SERVER_CHANNEL_KEEPALIVE_WITHOUT_CALLS_ENABLED.
              get())),
         ('grpc.per_rpc_retry_buffer_size',
          int(CoreConfKeys.CONFKEY_CORE_GRPC_SERVER_CHANNEL_RETRY_BUFFER_SIZE.
              get())), ('grpc.so_reuseport', False)])

    command_servicer = CommandServicer()
    command_pb2_grpc.add_CommandServiceServicer_to_server(
        command_servicer, command_server)

    transfer_servicer = GrpcTransferServicer()

    port = args.port
    transfer_port = args.transfer_port

    port = command_server.add_insecure_port(f'[::]:{port}')

    if transfer_port == "-1":
        transfer_server = command_server
        transfer_port = port
        transfer_pb2_grpc.add_TransferServiceServicer_to_server(
            transfer_servicer, transfer_server)
    else:
        transfer_server_max_workers = int(
            RollPairConfKeys.
            EGGROLL_ROLLPAIR_EGGPAIR_DATA_SERVER_EXECUTOR_POOL_MAX_SIZE.get())
        transfer_server = grpc.server(
            create_executor_pool(canonical_name=executor_pool_type,
                                 max_workers=transfer_server_max_workers,
                                 thread_name_prefix="transfer_server"),
            options=
            [('grpc.max_metadata_size',
              int(CoreConfKeys.
                  EGGROLL_CORE_GRPC_SERVER_CHANNEL_MAX_INBOUND_METADATA_SIZE.
                  get())),
             ('grpc.max_send_message_length',
              int(CoreConfKeys.
                  EGGROLL_CORE_GRPC_SERVER_CHANNEL_MAX_INBOUND_MESSAGE_SIZE.
                  get())),
             ('grpc.max_receive_message_length',
              int(CoreConfKeys.
                  EGGROLL_CORE_GRPC_SERVER_CHANNEL_MAX_INBOUND_MESSAGE_SIZE.
                  get())),
             ('grpc.keepalive_time_ms',
              int(CoreConfKeys.
                  CONFKEY_CORE_GRPC_SERVER_CHANNEL_KEEPALIVE_WITHOUT_CALLS_ENABLED
                  .get()) * 1000),
             ('grpc.keepalive_timeout_ms',
              int(CoreConfKeys.
                  CONFKEY_CORE_GRPC_SERVER_CHANNEL_KEEPALIVE_TIMEOUT_SEC.get())
              * 1000),
             ('grpc.keepalive_permit_without_calls',
              int(CoreConfKeys.
                  CONFKEY_CORE_GRPC_SERVER_CHANNEL_KEEPALIVE_WITHOUT_CALLS_ENABLED
                  .get())),
             ('grpc.per_rpc_retry_buffer_size',
              int(CoreConfKeys.
                  CONFKEY_CORE_GRPC_SERVER_CHANNEL_RETRY_BUFFER_SIZE.get())),
             ('grpc.so_reuseport', False)])
        transfer_port = transfer_server.add_insecure_port(
            f'[::]:{transfer_port}')
        transfer_pb2_grpc.add_TransferServiceServicer_to_server(
            transfer_servicer, transfer_server)
        transfer_server.start()
    pid = os.getpid()

    L.info(
        f"starting egg_pair service, port: {port}, transfer port: {transfer_port}, pid: {pid}"
    )
    command_server.start()

    cluster_manager = args.cluster_manager
    myself = None
    cluster_manager_client = None
    if cluster_manager:
        session_id = args.session_id
        server_node_id = int(args.server_node_id)
        static_er_conf = get_static_er_conf()
        static_er_conf['server_node_id'] = server_node_id

        if not session_id:
            raise ValueError('session id is missing')
        options = {SessionConfKeys.CONFKEY_SESSION_ID: args.session_id}
        myself = ErProcessor(id=int(args.processor_id),
                             server_node_id=server_node_id,
                             processor_type=ProcessorTypes.EGG_PAIR,
                             command_endpoint=ErEndpoint(host='localhost',
                                                         port=port),
                             transfer_endpoint=ErEndpoint(host='localhost',
                                                          port=transfer_port),
                             pid=pid,
                             options=options,
                             status=ProcessorStatus.RUNNING)

        cluster_manager_host, cluster_manager_port = cluster_manager.strip(
        ).split(':')

        L.info(f'egg_pair cluster_manager: {cluster_manager}')
        cluster_manager_client = ClusterManagerClient(
            options={
                ClusterManagerConfKeys.CONFKEY_CLUSTER_MANAGER_HOST:
                cluster_manager_host,
                ClusterManagerConfKeys.CONFKEY_CLUSTER_MANAGER_PORT:
                cluster_manager_port
            })
        cluster_manager_client.heartbeat(myself)

        if platform.system() == "Windows":
            t1 = threading.Thread(target=stop_processor,
                                  args=[cluster_manager_client, myself])
            t1.start()

    L.info(f'egg_pair started at port={port}, transfer_port={transfer_port}')

    run = True

    def exit_gracefully(signum, frame):
        nonlocal run
        run = False
        L.info(
            f'egg_pair {args.processor_id} at port={port}, transfer_port={transfer_port}, pid={pid} receives signum={signal.getsignal(signum)}, stopping gracefully.'
        )

    signal.signal(signal.SIGTERM, exit_gracefully)
    signal.signal(signal.SIGINT, exit_gracefully)

    while run:
        time.sleep(1)

    L.info(f'sending exit heartbeat to cm')
    if cluster_manager:
        myself._status = ProcessorStatus.STOPPED
        cluster_manager_client.heartbeat(myself)

    GrpcChannelFactory.shutdown_all_now()

    L.info(f'closing RocksDB open dbs')
    #todo:1: move to RocksdbAdapter and provide a cleanup method
    from eggroll.core.pair_store.rocksdb import RocksdbAdapter
    for path, db in RocksdbAdapter.db_dict.items():
        del db

    gc.collect()

    L.info(f'system metric at exit: {get_system_metric(1)}')
    L.info(
        f'egg_pair {args.processor_id} at port={port}, transfer_port={transfer_port}, pid={pid} stopped gracefully'
    )
Exemple #8
0
        def _push_partition(ertask):
            rs_header._partition_id = ertask._inputs[0]._id
            L.trace(
                f"pushing rollpair partition. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}"
            )

            from eggroll.core.grpc.factory import GrpcChannelFactory
            from eggroll.core.proto import proxy_pb2_grpc
            grpc_channel_factory = GrpcChannelFactory()

            with create_adapter(ertask._inputs[0]) as db, db.iteritems() as rb:
                # NOTICE AGAIN: all modifications to rs_header are limited in bs_helper.
                # rs_header is shared by bs_helper and here. any modification in bs_helper affects this header.
                # Remind that python's object references are passed by value,
                # meaning the 'pointer' is copied, while the contents are modificable
                bs_helper = _BatchStreamHelper(rs_header)
                bin_batch_streams = bs_helper._generate_batch_streams(
                    pair_iter=rb,
                    batches_per_stream=batches_per_stream,
                    body_bytes=body_bytes)

                channel = grpc_channel_factory.create_channel(endpoint)
                stub = proxy_pb2_grpc.DataTransferServiceStub(channel)
                for batch_stream in bin_batch_streams:
                    batch_stream_data = list(batch_stream)
                    cur_retry = 0
                    exception = None
                    while cur_retry < max_retry_cnt:
                        L.trace(
                            f'pushing rollpair partition stream. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}'
                        )
                        try:
                            stub.push(bs_helper.generate_packet(
                                batch_stream_data, cur_retry),
                                      timeout=per_stream_timeout)
                            exception = None
                            break
                        except Exception as e:
                            if cur_retry < max_retry_cnt - long_retry_cnt:
                                retry_interval = round(
                                    min(2 * cur_retry, 20) +
                                    random.random() * 10, 3)
                            else:
                                retry_interval = round(
                                    300 + random.random() * 10, 3)
                            L.warn(
                                f"push rp partition error. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, max_retry_cnt={max_retry_cnt}, cur_retry={cur_retry}, retry_interval={retry_interval}",
                                exc_info=e)
                            time.sleep(retry_interval)
                            if isinstance(e, RpcError) and e.code(
                            ).name == 'UNAVAILABLE':
                                channel = grpc_channel_factory.create_channel(
                                    endpoint, refresh=True)
                                stub = proxy_pb2_grpc.DataTransferServiceStub(
                                    channel)

                            exception = e
                        finally:
                            cur_retry += 1
                    if exception is not None:
                        L.exception(
                            f"push partition failed. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}",
                            exc_info=exception)
                        raise exception
                    L.trace(
                        f'pushed rollpair partition stream. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, retry count={cur_retry - 1}'
                    )

            L.trace(
                f"pushed rollpair partition. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}"
            )
Exemple #9
0
    def _push_bytes(self,
                    obj,
                    rs_header: ErRollSiteHeader,
                    options: dict = None):
        if options is None:
            options = {}
        start_time = time.time()

        rs_key = rs_header.get_rs_key()
        int_size = 4

        if L.isEnabledFor(logging.DEBUG):
            L.debug(f"pushing object: rs_key={rs_key}, rs_header={rs_header}")

        def _generate_obj_bytes(py_obj, body_bytes):
            key_id = 0
            obj_bytes = pickle.dumps(py_obj)
            obj_bytes_len = len(obj_bytes)
            cur_pos = 0

            while cur_pos <= obj_bytes_len:
                yield key_id.to_bytes(
                    int_size, "big"), obj_bytes[cur_pos:cur_pos + body_bytes]
                key_id += 1
                cur_pos += body_bytes

        rs_header._partition_id = 0
        rs_header._total_partitions = 1

        serdes = options.get('serdes', None)
        if serdes is not None:
            rs_header._options['serdes'] = serdes

        wrapee_cls = options.get('wrapee_cls', None)
        if serdes is not None:
            rs_header._options['wrapee_cls'] = wrapee_cls
        # NOTICE: all modifications to rs_header are limited in bs_helper.
        # rs_header is shared by bs_helper and here. any modification in bs_helper affects this header.
        # Remind that python's object references are passed by value,
        # meaning the 'pointer' is copied, while the contents are modificable
        bs_helper = _BatchStreamHelper(rs_header=rs_header)
        bin_batch_streams = bs_helper._generate_batch_streams(
            pair_iter=_generate_obj_bytes(obj, self.batch_body_bytes),
            batches_per_stream=self.push_batches_per_stream,
            body_bytes=self.batch_body_bytes)

        grpc_channel_factory = GrpcChannelFactory()
        channel = grpc_channel_factory.create_channel(self.ctx.proxy_endpoint)
        stub = proxy_pb2_grpc.DataTransferServiceStub(channel)
        max_retry_cnt = self.push_max_retry
        long_retry_cnt = self.push_long_retry
        per_stream_timeout = self.push_per_stream_timeout

        # if use stub.push.future here, retry mechanism is a problem to solve
        for batch_stream in bin_batch_streams:
            cur_retry = 0

            batch_stream_data = list(batch_stream)
            exception = None
            while cur_retry < max_retry_cnt:
                L.trace(
                    f'pushing object stream. rs_key={rs_key}, rs_header={rs_header}, cur_retry={cur_retry}'
                )
                try:
                    stub.push(bs_helper.generate_packet(
                        batch_stream_data, cur_retry),
                              timeout=per_stream_timeout)
                    exception = None
                    break
                except Exception as e:
                    if cur_retry <= max_retry_cnt - long_retry_cnt:
                        retry_interval = round(
                            min(2 * cur_retry, 20) + random.random() * 10, 3)
                    else:
                        retry_interval = round(300 + random.random() * 10, 3)
                    L.warn(
                        f"push object error. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, max_retry_cnt={max_retry_cnt}, cur_retry={cur_retry}, retry_interval={retry_interval}",
                        exc_info=e)
                    time.sleep(retry_interval)
                    if isinstance(e,
                                  RpcError) and e.code().name == 'UNAVAILABLE':
                        channel = grpc_channel_factory.create_channel(
                            self.ctx.proxy_endpoint, refresh=True)
                        stub = proxy_pb2_grpc.DataTransferServiceStub(channel)
                    exception = e
                finally:
                    cur_retry += 1
            if exception is not None:
                L.exception(
                    f"push object failed. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}",
                    exc_info=exception)
                raise exception
            L.trace(
                f'pushed object stream. rs_key={rs_key}, rs_header={rs_header}, cur_retry={cur_retry - 1}'
            )

        L.debug(
            f"pushed object: rs_key={rs_key}, rs_header={rs_header}, is_none={obj is None}, elapsed={time.time() - start_time}"
        )
        self.ctx.pushing_latch.count_down()
Exemple #10
0
class RollSiteContext:
    grpc_channel_factory = GrpcChannelFactory()

    def __init__(self,
                 roll_site_session_id,
                 rp_ctx: RollPairContext,
                 options: dict = None):
        if options is None:
            options = {}
        self.roll_site_session_id = roll_site_session_id
        self.rp_ctx = rp_ctx

        self.push_session_enabled = RollSiteConfKeys.EGGROLL_ROLLSITE_PUSH_SESSION_ENABLED.get_with(
            options)
        if self.push_session_enabled:
            # create session for push roll_pair and object
            self._push_session = ErSession(
                session_id=roll_site_session_id + "_push",
                options=rp_ctx.get_session().get_all_options())
            self._push_rp_ctx = RollPairContext(session=self._push_session)
            L.info(
                f"push_session={self._push_session.get_session_id()} enabled")

            def stop_push_session():
                self._push_session.stop()
        else:
            self._push_session = None
            self._push_rp_ctx = None

        self.role = options["self_role"]
        self.party_id = str(options["self_party_id"])
        self._options = options

        self._registered_comm_types = dict()
        self.register_comm_type('grpc', RollSiteGrpc)

        endpoint = options["proxy_endpoint"]
        if isinstance(endpoint, str):
            splitted = endpoint.split(':')
            self.proxy_endpoint = ErEndpoint(host=splitted[0].strip(),
                                             port=int(splitted[1].strip()))
        elif isinstance(endpoint, ErEndpoint):
            self.proxy_endpoint = endpoint
        else:
            raise ValueError("endpoint only support str and ErEndpoint type")

        self.is_standalone = RollSiteConfKeys.EGGROLL_ROLLSITE_DEPLOY_MODE.get_with(
            options) == "standalone"
        # if self.is_standalone:
        #     self.stub = None
        # else:
        #     channel = self.grpc_channel_factory.create_channel(self.proxy_endpoint)
        #     self.stub = proxy_pb2_grpc.DataTransferServiceStub(channel)

        self.pushing_latch = CountDownLatch(0)
        self.rp_ctx.get_session().add_exit_task(self._wait_push_complete)
        if self.push_session_enabled:
            self.rp_ctx.get_session().add_exit_task(stop_push_session)
        self._wait_push_exit_timeout = int(
            RollSiteConfKeys.EGGROLL_ROLLSITE_PUSH_OVERALL_TIMEOUT_SEC.
            get_with(options))

        L.info(f"inited RollSiteContext: {self.__dict__}")

    def _wait_push_complete(self):
        session_id = self.rp_ctx.get_session().get_session_id()
        L.info(f"running roll site exit func for er session={session_id},"
               f" roll site session id={self.roll_site_session_id}")
        residual_count = self.pushing_latch.await_latch(
            self._wait_push_exit_timeout)
        if residual_count != 0:
            L.error(
                f"exit session when not finish push: "
                f"residual_count={residual_count}, timeout={self._wait_push_exit_timeout}"
            )

    def load(self, name: str, tag: str, options: dict = None):
        if options is None:
            options = {}
        final_options = self._options.copy()
        final_options.update(options)
        return RollSite(name, tag, self, options=final_options)

    def register_comm_type(self, name, clazz):
        self._registered_comm_types[name] = clazz

    def get_comm_impl(self, name):
        if name in self._registered_comm_types:
            return self._registered_comm_types[name]
        else:
            raise ValueError(f'comm_type={name} is not registered')
Exemple #11
0
class RollSiteContext:
    grpc_channel_factory = GrpcChannelFactory()

    def __init__(self,
                 roll_site_session_id,
                 rp_ctx: RollPairContext,
                 options: dict = None):
        if options is None:
            options = {}
        self.roll_site_session_id = roll_site_session_id
        self.rp_ctx = rp_ctx

        self.role = options["self_role"]
        self.party_id = str(options["self_party_id"])
        self.proxy_endpoint = options["proxy_endpoint"]
        self.is_standalone = self.rp_ctx.get_session().get_option(
            SessionConfKeys.CONFKEY_SESSION_DEPLOY_MODE) == "standalone"
        if self.is_standalone:
            self.stub = None
        else:
            channel = self.grpc_channel_factory.create_channel(
                self.proxy_endpoint)
            self.stub = proxy_pb2_grpc.DataTransferServiceStub(channel)
            self.init_job_session_pair(self.roll_site_session_id,
                                       self.rp_ctx.session_id)

        self.pushing_task_count = 0
        self.rp_ctx.get_session().add_exit_task(self.push_complete)

        L.info(f"inited RollSiteContext: {self.__dict__}")

    def push_complete(self):
        session_id = self.rp_ctx.get_session().get_session_id()
        L.info(
            f"running roll site exit func for er session: {session_id}, roll site session id: {self.roll_site_session_id}"
        )
        try_count = 0
        max_try_count = 800
        while True:
            if try_count >= max_try_count:
                L.warn(
                    f"try times reach {max_try_count} for session: {session_id}, exiting"
                )
                return
            if self.pushing_task_count:
                L.info(
                    f"session: {session_id} "
                    f"waiting for all push tasks complete. "
                    f"current try_count: {try_count}, "
                    f"current pushing task count: {self.pushing_task_count}")
                try_count += 1
                time.sleep(min(0.1 * try_count, 60))
            else:
                L.info(f"session: {session_id} finishes all pushing tasks")
                return

    # todo:1: add options?
    def load(self, name: str, tag: str, options: dict = None):
        if options is None:
            options = {}
        return RollSite(name, tag, self, options=options)

    # todo:1: try-except as decorator
    def init_job_session_pair(self, roll_site_session_id, er_session_id):
        try:
            task_info = proxy_pb2.Task(model=proxy_pb2.Model(
                name=roll_site_session_id,
                dataKey=bytes(er_session_id, encoding='utf8')))
            topic_src = proxy_pb2.Topic(name="init_job_session_pair",
                                        partyId=self.party_id,
                                        role=self.role,
                                        callback=None)
            topic_dst = proxy_pb2.Topic(name="init_job_session_pair",
                                        partyId=self.party_id,
                                        role=self.role,
                                        callback=None)
            command_test = proxy_pb2.Command(name="init_job_session_pair")
            conf_test = proxy_pb2.Conf(overallTimeout=1000,
                                       completionWaitTimeout=1000,
                                       packetIntervalTimeout=1000,
                                       maxRetries=10)

            metadata = proxy_pb2.Metadata(task=task_info,
                                          src=topic_src,
                                          dst=topic_dst,
                                          command=command_test,
                                          operator="init_job_session_pair",
                                          seq=0,
                                          ack=0)
            packet = proxy_pb2.Packet(header=metadata)

            self.stub.unaryCall(packet)
            L.info(
                f"send RollSiteContext init to Proxy: {to_one_line_string(packet)}"
            )
        except Exception as e:
            raise GrpcCallError("init_job_session_pair", self.proxy_endpoint,
                                e)