Beispiel #1
0
def create_adapter(er_partition: ErPartition, options: dict = None):
    if options is None:
        options = {}
    options['store_type'] = er_partition._store_locator._store_type
    options['path'] = get_db_path(er_partition)
    options['er_partition'] = er_partition
    return create_pair_adapter(options=options)
Beispiel #2
0
    def run_task(self, task: ErTask):
        if L.isEnabledFor(logging.TRACE):
            L.trace(
                f'[RUNTASK] start. task_name={task._name}, inputs={task._inputs}, outputs={task._outputs}, task_id={task._id}'
            )
        else:
            L.debug(
                f'[RUNTASK] start. task_name={task._name}, task_id={task._id}')
        functors = task._job._functors
        result = task

        if task._name == 'get':
            # TODO:1: move to create_serdes
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                L.trace(
                    f"get: key: {self.functor_serdes.deserialize(f._key)}, path: {input_adapter.path}"
                )
                value = input_adapter.get(f._key)
                result = ErPair(key=f._key, value=value)
        elif task._name == 'getAll':
            tag = f'{task._id}'
            er_pair = create_functor(functors[0]._body)
            input_store_head = task._job._inputs[0]
            key_serdes = create_serdes(input_store_head._store_locator._serdes)

            def generate_broker():
                with create_adapter(
                        task._inputs[0]) as db, db.iteritems() as rb:
                    limit = None if er_pair._key is None else key_serdes.deserialize(
                        er_pair._key)
                    try:
                        yield from TransferPair.pair_to_bin_batch(rb,
                                                                  limit=limit)
                    finally:
                        TransferService.remove_broker(tag)

            TransferService.set_broker(tag, generate_broker())
        elif task._name == 'count':
            with create_adapter(task._inputs[0]) as input_adapter:
                result = ErPair(key=self.functor_serdes.serialize('result'),
                                value=self.functor_serdes.serialize(
                                    input_adapter.count()))

        elif task._name == 'putBatch':
            partition = task._outputs[0]
            tag = f'{task._id}'
            PutBatchTask(tag, partition).run()

        elif task._name == 'putAll':
            output_partition = task._outputs[0]
            tag = f'{task._id}'
            L.trace(f'egg_pair putAll: transfer service tag={tag}')
            tf = TransferPair(tag)
            store_broker_result = tf.store_broker(output_partition,
                                                  False).result()
            # TODO:2: should wait complete?, command timeout?

        elif task._name == 'put':
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                value = input_adapter.put(f._key, f._value)
                #result = ErPair(key=f._key, value=bytes(value))

        elif task._name == 'destroy':
            input_store_locator = task._inputs[0]._store_locator
            namespace = input_store_locator._namespace
            name = input_store_locator._name
            store_type = input_store_locator._store_type
            L.debug(
                f'destroying store_type={store_type}, namespace={namespace}, name={name}'
            )
            if name == '*':
                from eggroll.roll_pair.utils.pair_utils import get_db_path, get_data_dir
                target_paths = list()
                if store_type == '*':
                    data_dir = get_data_dir()
                    store_types = os.listdir(data_dir)
                    for store_type in store_types:
                        target_paths.append('/'.join(
                            [data_dir, store_type, namespace]))
                else:
                    db_path = get_db_path(task._inputs[0])
                    target_paths.append(db_path[:db_path.rfind('*')])

                real_data_dir = os.path.realpath(get_data_dir())
                for path in target_paths:
                    realpath = os.path.realpath(path)
                    if os.path.exists(path):
                        if realpath == "/" \
                                or realpath == real_data_dir \
                                or not realpath.startswith(real_data_dir):
                            raise ValueError(
                                f'trying to delete a dangerous path: {realpath}'
                            )
                        else:
                            shutil.rmtree(path)
            else:
                options = task._job._options
                with create_adapter(task._inputs[0],
                                    options=options) as input_adapter:
                    input_adapter.destroy(options=options)

        elif task._name == 'delete':
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                if input_adapter.delete(f._key):
                    L.trace("delete k success")

        elif task._name == 'mapValues':
            f = create_functor(functors[0]._body)

            def map_values_wrapper(input_iterator, key_serdes, value_serdes,
                                   output_writebatch):
                for k_bytes, v_bytes in input_iterator:
                    v = value_serdes.deserialize(v_bytes)
                    output_writebatch.put(k_bytes,
                                          value_serdes.serialize(f(v)))

            self._run_unary(map_values_wrapper, task)
        elif task._name == 'map':
            f = create_functor(functors[0]._body)

            def map_wrapper(input_iterator, key_serdes, value_serdes,
                            shuffle_broker):
                for k_bytes, v_bytes in input_iterator:
                    k1, v1 = f(key_serdes.deserialize(k_bytes),
                               value_serdes.deserialize(v_bytes))
                    shuffle_broker.put(
                        (key_serdes.serialize(k1), value_serdes.serialize(v1)))

            self._run_unary(map_wrapper, task, shuffle=True)

        elif task._name == 'reduce':
            seq_op_result = self.aggregate_seq(task=task)
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(seq_op_result))

        elif task._name == 'aggregate':
            seq_op_result = self.aggregate_seq(task=task)
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(seq_op_result))

        elif task._name == 'mapPartitions':
            reduce_op = create_functor(functors[1]._body)
            shuffle = create_functor(functors[2]._body)

            def map_partitions_wrapper(input_iterator, key_serdes,
                                       value_serdes, output_writebatch):
                f = create_functor(functors[0]._body)
                value = f(generator(key_serdes, value_serdes, input_iterator))
                if isinstance(value, Iterable):
                    for k1, v1 in value:
                        if shuffle:
                            output_writebatch.put((key_serdes.serialize(k1),
                                                   value_serdes.serialize(v1)))
                        else:
                            output_writebatch.put(key_serdes.serialize(k1),
                                                  value_serdes.serialize(v1))
                else:
                    key = input_iterator.key()
                    output_writebatch.put((key, value_serdes.serialize(value)))

            self._run_unary(map_partitions_wrapper,
                            task,
                            shuffle=shuffle,
                            reduce_op=reduce_op)

        elif task._name == 'collapsePartitions':

            def collapse_partitions_wrapper(input_iterator, key_serdes,
                                            value_serdes, output_writebatch):
                f = create_functor(functors[0]._body)
                value = f(generator(key_serdes, value_serdes, input_iterator))
                if input_iterator.last():
                    key = input_iterator.key()
                    output_writebatch.put(key, value_serdes.serialize(value))

            self._run_unary(collapse_partitions_wrapper, task)

        elif task._name == 'flatMap':
            shuffle = create_functor(functors[1]._body)

            def flat_map_wraaper(input_iterator, key_serdes, value_serdes,
                                 output_writebatch):
                f = create_functor(functors[0]._body)
                for k1, v1 in input_iterator:
                    for k2, v2 in f(key_serdes.deserialize(k1),
                                    value_serdes.deserialize(v1)):
                        if shuffle:
                            output_writebatch.put((key_serdes.serialize(k2),
                                                   value_serdes.serialize(v2)))
                        else:
                            output_writebatch.put(key_serdes.serialize(k2),
                                                  value_serdes.serialize(v2))

            self._run_unary(flat_map_wraaper, task, shuffle=shuffle)

        elif task._name == 'glom':

            def glom_wrapper(input_iterator, key_serdes, value_serdes,
                             output_writebatch):
                k_tmp = None
                v_list = []
                for k, v in input_iterator:
                    v_list.append((key_serdes.deserialize(k),
                                   value_serdes.deserialize(v)))
                    k_tmp = k
                if k_tmp is not None:
                    output_writebatch.put(k_tmp,
                                          value_serdes.serialize(v_list))

            self._run_unary(glom_wrapper, task)

        elif task._name == 'sample':

            def sample_wrapper(input_iterator, key_serdes, value_serdes,
                               output_writebatch):
                fraction = create_functor(functors[0]._body)
                seed = create_functor(functors[1]._body)
                input_iterator.first()
                random_state = np.random.RandomState(seed)
                for k, v in input_iterator:
                    if random_state.rand() < fraction:
                        output_writebatch.put(k, v)

            self._run_unary(sample_wrapper, task)

        elif task._name == 'filter':

            def filter_wrapper(input_iterator, key_serdes, value_serdes,
                               output_writebatch):
                f = create_functor(functors[0]._body)
                for k, v in input_iterator:
                    if f(key_serdes.deserialize(k),
                         value_serdes.deserialize(v)):
                        output_writebatch.put(k, v)

            self._run_unary(filter_wrapper, task)

        elif task._name == 'join':

            def merge_join_wrapper(left_iterator, left_key_serdes,
                                   left_value_serdes, right_iterator,
                                   right_key_serdes, right_value_serdes,
                                   output_writebatch):
                if not left_iterator.adapter.is_sorted(
                ) or not right_iterator.adapter.is_sorted():
                    raise RuntimeError(
                        f"merge join cannot be applied: not both store types support sorting. "
                        f"left type: {type(left_iterator.adapter)}, is_sorted: {left_iterator.adapter.is_sorted()}; "
                        f"right type: {type(right_iterator.adapter)}, is_sorted: {right_iterator.adapter.is_sorted()}"
                    )
                f = create_functor(functors[0]._body)
                is_same_serdes = left_key_serdes == right_key_serdes

                l_iter = iter(left_iterator)
                r_iter = iter(right_iterator)

                try:
                    k_left, v_left_bytes = next(l_iter)
                    k_right_raw, v_right_bytes = next(r_iter)
                    if is_same_serdes:
                        k_right = k_right_raw
                    else:
                        k_right = left_key_serdes.serialize(
                            right_key_serdes.deserialize(k_right_raw))

                    while True:
                        while k_right < k_left:
                            k_right_raw, v_right_bytes = next(r_iter)
                            if is_same_serdes:
                                k_right = k_right_raw
                            else:
                                k_right = left_key_serdes.serialize(
                                    right_key_serdes.deserialize(k_right_raw))

                        while k_left < k_right:
                            k_left, v_left_bytes = next(l_iter)

                        if k_left == k_right:
                            output_writebatch.put(
                                k_left,
                                left_value_serdes.serialize(
                                    f(
                                        left_value_serdes.deserialize(
                                            v_left_bytes),
                                        right_value_serdes.deserialize(
                                            v_right_bytes))))
                            k_left, v_left_bytes = next(l_iter)
                            # skips next(r_iter) to avoid duplicate codes for the 3rd time
                except StopIteration as e:
                    return

            def hash_join_wrapper(left_iterator, left_key_serdes,
                                  left_value_serdes, right_iterator,
                                  right_key_serdes, right_value_serdes,
                                  output_writebatch):
                f = create_functor(functors[0]._body)
                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, l_v_bytes in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    r_v_bytes = right_iterator.adapter.get(k_left)
                    if r_v_bytes:
                        output_writebatch.put(
                            k_left,
                            left_value_serdes.serialize(
                                f(left_value_serdes.deserialize(l_v_bytes),
                                  right_value_serdes.deserialize(r_v_bytes))))

            join_type = task._job._options.get('join_type', 'merge')

            if join_type == 'merge':
                self._run_binary(merge_join_wrapper, task)
            else:
                self._run_binary(hash_join_wrapper, task)

        elif task._name == 'subtractByKey':

            def merge_subtract_by_key_wrapper(left_iterator, left_key_serdes,
                                              left_value_serdes,
                                              right_iterator, right_key_serdes,
                                              right_value_serdes,
                                              output_writebatch):
                if not left_iterator.adapter.is_sorted(
                ) or not right_iterator.adapter.is_sorted():
                    raise RuntimeError(
                        f"merge subtract_by_key cannot be applied: not both store types support sorting. "
                        f"left type: {type(left_iterator.adapter)}, is_sorted: {left_iterator.adapter.is_sorted()}; "
                        f"right type: {type(right_iterator.adapter)}, is_sorted: {right_iterator.adapter.is_sorted()}"
                    )

                is_same_serdes = left_key_serdes == right_key_serdes
                l_iter = iter(left_iterator)
                r_iter = iter(right_iterator)
                is_left_stopped = False
                is_equal = False
                try:
                    k_left, v_left = next(l_iter)
                except StopIteration:
                    is_left_stopped = True
                    k_left = None
                    v_left = None

                try:
                    k_right_raw, v_right = next(r_iter)
                except StopIteration:
                    is_left_stopped = False
                    k_right_raw = None
                    v_right = None

                # left is None, output must be None
                if k_left is None:
                    return

                try:
                    if k_left is None:
                        raise StopIteration()
                    if k_right_raw is None:
                        raise StopIteration()

                    if is_same_serdes:
                        k_right = k_right_raw
                    else:
                        k_right = left_key_serdes.serialize(
                            right_key_serdes.deserialize(k_right_raw))

                    while True:
                        is_left_stopped = False
                        if k_left < k_right:
                            output_writebatch.put(k_left, v_left)
                            k_left, v_left = next(l_iter)
                            is_left_stopped = True
                        elif k_left == k_right:
                            is_equal = True
                            is_left_stopped = True
                            k_left, v_left = next(l_iter)
                            is_left_stopped = False
                            is_equal = False
                            k_right_raw, v_right = next(r_iter)
                            k_right = left_key_serdes.serialize(
                                right_key_serdes.deserialize(k_right_raw))
                        else:
                            k_right_raw, v_right = next(r_iter)
                            k_right = left_key_serdes.serialize(
                                right_key_serdes.deserialize(k_right_raw))
                        is_left_stopped = True
                except StopIteration as e:
                    pass
                if not is_left_stopped and not is_equal:
                    try:
                        if k_left is not None and v_left is not None:
                            output_writebatch.put(k_left, v_left)
                            while True:
                                k_left, v_left = next(l_iter)
                                output_writebatch.put(k_left, v_left)
                    except StopIteration as e:
                        pass
                elif is_left_stopped and not is_equal and k_left is not None:
                    output_writebatch.put(k_left, v_left)

                return

            def hash_subtract_by_key_wrapper(left_iterator, left_key_serdes,
                                             left_value_serdes, right_iterator,
                                             right_key_serdes,
                                             right_value_serdes,
                                             output_writebatch):
                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, v_left in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    v_right = right_iterator.adapter.get(k_left)
                    if v_right is None:
                        output_writebatch.put(k_left, v_left)

            subtract_by_key_type = task._job._options.get(
                'subtract_by_key_type', 'merge')
            if subtract_by_key_type == 'merge':
                self._run_binary(merge_subtract_by_key_wrapper, task)
            else:
                self._run_binary(hash_subtract_by_key_wrapper, task)

        elif task._name == 'union':

            def merge_union_wrapper(left_iterator, left_key_serdes,
                                    left_value_serdes, right_iterator,
                                    right_key_serdes, right_value_serdes,
                                    output_writebatch):
                if not left_iterator.adapter.is_sorted(
                ) or not right_iterator.adapter.is_sorted():
                    raise RuntimeError(
                        f"merge union cannot be applied: not both store types support sorting. "
                        f"left type: {type(left_iterator.adapter)}, is_sorted: {left_iterator.adapter.is_sorted()}; "
                        f"right type: {type(right_iterator.adapter)}, is_sorted: {right_iterator.adapter.is_sorted()}"
                    )
                f = create_functor(functors[0]._body)
                is_same_serdes = left_key_serdes == right_key_serdes

                l_iter = iter(left_iterator)
                r_iter = iter(right_iterator)
                k_left = None
                v_left_bytes = None
                k_right = None
                v_right_bytes = None
                none_none = (None, None)

                is_left_stopped = False
                is_equal = False
                try:
                    k_left, v_left_bytes = next(l_iter, none_none)
                    k_right_raw, v_right_bytes = next(r_iter, none_none)

                    if k_left is None and k_right_raw is None:
                        return
                    elif k_left is None:
                        is_left_stopped = True
                        raise StopIteration()
                    elif k_right_raw is None:
                        is_left_stopped = False
                        raise StopIteration()

                    if is_same_serdes:
                        k_right = k_right_raw
                    else:
                        k_right = left_key_serdes.serialize(
                            right_key_serdes.deserialize(k_right_raw))

                    while True:
                        is_left_stopped = False
                        while k_right < k_left:
                            if is_same_serdes:
                                output_writebatch.put(k_right, v_right_bytes)
                            else:
                                output_writebatch.put(
                                    k_right,
                                    left_value_serdes.serialize(
                                        right_value_serdes.deserialize(
                                            v_right_bytes)))

                            k_right_raw, v_right_bytes = next(r_iter)

                            if is_same_serdes:
                                k_right = k_right_raw
                            else:
                                k_right = left_key_serdes.serialize(
                                    right_key_serdes.deserialize(k_right_raw))

                        is_left_stopped = True
                        while k_left < k_right:
                            output_writebatch.put(k_left, v_left_bytes)
                            k_left, v_left_bytes = next(l_iter)

                        if k_left == k_right:
                            is_equal = True
                            output_writebatch.put(
                                k_left,
                                left_value_serdes.serialize(
                                    f(
                                        left_value_serdes.deserialize(
                                            v_left_bytes),
                                        right_value_serdes.deserialize(
                                            v_right_bytes))))
                            is_left_stopped = True
                            k_left, v_left_bytes = next(l_iter)

                            is_left_stopped = False
                            k_right_raw, v_right_bytes = next(r_iter)

                            if is_same_serdes:
                                k_right = k_right_raw
                            else:
                                k_right = left_key_serdes.serialize(
                                    right_key_serdes.deserialize(k_right_raw))
                            is_equal = False
                except StopIteration as e:
                    pass

                if not is_left_stopped:
                    try:
                        output_writebatch.put(k_left, v_left_bytes)
                        while True:
                            k_left, v_left_bytes = next(l_iter)
                            output_writebatch.put(k_left, v_left_bytes)
                    except StopIteration as e:
                        pass
                else:
                    try:
                        if not is_equal:
                            if is_same_serdes:
                                output_writebatch.put(k_right_raw,
                                                      v_right_bytes)
                            else:
                                output_writebatch.put(
                                    left_key_serdes.serialize(
                                        right_key_serdes.deserialize(
                                            k_right_raw)),
                                    left_value_serdes.serialize(
                                        right_value_serdes.deserialize(
                                            v_right_bytes)))
                        while True:
                            k_right_raw, v_right_bytes = next(r_iter)
                            if is_same_serdes:
                                output_writebatch.put(k_right_raw,
                                                      v_right_bytes)
                            else:
                                output_writebatch.put(
                                    left_key_serdes.serialize(
                                        right_key_serdes.deserialize(
                                            k_right_raw)),
                                    left_value_serdes.serialize(
                                        right_value_serdes.deserialize(
                                            v_right_bytes)))
                    except StopIteration as e:
                        pass

                    # end of merge union wrapper
                    return

            def hash_union_wrapper(left_iterator, left_key_serdes,
                                   left_value_serdes, right_iterator,
                                   right_key_serdes, right_value_serdes,
                                   output_writebatch):
                f = create_functor(functors[0]._body)

                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, v_left in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    v_right = right_iterator.adapter.get(k_left)
                    if v_right is None:
                        output_writebatch.put(k_left, v_left)
                    else:
                        v_final = f(left_value_serdes.deserialize(v_left),
                                    right_value_serdes.deserialize(v_right))
                        output_writebatch.put(
                            k_left, left_value_serdes.serialize(v_final))

                right_iterator.first()
                for k_right, v_right in right_iterator:
                    if is_diff_serdes:
                        final_v_bytes = output_writebatch.get(
                            left_key_serdes.serialize(
                                right_key_serdes.deserialize(k_right)))
                    else:
                        final_v_bytes = output_writebatch.get(k_right)

                    if final_v_bytes is None:
                        output_writebatch.put(k_right, v_right)

            union_type = task._job._options.get('union_type', 'merge')

            if union_type == 'merge':
                self._run_binary(merge_union_wrapper, task)
            else:
                self._run_binary(hash_union_wrapper, task)

        elif task._name == 'withStores':
            f = create_functor(functors[0]._body)
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(f(task)))

        if L.isEnabledFor(logging.TRACE):
            L.trace(
                f'[RUNTASK] end. task_name={task._name}, inputs={task._inputs}, outputs={task._outputs}, task_id={task._id}'
            )
        else:
            L.debug(
                f'[RUNTASK] end. task_name={task._name}, task_id={task._id}')

        return result