Ejemplo n.º 1
0
 def _run_unary(self, func, task, shuffle=False):
     key_serdes = create_serdes(task._inputs[0]._store_locator._serdes)
     value_serdes = create_serdes(task._inputs[0]._store_locator._serdes)
     with create_adapter(task._inputs[0]) as input_db:
         L.debug(f"create_store_adatper: {task._inputs[0]}")
         with input_db.iteritems() as rb:
             L.debug(f"create_store_adatper_iter: {task._inputs[0]}")
             from eggroll.roll_pair.transfer_pair import TransferPair, BatchBroker
             if shuffle:
                 total_partitions = task._inputs[
                     0]._store_locator._total_partitions
                 output_store = task._job._outputs[0]
                 shuffle_broker = FifoBroker()
                 write_bb = BatchBroker(shuffle_broker)
                 try:
                     shuffler = TransferPair(transfer_id=task._job._id)
                     store_future = shuffler.store_broker(
                         task._outputs[0], True, total_partitions)
                     scatter_future = shuffler.scatter(
                         shuffle_broker,
                         partitioner(hash_func=hash_code,
                                     total_partitions=total_partitions),
                         output_store)
                     func(rb, key_serdes, value_serdes, write_bb)
                 finally:
                     write_bb.signal_write_finish()
                 scatter_results = scatter_future.result()
                 store_result = store_future.result()
                 L.debug(f"scatter_result:{scatter_results}")
                 L.debug(f"gather_result:{store_result}")
             else:
                 # TODO: modification may be needed when store options finished
                 with create_adapter(task._outputs[0],
                                     options=task._job._options
                                     ) as db, db.new_batch() as wb:
                     func(rb, key_serdes, value_serdes, wb)
             L.debug(f"close_store_adatper:{task._inputs[0]}")
Ejemplo n.º 2
0
    def _run_unary(self, func, task, shuffle=False, reduce_op=None):
        input_store_head = task._job._inputs[0]
        output_store_head = task._job._outputs[0]
        input_key_serdes = create_serdes(
            input_store_head._store_locator._serdes)
        input_value_serdes = create_serdes(
            input_store_head._store_locator._serdes)
        output_key_serdes = create_serdes(
            output_store_head._store_locator._serdes)
        output_value_serdes = create_serdes(
            output_store_head._store_locator._serdes)

        if input_key_serdes != output_key_serdes or \
                input_value_serdes != output_value_serdes:
            raise ValueError(
                f"input key-value serdes:{(input_key_serdes, input_value_serdes)}"
                f"differ from output key-value serdes:{(output_key_serdes, output_value_serdes)}"
            )

        if shuffle:
            from eggroll.roll_pair.transfer_pair import TransferPair
            input_total_partitions = input_store_head._store_locator._total_partitions
            output_total_partitions = output_store_head._store_locator._total_partitions
            output_store = output_store_head

            my_server_node_id = get_static_er_conf().get(
                'server_node_id', None)
            shuffler = TransferPair(transfer_id=task._job._id)
            if not task._outputs or \
                    (my_server_node_id is not None
                     and my_server_node_id != task._outputs[0]._processor._server_node_id):
                store_future = None
            else:
                store_future = shuffler.store_broker(
                    store_partition=task._outputs[0],
                    is_shuffle=True,
                    total_writers=input_total_partitions,
                    reduce_op=reduce_op)

            if not task._inputs or \
                    (my_server_node_id is not None
                     and my_server_node_id != task._inputs[0]._processor._server_node_id):
                scatter_future = None
            else:
                shuffle_broker = FifoBroker()
                write_bb = BatchBroker(shuffle_broker)
                try:
                    scatter_future = shuffler.scatter(
                        input_broker=shuffle_broker,
                        partition_function=partitioner(
                            hash_func=hash_code,
                            total_partitions=output_total_partitions),
                        output_store=output_store)
                    with create_adapter(task._inputs[0]) as input_db, \
                        input_db.iteritems() as rb:
                        func(rb, input_key_serdes, input_value_serdes,
                             write_bb)
                finally:
                    write_bb.signal_write_finish()

            if scatter_future:
                scatter_results = scatter_future.result()
            else:
                scatter_results = 'no scatter for this partition'
            if store_future:
                store_results = store_future.result()
            else:
                store_results = 'no store for this partition'
        else:  # no shuffle
            with create_adapter(task._inputs[0]) as input_db, \
                    input_db.iteritems() as rb, \
                    create_adapter(task._outputs[0], options=task._job._options) as db, \
                    db.new_batch() as wb:
                func(rb, input_key_serdes, input_value_serdes, wb)
            L.trace(f"close_store_adatper:{task._inputs[0]}")
Ejemplo n.º 3
0
    def run_task(self, task: ErTask):
        if L.isEnabledFor(logging.TRACE):
            L.trace(
                f'[RUNTASK] start. task_name={task._name}, inputs={task._inputs}, outputs={task._outputs}, task_id={task._id}'
            )
        else:
            L.debug(
                f'[RUNTASK] start. task_name={task._name}, task_id={task._id}')
        functors = task._job._functors
        result = task

        if task._name == 'get':
            # TODO:1: move to create_serdes
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                L.trace(
                    f"get: key: {self.functor_serdes.deserialize(f._key)}, path: {input_adapter.path}"
                )
                value = input_adapter.get(f._key)
                result = ErPair(key=f._key, value=value)
        elif task._name == 'getAll':
            tag = f'{task._id}'
            er_pair = create_functor(functors[0]._body)
            input_store_head = task._job._inputs[0]
            key_serdes = create_serdes(input_store_head._store_locator._serdes)

            def generate_broker():
                with create_adapter(
                        task._inputs[0]) as db, db.iteritems() as rb:
                    limit = None if er_pair._key is None else key_serdes.deserialize(
                        er_pair._key)
                    try:
                        yield from TransferPair.pair_to_bin_batch(rb,
                                                                  limit=limit)
                    finally:
                        TransferService.remove_broker(tag)

            TransferService.set_broker(tag, generate_broker())
        elif task._name == 'count':
            with create_adapter(task._inputs[0]) as input_adapter:
                result = ErPair(key=self.functor_serdes.serialize('result'),
                                value=self.functor_serdes.serialize(
                                    input_adapter.count()))

        elif task._name == 'putBatch':
            partition = task._outputs[0]
            tag = f'{task._id}'
            PutBatchTask(tag, partition).run()

        elif task._name == 'putAll':
            output_partition = task._outputs[0]
            tag = f'{task._id}'
            L.trace(f'egg_pair putAll: transfer service tag={tag}')
            tf = TransferPair(tag)
            store_broker_result = tf.store_broker(output_partition,
                                                  False).result()
            # TODO:2: should wait complete?, command timeout?

        elif task._name == 'put':
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                value = input_adapter.put(f._key, f._value)
                #result = ErPair(key=f._key, value=bytes(value))

        elif task._name == 'destroy':
            input_store_locator = task._inputs[0]._store_locator
            namespace = input_store_locator._namespace
            name = input_store_locator._name
            store_type = input_store_locator._store_type
            L.debug(
                f'destroying store_type={store_type}, namespace={namespace}, name={name}'
            )
            if name == '*':
                from eggroll.roll_pair.utils.pair_utils import get_db_path, get_data_dir
                target_paths = list()
                if store_type == '*':
                    data_dir = get_data_dir()
                    store_types = os.listdir(data_dir)
                    for store_type in store_types:
                        target_paths.append('/'.join(
                            [data_dir, store_type, namespace]))
                else:
                    db_path = get_db_path(task._inputs[0])
                    target_paths.append(db_path[:db_path.rfind('*')])

                real_data_dir = os.path.realpath(get_data_dir())
                for path in target_paths:
                    realpath = os.path.realpath(path)
                    if os.path.exists(path):
                        if realpath == "/" \
                                or realpath == real_data_dir \
                                or not realpath.startswith(real_data_dir):
                            raise ValueError(
                                f'trying to delete a dangerous path: {realpath}'
                            )
                        else:
                            shutil.rmtree(path)
            else:
                options = task._job._options
                with create_adapter(task._inputs[0],
                                    options=options) as input_adapter:
                    input_adapter.destroy(options=options)

        elif task._name == 'delete':
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                if input_adapter.delete(f._key):
                    L.trace("delete k success")

        elif task._name == 'mapValues':
            f = create_functor(functors[0]._body)

            def map_values_wrapper(input_iterator, key_serdes, value_serdes,
                                   output_writebatch):
                for k_bytes, v_bytes in input_iterator:
                    v = value_serdes.deserialize(v_bytes)
                    output_writebatch.put(k_bytes,
                                          value_serdes.serialize(f(v)))

            self._run_unary(map_values_wrapper, task)
        elif task._name == 'map':
            f = create_functor(functors[0]._body)

            def map_wrapper(input_iterator, key_serdes, value_serdes,
                            shuffle_broker):
                for k_bytes, v_bytes in input_iterator:
                    k1, v1 = f(key_serdes.deserialize(k_bytes),
                               value_serdes.deserialize(v_bytes))
                    shuffle_broker.put(
                        (key_serdes.serialize(k1), value_serdes.serialize(v1)))

            self._run_unary(map_wrapper, task, shuffle=True)

        elif task._name == 'reduce':
            seq_op_result = self.aggregate_seq(task=task)
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(seq_op_result))

        elif task._name == 'aggregate':
            seq_op_result = self.aggregate_seq(task=task)
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(seq_op_result))

        elif task._name == 'mapPartitions':
            reduce_op = create_functor(functors[1]._body)
            shuffle = create_functor(functors[2]._body)

            def map_partitions_wrapper(input_iterator, key_serdes,
                                       value_serdes, output_writebatch):
                f = create_functor(functors[0]._body)
                value = f(generator(key_serdes, value_serdes, input_iterator))
                if isinstance(value, Iterable):
                    for k1, v1 in value:
                        if shuffle:
                            output_writebatch.put((key_serdes.serialize(k1),
                                                   value_serdes.serialize(v1)))
                        else:
                            output_writebatch.put(key_serdes.serialize(k1),
                                                  value_serdes.serialize(v1))
                else:
                    key = input_iterator.key()
                    output_writebatch.put((key, value_serdes.serialize(value)))

            self._run_unary(map_partitions_wrapper,
                            task,
                            shuffle=shuffle,
                            reduce_op=reduce_op)

        elif task._name == 'collapsePartitions':

            def collapse_partitions_wrapper(input_iterator, key_serdes,
                                            value_serdes, output_writebatch):
                f = create_functor(functors[0]._body)
                value = f(generator(key_serdes, value_serdes, input_iterator))
                if input_iterator.last():
                    key = input_iterator.key()
                    output_writebatch.put(key, value_serdes.serialize(value))

            self._run_unary(collapse_partitions_wrapper, task)

        elif task._name == 'flatMap':
            shuffle = create_functor(functors[1]._body)

            def flat_map_wraaper(input_iterator, key_serdes, value_serdes,
                                 output_writebatch):
                f = create_functor(functors[0]._body)
                for k1, v1 in input_iterator:
                    for k2, v2 in f(key_serdes.deserialize(k1),
                                    value_serdes.deserialize(v1)):
                        if shuffle:
                            output_writebatch.put((key_serdes.serialize(k2),
                                                   value_serdes.serialize(v2)))
                        else:
                            output_writebatch.put(key_serdes.serialize(k2),
                                                  value_serdes.serialize(v2))

            self._run_unary(flat_map_wraaper, task, shuffle=shuffle)

        elif task._name == 'glom':

            def glom_wrapper(input_iterator, key_serdes, value_serdes,
                             output_writebatch):
                k_tmp = None
                v_list = []
                for k, v in input_iterator:
                    v_list.append((key_serdes.deserialize(k),
                                   value_serdes.deserialize(v)))
                    k_tmp = k
                if k_tmp is not None:
                    output_writebatch.put(k_tmp,
                                          value_serdes.serialize(v_list))

            self._run_unary(glom_wrapper, task)

        elif task._name == 'sample':

            def sample_wrapper(input_iterator, key_serdes, value_serdes,
                               output_writebatch):
                fraction = create_functor(functors[0]._body)
                seed = create_functor(functors[1]._body)
                input_iterator.first()
                random_state = np.random.RandomState(seed)
                for k, v in input_iterator:
                    if random_state.rand() < fraction:
                        output_writebatch.put(k, v)

            self._run_unary(sample_wrapper, task)

        elif task._name == 'filter':

            def filter_wrapper(input_iterator, key_serdes, value_serdes,
                               output_writebatch):
                f = create_functor(functors[0]._body)
                for k, v in input_iterator:
                    if f(key_serdes.deserialize(k),
                         value_serdes.deserialize(v)):
                        output_writebatch.put(k, v)

            self._run_unary(filter_wrapper, task)

        elif task._name == 'join':

            def merge_join_wrapper(left_iterator, left_key_serdes,
                                   left_value_serdes, right_iterator,
                                   right_key_serdes, right_value_serdes,
                                   output_writebatch):
                if not left_iterator.adapter.is_sorted(
                ) or not right_iterator.adapter.is_sorted():
                    raise RuntimeError(
                        f"merge join cannot be applied: not both store types support sorting. "
                        f"left type: {type(left_iterator.adapter)}, is_sorted: {left_iterator.adapter.is_sorted()}; "
                        f"right type: {type(right_iterator.adapter)}, is_sorted: {right_iterator.adapter.is_sorted()}"
                    )
                f = create_functor(functors[0]._body)
                is_same_serdes = left_key_serdes == right_key_serdes

                l_iter = iter(left_iterator)
                r_iter = iter(right_iterator)

                try:
                    k_left, v_left_bytes = next(l_iter)
                    k_right_raw, v_right_bytes = next(r_iter)
                    if is_same_serdes:
                        k_right = k_right_raw
                    else:
                        k_right = left_key_serdes.serialize(
                            right_key_serdes.deserialize(k_right_raw))

                    while True:
                        while k_right < k_left:
                            k_right_raw, v_right_bytes = next(r_iter)
                            if is_same_serdes:
                                k_right = k_right_raw
                            else:
                                k_right = left_key_serdes.serialize(
                                    right_key_serdes.deserialize(k_right_raw))

                        while k_left < k_right:
                            k_left, v_left_bytes = next(l_iter)

                        if k_left == k_right:
                            output_writebatch.put(
                                k_left,
                                left_value_serdes.serialize(
                                    f(
                                        left_value_serdes.deserialize(
                                            v_left_bytes),
                                        right_value_serdes.deserialize(
                                            v_right_bytes))))
                            k_left, v_left_bytes = next(l_iter)
                            # skips next(r_iter) to avoid duplicate codes for the 3rd time
                except StopIteration as e:
                    return

            def hash_join_wrapper(left_iterator, left_key_serdes,
                                  left_value_serdes, right_iterator,
                                  right_key_serdes, right_value_serdes,
                                  output_writebatch):
                f = create_functor(functors[0]._body)
                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, l_v_bytes in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    r_v_bytes = right_iterator.adapter.get(k_left)
                    if r_v_bytes:
                        output_writebatch.put(
                            k_left,
                            left_value_serdes.serialize(
                                f(left_value_serdes.deserialize(l_v_bytes),
                                  right_value_serdes.deserialize(r_v_bytes))))

            join_type = task._job._options.get('join_type', 'merge')

            if join_type == 'merge':
                self._run_binary(merge_join_wrapper, task)
            else:
                self._run_binary(hash_join_wrapper, task)

        elif task._name == 'subtractByKey':

            def merge_subtract_by_key_wrapper(left_iterator, left_key_serdes,
                                              left_value_serdes,
                                              right_iterator, right_key_serdes,
                                              right_value_serdes,
                                              output_writebatch):
                if not left_iterator.adapter.is_sorted(
                ) or not right_iterator.adapter.is_sorted():
                    raise RuntimeError(
                        f"merge subtract_by_key cannot be applied: not both store types support sorting. "
                        f"left type: {type(left_iterator.adapter)}, is_sorted: {left_iterator.adapter.is_sorted()}; "
                        f"right type: {type(right_iterator.adapter)}, is_sorted: {right_iterator.adapter.is_sorted()}"
                    )

                is_same_serdes = left_key_serdes == right_key_serdes
                l_iter = iter(left_iterator)
                r_iter = iter(right_iterator)
                is_left_stopped = False
                is_equal = False
                try:
                    k_left, v_left = next(l_iter)
                except StopIteration:
                    is_left_stopped = True
                    k_left = None
                    v_left = None

                try:
                    k_right_raw, v_right = next(r_iter)
                except StopIteration:
                    is_left_stopped = False
                    k_right_raw = None
                    v_right = None

                # left is None, output must be None
                if k_left is None:
                    return

                try:
                    if k_left is None:
                        raise StopIteration()
                    if k_right_raw is None:
                        raise StopIteration()

                    if is_same_serdes:
                        k_right = k_right_raw
                    else:
                        k_right = left_key_serdes.serialize(
                            right_key_serdes.deserialize(k_right_raw))

                    while True:
                        is_left_stopped = False
                        if k_left < k_right:
                            output_writebatch.put(k_left, v_left)
                            k_left, v_left = next(l_iter)
                            is_left_stopped = True
                        elif k_left == k_right:
                            is_equal = True
                            is_left_stopped = True
                            k_left, v_left = next(l_iter)
                            is_left_stopped = False
                            is_equal = False
                            k_right_raw, v_right = next(r_iter)
                            k_right = left_key_serdes.serialize(
                                right_key_serdes.deserialize(k_right_raw))
                        else:
                            k_right_raw, v_right = next(r_iter)
                            k_right = left_key_serdes.serialize(
                                right_key_serdes.deserialize(k_right_raw))
                        is_left_stopped = True
                except StopIteration as e:
                    pass
                if not is_left_stopped and not is_equal:
                    try:
                        if k_left is not None and v_left is not None:
                            output_writebatch.put(k_left, v_left)
                            while True:
                                k_left, v_left = next(l_iter)
                                output_writebatch.put(k_left, v_left)
                    except StopIteration as e:
                        pass
                elif is_left_stopped and not is_equal and k_left is not None:
                    output_writebatch.put(k_left, v_left)

                return

            def hash_subtract_by_key_wrapper(left_iterator, left_key_serdes,
                                             left_value_serdes, right_iterator,
                                             right_key_serdes,
                                             right_value_serdes,
                                             output_writebatch):
                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, v_left in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    v_right = right_iterator.adapter.get(k_left)
                    if v_right is None:
                        output_writebatch.put(k_left, v_left)

            subtract_by_key_type = task._job._options.get(
                'subtract_by_key_type', 'merge')
            if subtract_by_key_type == 'merge':
                self._run_binary(merge_subtract_by_key_wrapper, task)
            else:
                self._run_binary(hash_subtract_by_key_wrapper, task)

        elif task._name == 'union':

            def merge_union_wrapper(left_iterator, left_key_serdes,
                                    left_value_serdes, right_iterator,
                                    right_key_serdes, right_value_serdes,
                                    output_writebatch):
                if not left_iterator.adapter.is_sorted(
                ) or not right_iterator.adapter.is_sorted():
                    raise RuntimeError(
                        f"merge union cannot be applied: not both store types support sorting. "
                        f"left type: {type(left_iterator.adapter)}, is_sorted: {left_iterator.adapter.is_sorted()}; "
                        f"right type: {type(right_iterator.adapter)}, is_sorted: {right_iterator.adapter.is_sorted()}"
                    )
                f = create_functor(functors[0]._body)
                is_same_serdes = left_key_serdes == right_key_serdes

                l_iter = iter(left_iterator)
                r_iter = iter(right_iterator)
                k_left = None
                v_left_bytes = None
                k_right = None
                v_right_bytes = None
                none_none = (None, None)

                is_left_stopped = False
                is_equal = False
                try:
                    k_left, v_left_bytes = next(l_iter, none_none)
                    k_right_raw, v_right_bytes = next(r_iter, none_none)

                    if k_left is None and k_right_raw is None:
                        return
                    elif k_left is None:
                        is_left_stopped = True
                        raise StopIteration()
                    elif k_right_raw is None:
                        is_left_stopped = False
                        raise StopIteration()

                    if is_same_serdes:
                        k_right = k_right_raw
                    else:
                        k_right = left_key_serdes.serialize(
                            right_key_serdes.deserialize(k_right_raw))

                    while True:
                        is_left_stopped = False
                        while k_right < k_left:
                            if is_same_serdes:
                                output_writebatch.put(k_right, v_right_bytes)
                            else:
                                output_writebatch.put(
                                    k_right,
                                    left_value_serdes.serialize(
                                        right_value_serdes.deserialize(
                                            v_right_bytes)))

                            k_right_raw, v_right_bytes = next(r_iter)

                            if is_same_serdes:
                                k_right = k_right_raw
                            else:
                                k_right = left_key_serdes.serialize(
                                    right_key_serdes.deserialize(k_right_raw))

                        is_left_stopped = True
                        while k_left < k_right:
                            output_writebatch.put(k_left, v_left_bytes)
                            k_left, v_left_bytes = next(l_iter)

                        if k_left == k_right:
                            is_equal = True
                            output_writebatch.put(
                                k_left,
                                left_value_serdes.serialize(
                                    f(
                                        left_value_serdes.deserialize(
                                            v_left_bytes),
                                        right_value_serdes.deserialize(
                                            v_right_bytes))))
                            is_left_stopped = True
                            k_left, v_left_bytes = next(l_iter)

                            is_left_stopped = False
                            k_right_raw, v_right_bytes = next(r_iter)

                            if is_same_serdes:
                                k_right = k_right_raw
                            else:
                                k_right = left_key_serdes.serialize(
                                    right_key_serdes.deserialize(k_right_raw))
                            is_equal = False
                except StopIteration as e:
                    pass

                if not is_left_stopped:
                    try:
                        output_writebatch.put(k_left, v_left_bytes)
                        while True:
                            k_left, v_left_bytes = next(l_iter)
                            output_writebatch.put(k_left, v_left_bytes)
                    except StopIteration as e:
                        pass
                else:
                    try:
                        if not is_equal:
                            if is_same_serdes:
                                output_writebatch.put(k_right_raw,
                                                      v_right_bytes)
                            else:
                                output_writebatch.put(
                                    left_key_serdes.serialize(
                                        right_key_serdes.deserialize(
                                            k_right_raw)),
                                    left_value_serdes.serialize(
                                        right_value_serdes.deserialize(
                                            v_right_bytes)))
                        while True:
                            k_right_raw, v_right_bytes = next(r_iter)
                            if is_same_serdes:
                                output_writebatch.put(k_right_raw,
                                                      v_right_bytes)
                            else:
                                output_writebatch.put(
                                    left_key_serdes.serialize(
                                        right_key_serdes.deserialize(
                                            k_right_raw)),
                                    left_value_serdes.serialize(
                                        right_value_serdes.deserialize(
                                            v_right_bytes)))
                    except StopIteration as e:
                        pass

                    # end of merge union wrapper
                    return

            def hash_union_wrapper(left_iterator, left_key_serdes,
                                   left_value_serdes, right_iterator,
                                   right_key_serdes, right_value_serdes,
                                   output_writebatch):
                f = create_functor(functors[0]._body)

                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, v_left in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    v_right = right_iterator.adapter.get(k_left)
                    if v_right is None:
                        output_writebatch.put(k_left, v_left)
                    else:
                        v_final = f(left_value_serdes.deserialize(v_left),
                                    right_value_serdes.deserialize(v_right))
                        output_writebatch.put(
                            k_left, left_value_serdes.serialize(v_final))

                right_iterator.first()
                for k_right, v_right in right_iterator:
                    if is_diff_serdes:
                        final_v_bytes = output_writebatch.get(
                            left_key_serdes.serialize(
                                right_key_serdes.deserialize(k_right)))
                    else:
                        final_v_bytes = output_writebatch.get(k_right)

                    if final_v_bytes is None:
                        output_writebatch.put(k_right, v_right)

            union_type = task._job._options.get('union_type', 'merge')

            if union_type == 'merge':
                self._run_binary(merge_union_wrapper, task)
            else:
                self._run_binary(hash_union_wrapper, task)

        elif task._name == 'withStores':
            f = create_functor(functors[0]._body)
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(f(task)))

        if L.isEnabledFor(logging.TRACE):
            L.trace(
                f'[RUNTASK] end. task_name={task._name}, inputs={task._inputs}, outputs={task._outputs}, task_id={task._id}'
            )
        else:
            L.debug(
                f'[RUNTASK] end. task_name={task._name}, task_id={task._id}')

        return result
Ejemplo n.º 4
0
    def run_task(self, task: ErTask):
        L.info("start run task")
        L.debug("start run task")
        functors = task._job._functors
        result = task

        if task._name == 'get':
            L.info("egg_pair get call")
            # TODO:1: move to create_serdes
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                print("get key:{} and path is:{}".format(
                    self.functor_serdes.deserialize(f._key),
                    input_adapter.path))
                value = input_adapter.get(f._key)
                result = ErPair(key=f._key, value=value)
        elif task._name == 'getAll':
            L.info("egg_pair getAll call")
            tag = f'{task._id}'

            def generate_broker():
                with create_adapter(
                        task._inputs[0]) as db, db.iteritems() as rb:
                    yield from TransferPair.pair_to_bin_batch(rb)
                    # TODO:0 how to remove?
                    # TransferService.remove_broker(tag)

            TransferService.set_broker(tag, generate_broker())
        elif task._name == 'count':
            L.info('egg_pair count call')
            with create_adapter(task._inputs[0]) as input_adapter:
                result = ErPair(key=self.functor_serdes.serialize('result'),
                                value=self.functor_serdes.serialize(
                                    input_adapter.count()))

        # TODO:1: multiprocessor scenario
        elif task._name == 'putAll':
            L.info("egg_pair putAll call")
            output_partition = task._outputs[0]
            tag = f'{task._id}'
            L.info(f'egg_pair transfer service tag:{tag}')
            tf = TransferPair(tag)
            store_broker_result = tf.store_broker(output_partition,
                                                  False).result()
            # TODO:2: should wait complete?, command timeout?
            L.debug(f"putAll result:{store_broker_result}")

        if task._name == 'put':
            L.info("egg_pair put call")
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                value = input_adapter.put(f._key, f._value)
                #result = ErPair(key=f._key, value=bytes(value))

        if task._name == 'destroy':
            with create_adapter(task._inputs[0]) as input_adapter:
                input_adapter.destroy()
                L.info("finish destroy")

        if task._name == 'delete':
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                L.info("delete k:{}, its value:{}".format(
                    f._key, input_adapter.get(f._key)))
                if input_adapter.delete(f._key):
                    L.info("delete k success")

        if task._name == 'mapValues':
            f = create_functor(functors[0]._body)

            def map_values_wrapper(input_iterator, key_serdes, value_serdes,
                                   output_writebatch):
                for k_bytes, v_bytes in input_iterator:
                    v = value_serdes.deserialize(v_bytes)
                    output_writebatch.put(k_bytes,
                                          value_serdes.serialize(f(v)))

            self._run_unary(map_values_wrapper, task)
        elif task._name == 'map':
            f = create_functor(functors[0]._body)

            def map_wrapper(input_iterator, key_serdes, value_serdes,
                            shuffle_broker):
                for k_bytes, v_bytes in input_iterator:
                    k1, v1 = f(key_serdes.deserialize(k_bytes),
                               value_serdes.deserialize(v_bytes))
                    shuffle_broker.put(
                        (key_serdes.serialize(k1), value_serdes.serialize(v1)))
                L.info('finish calculating')

            self._run_unary(map_wrapper, task, shuffle=True)
            L.info('map finished')

        elif task._name == 'reduce':
            seq_op_result = self.aggregate_seq(task=task)
            L.info('reduce finished')
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(seq_op_result))

        elif task._name == 'aggregate':
            L.info('ready to aggregate')
            seq_op_result = self.aggregate_seq(task=task)
            L.info('aggregate finished')
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(seq_op_result))

        elif task._name == 'mapPartitions':

            def map_partitions_wrapper(input_iterator, key_serdes,
                                       value_serdes, shuffle_broker):
                f = create_functor(functors[0]._body)
                value = f(generator(key_serdes, value_serdes, input_iterator))
                if input_iterator.last():
                    L.info("value of mapPartitions:{}".format(value))
                    if isinstance(value, Iterable):
                        for k1, v1 in value:
                            shuffle_broker.put((key_serdes.serialize(k1),
                                                value_serdes.serialize(v1)))
                    else:
                        key = input_iterator.key()
                        shuffle_broker.put(
                            (key, value_serdes.serialize(value)))

            self._run_unary(map_partitions_wrapper, task, shuffle=True)

        elif task._name == 'collapsePartitions':

            def collapse_partitions_wrapper(input_iterator, key_serdes,
                                            value_serdes, output_writebatch):
                f = create_functor(functors[0]._body)
                value = f(generator(key_serdes, value_serdes, input_iterator))
                if input_iterator.last():
                    key = input_iterator.key()
                    output_writebatch.put(key, value_serdes.serialize(value))

            self._run_unary(collapse_partitions_wrapper, task)

        elif task._name == 'flatMap':

            def flat_map_wraaper(input_iterator, key_serdes, value_serdes,
                                 output_writebatch):
                f = create_functor(functors[0]._body)
                for k1, v1 in input_iterator:
                    for k2, v2 in f(key_serdes.deserialize(k1),
                                    value_serdes.deserialize(v1)):
                        output_writebatch.put(key_serdes.serialize(k2),
                                              value_serdes.serialize(v2))

            self._run_unary(flat_map_wraaper, task)

        elif task._name == 'glom':

            def glom_wrapper(input_iterator, key_serdes, value_serdes,
                             output_writebatch):
                k_tmp = None
                v_list = []
                for k, v in input_iterator:
                    v_list.append((key_serdes.deserialize(k),
                                   value_serdes.deserialize(v)))
                    k_tmp = k
                if k_tmp is not None:
                    output_writebatch.put(k_tmp,
                                          value_serdes.serialize(v_list))

            self._run_unary(glom_wrapper, task)

        elif task._name == 'sample':

            def sample_wrapper(input_iterator, key_serdes, value_serdes,
                               output_writebatch):
                fraction = create_functor(functors[0]._body)
                seed = create_functor(functors[1]._body)
                input_iterator.first()
                random_state = np.random.RandomState(seed)
                for k, v in input_iterator:
                    if random_state.rand() < fraction:
                        output_writebatch.put(k, v)

            self._run_unary(sample_wrapper, task)

        elif task._name == 'filter':

            def filter_wrapper(input_iterator, key_serdes, value_serdes,
                               output_writebatch):
                f = create_functor(functors[0]._body)
                for k, v in input_iterator:
                    if f(key_serdes.deserialize(k),
                         value_serdes.deserialize(v)):
                        output_writebatch.put(k, v)

            self._run_unary(filter_wrapper, task)

        elif task._name == 'join':

            def merge_join_wrapper(left_iterator, left_key_serdes,
                                   left_value_serdes, right_iterator,
                                   right_key_serdes, right_value_serdes,
                                   output_writebatch):
                if not left_iterator.adapter.is_sorted(
                ) or not right_iterator.adapter.is_sorted():
                    raise RuntimeError(
                        f"merge join cannot be applied: not both store types support sorting. "
                        f"left type: {type(left_iterator.adapter)}, is_sorted: {left_iterator.adapter.is_sorted()}; "
                        f"right type: {type(right_iterator.adapter)}, is_sorted: {right_iterator.adapter.is_sorted()}"
                    )
                f = create_functor(functors[0]._body)
                is_same_serdes = type(left_key_serdes) == type(
                    right_key_serdes)

                l_iter = iter(left_iterator)
                r_iter = iter(right_iterator)

                try:
                    k_left, v_left_bytes = next(l_iter)
                    k_right_raw, v_right_bytes = next(r_iter)
                    if is_same_serdes:
                        k_right = k_right_raw
                    else:
                        k_right = left_key_serdes.serialize(
                            right_key_serdes.deserialize(k_right_raw))

                    while True:
                        while k_right < k_left:
                            k_right_raw, v_right_bytes = next(r_iter)
                            if is_same_serdes:
                                k_right = k_right_raw
                            else:
                                k_right = left_key_serdes.serialize(
                                    right_key_serdes.deserialize(k_right_raw))

                        while k_left < k_right:
                            k_left, v_left_bytes = next(l_iter)

                        if k_left == k_right:
                            output_writebatch.put(
                                k_left,
                                left_value_serdes.serialize(
                                    f(
                                        left_value_serdes.deserialize(
                                            v_left_bytes),
                                        right_value_serdes.deserialize(
                                            v_right_bytes))))
                            k_left, v_left_bytes = next(l_iter)
                            # skips next(r_iter) to avoid duplicate codes for the 3rd time
                except StopIteration as e:
                    return

            def hash_join_wrapper(left_iterator, left_key_serdes,
                                  left_value_serdes, right_iterator,
                                  right_key_serdes, right_value_serdes,
                                  output_writebatch):
                f = create_functor(functors[0]._body)
                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, l_v_bytes in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    r_v_bytes = right_iterator.adapter.get(k_left)
                    if r_v_bytes:
                        #L.info("egg join:{}".format(right_value_serdes.deserialize(r_v_bytes)))
                        output_writebatch.put(
                            k_left,
                            left_value_serdes.serialize(
                                f(left_value_serdes.deserialize(l_v_bytes),
                                  right_value_serdes.deserialize(r_v_bytes))))

            join_type = task._job._options.get('join_type', 'merge')

            if join_type == 'merge':
                self._run_binary(merge_join_wrapper, task)
            else:
                self._run_binary(hash_join_wrapper, task)

        elif task._name == 'subtractByKey':

            def subtract_by_key_wrapper(left_iterator, left_key_serdes,
                                        left_value_serdess, right_iterator,
                                        right_key_serdes, right_value_serdess,
                                        output_writebatch):
                L.info("sub wrapper")
                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, v_left in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    v_right = right_iterator.adapter.get(k_left)
                    if v_right is None:
                        output_writebatch.put(k_left, v_left)

            self._run_binary(subtract_by_key_wrapper, task)

        elif task._name == 'union':

            def union_wrapper(left_iterator, left_key_serdes,
                              left_value_serdess, right_iterator,
                              right_key_serdes, right_value_serdess,
                              output_writebatch):
                f = create_functor(functors[0]._body)

                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, v_left in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    v_right = right_iterator.adapter.get(k_left)
                    if v_right is None:
                        output_writebatch.put(k_left, v_left)
                    else:
                        v_final = f(left_value_serdess.deserialize(v_left),
                                    right_value_serdess.deserialize(v_right))
                        output_writebatch.put(
                            k_left, left_value_serdess.serialize(v_final))

                right_iterator.first()
                for k_right, v_right in right_iterator:
                    if is_diff_serdes:
                        final_v_bytes = output_writebatch.get(
                            left_key_serdes.serialize(
                                right_key_serdes.deserialize(k_right)))
                    else:
                        final_v_bytes = output_writebatch.get(k_right)

                    if final_v_bytes is None:
                        output_writebatch.put(k_right, v_right)

            self._run_binary(union_wrapper, task)

        elif task._name == 'withStores':
            f = create_functor(functors[0]._body)
            result = ErPair(
                key=self.functor_serdes.serialize(task._inputs[0]._id),
                value=self.functor_serdes.serialize(f(task._inputs)))
        return result