예제 #1
0
def _get_remote_obj_store_table(parties, rollsite):
    from eggroll.roll_site.utils.roll_site_utils import create_store_name
    from eggroll.core.transfer_model import ErRollSiteHeader
    tables = []
    for role_party_id in parties:
        _role = role_party_id[0]
        _party_id = str(role_party_id[1])

        _options = {}
        obj_type = 'object'
        roll_site_header = ErRollSiteHeader(
            roll_site_session_id=rollsite.roll_site_session_id,
            name=rollsite.name,
            tag=rollsite.tag,
            src_role=rollsite.local_role,
            src_party_id=rollsite.party_id,
            dst_role=_role,
            dst_party_id=_party_id,
            data_type=obj_type,
            options=_options)
        _tagged_key = create_store_name(roll_site_header)
        namespace = rollsite.roll_site_session_id
        tables.append(rollsite.ctx.rp_ctx.load(namespace, _tagged_key))

    return tables
예제 #2
0
    def __init__(self, options: dict = None):
        if options is None:
            options = {}

        super().__init__(options)

        er_partition = options['er_partition']
        self.partition = er_partition
        self.store_locator = er_partition._store_locator
        self.partition_id = er_partition._id

        self.namespace = self.store_locator._namespace

        #_store_type = StoreTypes.ROLLPAIR_ROLLSITE
        # self._store_locator = meta_pb2.StoreLocator(storeType=_store_type,
        #                                             namespace=self.namespace,
        #                                             name=self.store_locator._name,
        #                                             partitioner=self.store_locator._partitioner,
        #                                             serdes=self.store_locator._serdes,
        #                                             totalPartitions=self.store_locator._total_partitions)

        self.roll_site_header_string = options.get('roll_site_header', None)
        self.is_writable = False
        if self.roll_site_header_string:
            self.roll_site_header = ErRollSiteHeader.from_proto_string(
                self.roll_site_header_string.encode(stringify_charset))
            self.proxy_endpoint = ErEndpoint.from_proto_string(
                options['proxy_endpoint'].encode(stringify_charset))
            self.obj_type = options['obj_type']
            self.is_writable = True

            L.info(
                f"writable RollSiteAdapter: {self.namespace}, {self.partition_id}. proxy_endpoint: {self.proxy_endpoint}, partition: {self.partition}"
            )
예제 #3
0
    def push(self, obj, parties: list = None, options: dict = None):
        if options is None:
            options = {}

        futures = []
        for role_party_id in parties:
            self.ctx.pushing_latch.count_up()
            dst_role = role_party_id[0]
            dst_party_id = str(role_party_id[1])
            data_type = 'rollpair' if isinstance(obj, RollPair) else 'object'
            rs_header = ErRollSiteHeader(
                roll_site_session_id=self.roll_site_session_id,
                name=self.name,
                tag=self.tag,
                src_role=self.local_role,
                src_party_id=self.party_id,
                dst_role=dst_role,
                dst_party_id=dst_party_id,
                data_type=data_type)

            if isinstance(obj, RollPair):
                future = self._run_thread(self._impl_instance._push_rollpair,
                                          obj, rs_header, options)
            else:
                future = self._run_thread(self._impl_instance._push_bytes, obj,
                                          rs_header, options)
            futures.append(future)
        return futures
예제 #4
0
 def count_batch(self, rs_header: ErRollSiteHeader, batch_pairs):
     L.trace(f'count batch. rs_key={rs_header.get_rs_key()}, rs_header={rs_header}, batch_pairs={batch_pairs}')
     batch_seq_id = rs_header._batch_seq
     stream_seq_id = rs_header._stream_seq
     if self._rs_header is None:
         self._rs_header = rs_header
         self._rs_key = rs_header.get_rs_key()
         L.debug(f"header arrived. rs_key={rs_header.get_rs_key()}, rs_header={rs_header}")
         self._header_arrive_event.set()
     self._batch_seq_to_pair_counter[batch_seq_id] = batch_pairs
     self._stream_seq_to_pair_counter[stream_seq_id] += batch_pairs
     self._stream_seq_to_batch_seq[stream_seq_id] = batch_seq_id
예제 #5
0
    def pull(self, parties: list = None):
        self._pull_start_wall_time = time.time()
        self._pull_start_cpu_time = time.perf_counter()
        futures = []
        for src_role, src_party_id in parties:
            src_party_id = str(src_party_id)
            roll_site_header = ErRollSiteHeader(
                roll_site_session_id=self.roll_site_session_id,
                name=self.name,
                tag=self.tag,
                src_role=src_role,
                src_party_id=src_party_id,
                dst_role=self.local_role,
                dst_party_id=self.party_id)
            _tagged_key = create_store_name(roll_site_header)

            name = _tagged_key

            model = proxy_pb2.Model(name=_stringify(roll_site_header))
            task_info = proxy_pb2.Task(taskId=name, model=model)
            topic_src = proxy_pb2.Topic(name="get_status",
                                        partyId=src_party_id,
                                        role=src_role,
                                        callback=None)
            topic_dst = proxy_pb2.Topic(name="get_status",
                                        partyId=self.party_id,
                                        role=self.local_role,
                                        callback=None)
            get_status_command = proxy_pb2.Command(name="get_status")
            conf_test = proxy_pb2.Conf(overallTimeout=1000,
                                       completionWaitTimeout=1000,
                                       packetIntervalTimeout=1000,
                                       maxRetries=10)

            metadata = proxy_pb2.Metadata(task=task_info,
                                          src=topic_src,
                                          dst=topic_dst,
                                          command=get_status_command,
                                          operator="getStatus",
                                          seq=0,
                                          ack=0)

            packet = proxy_pb2.Packet(header=metadata)
            namespace = self.roll_site_session_id
            L.info(
                f"pulling prepared tagged_key: {_tagged_key}, packet:{to_one_line_string(packet)}"
            )
            futures.append(
                RollSite.receive_exeutor_pool.submit(RollSite._thread_receive,
                                                     self, packet, namespace,
                                                     roll_site_header))

        return futures
예제 #6
0
    def run(self):
        # batch stream must be executed serially, and reinit.
        # TODO:0:  remove lock to bss
        rs_header = None
        with PutBatchTask._partition_lock[self.tag]:   # tag includes partition info in tag generation
            L.trace(f"do_store start for tag={self.tag}, partition_id={self.partition._id}")
            bss = _BatchStreamStatus.get_or_create(self.tag)
            try:
                broker = TransferService.get_or_create_broker(self.tag, write_signals=1)

                iter_wait = 0
                iter_timeout = int(CoreConfKeys.EGGROLL_CORE_FIFOBROKER_ITER_TIMEOUT_SEC.get())

                batch = None
                batch_get_interval = 0.1
                with create_adapter(self.partition) as db, db.new_batch() as wb:
                    #for batch in broker:
                    while not broker.is_closable():
                        try:
                            batch = broker.get(block=True, timeout=batch_get_interval)
                        except queue.Empty as e:
                            iter_wait += batch_get_interval
                            if iter_wait > iter_timeout:
                                raise TimeoutError(f'timeout in PutBatchTask.run. tag={self.tag}, iter_timeout={iter_timeout}, iter_wait={iter_wait}')
                            else:
                                continue
                        except BrokerClosed as e:
                            continue

                        iter_wait = 0
                        rs_header = ErRollSiteHeader.from_proto_string(batch.header.ext)
                        batch_pairs = 0
                        if batch.data:
                            bin_data = ArrayByteBuffer(batch.data)
                            reader = PairBinReader(pair_buffer=bin_data, data=batch.data)
                            for k_bytes, v_bytes in reader.read_all():
                                wb.put(k_bytes, v_bytes)
                                batch_pairs += 1
                        bss.count_batch(rs_header, batch_pairs)
                        # TODO:0
                        bss._data_type = rs_header._data_type
                        if rs_header._stage == FINISH_STATUS:
                            bss.set_done(rs_header)  # starting from 0

                bss.check_finish()
                # TransferService.remove_broker(tag) will be called in get_status phrase finished or exception got
            except Exception as e:
                L.exception(f'_run_put_batch error, tag={self.tag}, '
                        f'rs_key={rs_header.get_rs_key() if rs_header is not None else None}, rs_header={rs_header}')
                raise e
            finally:
                TransferService.remove_broker(self.tag)
예제 #7
0
    def __init__(self, options: dict = None):
        if options is None:
            options = {}

        super().__init__(options)

        er_partition = options['er_partition']
        self.partition = er_partition
        self.store_locator = er_partition._store_locator
        self.partition_id = er_partition._id

        self.namespace = self.store_locator._namespace

        #_store_type = StoreTypes.ROLLPAIR_ROLLSITE
        # self._store_locator = meta_pb2.StoreLocator(storeType=_store_type,
        #                                             namespace=self.namespace,
        #                                             name=self.store_locator._name,
        #                                             partitioner=self.store_locator._partitioner,
        #                                             serdes=self.store_locator._serdes,
        #                                             totalPartitions=self.store_locator._total_partitions)

        self.roll_site_header_string = options.get('roll_site_header', None)
        self.is_writable = False
        if self.roll_site_header_string:
            self.roll_site_header = ErRollSiteHeader.from_proto_string(
                self.roll_site_header_string.encode(stringify_charset))
            self.roll_site_header._options['partition_id'] = self.partition_id
            self.proxy_endpoint = ErEndpoint.from_proto_string(
                options['proxy_endpoint'].encode(stringify_charset))
            self.obj_type = options['obj_type']
            self.is_writable = True

            L.trace(
                f"writable RollSiteAdapter: {self.namespace}, partition_id={self.partition_id}. proxy_endpoint={self.proxy_endpoint}, partition={self.partition}"
            )

        self.unarycall_max_retry_cnt = int(
            RollSiteConfKeys.EGGROLL_ROLLSITE_UNARYCALL_CLIENT_MAX_RETRY.
            get_with(options))
        self.push_max_retry_cnt = int(
            RollSiteConfKeys.EGGROLL_ROLLSITE_PUSH_CLIENT_MAX_RETRY.get_with(
                options))
        self.push_overall_timeout = int(
            RollSiteConfKeys.EGGROLL_ROLLSITE_OVERALL_TIMEOUT_SEC.get_with(
                options))
        self.push_completion_wait_timeout = int(
            RollSiteConfKeys.EGGROLL_ROLLSITE_COMPLETION_WAIT_TIMEOUT_SEC.
            get_with(options))
        self.push_packet_interval_timeout = int(
            RollSiteConfKeys.EGGROLL_ROLLSITE_PACKET_INTERVAL_TIMEOUT_SEC.
            get_with(options))
예제 #8
0
 def pull(self, parties: list = None):
     futures = []
     for src_role, src_party_id in parties:
         src_party_id = str(src_party_id)
         rs_header = ErRollSiteHeader(
             roll_site_session_id=self.roll_site_session_id,
             name=self.name,
             tag=self.tag,
             src_role=src_role,
             src_party_id=src_party_id,
             dst_role=self.local_role,
             dst_party_id=self.party_id)
         futures.append(
             self._receive_executor_pool.submit(self._pull_one, rs_header))
     return futures
예제 #9
0
    def _pull_one(self, rs_header: ErRollSiteHeader, options: dict = None):
        if options is None:
            options = {}

        start_time = time.time()
        rs_key = rp_name = rs_header.get_rs_key()
        rp_namespace = self.roll_site_session_id
        transfer_tag_prefix = "putBatch-" + rs_header.get_rs_key() + "-"
        last_total_batches = None
        last_cur_pairs = -1
        pull_attempts = 0
        data_type = None
        L.debug(f'pulling rs_key={rs_key}')
        try:
            # make sure rollpair already created
            pull_header_interval = self.pull_header_interval
            pull_header_timeout = self.pull_header_timeout  # skips pickling self
            pull_interval = self.pull_interval  # skips pickling self

            def get_partition_status(task):
                put_batch_task = PutBatchTask(
                    transfer_tag_prefix + str(task._inputs[0]._id), None)
                return put_batch_task.get_status(pull_interval)

            def get_status(roll_site):
                pull_status = {}
                total_pairs = 0
                total_batches = 0
                all_finished = True

                final_options = options.copy()
                final_options['create_if_missing'] = False
                store = roll_site.ctx.rp_ctx.load(name=rp_name,
                                                  namespace=rp_namespace,
                                                  options=final_options)

                if store is None:
                    raise ValueError(
                        f'illegal state for rp_name={rp_name}, rp_namespace={rp_namespace}'
                    )

                all_status = store.with_stores(
                    get_partition_status,
                    options={"__op": "get_partition_status"})

                for part_id, part_status in all_status:
                    if not part_status.is_finished:
                        all_finished = False
                    pull_status[part_id] = part_status
                    total_batches += part_status.total_batches
                    total_pairs += part_status.total_pairs

                return pull_status, all_finished, total_batches, total_pairs

            def clear_status(task):
                return PutBatchTask(transfer_tag_prefix +
                                    str(task._inputs[0]._id)).clear_status()

            wait_time = 0
            header_response = None
            while wait_time < pull_header_timeout and \
                    (header_response is None or not isinstance(header_response[0][1], ErRollSiteHeader)):
                final_options = options.copy()
                final_options['create_if_missing'] = True
                final_options['total_partitions'] = 1
                header_response = self.ctx.rp_ctx.load(name=STATUS_TABLE_NAME,
                                                       namespace=rp_namespace,
                                                       options=final_options) \
                    .with_stores(lambda x: PutBatchTask(transfer_tag_prefix + "0").get_header(pull_header_interval), options={"__op": "pull_header"})
                wait_time += pull_header_interval

                #pull_status, all_finished, total_batches, total_pairs = stat_all_status(self)
                L.debug(
                    f"roll site get header_response: rs_key={rs_key}, rs_header={rs_header}, wait_time={wait_time}"
                )

            if header_response is None or not isinstance(
                    header_response[0][1], ErRollSiteHeader):
                raise IOError(
                    f"roll site pull header failed: rs_key={rs_key}, rs_header={rs_header}, timeout={self.pull_header_timeout}"
                )
            else:
                header: ErRollSiteHeader = header_response[0][1]
                # TODO:0:  push bytes has only one partition, that means it has finished, need not get_status
                data_type = header._data_type
                L.debug(
                    f"roll site pull header successful: rs_key={rs_key}, rs_header={header}"
                )

            pull_status = {}
            for cur_retry in range(self.pull_max_retry):
                pull_attempts = cur_retry
                pull_status, all_finished, total_batches, cur_pairs = get_status(
                    self)

                if not all_finished:
                    L.debug(
                        f'getting status NOT finished for rs_key={rs_key}, '
                        f'rs_header={rs_header}, '
                        f'cur_status={pull_status}, '
                        f'attempts={pull_attempts}, '
                        f'cur_pairs={cur_pairs}, '
                        f'last_cur_pairs={last_cur_pairs}, '
                        f'total_batches={total_batches}, '
                        f'last_total_batches={last_total_batches}, '
                        f'elapsed={time.time() - start_time}')
                    if last_cur_pairs == cur_pairs and cur_pairs > 0:
                        raise IOError(
                            f"roll site pull waiting failed because there is no updated progress: rs_key={rs_key}, "
                            f"rs_header={rs_header}, pull_status={pull_status}, last_cur_pairs={last_cur_pairs}, cur_pairs={cur_pairs}"
                        )
                else:
                    L.debug(
                        f"getting status DO finished for rs_key={rs_key}, rs_header={rs_header}, pull_status={pull_status}, cur_pairs={cur_pairs}, total_batches={total_batches}"
                    )
                    rp = self.ctx.rp_ctx.load(name=rp_name,
                                              namespace=rp_namespace)

                    clear_future = self._receive_executor_pool.submit(
                        rp.with_stores,
                        clear_status,
                        options={"__op": "clear_status"})
                    if data_type == "object":
                        result = pickle.loads(b''.join(
                            map(
                                lambda t: t[1],
                                sorted(rp.get_all(),
                                       key=lambda x: int.from_bytes(
                                           x[0], "big")))))
                        rp.destroy()
                        L.debug(
                            f"pulled object: rs_key={rs_key}, rs_header={rs_header}, is_none={result is None}, "
                            f"elapsed={time.time() - start_time}")
                    else:
                        result = rp

                        if L.isEnabledFor(logging.DEBUG):
                            L.debug(
                                f"pulled roll_pair: rs_key={rs_key}, rs_header={rs_header}, rp.count={rp.count()}, "
                                f"elapsed={time.time() - start_time}")

                    clear_future.result()
                    return result
                last_total_batches = total_batches
                last_cur_pairs = cur_pairs
            raise IOError(
                f"roll site pull failed. max try exceeded: {self.pull_max_retry}, rs_key={rs_key}, "
                f"rs_header={rs_header}, pull_status={pull_status}")
        except Exception as e:
            L.exception(
                f"fatal error: when pulling rs_key={rs_key}, rs_header={rs_header}, attempts={pull_attempts}"
            )
            raise e
예제 #10
0
    def _push_rollpair(self,
                       rp: RollPair,
                       rs_header: ErRollSiteHeader,
                       options: dict = None):
        if options is None:
            options = {}
        rs_key = rs_header.get_rs_key()
        if L.isEnabledFor(logging.DEBUG):
            L.debug(
                f"pushing rollpair: rs_key={rs_key}, rs_header={rs_header}, rp.count={rp.count()}"
            )
        start_time = time.time()

        rs_header._total_partitions = rp.get_partitions()
        serdes = options.get('serdes', None)
        if serdes is not None:
            rs_header._options['serdes'] = serdes

        wrapee_cls = options.get('wrapee_cls', None)
        if serdes is not None:
            rs_header._options['wrapee_cls'] = wrapee_cls

        batches_per_stream = self.push_batches_per_stream
        body_bytes = self.batch_body_bytes
        endpoint = self.ctx.proxy_endpoint
        max_retry_cnt = self.push_max_retry
        long_retry_cnt = self.push_long_retry
        per_stream_timeout = self.push_per_stream_timeout

        def _push_partition(ertask):
            rs_header._partition_id = ertask._inputs[0]._id
            L.trace(
                f"pushing rollpair partition. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}"
            )

            from eggroll.core.grpc.factory import GrpcChannelFactory
            from eggroll.core.proto import proxy_pb2_grpc
            grpc_channel_factory = GrpcChannelFactory()

            with create_adapter(ertask._inputs[0]) as db, db.iteritems() as rb:
                # NOTICE AGAIN: all modifications to rs_header are limited in bs_helper.
                # rs_header is shared by bs_helper and here. any modification in bs_helper affects this header.
                # Remind that python's object references are passed by value,
                # meaning the 'pointer' is copied, while the contents are modificable
                bs_helper = _BatchStreamHelper(rs_header)
                bin_batch_streams = bs_helper._generate_batch_streams(
                    pair_iter=rb,
                    batches_per_stream=batches_per_stream,
                    body_bytes=body_bytes)

                channel = grpc_channel_factory.create_channel(endpoint)
                stub = proxy_pb2_grpc.DataTransferServiceStub(channel)
                for batch_stream in bin_batch_streams:
                    batch_stream_data = list(batch_stream)
                    cur_retry = 0
                    exception = None
                    while cur_retry < max_retry_cnt:
                        L.trace(
                            f'pushing rollpair partition stream. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}'
                        )
                        try:
                            stub.push(bs_helper.generate_packet(
                                batch_stream_data, cur_retry),
                                      timeout=per_stream_timeout)
                            exception = None
                            break
                        except Exception as e:
                            if cur_retry < max_retry_cnt - long_retry_cnt:
                                retry_interval = round(
                                    min(2 * cur_retry, 20) +
                                    random.random() * 10, 3)
                            else:
                                retry_interval = round(
                                    300 + random.random() * 10, 3)
                            L.warn(
                                f"push rp partition error. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, max_retry_cnt={max_retry_cnt}, cur_retry={cur_retry}, retry_interval={retry_interval}",
                                exc_info=e)
                            time.sleep(retry_interval)
                            if isinstance(e, RpcError) and e.code(
                            ).name == 'UNAVAILABLE':
                                channel = grpc_channel_factory.create_channel(
                                    endpoint, refresh=True)
                                stub = proxy_pb2_grpc.DataTransferServiceStub(
                                    channel)

                            exception = e
                        finally:
                            cur_retry += 1
                    if exception is not None:
                        L.exception(
                            f"push partition failed. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}",
                            exc_info=exception)
                        raise exception
                    L.trace(
                        f'pushed rollpair partition stream. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, retry count={cur_retry - 1}'
                    )

            L.trace(
                f"pushed rollpair partition. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}"
            )

        rp.with_stores(_push_partition, options={"__op": "push_partition"})
        if L.isEnabledFor(logging.DEBUG):
            L.debug(
                f"pushed rollpair: rs_key={rs_key}, rs_header={rs_header}, count={rp.count()}, elapsed={time.time() - start_time}"
            )
        self.ctx.pushing_latch.count_down()
예제 #11
0
    def _push_bytes(self,
                    obj,
                    rs_header: ErRollSiteHeader,
                    options: dict = None):
        if options is None:
            options = {}
        start_time = time.time()

        rs_key = rs_header.get_rs_key()
        int_size = 4

        if L.isEnabledFor(logging.DEBUG):
            L.debug(f"pushing object: rs_key={rs_key}, rs_header={rs_header}")

        def _generate_obj_bytes(py_obj, body_bytes):
            key_id = 0
            obj_bytes = pickle.dumps(py_obj)
            obj_bytes_len = len(obj_bytes)
            cur_pos = 0

            while cur_pos <= obj_bytes_len:
                yield key_id.to_bytes(
                    int_size, "big"), obj_bytes[cur_pos:cur_pos + body_bytes]
                key_id += 1
                cur_pos += body_bytes

        rs_header._partition_id = 0
        rs_header._total_partitions = 1

        serdes = options.get('serdes', None)
        if serdes is not None:
            rs_header._options['serdes'] = serdes

        wrapee_cls = options.get('wrapee_cls', None)
        if serdes is not None:
            rs_header._options['wrapee_cls'] = wrapee_cls
        # NOTICE: all modifications to rs_header are limited in bs_helper.
        # rs_header is shared by bs_helper and here. any modification in bs_helper affects this header.
        # Remind that python's object references are passed by value,
        # meaning the 'pointer' is copied, while the contents are modificable
        bs_helper = _BatchStreamHelper(rs_header=rs_header)
        bin_batch_streams = bs_helper._generate_batch_streams(
            pair_iter=_generate_obj_bytes(obj, self.batch_body_bytes),
            batches_per_stream=self.push_batches_per_stream,
            body_bytes=self.batch_body_bytes)

        grpc_channel_factory = GrpcChannelFactory()
        channel = grpc_channel_factory.create_channel(self.ctx.proxy_endpoint)
        stub = proxy_pb2_grpc.DataTransferServiceStub(channel)
        max_retry_cnt = self.push_max_retry
        long_retry_cnt = self.push_long_retry
        per_stream_timeout = self.push_per_stream_timeout

        # if use stub.push.future here, retry mechanism is a problem to solve
        for batch_stream in bin_batch_streams:
            cur_retry = 0

            batch_stream_data = list(batch_stream)
            exception = None
            while cur_retry < max_retry_cnt:
                L.trace(
                    f'pushing object stream. rs_key={rs_key}, rs_header={rs_header}, cur_retry={cur_retry}'
                )
                try:
                    stub.push(bs_helper.generate_packet(
                        batch_stream_data, cur_retry),
                              timeout=per_stream_timeout)
                    exception = None
                    break
                except Exception as e:
                    if cur_retry <= max_retry_cnt - long_retry_cnt:
                        retry_interval = round(
                            min(2 * cur_retry, 20) + random.random() * 10, 3)
                    else:
                        retry_interval = round(300 + random.random() * 10, 3)
                    L.warn(
                        f"push object error. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, max_retry_cnt={max_retry_cnt}, cur_retry={cur_retry}, retry_interval={retry_interval}",
                        exc_info=e)
                    time.sleep(retry_interval)
                    if isinstance(e,
                                  RpcError) and e.code().name == 'UNAVAILABLE':
                        channel = grpc_channel_factory.create_channel(
                            self.ctx.proxy_endpoint, refresh=True)
                        stub = proxy_pb2_grpc.DataTransferServiceStub(channel)
                    exception = e
                finally:
                    cur_retry += 1
            if exception is not None:
                L.exception(
                    f"push object failed. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}",
                    exc_info=exception)
                raise exception
            L.trace(
                f'pushed object stream. rs_key={rs_key}, rs_header={rs_header}, cur_retry={cur_retry - 1}'
            )

        L.debug(
            f"pushed object: rs_key={rs_key}, rs_header={rs_header}, is_none={obj is None}, elapsed={time.time() - start_time}"
        )
        self.ctx.pushing_latch.count_down()
예제 #12
0
    def push(self, obj, parties: list = None):
        L.info(
            f"pushing: self:{self.__dict__}, obj_type:{type(obj)}, parties:{parties}"
        )
        self._push_start_wall_time = time.time()
        self._push_start_cpu_time = time.perf_counter()
        futures = []
        for role_party_id in parties:
            self.ctx.pushing_task_count += 1
            _role = role_party_id[0]
            _party_id = str(role_party_id[1])

            _options = {}
            obj_type = 'rollpair' if isinstance(obj, RollPair) else 'object'
            roll_site_header = ErRollSiteHeader(
                roll_site_session_id=self.roll_site_session_id,
                name=self.name,
                tag=self.tag,
                src_role=self.local_role,
                src_party_id=self.party_id,
                dst_role=_role,
                dst_party_id=_party_id,
                data_type=obj_type,
                options=_options)
            _tagged_key = create_store_name(roll_site_header)
            L.debug(f"pushing start party:{type(obj)}, {_tagged_key}")
            namespace = self.roll_site_session_id

            if isinstance(obj, RollPair):
                rp = obj
            else:
                rp = self.ctx.rp_ctx.load(namespace, _tagged_key)
                rp.put(_tagged_key, obj)
            rp.disable_gc()
            L.info(f"pushing prepared: {type(obj)}, tag_key:{_tagged_key}")

            def map_values(_tagged_key, is_standalone, roll_site_header):
                if is_standalone:
                    dst_name = _tagged_key
                    store_type = rp.get_store_type()
                else:
                    dst_name = DELIM.join([
                        _tagged_key, self.dst_host,
                        str(self.dst_port), obj_type
                    ])
                    store_type = StoreTypes.ROLLPAIR_ROLLSITE
                if is_standalone:
                    status_rp = self.ctx.rp_ctx.load(
                        namespace,
                        STATUS_TABLE_NAME + DELIM + self.roll_site_session_id,
                        options=_options)
                    status_rp.disable_gc()
                    if isinstance(obj, RollPair):
                        status_rp.put(_tagged_key,
                                      (obj_type.encode("utf-8"), rp.get_name(),
                                       rp.get_namespace()))
                    else:
                        status_rp.put(
                            _tagged_key,
                            (obj_type.encode("utf-8"), dst_name, namespace))
                else:
                    store = rp.get_store()
                    store_locator = store._store_locator
                    new_store_locator = ErStoreLocator(
                        store_type=store_type,
                        namespace=namespace,
                        name=dst_name,
                        total_partitions=store_locator._total_partitions,
                        partitioner=store_locator._partitioner,
                        serdes=store_locator._serdes)

                    # TODO:0: move options from job to store when database modification finished

                    options = {
                        "roll_site_header": roll_site_header,
                        "proxy_endpoint": self.ctx.proxy_endpoint,
                        "obj_type": obj_type
                    }

                    if isinstance(obj, RollPair):
                        roll_site_header._options[
                            'total_partitions'] = obj.get_store(
                            )._store_locator._total_partitions
                        L.info(
                            f"RollSite.push: pushing {roll_site_header}, type: RollPair, count: {obj.count()}"
                        )
                    else:
                        L.info(
                            f"RollSite.push: pushing {roll_site_header}, type: object"
                        )
                    rp.map_values(
                        lambda v: v,
                        output=ErStore(store_locator=new_store_locator),
                        options=options)

                L.info(
                    f"RollSite.push: push {roll_site_header} done. type:{type(obj)}"
                )
                return _tagged_key

            future = RollSite.receive_exeutor_pool.submit(
                map_values, _tagged_key, self._is_standalone, roll_site_header)
            if not self._is_standalone and (obj_type == 'object'
                                            or obj_type == b'object'):
                tmp_rp = rp
            else:
                tmp_rp = None

            future.add_done_callback(
                functools.partial(self._push_callback, tmp_rp=tmp_rp))
            futures.append(future)

        return futures
예제 #13
0
    def push(self, obj, parties: list = None):
        L.info(f"pushing: self:{self.__dict__}, obj_type:{type(obj)}, parties:{parties}")
        self._push_start_wall_time = time.time()
        self._push_start_cpu_time = time.perf_counter()
        self.ctx.pushing_task_count += 1
        futures = []
        for role_party_id in parties:
            _role = role_party_id[0]
            _party_id = str(role_party_id[1])

            _options = {}
            obj_type = 'rollpair' if isinstance(obj, RollPair) else 'object'
            roll_site_header = ErRollSiteHeader(
                roll_site_session_id=self.roll_site_session_id,
                name=self.name,
                tag=self.tag,
                src_role=self.local_role,
                src_party_id=self.party_id,
                dst_role=_role,
                dst_party_id=_party_id,
                data_type=obj_type,
                options=_options)
            _tagged_key = create_store_name(roll_site_header)
            L.debug(f"pushing start party:{type(obj)}, {_tagged_key}")
            namespace = self.roll_site_session_id

            if isinstance(obj, RollPair):
                rp = obj
            else:
                rp = self.ctx.rp_ctx.load(namespace, _tagged_key)
                rp.put(_tagged_key, obj)
            rp.disable_gc()
            L.info(f"pushing prepared: {type(obj)}, tag_key:{_tagged_key}")

            def map_values(_tagged_key):
                is_standalone = self.ctx.rp_ctx.get_session().get_option(
                        SessionConfKeys.CONFKEY_SESSION_DEPLOY_MODE) == DeployModes.STANDALONE

                if is_standalone:
                    dst_name = _tagged_key
                    store_type = rp.get_store_type()
                else:
                    dst_name = DELIM.join([_tagged_key,
                                           self.dst_host,
                                           str(self.dst_port),
                                           obj_type])
                    store_type = StoreTypes.ROLLPAIR_ROLLSITE
                if is_standalone:
                    status_rp = self.ctx.rp_ctx.load(namespace, STATUS_TABLE_NAME + DELIM + self.roll_site_session_id, options=_options)
                    status_rp.disable_gc()
                    if isinstance(obj, RollPair):
                        status_rp.put(_tagged_key, (obj_type.encode("utf-8"), rp.get_name(), rp.get_namespace()))
                    else:
                        status_rp.put(_tagged_key, (obj_type.encode("utf-8"), dst_name, namespace))
                else:
                    store = rp.get_store()
                    store_locator = store._store_locator
                    new_store_locator = ErStoreLocator(store_type=store_type,
                                                       namespace=namespace,
                                                       name=dst_name,
                                                       total_partitions=store_locator._total_partitions,
                                                       partitioner=store_locator._partitioner,
                                                       serdes=store_locator._serdes)

                    # TODO:0: move options from job to store when database modification finished

                    options = {"roll_site_header": roll_site_header,
                               "proxy_endpoint": self.ctx.proxy_endpoint,
                               "obj_type": obj_type}

                    if isinstance(obj, RollPair):
                        roll_site_header._options['total_partitions'] = obj.get_store()._store_locator._total_partitions
                        L.debug(f"pushing map_values: {dst_name}, count: {obj.count()}, tag_key:{_tagged_key}")
                    rp.map_values(lambda v: v,
                        output=ErStore(store_locator=new_store_locator),
                                  options=options)

                L.info(f"pushing map_values done:{type(obj)}, tag_key:{_tagged_key}")
                return _tagged_key

            future = self.process_pool.submit(map_values, _tagged_key)
            future.add_done_callback(self._push_callback)
            futures.append(future)

        self.process_pool.shutdown(wait=False)

        return futures