Beispiel #1
0
 def __call__(self, *args, **kwargs) -> Optional[Any]:
     if 'sync' in kwargs:
         sync = bool(kwargs['sync'])
         del kwargs['sync']
     else:
         sync = self.__default_sync
     if self.__disable_sync or not sync:
         if self.__async_limit is not None:
             _, env, _, _, _, _ = get_environment_threadsafe(
                 self.storage_path, constants.TASK_NAMESPACE, create=True)
             count = 0
             with transaction_context(env, write=True):
                 for task in self.tasks():
                     if task.status in ['submitted', 'running']:
                         count += 1
                         if count == self.__async_limit:
                             return None
                 return Task(asyncable=self,
                             args=args,
                             kwargs=kwargs,
                             site_uuid=self.site_uuid,
                             create=True,
                             bind=False)
         else:
             return Task(asyncable=self,
                         args=args,
                         kwargs=kwargs,
                         site_uuid=self.site_uuid,
                         create=True,
                         bind=False)
     return self.invoke(args=args, kwargs=kwargs)
Beispiel #2
0
def wait(*args):
    args = list(args)
    if not args:
        raise ValueError()
    if isinstance(args[-1], types.FunctionType):
        condition = args.pop()
    else:
        condition = lambda: True
    if not all(isinstance(arg, Entity) for arg in args):
        raise ValueError()
    if len(args) > 0:
        if [(arg.site_uuid, arg.namespace) for arg in args].count(
            (args[0].site_uuid, args[0].namespace)) != len(args):
            raise ValueError()
    versions = []
    for _ in polling_loop(
            getenv(constants.ADAPTER_POLLING_INTERVAL_ENVNAME, float)):
        if len(args) > 0:
            _, env, _, _, _, _ = get_environment_threadsafe(
                args[0].storage_path, args[0].namespace, create=False)
            with transaction_context(env, write=False):
                if not versions:
                    versions = [arg.version for arg in args]
                    if condition():
                        break
                else:
                    if any(versions[i] != arg.version
                           for i, arg in enumerate(args)):
                        if condition():
                            break
                        versions = [arg.version for arg in args]
        else:
            if condition():
                break
Beispiel #3
0
 def __setstate__(self, from_wire: Any):
     self._site_uuid, self._namespace, name = from_wire
     self._storage_path = get_storage_path(self._site_uuid)
     self._create = False
     self._encname = name.encode('utf-8')
     _, self._env, self._namedb, self._attrdb, self._versdb, self._descdb = \
     get_environment_threadsafe(
         self._storage_path,
         self._namespace,
         create = False
     )
     self._userdb = []
     with transaction_context(self._env, write=False) as (txn, _, _):
         self._uuid_bytes = txn.get(key=self._encname, db=self._namedb)
         self._uuid_bytes = bytes(self._uuid_bytes) \
         if isinstance(self._uuid_bytes, memoryview) else self._uuid_bytes
         if self._uuid_bytes is None:
             raise ObjectNotFoundError()
         result = txn.get(key=self._uuid_bytes, db=self._descdb)
         result = bytes(result) if isinstance(result,
                                              memoryview) else result
         descriptor = orjson.loads(result)
         self._versioned = descriptor['versioned']
     self.__bind_databases(descriptor=descriptor)
     self.__class__.__initialize_class__()
Beispiel #4
0
def get_concurrency(*, site_uuid: Optional[str] = None) -> int:
    state = Dict(
        constants.CLUSTER_STATE_DICT_PATH, site_uuid = site_uuid,
        create = True, bind = True
    )
    _, env, _, _, _, _ = get_environment_threadsafe(
        state.storage_path, state.namespace, create = False
    )
    with transaction_context(env, write = True):
        if 'concurrency' not in state:
            state['concurrency'] = getenv(constants.CLUSTER_CONCURRENCY_ENVNAME, int)
    return state['concurrency']
Beispiel #5
0
 def get(
     self,
     name: str
 ):
     _, env, name_db, _, _, descriptor_db = get_environment_threadsafe(
         self._storage_path, self._path, create = self._create
     )
     with transaction_context(env, write = False, iterator = True) as (_, cursors, _):
         obj = load_entity(
             name_db, descriptor_db, cursors, self._site_uuid,
             self._path, name = name
         )
         if obj is not None:
             return obj
         raise ObjectNotFoundError()
Beispiel #6
0
 def tasks(self) -> Iterator[Task]:
     _, env, _, _, _, _ = get_environment_threadsafe(
         self.storage_path, constants.TASK_NAMESPACE, create=True)
     with transaction_context(env, write=False) as (txn, _, _):
         cursor = txn.cursor()
         namespace = Namespace(constants.TASK_NAMESPACE,
                               site_uuid=self.site_uuid)
         if cursor.set_range(self.uuid.encode('utf-8')):
             while True:
                 key_bytes = cursor.key()
                 key_bytes = bytes(key_bytes) if isinstance(
                     key_bytes, memoryview) else key_bytes
                 key = key_bytes.decode('utf-8')
                 if key.startswith(self.uuid):
                     name = key.split(':')[1]
                     try:
                         entity = namespace.get(name)
                         if isinstance(entity, Task):
                             yield entity
                     except (KeyError, ObjectNotFoundError):
                         pass
                     if cursor.next():
                         continue
                 break
Beispiel #7
0
        self,
        /, *,
        include_hidden: Optional[bool] = None
    ) -> Iterator[Tuple[str, Dict[str, Any]]]:
        for name, descriptor in self.descriptors(
            include_hidden = include_hidden if include_hidden is not None else self._include_hidden
        ):
            yield (name, descriptor['metadata'])

    def descriptors(
        self,
        /, *,
        include_hidden: Optional[bool] = None
    ) -> Iterator[Tuple[str, Descriptor]]:
        _, env, name_db, _, _, descriptor_db = get_environment_threadsafe(
            self._storage_path, self._path, create = self._create
        )
        with transaction_context(env, write = False, iterator = True) as (_, cursors, _):
            return descriptor_iter(
                name_db,
                descriptor_db,
                cursors,
                include_hidden = include_hidden \
                if include_hidden is not None else self._include_hidden
            )

    def names(
        self,
        /, *,
        include_hidden: Optional[bool] = None
    ) -> Iterator[str]:
Beispiel #8
0
    return thread.local.default_site

def import_site(
    path: str,
    /, *,
    create: bool = False
):
    set_default_site(path, create = create, overwrite_thread_local = False)

@functools.lru_cache(None)
def get_storage_path(site_uuid: str) -> str:
    if site_uuid in site_map:
        return site_map[site_uuid]
    raise SiteNotFoundError()

def get_site_uuid(
    path: str,
    /, *,
    create: bool = False
) -> str:
    storage_path = os.path.abspath(path)
    if storage_path not in site_map:
        with site_map_lock:
            if storage_path not in site_map:
                site_uuid, _, _, _, _, _ = get_environment_threadsafe(
                    storage_path, constants.ROOT_NAMESPACE, create = create
                )
                site_map[site_uuid] = storage_path
                site_map[storage_path] = site_uuid
    return site_map[storage_path]
Beispiel #9
0
        else:
            if thread.local.default_site is not None:
                self._storage_path, self._site_uuid = thread.local.default_site
            else:
                raise SiteNotSpecifiedError()

        self._env: lmdb.Environment
        self._namedb: lmdb._Database
        self._attrdb: lmdb._Database
        self._versdb: lmdb._Database
        self._descdb: lmdb._Database

        _, self._env, self._namedb, self._attrdb, self._versdb, self._descdb = \
        get_environment_threadsafe(
            self._storage_path,
            self._namespace,
            create = create
        )

        self._userdb: List[lmdb._Database] = []

        self._uuid_bytes: bytes
        self._versioned: bool

        if bind:
            self.__bind_or_create(
                db_properties=db_properties
                if db_properties is not None else [],
                versioned=versioned,
                metadata=metadata if metadata is not None else {},
                on_init=on_init,
Beispiel #10
0
logger = logging.getLogger(__name__)

if __name__ == '__main__':

    try:

        node_uid = getenv(constants.NODE_UID_ENVNAME, str)
        cluster_uid = getenv(constants.CLUSTER_UID_ENVNAME, str)

        logger.info('worker (%s) started for site %s', node_uid,
                    get_default_site())

        submit_queue = Queue(constants.SUBMIT_QUEUE_PATH, create=True)

        _, environment, _, _, _, _ = get_environment_threadsafe(
            submit_queue.storage_path, submit_queue.namespace, create=False)

        termination_queue = Queue(constants.NODE_TERMINATION_QUEUE_PATH,
                                  create=True)

        polling_interval = getenv(constants.WORKER_POLLING_INTERVAL_ENVNAME,
                                  float)

        for i in polling_loop(polling_interval):
            while True:
                try:
                    if len(termination_queue):
                        _ = termination_queue.get()
                        logger.info('worker (%s) terminating on request',
                                    node_uid)
                        sys.exit(0)
Beispiel #11
0
            storage_path = get_storage_path(site_uuid)
        else:
            if thread.local.default_site is not None:
                storage_path, _ = thread.local.default_site
            else:
                raise SiteNotSpecifiedError()
    elif isinstance(obj, Namespace):
        namespace = obj.path
        storage_path = obj.storage_path
    elif isinstance(obj, Entity):
        namespace = obj.namespace
        storage_path = obj.storage_path
    else:
        raise ValueError()
    _, env, _, _, _, _ = get_environment_threadsafe(storage_path,
                                                    namespace,
                                                    create=False)
    return transaction_context(env, write=True)


def snapshot(obj: Optional[Union[str, Namespace, Entity]] = None,
             /,
             *,
             site_uuid: Optional[str] = None) -> ContextManager:
    if obj is None or isinstance(obj, str):
        namespace = resolve_namespace(obj)
        if site_uuid is not None:
            storage_path = get_storage_path(site_uuid)
        else:
            if thread.local.default_site is not None:
                storage_path, _ = thread.local.default_site