示例#1
0
    def put(self, k, v, options: dict = None):
        if options is None:
            options = {}
        k, v = create_serdes(self.__store._store_locator._serdes).serialize(k), \
               create_serdes(self.__store._store_locator._serdes).serialize(v)
        er_pair = ErPair(key=k, value=v)
        outputs = []
        partition_id = self.partitioner(k)
        egg = self.ctx.route_to_egg(self.__store._partitions[partition_id])
        inputs = [ErPartition(id=partition_id, store_locator=self.__store._store_locator)]
        output = [ErPartition(id=0, store_locator=self.__store._store_locator)]

        job_id = generate_job_id(self.__session_id, RollPair.PUT)
        job = ErJob(id=job_id,
                    name=RollPair.PUT,
                    inputs=[self.__store],
                    outputs=outputs,
                    functors=[ErFunctor(name=RollPair.PUT, body=cloudpickle.dumps(er_pair))])

        task = ErTask(id=generate_task_id(job_id, partition_id),
                      name=RollPair.PUT,
                      inputs=inputs,
                      outputs=output,
                      job=job)
        job_resp = self.__command_client.simple_sync_send(
                input=task,
                output_type=ErPair,
                endpoint=egg._command_endpoint,
                command_uri=CommandURI(f'{RollPair.EGG_PAIR_URI_PREFIX}/{RollPair.RUN_TASK}'),
                serdes_type=self.__command_serdes
        )
        value = job_resp._value
        return value
示例#2
0
    def aggregate_seq(self, task: ErTask):
        functors = task._job._functors
        is_reduce = functors[0]._name == 'reduce'
        zero_value = None if is_reduce or functors[
            0] is None else create_functor(functors[0]._body)
        if is_reduce:
            seq_op = create_functor(functors[0]._body)
        else:
            seq_op = create_functor(functors[1]._body)

        first = True
        seq_op_result = zero_value
        input_partition = task._inputs[0]
        input_key_serdes = create_serdes(
            input_partition._store_locator._serdes)
        input_value_serdes = input_key_serdes

        with create_adapter(input_partition) as input_adapter, \
            input_adapter.iteritems() as input_iter:
            for k_bytes, v_bytes in input_iter:
                v = input_value_serdes.deserialize(v_bytes)
                if is_reduce and first:
                    seq_op_result = v
                    first = False
                else:
                    seq_op_result = seq_op(seq_op_result, v)

        return seq_op_result
示例#3
0
    def get(self, k, options: dict = None):
        if options is None:
            options = {}
        k = create_serdes(self.__store._store_locator._serdes).serialize(k)
        er_pair = ErPair(key=k, value=None)
        partition_id = self.partitioner(k)
        egg = self.ctx.route_to_egg(self.__store._partitions[partition_id])
        inputs = [ErPartition(id=partition_id, store_locator=self.__store._store_locator)]
        outputs = [ErPartition(id=partition_id, store_locator=self.__store._store_locator)]

        job_id = generate_job_id(self.__session_id, RollPair.GET)
        job = ErJob(id=job_id,
                    name=RollPair.GET,
                    inputs=[self.__store],
                    outputs=[self.__store],
                    functors=[ErFunctor(name=RollPair.GET, body=cloudpickle.dumps(er_pair))])

        task = ErTask(id=generate_task_id(job_id, partition_id),
                      name=RollPair.GET,
                      inputs=inputs,
                      outputs=outputs,
                      job=job)
        job_resp = self.__command_client.simple_sync_send(
                input=task,
                output_type=ErPair,
                endpoint=egg._command_endpoint,
                command_uri=self.RUN_TASK_URI,
                serdes_type=self.__command_serdes
        )

        return self.value_serdes.deserialize(job_resp._value) if job_resp._value != b'' else None
示例#4
0
    def get_all(self, limit=None, options: dict = None):
        if options is None:
            options = {}

        if limit is not None and not isinstance(limit, int) and limit <= 0:
            raise ValueError(f"limit:{limit} must be positive int")

        job_id = generate_job_id(self.__session_id, RollPair.GET_ALL)
        er_pair = ErPair(key=create_serdes(self.__store._store_locator._serdes)
                         .serialize(limit) if limit is not None else None,
                         value=None)

        def send_command():
            job = ErJob(id=job_id,
                        name=RollPair.GET_ALL,
                        inputs=[self.__store],
                        outputs=[self.__store],
                        functors=[ErFunctor(name=RollPair.GET_ALL, body=cloudpickle.dumps(er_pair))])

            task_results = self._run_job(job=job)
            er_store = self.__get_output_from_result(task_results)

            return er_store

        send_command()

        populated_store = self.ctx.populate_processor(self.__store)
        transfer_pair = TransferPair(transfer_id=job_id)
        done_cnt = 0
        for k, v in transfer_pair.gather(populated_store):
            done_cnt += 1
            yield self.key_serdes.deserialize(k), self.value_serdes.deserialize(v)
        L.trace(f"get_all: namespace={self.get_namespace()} name={self.get_name()}, count={done_cnt}")
示例#5
0
        def do_store(store_partition_inner, is_shuffle_inner, total_writers_inner, reduce_op_inner):
            done_cnt = 0
            tag = self.__generate_tag(store_partition_inner._id) if is_shuffle_inner else self.__transfer_id
            try:
                broker = TransferService.get_or_create_broker(tag, write_signals=total_writers_inner)
                L.trace(f"do_store start for tag={tag}")
                batches = TransferPair.bin_batch_to_pair(b.data for b in broker)

                serdes = create_serdes(store_partition_inner._store_locator._serdes)
                if reduce_op_inner is not None:
                    merger = functools.partial(do_merge,
                                               merge_func=reduce_op_inner,
                                               serdes=serdes)
                else:
                    merger = None

                with create_adapter(store_partition_inner) as db:
                    L.trace(f"do_store create_db for tag={tag} for partition={store_partition_inner}")
                    with db.new_batch() as wb:
                        for k, v in batches:
                            if reduce_op_inner is None:
                                wb.put(k, v)
                            else:
                                wb.merge(merger, k, v)
                            done_cnt += 1
                    L.trace(f"do_store done for tag={tag} for partition={store_partition_inner}")
            except Exception as e:
                L.exception(f'Error in do_store for tag={tag}')
                raise e
            finally:
                TransferService.remove_broker(tag)
            return done_cnt
示例#6
0
 def func(partitions):
     part_main, part_obf, part_ret = partitions
     serder_main = create_serdes(part_main._store_locator._serdes)
     serder_obf = create_serdes(part_obf._store_locator._serdes)
     serder_ret = create_serdes(part_ret._store_locator._serdes)
     with create_adapter(part_main) as db_main, \
             create_adapter(part_obf) as db_obf,\
             create_adapter(part_ret) as db_ret:
         with db_main.iteritems() as rb, db_ret.new_batch() as wb:
             for k, v in rb:
                 mat = serder_main.deserialize(v)
                 obfs = []
                 L.debug(f"obf qsize: {db_obf.count()}, {part_obf}")
                 for i in range(mat._ndarray.size):
                     obfs.append(serder_obf.deserialize(db_obf.get()))
                 obfs = np.array(obfs).reshape(mat._ndarray.shape)
                 wb.put(
                     k,
                     serder_ret.serialize(
                         serder_main.deserialize(v).encrypt(obfs=obfs)))
示例#7
0
    def _run_binary(self, func, task):
        left_key_serdes = create_serdes(task._inputs[0]._store_locator._serdes)
        left_value_serdes = create_serdes(
            task._inputs[0]._store_locator._serdes)

        right_key_serdes = create_serdes(
            task._inputs[1]._store_locator._serdes)
        right_value_serdes = create_serdes(
            task._inputs[1]._store_locator._serdes)

        output_key_serdes = create_serdes(
            task._outputs[0]._store_locator._serdes)
        output_value_serdes = create_serdes(
            task._outputs[0]._store_locator._serdes)

        if left_key_serdes != output_key_serdes or \
                left_value_serdes != output_value_serdes:
            raise ValueError(
                f"input key-value serdes:{(left_key_serdes, left_value_serdes)}"
                f"differ from output key-value serdes:{(output_key_serdes, output_value_serdes)}"
            )

        with create_adapter(task._inputs[0]) as left_adapter, \
                create_adapter(task._inputs[1]) as right_adapter, \
                create_adapter(task._outputs[0]) as output_adapter, \
                left_adapter.iteritems() as left_iterator, \
                right_adapter.iteritems() as right_iterator, \
                output_adapter.new_batch() as output_writebatch:
            try:
                func(left_iterator, left_key_serdes, left_value_serdes,
                     right_iterator, right_key_serdes, right_value_serdes,
                     output_writebatch)
            except Exception as e:
                raise EnvironmentError("exec task:{} error".format(task), e)
示例#8
0
 def _run_unary(self, func, task, shuffle=False):
     key_serdes = create_serdes(task._inputs[0]._store_locator._serdes)
     value_serdes = create_serdes(task._inputs[0]._store_locator._serdes)
     with create_adapter(task._inputs[0]) as input_db:
         L.debug(f"create_store_adatper: {task._inputs[0]}")
         with input_db.iteritems() as rb:
             L.debug(f"create_store_adatper_iter: {task._inputs[0]}")
             from eggroll.roll_pair.transfer_pair import TransferPair, BatchBroker
             if shuffle:
                 total_partitions = task._inputs[
                     0]._store_locator._total_partitions
                 output_store = task._job._outputs[0]
                 shuffle_broker = FifoBroker()
                 write_bb = BatchBroker(shuffle_broker)
                 try:
                     shuffler = TransferPair(transfer_id=task._job._id)
                     store_future = shuffler.store_broker(
                         task._outputs[0], True, total_partitions)
                     scatter_future = shuffler.scatter(
                         shuffle_broker,
                         partitioner(hash_func=hash_code,
                                     total_partitions=total_partitions),
                         output_store)
                     func(rb, key_serdes, value_serdes, write_bb)
                 finally:
                     write_bb.signal_write_finish()
                 scatter_results = scatter_future.result()
                 store_result = store_future.result()
                 L.debug(f"scatter_result:{scatter_results}")
                 L.debug(f"gather_result:{store_result}")
             else:
                 # TODO: modification may be needed when store options finished
                 with create_adapter(task._outputs[0],
                                     options=task._job._options
                                     ) as db, db.new_batch() as wb:
                     func(rb, key_serdes, value_serdes, wb)
             L.debug(f"close_store_adatper:{task._inputs[0]}")
示例#9
0
 def func_asyn(partitions):
     part1 = partitions[0]
     serder1 = create_serdes(part1._store_locator._serdes)
     with create_adapter(part1) as db1:
         i = 0
         while True:
             try:
                 db1.put(serder1.serialize(pub_key.gen_obfuscator()))
             except InterruptedError as e:
                 L.info(
                     f"finish create asyn obfuscato:{ns}.{name}: {i}")
                 break
             if i % 10000 == 0:
                 L.debug(f"generating asyn obfuscato:{ns}.{name}: {i}")
             i += 1
示例#10
0
 def __init__(self, er_store: ErStore, rp_ctx: RollPairContext):
     self.__store = er_store
     self.ctx = rp_ctx
     self.__command_serdes = SerdesTypes.PROTOBUF
     self.__roll_pair_master = self.ctx.get_roll()
     self.__command_client = CommandClient()
     self.functor_serdes = create_serdes(SerdesTypes.CLOUD_PICKLE)
     self.value_serdes = self.get_store_serdes()
     self.key_serdes = self.get_store_serdes()
     self.partitioner = partitioner(
         hash_code, self.__store._store_locator._total_partitions)
     self.egg_router = default_egg_router
     self.__session_id = self.ctx.session_id
     self.gc_enable = rp_ctx.rpc_gc_enable
     self.gc_recorder = rp_ctx.gc_recorder
     self.gc_recorder.record(er_store)
示例#11
0
    def delete(self, k, options: dict = None):
        if options is None:
            options = {}
        key = create_serdes(self.__store._store_locator._serdes).serialize(k)
        er_pair = ErPair(key=key, value=None)
        value = None
        partition_id = self.partitioner(key)
        egg = self.ctx.route_to_egg(self.__store._partitions[partition_id])

        job_id = generate_job_id(self.__session_id, RollPair.DELETE)
        job = ErJob(id=job_id,
                    name=RollPair.DELETE,
                    inputs=[self.__store],
                    outputs=[],
                    functors=[ErFunctor(name=RollPair.DELETE, body=cloudpickle.dumps(er_pair))])

        task_results = self._run_job(job=job, create_output_if_missing=False)
示例#12
0
    def get(self, k, options: dict = None):
        if options is None:
            options = {}
        L.debug(f"get k: {k}")
        k = create_serdes(self.__store._store_locator._serdes).serialize(k)
        er_pair = ErPair(key=k, value=None)
        outputs = []
        value = None
        partition_id = self.partitioner(k)
        egg = self.ctx.route_to_egg(self.__store._partitions[partition_id])
        L.info(
            f"partitions count: {self.__store._store_locator._total_partitions}, target partition: {partition_id}, endpoint: {egg._command_endpoint}"
        )
        inputs = [
            ErPartition(id=partition_id,
                        store_locator=self.__store._store_locator)
        ]
        output = [
            ErPartition(id=partition_id,
                        store_locator=self.__store._store_locator)
        ]

        job_id = generate_job_id(self.__session_id, RollPair.GET)
        job = ErJob(id=job_id,
                    name=RollPair.GET,
                    inputs=[self.__store],
                    outputs=outputs,
                    functors=[ErFunctor(body=cloudpickle.dumps(er_pair))])

        task = ErTask(id=generate_task_id(job_id, partition_id),
                      name=RollPair.GET,
                      inputs=inputs,
                      outputs=output,
                      job=job)
        job_resp = self.__command_client.simple_sync_send(
            input=task,
            output_type=ErPair,
            endpoint=egg._command_endpoint,
            command_uri=CommandURI(
                f'{RollPair.EGG_PAIR_URI_PREFIX}/{RollPair.RUN_TASK}'),
            serdes_type=self.__command_serdes)

        return self.value_serdes.deserialize(
            job_resp._value) if job_resp._value != b'' else None
示例#13
0
 def __init__(self, er_store: ErStore, rp_ctx: RollPairContext):
     if not rp_ctx:
         raise ValueError('rp_ctx cannot be None')
     self.__store = er_store
     self.ctx = rp_ctx
     self.__command_serdes = SerdesTypes.PROTOBUF
     # self.__roll_pair_master = self.ctx.get_roll()
     self.__command_client = CommandClient()
     self.functor_serdes = create_serdes(SerdesTypes.CLOUD_PICKLE)
     self.value_serdes = self.get_store_serdes()
     self.key_serdes = self.get_store_serdes()
     # self.partitioner = partitioner(hash_code, self.__store._store_locator._total_partitions)
     import mmh3
     self.partitioner = partitioner(
         mmh3.hash, self.__store._store_locator._total_partitions)
     self.egg_router = default_egg_router
     self.__session_id = self.ctx.session_id
     self.gc_enable = rp_ctx.rpc_gc_enable
     self.gc_recorder = rp_ctx.gc_recorder
     self.gc_recorder.record(er_store)
     self.destroyed = False
示例#14
0
    def delete(self, k, options: dict = None):
        if options is None:
            options = {}
        key = create_serdes(self.__store._store_locator._serdes).serialize(k)
        er_pair = ErPair(key=key, value=None)
        outputs = []
        value = None
        partition_id = self.partitioner(key)
        egg = self.ctx.route_to_egg(self.__store._partitions[partition_id])
        L.info(egg._command_endpoint)
        L.info(f"count: {self.__store._store_locator._total_partitions}")
        inputs = [
            ErPartition(id=partition_id,
                        store_locator=self.__store._store_locator)
        ]
        output = [
            ErPartition(id=partition_id,
                        store_locator=self.__store._store_locator)
        ]

        job_id = generate_job_id(self.__session_id, RollPair.DELETE)
        job = ErJob(id=job_id,
                    name=RollPair.DELETE,
                    inputs=[self.__store],
                    outputs=outputs,
                    functors=[ErFunctor(body=cloudpickle.dumps(er_pair))])
        task = ErTask(id=generate_task_id(job_id, partition_id),
                      name=RollPair.DELETE,
                      inputs=inputs,
                      outputs=output,
                      job=job)
        L.info("start send req")
        job_resp = self.__command_client.simple_sync_send(
            input=task,
            output_type=ErPair,
            endpoint=egg._command_endpoint,
            command_uri=CommandURI(
                f'{RollPair.EGG_PAIR_URI_PREFIX}/{RollPair.RUN_TASK}'),
            serdes_type=self.__command_serdes)
示例#15
0
    def _run_unary(self, func, task, shuffle=False, reduce_op=None):
        input_store_head = task._job._inputs[0]
        output_store_head = task._job._outputs[0]
        input_key_serdes = create_serdes(
            input_store_head._store_locator._serdes)
        input_value_serdes = create_serdes(
            input_store_head._store_locator._serdes)
        output_key_serdes = create_serdes(
            output_store_head._store_locator._serdes)
        output_value_serdes = create_serdes(
            output_store_head._store_locator._serdes)

        if input_key_serdes != output_key_serdes or \
                input_value_serdes != output_value_serdes:
            raise ValueError(
                f"input key-value serdes:{(input_key_serdes, input_value_serdes)}"
                f"differ from output key-value serdes:{(output_key_serdes, output_value_serdes)}"
            )

        if shuffle:
            from eggroll.roll_pair.transfer_pair import TransferPair
            input_total_partitions = input_store_head._store_locator._total_partitions
            output_total_partitions = output_store_head._store_locator._total_partitions
            output_store = output_store_head

            my_server_node_id = get_static_er_conf().get(
                'server_node_id', None)
            shuffler = TransferPair(transfer_id=task._job._id)
            if not task._outputs or \
                    (my_server_node_id is not None
                     and my_server_node_id != task._outputs[0]._processor._server_node_id):
                store_future = None
            else:
                store_future = shuffler.store_broker(
                    store_partition=task._outputs[0],
                    is_shuffle=True,
                    total_writers=input_total_partitions,
                    reduce_op=reduce_op)

            if not task._inputs or \
                    (my_server_node_id is not None
                     and my_server_node_id != task._inputs[0]._processor._server_node_id):
                scatter_future = None
            else:
                shuffle_broker = FifoBroker()
                write_bb = BatchBroker(shuffle_broker)
                try:
                    scatter_future = shuffler.scatter(
                        input_broker=shuffle_broker,
                        partition_function=partitioner(
                            hash_func=hash_code,
                            total_partitions=output_total_partitions),
                        output_store=output_store)
                    with create_adapter(task._inputs[0]) as input_db, \
                        input_db.iteritems() as rb:
                        func(rb, input_key_serdes, input_value_serdes,
                             write_bb)
                finally:
                    write_bb.signal_write_finish()

            if scatter_future:
                scatter_results = scatter_future.result()
            else:
                scatter_results = 'no scatter for this partition'
            if store_future:
                store_results = store_future.result()
            else:
                store_results = 'no store for this partition'
        else:  # no shuffle
            with create_adapter(task._inputs[0]) as input_db, \
                    input_db.iteritems() as rb, \
                    create_adapter(task._outputs[0], options=task._job._options) as db, \
                    db.new_batch() as wb:
                func(rb, input_key_serdes, input_value_serdes, wb)
            L.trace(f"close_store_adatper:{task._inputs[0]}")
示例#16
0
 def get_store_serdes(self):
     return create_serdes(self.__store._store_locator._serdes)
示例#17
0
    def run_task(self, task: ErTask):
        if L.isEnabledFor(logging.TRACE):
            L.trace(
                f'[RUNTASK] start. task_name={task._name}, inputs={task._inputs}, outputs={task._outputs}, task_id={task._id}'
            )
        else:
            L.debug(
                f'[RUNTASK] start. task_name={task._name}, task_id={task._id}')
        functors = task._job._functors
        result = task

        if task._name == 'get':
            # TODO:1: move to create_serdes
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                L.trace(
                    f"get: key: {self.functor_serdes.deserialize(f._key)}, path: {input_adapter.path}"
                )
                value = input_adapter.get(f._key)
                result = ErPair(key=f._key, value=value)
        elif task._name == 'getAll':
            tag = f'{task._id}'
            er_pair = create_functor(functors[0]._body)
            input_store_head = task._job._inputs[0]
            key_serdes = create_serdes(input_store_head._store_locator._serdes)

            def generate_broker():
                with create_adapter(
                        task._inputs[0]) as db, db.iteritems() as rb:
                    limit = None if er_pair._key is None else key_serdes.deserialize(
                        er_pair._key)
                    try:
                        yield from TransferPair.pair_to_bin_batch(rb,
                                                                  limit=limit)
                    finally:
                        TransferService.remove_broker(tag)

            TransferService.set_broker(tag, generate_broker())
        elif task._name == 'count':
            with create_adapter(task._inputs[0]) as input_adapter:
                result = ErPair(key=self.functor_serdes.serialize('result'),
                                value=self.functor_serdes.serialize(
                                    input_adapter.count()))

        elif task._name == 'putBatch':
            partition = task._outputs[0]
            tag = f'{task._id}'
            PutBatchTask(tag, partition).run()

        elif task._name == 'putAll':
            output_partition = task._outputs[0]
            tag = f'{task._id}'
            L.trace(f'egg_pair putAll: transfer service tag={tag}')
            tf = TransferPair(tag)
            store_broker_result = tf.store_broker(output_partition,
                                                  False).result()
            # TODO:2: should wait complete?, command timeout?

        elif task._name == 'put':
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                value = input_adapter.put(f._key, f._value)
                #result = ErPair(key=f._key, value=bytes(value))

        elif task._name == 'destroy':
            input_store_locator = task._inputs[0]._store_locator
            namespace = input_store_locator._namespace
            name = input_store_locator._name
            store_type = input_store_locator._store_type
            L.debug(
                f'destroying store_type={store_type}, namespace={namespace}, name={name}'
            )
            if name == '*':
                from eggroll.roll_pair.utils.pair_utils import get_db_path, get_data_dir
                target_paths = list()
                if store_type == '*':
                    data_dir = get_data_dir()
                    store_types = os.listdir(data_dir)
                    for store_type in store_types:
                        target_paths.append('/'.join(
                            [data_dir, store_type, namespace]))
                else:
                    db_path = get_db_path(task._inputs[0])
                    target_paths.append(db_path[:db_path.rfind('*')])

                real_data_dir = os.path.realpath(get_data_dir())
                for path in target_paths:
                    realpath = os.path.realpath(path)
                    if os.path.exists(path):
                        if realpath == "/" \
                                or realpath == real_data_dir \
                                or not realpath.startswith(real_data_dir):
                            raise ValueError(
                                f'trying to delete a dangerous path: {realpath}'
                            )
                        else:
                            shutil.rmtree(path)
            else:
                options = task._job._options
                with create_adapter(task._inputs[0],
                                    options=options) as input_adapter:
                    input_adapter.destroy(options=options)

        elif task._name == 'delete':
            f = create_functor(functors[0]._body)
            with create_adapter(task._inputs[0]) as input_adapter:
                if input_adapter.delete(f._key):
                    L.trace("delete k success")

        elif task._name == 'mapValues':
            f = create_functor(functors[0]._body)

            def map_values_wrapper(input_iterator, key_serdes, value_serdes,
                                   output_writebatch):
                for k_bytes, v_bytes in input_iterator:
                    v = value_serdes.deserialize(v_bytes)
                    output_writebatch.put(k_bytes,
                                          value_serdes.serialize(f(v)))

            self._run_unary(map_values_wrapper, task)
        elif task._name == 'map':
            f = create_functor(functors[0]._body)

            def map_wrapper(input_iterator, key_serdes, value_serdes,
                            shuffle_broker):
                for k_bytes, v_bytes in input_iterator:
                    k1, v1 = f(key_serdes.deserialize(k_bytes),
                               value_serdes.deserialize(v_bytes))
                    shuffle_broker.put(
                        (key_serdes.serialize(k1), value_serdes.serialize(v1)))

            self._run_unary(map_wrapper, task, shuffle=True)

        elif task._name == 'reduce':
            seq_op_result = self.aggregate_seq(task=task)
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(seq_op_result))

        elif task._name == 'aggregate':
            seq_op_result = self.aggregate_seq(task=task)
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(seq_op_result))

        elif task._name == 'mapPartitions':
            reduce_op = create_functor(functors[1]._body)
            shuffle = create_functor(functors[2]._body)

            def map_partitions_wrapper(input_iterator, key_serdes,
                                       value_serdes, output_writebatch):
                f = create_functor(functors[0]._body)
                value = f(generator(key_serdes, value_serdes, input_iterator))
                if isinstance(value, Iterable):
                    for k1, v1 in value:
                        if shuffle:
                            output_writebatch.put((key_serdes.serialize(k1),
                                                   value_serdes.serialize(v1)))
                        else:
                            output_writebatch.put(key_serdes.serialize(k1),
                                                  value_serdes.serialize(v1))
                else:
                    key = input_iterator.key()
                    output_writebatch.put((key, value_serdes.serialize(value)))

            self._run_unary(map_partitions_wrapper,
                            task,
                            shuffle=shuffle,
                            reduce_op=reduce_op)

        elif task._name == 'collapsePartitions':

            def collapse_partitions_wrapper(input_iterator, key_serdes,
                                            value_serdes, output_writebatch):
                f = create_functor(functors[0]._body)
                value = f(generator(key_serdes, value_serdes, input_iterator))
                if input_iterator.last():
                    key = input_iterator.key()
                    output_writebatch.put(key, value_serdes.serialize(value))

            self._run_unary(collapse_partitions_wrapper, task)

        elif task._name == 'flatMap':
            shuffle = create_functor(functors[1]._body)

            def flat_map_wraaper(input_iterator, key_serdes, value_serdes,
                                 output_writebatch):
                f = create_functor(functors[0]._body)
                for k1, v1 in input_iterator:
                    for k2, v2 in f(key_serdes.deserialize(k1),
                                    value_serdes.deserialize(v1)):
                        if shuffle:
                            output_writebatch.put((key_serdes.serialize(k2),
                                                   value_serdes.serialize(v2)))
                        else:
                            output_writebatch.put(key_serdes.serialize(k2),
                                                  value_serdes.serialize(v2))

            self._run_unary(flat_map_wraaper, task, shuffle=shuffle)

        elif task._name == 'glom':

            def glom_wrapper(input_iterator, key_serdes, value_serdes,
                             output_writebatch):
                k_tmp = None
                v_list = []
                for k, v in input_iterator:
                    v_list.append((key_serdes.deserialize(k),
                                   value_serdes.deserialize(v)))
                    k_tmp = k
                if k_tmp is not None:
                    output_writebatch.put(k_tmp,
                                          value_serdes.serialize(v_list))

            self._run_unary(glom_wrapper, task)

        elif task._name == 'sample':

            def sample_wrapper(input_iterator, key_serdes, value_serdes,
                               output_writebatch):
                fraction = create_functor(functors[0]._body)
                seed = create_functor(functors[1]._body)
                input_iterator.first()
                random_state = np.random.RandomState(seed)
                for k, v in input_iterator:
                    if random_state.rand() < fraction:
                        output_writebatch.put(k, v)

            self._run_unary(sample_wrapper, task)

        elif task._name == 'filter':

            def filter_wrapper(input_iterator, key_serdes, value_serdes,
                               output_writebatch):
                f = create_functor(functors[0]._body)
                for k, v in input_iterator:
                    if f(key_serdes.deserialize(k),
                         value_serdes.deserialize(v)):
                        output_writebatch.put(k, v)

            self._run_unary(filter_wrapper, task)

        elif task._name == 'join':

            def merge_join_wrapper(left_iterator, left_key_serdes,
                                   left_value_serdes, right_iterator,
                                   right_key_serdes, right_value_serdes,
                                   output_writebatch):
                if not left_iterator.adapter.is_sorted(
                ) or not right_iterator.adapter.is_sorted():
                    raise RuntimeError(
                        f"merge join cannot be applied: not both store types support sorting. "
                        f"left type: {type(left_iterator.adapter)}, is_sorted: {left_iterator.adapter.is_sorted()}; "
                        f"right type: {type(right_iterator.adapter)}, is_sorted: {right_iterator.adapter.is_sorted()}"
                    )
                f = create_functor(functors[0]._body)
                is_same_serdes = left_key_serdes == right_key_serdes

                l_iter = iter(left_iterator)
                r_iter = iter(right_iterator)

                try:
                    k_left, v_left_bytes = next(l_iter)
                    k_right_raw, v_right_bytes = next(r_iter)
                    if is_same_serdes:
                        k_right = k_right_raw
                    else:
                        k_right = left_key_serdes.serialize(
                            right_key_serdes.deserialize(k_right_raw))

                    while True:
                        while k_right < k_left:
                            k_right_raw, v_right_bytes = next(r_iter)
                            if is_same_serdes:
                                k_right = k_right_raw
                            else:
                                k_right = left_key_serdes.serialize(
                                    right_key_serdes.deserialize(k_right_raw))

                        while k_left < k_right:
                            k_left, v_left_bytes = next(l_iter)

                        if k_left == k_right:
                            output_writebatch.put(
                                k_left,
                                left_value_serdes.serialize(
                                    f(
                                        left_value_serdes.deserialize(
                                            v_left_bytes),
                                        right_value_serdes.deserialize(
                                            v_right_bytes))))
                            k_left, v_left_bytes = next(l_iter)
                            # skips next(r_iter) to avoid duplicate codes for the 3rd time
                except StopIteration as e:
                    return

            def hash_join_wrapper(left_iterator, left_key_serdes,
                                  left_value_serdes, right_iterator,
                                  right_key_serdes, right_value_serdes,
                                  output_writebatch):
                f = create_functor(functors[0]._body)
                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, l_v_bytes in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    r_v_bytes = right_iterator.adapter.get(k_left)
                    if r_v_bytes:
                        output_writebatch.put(
                            k_left,
                            left_value_serdes.serialize(
                                f(left_value_serdes.deserialize(l_v_bytes),
                                  right_value_serdes.deserialize(r_v_bytes))))

            join_type = task._job._options.get('join_type', 'merge')

            if join_type == 'merge':
                self._run_binary(merge_join_wrapper, task)
            else:
                self._run_binary(hash_join_wrapper, task)

        elif task._name == 'subtractByKey':

            def merge_subtract_by_key_wrapper(left_iterator, left_key_serdes,
                                              left_value_serdes,
                                              right_iterator, right_key_serdes,
                                              right_value_serdes,
                                              output_writebatch):
                if not left_iterator.adapter.is_sorted(
                ) or not right_iterator.adapter.is_sorted():
                    raise RuntimeError(
                        f"merge subtract_by_key cannot be applied: not both store types support sorting. "
                        f"left type: {type(left_iterator.adapter)}, is_sorted: {left_iterator.adapter.is_sorted()}; "
                        f"right type: {type(right_iterator.adapter)}, is_sorted: {right_iterator.adapter.is_sorted()}"
                    )

                is_same_serdes = left_key_serdes == right_key_serdes
                l_iter = iter(left_iterator)
                r_iter = iter(right_iterator)
                is_left_stopped = False
                is_equal = False
                try:
                    k_left, v_left = next(l_iter)
                except StopIteration:
                    is_left_stopped = True
                    k_left = None
                    v_left = None

                try:
                    k_right_raw, v_right = next(r_iter)
                except StopIteration:
                    is_left_stopped = False
                    k_right_raw = None
                    v_right = None

                # left is None, output must be None
                if k_left is None:
                    return

                try:
                    if k_left is None:
                        raise StopIteration()
                    if k_right_raw is None:
                        raise StopIteration()

                    if is_same_serdes:
                        k_right = k_right_raw
                    else:
                        k_right = left_key_serdes.serialize(
                            right_key_serdes.deserialize(k_right_raw))

                    while True:
                        is_left_stopped = False
                        if k_left < k_right:
                            output_writebatch.put(k_left, v_left)
                            k_left, v_left = next(l_iter)
                            is_left_stopped = True
                        elif k_left == k_right:
                            is_equal = True
                            is_left_stopped = True
                            k_left, v_left = next(l_iter)
                            is_left_stopped = False
                            is_equal = False
                            k_right_raw, v_right = next(r_iter)
                            k_right = left_key_serdes.serialize(
                                right_key_serdes.deserialize(k_right_raw))
                        else:
                            k_right_raw, v_right = next(r_iter)
                            k_right = left_key_serdes.serialize(
                                right_key_serdes.deserialize(k_right_raw))
                        is_left_stopped = True
                except StopIteration as e:
                    pass
                if not is_left_stopped and not is_equal:
                    try:
                        if k_left is not None and v_left is not None:
                            output_writebatch.put(k_left, v_left)
                            while True:
                                k_left, v_left = next(l_iter)
                                output_writebatch.put(k_left, v_left)
                    except StopIteration as e:
                        pass
                elif is_left_stopped and not is_equal and k_left is not None:
                    output_writebatch.put(k_left, v_left)

                return

            def hash_subtract_by_key_wrapper(left_iterator, left_key_serdes,
                                             left_value_serdes, right_iterator,
                                             right_key_serdes,
                                             right_value_serdes,
                                             output_writebatch):
                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, v_left in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    v_right = right_iterator.adapter.get(k_left)
                    if v_right is None:
                        output_writebatch.put(k_left, v_left)

            subtract_by_key_type = task._job._options.get(
                'subtract_by_key_type', 'merge')
            if subtract_by_key_type == 'merge':
                self._run_binary(merge_subtract_by_key_wrapper, task)
            else:
                self._run_binary(hash_subtract_by_key_wrapper, task)

        elif task._name == 'union':

            def merge_union_wrapper(left_iterator, left_key_serdes,
                                    left_value_serdes, right_iterator,
                                    right_key_serdes, right_value_serdes,
                                    output_writebatch):
                if not left_iterator.adapter.is_sorted(
                ) or not right_iterator.adapter.is_sorted():
                    raise RuntimeError(
                        f"merge union cannot be applied: not both store types support sorting. "
                        f"left type: {type(left_iterator.adapter)}, is_sorted: {left_iterator.adapter.is_sorted()}; "
                        f"right type: {type(right_iterator.adapter)}, is_sorted: {right_iterator.adapter.is_sorted()}"
                    )
                f = create_functor(functors[0]._body)
                is_same_serdes = left_key_serdes == right_key_serdes

                l_iter = iter(left_iterator)
                r_iter = iter(right_iterator)
                k_left = None
                v_left_bytes = None
                k_right = None
                v_right_bytes = None
                none_none = (None, None)

                is_left_stopped = False
                is_equal = False
                try:
                    k_left, v_left_bytes = next(l_iter, none_none)
                    k_right_raw, v_right_bytes = next(r_iter, none_none)

                    if k_left is None and k_right_raw is None:
                        return
                    elif k_left is None:
                        is_left_stopped = True
                        raise StopIteration()
                    elif k_right_raw is None:
                        is_left_stopped = False
                        raise StopIteration()

                    if is_same_serdes:
                        k_right = k_right_raw
                    else:
                        k_right = left_key_serdes.serialize(
                            right_key_serdes.deserialize(k_right_raw))

                    while True:
                        is_left_stopped = False
                        while k_right < k_left:
                            if is_same_serdes:
                                output_writebatch.put(k_right, v_right_bytes)
                            else:
                                output_writebatch.put(
                                    k_right,
                                    left_value_serdes.serialize(
                                        right_value_serdes.deserialize(
                                            v_right_bytes)))

                            k_right_raw, v_right_bytes = next(r_iter)

                            if is_same_serdes:
                                k_right = k_right_raw
                            else:
                                k_right = left_key_serdes.serialize(
                                    right_key_serdes.deserialize(k_right_raw))

                        is_left_stopped = True
                        while k_left < k_right:
                            output_writebatch.put(k_left, v_left_bytes)
                            k_left, v_left_bytes = next(l_iter)

                        if k_left == k_right:
                            is_equal = True
                            output_writebatch.put(
                                k_left,
                                left_value_serdes.serialize(
                                    f(
                                        left_value_serdes.deserialize(
                                            v_left_bytes),
                                        right_value_serdes.deserialize(
                                            v_right_bytes))))
                            is_left_stopped = True
                            k_left, v_left_bytes = next(l_iter)

                            is_left_stopped = False
                            k_right_raw, v_right_bytes = next(r_iter)

                            if is_same_serdes:
                                k_right = k_right_raw
                            else:
                                k_right = left_key_serdes.serialize(
                                    right_key_serdes.deserialize(k_right_raw))
                            is_equal = False
                except StopIteration as e:
                    pass

                if not is_left_stopped:
                    try:
                        output_writebatch.put(k_left, v_left_bytes)
                        while True:
                            k_left, v_left_bytes = next(l_iter)
                            output_writebatch.put(k_left, v_left_bytes)
                    except StopIteration as e:
                        pass
                else:
                    try:
                        if not is_equal:
                            if is_same_serdes:
                                output_writebatch.put(k_right_raw,
                                                      v_right_bytes)
                            else:
                                output_writebatch.put(
                                    left_key_serdes.serialize(
                                        right_key_serdes.deserialize(
                                            k_right_raw)),
                                    left_value_serdes.serialize(
                                        right_value_serdes.deserialize(
                                            v_right_bytes)))
                        while True:
                            k_right_raw, v_right_bytes = next(r_iter)
                            if is_same_serdes:
                                output_writebatch.put(k_right_raw,
                                                      v_right_bytes)
                            else:
                                output_writebatch.put(
                                    left_key_serdes.serialize(
                                        right_key_serdes.deserialize(
                                            k_right_raw)),
                                    left_value_serdes.serialize(
                                        right_value_serdes.deserialize(
                                            v_right_bytes)))
                    except StopIteration as e:
                        pass

                    # end of merge union wrapper
                    return

            def hash_union_wrapper(left_iterator, left_key_serdes,
                                   left_value_serdes, right_iterator,
                                   right_key_serdes, right_value_serdes,
                                   output_writebatch):
                f = create_functor(functors[0]._body)

                is_diff_serdes = type(left_key_serdes) != type(
                    right_key_serdes)
                for k_left, v_left in left_iterator:
                    if is_diff_serdes:
                        k_left = right_key_serdes.serialize(
                            left_key_serdes.deserialize(k_left))
                    v_right = right_iterator.adapter.get(k_left)
                    if v_right is None:
                        output_writebatch.put(k_left, v_left)
                    else:
                        v_final = f(left_value_serdes.deserialize(v_left),
                                    right_value_serdes.deserialize(v_right))
                        output_writebatch.put(
                            k_left, left_value_serdes.serialize(v_final))

                right_iterator.first()
                for k_right, v_right in right_iterator:
                    if is_diff_serdes:
                        final_v_bytes = output_writebatch.get(
                            left_key_serdes.serialize(
                                right_key_serdes.deserialize(k_right)))
                    else:
                        final_v_bytes = output_writebatch.get(k_right)

                    if final_v_bytes is None:
                        output_writebatch.put(k_right, v_right)

            union_type = task._job._options.get('union_type', 'merge')

            if union_type == 'merge':
                self._run_binary(merge_union_wrapper, task)
            else:
                self._run_binary(hash_union_wrapper, task)

        elif task._name == 'withStores':
            f = create_functor(functors[0]._body)
            result = ErPair(key=self.functor_serdes.serialize(
                task._inputs[0]._id),
                            value=self.functor_serdes.serialize(f(task)))

        if L.isEnabledFor(logging.TRACE):
            L.trace(
                f'[RUNTASK] end. task_name={task._name}, inputs={task._inputs}, outputs={task._outputs}, task_id={task._id}'
            )
        else:
            L.debug(
                f'[RUNTASK] end. task_name={task._name}, task_id={task._id}')

        return result
示例#18
0
 def __init__(self):
     self.functor_serdes = create_serdes(SerdesTypes.CLOUD_PICKLE)