Ejemplo n.º 1
0
    def _run_binary(self, func, task):
        left_key_serdes = create_serdes(task._inputs[0]._store_locator._serdes)
        left_value_serdes = create_serdes(
            task._inputs[0]._store_locator._serdes)

        right_key_serdes = create_serdes(
            task._inputs[1]._store_locator._serdes)
        right_value_serdes = create_serdes(
            task._inputs[1]._store_locator._serdes)

        output_key_serdes = create_serdes(
            task._outputs[0]._store_locator._serdes)
        output_value_serdes = create_serdes(
            task._outputs[0]._store_locator._serdes)

        if left_key_serdes != output_key_serdes or \
                left_value_serdes != output_value_serdes:
            raise ValueError(
                f"input key-value serdes:{(left_key_serdes, left_value_serdes)}"
                f"differ from output key-value serdes:{(output_key_serdes, output_value_serdes)}"
            )

        with create_adapter(task._inputs[0]) as left_adapter, \
                create_adapter(task._inputs[1]) as right_adapter, \
                create_adapter(task._outputs[0]) as output_adapter, \
                left_adapter.iteritems() as left_iterator, \
                right_adapter.iteritems() as right_iterator, \
                output_adapter.new_batch() as output_writebatch:
            try:
                func(left_iterator, left_key_serdes, left_value_serdes,
                     right_iterator, right_key_serdes, right_value_serdes,
                     output_writebatch)
            except Exception as e:
                raise EnvironmentError("exec task:{} error".format(task), e)
Ejemplo n.º 2
0
    def aggregate_seq(self, task: ErTask):
        functors = task._job._functors
        is_reduce = functors[0]._name == 'reduce'
        zero_value = None if is_reduce or functors[
            0] is None else create_functor(functors[0]._body)
        if is_reduce:
            seq_op = create_functor(functors[0]._body)
        else:
            seq_op = create_functor(functors[1]._body)

        first = True
        seq_op_result = zero_value
        input_partition = task._inputs[0]
        input_key_serdes = create_serdes(
            input_partition._store_locator._serdes)
        input_value_serdes = input_key_serdes

        with create_adapter(input_partition) as input_adapter, \
            input_adapter.iteritems() as input_iter:
            for k_bytes, v_bytes in input_iter:
                v = input_value_serdes.deserialize(v_bytes)
                if is_reduce and first:
                    seq_op_result = v
                    first = False
                else:
                    seq_op_result = seq_op(seq_op_result, v)

        return seq_op_result
Ejemplo n.º 3
0
        def do_store(store_partition_inner, is_shuffle_inner, total_writers_inner, reduce_op_inner):
            done_cnt = 0
            tag = self.__generate_tag(store_partition_inner._id) if is_shuffle_inner else self.__transfer_id
            try:
                broker = TransferService.get_or_create_broker(tag, write_signals=total_writers_inner)
                L.trace(f"do_store start for tag={tag}")
                batches = TransferPair.bin_batch_to_pair(b.data for b in broker)

                serdes = create_serdes(store_partition_inner._store_locator._serdes)
                if reduce_op_inner is not None:
                    merger = functools.partial(do_merge,
                                               merge_func=reduce_op_inner,
                                               serdes=serdes)
                else:
                    merger = None

                with create_adapter(store_partition_inner) as db:
                    L.trace(f"do_store create_db for tag={tag} for partition={store_partition_inner}")
                    with db.new_batch() as wb:
                        for k, v in batches:
                            if reduce_op_inner is None:
                                wb.put(k, v)
                            else:
                                wb.merge(merger, k, v)
                            done_cnt += 1
                    L.trace(f"do_store done for tag={tag} for partition={store_partition_inner}")
            except Exception as e:
                L.exception(f'Error in do_store for tag={tag}')
                raise e
            finally:
                TransferService.remove_broker(tag)
            return done_cnt
Ejemplo n.º 4
0
 def generate_broker():
     with create_adapter(task._inputs[0]) as db, db.iteritems() as rb:
         limit = None if er_pair._key is None else key_serdes.deserialize(er_pair._key)
         try:
             yield from TransferPair.pair_to_bin_batch(rb, limit=limit)
         finally:
             TransferService.remove_broker(tag)
Ejemplo n.º 5
0
 def do_store(store_partition_inner, is_shuffle_inner,
              total_writers_inner):
     done_cnt = 0
     tag = self.__generate_tag(
         store_partition_inner._id
     ) if is_shuffle_inner else self.__transfer_id
     try:
         broker = TransferService.get_or_create_broker(
             tag, write_signals=total_writers_inner)
         L.debug(f"do_store start for tag: {tag}")
         batches = TransferPair.bin_batch_to_pair(b.data
                                                  for b in broker)
         with create_adapter(store_partition_inner) as db:
             L.debug(
                 f"do_store create_db for tag: {tag} for partition: {store_partition_inner}"
             )
             with db.new_batch() as wb:
                 for k, v in batches:
                     wb.put(k, v)
                     done_cnt += 1
             L.debug(
                 f"do_store done for tag: {tag} for partition: {store_partition_inner}"
             )
         TransferService.remove_broker(tag)
     except Exception as e:
         L.error(f'Error in do_store for tag {tag}')
         raise e
     return done_cnt
Ejemplo n.º 6
0
 def func(partitions):
     part_main, part_obf, part_ret = partitions
     serder_main = create_serdes(part_main._store_locator._serdes)
     serder_obf = create_serdes(part_obf._store_locator._serdes)
     serder_ret = create_serdes(part_ret._store_locator._serdes)
     with create_adapter(part_main) as db_main, \
             create_adapter(part_obf) as db_obf,\
             create_adapter(part_ret) as db_ret:
         with db_main.iteritems() as rb, db_ret.new_batch() as wb:
             for k, v in rb:
                 mat = serder_main.deserialize(v)
                 obfs = []
                 L.debug(f"obf qsize: {db_obf.count()}, {part_obf}")
                 for i in range(mat._ndarray.size):
                     obfs.append(serder_obf.deserialize(db_obf.get()))
                 obfs = np.array(obfs).reshape(mat._ndarray.shape)
                 wb.put(
                     k,
                     serder_ret.serialize(
                         serder_main.deserialize(v).encrypt(obfs=obfs)))
Ejemplo n.º 7
0
    def run(self):
        # batch stream must be executed serially, and reinit.
        # TODO:0:  remove lock to bss
        rs_header = None
        with PutBatchTask._partition_lock[self.tag]:   # tag includes partition info in tag generation
            L.trace(f"do_store start for tag={self.tag}, partition_id={self.partition._id}")
            bss = _BatchStreamStatus.get_or_create(self.tag)
            try:
                broker = TransferService.get_or_create_broker(self.tag, write_signals=1)

                iter_wait = 0
                iter_timeout = int(CoreConfKeys.EGGROLL_CORE_FIFOBROKER_ITER_TIMEOUT_SEC.get())

                batch = None
                batch_get_interval = 0.1
                with create_adapter(self.partition) as db, db.new_batch() as wb:
                    #for batch in broker:
                    while not broker.is_closable():
                        try:
                            batch = broker.get(block=True, timeout=batch_get_interval)
                        except queue.Empty as e:
                            iter_wait += batch_get_interval
                            if iter_wait > iter_timeout:
                                raise TimeoutError(f'timeout in PutBatchTask.run. tag={self.tag}, iter_timeout={iter_timeout}, iter_wait={iter_wait}')
                            else:
                                continue
                        except BrokerClosed as e:
                            continue

                        iter_wait = 0
                        rs_header = ErRollSiteHeader.from_proto_string(batch.header.ext)
                        batch_pairs = 0
                        if batch.data:
                            bin_data = ArrayByteBuffer(batch.data)
                            reader = PairBinReader(pair_buffer=bin_data, data=batch.data)
                            for k_bytes, v_bytes in reader.read_all():
                                wb.put(k_bytes, v_bytes)
                                batch_pairs += 1
                        bss.count_batch(rs_header, batch_pairs)
                        # TODO:0
                        bss._data_type = rs_header._data_type
                        if rs_header._stage == FINISH_STATUS:
                            bss.set_done(rs_header)  # starting from 0

                bss.check_finish()
                # TransferService.remove_broker(tag) will be called in get_status phrase finished or exception got
            except Exception as e:
                L.exception(f'_run_put_batch error, tag={self.tag}, '
                        f'rs_key={rs_header.get_rs_key() if rs_header is not None else None}, rs_header={rs_header}')
                raise e
            finally:
                TransferService.remove_broker(self.tag)
Ejemplo n.º 8
0
 def _run_unary(self, func, task, shuffle=False):
     key_serdes = create_serdes(task._inputs[0]._store_locator._serdes)
     value_serdes = create_serdes(task._inputs[0]._store_locator._serdes)
     with create_adapter(task._inputs[0]) as input_db:
         L.debug(f"create_store_adatper: {task._inputs[0]}")
         with input_db.iteritems() as rb:
             L.debug(f"create_store_adatper_iter: {task._inputs[0]}")
             from eggroll.roll_pair.transfer_pair import TransferPair, BatchBroker
             if shuffle:
                 total_partitions = task._inputs[
                     0]._store_locator._total_partitions
                 output_store = task._job._outputs[0]
                 shuffle_broker = FifoBroker()
                 write_bb = BatchBroker(shuffle_broker)
                 try:
                     shuffler = TransferPair(transfer_id=task._job._id)
                     store_future = shuffler.store_broker(
                         task._outputs[0], True, total_partitions)
                     scatter_future = shuffler.scatter(
                         shuffle_broker,
                         partitioner(hash_func=hash_code,
                                     total_partitions=total_partitions),
                         output_store)
                     func(rb, key_serdes, value_serdes, write_bb)
                 finally:
                     write_bb.signal_write_finish()
                 scatter_results = scatter_future.result()
                 store_result = store_future.result()
                 L.debug(f"scatter_result:{scatter_results}")
                 L.debug(f"gather_result:{store_result}")
             else:
                 # TODO: modification may be needed when store options finished
                 with create_adapter(task._outputs[0],
                                     options=task._job._options
                                     ) as db, db.new_batch() as wb:
                     func(rb, key_serdes, value_serdes, wb)
             L.debug(f"close_store_adatper:{task._inputs[0]}")
Ejemplo n.º 9
0
 def func_asyn(partitions):
     part1 = partitions[0]
     serder1 = create_serdes(part1._store_locator._serdes)
     with create_adapter(part1) as db1:
         i = 0
         while True:
             try:
                 db1.put(serder1.serialize(pub_key.gen_obfuscator()))
             except InterruptedError as e:
                 L.info(
                     f"finish create asyn obfuscato:{ns}.{name}: {i}")
                 break
             if i % 10000 == 0:
                 L.debug(f"generating asyn obfuscato:{ns}.{name}: {i}")
             i += 1
Ejemplo n.º 10
0
 def do_store(store_partition_inner, is_shuffle_inner,
              total_writers_inner):
     tag = self.__generate_tag(
         store_partition_inner._id
     ) if is_shuffle_inner else self.__transfer_id
     broker = TransferService.get_or_create_broker(
         tag, write_signals=total_writers_inner)
     L.debug(f"do_store_start:{tag}")
     done_cnt = 0
     batches = TransferPair.bin_batch_to_pair(b.data for b in broker)
     with create_adapter(store_partition_inner) as db:
         L.debug(
             f"do_store_create_db: {tag} for partition: {store_partition_inner}"
         )
         with db.new_batch() as wb:
             for k, v in batches:
                 wb.put(k, v)
                 done_cnt += 1
         L.debug(
             f"do_store_done: {tag} for partition: {store_partition_inner}"
         )
     TransferService.remove_broker(tag)
     return done_cnt
Ejemplo n.º 11
0
    def _run_unary(self, func, task, shuffle=False, reduce_op=None):
        input_store_head = task._job._inputs[0]
        output_store_head = task._job._outputs[0]
        input_key_serdes = create_serdes(
            input_store_head._store_locator._serdes)
        input_value_serdes = create_serdes(
            input_store_head._store_locator._serdes)
        output_key_serdes = create_serdes(
            output_store_head._store_locator._serdes)
        output_value_serdes = create_serdes(
            output_store_head._store_locator._serdes)

        if input_key_serdes != output_key_serdes or \
                input_value_serdes != output_value_serdes:
            raise ValueError(
                f"input key-value serdes:{(input_key_serdes, input_value_serdes)}"
                f"differ from output key-value serdes:{(output_key_serdes, output_value_serdes)}"
            )

        if shuffle:
            from eggroll.roll_pair.transfer_pair import TransferPair
            input_total_partitions = input_store_head._store_locator._total_partitions
            output_total_partitions = output_store_head._store_locator._total_partitions
            output_store = output_store_head

            my_server_node_id = get_static_er_conf().get(
                'server_node_id', None)
            shuffler = TransferPair(transfer_id=task._job._id)
            if not task._outputs or \
                    (my_server_node_id is not None
                     and my_server_node_id != task._outputs[0]._processor._server_node_id):
                store_future = None
            else:
                store_future = shuffler.store_broker(
                    store_partition=task._outputs[0],
                    is_shuffle=True,
                    total_writers=input_total_partitions,
                    reduce_op=reduce_op)

            if not task._inputs or \
                    (my_server_node_id is not None
                     and my_server_node_id != task._inputs[0]._processor._server_node_id):
                scatter_future = None
            else:
                shuffle_broker = FifoBroker()
                write_bb = BatchBroker(shuffle_broker)
                try:
                    scatter_future = shuffler.scatter(
                        input_broker=shuffle_broker,
                        partition_function=partitioner(
                            hash_func=hash_code,
                            total_partitions=output_total_partitions),
                        output_store=output_store)
                    with create_adapter(task._inputs[0]) as input_db, \
                        input_db.iteritems() as rb:
                        func(rb, input_key_serdes, input_value_serdes,
                             write_bb)
                finally:
                    write_bb.signal_write_finish()

            if scatter_future:
                scatter_results = scatter_future.result()
            else:
                scatter_results = 'no scatter for this partition'
            if store_future:
                store_results = store_future.result()
            else:
                store_results = 'no store for this partition'
        else:  # no shuffle
            with create_adapter(task._inputs[0]) as input_db, \
                    input_db.iteritems() as rb, \
                    create_adapter(task._outputs[0], options=task._job._options) as db, \
                    db.new_batch() as wb:
                func(rb, input_key_serdes, input_value_serdes, wb)
            L.trace(f"close_store_adatper:{task._inputs[0]}")
Ejemplo n.º 12
0
    def run_task(self, task: ErTask):
        if L.isEnabledFor(logging.TRACE):
            L.trace(
                f'[RUNTASK] start. task_name={task._name}, inputs={task._inputs}, outputs={task._outputs}, task_id={task._id}'
            )
        else:
            L.debug(
                f'[RUNTASK] start. task_name={task._name}, task_id={task._id}')
        functors = task._job._functors
        result = task

        if task._name == 'get':
            # TODO:1: move to create_serdes
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                L.trace(
                    f"get: key: {self.functor_serdes.deserialize(f._key)}, path: {input_adapter.path}"
                )
                value = input_adapter.get(f._key)
                result = ErPair(key=f._key, value=value)
        elif task._name == 'getAll':
            tag = f'{task._id}'
            er_pair = create_functor(functors[0]._body)
            input_store_head = task._job._inputs[0]
            key_serdes = create_serdes(input_store_head._store_locator._serdes)

            def generate_broker():
                with create_adapter(
                        task._inputs[0]) as db, db.iteritems() as rb:
                    limit = None if er_pair._key is None else key_serdes.deserialize(
                        er_pair._key)
                    try:
                        yield from TransferPair.pair_to_bin_batch(rb,
                                                                  limit=limit)
                    finally:
                        TransferService.remove_broker(tag)

            TransferService.set_broker(tag, generate_broker())
        elif task._name == 'count':
            with create_adapter(task._inputs[0]) as input_adapter:
                result = ErPair(key=self.functor_serdes.serialize('result'),
                                value=self.functor_serdes.serialize(
                                    input_adapter.count()))

        elif task._name == 'putBatch':
            partition = task._outputs[0]
            tag = f'{task._id}'
            PutBatchTask(tag, partition).run()

        elif task._name == 'putAll':
            output_partition = task._outputs[0]
            tag = f'{task._id}'
            L.trace(f'egg_pair putAll: transfer service tag={tag}')
            tf = TransferPair(tag)
            store_broker_result = tf.store_broker(output_partition,
                                                  False).result()
            # TODO:2: should wait complete?, command timeout?

        elif task._name == 'put':
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                value = input_adapter.put(f._key, f._value)
                #result = ErPair(key=f._key, value=bytes(value))

        elif task._name == 'destroy':
            input_store_locator = task._inputs[0]._store_locator
            namespace = input_store_locator._namespace
            name = input_store_locator._name
            store_type = input_store_locator._store_type
            L.debug(
                f'destroying store_type={store_type}, namespace={namespace}, name={name}'
            )
            if name == '*':
                from eggroll.roll_pair.utils.pair_utils import get_db_path, get_data_dir
                target_paths = list()
                if store_type == '*':
                    data_dir = get_data_dir()
                    store_types = os.listdir(data_dir)
                    for store_type in store_types:
                        target_paths.append('/'.join(
                            [data_dir, store_type, namespace]))
                else:
                    db_path = get_db_path(task._inputs[0])
                    target_paths.append(db_path[:db_path.rfind('*')])

                real_data_dir = os.path.realpath(get_data_dir())
                for path in target_paths:
                    realpath = os.path.realpath(path)
                    if os.path.exists(path):
                        if realpath == "/" \
                                or realpath == real_data_dir \
                                or not realpath.startswith(real_data_dir):
                            raise ValueError(
                                f'trying to delete a dangerous path: {realpath}'
                            )
                        else:
                            shutil.rmtree(path)
            else:
                options = task._job._options
                with create_adapter(task._inputs[0],
                                    options=options) as input_adapter:
                    input_adapter.destroy(options=options)

        elif task._name == 'delete':
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                if input_adapter.delete(f._key):
                    L.trace("delete k success")

        elif task._name == 'mapValues':
            f = create_functor(functors[0]._body)

            def map_values_wrapper(input_iterator, key_serdes, value_serdes,
                                   output_writebatch):
                for k_bytes, v_bytes in input_iterator:
                    v = value_serdes.deserialize(v_bytes)
                    output_writebatch.put(k_bytes,
                                          value_serdes.serialize(f(v)))

            self._run_unary(map_values_wrapper, task)
        elif task._name == 'map':
            f = create_functor(functors[0]._body)

            def map_wrapper(input_iterator, key_serdes, value_serdes,
                            shuffle_broker):
                for k_bytes, v_bytes in input_iterator:
                    k1, v1 = f(key_serdes.deserialize(k_bytes),
                               value_serdes.deserialize(v_bytes))
                    shuffle_broker.put(
                        (key_serdes.serialize(k1), value_serdes.serialize(v1)))

            self._run_unary(map_wrapper, task, shuffle=True)

        elif task._name == 'reduce':
            seq_op_result = self.aggregate_seq(task=task)
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(seq_op_result))

        elif task._name == 'aggregate':
            seq_op_result = self.aggregate_seq(task=task)
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(seq_op_result))

        elif task._name == 'mapPartitions':
            reduce_op = create_functor(functors[1]._body)
            shuffle = create_functor(functors[2]._body)

            def map_partitions_wrapper(input_iterator, key_serdes,
                                       value_serdes, output_writebatch):
                f = create_functor(functors[0]._body)
                value = f(generator(key_serdes, value_serdes, input_iterator))
                if isinstance(value, Iterable):
                    for k1, v1 in value:
                        if shuffle:
                            output_writebatch.put((key_serdes.serialize(k1),
                                                   value_serdes.serialize(v1)))
                        else:
                            output_writebatch.put(key_serdes.serialize(k1),
                                                  value_serdes.serialize(v1))
                else:
                    key = input_iterator.key()
                    output_writebatch.put((key, value_serdes.serialize(value)))

            self._run_unary(map_partitions_wrapper,
                            task,
                            shuffle=shuffle,
                            reduce_op=reduce_op)

        elif task._name == 'collapsePartitions':

            def collapse_partitions_wrapper(input_iterator, key_serdes,
                                            value_serdes, output_writebatch):
                f = create_functor(functors[0]._body)
                value = f(generator(key_serdes, value_serdes, input_iterator))
                if input_iterator.last():
                    key = input_iterator.key()
                    output_writebatch.put(key, value_serdes.serialize(value))

            self._run_unary(collapse_partitions_wrapper, task)

        elif task._name == 'flatMap':
            shuffle = create_functor(functors[1]._body)

            def flat_map_wraaper(input_iterator, key_serdes, value_serdes,
                                 output_writebatch):
                f = create_functor(functors[0]._body)
                for k1, v1 in input_iterator:
                    for k2, v2 in f(key_serdes.deserialize(k1),
                                    value_serdes.deserialize(v1)):
                        if shuffle:
                            output_writebatch.put((key_serdes.serialize(k2),
                                                   value_serdes.serialize(v2)))
                        else:
                            output_writebatch.put(key_serdes.serialize(k2),
                                                  value_serdes.serialize(v2))

            self._run_unary(flat_map_wraaper, task, shuffle=shuffle)

        elif task._name == 'glom':

            def glom_wrapper(input_iterator, key_serdes, value_serdes,
                             output_writebatch):
                k_tmp = None
                v_list = []
                for k, v in input_iterator:
                    v_list.append((key_serdes.deserialize(k),
                                   value_serdes.deserialize(v)))
                    k_tmp = k
                if k_tmp is not None:
                    output_writebatch.put(k_tmp,
                                          value_serdes.serialize(v_list))

            self._run_unary(glom_wrapper, task)

        elif task._name == 'sample':

            def sample_wrapper(input_iterator, key_serdes, value_serdes,
                               output_writebatch):
                fraction = create_functor(functors[0]._body)
                seed = create_functor(functors[1]._body)
                input_iterator.first()
                random_state = np.random.RandomState(seed)
                for k, v in input_iterator:
                    if random_state.rand() < fraction:
                        output_writebatch.put(k, v)

            self._run_unary(sample_wrapper, task)

        elif task._name == 'filter':

            def filter_wrapper(input_iterator, key_serdes, value_serdes,
                               output_writebatch):
                f = create_functor(functors[0]._body)
                for k, v in input_iterator:
                    if f(key_serdes.deserialize(k),
                         value_serdes.deserialize(v)):
                        output_writebatch.put(k, v)

            self._run_unary(filter_wrapper, task)

        elif task._name == 'join':

            def merge_join_wrapper(left_iterator, left_key_serdes,
                                   left_value_serdes, right_iterator,
                                   right_key_serdes, right_value_serdes,
                                   output_writebatch):
                if not left_iterator.adapter.is_sorted(
                ) or not right_iterator.adapter.is_sorted():
                    raise RuntimeError(
                        f"merge join cannot be applied: not both store types support sorting. "
                        f"left type: {type(left_iterator.adapter)}, is_sorted: {left_iterator.adapter.is_sorted()}; "
                        f"right type: {type(right_iterator.adapter)}, is_sorted: {right_iterator.adapter.is_sorted()}"
                    )
                f = create_functor(functors[0]._body)
                is_same_serdes = left_key_serdes == right_key_serdes

                l_iter = iter(left_iterator)
                r_iter = iter(right_iterator)

                try:
                    k_left, v_left_bytes = next(l_iter)
                    k_right_raw, v_right_bytes = next(r_iter)
                    if is_same_serdes:
                        k_right = k_right_raw
                    else:
                        k_right = left_key_serdes.serialize(
                            right_key_serdes.deserialize(k_right_raw))

                    while True:
                        while k_right < k_left:
                            k_right_raw, v_right_bytes = next(r_iter)
                            if is_same_serdes:
                                k_right = k_right_raw
                            else:
                                k_right = left_key_serdes.serialize(
                                    right_key_serdes.deserialize(k_right_raw))

                        while k_left < k_right:
                            k_left, v_left_bytes = next(l_iter)

                        if k_left == k_right:
                            output_writebatch.put(
                                k_left,
                                left_value_serdes.serialize(
                                    f(
                                        left_value_serdes.deserialize(
                                            v_left_bytes),
                                        right_value_serdes.deserialize(
                                            v_right_bytes))))
                            k_left, v_left_bytes = next(l_iter)
                            # skips next(r_iter) to avoid duplicate codes for the 3rd time
                except StopIteration as e:
                    return

            def hash_join_wrapper(left_iterator, left_key_serdes,
                                  left_value_serdes, right_iterator,
                                  right_key_serdes, right_value_serdes,
                                  output_writebatch):
                f = create_functor(functors[0]._body)
                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, l_v_bytes in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    r_v_bytes = right_iterator.adapter.get(k_left)
                    if r_v_bytes:
                        output_writebatch.put(
                            k_left,
                            left_value_serdes.serialize(
                                f(left_value_serdes.deserialize(l_v_bytes),
                                  right_value_serdes.deserialize(r_v_bytes))))

            join_type = task._job._options.get('join_type', 'merge')

            if join_type == 'merge':
                self._run_binary(merge_join_wrapper, task)
            else:
                self._run_binary(hash_join_wrapper, task)

        elif task._name == 'subtractByKey':

            def merge_subtract_by_key_wrapper(left_iterator, left_key_serdes,
                                              left_value_serdes,
                                              right_iterator, right_key_serdes,
                                              right_value_serdes,
                                              output_writebatch):
                if not left_iterator.adapter.is_sorted(
                ) or not right_iterator.adapter.is_sorted():
                    raise RuntimeError(
                        f"merge subtract_by_key cannot be applied: not both store types support sorting. "
                        f"left type: {type(left_iterator.adapter)}, is_sorted: {left_iterator.adapter.is_sorted()}; "
                        f"right type: {type(right_iterator.adapter)}, is_sorted: {right_iterator.adapter.is_sorted()}"
                    )

                is_same_serdes = left_key_serdes == right_key_serdes
                l_iter = iter(left_iterator)
                r_iter = iter(right_iterator)
                is_left_stopped = False
                is_equal = False
                try:
                    k_left, v_left = next(l_iter)
                except StopIteration:
                    is_left_stopped = True
                    k_left = None
                    v_left = None

                try:
                    k_right_raw, v_right = next(r_iter)
                except StopIteration:
                    is_left_stopped = False
                    k_right_raw = None
                    v_right = None

                # left is None, output must be None
                if k_left is None:
                    return

                try:
                    if k_left is None:
                        raise StopIteration()
                    if k_right_raw is None:
                        raise StopIteration()

                    if is_same_serdes:
                        k_right = k_right_raw
                    else:
                        k_right = left_key_serdes.serialize(
                            right_key_serdes.deserialize(k_right_raw))

                    while True:
                        is_left_stopped = False
                        if k_left < k_right:
                            output_writebatch.put(k_left, v_left)
                            k_left, v_left = next(l_iter)
                            is_left_stopped = True
                        elif k_left == k_right:
                            is_equal = True
                            is_left_stopped = True
                            k_left, v_left = next(l_iter)
                            is_left_stopped = False
                            is_equal = False
                            k_right_raw, v_right = next(r_iter)
                            k_right = left_key_serdes.serialize(
                                right_key_serdes.deserialize(k_right_raw))
                        else:
                            k_right_raw, v_right = next(r_iter)
                            k_right = left_key_serdes.serialize(
                                right_key_serdes.deserialize(k_right_raw))
                        is_left_stopped = True
                except StopIteration as e:
                    pass
                if not is_left_stopped and not is_equal:
                    try:
                        if k_left is not None and v_left is not None:
                            output_writebatch.put(k_left, v_left)
                            while True:
                                k_left, v_left = next(l_iter)
                                output_writebatch.put(k_left, v_left)
                    except StopIteration as e:
                        pass
                elif is_left_stopped and not is_equal and k_left is not None:
                    output_writebatch.put(k_left, v_left)

                return

            def hash_subtract_by_key_wrapper(left_iterator, left_key_serdes,
                                             left_value_serdes, right_iterator,
                                             right_key_serdes,
                                             right_value_serdes,
                                             output_writebatch):
                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, v_left in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    v_right = right_iterator.adapter.get(k_left)
                    if v_right is None:
                        output_writebatch.put(k_left, v_left)

            subtract_by_key_type = task._job._options.get(
                'subtract_by_key_type', 'merge')
            if subtract_by_key_type == 'merge':
                self._run_binary(merge_subtract_by_key_wrapper, task)
            else:
                self._run_binary(hash_subtract_by_key_wrapper, task)

        elif task._name == 'union':

            def merge_union_wrapper(left_iterator, left_key_serdes,
                                    left_value_serdes, right_iterator,
                                    right_key_serdes, right_value_serdes,
                                    output_writebatch):
                if not left_iterator.adapter.is_sorted(
                ) or not right_iterator.adapter.is_sorted():
                    raise RuntimeError(
                        f"merge union cannot be applied: not both store types support sorting. "
                        f"left type: {type(left_iterator.adapter)}, is_sorted: {left_iterator.adapter.is_sorted()}; "
                        f"right type: {type(right_iterator.adapter)}, is_sorted: {right_iterator.adapter.is_sorted()}"
                    )
                f = create_functor(functors[0]._body)
                is_same_serdes = left_key_serdes == right_key_serdes

                l_iter = iter(left_iterator)
                r_iter = iter(right_iterator)
                k_left = None
                v_left_bytes = None
                k_right = None
                v_right_bytes = None
                none_none = (None, None)

                is_left_stopped = False
                is_equal = False
                try:
                    k_left, v_left_bytes = next(l_iter, none_none)
                    k_right_raw, v_right_bytes = next(r_iter, none_none)

                    if k_left is None and k_right_raw is None:
                        return
                    elif k_left is None:
                        is_left_stopped = True
                        raise StopIteration()
                    elif k_right_raw is None:
                        is_left_stopped = False
                        raise StopIteration()

                    if is_same_serdes:
                        k_right = k_right_raw
                    else:
                        k_right = left_key_serdes.serialize(
                            right_key_serdes.deserialize(k_right_raw))

                    while True:
                        is_left_stopped = False
                        while k_right < k_left:
                            if is_same_serdes:
                                output_writebatch.put(k_right, v_right_bytes)
                            else:
                                output_writebatch.put(
                                    k_right,
                                    left_value_serdes.serialize(
                                        right_value_serdes.deserialize(
                                            v_right_bytes)))

                            k_right_raw, v_right_bytes = next(r_iter)

                            if is_same_serdes:
                                k_right = k_right_raw
                            else:
                                k_right = left_key_serdes.serialize(
                                    right_key_serdes.deserialize(k_right_raw))

                        is_left_stopped = True
                        while k_left < k_right:
                            output_writebatch.put(k_left, v_left_bytes)
                            k_left, v_left_bytes = next(l_iter)

                        if k_left == k_right:
                            is_equal = True
                            output_writebatch.put(
                                k_left,
                                left_value_serdes.serialize(
                                    f(
                                        left_value_serdes.deserialize(
                                            v_left_bytes),
                                        right_value_serdes.deserialize(
                                            v_right_bytes))))
                            is_left_stopped = True
                            k_left, v_left_bytes = next(l_iter)

                            is_left_stopped = False
                            k_right_raw, v_right_bytes = next(r_iter)

                            if is_same_serdes:
                                k_right = k_right_raw
                            else:
                                k_right = left_key_serdes.serialize(
                                    right_key_serdes.deserialize(k_right_raw))
                            is_equal = False
                except StopIteration as e:
                    pass

                if not is_left_stopped:
                    try:
                        output_writebatch.put(k_left, v_left_bytes)
                        while True:
                            k_left, v_left_bytes = next(l_iter)
                            output_writebatch.put(k_left, v_left_bytes)
                    except StopIteration as e:
                        pass
                else:
                    try:
                        if not is_equal:
                            if is_same_serdes:
                                output_writebatch.put(k_right_raw,
                                                      v_right_bytes)
                            else:
                                output_writebatch.put(
                                    left_key_serdes.serialize(
                                        right_key_serdes.deserialize(
                                            k_right_raw)),
                                    left_value_serdes.serialize(
                                        right_value_serdes.deserialize(
                                            v_right_bytes)))
                        while True:
                            k_right_raw, v_right_bytes = next(r_iter)
                            if is_same_serdes:
                                output_writebatch.put(k_right_raw,
                                                      v_right_bytes)
                            else:
                                output_writebatch.put(
                                    left_key_serdes.serialize(
                                        right_key_serdes.deserialize(
                                            k_right_raw)),
                                    left_value_serdes.serialize(
                                        right_value_serdes.deserialize(
                                            v_right_bytes)))
                    except StopIteration as e:
                        pass

                    # end of merge union wrapper
                    return

            def hash_union_wrapper(left_iterator, left_key_serdes,
                                   left_value_serdes, right_iterator,
                                   right_key_serdes, right_value_serdes,
                                   output_writebatch):
                f = create_functor(functors[0]._body)

                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, v_left in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    v_right = right_iterator.adapter.get(k_left)
                    if v_right is None:
                        output_writebatch.put(k_left, v_left)
                    else:
                        v_final = f(left_value_serdes.deserialize(v_left),
                                    right_value_serdes.deserialize(v_right))
                        output_writebatch.put(
                            k_left, left_value_serdes.serialize(v_final))

                right_iterator.first()
                for k_right, v_right in right_iterator:
                    if is_diff_serdes:
                        final_v_bytes = output_writebatch.get(
                            left_key_serdes.serialize(
                                right_key_serdes.deserialize(k_right)))
                    else:
                        final_v_bytes = output_writebatch.get(k_right)

                    if final_v_bytes is None:
                        output_writebatch.put(k_right, v_right)

            union_type = task._job._options.get('union_type', 'merge')

            if union_type == 'merge':
                self._run_binary(merge_union_wrapper, task)
            else:
                self._run_binary(hash_union_wrapper, task)

        elif task._name == 'withStores':
            f = create_functor(functors[0]._body)
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(f(task)))

        if L.isEnabledFor(logging.TRACE):
            L.trace(
                f'[RUNTASK] end. task_name={task._name}, inputs={task._inputs}, outputs={task._outputs}, task_id={task._id}'
            )
        else:
            L.debug(
                f'[RUNTASK] end. task_name={task._name}, task_id={task._id}')

        return result
Ejemplo n.º 13
0
 def get_eggs(task):
     _input = task._inputs[0]
     with create_adapter(_input) as input_adapter:
         eggs = input_adapter.get('eggs')
         return eggs
Ejemplo n.º 14
0
        def _push_partition(ertask):
            rs_header._partition_id = ertask._inputs[0]._id
            L.trace(
                f"pushing rollpair partition. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}"
            )

            from eggroll.core.grpc.factory import GrpcChannelFactory
            from eggroll.core.proto import proxy_pb2_grpc
            grpc_channel_factory = GrpcChannelFactory()

            with create_adapter(ertask._inputs[0]) as db, db.iteritems() as rb:
                # NOTICE AGAIN: all modifications to rs_header are limited in bs_helper.
                # rs_header is shared by bs_helper and here. any modification in bs_helper affects this header.
                # Remind that python's object references are passed by value,
                # meaning the 'pointer' is copied, while the contents are modificable
                bs_helper = _BatchStreamHelper(rs_header)
                bin_batch_streams = bs_helper._generate_batch_streams(
                    pair_iter=rb,
                    batches_per_stream=batches_per_stream,
                    body_bytes=body_bytes)

                channel = grpc_channel_factory.create_channel(endpoint)
                stub = proxy_pb2_grpc.DataTransferServiceStub(channel)
                for batch_stream in bin_batch_streams:
                    batch_stream_data = list(batch_stream)
                    cur_retry = 0
                    exception = None
                    while cur_retry < max_retry_cnt:
                        L.trace(
                            f'pushing rollpair partition stream. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}'
                        )
                        try:
                            stub.push(bs_helper.generate_packet(
                                batch_stream_data, cur_retry),
                                      timeout=per_stream_timeout)
                            exception = None
                            break
                        except Exception as e:
                            if cur_retry < max_retry_cnt - long_retry_cnt:
                                retry_interval = round(
                                    min(2 * cur_retry, 20) +
                                    random.random() * 10, 3)
                            else:
                                retry_interval = round(
                                    300 + random.random() * 10, 3)
                            L.warn(
                                f"push rp partition error. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, max_retry_cnt={max_retry_cnt}, cur_retry={cur_retry}, retry_interval={retry_interval}",
                                exc_info=e)
                            time.sleep(retry_interval)
                            if isinstance(e, RpcError) and e.code(
                            ).name == 'UNAVAILABLE':
                                channel = grpc_channel_factory.create_channel(
                                    endpoint, refresh=True)
                                stub = proxy_pb2_grpc.DataTransferServiceStub(
                                    channel)

                            exception = e
                        finally:
                            cur_retry += 1
                    if exception is not None:
                        L.exception(
                            f"push partition failed. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, cur_retry={cur_retry}",
                            exc_info=exception)
                        raise exception
                    L.trace(
                        f'pushed rollpair partition stream. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}, retry count={cur_retry - 1}'
                    )

            L.trace(
                f"pushed rollpair partition. rs_key={rs_key}, partition_id={rs_header._partition_id}, rs_header={rs_header}"
            )
Ejemplo n.º 15
0
 def generate_broker():
     with create_adapter(
             task._inputs[0]) as db, db.iteritems() as rb:
         yield from TransferPair.pair_to_bin_batch(rb)
Ejemplo n.º 16
0
    def run_task(self, task: ErTask):
        L.info("start run task")
        L.debug("start run task")
        functors = task._job._functors
        result = task

        if task._name == 'get':
            L.info("egg_pair get call")
            # TODO:1: move to create_serdes
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                print("get key:{} and path is:{}".format(
                    self.functor_serdes.deserialize(f._key),
                    input_adapter.path))
                value = input_adapter.get(f._key)
                result = ErPair(key=f._key, value=value)
        elif task._name == 'getAll':
            L.info("egg_pair getAll call")
            tag = f'{task._id}'

            def generate_broker():
                with create_adapter(
                        task._inputs[0]) as db, db.iteritems() as rb:
                    yield from TransferPair.pair_to_bin_batch(rb)
                    # TODO:0 how to remove?
                    # TransferService.remove_broker(tag)

            TransferService.set_broker(tag, generate_broker())
        elif task._name == 'count':
            L.info('egg_pair count call')
            with create_adapter(task._inputs[0]) as input_adapter:
                result = ErPair(key=self.functor_serdes.serialize('result'),
                                value=self.functor_serdes.serialize(
                                    input_adapter.count()))

        # TODO:1: multiprocessor scenario
        elif task._name == 'putAll':
            L.info("egg_pair putAll call")
            output_partition = task._outputs[0]
            tag = f'{task._id}'
            L.info(f'egg_pair transfer service tag:{tag}')
            tf = TransferPair(tag)
            store_broker_result = tf.store_broker(output_partition,
                                                  False).result()
            # TODO:2: should wait complete?, command timeout?
            L.debug(f"putAll result:{store_broker_result}")

        if task._name == 'put':
            L.info("egg_pair put call")
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                value = input_adapter.put(f._key, f._value)
                #result = ErPair(key=f._key, value=bytes(value))

        if task._name == 'destroy':
            with create_adapter(task._inputs[0]) as input_adapter:
                input_adapter.destroy()
                L.info("finish destroy")

        if task._name == 'delete':
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                L.info("delete k:{}, its value:{}".format(
                    f._key, input_adapter.get(f._key)))
                if input_adapter.delete(f._key):
                    L.info("delete k success")

        if task._name == 'mapValues':
            f = create_functor(functors[0]._body)

            def map_values_wrapper(input_iterator, key_serdes, value_serdes,
                                   output_writebatch):
                for k_bytes, v_bytes in input_iterator:
                    v = value_serdes.deserialize(v_bytes)
                    output_writebatch.put(k_bytes,
                                          value_serdes.serialize(f(v)))

            self._run_unary(map_values_wrapper, task)
        elif task._name == 'map':
            f = create_functor(functors[0]._body)

            def map_wrapper(input_iterator, key_serdes, value_serdes,
                            shuffle_broker):
                for k_bytes, v_bytes in input_iterator:
                    k1, v1 = f(key_serdes.deserialize(k_bytes),
                               value_serdes.deserialize(v_bytes))
                    shuffle_broker.put(
                        (key_serdes.serialize(k1), value_serdes.serialize(v1)))
                L.info('finish calculating')

            self._run_unary(map_wrapper, task, shuffle=True)
            L.info('map finished')

        elif task._name == 'reduce':
            seq_op_result = self.aggregate_seq(task=task)
            L.info('reduce finished')
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(seq_op_result))

        elif task._name == 'aggregate':
            L.info('ready to aggregate')
            seq_op_result = self.aggregate_seq(task=task)
            L.info('aggregate finished')
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(seq_op_result))

        elif task._name == 'mapPartitions':

            def map_partitions_wrapper(input_iterator, key_serdes,
                                       value_serdes, shuffle_broker):
                f = create_functor(functors[0]._body)
                value = f(generator(key_serdes, value_serdes, input_iterator))
                if input_iterator.last():
                    L.info("value of mapPartitions:{}".format(value))
                    if isinstance(value, Iterable):
                        for k1, v1 in value:
                            shuffle_broker.put((key_serdes.serialize(k1),
                                                value_serdes.serialize(v1)))
                    else:
                        key = input_iterator.key()
                        shuffle_broker.put(
                            (key, value_serdes.serialize(value)))

            self._run_unary(map_partitions_wrapper, task, shuffle=True)

        elif task._name == 'collapsePartitions':

            def collapse_partitions_wrapper(input_iterator, key_serdes,
                                            value_serdes, output_writebatch):
                f = create_functor(functors[0]._body)
                value = f(generator(key_serdes, value_serdes, input_iterator))
                if input_iterator.last():
                    key = input_iterator.key()
                    output_writebatch.put(key, value_serdes.serialize(value))

            self._run_unary(collapse_partitions_wrapper, task)

        elif task._name == 'flatMap':

            def flat_map_wraaper(input_iterator, key_serdes, value_serdes,
                                 output_writebatch):
                f = create_functor(functors[0]._body)
                for k1, v1 in input_iterator:
                    for k2, v2 in f(key_serdes.deserialize(k1),
                                    value_serdes.deserialize(v1)):
                        output_writebatch.put(key_serdes.serialize(k2),
                                              value_serdes.serialize(v2))

            self._run_unary(flat_map_wraaper, task)

        elif task._name == 'glom':

            def glom_wrapper(input_iterator, key_serdes, value_serdes,
                             output_writebatch):
                k_tmp = None
                v_list = []
                for k, v in input_iterator:
                    v_list.append((key_serdes.deserialize(k),
                                   value_serdes.deserialize(v)))
                    k_tmp = k
                if k_tmp is not None:
                    output_writebatch.put(k_tmp,
                                          value_serdes.serialize(v_list))

            self._run_unary(glom_wrapper, task)

        elif task._name == 'sample':

            def sample_wrapper(input_iterator, key_serdes, value_serdes,
                               output_writebatch):
                fraction = create_functor(functors[0]._body)
                seed = create_functor(functors[1]._body)
                input_iterator.first()
                random_state = np.random.RandomState(seed)
                for k, v in input_iterator:
                    if random_state.rand() < fraction:
                        output_writebatch.put(k, v)

            self._run_unary(sample_wrapper, task)

        elif task._name == 'filter':

            def filter_wrapper(input_iterator, key_serdes, value_serdes,
                               output_writebatch):
                f = create_functor(functors[0]._body)
                for k, v in input_iterator:
                    if f(key_serdes.deserialize(k),
                         value_serdes.deserialize(v)):
                        output_writebatch.put(k, v)

            self._run_unary(filter_wrapper, task)

        elif task._name == 'join':

            def merge_join_wrapper(left_iterator, left_key_serdes,
                                   left_value_serdes, right_iterator,
                                   right_key_serdes, right_value_serdes,
                                   output_writebatch):
                if not left_iterator.adapter.is_sorted(
                ) or not right_iterator.adapter.is_sorted():
                    raise RuntimeError(
                        f"merge join cannot be applied: not both store types support sorting. "
                        f"left type: {type(left_iterator.adapter)}, is_sorted: {left_iterator.adapter.is_sorted()}; "
                        f"right type: {type(right_iterator.adapter)}, is_sorted: {right_iterator.adapter.is_sorted()}"
                    )
                f = create_functor(functors[0]._body)
                is_same_serdes = type(left_key_serdes) == type(
                    right_key_serdes)

                l_iter = iter(left_iterator)
                r_iter = iter(right_iterator)

                try:
                    k_left, v_left_bytes = next(l_iter)
                    k_right_raw, v_right_bytes = next(r_iter)
                    if is_same_serdes:
                        k_right = k_right_raw
                    else:
                        k_right = left_key_serdes.serialize(
                            right_key_serdes.deserialize(k_right_raw))

                    while True:
                        while k_right < k_left:
                            k_right_raw, v_right_bytes = next(r_iter)
                            if is_same_serdes:
                                k_right = k_right_raw
                            else:
                                k_right = left_key_serdes.serialize(
                                    right_key_serdes.deserialize(k_right_raw))

                        while k_left < k_right:
                            k_left, v_left_bytes = next(l_iter)

                        if k_left == k_right:
                            output_writebatch.put(
                                k_left,
                                left_value_serdes.serialize(
                                    f(
                                        left_value_serdes.deserialize(
                                            v_left_bytes),
                                        right_value_serdes.deserialize(
                                            v_right_bytes))))
                            k_left, v_left_bytes = next(l_iter)
                            # skips next(r_iter) to avoid duplicate codes for the 3rd time
                except StopIteration as e:
                    return

            def hash_join_wrapper(left_iterator, left_key_serdes,
                                  left_value_serdes, right_iterator,
                                  right_key_serdes, right_value_serdes,
                                  output_writebatch):
                f = create_functor(functors[0]._body)
                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, l_v_bytes in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    r_v_bytes = right_iterator.adapter.get(k_left)
                    if r_v_bytes:
                        #L.info("egg join:{}".format(right_value_serdes.deserialize(r_v_bytes)))
                        output_writebatch.put(
                            k_left,
                            left_value_serdes.serialize(
                                f(left_value_serdes.deserialize(l_v_bytes),
                                  right_value_serdes.deserialize(r_v_bytes))))

            join_type = task._job._options.get('join_type', 'merge')

            if join_type == 'merge':
                self._run_binary(merge_join_wrapper, task)
            else:
                self._run_binary(hash_join_wrapper, task)

        elif task._name == 'subtractByKey':

            def subtract_by_key_wrapper(left_iterator, left_key_serdes,
                                        left_value_serdess, right_iterator,
                                        right_key_serdes, right_value_serdess,
                                        output_writebatch):
                L.info("sub wrapper")
                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, v_left in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    v_right = right_iterator.adapter.get(k_left)
                    if v_right is None:
                        output_writebatch.put(k_left, v_left)

            self._run_binary(subtract_by_key_wrapper, task)

        elif task._name == 'union':

            def union_wrapper(left_iterator, left_key_serdes,
                              left_value_serdess, right_iterator,
                              right_key_serdes, right_value_serdess,
                              output_writebatch):
                f = create_functor(functors[0]._body)

                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, v_left in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    v_right = right_iterator.adapter.get(k_left)
                    if v_right is None:
                        output_writebatch.put(k_left, v_left)
                    else:
                        v_final = f(left_value_serdess.deserialize(v_left),
                                    right_value_serdess.deserialize(v_right))
                        output_writebatch.put(
                            k_left, left_value_serdess.serialize(v_final))

                right_iterator.first()
                for k_right, v_right in right_iterator:
                    if is_diff_serdes:
                        final_v_bytes = output_writebatch.get(
                            left_key_serdes.serialize(
                                right_key_serdes.deserialize(k_right)))
                    else:
                        final_v_bytes = output_writebatch.get(k_right)

                    if final_v_bytes is None:
                        output_writebatch.put(k_right, v_right)

            self._run_binary(union_wrapper, task)

        elif task._name == 'withStores':
            f = create_functor(functors[0]._body)
            result = ErPair(
                key=self.functor_serdes.serialize(task._inputs[0]._id),
                value=self.functor_serdes.serialize(f(task._inputs)))
        return result