예제 #1
0
        def do_store(store_partition_inner, is_shuffle_inner, total_writers_inner, reduce_op_inner):
            done_cnt = 0
            tag = self.__generate_tag(store_partition_inner._id) if is_shuffle_inner else self.__transfer_id
            try:
                broker = TransferService.get_or_create_broker(tag, write_signals=total_writers_inner)
                L.trace(f"do_store start for tag={tag}")
                batches = TransferPair.bin_batch_to_pair(b.data for b in broker)

                serdes = create_serdes(store_partition_inner._store_locator._serdes)
                if reduce_op_inner is not None:
                    merger = functools.partial(do_merge,
                                               merge_func=reduce_op_inner,
                                               serdes=serdes)
                else:
                    merger = None

                with create_adapter(store_partition_inner) as db:
                    L.trace(f"do_store create_db for tag={tag} for partition={store_partition_inner}")
                    with db.new_batch() as wb:
                        for k, v in batches:
                            if reduce_op_inner is None:
                                wb.put(k, v)
                            else:
                                wb.merge(merger, k, v)
                            done_cnt += 1
                    L.trace(f"do_store done for tag={tag} for partition={store_partition_inner}")
            except Exception as e:
                L.exception(f'Error in do_store for tag={tag}')
                raise e
            finally:
                TransferService.remove_broker(tag)
            return done_cnt
예제 #2
0
 def generate_broker():
     with create_adapter(task._inputs[0]) as db, db.iteritems() as rb:
         limit = None if er_pair._key is None else key_serdes.deserialize(er_pair._key)
         try:
             yield from TransferPair.pair_to_bin_batch(rb, limit=limit)
         finally:
             TransferService.remove_broker(tag)
예제 #3
0
 def do_store(store_partition_inner, is_shuffle_inner,
              total_writers_inner):
     done_cnt = 0
     tag = self.__generate_tag(
         store_partition_inner._id
     ) if is_shuffle_inner else self.__transfer_id
     try:
         broker = TransferService.get_or_create_broker(
             tag, write_signals=total_writers_inner)
         L.debug(f"do_store start for tag: {tag}")
         batches = TransferPair.bin_batch_to_pair(b.data
                                                  for b in broker)
         with create_adapter(store_partition_inner) as db:
             L.debug(
                 f"do_store create_db for tag: {tag} for partition: {store_partition_inner}"
             )
             with db.new_batch() as wb:
                 for k, v in batches:
                     wb.put(k, v)
                     done_cnt += 1
             L.debug(
                 f"do_store done for tag: {tag} for partition: {store_partition_inner}"
             )
         TransferService.remove_broker(tag)
     except Exception as e:
         L.error(f'Error in do_store for tag {tag}')
         raise e
     return done_cnt
예제 #4
0
    def run(self):
        # batch stream must be executed serially, and reinit.
        # TODO:0:  remove lock to bss
        rs_header = None
        with PutBatchTask._partition_lock[self.tag]:   # tag includes partition info in tag generation
            L.trace(f"do_store start for tag={self.tag}, partition_id={self.partition._id}")
            bss = _BatchStreamStatus.get_or_create(self.tag)
            try:
                broker = TransferService.get_or_create_broker(self.tag, write_signals=1)

                iter_wait = 0
                iter_timeout = int(CoreConfKeys.EGGROLL_CORE_FIFOBROKER_ITER_TIMEOUT_SEC.get())

                batch = None
                batch_get_interval = 0.1
                with create_adapter(self.partition) as db, db.new_batch() as wb:
                    #for batch in broker:
                    while not broker.is_closable():
                        try:
                            batch = broker.get(block=True, timeout=batch_get_interval)
                        except queue.Empty as e:
                            iter_wait += batch_get_interval
                            if iter_wait > iter_timeout:
                                raise TimeoutError(f'timeout in PutBatchTask.run. tag={self.tag}, iter_timeout={iter_timeout}, iter_wait={iter_wait}')
                            else:
                                continue
                        except BrokerClosed as e:
                            continue

                        iter_wait = 0
                        rs_header = ErRollSiteHeader.from_proto_string(batch.header.ext)
                        batch_pairs = 0
                        if batch.data:
                            bin_data = ArrayByteBuffer(batch.data)
                            reader = PairBinReader(pair_buffer=bin_data, data=batch.data)
                            for k_bytes, v_bytes in reader.read_all():
                                wb.put(k_bytes, v_bytes)
                                batch_pairs += 1
                        bss.count_batch(rs_header, batch_pairs)
                        # TODO:0
                        bss._data_type = rs_header._data_type
                        if rs_header._stage == FINISH_STATUS:
                            bss.set_done(rs_header)  # starting from 0

                bss.check_finish()
                # TransferService.remove_broker(tag) will be called in get_status phrase finished or exception got
            except Exception as e:
                L.exception(f'_run_put_batch error, tag={self.tag}, '
                        f'rs_key={rs_header.get_rs_key() if rs_header is not None else None}, rs_header={rs_header}')
                raise e
            finally:
                TransferService.remove_broker(self.tag)
예제 #5
0
파일: storage.py 프로젝트: mask-pp/eggroll
    def wait_finish(cls, tag, timeout):
        bss = cls.get_or_create(tag)
        finished = bss._stream_finish_event.wait(timeout)
        if finished:
            TransferService.remove_broker(tag)
            #del cls._recorder[tag]

        return BSS(tag=bss._tag,
                   is_finished=finished,
                   total_batches=bss._total_batches,
                   batch_seq_to_pair_counter=bss._batch_seq_to_pair_counter,
                   total_streams=bss._total_streams,
                   stream_seq_to_pair_counter=bss._stream_seq_to_pair_counter,
                   stream_seq_to_batch_seq=bss._stream_seq_to_batch_seq,
                   total_pairs=sum(bss._batch_seq_to_pair_counter.values()),
                   data_type=bss._data_type)
예제 #6
0
    def test_recv(self):
        def start_server():
            transfer_service = GrpcTransferService()

            options = {
                TransferConfKeys.CONFKEY_TRANSFER_SERVICE_PORT: transfer_port
            }
            transfer_service.start(options=options)

        thread = threading.Thread(target=start_server)
        thread.start()

        broker = TransferService.get_or_create_broker('test')
        i = 0
        while not broker.is_closable():
            try:
                data = broker.get(block=True, timeout=0.1)
                if data:
                    print(f'recv: {i}: {data}')
                    i += 1
            except queue.Empty as e:
                print(f'empty')
            except BrokerClosed:
                break

        thread.join(1)
예제 #7
0
 def do_store(store_partition_inner, is_shuffle_inner,
              total_writers_inner):
     tag = self.__generate_tag(
         store_partition_inner._id
     ) if is_shuffle_inner else self.__transfer_id
     broker = TransferService.get_or_create_broker(
         tag, write_signals=total_writers_inner)
     L.debug(f"do_store_start:{tag}")
     done_cnt = 0
     batches = TransferPair.bin_batch_to_pair(b.data for b in broker)
     with create_adapter(store_partition_inner) as db:
         L.debug(
             f"do_store_create_db: {tag} for partition: {store_partition_inner}"
         )
         with db.new_batch() as wb:
             for k, v in batches:
                 wb.put(k, v)
                 done_cnt += 1
         L.debug(
             f"do_store_done: {tag} for partition: {store_partition_inner}"
         )
     TransferService.remove_broker(tag)
     return done_cnt
예제 #8
0
    def test_pull(self):
        tag = 'test_pull'
        send_broker = TransferService.get_or_create_broker('test_pull')
        self.__put_to_broker(send_broker)
        send_broker.signal_write_finish()

        transfer_service = GrpcTransferService()

        options = {
            TransferConfKeys.CONFKEY_TRANSFER_SERVICE_PORT: transfer_port
        }
        future = self.__executor_pool.submit(self.test_start_server)

        from time import sleep
        sleep(1)

        transfer_client = TransferClient()
        recv_broker = transfer_client.recv(transfer_endpont, tag)

        while not recv_broker.is_closable():
            print('recv:', recv_broker.get())

        future.result(timeout=10)
예제 #9
0
    def run_task(self, task: ErTask):
        if L.isEnabledFor(logging.TRACE):
            L.trace(
                f'[RUNTASK] start. task_name={task._name}, inputs={task._inputs}, outputs={task._outputs}, task_id={task._id}'
            )
        else:
            L.debug(
                f'[RUNTASK] start. task_name={task._name}, task_id={task._id}')
        functors = task._job._functors
        result = task

        if task._name == 'get':
            # TODO:1: move to create_serdes
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                L.trace(
                    f"get: key: {self.functor_serdes.deserialize(f._key)}, path: {input_adapter.path}"
                )
                value = input_adapter.get(f._key)
                result = ErPair(key=f._key, value=value)
        elif task._name == 'getAll':
            tag = f'{task._id}'
            er_pair = create_functor(functors[0]._body)
            input_store_head = task._job._inputs[0]
            key_serdes = create_serdes(input_store_head._store_locator._serdes)

            def generate_broker():
                with create_adapter(
                        task._inputs[0]) as db, db.iteritems() as rb:
                    limit = None if er_pair._key is None else key_serdes.deserialize(
                        er_pair._key)
                    try:
                        yield from TransferPair.pair_to_bin_batch(rb,
                                                                  limit=limit)
                    finally:
                        TransferService.remove_broker(tag)

            TransferService.set_broker(tag, generate_broker())
        elif task._name == 'count':
            with create_adapter(task._inputs[0]) as input_adapter:
                result = ErPair(key=self.functor_serdes.serialize('result'),
                                value=self.functor_serdes.serialize(
                                    input_adapter.count()))

        elif task._name == 'putBatch':
            partition = task._outputs[0]
            tag = f'{task._id}'
            PutBatchTask(tag, partition).run()

        elif task._name == 'putAll':
            output_partition = task._outputs[0]
            tag = f'{task._id}'
            L.trace(f'egg_pair putAll: transfer service tag={tag}')
            tf = TransferPair(tag)
            store_broker_result = tf.store_broker(output_partition,
                                                  False).result()
            # TODO:2: should wait complete?, command timeout?

        elif task._name == 'put':
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                value = input_adapter.put(f._key, f._value)
                #result = ErPair(key=f._key, value=bytes(value))

        elif task._name == 'destroy':
            input_store_locator = task._inputs[0]._store_locator
            namespace = input_store_locator._namespace
            name = input_store_locator._name
            store_type = input_store_locator._store_type
            L.debug(
                f'destroying store_type={store_type}, namespace={namespace}, name={name}'
            )
            if name == '*':
                from eggroll.roll_pair.utils.pair_utils import get_db_path, get_data_dir
                target_paths = list()
                if store_type == '*':
                    data_dir = get_data_dir()
                    store_types = os.listdir(data_dir)
                    for store_type in store_types:
                        target_paths.append('/'.join(
                            [data_dir, store_type, namespace]))
                else:
                    db_path = get_db_path(task._inputs[0])
                    target_paths.append(db_path[:db_path.rfind('*')])

                real_data_dir = os.path.realpath(get_data_dir())
                for path in target_paths:
                    realpath = os.path.realpath(path)
                    if os.path.exists(path):
                        if realpath == "/" \
                                or realpath == real_data_dir \
                                or not realpath.startswith(real_data_dir):
                            raise ValueError(
                                f'trying to delete a dangerous path: {realpath}'
                            )
                        else:
                            shutil.rmtree(path)
            else:
                options = task._job._options
                with create_adapter(task._inputs[0],
                                    options=options) as input_adapter:
                    input_adapter.destroy(options=options)

        elif task._name == 'delete':
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                if input_adapter.delete(f._key):
                    L.trace("delete k success")

        elif task._name == 'mapValues':
            f = create_functor(functors[0]._body)

            def map_values_wrapper(input_iterator, key_serdes, value_serdes,
                                   output_writebatch):
                for k_bytes, v_bytes in input_iterator:
                    v = value_serdes.deserialize(v_bytes)
                    output_writebatch.put(k_bytes,
                                          value_serdes.serialize(f(v)))

            self._run_unary(map_values_wrapper, task)
        elif task._name == 'map':
            f = create_functor(functors[0]._body)

            def map_wrapper(input_iterator, key_serdes, value_serdes,
                            shuffle_broker):
                for k_bytes, v_bytes in input_iterator:
                    k1, v1 = f(key_serdes.deserialize(k_bytes),
                               value_serdes.deserialize(v_bytes))
                    shuffle_broker.put(
                        (key_serdes.serialize(k1), value_serdes.serialize(v1)))

            self._run_unary(map_wrapper, task, shuffle=True)

        elif task._name == 'reduce':
            seq_op_result = self.aggregate_seq(task=task)
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(seq_op_result))

        elif task._name == 'aggregate':
            seq_op_result = self.aggregate_seq(task=task)
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(seq_op_result))

        elif task._name == 'mapPartitions':
            reduce_op = create_functor(functors[1]._body)
            shuffle = create_functor(functors[2]._body)

            def map_partitions_wrapper(input_iterator, key_serdes,
                                       value_serdes, output_writebatch):
                f = create_functor(functors[0]._body)
                value = f(generator(key_serdes, value_serdes, input_iterator))
                if isinstance(value, Iterable):
                    for k1, v1 in value:
                        if shuffle:
                            output_writebatch.put((key_serdes.serialize(k1),
                                                   value_serdes.serialize(v1)))
                        else:
                            output_writebatch.put(key_serdes.serialize(k1),
                                                  value_serdes.serialize(v1))
                else:
                    key = input_iterator.key()
                    output_writebatch.put((key, value_serdes.serialize(value)))

            self._run_unary(map_partitions_wrapper,
                            task,
                            shuffle=shuffle,
                            reduce_op=reduce_op)

        elif task._name == 'collapsePartitions':

            def collapse_partitions_wrapper(input_iterator, key_serdes,
                                            value_serdes, output_writebatch):
                f = create_functor(functors[0]._body)
                value = f(generator(key_serdes, value_serdes, input_iterator))
                if input_iterator.last():
                    key = input_iterator.key()
                    output_writebatch.put(key, value_serdes.serialize(value))

            self._run_unary(collapse_partitions_wrapper, task)

        elif task._name == 'flatMap':
            shuffle = create_functor(functors[1]._body)

            def flat_map_wraaper(input_iterator, key_serdes, value_serdes,
                                 output_writebatch):
                f = create_functor(functors[0]._body)
                for k1, v1 in input_iterator:
                    for k2, v2 in f(key_serdes.deserialize(k1),
                                    value_serdes.deserialize(v1)):
                        if shuffle:
                            output_writebatch.put((key_serdes.serialize(k2),
                                                   value_serdes.serialize(v2)))
                        else:
                            output_writebatch.put(key_serdes.serialize(k2),
                                                  value_serdes.serialize(v2))

            self._run_unary(flat_map_wraaper, task, shuffle=shuffle)

        elif task._name == 'glom':

            def glom_wrapper(input_iterator, key_serdes, value_serdes,
                             output_writebatch):
                k_tmp = None
                v_list = []
                for k, v in input_iterator:
                    v_list.append((key_serdes.deserialize(k),
                                   value_serdes.deserialize(v)))
                    k_tmp = k
                if k_tmp is not None:
                    output_writebatch.put(k_tmp,
                                          value_serdes.serialize(v_list))

            self._run_unary(glom_wrapper, task)

        elif task._name == 'sample':

            def sample_wrapper(input_iterator, key_serdes, value_serdes,
                               output_writebatch):
                fraction = create_functor(functors[0]._body)
                seed = create_functor(functors[1]._body)
                input_iterator.first()
                random_state = np.random.RandomState(seed)
                for k, v in input_iterator:
                    if random_state.rand() < fraction:
                        output_writebatch.put(k, v)

            self._run_unary(sample_wrapper, task)

        elif task._name == 'filter':

            def filter_wrapper(input_iterator, key_serdes, value_serdes,
                               output_writebatch):
                f = create_functor(functors[0]._body)
                for k, v in input_iterator:
                    if f(key_serdes.deserialize(k),
                         value_serdes.deserialize(v)):
                        output_writebatch.put(k, v)

            self._run_unary(filter_wrapper, task)

        elif task._name == 'join':

            def merge_join_wrapper(left_iterator, left_key_serdes,
                                   left_value_serdes, right_iterator,
                                   right_key_serdes, right_value_serdes,
                                   output_writebatch):
                if not left_iterator.adapter.is_sorted(
                ) or not right_iterator.adapter.is_sorted():
                    raise RuntimeError(
                        f"merge join cannot be applied: not both store types support sorting. "
                        f"left type: {type(left_iterator.adapter)}, is_sorted: {left_iterator.adapter.is_sorted()}; "
                        f"right type: {type(right_iterator.adapter)}, is_sorted: {right_iterator.adapter.is_sorted()}"
                    )
                f = create_functor(functors[0]._body)
                is_same_serdes = left_key_serdes == right_key_serdes

                l_iter = iter(left_iterator)
                r_iter = iter(right_iterator)

                try:
                    k_left, v_left_bytes = next(l_iter)
                    k_right_raw, v_right_bytes = next(r_iter)
                    if is_same_serdes:
                        k_right = k_right_raw
                    else:
                        k_right = left_key_serdes.serialize(
                            right_key_serdes.deserialize(k_right_raw))

                    while True:
                        while k_right < k_left:
                            k_right_raw, v_right_bytes = next(r_iter)
                            if is_same_serdes:
                                k_right = k_right_raw
                            else:
                                k_right = left_key_serdes.serialize(
                                    right_key_serdes.deserialize(k_right_raw))

                        while k_left < k_right:
                            k_left, v_left_bytes = next(l_iter)

                        if k_left == k_right:
                            output_writebatch.put(
                                k_left,
                                left_value_serdes.serialize(
                                    f(
                                        left_value_serdes.deserialize(
                                            v_left_bytes),
                                        right_value_serdes.deserialize(
                                            v_right_bytes))))
                            k_left, v_left_bytes = next(l_iter)
                            # skips next(r_iter) to avoid duplicate codes for the 3rd time
                except StopIteration as e:
                    return

            def hash_join_wrapper(left_iterator, left_key_serdes,
                                  left_value_serdes, right_iterator,
                                  right_key_serdes, right_value_serdes,
                                  output_writebatch):
                f = create_functor(functors[0]._body)
                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, l_v_bytes in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    r_v_bytes = right_iterator.adapter.get(k_left)
                    if r_v_bytes:
                        output_writebatch.put(
                            k_left,
                            left_value_serdes.serialize(
                                f(left_value_serdes.deserialize(l_v_bytes),
                                  right_value_serdes.deserialize(r_v_bytes))))

            join_type = task._job._options.get('join_type', 'merge')

            if join_type == 'merge':
                self._run_binary(merge_join_wrapper, task)
            else:
                self._run_binary(hash_join_wrapper, task)

        elif task._name == 'subtractByKey':

            def merge_subtract_by_key_wrapper(left_iterator, left_key_serdes,
                                              left_value_serdes,
                                              right_iterator, right_key_serdes,
                                              right_value_serdes,
                                              output_writebatch):
                if not left_iterator.adapter.is_sorted(
                ) or not right_iterator.adapter.is_sorted():
                    raise RuntimeError(
                        f"merge subtract_by_key cannot be applied: not both store types support sorting. "
                        f"left type: {type(left_iterator.adapter)}, is_sorted: {left_iterator.adapter.is_sorted()}; "
                        f"right type: {type(right_iterator.adapter)}, is_sorted: {right_iterator.adapter.is_sorted()}"
                    )

                is_same_serdes = left_key_serdes == right_key_serdes
                l_iter = iter(left_iterator)
                r_iter = iter(right_iterator)
                is_left_stopped = False
                is_equal = False
                try:
                    k_left, v_left = next(l_iter)
                except StopIteration:
                    is_left_stopped = True
                    k_left = None
                    v_left = None

                try:
                    k_right_raw, v_right = next(r_iter)
                except StopIteration:
                    is_left_stopped = False
                    k_right_raw = None
                    v_right = None

                # left is None, output must be None
                if k_left is None:
                    return

                try:
                    if k_left is None:
                        raise StopIteration()
                    if k_right_raw is None:
                        raise StopIteration()

                    if is_same_serdes:
                        k_right = k_right_raw
                    else:
                        k_right = left_key_serdes.serialize(
                            right_key_serdes.deserialize(k_right_raw))

                    while True:
                        is_left_stopped = False
                        if k_left < k_right:
                            output_writebatch.put(k_left, v_left)
                            k_left, v_left = next(l_iter)
                            is_left_stopped = True
                        elif k_left == k_right:
                            is_equal = True
                            is_left_stopped = True
                            k_left, v_left = next(l_iter)
                            is_left_stopped = False
                            is_equal = False
                            k_right_raw, v_right = next(r_iter)
                            k_right = left_key_serdes.serialize(
                                right_key_serdes.deserialize(k_right_raw))
                        else:
                            k_right_raw, v_right = next(r_iter)
                            k_right = left_key_serdes.serialize(
                                right_key_serdes.deserialize(k_right_raw))
                        is_left_stopped = True
                except StopIteration as e:
                    pass
                if not is_left_stopped and not is_equal:
                    try:
                        if k_left is not None and v_left is not None:
                            output_writebatch.put(k_left, v_left)
                            while True:
                                k_left, v_left = next(l_iter)
                                output_writebatch.put(k_left, v_left)
                    except StopIteration as e:
                        pass
                elif is_left_stopped and not is_equal and k_left is not None:
                    output_writebatch.put(k_left, v_left)

                return

            def hash_subtract_by_key_wrapper(left_iterator, left_key_serdes,
                                             left_value_serdes, right_iterator,
                                             right_key_serdes,
                                             right_value_serdes,
                                             output_writebatch):
                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, v_left in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    v_right = right_iterator.adapter.get(k_left)
                    if v_right is None:
                        output_writebatch.put(k_left, v_left)

            subtract_by_key_type = task._job._options.get(
                'subtract_by_key_type', 'merge')
            if subtract_by_key_type == 'merge':
                self._run_binary(merge_subtract_by_key_wrapper, task)
            else:
                self._run_binary(hash_subtract_by_key_wrapper, task)

        elif task._name == 'union':

            def merge_union_wrapper(left_iterator, left_key_serdes,
                                    left_value_serdes, right_iterator,
                                    right_key_serdes, right_value_serdes,
                                    output_writebatch):
                if not left_iterator.adapter.is_sorted(
                ) or not right_iterator.adapter.is_sorted():
                    raise RuntimeError(
                        f"merge union cannot be applied: not both store types support sorting. "
                        f"left type: {type(left_iterator.adapter)}, is_sorted: {left_iterator.adapter.is_sorted()}; "
                        f"right type: {type(right_iterator.adapter)}, is_sorted: {right_iterator.adapter.is_sorted()}"
                    )
                f = create_functor(functors[0]._body)
                is_same_serdes = left_key_serdes == right_key_serdes

                l_iter = iter(left_iterator)
                r_iter = iter(right_iterator)
                k_left = None
                v_left_bytes = None
                k_right = None
                v_right_bytes = None
                none_none = (None, None)

                is_left_stopped = False
                is_equal = False
                try:
                    k_left, v_left_bytes = next(l_iter, none_none)
                    k_right_raw, v_right_bytes = next(r_iter, none_none)

                    if k_left is None and k_right_raw is None:
                        return
                    elif k_left is None:
                        is_left_stopped = True
                        raise StopIteration()
                    elif k_right_raw is None:
                        is_left_stopped = False
                        raise StopIteration()

                    if is_same_serdes:
                        k_right = k_right_raw
                    else:
                        k_right = left_key_serdes.serialize(
                            right_key_serdes.deserialize(k_right_raw))

                    while True:
                        is_left_stopped = False
                        while k_right < k_left:
                            if is_same_serdes:
                                output_writebatch.put(k_right, v_right_bytes)
                            else:
                                output_writebatch.put(
                                    k_right,
                                    left_value_serdes.serialize(
                                        right_value_serdes.deserialize(
                                            v_right_bytes)))

                            k_right_raw, v_right_bytes = next(r_iter)

                            if is_same_serdes:
                                k_right = k_right_raw
                            else:
                                k_right = left_key_serdes.serialize(
                                    right_key_serdes.deserialize(k_right_raw))

                        is_left_stopped = True
                        while k_left < k_right:
                            output_writebatch.put(k_left, v_left_bytes)
                            k_left, v_left_bytes = next(l_iter)

                        if k_left == k_right:
                            is_equal = True
                            output_writebatch.put(
                                k_left,
                                left_value_serdes.serialize(
                                    f(
                                        left_value_serdes.deserialize(
                                            v_left_bytes),
                                        right_value_serdes.deserialize(
                                            v_right_bytes))))
                            is_left_stopped = True
                            k_left, v_left_bytes = next(l_iter)

                            is_left_stopped = False
                            k_right_raw, v_right_bytes = next(r_iter)

                            if is_same_serdes:
                                k_right = k_right_raw
                            else:
                                k_right = left_key_serdes.serialize(
                                    right_key_serdes.deserialize(k_right_raw))
                            is_equal = False
                except StopIteration as e:
                    pass

                if not is_left_stopped:
                    try:
                        output_writebatch.put(k_left, v_left_bytes)
                        while True:
                            k_left, v_left_bytes = next(l_iter)
                            output_writebatch.put(k_left, v_left_bytes)
                    except StopIteration as e:
                        pass
                else:
                    try:
                        if not is_equal:
                            if is_same_serdes:
                                output_writebatch.put(k_right_raw,
                                                      v_right_bytes)
                            else:
                                output_writebatch.put(
                                    left_key_serdes.serialize(
                                        right_key_serdes.deserialize(
                                            k_right_raw)),
                                    left_value_serdes.serialize(
                                        right_value_serdes.deserialize(
                                            v_right_bytes)))
                        while True:
                            k_right_raw, v_right_bytes = next(r_iter)
                            if is_same_serdes:
                                output_writebatch.put(k_right_raw,
                                                      v_right_bytes)
                            else:
                                output_writebatch.put(
                                    left_key_serdes.serialize(
                                        right_key_serdes.deserialize(
                                            k_right_raw)),
                                    left_value_serdes.serialize(
                                        right_value_serdes.deserialize(
                                            v_right_bytes)))
                    except StopIteration as e:
                        pass

                    # end of merge union wrapper
                    return

            def hash_union_wrapper(left_iterator, left_key_serdes,
                                   left_value_serdes, right_iterator,
                                   right_key_serdes, right_value_serdes,
                                   output_writebatch):
                f = create_functor(functors[0]._body)

                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, v_left in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    v_right = right_iterator.adapter.get(k_left)
                    if v_right is None:
                        output_writebatch.put(k_left, v_left)
                    else:
                        v_final = f(left_value_serdes.deserialize(v_left),
                                    right_value_serdes.deserialize(v_right))
                        output_writebatch.put(
                            k_left, left_value_serdes.serialize(v_final))

                right_iterator.first()
                for k_right, v_right in right_iterator:
                    if is_diff_serdes:
                        final_v_bytes = output_writebatch.get(
                            left_key_serdes.serialize(
                                right_key_serdes.deserialize(k_right)))
                    else:
                        final_v_bytes = output_writebatch.get(k_right)

                    if final_v_bytes is None:
                        output_writebatch.put(k_right, v_right)

            union_type = task._job._options.get('union_type', 'merge')

            if union_type == 'merge':
                self._run_binary(merge_union_wrapper, task)
            else:
                self._run_binary(hash_union_wrapper, task)

        elif task._name == 'withStores':
            f = create_functor(functors[0]._body)
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(f(task)))

        if L.isEnabledFor(logging.TRACE):
            L.trace(
                f'[RUNTASK] end. task_name={task._name}, inputs={task._inputs}, outputs={task._outputs}, task_id={task._id}'
            )
        else:
            L.debug(
                f'[RUNTASK] end. task_name={task._name}, task_id={task._id}')

        return result
예제 #10
0
    def run_task(self, task: ErTask):
        L.info("start run task")
        L.debug("start run task")
        functors = task._job._functors
        result = task

        if task._name == 'get':
            L.info("egg_pair get call")
            # TODO:1: move to create_serdes
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                print("get key:{} and path is:{}".format(
                    self.functor_serdes.deserialize(f._key),
                    input_adapter.path))
                value = input_adapter.get(f._key)
                result = ErPair(key=f._key, value=value)
        elif task._name == 'getAll':
            L.info("egg_pair getAll call")
            tag = f'{task._id}'

            def generate_broker():
                with create_adapter(
                        task._inputs[0]) as db, db.iteritems() as rb:
                    yield from TransferPair.pair_to_bin_batch(rb)
                    # TODO:0 how to remove?
                    # TransferService.remove_broker(tag)

            TransferService.set_broker(tag, generate_broker())
        elif task._name == 'count':
            L.info('egg_pair count call')
            with create_adapter(task._inputs[0]) as input_adapter:
                result = ErPair(key=self.functor_serdes.serialize('result'),
                                value=self.functor_serdes.serialize(
                                    input_adapter.count()))

        # TODO:1: multiprocessor scenario
        elif task._name == 'putAll':
            L.info("egg_pair putAll call")
            output_partition = task._outputs[0]
            tag = f'{task._id}'
            L.info(f'egg_pair transfer service tag:{tag}')
            tf = TransferPair(tag)
            store_broker_result = tf.store_broker(output_partition,
                                                  False).result()
            # TODO:2: should wait complete?, command timeout?
            L.debug(f"putAll result:{store_broker_result}")

        if task._name == 'put':
            L.info("egg_pair put call")
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                value = input_adapter.put(f._key, f._value)
                #result = ErPair(key=f._key, value=bytes(value))

        if task._name == 'destroy':
            with create_adapter(task._inputs[0]) as input_adapter:
                input_adapter.destroy()
                L.info("finish destroy")

        if task._name == 'delete':
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                L.info("delete k:{}, its value:{}".format(
                    f._key, input_adapter.get(f._key)))
                if input_adapter.delete(f._key):
                    L.info("delete k success")

        if task._name == 'mapValues':
            f = create_functor(functors[0]._body)

            def map_values_wrapper(input_iterator, key_serdes, value_serdes,
                                   output_writebatch):
                for k_bytes, v_bytes in input_iterator:
                    v = value_serdes.deserialize(v_bytes)
                    output_writebatch.put(k_bytes,
                                          value_serdes.serialize(f(v)))

            self._run_unary(map_values_wrapper, task)
        elif task._name == 'map':
            f = create_functor(functors[0]._body)

            def map_wrapper(input_iterator, key_serdes, value_serdes,
                            shuffle_broker):
                for k_bytes, v_bytes in input_iterator:
                    k1, v1 = f(key_serdes.deserialize(k_bytes),
                               value_serdes.deserialize(v_bytes))
                    shuffle_broker.put(
                        (key_serdes.serialize(k1), value_serdes.serialize(v1)))
                L.info('finish calculating')

            self._run_unary(map_wrapper, task, shuffle=True)
            L.info('map finished')

        elif task._name == 'reduce':
            seq_op_result = self.aggregate_seq(task=task)
            L.info('reduce finished')
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(seq_op_result))

        elif task._name == 'aggregate':
            L.info('ready to aggregate')
            seq_op_result = self.aggregate_seq(task=task)
            L.info('aggregate finished')
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(seq_op_result))

        elif task._name == 'mapPartitions':

            def map_partitions_wrapper(input_iterator, key_serdes,
                                       value_serdes, shuffle_broker):
                f = create_functor(functors[0]._body)
                value = f(generator(key_serdes, value_serdes, input_iterator))
                if input_iterator.last():
                    L.info("value of mapPartitions:{}".format(value))
                    if isinstance(value, Iterable):
                        for k1, v1 in value:
                            shuffle_broker.put((key_serdes.serialize(k1),
                                                value_serdes.serialize(v1)))
                    else:
                        key = input_iterator.key()
                        shuffle_broker.put(
                            (key, value_serdes.serialize(value)))

            self._run_unary(map_partitions_wrapper, task, shuffle=True)

        elif task._name == 'collapsePartitions':

            def collapse_partitions_wrapper(input_iterator, key_serdes,
                                            value_serdes, output_writebatch):
                f = create_functor(functors[0]._body)
                value = f(generator(key_serdes, value_serdes, input_iterator))
                if input_iterator.last():
                    key = input_iterator.key()
                    output_writebatch.put(key, value_serdes.serialize(value))

            self._run_unary(collapse_partitions_wrapper, task)

        elif task._name == 'flatMap':

            def flat_map_wraaper(input_iterator, key_serdes, value_serdes,
                                 output_writebatch):
                f = create_functor(functors[0]._body)
                for k1, v1 in input_iterator:
                    for k2, v2 in f(key_serdes.deserialize(k1),
                                    value_serdes.deserialize(v1)):
                        output_writebatch.put(key_serdes.serialize(k2),
                                              value_serdes.serialize(v2))

            self._run_unary(flat_map_wraaper, task)

        elif task._name == 'glom':

            def glom_wrapper(input_iterator, key_serdes, value_serdes,
                             output_writebatch):
                k_tmp = None
                v_list = []
                for k, v in input_iterator:
                    v_list.append((key_serdes.deserialize(k),
                                   value_serdes.deserialize(v)))
                    k_tmp = k
                if k_tmp is not None:
                    output_writebatch.put(k_tmp,
                                          value_serdes.serialize(v_list))

            self._run_unary(glom_wrapper, task)

        elif task._name == 'sample':

            def sample_wrapper(input_iterator, key_serdes, value_serdes,
                               output_writebatch):
                fraction = create_functor(functors[0]._body)
                seed = create_functor(functors[1]._body)
                input_iterator.first()
                random_state = np.random.RandomState(seed)
                for k, v in input_iterator:
                    if random_state.rand() < fraction:
                        output_writebatch.put(k, v)

            self._run_unary(sample_wrapper, task)

        elif task._name == 'filter':

            def filter_wrapper(input_iterator, key_serdes, value_serdes,
                               output_writebatch):
                f = create_functor(functors[0]._body)
                for k, v in input_iterator:
                    if f(key_serdes.deserialize(k),
                         value_serdes.deserialize(v)):
                        output_writebatch.put(k, v)

            self._run_unary(filter_wrapper, task)

        elif task._name == 'join':

            def merge_join_wrapper(left_iterator, left_key_serdes,
                                   left_value_serdes, right_iterator,
                                   right_key_serdes, right_value_serdes,
                                   output_writebatch):
                if not left_iterator.adapter.is_sorted(
                ) or not right_iterator.adapter.is_sorted():
                    raise RuntimeError(
                        f"merge join cannot be applied: not both store types support sorting. "
                        f"left type: {type(left_iterator.adapter)}, is_sorted: {left_iterator.adapter.is_sorted()}; "
                        f"right type: {type(right_iterator.adapter)}, is_sorted: {right_iterator.adapter.is_sorted()}"
                    )
                f = create_functor(functors[0]._body)
                is_same_serdes = type(left_key_serdes) == type(
                    right_key_serdes)

                l_iter = iter(left_iterator)
                r_iter = iter(right_iterator)

                try:
                    k_left, v_left_bytes = next(l_iter)
                    k_right_raw, v_right_bytes = next(r_iter)
                    if is_same_serdes:
                        k_right = k_right_raw
                    else:
                        k_right = left_key_serdes.serialize(
                            right_key_serdes.deserialize(k_right_raw))

                    while True:
                        while k_right < k_left:
                            k_right_raw, v_right_bytes = next(r_iter)
                            if is_same_serdes:
                                k_right = k_right_raw
                            else:
                                k_right = left_key_serdes.serialize(
                                    right_key_serdes.deserialize(k_right_raw))

                        while k_left < k_right:
                            k_left, v_left_bytes = next(l_iter)

                        if k_left == k_right:
                            output_writebatch.put(
                                k_left,
                                left_value_serdes.serialize(
                                    f(
                                        left_value_serdes.deserialize(
                                            v_left_bytes),
                                        right_value_serdes.deserialize(
                                            v_right_bytes))))
                            k_left, v_left_bytes = next(l_iter)
                            # skips next(r_iter) to avoid duplicate codes for the 3rd time
                except StopIteration as e:
                    return

            def hash_join_wrapper(left_iterator, left_key_serdes,
                                  left_value_serdes, right_iterator,
                                  right_key_serdes, right_value_serdes,
                                  output_writebatch):
                f = create_functor(functors[0]._body)
                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, l_v_bytes in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    r_v_bytes = right_iterator.adapter.get(k_left)
                    if r_v_bytes:
                        #L.info("egg join:{}".format(right_value_serdes.deserialize(r_v_bytes)))
                        output_writebatch.put(
                            k_left,
                            left_value_serdes.serialize(
                                f(left_value_serdes.deserialize(l_v_bytes),
                                  right_value_serdes.deserialize(r_v_bytes))))

            join_type = task._job._options.get('join_type', 'merge')

            if join_type == 'merge':
                self._run_binary(merge_join_wrapper, task)
            else:
                self._run_binary(hash_join_wrapper, task)

        elif task._name == 'subtractByKey':

            def subtract_by_key_wrapper(left_iterator, left_key_serdes,
                                        left_value_serdess, right_iterator,
                                        right_key_serdes, right_value_serdess,
                                        output_writebatch):
                L.info("sub wrapper")
                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, v_left in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    v_right = right_iterator.adapter.get(k_left)
                    if v_right is None:
                        output_writebatch.put(k_left, v_left)

            self._run_binary(subtract_by_key_wrapper, task)

        elif task._name == 'union':

            def union_wrapper(left_iterator, left_key_serdes,
                              left_value_serdess, right_iterator,
                              right_key_serdes, right_value_serdess,
                              output_writebatch):
                f = create_functor(functors[0]._body)

                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, v_left in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    v_right = right_iterator.adapter.get(k_left)
                    if v_right is None:
                        output_writebatch.put(k_left, v_left)
                    else:
                        v_final = f(left_value_serdess.deserialize(v_left),
                                    right_value_serdess.deserialize(v_right))
                        output_writebatch.put(
                            k_left, left_value_serdess.serialize(v_final))

                right_iterator.first()
                for k_right, v_right in right_iterator:
                    if is_diff_serdes:
                        final_v_bytes = output_writebatch.get(
                            left_key_serdes.serialize(
                                right_key_serdes.deserialize(k_right)))
                    else:
                        final_v_bytes = output_writebatch.get(k_right)

                    if final_v_bytes is None:
                        output_writebatch.put(k_right, v_right)

            self._run_binary(union_wrapper, task)

        elif task._name == 'withStores':
            f = create_functor(functors[0]._body)
            result = ErPair(
                key=self.functor_serdes.serialize(task._inputs[0]._id),
                value=self.functor_serdes.serialize(f(task._inputs)))
        return result