Ejemplo n.º 1
0
 def put(self, k, v):
     from eggroll.api.utils import eggroll_serdes
     self.serde = eggroll_serdes.get_serdes()
     self.key = k
     self.value = v
     self.batch.put(k, v)
     self.write()
Ejemplo n.º 2
0
 def _create_serde(self):
     return eggroll_serdes.get_serdes()
Ejemplo n.º 3
0
class _EggRoll(object):
    value_serdes = eggroll_serdes.get_serdes()
    instance = None
    unique_id_template = '_EggRoll_%s_%s_%s_%.20f_%d'
    host_name = 'unknown'
    host_ip = 'unknown'
    chunk_size = CHUNK_SIZE_DEFAULT

    @staticmethod
    def get_instance():
        if _EggRoll.instance is None:
            raise EnvironmentError("eggroll should be initialized before use")
        return _EggRoll.instance

    def get_channel(self):
        return self.channel

    def __init__(self, eggroll_session):
        if _EggRoll.instance is not None:
            raise EnvironmentError("eggroll should be initialized only once")

        host = eggroll_session.get_conf(EGGROLL_ROLL_HOST)
        port = eggroll_session.get_conf(EGGROLL_ROLL_PORT)
        self.chunk_size = eggroll_session.get_chunk_size()
        self.host = host
        self.port = port

        self.channel = grpc.insecure_channel(
            target="{}:{}".format(host, port),
            options=[('grpc.max_send_message_length', -1),
                     ('grpc.max_receive_message_length', -1)])
        self.session_id = eggroll_session.get_session_id()
        self.kv_stub = kv_pb2_grpc.KVServiceStub(self.channel)
        self.proc_stub = processor_pb2_grpc.ProcessServiceStub(self.channel)
        self.session_stub = node_manager_pb2_grpc.SessionServiceStub(
            self.channel)
        self.eggroll_session = eggroll_session
        _EggRoll.instance = self

        self.session_stub.getOrCreateSession(
            self.eggroll_session.to_protobuf())

        # todo: move to eggrollSession
        try:
            self.host_name = socket.gethostname()
            self.host_ip = socket.gethostbyname(self.host_name)
        except socket.gaierror as e:
            self.host_name = 'unknown'
            self.host_ip = 'unknown'

    def get_eggroll_session(self):
        return self.eggroll_session

    def stop(self):
        self.session_stub.stopSession(self.eggroll_session.to_protobuf())
        self.eggroll_session.run_cleanup_tasks()
        _EggRoll.instance = None
        self.channel.close()

    def is_stopped(self):
        return (self.instance is None)

    def table(self,
              name,
              namespace,
              partition=1,
              create_if_missing=True,
              error_if_exist=False,
              persistent=True,
              in_place_computing=False,
              persistent_engine=StoreType.LMDB):
        _type = to_pb_store_type(persistent_engine, persistent)
        storage_locator = storage_basic_pb2.StorageLocator(type=_type,
                                                           namespace=namespace,
                                                           name=name)
        create_table_info = kv_pb2.CreateTableInfo(
            storageLocator=storage_locator, fragmentCount=partition)
        _table = self._create_table(create_table_info)
        _table.set_in_place_computing(in_place_computing)
        LOGGER.debug("created table: %s", _table)
        return _table

    def parallelize(self,
                    data: Iterable,
                    include_key=False,
                    name=None,
                    partition=1,
                    namespace=None,
                    create_if_missing=True,
                    error_if_exist=False,
                    persistent=False,
                    chunk_size=100000,
                    in_place_computing=False,
                    persistent_engine=StoreType.LMDB):
        if namespace is None:
            namespace = _EggRoll.get_instance().session_id
        if name is None:
            name = str(uuid.uuid1())

        _type = to_pb_store_type(persistent_engine, persistent)

        storage_locator = storage_basic_pb2.StorageLocator(type=_type,
                                                           namespace=namespace,
                                                           name=name)
        create_table_info = kv_pb2.CreateTableInfo(
            storageLocator=storage_locator, fragmentCount=partition)
        _table = self._create_table(create_table_info)
        _table.set_in_place_computing(in_place_computing)
        _iter = data if include_key else enumerate(data)
        _table.put_all(_iter, chunk_size=chunk_size)
        LOGGER.debug("created table: %s", _table)
        return _table

    def cleanup(self,
                name,
                namespace,
                persistent,
                persistent_engine=StoreType.LMDB):
        if namespace is None or name is None:
            raise ValueError("neither name nor namespace can be None")

        _type = to_pb_store_type(persistent_engine, persistent)

        storage_locator = storage_basic_pb2.StorageLocator(type=_type,
                                                           namespace=namespace,
                                                           name=name)
        _table = _DTable(storage_locator=storage_locator)

        self.destroy_all(_table)

        LOGGER.debug("cleaned up: %s", _table)

    def generateUniqueId(self):
        return self.unique_id_template % (self.session_id, self.host_name,
                                          self.host_ip, time.time(),
                                          random.randint(10000, 99999))

    @staticmethod
    def serialize_and_hash_func(func):
        pickled_function = cloudpickle.dumps(func)
        func_id = str(uuid.uuid1())
        return func_id, pickled_function

    def _create_table(self, create_table_info):
        info = self.kv_stub.createIfAbsent(create_table_info)
        return _DTable(info.storageLocator, info.fragmentCount)

    def _create_table_from_locator(self, storage_locator, template: _DTable):
        create_table_info = kv_pb2.CreateTableInfo(
            storageLocator=storage_locator, fragmentCount=template._partitions)
        result = self._create_table(create_table_info)
        result.set_in_place_computing(template.get_in_place_computing())
        return result

    @staticmethod
    def __generate_operand(kvs: Iterable, use_serialize=True):
        for k, v in kvs:
            yield kv_pb2.Operand(key=_EggRoll.value_serdes.serialize(k)
                                 if use_serialize else bytes_to_string(k),
                                 value=_EggRoll.value_serdes.serialize(v)
                                 if use_serialize else v)

    @staticmethod
    def _deserialize_operand(operand: kv_pb2.Operand,
                             include_key=False,
                             use_serialize=True):
        if operand.value and len(operand.value) > 0:
            if use_serialize:
                return (
                    _EggRoll.value_serdes.deserialize(operand.key),
                    _EggRoll.value_serdes.deserialize(operand.value)
                ) if include_key else _EggRoll.value_serdes.deserialize(
                    operand.value)
            else:
                return (bytes_to_string(operand.key),
                        operand.value) if include_key else operand.value
        return None

    '''
    Storage apis
    '''

    def kv_to_bytes(self, **kwargs):
        use_serialize = kwargs.get("use_serialize", True)
        # can not use is None
        if "k" in kwargs and "v" in kwargs:
            k, v = kwargs["k"], kwargs["v"]
            return (self.value_serdes.serialize(k), self.value_serdes.serialize(v)) if use_serialize \
                else (string_to_bytes(k), string_to_bytes(v))
        elif "k" in kwargs:
            k = kwargs["k"]
            return self.value_serdes.serialize(
                k) if use_serialize else string_to_bytes(k)
        elif "v" in kwargs:
            v = kwargs["v"]
            return self.value_serdes.serialize(
                v) if use_serialize else string_to_bytes(v)

    def put(self, _table, k, v, use_serialize=True):
        k, v = self.kv_to_bytes(k=k, v=v, use_serialize=use_serialize)
        self.kv_stub.put(kv_pb2.Operand(key=k, value=v),
                         metadata=_get_meta(_table))

    def put_if_absent(self, _table, k, v, use_serialize=True):
        k, v = self.kv_to_bytes(k=k, v=v, use_serialize=use_serialize)
        operand = self.kv_stub.putIfAbsent(kv_pb2.Operand(key=k, value=v),
                                           metadata=_get_meta(_table))
        return self._deserialize_operand(operand, use_serialize=use_serialize)

    def action(_table, host, port, chunked_iter, use_serialize):
        _table.set_gc_disable()
        _EggRoll.get_instance().get_channel().close()
        _EggRoll.get_instance().channel = grpc.insecure_channel(
            target="{}:{}".format(host, port),
            options=[('grpc.max_send_message_length', -1),
                     ('grpc.max_receive_message_length', -1)])

        _EggRoll.get_instance().kv_stub = kv_pb2_grpc.KVServiceStub(
            _EggRoll.get_instance().channel)
        _EggRoll.get_instance(
        ).proc_stub = processor_pb2_grpc.ProcessServiceStub(
            _EggRoll.get_instance().channel)

        operand = _EggRoll.get_instance().__generate_operand(
            chunked_iter, use_serialize)
        _EggRoll.get_instance().kv_stub.putAll(operand,
                                               metadata=_get_meta(_table))

    def put_all(self,
                _table,
                kvs: Iterable,
                use_serialize=True,
                chunk_size=100000,
                skip_chunk=0):
        global gc_tag
        gc_tag = False
        skipped_chunk = 0

        chunk_size = self.chunk_size
        if chunk_size < CHUNK_SIZE_MIN:
            chunk_size = CHUNK_SIZE_DEFAULT

        host = self.host
        port = self.port
        process_pool_size = cpu_count()

        with ProcessPoolExecutor(process_pool_size) as executor:
            if isinstance(kvs, Sequence):  # Sequence
                for chunked_iter in split_every_yield(kvs, chunk_size):
                    if skipped_chunk < skip_chunk:
                        skipped_chunk += 1
                    else:
                        future = executor.submit(_EggRoll.action, _table, host,
                                                 port, chunked_iter,
                                                 use_serialize)
            else:  # other Iterable types
                try:
                    index = 0
                    while True:
                        chunked_iter = split_every(kvs, index, chunk_size,
                                                   skip_chunk)
                        chunked_iter_ = copy.deepcopy(chunked_iter)
                        next(chunked_iter_)
                        future = executor.submit(_EggRoll.action, _table, host,
                                                 port, chunked_iter,
                                                 use_serialize)
                        index += 1
                except StopIteration as e:
                    LOGGER.debug("StopIteration")
            executor.shutdown(wait=True)
        gc_tag = True

    def delete(self, _table, k, use_serialize=True):
        k = self.kv_to_bytes(k=k, use_serialize=use_serialize)
        operand = self.kv_stub.delOne(kv_pb2.Operand(key=k),
                                      metadata=_get_meta(_table))
        return self._deserialize_operand(operand, use_serialize=use_serialize)

    def get(self, _table, k, use_serialize=True):
        k = self.kv_to_bytes(k=k, use_serialize=use_serialize)
        operand = self.kv_stub.get(kv_pb2.Operand(key=k),
                                   metadata=_get_meta(_table))
        return self._deserialize_operand(operand, use_serialize=use_serialize)

    def iterate(self, _table, _range):
        return self.kv_stub.iterate(_range, metadata=_get_meta(_table))

    def destroy(self, _table):
        self.kv_stub.destroy(empty, metadata=_get_meta(_table))

    def destroy_all(self, _table):
        self.kv_stub.destroyAll(empty, metadata=_get_meta(_table))

    def count(self, _table):
        return self.kv_stub.count(empty, metadata=_get_meta(_table)).value

    '''
    Computing apis
    '''

    def map(self, _table: _DTable, func):
        return self.__do_unary_process_and_create_table(
            table=_table, user_func=func, stub_func=self.proc_stub.map)

    def map_values(self, _table: _DTable, func):
        return self.__do_unary_process_and_create_table(
            table=_table, user_func=func, stub_func=self.proc_stub.mapValues)

    def map_partitions(self, _table: _DTable, func):
        return self.__do_unary_process_and_create_table(
            table=_table,
            user_func=func,
            stub_func=self.proc_stub.mapPartitions)

    def map_partitions2(self, _table: _DTable, func):
        return self.__do_unary_process_and_create_table(
            table=_table,
            user_func=func,
            stub_func=self.proc_stub.mapPartitions2)

    def reduce(self, _table: _DTable, func):
        unary_p = self.__create_unary_process(table=_table, func=func)

        values = [
            _EggRoll._deserialize_operand(operand)
            for operand in self.proc_stub.reduce(unary_p)
        ]
        values = [v for v in filter(partial(is_not, None), values)]
        if len(values) <= 0:
            return None
        if len(values) == 1:
            return values[0]
        else:
            val, *remain = values
            for _nv in remain:
                val = func(val, _nv)
        return val

    def join(self, _left: _DTable, _right: _DTable, func):
        return self.__do_binary_process_and_create_table(
            left=_left,
            right=_right,
            user_func=func,
            stub_func=self.proc_stub.join)

    def glom(self, _table: _DTable):
        return self.__do_unary_process_and_create_table(
            table=_table, user_func=None, stub_func=self.proc_stub.glom)

    def sample(self, _table: _DTable, fraction, seed):
        if fraction < 0 or fraction > 1:
            raise ValueError("fraction must be in [0, 1]")

        func = lambda: (fraction, seed)
        return self.__do_unary_process_and_create_table(
            table=_table, user_func=func, stub_func=self.proc_stub.sample)

    def subtractByKey(self, _left: _DTable, _right: _DTable):
        return self.__do_binary_process_and_create_table(
            left=_left,
            right=_right,
            user_func=None,
            stub_func=self.proc_stub.subtractByKey)

    def filter(self, _table: _DTable, func):
        return self.__do_unary_process_and_create_table(
            table=_table, user_func=func, stub_func=self.proc_stub.filter)

    def union(self, _left: _DTable, _right: _DTable, func):
        return self.__do_binary_process_and_create_table(
            left=_left,
            right=_right,
            user_func=func,
            stub_func=self.proc_stub.union)

    def flatMap(self, _table: _DTable, func):
        return self.__do_unary_process_and_create_table(
            table=_table, user_func=func, stub_func=self.proc_stub.flatMap)

    def __create_storage_locator(self, namespace, name, _type):
        return storage_basic_pb2.StorageLocator(namespace=namespace,
                                                name=name,
                                                type=_type)

    def __create_storage_locator_from_dtable(self, _table: _DTable):
        return self.__create_storage_locator(_table._namespace, _table._name,
                                             _table._type)

    def __create_task_info(self, func, is_in_place_computing):
        if func:
            func_id, func_bytes = self.serialize_and_hash_func(func)
        else:
            func_id = str(uuid.uuid1())
            func_bytes = b'blank'

        return processor_pb2.TaskInfo(task_id=self.session_id,
                                      function_id=func_id,
                                      function_bytes=func_bytes,
                                      isInPlaceComputing=is_in_place_computing)

    def __create_unary_process(self, table: _DTable, func):
        operand = self.__create_storage_locator_from_dtable(table)
        task_info = self.__create_task_info(
            func=func, is_in_place_computing=table.get_in_place_computing())

        return processor_pb2.UnaryProcess(
            info=task_info,
            operand=operand,
            session=self.eggroll_session.to_protobuf())

    def __do_unary_process(self, table: _DTable, user_func, stub_func):
        process = self.__create_unary_process(table=table, func=user_func)

        return stub_func(process)

    def __do_unary_process_and_create_table(self, table: _DTable, user_func,
                                            stub_func):
        resp = self.__do_unary_process(table=table,
                                       user_func=user_func,
                                       stub_func=stub_func)
        return self._create_table_from_locator(resp, table)

    def __create_binary_process(self, left: _DTable, right: _DTable, func,
                                session):
        left_op = self.__create_storage_locator_from_dtable(left)
        right_op = self.__create_storage_locator_from_dtable(right)
        task_info = self.__create_task_info(
            func=func, is_in_place_computing=left.get_in_place_computing())

        return processor_pb2.BinaryProcess(
            info=task_info,
            left=left_op,
            right=right_op,
            session=self.eggroll_session.to_protobuf())

    def __do_binary_process(self, left: _DTable, right: _DTable, user_func,
                            stub_func):
        process = self.__create_binary_process(
            left=left,
            right=right,
            func=user_func,
            session=self.eggroll_session.to_protobuf())

        return stub_func(process)

    def __do_binary_process_and_create_table(self, left: _DTable,
                                             right: _DTable, user_func,
                                             stub_func):
        resp = self.__do_binary_process(left=left,
                                        right=right,
                                        user_func=user_func,
                                        stub_func=stub_func)
        return self._create_table_from_locator(resp, left)
Ejemplo n.º 4
0
 def __init__(self, data_dir):
     self._serdes = eggroll_serdes.get_serdes()
     Processor.TEMP_DIR = os.sep.join([data_dir, 'in_memory '])
     Processor.LMDB_DIR = os.sep.join([data_dir, 'lmdb'])
     Processor.LEVEL_DB_DIR = os.sep.join([data_dir, 'level_db'])
Ejemplo n.º 5
0
class EggRoll(object):
    __instance = None
    _serdes = eggroll_serdes.get_serdes()
    egg_list = []
    init_flag = False
    proc_list = []
    proc_egg_map = {}

    @staticmethod
    def get_instance():
        if EggRoll.__instance is None:
            raise EnvironmentError("eggroll should be initialized before use")
        return EggRoll.__instance

    def __init__(self, job_id):
        if EggRoll.__instance is not None:
            raise Exception("This class is a singleton!")
        EggRoll.init()
        self.job_id = str(uuid.uuid1()) if job_id is None else job_id
        self._meta_table = _DTable(self, storage_basic_pb2.LMDB, "__META__",
                                   "__META__", 10)
        EggRoll.__instance = self

    def table(self,
              name,
              namespace,
              partition=1,
              create_if_missing=True,
              error_if_exist=False,
              persistent=True):
        _type = storage_basic_pb2.LMDB if persistent else storage_basic_pb2.IN_MEMORY
        _table_key = "{}.{}.{}".format(_type, namespace, name)
        _old_partition = self._meta_table.put_if_absent(_table_key, partition)
        return _DTable(EggRoll.get_instance(), _type, namespace, name,
                       partition if _old_partition is None else _old_partition)

    def parallelize(self,
                    data: Iterable,
                    include_key=False,
                    name=None,
                    partition=1,
                    namespace=None,
                    create_if_missing=True,
                    error_if_exist=False,
                    persistent=False):
        eggroll = EggRoll.get_instance()
        if name is None:
            name = str(uuid.uuid1())
        if namespace is None and persistent:
            raise ValueError("namespace cannot be None for persistent table")
        elif namespace is None:
            namespace = eggroll.job_id
        _table = self.table(name, namespace, partition, persistent)
        _iter = data if include_key else enumerate(data)
        eggroll.put(_table, _iter)
        return _table

    def _merge(self, iters):
        ''' Merge sorted iterators. '''
        entries = []
        for _id, it in enumerate(map(iter, iters)):
            try:
                op = next(it)
                entries.append([op.key, op.value, _id, it])
            except StopIteration:
                pass
        heapify(entries)
        while entries:
            key, value, _, it = entry = entries[0]
            yield self._serdes.deserialize(key), self._serdes.deserialize(
                value)
            try:
                op = next(it)
                entry[0], entry[1] = op.key, op.value
                heapreplace(entries, entry)
            except StopIteration:
                heappop(entries)

    @staticmethod
    def init():
        if EggRoll.init_flag:
            return
        config = file_utils.load_json_conf('eggroll/conf/mock_roll.json')
        egg_ids = config.get('eggs')

        for egg_id in egg_ids:
            target = config.get('storage').get(egg_id)
            channel = grpc.insecure_channel(
                target,
                options=[('grpc.max_send_message_length', -1),
                         ('grpc.max_receive_message_length', -1)])
            EggRoll.egg_list.append(kv_pb2_grpc.KVServiceStub(channel))
            procs = config.get('procs').get(egg_id)
            for proc in procs:
                _channel = grpc.insecure_channel(
                    proc,
                    options=[('grpc.max_send_message_length', -1),
                             ('grpc.max_receive_message_length', -1)])
                _stub = processor_pb2_grpc.ProcessServiceStub(_channel)
                proc_info = (_channel, _stub)
                i = len(EggRoll.proc_list)
                EggRoll.proc_egg_map[i] = int(egg_id) - 1
                EggRoll.proc_list.append(proc_info)
        EggRoll.init_flag = True

    def serialize_and_hash_func(self, func):
        pickled_function = pickle.dumps(func)
        func_id = str(uuid.uuid1())
        return func_id, pickled_function

    @record_metrics
    def map(self, _table, func):
        func_id, func_bytes = self.serialize_and_hash_func(func)
        results = []

        for partition in range(_table.partition):
            operand = EggRoll.__get_storage_locator(_table, partition)
            unary_p = processor_pb2.UnaryProcess(
                operand=operand,
                info=processor_pb2.TaskInfo(task_id=self.job_id,
                                            function_id=func_id + "_inter",
                                            function_bytes=func_bytes))

            proc_id = partition % len(self.proc_list)
            channel, stub = self.proc_list[proc_id]
            results.append(stub.map.future(unary_p))
        for r in results:
            result = r.result()

        return _DTable(self, result.type, result.namespace, result.name,
                       _table.partition).save_as(func_id, result.namespace,
                                                 _table.partition)

    @record_metrics
    def mapPartitions(self, _table, func):
        func_id, func_bytes = self.serialize_and_hash_func(func)
        results = []

        for partition in range(_table.partition):
            operand = EggRoll.__get_storage_locator(_table, partition)
            unary_p = processor_pb2.UnaryProcess(
                operand=operand,
                info=processor_pb2.TaskInfo(task_id=self.job_id,
                                            function_id=func_id,
                                            function_bytes=func_bytes))

            proc_id = partition % len(self.proc_list)
            channel, stub = self.proc_list[proc_id]
            results.append(stub.mapPartitions.future(unary_p))
        for r in results:
            result = r.result()
        return _DTable(self, result.type, result.namespace, result.name,
                       _table.partition)

    @record_metrics
    def mapValues(self, _table, func):
        func_id, func_bytes = self.serialize_and_hash_func(func)
        results = []
        for partition in range(_table.partition):
            operand = EggRoll.__get_storage_locator(_table, partition)
            unary_p = processor_pb2.UnaryProcess(
                operand=operand,
                info=processor_pb2.TaskInfo(task_id=self.job_id,
                                            function_id=func_id,
                                            function_bytes=func_bytes))

            proc_id = partition % len(self.proc_list)
            channel, stub = self.proc_list[proc_id]
            results.append(stub.mapValues.future(unary_p))

        for r in results:
            result = r.result()
        return _DTable(self, result.type, result.namespace, result.name,
                       _table.partition)

    @record_metrics
    def glom(self, _table):
        results = []
        func_id = str(uuid.uuid1())
        for p in range(_table.partition):
            operand = EggRoll.__get_storage_locator(_table, p)
            unary_p = processor_pb2.UnaryProcess(operand=operand,
                                                 info=processor_pb2.TaskInfo(
                                                     task_id=self.job_id,
                                                     function_id=func_id))
            proc_id = p % len(self.proc_list)
            channel, stub = self.proc_list[proc_id]
            results.append(stub.glom.future(unary_p))
        for r in results:
            result = r.result()
        return _DTable(self, result.type, result.namespace, result.name,
                       _table.partition)

    @record_metrics
    def sample(self, _table, fraction, seed):
        if fraction < 0 or fraction > 1:
            raise ValueError("fraction must be in [0, 1]")
        func_bytes = self._serdes.serialize((fraction, seed))
        results = []
        func_id = str(uuid.uuid1())
        for p in range(_table.partition):
            operand = EggRoll.__get_storage_locator(_table, p)
            unary_p = processor_pb2.UnaryProcess(
                operand=operand,
                info=processor_pb2.TaskInfo(task_id=self.job_id,
                                            function_id=func_id,
                                            function_bytes=func_bytes))
            proc_id = p % len(self.proc_list)
            channel, stub = self.proc_list[proc_id]
            results.append(stub.sample.future(unary_p))
        for r in results:
            result = r.result()
        return _DTable(self, result.type, result.namespace, result.name,
                       _table.partition)

    @record_metrics
    def reduce(self, _table, func):
        func_id, func_bytes = self.serialize_and_hash_func(func)
        rtn = None
        results = []
        for partition in range(_table.partition):
            operand = EggRoll.__get_storage_locator(_table, partition)
            proc_id = partition % len(self.proc_list)
            channel, stub = self.proc_list[proc_id]
            unary_p = processor_pb2.UnaryProcess(
                operand=operand,
                info=processor_pb2.TaskInfo(task_id=self.job_id,
                                            function_id=func_id,
                                            function_bytes=func_bytes))
            results = results + list(stub.reduce(unary_p))
        rs = []
        for val in results:
            if len(val.value) > 0:
                rs.append(self._serdes.deserialize(val.value))
        rs = [r for r in filter(partial(is_not, None), rs)]
        if len(results) <= 0:
            return rtn
        rtn = rs[0]
        for r in rs[1:]:
            rtn = func(rtn, r)
        return rtn

    @record_metrics
    def join(self, left, right, func):
        func_id, func_bytes = self.serialize_and_hash_func(func)

        results = []
        res = None
        for partition in range(left.partition):
            l_op = EggRoll.__get_storage_locator(left, partition)
            r_op = EggRoll.__get_storage_locator(right, partition)
            binary_p = processor_pb2.BinaryProcess(
                left=l_op,
                right=r_op,
                info=processor_pb2.TaskInfo(task_id=self.job_id,
                                            function_id=func_id,
                                            function_bytes=func_bytes))
            proc_id = partition % len(self.proc_list)
            channel, stub = self.proc_list[proc_id]
            results.append(stub.join.future(binary_p))
        for r in results:
            res = r.result()
        return _DTable(self, res.type, res.namespace, res.name, left.partition)

    @staticmethod
    def __get_storage_locator(_table, fragment=None):
        if fragment is None:
            fragment = _table.partition
        return StorageLocator(name=_table.name,
                              namespace=_table.namespace,
                              type=_table.type,
                              fragment=fragment)

    def split_gen(self, _iter: Iterable, num):
        gens = tee(_iter, num)
        return (self.dispatch_gen(gen, i, num) for i, gen in enumerate(gens))

    def dispatch_gen(self, _iter: Iterable, p, total):
        for k, v in _iter:
            _p, i = self.__get_index(k, total)
            if _p == p:
                yield kv_pb2.Operand(key=self._serdes.serialize(k),
                                     value=self._serdes.serialize(v))

    def put(self, _table, kv_list):

        gens = self.split_gen(kv_list, _table.partition)
        results = []

        for p, gen in enumerate(gens):
            i = p % len(self.proc_list)
            stub = self.egg_list[i]
            meta = self.__get_meta(_table, str(p))
            stub.putAll(gen, metadata=meta)
        for r in results:
            r.result()
        return True

    def put_if_absent(self, _table, k, v):
        p, i = self.__get_index(k, _table.partition)
        stub = self.egg_list[i]
        meta = self.__get_meta(_table, str(p))
        rtn = stub.putIfAbsent(kv_pb2.Operand(key=self._serdes.serialize(k),
                                              value=self._serdes.serialize(v)),
                               metadata=meta).value
        rtn = self._serdes.deserialize(rtn) if len(rtn) > 0 else None
        return rtn

    def get(self, _table, k_list):
        res = []
        for k in k_list:
            p, i = self.__get_index(k, _table.partition)
            stub = self.egg_list[i]
            op = stub.get(kv_pb2.Operand(key=self._serdes.serialize(k)),
                          metadata=self.__get_meta(_table, str(p)))
            res.append(self.__get_pair(op))
        return res

    def delete(self, _table, k):
        p, i = self.__get_index(k, _table.partition)
        stub = self.egg_list[i]
        op = stub.delOne(kv_pb2.Operand(key=self._serdes.serialize(k)),
                         metadata=self.__get_meta(_table, str(p)))
        return self.__get_pair(op)

    def iterate(self, _table):
        iters = []
        for p in range(_table.partition):
            proc_id = p % len(EggRoll.proc_list)
            i = self.__get_index_by_proc(proc_id)
            stub = self.egg_list[i]
            iters.append(
                _PartitionIterator(stub, self.__get_meta(_table, str(p))))
        return self._merge(iters)

    def destroy(self, _table):
        for p in range(_table.partition):
            proc_id = p % len(EggRoll.proc_list)
            i = self.__get_index_by_proc(proc_id)
            stub = self.egg_list[i]
            stub.destroy(kv_pb2.Empty(),
                         metadata=self.__get_meta(_table, str(p)))

    def count(self, _table):
        count = 0
        for p in range(_table.partition):
            proc_id = p % len(EggRoll.proc_list)
            i = self.__get_index_by_proc(proc_id)
            stub = self.egg_list[i]
            count += stub.count(kv_pb2.Empty(),
                                metadata=self.__get_meta(_table, str(p))).value
        return count

    @staticmethod
    def __get_meta(_table, fragment):
        return ('store_type',
                _table.type), ('table_name',
                               _table.name), ('name_space',
                                              _table.namespace), ('fragment',
                                                                  fragment)

    @cached(cache=TTLCache(maxsize=100, ttl=360))
    def __calc_hash(self, k):
        k_bytes = hashlib.sha1(self._serdes.serialize(k)).digest()
        return int.from_bytes(k_bytes, byteorder='little')

    def __key_to_partition(self, k, partitions):
        i = self.__calc_hash(k)
        return i % partitions

    @staticmethod
    def __get_index_by_proc(proc_id):
        egg_id = EggRoll.proc_egg_map[proc_id]
        return egg_id

    def __get_index(self, k, partitions):
        p, proc_id = self.__get_proc(k, partitions)
        return p, self.__get_index_by_proc(proc_id)

    def __get_proc(self, k, partitions):
        p = self.__key_to_partition(k, partitions)
        return p, p % len(self.proc_list)

    def __get_pair(self, op):
        return (self._serdes.deserialize(op.key),
                self._serdes.deserialize(op.value)) if len(op.value) > 0 else (
                    self._serdes.deserialize(op.key), None)
Ejemplo n.º 6
0
 def __init__(self, data_dir):
     self._serdes = eggroll_serdes.get_serdes()
     Processor.TEMP_DIR = os.sep.join([data_dir, 'lmdb_temporary'])
     Processor.DATA_DIR = os.sep.join([data_dir, 'lmdb'])