Beispiel #1
0
 async def _wait_response(lock: asyncio.Lock, queue: asyncio.Queue):
     """Wait for the direct ACK message, and post False if timeout reached."""
     # TODO: Need to consider the risk of this. We may be unlocking a prior send command.
     # This would mean that the prior command will terminate. What happens when the
     # prior command returns a direct ACK then this command returns a direct ACK?
     # Do not believe this is an issue but need to test.
     if lock.locked():
         lock.release()
     await lock.acquire()
     try:
         await asyncio.wait_for(lock.acquire(), TIMEOUT)
     except asyncio.TimeoutError:
         if lock.locked():
             await queue.put(ResponseStatus.FAILURE)
     if lock.locked():
         lock.release()
Beispiel #2
0
    async def archive(self, record: ArchiveRecord):
        """Process a storage write, only allowing one write at a time to each path
        """
        async with self.ability_to_create_path_lock:
            lock = self.path_locks.get(record.filepath)

            if lock is None:
                lock = Lock()
                self.path_locks[record.filepath] = lock

        # Skip writes when a given path is already locked
        if lock.locked():
            self.log.info("Skipping archive of %s", record.filepath)
            return

        async with lock:
            try:
                async with self.session.create_client('s3',
                                                      aws_secret_access_key=self.settings.s3_secret_access_key,
                                                      aws_access_key_id=self.settings.s3_access_key_id,
                                                      endpoint_url=self.settings.s3_endpoint_url,
                                                      region_name=self.settings.s3_region_name,
                                                      ) as client:
                    self.log.info("Processing storage write of %s", record.filepath)
                    file_key = s3_key(self.settings.workspace_prefix, record.filepath)
                    await client.put_object(Bucket=self.settings.s3_bucket, Key=file_key, Body=record.content)
                    self.log.info("Done with storage write of %s", record.filepath)
            except Exception as e:
                    self.log.error(
                        'Error while archiving file: %s %s',
                        record.filepath,
                        e,
                        exc_info=True,
                    )
Beispiel #3
0
class UpdateManager(object):
    def __init__(self, name=None):
        self.lock = Lock()
        self.name = name or random_key(6)
        self.loss_history = []
        self.n_updates = 0
        self._reset_state()

    def _reset_state(self):
        self.update_name = "update_{}_{:05d}".format(self.name, self.n_updates)
        self.clients = set()
        self.client_responses = dict()
        self.update_meta = None

    @property
    def in_progress(self):
        return self.lock.locked()

    @property
    def clients_left(self):
        return len(self.clients) - len(self.client_responses)

    def __len__(self):
        #return self.in_progress * len(self.clients)
        return len(self.clients)

    async def start_update(self, **update_meta):
        print("starting update")
        #if self.in_progress:
        #    raise UpdateInProgress
        self._reset_state()
        #await self.lock.acquire()
        self.update_meta = update_meta

    def end_update(self):
        #self.lock.release()
        self.n_updates += 1
        return self.client_responses


    def client_start(self, client_id):
        #if not self.in_progress:
        #    raise UpdateNotInProgress
        self.clients.add(client_id)

    def client_end(self, client_id, response):
        #if not self.in_progress:
        #    raise UpdateNotInProgress
        
        #client might not be there already for certain scenarios 
        self.clients.add(client_id)
        
        self.client_responses[client_id] = response
        
        print("Update finished: {} [{}/{}]".format(
            client_id,
            len(self.client_responses),
            len(self.clients))
        )
Beispiel #4
0
async def route_actions(action: Actions, trade_lock: Lock, testnet: bool,
                        ohlcv: dict):
    if trade_lock.locked():
        logger.info('There is an ongoing trade, bailing out')
        return

    logger.info(
        'There is no ongoing trade, we can instantiate a playbook if applicable'
    )

    if action == Actions.SHORT:
        logger.info('Signal suggests short!')
    elif action == Actions.LONG:
        logger.info('Signal suggests long!')
    elif action == Actions.NOTHING:
        logger.info('Signal suggests doing nothing..')
        return

    playbooks = {
        'HitAndRun': HitAndRun,
        'Fractalism': Fractalism,
        'FractalismFibo': FractalismFibo,
        'DCA': DCA,
    }
    _playbook = playbooks.get(PLAYBOOK)
    if not _playbook:
        raise NotImplementedError(
            f'Playbook {PLAYBOOK} is not yet implemented')

    exchange = get_ccxt_client(exchange=TRADES_EXCHANGE,
                               api_key=API_KEY,
                               api_secret=API_SECRET,
                               testnet=testnet)

    notification_channels = {'telegram': Telegram, 'webhook': Webhook}
    _notification = notification_channels.get(NOTIFY_USING)
    if not _notification:
        raise NotImplementedError(
            f'The notification channel {NOTIFY_USING} is not yet implemented')

    notification = _notification()

    playbook = _playbook(action=action,
                         exchange=exchange,
                         trade_lock=trade_lock,
                         logger=logger,
                         symbol=TRADE_SYMBOL,
                         timeframe=TIMEFRAME,
                         notification=notification,
                         leverage=int(LEVERAGE),
                         ohlcv=ohlcv)

    await playbook.play()
Beispiel #5
0
    async def archive(self, record: ArchiveRecord):
        """Process a record to write to storage.

        Acquire a path lock before archive. Writing to storage will only be
        allowed to a path if a valid `path_lock` is held and the path is not
        locked by another process.

        Parameters
        ----------

        record : ArchiveRecord
            A notebook and where it should be written to storage
        """
        async with self.path_lock_ready:
            lock = self.path_locks.get(record.filepath)

            if lock is None:
                lock = Lock()
                self.path_locks[record.filepath] = lock

        # Skip writes when a given path is already locked
        if lock.locked():
            self.log.info("Skipping archive of %s", record.filepath)
            return

        async with lock:
            try:
                async with self.session.create_client(
                        's3',
                        aws_secret_access_key=self.settings.
                        s3_secret_access_key,
                        aws_access_key_id=self.settings.s3_access_key_id,
                        endpoint_url=self.settings.s3_endpoint_url,
                        region_name=self.settings.s3_region_name,
                ) as client:
                    self.log.info("Processing storage write of %s",
                                  record.filepath)
                    file_key = s3_key(self.settings.workspace_prefix,
                                      record.filepath)
                    await client.put_object(Bucket=self.settings.s3_bucket,
                                            Key=file_key,
                                            Body=record.content)
                    self.log.info("Done with storage write of %s",
                                  record.filepath)
            except Exception as e:
                self.log.error('Error while archiving file: %s %s',
                               record.filepath,
                               e,
                               exc_info=True)
Beispiel #6
0
    async def keeproles(self, ctx, boolean: bool = None):
        """Get the current keeproles value on this server or change it.
        Keeproles makes the bot save every users roles so it can give them even if that user rejoins
        but only the roles the bot can give"""
        guild = ctx.guild
        current = self.cache.keeproles(guild.id)

        if current == boolean:
            return await ctx.send('Keeproles is already set to %s' % boolean)

        lock = self._guild_locks['keeproles'].get(guild.id, None)
        if lock is None:
            lock = Lock(loop=self.bot.loop)
            self._guild_locks['keeproles'][guild.id] = lock

        if lock.locked():
            return await ctx.send('Hol up b')

        if boolean:
            t = time.time()
            await lock.acquire()
            try:
                bot_member = guild.get_member(self.bot.user.id)
                perms = bot_member.guild_permissions
                if not perms.administrator and not perms.manage_roles:
                    return await ctx.send(
                        'This bot needs manage roles permissions to enable this feature'
                    )
                msg = await ctx.send('indexing roles')
                if not await self.bot.dbutils.index_guild_member_roles(guild):
                    return await ctx.send('Failed to index user roles')

                await msg.edit(
                    content='Indexed roles in {0:.2f}s'.format(time.time() - t)
                )
            except discord.DiscordException:
                pass
            finally:
                lock.release()

        await self.cache.set_keeproles(guild.id, boolean)
        await ctx.send('Keeproles set to %s' % str(boolean))
class ListenHandle:
    """Base listen handle container."""
    def __init__(self, handle, app):
        self._lock = Lock()
        self._app = app
        self.is_active = True
        self._handle = handle

    async def cancel(self):
        """Cancel the listener."""
        if self._lock.locked() or not self.is_active or self._handle is None:
            return

        self.is_active = False
        async with self._lock:
            await self._do_cancel()
            self._on_cancelled()

    def _on_cancelled(self):
        """Set properties once cancelled."""
        self.is_active = False
        self._handle = None
        self._app = None

    async def _do_cancel(self):
        """Perform handle cancel."""
        pass

    def __str__(self):
        return self._handle

    def __eq__(self, o):
        if self._handle is None:
            return False
        if isinstance(o, ListenHandle):
            return self._handle.__eq__(o._handle)
        return self._handle.__eq__(o)

    def __hash__(self):
        return hash(self._handle)
Beispiel #8
0
async def healthcheck_job(
    job_params: "JobParams",
    on_result: t.Callable[[MetricsCollection], t.Awaitable[None]],
    on_error: t.Callable[["JobParams", Exception], t.Awaitable[None]],
    sync_lock: asyncio.Lock,
    settings: Settings,
):
    """
    Makes request to target url, produce metrics and send them to output queue
    :param job_params: contains JobParams instance with details about this job
    :param on_result: callback function to call with results
    :param on_error: callback function to call with any errors
    :param sync_lock: lock to keep one instance of each health check job at a time
    :param settings: instance of application settings
    """
    if sync_lock.locked():
        logger.debug(f"Health check locked for {job_params}. Skipping.")
        return

    async with sync_lock:
        start_at = datetime.datetime.utcnow()
        try:
            async with ClientSession(timeout=ClientTimeout(
                    total=settings.request_timeout)) as session:
                async with session.get(job_params.url) as response:
                    result = HealthcheckJobResult(
                        url=job_params.url,
                        request_start_at=start_at,
                        response_received_at=datetime.datetime.utcnow(),
                        response_headers=response.headers,
                        response_content=await response.read(),
                        response_status=response.status,
                    )
            metrics = collect_metrics(result,
                                      regex_pattern=job_params.body_regex)
            await on_result(metrics)
        except Exception as e:
            await on_error(job_params, e)
Beispiel #9
0
    async def flush_cache(self, lock: asyncio.Lock = None) -> None:
        """We need to hack around the fact that asyncio has no reentrant lock
        So we either grab the storing_lock ourselves or the caller needs
        to pass us the locked storing_lock
        """
        has_lock_arg = lock is not None
        if not has_lock_arg:
            lock = self.storing_lock
            await lock.acquire()

        assert lock == self.storing_lock and lock.locked()

        for table_name, batches in self._batches.items():
            table = pa.Table.from_batches(batches)
            await self.write_table(table_name, table)
        self._batches.clear()

        for event in self.flush_events:
            event.set()
        self.flush_events.clear()

        if not has_lock_arg:
            lock.release()
Beispiel #10
0
    class ActivityDecorator(ActivityBase):
        def __init__(self, target):
            if not iscoroutinefunction(target):
                raise TypeError("I find your lack of async disturbing.")
            self.target = target  #: target coroutine for this activity.

            self.__trigger_obj = None

            self.trigger_obj = trigger

            self.mode = mode

            self.lock = Lock()

            # Create empty parameter dictionary
            self.parameters = inspect.signature(target).parameters
            self.empty_param_dict = {}
            for param in self.parameters:
                if param != "self":
                    self.empty_param_dict[param] = None

            self._args = args
            self._kwargs = kwargs

        @property
        def trigger_obj(self):
            """
            Trigger connected to this activity.
            :rtype: urban_journey.TriggerBase
            """
            return self.__trigger_obj

        @trigger_obj.setter
        def trigger_obj(self, trigger):
            if trigger is not None:
                if self.__trigger_obj is not None:
                    self.__trigger_obj.remove_activity(self)
                self.__trigger_obj = trigger
                self.__trigger_obj.add_activity(self)

        async def trigger(self, senders, sender_parameters, instance, *args, **kwargs):
            """
            Called by the trigger.

            :param senders: Dictionary with string typed key containing
            """
            try:
                if self.lock.locked():
                    if self.mode is ActivityMode.drop:
                        return
                with (await self.lock):
                    # TODO: Remove support for "instance is None". This is currently only meant to be used in unittests.

                    # Create new parameters dictionary and fill it in with the data coming in from the triggers.
                    params = copy(self.empty_param_dict)
                    for param in params:
                        if param in sender_parameters:
                            params[param] = sender_parameters[param]

                    if instance is None:
                        await self.target(*args, *self._args, **kwargs, **self._kwargs, **params)
                    else:
                        await self.target(instance, *args, *self._args, **kwargs, **self._kwargs, **params)

            except Exception as e:
                # On exeption let the exception handler of the root node deal with it.
                if instance is None:
                    print_exception(*sys.exc_info())
                    raise e
                else:
                    instance.root.handle_exception(sys.exc_info())

        def __call__(self, *args, **kwargs):
            return self.target(*args, **kwargs)
class TaskExecuter:
    """Execute queued tasks and cache results in Redis."""

    @classmethod
    async def async_init(cls,
                         # child class of ABC Loader
                         loader: Loader,
                         # queue where tasks will be fetched from.
                         task_queue: TaskQueue,
                         # key to Redis List where results completed tasks will be stored. (stored as pickled Python objects)
                         task_results_key: Optional[str] = 'crawler_cluster_results',
                         # logging.Logger instance
                         logger: Logger = None) -> 'TaskExecuter':
        """Return an initialized TaskExecuter instance and start the task-processing loop."""
        self = cls()
        self.logger = logger or get_logger()
        self.loader = loader
        self.task_queue = task_queue
        # use same Redis connection as task queue.
        self.redis = task_queue.redis
        # check for terminal errors.
        if not self.ok:
            await self.shutdown()
        self.task_results_key = task_results_key
        # keep track of task in progress so it can be re-cached in case of failure.
        self._task_in_progres = None
        self._task_in_progres_lock = Lock()
        # execute all queued tasks.
        asyncio.create_task(self._execute_queued_tasks())
        self.logger.info(f"Running: {self}")
        return self

    @property
    def ok(self) -> bool:
        """Return False if a terminal error has been encountered."""
        return self.task_queue.ok and self.loader.ok

    @property
    def task_in_progres(self) -> Dict[str, Any]:
        """Return metadata of currently running task."""
        return self._task_in_progres

    async def cache_task_results(self, task_data: Dict[str, Any]) -> None:
        """Pickle serialize function results and add to results list."""
        # lock to eliminate the possibility of current task being requeued after result is set. this can happen if shutdown() and cache_task_results() are called at approximately the same time.
        self.logger.info(f"Caching task result: {task_data}")
        async with self._task_in_progres_lock:
            try:
                await self.redis.lpush(self.task_results_key, pickle.dumps(task_data))
            except PicklingError as e:
                self.logger.error(
                    f"Failed to pickle function result! Task: {task_data}. Error: {e}")
            # mark that function is no longer in progress.
            self._task_in_progres = None

    async def shutdown(self, sig=None) -> None:
        """Close connections and save any incomplete task."""
        try:
            if sig is not None:
                self.logger.info(f"Caught signal: {sig.name}")
            self.logger.info(f"Shutting down {self}")
            # check if there is an unfinished task that should be requeued.
            await self._requeue_current_task()
        finally:
            # exit so systemd can restart the service.
            sys.exit(0)

    async def _execute_queued_tasks(self) -> None:
        """Execute the next task in queue."""
        # get task from queue.
        task_data = await self.task_queue.get()
        # execute task.
        await self._execute_task(task_data)
        # repeat.
        asyncio.create_task(self._execute_queued_tasks())

    async def _execute_task(self, task_data: Dict[str, Any]):
        """Execute task and add result to Redis results list."""
        # keep reference to task in progress so it can be re-queued in case shutdown() is called.
        self._task_in_progres = task_data
        # execute this task's function.
        result_data = await self._call_function(task_data)
        # check if worker has encountered terminal error during execution of function.
        if not self.ok:
            self.logger.error(
                f"Encountered terminal error while executing task: {task_data}. Shutting down.")
            # shutdown do this worker node can be restarted (by systemd, etc)
            await self.shutdown()
        if isinstance(result_data, dict) and 'is_last' in result_data:
            # this is part of a multipart result.
            await self._process_multipart_result_part(task_data, result_data)
        else:
            # this if not a multipart result, so cache this data as final result.
            task_data['result'] = result_data
            return await self.cache_task_results(task_data)

    async def _process_multipart_result_part(self, task_data: Dict[str, Any], result_data: Dict[str, Any]):
        """Tasks may be long-running tasks that create 'child' tasks as they're executing. 
        These tasks should contain the field 'is_last' in their result_data dict, indicting if the result is from the last child task.
        """
        multipart_list_key = f"tmp_multipart_results::{task_data['task_id']}"
        if 'result' in result_data:
            await self.redis.lpush(multipart_list_key, pickle.dumps(result_data.pop('result')))
        if result_data['is_last']:
            # this is the final part of a multipart result.
            del result_data['is_last']
            result_data['multipart_list'] = multipart_list_key
            result_data['task_id'] = task_data['task_id']
            return await self.cache_task_results(result_data)

    async def _call_function(self, task_data: Dict[str, Any]) -> Any:
        """Call a function, importing it if needed.
        task_data is required to have fields two fields: 'module' and 'function', 
        which specify the module name and function name of the function that should be called.
        """
        if 'module' not in task_data or 'function' not in task_data:
            raise ValueError(
                "task_data must contain keys 'module' and 'function', which correspond to the module name and function name of the function that should be called.")
        # get function from a dynamically imported module.
        function = get_function(task_data['module'], task_data['function'])
        if function is not None:
            kwargs = self._get_function_kwargs(function, task_data)
            try:
                # execute function.
                start = time()
                result = await asyncio.wait_for(function(**kwargs), task_data.get('task_timeout', 75))
                finish = time()
                self.logger.info(
                    f"Executed task ({round(finish-start,1)} s) {task_data}")
                return result
            except Exception as e:
                self.logger.error(
                    f"Error executing user-provided function ({type(e)}): {e}. Task: {task_data}")
                await self.loader.handle_exception(e)

    def _get_function_kwargs(self, function: Any, task_data: Dict[str, Any]) -> Dict[str, Any]:
        """Create dict of keyword arguments for function call."""
        # remove any fields that were added by the queue client.
        kwargs = {}
        if task_data:
            kwargs.update({k: v for k, v in task_data.items() if k not in (
                'module', 'function', 'task_id')})
        try:
            func_params = inspect.signature(function).parameters
        except Exception as e:
            self.logger.error(
                f"Could not get signature of function {function}! Error: {e}")
            return kwargs
        # check if the function requires reference to the Loader instance.
        if 'loader' in func_params:
            kwargs['loader'] = self.loader
        # check if the function requires reference to the TaskQueue instance.
        if 'task_queue' in func_params:
            kwargs['task_queue'] = self.task_queue
        # check if the function requires reference to the TaskExecuter instance.
        if 'task_executer' in func_params:
            kwargs['task_executer'] = self
        # check if function requires reference to the task ID.
        if 'task_id' in func_params:
            kwargs['task_id'] = task_data['task_id']
        return kwargs

    async def _requeue_current_task(self) -> None:
        """Requeue any currently in progress and incomplete task."""
        if self._task_in_progres is not None and not self._task_in_progres_lock.locked():
            self.logger.info(
                f"Requeuing task: {self._task_in_progres}")
            await self.task_queue._priority_put(self._task_in_progres)

    def __str__(self):
        return f"TaskExecuter (Results list key: {self.task_results_key}, {self.loader}, {self.task_queue})"
Beispiel #12
0
class Supervisor:
    def __init__(
        self,
        pool: Executor,
        nvim: Nvim,
        vars_dir: Path,
        match: MatchOptions,
        comp: CompleteOptions,
        limits: Limits,
        reviewer: PReviewer,
    ) -> None:
        self.pool = pool
        self.vars_dir = vars_dir
        self.match, self.comp, self.limits = match, comp, limits
        self.nvim, self._reviewer = nvim, reviewer

        self.idling = Condition()
        self._workers: MutableMapping[Worker, BaseClient] = WeakKeyDictionary()

        self._lock = Lock()
        self._task: Optional[Task] = None
        self._tasks: Sequence[Task] = ()

    @property
    def clients(self) -> AbstractSet[BaseClient]:
        return {*self._workers.values()}

    def register(self, worker: Worker, assoc: BaseClient) -> None:
        self._reviewer.register(assoc)
        self._workers[worker] = assoc

    def notify_idle(self) -> None:
        async def cont() -> None:
            async with self.idling:
                self.idling.notify_all()

        go(self.nvim, aw=cont())

    async def interrupt(self) -> None:
        g = gather(*chain(((self._task, ) if self._task else ()), self._tasks))
        self._task, self._tasks = None, ()
        await cancel(g)

    def collect(self, context: Context) -> Awaitable[Sequence[Metric]]:
        loop: AbstractEventLoop = self.nvim.loop
        t1, done = monotonic(), False
        timeout = (self.limits.completion_manual_timeout
                   if context.manual else self.limits.completion_auto_timeout)

        acc: MutableSequence[Metric] = []

        async def supervise(worker: Worker, assoc: BaseClient) -> None:
            instance, items = uuid4(), 0

            with with_suppress(), timeit(f"WORKER -- {assoc.short_name}"):
                await self._reviewer.s_begin(assoc, instance=instance)
                try:
                    async for completion in worker.work(context):
                        if not done and completion:
                            metric = self._reviewer.trans(
                                instance, completion=completion)
                            acc.append(metric)
                            items += 1
                        else:
                            await sleep(0)
                finally:
                    elapsed = monotonic() - t1
                    await self._reviewer.s_end(
                        instance,
                        interrupted=done,
                        elapsed=elapsed,
                        items=items,
                    )

        async def cont() -> Sequence[Metric]:
            nonlocal done

            with with_suppress(), timeit("COLLECTED -- ALL"):
                if self._lock.locked():
                    log.warn("%s", "SHOULD NOT BE LOCKED <><> supervisor")
                async with self._lock:
                    await self._reviewer.begin(context)
                    self._tasks = tasks = tuple(
                        loop.create_task(supervise(worker, assoc=assoc))
                        for worker, assoc in self._workers.items())
                    try:
                        if not tasks:
                            return ()
                        else:
                            _, pending = await wait(tasks, timeout=timeout)
                            if not acc:
                                for fut in as_completed(pending):
                                    await fut
                                    if acc:
                                        break
                            return acc
                    finally:
                        done = True

        self._task = loop.create_task(cont())
        return self._task
Beispiel #13
0
class BasicEngine:
    # 基础事务
    def __init__(self, filename: str):
        if not isfile(filename):
            with open(filename, 'wb') as file:
                # indicator
                file.write(OP)
                # root
                file.write(pack('Q', 9))
                self.root = IndexNode(is_leaf=True)
                self.root.dump(file)
        else:
            with open(filename, 'rb+') as file:
                if file.read(1) == OP:
                    file.close()
                    p = Process(target=repair, args=(filename, ))
                    p.start()
                    p.join()
                    return self.__init__(filename)
                else:
                    ptr = unpack('Q', file.read(8))[0]
                    file.seek(ptr)
                    self.root = IndexNode(file=file)
                    file.seek(0)
                    file.write(OP)

        self.allocator = Allocator()
        self.async_file = AsyncFile(filename)
        self.command_que = SortedList()
        self.file = open(filename, 'rb+', buffering=0)
        self.lock = Lock()
        self.on_interval = (0, 1)
        self.on_write = False
        self.task_que = TaskQue()

    def malloc(self, size: int) -> int:
        def is_inside(ptr: int) -> bool:
            if self.on_write:
                begin, end = self.on_interval
                return min(ptr + size, end) - max(ptr, begin) >= 0

        ptr = self.allocator.malloc(size)
        if ptr and is_inside(ptr):
            self.free(ptr, size)
            ptr = 0
        if not ptr:
            ptr = self.async_file.size
            if is_inside(ptr):
                ptr += 1
                self.async_file.size += 1
            self.async_file.size += size
        return ptr

    def free(self, ptr: int, size: int):
        self.allocator.free(ptr, size)

    def time_travel(self, token: Task, node: IndexNode):
        address = node.nth_value_ads(0)
        for i in range(len(node.ptrs_value)):
            ptr = self.task_que.get(token, address, node.ptr)
            if ptr:
                node.ptrs_value[i] = ptr
            address += 8
        if not node.is_leaf:
            for i in range(len(node.ptrs_child)):
                ptr = self.task_que.get(token, address, node.ptr)
                if ptr:
                    node.ptrs_child[i] = ptr
                address += 8

    def a_command_done(self, token: Task):
        token.command_num -= 1
        if token.command_num == 0:
            self.task_que.clean()
            if not self.task_que.que and self.lock.locked():
                self.lock.release()

    # cumulation
    def do_cum(self, token: Task, free_nodes, command_map):
        def func():
            for node in free_nodes:
                self.free(node.ptr, node.size)

        token.free_param = func
        for ptr, param in command_map.items():
            data, depend = param if isinstance(param, tuple) else (param, 0)
            self.ensure_write(token, ptr, data, depend)
        self.time_travel(token, self.root)
        self.root = self.root.clone()

    def ensure_write(self, token: Task, ptr: int, data: bytes, depend=0):
        async def coro():
            while self.command_que:
                ptr, token, data, depend = self.command_que.pop(0)
                cancel = depend and self.task_que.is_canceled(token, depend)
                if not cancel:
                    cancel = self.task_que.is_canceled(token, ptr)
                if not cancel:
                    # 确保边界不相连
                    self.on_interval = (ptr - 1, ptr + len(data) + 1)
                    await self.async_file.write(ptr, data)
                self.a_command_done(token)
            self.on_write = False

        if not self.on_write:
            self.on_write = True
            ensure_future(coro())
        # 按ptr和token.id排序
        self.command_que.append((ptr, token, data, depend))
        token.command_num += 1

    def close(self):
        self.file.seek(0)
        self.file.write(ED)
        self.file.close()
        self.async_file.close()
Beispiel #14
0
class BasicEngine:
    # 基础事务
    def __init__(self, filename: str):
        if not isfile(filename):
            with open(filename, 'wb') as file:
                # indicator
                file.write(OP)
                # root
                file.write(pack('Q', 9))
                self.root = IndexNode(is_leaf=True)
                self.root.dump(file)
        else:
            with open(filename, 'rb+') as file:
                if file.read(1) == OP:
                    file.close()
                    return BasicEngine.repair(filename)
                else:
                    ptr = unpack('Q', file.read(8))[0]
                    file.seek(ptr)
                    self.root = IndexNode(file=file)
                    file.seek(0)
                    file.write(OP)

        self.allocator = Allocator()
        self.async_file = AsyncFile(filename)
        self.command_que = SortedList()
        self.file = open(filename, 'rb+', buffering=0)
        self.lock = Lock()
        self.on_interval = (0, 1)
        self.on_write = False
        self.task_que = TaskQue()

    def malloc(self, size: int) -> int:
        def is_inside(ptr: int) -> bool:
            begin, end = self.on_interval
            return begin <= ptr <= end or begin <= ptr + size <= end

        ptr = self.allocator.malloc(size)
        if ptr and is_inside(ptr):
            self.free(ptr, size)
            ptr = 0
        if not ptr:
            ptr = self.async_file.size
            self.async_file.size += size
        return ptr

    def free(self, ptr: int, size: int):
        self.allocator.free(ptr, size)

    def time_travel(self, token: Task, node: IndexNode):
        address = node.nth_value_ads(0)
        for i in range(len(node.ptrs_value)):
            ptr = self.task_que.get(token, address, node.ptr)
            if ptr:
                node.ptrs_value[i] = ptr
            address += 8
        if not node.is_leaf:
            for i in range(len(node.ptrs_child)):
                ptr = self.task_que.get(token, address, node.ptr)
                if ptr:
                    node.ptrs_child[i] = ptr
                address += 8

    def a_command_done(self, token: Task):
        token.command_num -= 1
        if token.command_num == 0:
            self.task_que.clean()
            if not self.task_que.que and self.lock.locked():
                self.lock.release()

    # cum = cumulation
    def do_cum(self, token: Task, free_nodes, command_map):
        for node in free_nodes:
            self.free(node.ptr, node.size)
        for ptr, param in command_map.items():
            data, depend = param if isinstance(param, tuple) else (param, 0)
            self.ensure_write(token, ptr, data, depend)
        self.time_travel(token, self.root)
        self.root = self.root.clone()

    def ensure_write(self, token: Task, ptr: int, data: bytes, depend=0):
        async def coro():
            while self.command_que:
                ptr, token, data, depend = self.command_que.pop(0)
                cancel = depend and self.task_que.is_canceled(token, depend)
                if not cancel:
                    cancel = self.task_que.is_canceled(token, ptr)
                if not cancel:
                    # 确保边界不相连
                    self.on_interval = (ptr - 1, ptr + len(data) + 1)
                    await self.async_file.write(ptr, data)
                self.a_command_done(token)
            self.on_write = False

        if not self.on_write:
            self.on_write = True
            ensure_future(coro())
        # 按ptr和token.id排序
        self.command_que.append((ptr, token, data, depend))
        token.command_num += 1

    def close(self):
        self.file.seek(0)
        self.file.write(ED)
        self.file.close()
        self.async_file.close()

    @staticmethod
    def repair(filename: str):
        temp = '__' + filename
        size = getsize(filename)
        with open(filename, 'rb') as file, open('$' + temp, 'wb') as items:
            file.seek(9)
            while file.tell() != size:
                indicator = file.read(1)
                if indicator != ED:
                    continue
                with suppress(EOFError, UnpicklingError):
                    item = load(file)
                    if isinstance(item, tuple) and len(item) == 2:
                        dump(item, items)
        rename('$' + temp, temp)
Beispiel #15
0
class _Keycloak:
    """Class that handles the communication with keycloak"""

    last_token_reception_time: float
    """UTC timestamp"""

    last_token: str
    """Token"""

    async_lock: Lock
    """Asynchronous lock """

    session: ClientSession
    """Aiohttp session"""
    async def setup(self):
        """
        Setup keycloak, call this method only inside the startup event
        """

        # timeout to obtain  token
        timeout = ClientTimeout(total=10)
        # connection options
        connector = TCPConnector(ttl_dns_cache=300, limit=1, ssl=False)

        # Instantiate a session
        self.session = ClientSession(
            raise_for_status=True,
            json_serialize=lambda x: orjson.dumps(x).decode(),
            timeout=timeout,
            connector=connector,
        )
        self.async_lock = Lock()
        self.last_token = await self._get_token()
        self.last_token_reception_time = time.time()

    async def close(self):
        """
        Close gracefully Keycloak session
        """
        await self.session.close()

    async def _get_token(self):
        """
        Private method used to make the http request to keycloak
        in order to obtain a valid token
        """
        # Get Logger
        logger = get_logger()
        # Get settings
        settings = get_keycloak_settings()

        try:
            async with self.session.post(
                    url=settings.token_request_url,
                    data={
                        "client_id": settings.client_id,
                        "username": settings.username_keycloak,
                        "password": settings.password,
                        "grant_type": settings.grant_type,
                        "client_secret": settings.client_secret,
                    },
            ) as resp:
                return Token.parse_obj(
                    await resp.json(encoding="utf-8",
                                    loads=orjson.loads,
                                    content_type=None)).access_token

        except ClientResponseError as exc:
            await logger.error({
                "method": exc.request_info.method,
                "url": exc.request_info.url,
                "client_id": settings.client_id,
                "username": settings.username_keycloak,
                "password": settings.password,
                "grant_type": settings.grant_type,
                "client_secret": settings.client_secret,
                "status_code": exc.status,
                "error": exc.message,
            })
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Wrong keycloak credentials",
            )
        except TimeoutError:
            # Keycloak is in starvation
            await logger.warning({
                "url": settings.token_request_url,
                "error": "Keycloak is in starvation",
            })
            raise HTTPException(
                status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
                detail="Keycloak service not available",
            )

    async def get_ublox_token(self):
        """
        Obtain a token from keycloak. If 150 seconds are passed since the last token was obtained,
        it will obtain a fresh one and adjust it's timestamp
        """

        # Obtain actual timestamp
        now = time.time()

        # Check if it's been at least 50 seconds since last token was obtained
        if now - self.last_token_reception_time >= 150:

            # Check if the lock is already acquired by a coroutine
            if self.async_lock.locked():

                # Await until the lock is released by the other coroutine
                # and after that release it
                await self.async_lock.acquire()
                self.async_lock.release()

                # Here we are sure that the token was updated
                return self.last_token

            # Only one coroutine has to update the token
            async with self.async_lock:
                # Update token and timestamp
                self.last_token = await self._get_token()
                self.last_token_reception_time = time.time()

        # return the stored token
        return self.last_token
Beispiel #16
0
class RatBoard(abc.Mapping):
    """
    The Rat Board
    """

    __slots__ = [
        "_storage_by_uuid",
        "_storage_by_client",
        "_handler",
        "_storage_by_index",
        "_index_counter",
        "_offline",
        "_modification_lock",
        "__weakref__",
    ]

    def __init__(self,
                 api_handler: typing.Optional[FuelratsApiABC] = None,
                 offline: bool = True):
        self._handler: typing.Optional[FuelratsApiABC] = api_handler
        """
        fuelrats.com API handler
        """
        self._storage_by_uuid: typing.Dict[UUID, Rescue] = {}
        """
        internal rescue storage keyed by uuid
        """
        self._storage_by_client: typing.Dict[str, Rescue] = {}
        """
        internal rescue storage keyed by client
        """
        self._storage_by_index: typing.Dict[int, Rescue] = {}
        """
        internal rescue storage keyed by board index
        """
        self._index_counter = itertools.count()
        """
        Internal counter for tracking used indexes
        """
        self._offline = offline

        self._modification_lock = Lock()
        """
        Modification lock to prevent concurrent modification of the board.
        """

        super(RatBoard, self).__init__()

    @property
    def api_handler(self):
        """ Api handler reference """
        return self._handler

    @api_handler.setter
    def api_handler(self, value: FuelratsApiABC):
        if not isinstance(value, FuelratsApiABC):
            raise TypeError(type(value))
        self._handler = value

    @api_handler.deleter
    def api_handler(self):
        self._handler = None

    async def on_online(self):
        logger.info("Rescue board online.")
        self._offline = False
        # TODO get API version from remote and log it
        # TODO emit canned offline events to API

    async def on_offline(self):
        logger.warning("Rescue board now offline.")
        self._offline = True

    def __getitem__(self, key: _KEY_TYPE) -> Rescue:
        if isinstance(key, str):
            return self._storage_by_client[key.casefold()]

        if isinstance(key, UUID):
            return self._storage_by_uuid[key]

        if isinstance(key, int):
            return self._storage_by_index[key]

        # not one of our key types,
        return super(RatBoard, self).__getitem__(key)

    def __len__(self) -> int:
        return len(self._storage_by_uuid)

    def __iter__(self) -> typing.Iterator[UUID]:
        return iter(self._storage_by_uuid)

    @property
    def _free_case_number(self) -> int:
        """
        returns the next unused index in the series.

        Returns:
            int: next free board index
        """
        # this line is so magical it gets its own method.
        # basically, this returns the next index from self._index_counter'
        # that is not already in use (contains & filterfalse)
        return next(
            itertools.filterfalse(self.__contains__, self._index_counter))

    @property
    def free_case_number(self) -> int:
        """
        Returns a unused case number

        Returns:
            int: unused case number

        Notes:
            This method will attempt to return values smaller than the defined
            :obj:`CYCLE_AT` whenever possible, though will return values
            in excess of :obj:`CYCLE_AT` as necessary.
        """
        next_free = self._free_case_number

        # if we are larger or equal to the CYCLE_AT point, reset the counter
        overflow = next_free >= cycle_at
        if overflow:
            self._index_counter = itertools.count()
            # get the next free index from the underlying magic
            return self._free_case_number

        # return the next index from the magic
        return next_free

    async def append(self, rescue: Rescue, overwrite: bool = False) -> None:
        """
        Append a rescue to ourselves

        If the rescue doesn't have a board index, it will be assigned one.

        Args:
            rescue (Rescue): object to append
            overwrite(bool): overwrite existing cases
        """
        logger.trace("acquiring modification lock...")
        async with self._modification_lock:
            # ensure the rescue has a board index, because if this is null it breaks all the things.
            if rescue.board_index is None:
                rescue.board_index = self.free_case_number
            logger.trace("acquired modification lock.")
            if (rescue.api_id in self
                    or rescue.board_index in self) and not overwrite:
                raise ValueError(
                    "Attempted to append a rescue that already exists to the board"
                )
            self._storage_by_uuid[rescue.api_id] = rescue
            self._storage_by_index[rescue.board_index] = rescue

            if rescue.irc_nickname:
                self._storage_by_client[
                    rescue.irc_nickname.casefold()] = rescue
        logger.trace("released modification lock.")

    @property
    def online(self):
        """ is this module in online mode """
        return not self._offline and self._handler is not None

    def __contains__(self, key: _KEY_TYPE) -> bool:
        if isinstance(key, str):
            return key.casefold() in self._storage_by_client
        if isinstance(key, UUID):
            return key in self._storage_by_uuid
        return key in self._storage_by_index

    def __delitem__(self, key: _KEY_TYPE):
        # Sanity check.
        if not self._modification_lock.locked():
            raise RuntimeError(
                "attempted to delete a rescue without acquiring the lock first!"
            )
        # Get the target.
        target = self[key]

        # Purge it key by key.
        del self._storage_by_uuid[target.api_id]
        del self._storage_by_index[target.board_index]
        if target.irc_nickname and target.irc_nickname.casefold(
        ) in self._storage_by_client:
            del self._storage_by_client[target.irc_nickname.casefold()]

    @asynccontextmanager
    async def modify_rescue(
            self,
            key: BoardKey,
            impersonation: typing.Optional[Impersonation] = None) -> Rescue:
        """
        Context manager to modify a Rescue

        Args:
            impersonation: User account this modification was issued by
            key ():

        Yields:
            Rescue: rescue to modify based on its `key`
        """
        logger.trace("acquiring modification lock...")
        async with self._modification_lock:
            logger.trace("acquired modification lock.")
            if isinstance(key, Rescue):
                key = key.board_index

            target = self[key]

            # most tracked attributes may be modified in here, so we pop the rescue
            # from tracking and append it after

            del self[key]

            self._modification_lock.release()
            try:
                # Yield so the caller can modify the rescue
                yield target

            finally:
                # we need to be sure to re-append the rescue upon completion
                # (so errors don't drop cases)
                await self.append(target)
                # append will reacquire the lock, so don;t reacquire it ourselves (damn no rlocks),
                # but the context manger is gunna freak out if we don't re-acquire it though.
                await self._modification_lock.acquire()
            # If we are in online mode, emit update event to API.
            if self.online:
                logger.trace("updating API...")
                await self._handler.update_rescue(target,
                                                  impersonating=impersonation)

        logger.trace("released modification lock.")

    async def create_rescue(self, *args, ovewrite=False, **kwargs) -> Rescue:
        """
        Creates a rescue, in online mode this will perform creation actions against the API.
        In the event of API error, the rescue will still be created locally, though an exception
        raised.

        Args:
            pass through to :class:Rescue 's constructor

            *args (): args to pass to :class:`Rescue` 's constructor
            overwite(bool): overwrite existing rescues
            **kwargs (): keyword arguments to pass to Rescue's constructor

        Returns:
            created rescue object

        Raises:
            ApiError: Something went wrong in API creation, rescue has been created locally.
        """
        index = self.free_case_number
        logger.trace("instantiating local rescue object...")
        rescue = Rescue(*args, board_index=index, **kwargs)

        try:
            if not self.online:
                logger.warning("creating case in offline mode...")
            else:
                logger.trace("creating rescue on API...")
                rescue = await self._handler.create_rescue(rescue,
                                                           impersonating=None)

        except ApiException:
            logger.exception("unable to create rescue on API!")
            # Emit upstream so the caller knows something went wrong
            raise

        finally:
            rescue.board_index = index
            # Always append it to ourselves, regardless of API errors
            await self.append(rescue, overwrite=ovewrite)

        return rescue

    async def remove_rescue(self, target: BoardKey):
        """ removes a rescue from active tracking """
        if isinstance(target, Rescue):
            target = target.board_index
        logger.trace("Acquiring modification lock...")
        async with self._modification_lock:
            logger.trace("Acquired modification lock.")
            # TODO: add to internal deck in offline mode so we can push to the API when we eventually
            del self[target]
        logger.trace("Released modification lock.")
Beispiel #17
0
class Actor(object):

    """
    Main actor model.

    Args:
        inbox (GeneratorQueue): Inbox to consume from.
        outbox (GeneratorQueue): Outbox to publish to.
        loop (GeneratorQueue): Event loop.
    """

    running = False
    _force_stop = False

    def __init__(self, inbox, outbox, loop=None):
        self.inbox = inbox
        self.outbox = outbox

        if not loop:
            loop = get_event_loop()

        self._loop = loop
        self._pause_lock = Lock(loop=self._loop)
        self._stop_event = Event(loop=self._loop)
        self._test = None
        self.__testy = None

        self.on_init()

    @property
    def paused(self):
        """Indicate if actor is paused."""
        return self._pause_lock.locked()

    async def start(self):
        """Main public entry point to start the actor."""
        await self.initialize()
        await self._start()

    async def initialize(self):
        """Initialize the actor before starting."""
        await self.on_start()

        if self._force_stop:
            return

        if not self.running:
            self.running = True

    async def _start(self):
        """Run the event loop and force the on_stop event."""
        try:
            await self._run()
        finally:
            await self.on_stop()

    async def resume(self):
        """Resume the actor."""
        await self.on_resume()
        self._pause_lock.release()

    async def pause(self):
        """Pause the actor."""
        await self._pause_lock.acquire()
        await self.on_pause()

    async def _block_if_paused(self):
        """Block on the pause lock."""
        if self.paused:
            await self._pause_lock.acquire()
            await self._pause_lock.release()

    async def _run(self):
        """Main event loop."""
        while self.running:
            await self._block_if_paused()

            await self._process()

    async def publish(self, data):
        """Push data to the outbox."""
        await self.outbox.put(data)

    async def stop(self):
        """Stop the actor."""
        self.inbox = None
        self.outbox = None
        self.running = False
        self._force_stop = True

        self._stop_event.set()

        try:
            self._pause_lock.release()
        except RuntimeError:
            pass

    async def _process(self):
        """Process incoming messages."""
        if not self.inbox:
            return

        pending = {self.inbox.get(), self._stop_event.wait()}
        result = await get_first_completed(pending, self._loop)

        if self.running:
            await self.on_message(result)

    async def on_message(self, data):
        """Called when the actor receives a message."""
        raise NotImplementedError

    def on_init(self):
        """Called after the actor class is instantiated."""
        pass

    async def on_start(self):
        """Called before the actor starts ingesting the inbox."""
        pass

    async def on_stop(self):
        """Called after actor dies."""
        pass

    async def on_pause(self):
        """Called before the actor is paused."""
        pass

    async def on_resume(self):
        """Called before the actor is resumed."""
        pass
Beispiel #18
0
class Plan(HTMLParser):
    def __init__(self, db_path, auto_update_times, plan_user, plan_password):
        self.db_path = db_path
        self.auto_update_times = auto_update_times
        self.last_attr = []
        self.row = []
        self.currentPlan = "",
        self.planUser = plan_user
        self.planPassword = plan_password
        self.conn = None
        self.last_prepared_row = []
        self.db_mutex = Lock()
        super().__init__()

    async def get_database(self):
        await self.db_mutex.acquire()
        self.conn = sqlCipher.connect(self.db_path)
        self.conn.execute('pragma key="' + KEY + '"')
        return self.conn

    def close_database(self):
        if self.db_mutex.locked():
            self.conn.close()
            self.db_mutex.release()
            return True
        else:
            return False

    @staticmethod
    def localize_time(time):
        return pytz.timezone('Europe/Berlin').localize(time)

    def get_urls(self):
        # Iso format is for example 2019-10-29T19:20:31.875466
        current_time = datetime.datetime.now().isoformat()
        # Cut off last 3 digits and add 'Z' to get correct format
        current_time = current_time[:-3] + "Z"

        # Parameters required for the server to accept our data request
        params = {
            "UserId": self.planUser,
            "UserPw": self.planPassword,
            "AppVersion": "2.5.9",
            "Language": "de",
            "OsVersion": "28 8.0",
            "AppId": str(uuid.uuid4()),
            "Device": "SM-G935F",
            "BundleId": "de.heinekingmedia.dsbmobile",
            "Date": current_time,
            "LastUpdate": current_time
        }
        # Convert params into the right format
        params_bytestring = json.dumps(params,
                                       separators=(',', ':')).encode("UTF-8")
        params_compressed = base64.b64encode(
            gzip.compress(params_bytestring)).decode("UTF-8")

        # Send the request
        json_data = {"req": {"Data": params_compressed, "DataType": 1}}
        r = requests.post("https://app.dsbcontrol.de/JsonHandler.ashx/GetData",
                          json=json_data)

        try:
            if r.status_code == 200:
                # Decompress response
                data_compressed = json.loads(r.content)["d"]
            else:
                raise PlanError('could not fetch online plans')
        except KeyError:
            raise PlanError('could not fetch online plans')

        data = json.loads(gzip.decompress(base64.b64decode(data_compressed)))

        if data['Resultcode'] != 0:
            raise PlanError('could not fetch online plans')

        for menuItem in data['ResultMenuItems']:
            if menuItem['Title'] == 'Inhalte':
                urls = []

                for child in menuItem['Childs']:
                    if child['MethodName'] == 'timetable':

                        for innerChild in child['Root']['Childs']:
                            urls.append(innerChild['Childs'][0]['Detail'])

                return urls

        return []

    async def run_database_operation(self, run, *args):
        self.close_database()
        try:
            return await run(*args)
        finally:
            await self.get_database()

    async def update(self):
        try:
            await self.get_database()

            self.conn.execute(
                "CREATE TABLE IF NOT EXISTS dates (side TEXT UNIQUE, date TEXT)"
            )
            self.conn.execute(
                "CREATE TABLE IF NOT EXISTS info (side TEXT, info TEXT)")

            urls = self.get_urls()

            for index, url in enumerate(urls):
                website = requests.get(url)
                self.currentPlan = 'Plan' + str(index)

                self.conn.execute("DELETE FROM dates WHERE side == ?",
                                  (self.currentPlan, ))
                self.conn.execute("DELETE FROM info WHERE side == ?",
                                  (self.currentPlan, ))

                self.conn.execute("DROP TABLE IF EXISTS " +
                                  quote_identifier(self.currentPlan))
                self.conn.execute("""CREATE TABLE """ +
                                  quote_identifier(self.currentPlan) + """ (
                            grade TEXT,
                            lessons TEXT,
                            kind TEXT, 
                            subjects TEXT, 
                            rooms TEXT, 
                            text TEXT
                    );""")

                self.feed(codecs.decode(website.content, website.encoding))

            # noinspection SqlResolve
            self.conn.execute(
                "REPLACE INTO dates (side, date) VALUES (\"intern\", ?)",
                (int(datetime.datetime.now().timestamp()), ))
        finally:
            self.conn.commit()
            self.close_database()

            self.last_attr = []
            self.row = []
            self.currentPlan = ""
            self.last_prepared_row = []

    def handle_starttag(self, tag, attrs):
        if tag != "span":
            self.last_attr = attrs

    def handle_endtag(self, tag):
        if tag == "tr" or tag == "div":
            if len(self.last_attr) > 0 and self.last_attr[0][
                    0] == "class" and len(self.row) > 0:
                if self.last_attr[0][1] == "info":
                    # region insert info on side into table
                    info = ""

                    for part in self.row:
                        info += part

                    self.conn.execute(
                        "INSERT INTO info (side, info) VALUES (?, ?);",
                        (self.currentPlan, info))
                    # endregion

                elif self.last_attr[0][1] == "mon_title":
                    # region insert date of plan into table
                    self.conn.execute(
                        "INSERT INTO dates (side, date) VALUES (?, ?)",
                        (self.currentPlan, self.row[0]))
                    # endregion

                elif self.last_attr[0][1] == "list" and len(self.row) == 7:
                    prepared_row = [
                        prepare_grade(self.row[0]), self.row[1],
                        prepare_kind(self.row[2]),
                        prepare_subject(self.row[3], self.row[4]),
                        prepare_room(self.row[5]), self.row[6]
                    ]

                    if prepared_row[:5] == [
                            '\xa0', '\xa0', '\xa0', '\xa0', '\xa0'
                    ]:
                        # region completes text at the previous row in the table
                        self.conn.execute(
                            """
                                UPDATE """ +
                            quote_identifier(self.currentPlan) + """
                                SET text = ?
                                WHERE 
                                grade == ? and lessons == ? and kind == ? and subjects == ? and rooms == ? and text == ?
                            """,
                            (self.last_prepared_row[5] + " " + prepared_row[5],
                             self.last_prepared_row[0],
                             self.last_prepared_row[1],
                             self.last_prepared_row[2],
                             self.last_prepared_row[3],
                             self.last_prepared_row[4],
                             self.last_prepared_row[5]))
                        # endregion
                    else:
                        # region insert new row into table
                        if prepared_row[0] == '\xa0':
                            prepared_row[0] = self.last_prepared_row[0]

                        if prepared_row[0] == "AG":
                            if prepared_row[3] == "Klettern":
                                prepared_row[3] = "Kletter"

                            prepared_row[0] = prepared_row[3] + " AG"
                            prepared_row[3] = '\xa0'

                        self.conn.execute(
                            """
                                  INSERT INTO """ +
                            quote_identifier(self.currentPlan) + """
                                  (grade, lessons, kind, subjects, rooms, text) VALUES (?, ?, ?, ?, ?, ?)
                            """, prepared_row)
                        # endregion

                    self.last_prepared_row = prepared_row

                self.row = []

    def handle_data(self, data):
        if len(self.last_attr) > 0 and self.last_attr[0][
                0] == "class" and self.lasttag != "th":
            if self.last_attr[0][1] == "list" or \
               self.last_attr[0][1] == "info" or \
               self.last_attr[0][1] == "mon_title":
                if data != '\r\n':
                    self.row.append(data)

    async def search(self, user_id, search):
        result = {}

        try:
            await self.get_database()

            self.conn.execute(
                "CREATE TABLE IF NOT EXISTS searches (user_id TEXT UNIQUE, search TEXT)"
            )

            if user_id is not None:
                if search is None or len(search) < 1:
                    search = await self.run_database_operation(
                        self.get_last_user_search, user_id)
                else:
                    # region safe search from into the table
                    self.conn.execute(
                        "REPLACE INTO searches (user_id, search) VALUES (?, ?)",
                        (user_id, search))
                    self.conn.commit()
                    # endregion

                search = search.split()
                search[0] = "%" + search[0] + "%"

                for i in range(1, len(search)):
                    search[i] = "% " + search[i] + " %"

            # region check for updates
            update_date = await self.run_database_operation(
                self.get_update_date)
            update_date = datetime.datetime.utcfromtimestamp(update_date)
            update_date = pytz.utc.localize(update_date)

            for auto_update_time in self.auto_update_times:
                auto_update_time = datetime.datetime.strptime(
                    str(datetime.datetime.now().day) + "." +
                    str(datetime.datetime.now().month) + "." +
                    str(datetime.datetime.now().year) + " " + auto_update_time,
                    "%d.%m.%Y %H:%M")

                auto_update_time = self.localize_time(auto_update_time)
                now = pytz.utc.localize(datetime.datetime.now())

                if update_date < auto_update_time < now:
                    await self.run_database_operation(self.update)
                    update_date = datetime.datetime.now()
                    update_date = pytz.utc.localize(update_date)

            # endregion

            self.conn.execute(
                "CREATE TABLE IF NOT EXISTS dates (side TEXT UNIQUE, date TEXT)"
            )
            self.conn.execute(
                "CREATE TABLE IF NOT EXISTS info (side TEXT, info TEXT)")

            for index in range(2):
                # region get plan date
                plan_date = self.conn.execute(
                    "SELECT date FROM dates WHERE side == ?",
                    ('Plan' + str(index), ))
                plan_date = plan_date.fetchone()

                if plan_date is None or len(plan_date) < 1:
                    raise PlanError("missing plan date for " +
                                    quote_identifier('Plan' + str(index)))

                plan_date = plan_date[0]

                if datetime.datetime(  # Today at 0:00
                        datetime.datetime.now().year,
                        datetime.datetime.now().month,
                        datetime.datetime.now().
                        day) > datetime.datetime.strptime(
                            plan_date[:plan_date.find(" ")],
                            "%d.%m.%Y"):  # Date of plan at 0:00
                    continue

                # endregion

                # region get info
                info = self.conn.execute(
                    "SELECT info FROM info WHERE side == ?",
                    ('Plan' + str(index), ))
                info = info.fetchall()
                # endregion

                # region get entries
                if user_id is not None:
                    sql_query = "SELECT * FROM " + quote_identifier(
                        'Plan' + str(index)) + " WHERE grade LIKE ? AND ("

                    if len(search) > 1:
                        sql_query += "subjects LIKE ? OR " * (len(search) - 1)
                        sql_query = sql_query[:-3] + ")"
                    else:
                        sql_query = sql_query[:-5]

                    entries = self.conn.execute(sql_query, search)
                    entries = entries.fetchall()
                else:
                    entries = []
                # endregion

                result[plan_date] = (info, entries)

        finally:
            self.close_database()

        return result

    async def get_update_date(self):
        try:
            await self.get_database()
            self.conn.execute(
                "CREATE TABLE IF NOT EXISTS dates (side TEXT UNIQUE, date TEXT)"
            )
            # noinspection SqlResolve
            cursor = self.conn.execute(
                "SELECT date FROM dates WHERE side == \"intern\"")

            update_date = cursor.fetchone()
            if update_date is None or len(update_date) < 1:
                await self.run_database_operation(self.update)
                update_date = (datetime.datetime.now().timestamp(), )
        finally:
            self.close_database()

        return int(update_date[0])

    async def get_last_user_search(self, user_id):
        try:
            await self.get_database()

            self.conn.execute(
                "CREATE TABLE IF NOT EXISTS searches (user_id TEXT UNIQUE, search TEXT)"
            )
            search = self.conn.cursor()
            search.execute("SELECT search FROM searches WHERE user_id == ?",
                           (user_id, ))

            search = search.fetchone()
            if search is None or len(search) < 1:
                raise PlanError("no last search found")
        finally:
            self.close_database()

        return search[0]
Beispiel #19
0
class NodeContext:
    def __init__(self, app_config, data_dir, name, config_user=None):
        self.app_config = app_config
        self.data_dir = Path(data_dir).resolve()
        self.name = name

        self.host_up = False
        self.genesis = None
        self.follows = None
        self.ts_start = 0
        self.ts_heartbeat = 0
        self.failure = None

        self.log = logging.getLogger(__name__ + '.' + name)
        self.maintenance_lock = Lock()

        self.cookie_exec = None
        self.cookie_data = None

        self.load_config(config_user)
        self.host = HostGCP(self.config, self.app_config.gcp_credentials_file)

    def __str__(self):
        return (
            'NodeContext(' + self.name + \
            (' up=' + ('1' if self.host_up else '0')) + \
            (' fail=' + self.failure.name if self.failure else '') + \
            (' ts_start=' + str(int(self.ts_start))) + \
            (' ts_heartbeat=' + str(int(self.ts_heartbeat))) + \
            (' follows=' + self.follows.name if self.follows else ' leader') + \
            (' genesis=' + self.genesis if self.genesis else '') + \
            ')'
        )

    __repr__ = __str__

    def gen_cookie_exec(self):
        self.cookie_exec = uuid.uuid4()

    def gen_cookie_data(self):
        self.cookie_data = uuid.uuid4()

    # properties {{{

    @property
    def config_skeleton(self):
        return {
            'rnode_conf': {
                'server': {
                    'port': 40400,
                    'port-kademlia': 40404
                },
                'grpc': {
                    'port-external': 40401
                }
            },
            'hostname_suffix': '.',
            'hostname_ttl': 300,
            'resources_name_prefix': '',
            'templates': [],
            'timeout_heartbeat': 300,
            'timeout_start_rnode': 300,
            'timeout_start_host': 300,
            'host_metadata': {},
            'compute_timeout': 600,
        }

    @property
    def files_dir(self) -> Path:
        return self.data_dir / 'files'

    @property
    def config_file_user(self) -> Path:
        return self.data_dir / 'config.user.json'

    @property
    def config_file_auxiliary(self) -> Path:
        return self.data_dir / 'config.aux.json'

    @property
    def config_file_full(self) -> Path:
        return self.data_dir / 'config.full.json'

    @property
    def rnode_conf_file(self) -> Path:
        return self.files_dir / 'rnode.conf'

    @property
    def rnode_tls_key_file(self) -> Path:
        return self.files_dir / 'node.key.pem'

    # }}}

    # configuration {{{

    def load_config_template(self, name):
        try:
            path = resolve_path(self.app_config.node_config_templates_dir,
                                name + '.json')
            return read_json(path)
        except FileNotFoundError:
            raise NodeContextError(f'Template "{name}" does not exist')

    def load_config_merged(self, config_user):
        config = {}

        merge_list = [
            config_user,
            self.app_config.node_config_global,
            self.config_skeleton,
        ]
        merged_templates = set()

        i = 0
        while i < len(merge_list):
            for tpl_name in merge_list[i].get('templates', []):
                if not tpl_name in merged_templates:
                    tpl = self.load_config_template(tpl_name)
                    merge_list.insert(i + 1, tpl)
                    merged_templates.add(tpl_name)
            i += 1

        for part in merge_list[::-1]:
            config = merger.merge(config, part)

        return config

    def load_config_full(self, config_user):
        config = self.load_config_merged(config_user)
        config_aux = try_read_json(self.config_file_auxiliary, {})

        config, config_aux = add_missing_value_aux(
            config,
            config_aux,
            '.rnode_conf.casper."validator-private-key"',
            lambda: lib_rchain_key.generate_key_hex(),
        )

        config, config_aux = add_missing_value_aux(
            config,
            config_aux,
            '.rnode_tls_key',
            lambda: lib_rnode_tls.generate_key_pem(),
        )

        config = add_missing_value(
            config,
            '.rnode_id',
            lambda: lib_rnode_tls.get_node_id(config['rnode_tls_key']),
        )

        config = add_missing_value(
            config,
            '.resources_name',
            lambda: config['resources_name_prefix'] + self.name,
        )

        config = add_missing_value(
            config,
            '.hostname',
            lambda: self.name + config['hostname_suffix'],
        )

        if not config['hostname'].endswith('.'):
            config['hostname'] += '.'

        config = add_missing_value(
            config,
            '.rnode_addr',
            lambda: 'rnode://{}@{}?protocol={}&discovery={}'.format(
                config['rnode_id'],
                config['hostname'],
                config['rnode_conf']['server']['port'],
                config['rnode_conf']['server']['port-kademlia'],
            ),
        )

        return config, config_aux

    def load_update_config(self, config_user=None):
        if config_user == None:
            config_user = try_read_json(self.config_file_user, {})
        config, config_aux = self.load_config_full(config_user)

        os.makedirs(self.data_dir, exist_ok=True)
        if config_user:
            write_json(self.config_file_user, config_user)
        write_json(self.config_file_auxiliary, config_aux)
        write_json(self.config_file_full, config)

        os.makedirs(self.files_dir, exist_ok=True)
        write_json(self.rnode_conf_file, config['rnode_conf'])
        self.rnode_tls_key_file.write_text(config['rnode_tls_key'])

        self.config = config
        self.gen_cookie_exec()

    def load_config(self, config_user=None):
        if config_user == None and self.config_file_full.exists():
            self.config = read_json(self.config_file_full)
        else:
            self.load_update_config(config_user)

    # }}}

    # lifecycle {{{

    async def _stop(self, clean):
        self.host_up = False
        self.failure = None
        self.log.info('Stopping')
        await self.host.stop(clean)
        self.log.info('Stopped')

    async def _start(self):
        self.log.info('Starting')
        await self.host.start()
        if not self.host_up:
            self.ts_start = time.time()
        self.log.info('Started')

    async def _try_start(self):
        try:
            if self.maintenance_lock.locked():
                return
            async with self.maintenance_lock:
                await self._start()
        except:
            self.log.exception('Start failed')
            raise

    async def _try_restart(self, clean):
        try:
            if self.maintenance_lock.locked():
                return
            async with self.maintenance_lock:
                skip_start = False
                try:
                    await self._stop(clean)
                except CancelledError:
                    skip_start = True
                    raise
                finally:
                    if not skip_start:
                        await self._start()
        except:
            self.log.exception('Restart failed')
            raise

    def try_start_async(self):
        self.log.info('Scheduling start')
        create_task(self._try_start())

    def try_restart_async(self, clean_data=False):
        self.log.info('Scheduling restart')
        clean = HostClean.DATA if clean_data else HostClean.STOP
        create_task(self._try_restart(clean))

    # }}}

    def heartbeat(self, msg):
        self.log.info('Received heartbeat message')

        if self.maintenance_lock.locked():
            self.log.info('Ignoring due to active maintenance')
            return {}

        now = time.time()
        if not self.host_up:
            self.log.info('Host is up')
            self.host_up = True
            self.ts_start = now
        self.ts_heartbeat = now

        if 'cookie_exec' in msg and not self.cookie_exec:
            self.cookie_exec = msg['cookie_exec']

        if 'cookie_data' in msg and not self.cookie_data:
            self.cookie_data = msg['cookie_data']

        if 'genesis' in msg and self.genesis != msg['genesis']:
            self.genesis = msg['genesis']

        reply = {
            'cookie_exec': self.cookie_exec,
            'cookie_data': self.cookie_data,
            'rnode_package_url': self.config['rnode_package_url']
        }

        if self.follows:
            reply['mode'] = 'follower'
            reply['leader'] = self.follows.config['rnode_addr']
        else:
            reply['mode'] = 'leader'

        self.log.info('Sending reply')
        if self.log.isEnabledFor(logging.INFO):
            for k in sorted(reply.keys()):
                self.log.info('  reply[%s] = %s', k, reply[k])

        return reply

    def _check_timeouts(self, ts):
        if (self.host_up
                and ts > self.ts_heartbeat + self.config['timeout_heartbeat']):
            return NodeFailure.TIMEOUT_HEARTBEAT
        if (self.host_up and not self.genesis
                and ts > self.ts_start + self.config['timeout_start_rnode']):
            return NodeFailure.TIMEOUT_START_RNODE
        if (not self.host_up
                and ts > self.ts_start + self.config['timeout_start_host']):
            return NodeFailure.TIMEOUT_START_HOST
        return None

    def check_failure(self, ts):
        if self.maintenance_lock.locked():
            return None
        if not self.failure:
            new_failure = self._check_timeouts(ts)
            if new_failure:
                self.log.info('Failure: %s', new_failure.name)
                self.failure = new_failure
        return self.failure
Beispiel #20
0
class TaskQueue(Generic[TTask]):
    """
    TaskQueue keeps priority-order track of pending tasks, with a limit on number pending.

    A producer of tasks will insert pending tasks with await add(), which will not return until
    all tasks have been added to the queue.

    A task consumer calls await get() to retrieve tasks for processing. Tasks will be returned in
    priority order. If no tasks are pending, get()
    will pause until at least one is available. Only one consumer will have a task "checked out"
    from get() at a time.

    After tasks are successfully completed, the consumer will call complete() to remove them from
    the queue. The consumer doesn't need to complete all tasks, but any uncompleted tasks will be
    considered abandoned. Another consumer can pick it up at the next get() call.
    """

    # a function that determines the priority order (lower int is higher priority)
    _order_fn: FunctionProperty[Callable[[TTask], Any]]

    # batches of tasks that have been started but not completed
    _in_progress: Dict[int, Tuple[TTask, ...]]

    # all tasks that have been placed in the queue and have not been started
    _open_queue: 'PriorityQueue[Tuple[Any, TTask]]'

    # all tasks that have been placed in the queue and have not been completed
    _tasks: Set[TTask]

    def __init__(self,
                 maxsize: int = 0,
                 order_fn: Callable[[TTask], Any] = identity,
                 *,
                 loop: AbstractEventLoop = None) -> None:
        self._maxsize = maxsize
        self._full_lock = Lock(loop=loop)
        self._open_queue = PriorityQueue(maxsize, loop=loop)
        self._order_fn = order_fn
        self._id_generator = count()
        self._tasks = set()
        self._in_progress = {}

    async def add(self, tasks: Tuple[TTask, ...]) -> None:
        """
        add() will insert as many tasks as can be inserted until the queue fills up.
        Then it will pause until the queue is no longer full, and continue adding tasks.
        It will finally return when all tasks have been inserted.
        """
        if not isinstance(tasks, tuple):
            raise ValidationError(
                f"must pass a tuple of tasks to add(), but got {tasks!r}")

        already_pending = self._tasks.intersection(tasks)
        if already_pending:
            raise ValidationError(
                f"Duplicate tasks detected: {already_pending!r} are already present in the queue"
            )

        # make sure to insert the highest-priority items first, in case queue fills up
        remaining = tuple(
            sorted((self._order_fn(task), task) for task in tasks))

        while remaining:
            num_tasks = len(self._tasks)

            if self._maxsize <= 0:
                # no cap at all, immediately insert all tasks
                open_slots = len(remaining)
            elif num_tasks < self._maxsize:
                # there is room to add at least one more task
                open_slots = self._maxsize - num_tasks
            else:
                # wait until there is room in the queue
                await self._full_lock.acquire()

                # the current number of tasks has changed, can't reuse num_tasks
                num_tasks = len(self._tasks)
                open_slots = self._maxsize - num_tasks

            queueing, remaining = remaining[:open_slots], remaining[
                open_slots:]

            for task in queueing:
                # There will always be room in _open_queue until _maxsize is reached
                try:
                    self._open_queue.put_nowait(task)
                except QueueFull as exc:
                    task_idx = queueing.index(task)
                    qsize = self._open_queue.qsize()
                    raise QueueFull(
                        f'TaskQueue unsuccessful in adding task {task[1]!r} because qsize={qsize}, '
                        f'num_tasks={num_tasks}, maxsize={self._maxsize}, open_slots={open_slots}, '
                        f'num queueing={len(queueing)}, len(_tasks)={len(self._tasks)}, task_idx='
                        f'{task_idx}, queuing={queueing}, original msg: {exc}',
                    )

            unranked_queued = tuple(task for _rank, task in queueing)
            self._tasks.update(unranked_queued)

            if self._full_lock.locked() and len(self._tasks) < self._maxsize:
                self._full_lock.release()

    def get_nowait(self,
                   max_results: int = None) -> Tuple[int, Tuple[TTask, ...]]:
        """
        Get pending tasks. If no tasks are pending, raise an exception.

        :param max_results: return up to this many pending tasks. If None, return all pending tasks.
        :return: (batch_id, tasks to attempt)
        :raise ~asyncio.QueueFull: if no tasks are available
        """
        if self._open_queue.empty():
            raise QueueFull("No tasks are available to get")
        else:
            pending_tasks = self._get_nowait(max_results)

            # Generate a pending batch of tasks, so uncompleted tasks can be inferred
            next_id = next(self._id_generator)
            self._in_progress[next_id] = pending_tasks

            return (next_id, pending_tasks)

    async def get(self,
                  max_results: int = None) -> Tuple[int, Tuple[TTask, ...]]:
        """
        Get pending tasks. If no tasks are pending, wait until a task is added.

        :param max_results: return up to this many pending tasks. If None, return all pending tasks.
        :return: (batch_id, tasks to attempt)
        """
        if max_results is not None and max_results < 1:
            raise ValidationError(
                "Must request at least one task to process, not {max_results!r}"
            )

        # if the queue is empty, wait until at least one item is available
        queue = self._open_queue
        if queue.empty():
            _rank, first_task = await queue.get()
        else:
            _rank, first_task = queue.get_nowait()

        # In order to return from get() as soon as possible, never await again.
        # Instead, take only the tasks that are already available.
        if max_results is None:
            remaining_count = None
        else:
            remaining_count = max_results - 1
        remaining_tasks = self._get_nowait(remaining_count)

        # Combine the first and remaining tasks
        all_tasks = (first_task, ) + remaining_tasks

        # Generate a pending batch of tasks, so uncompleted tasks can be inferred
        next_id = next(self._id_generator)
        self._in_progress[next_id] = all_tasks

        return (next_id, all_tasks)

    def _get_nowait(self, max_results: int = None) -> Tuple[TTask, ...]:
        queue = self._open_queue

        # How many results do we want?
        available = queue.qsize()
        if max_results is None:
            num_tasks = available
        else:
            num_tasks = min((available, max_results))

        # Combine the remaining tasks with the first task we already pulled.
        ranked_tasks = tuple(queue.get_nowait() for _ in range(num_tasks))

        # strip out the rank value used internally for sorting in the priority queue
        return tuple(task for _rank, task in ranked_tasks)

    def complete(self, batch_id: int, completed: Tuple[TTask, ...]) -> None:
        if batch_id not in self._in_progress:
            raise ValidationError(
                f"batch id {batch_id} not recognized, with tasks {completed!r}"
            )

        attempted = self._in_progress.pop(batch_id)

        unrecognized_tasks = set(completed).difference(attempted)
        if unrecognized_tasks:
            self._in_progress[batch_id] = attempted
            raise ValidationError(
                f"cannot complete tasks {unrecognized_tasks!r} in this batch, only {attempted!r}"
            )

        incomplete = set(attempted).difference(completed)

        for task in incomplete:
            # These tasks are already counted in the total task count, so there will be room
            self._open_queue.put_nowait((self._order_fn(task), task))

        self._tasks.difference_update(completed)

        if self._full_lock.locked() and len(self._tasks) < self._maxsize:
            self._full_lock.release()

    def __contains__(self, task: TTask) -> bool:
        """Determine if a task has been added and not yet completed"""
        return task in self._tasks
Beispiel #21
0
class Pokefusion:
    RANDOM = '%'

    def __init__(self, client, bot):
        self._last_dex_number = 0
        self._pokemon = {}
        self._poke_reverse = {}
        self._last_updated = 0
        self._client = client
        self._data_folder = os.path.join(os.getcwd(), 'data', 'pokefusion')
        self._driver_lock = Lock(loop=bot.loop)
        self._bot = bot
        self._update_lock = Lock(loop=bot.loop)

        p = self.bot.config.chromedriver
        options = Options()
        options.add_argument('--headless')
        options.add_argument('--disable-gpu')
        binary = self.bot.config.chrome
        if binary:
            options.binary_location = binary

        self.driver = Chrome(p, chrome_options=options)

    @property
    def bot(self):
        return self._bot

    @property
    def last_dex_number(self):
        return self._last_dex_number

    @property
    def client(self):
        return self._client

    def is_dex_number(self, s):
        # No need to convert when the number is that big
        if len(s) > 5:
            return False
        try:
            return int(s) <= self.last_dex_number
        except ValueError:
            return False

    async def cache_types(self, start=1):
        name = 'sprPKMType_{}.png'
        url = 'http://pokefusion.japeal.com/sprPKMType_{}.png'
        while True:
            r = await self.client.get(url.format(start))
            if r.status == 404:
                r.close()
                break

            with open(os.path.join(self._data_folder, name.format(start)), 'wb') as f:
                f.write(await r.read())

            start += 1

    async def update_cache(self):
        if self._update_lock.locked():
            # If and update is in progress wait for it to finish and then continue
            await self._update_lock.acquire()
            self._update_lock.release()
            return

        await self._update_lock.acquire()
        success = False
        try:
            logger.info('Updating pokecache')
            r = await self.client.get('http://pokefusion.japeal.com/PKMSelectorV3.php')
            soup = BeautifulSoup(await r.text(), 'lxml')
            selector = soup.find(id='s1')
            if selector is None:
                logger.debug('Failed to update pokefusion cache')
                return False

            pokemon = selector.find_all('option')
            for idx, p in enumerate(pokemon[1:]):
                name = ' #'.join(p.text.split(' #')[:-1])
                self._pokemon[name.lower()] = idx + 1
                self._poke_reverse[idx + 1] = name.lower()

            self._last_dex_number = len(pokemon) - 1
            types = filter(lambda f: f.startswith('sprPKMType_'), os.listdir(self._data_folder))
            await self.cache_types(start=max(len(list(types)), 1))
            self._last_updated = time.time()
            success = True
        except:
            logger.exception('Failed to update pokefusion cache')
        finally:
            self._update_lock.release()
            return success

    def get_by_name(self, name):
        poke = self._pokemon.get(name.lower())
        if poke is None:
            for poke_, v in self._pokemon.items():
                if name in poke_:
                    return v
        return poke

    def get_by_dex_n(self, n: int):
        return n if n <= self.last_dex_number else None

    def get_pokemon(self, name):
        if name == self.RANDOM and self.last_dex_number > 0:
            return randint(1, self._last_dex_number)
        if self.is_dex_number(name):
            return int(name)
        else:
            return self.get_by_name(name)

    async def get_url(self, url):
        # Attempt at making phantomjs async friendly
        # After visiting the url remember to put 1 item in self.queue
        # Otherwise the browser will be locked

        # If lock is not locked lock it until this operation finishes
        unlock = False
        if not self._driver_lock.locked():
            await self._driver_lock.acquire()
            unlock = True

        f = partial(self.driver.get, url)
        await self.bot.loop.run_in_executor(self.bot.threadpool, f)
        if unlock:
            try:
                self._driver_lock.release()
            except RuntimeError:
                pass

    async def fuse(self, poke1=RANDOM, poke2=RANDOM, poke3=None):
        # Update cache once per day
        if time.time() - self._last_updated > 86400:
            if not await self.update_cache():
                raise BotException('Could not cache pokemon')

        dex_n = []
        for p in (poke1, poke2):
            poke = self.get_pokemon(p)
            if poke is None:
                raise NoPokeFoundException(p)
            dex_n.append(poke)

        if poke3 is None:
            color = 0
        else:
            color = self.get_pokemon(poke3)
            if color is None:
                raise NoPokeFoundException(poke3)

        url = 'http://pokefusion.japeal.com/PKMColourV5.php?ver=3.2&p1={}&p2={}&c={}&e=noone'.format(*dex_n, color)
        async with self._driver_lock:
            try:
                await self.get_url(url)
            except UnexpectedAlertPresentException:
                self.driver.switch_to.alert.accept()
                raise BotException('Invalid pokemon given')

            data = self.driver.execute_script("return document.getElementById('image1').src")
            types = self.driver.execute_script("return document.querySelectorAll('*[width=\"30\"]')")
            name = self.driver.execute_script("return document.getElementsByTagName('b')[0].textContent")

        data = data.replace('data:image/png;base64,', '', 1)
        img = Image.open(BytesIO(base64.b64decode(data)))
        type_imgs = []

        for tp in types:
            file = tp.get_attribute('src').split('/')[-1].split('?')[0]
            try:
                im = Image.open(os.path.join(self._data_folder, file))
                type_imgs.append(im)
            except (FileNotFoundError, OSError):
                raise BotException('Error while getting type images')

        bg = Image.open(os.path.join(self._data_folder, 'poke_bg.png'))

        # Paste pokemon in the middle of the background
        x, y = (bg.width//2-img.width//2, bg.height//2-img.height//2)
        bg.paste(img, (x, y), img)

        w, h = type_imgs[0].size
        padding = 2
        # Total width of all type images combined with padding
        type_w = len(type_imgs) * (w + padding)
        width = bg.width
        start_x = (width - type_w)//2
        y = y + img.height

        for tp in type_imgs:
            bg.paste(tp, (start_x, y), tp)
            start_x += w + padding

        font = ImageFont.truetype(os.path.join('M-1c', 'mplus-1c-bold.ttf'), 36)
        draw = ImageDraw.Draw(bg)
        w, h = draw.textsize(name, font)
        draw.text(((bg.width-w)//2, bg.height//2-img.height//2 - h), name, font=font, fill='black')

        s = 'Fusion of {} and {}'.format(self._poke_reverse[dex_n[0]], self._poke_reverse[dex_n[1]])
        if color:
            s += ' using the color palette of {}'.format(self._poke_reverse[color])
        return bg, s
Beispiel #22
0
class SocketModeClient(AsyncBaseSocketModeClient):
    logger: Logger
    web_client: AsyncWebClient
    app_token: str
    wss_uri: Optional[str]
    auto_reconnect_enabled: bool
    message_queue: Queue
    message_listeners: List[Union[AsyncWebSocketMessageListener, Callable[
        ["AsyncBaseSocketModeClient", dict, Optional[str]],
        Awaitable[None]], ]]
    socket_mode_request_listeners: List[
        Union[AsyncSocketModeRequestListener,
              Callable[["AsyncBaseSocketModeClient", SocketModeRequest],
                       Awaitable[None]], ]]

    message_receiver: Optional[Future]
    message_processor: Future

    proxy: Optional[str]
    ping_interval: float
    trace_enabled: bool

    last_ping_pong_time: Optional[float]
    current_session: Optional[ClientWebSocketResponse]
    current_session_monitor: Optional[Future]

    auto_reconnect_enabled: bool
    default_auto_reconnect_enabled: bool
    closed: bool
    stale: bool
    connect_operation_lock: Lock

    on_message_listeners: List[Callable[[WSMessage], Awaitable[None]]]
    on_error_listeners: List[Callable[[WSMessage], Awaitable[None]]]
    on_close_listeners: List[Callable[[WSMessage], Awaitable[None]]]

    def __init__(
        self,
        app_token: str,
        logger: Optional[Logger] = None,
        web_client: Optional[AsyncWebClient] = None,
        proxy: Optional[str] = None,
        auto_reconnect_enabled: bool = True,
        ping_interval: float = 5,
        trace_enabled: bool = False,
        on_message_listeners: Optional[List[Callable[[WSMessage],
                                                     Awaitable[None]]]] = None,
        on_error_listeners: Optional[List[Callable[[WSMessage],
                                                   Awaitable[None]]]] = None,
        on_close_listeners: Optional[List[Callable[[WSMessage],
                                                   Awaitable[None]]]] = None,
    ):
        """Socket Mode client

        Args:
            app_token: App-level token
            logger: Custom logger
            web_client: Web API client
            auto_reconnect_enabled: True if automatic reconnection is enabled (default: True)
            ping_interval: interval for ping-pong with Slack servers (seconds)
            trace_enabled: True if more verbose logs to see what's happening under the hood
            proxy: the HTTP proxy URL
            on_message_listeners: listener functions for on_message
            on_error_listeners: listener functions for on_error
            on_close_listeners: listener functions for on_close
        """
        self.app_token = app_token
        self.logger = logger or logging.getLogger(__name__)
        self.web_client = web_client or AsyncWebClient()
        self.closed = False
        self.stale = False
        self.connect_operation_lock = Lock()
        self.proxy = proxy
        if self.proxy is None or len(self.proxy.strip()) == 0:
            env_variable = load_http_proxy_from_env(self.logger)
            if env_variable is not None:
                self.proxy = env_variable

        self.default_auto_reconnect_enabled = auto_reconnect_enabled
        self.auto_reconnect_enabled = self.default_auto_reconnect_enabled
        self.ping_interval = ping_interval
        self.trace_enabled = trace_enabled
        self.last_ping_pong_time = None

        self.wss_uri = None
        self.message_queue = Queue()
        self.message_listeners = []
        self.socket_mode_request_listeners = []
        self.current_session = None
        self.current_session_monitor = None

        # https://docs.aiohttp.org/en/stable/client_reference.html
        # Unless you are connecting to a large, unknown number of different servers
        # over the lifetime of your application,
        # it is suggested you use a single session for the lifetime of your application
        # to benefit from connection pooling.
        self.aiohttp_client_session = aiohttp.ClientSession()

        self.on_message_listeners = on_message_listeners or []
        self.on_error_listeners = on_error_listeners or []
        self.on_close_listeners = on_close_listeners or []

        self.message_receiver = None
        self.message_processor = asyncio.ensure_future(self.process_messages())

    async def monitor_current_session(self) -> None:
        # In the asyncio runtime, accessing a shared object (self.current_session here) from
        # multiple tasks can cause race conditions and errors.
        # To avoid such, we access only the session that is active when this loop starts.
        session: ClientWebSocketResponse = self.current_session
        session_id: str = self.build_session_id(session)

        if self.logger.level <= logging.DEBUG:
            self.logger.debug(
                f"A new monitor_current_session() execution loop for {session_id} started"
            )
        try:
            logging_interval = 100
            counter_for_logging = 0

            while not self.closed:
                if session != self.current_session:
                    if self.logger.level <= logging.DEBUG:
                        self.logger.debug(
                            f"The monitor_current_session task for {session_id} is now cancelled"
                        )
                    break
                try:
                    if self.trace_enabled and self.logger.level <= logging.DEBUG:
                        # The logging here is for detailed investigation on potential issues in this client.
                        # If you don't see this log for a while, it means that
                        # this receive_messages execution is no longer working for some reason.
                        counter_for_logging += 1
                        if counter_for_logging >= logging_interval:
                            counter_for_logging = 0
                            log_message = (
                                "#monitor_current_session method has been verifying if this session is active "
                                f"(session: {session_id}, logging interval: {logging_interval})"
                            )
                            self.logger.debug(log_message)

                    await asyncio.sleep(self.ping_interval)

                    if session is not None and session.closed is False:
                        t = time.time()
                        if self.last_ping_pong_time is None:
                            self.last_ping_pong_time = float(t)
                        try:
                            await session.ping(f"sdk-ping-pong:{t}")
                        except Exception as e:
                            # The ping() method can fail for some reason.
                            # To establish a new connection even in this scenario,
                            # we ignore the exception here.
                            self.logger.warning(
                                f"Failed to send a ping message ({session_id}): {e}"
                            )

                    if self.auto_reconnect_enabled:
                        should_reconnect = False
                        if session is None or session.closed:
                            self.logger.info(
                                f"The session ({session_id}) seems to be already closed. Reconnecting..."
                            )
                            should_reconnect = True

                        if await self.is_ping_pong_failing():
                            disconnected_seconds = int(
                                time.time() - self.last_ping_pong_time)
                            self.logger.info(
                                f"The session ({session_id}) seems to be stale. Reconnecting..."
                                f" reason: disconnected for {disconnected_seconds}+ seconds)"
                            )
                            self.stale = True
                            self.last_ping_pong_time = None
                            should_reconnect = True

                        if should_reconnect is True or not await self.is_connected(
                        ):
                            await self.connect_to_new_endpoint()

                except Exception as e:
                    self.logger.error(
                        f"Failed to check the current session ({session_id}) or reconnect to the server "
                        f"(error: {type(e).__name__}, message: {e})")
        except asyncio.CancelledError:
            if self.logger.level <= logging.DEBUG:
                self.logger.debug(
                    f"The monitor_current_session task for {session_id} is now cancelled"
                )
            raise

    async def receive_messages(self) -> None:
        # In the asyncio runtime, accessing a shared object (self.current_session here) from
        # multiple tasks can cause race conditions and errors.
        # To avoid such, we access only the session that is active when this loop starts.
        session = self.current_session
        session_id = self.build_session_id(session)
        if self.logger.level <= logging.DEBUG:
            self.logger.debug(
                f"A new receive_messages() execution loop with {session_id} started"
            )
        try:
            consecutive_error_count = 0
            logging_interval = 100
            counter_for_logging = 0

            while not self.closed:
                if session != self.current_session:
                    if self.logger.level <= logging.DEBUG:
                        self.logger.debug(
                            f"The running receive_messages task for {session_id} is now cancelled"
                        )
                    break
                try:
                    message: WSMessage = await session.receive()
                    # just in case, checking if the value is not None
                    if message is not None:
                        if self.logger.level <= logging.DEBUG:
                            # The following logging prints every single received message
                            # except empty message data ones.
                            type = WSMsgType(message.type)
                            message_type = (type.name if type is not None else
                                            message.type)
                            message_data = message.data
                            if isinstance(message_data, bytes):
                                message_data = message_data.decode("utf-8")
                            if len(message_data) > 0:
                                # To skip the empty message that Slack server-side often sends
                                self.logger.debug(f"Received message "
                                                  f"(type: {message_type}, "
                                                  f"data: {message_data}, "
                                                  f"extra: {message.extra}, "
                                                  f"session: {session_id})")

                            if self.trace_enabled:
                                # The logging here is for detailed trouble shooting of potential issues in this client.
                                # If you don't see this log for a while, it can mean that
                                # this receive_messages execution is no longer working for some reason.
                                counter_for_logging += 1
                                if counter_for_logging >= logging_interval:
                                    counter_for_logging = 0
                                    log_message = (
                                        "#receive_messages method has been working without any issues "
                                        f"(session: {session_id}, logging interval: {logging_interval})"
                                    )
                                    self.logger.debug(log_message)

                        if message.type == WSMsgType.TEXT:
                            message_data = message.data
                            await self.enqueue_message(message_data)
                            for listener in self.on_message_listeners:
                                await listener(message)
                        elif message.type == WSMsgType.CLOSE:
                            if self.auto_reconnect_enabled:
                                self.logger.info(
                                    f"Received CLOSE event from {session_id}. Reconnecting..."
                                )
                                await self.connect_to_new_endpoint()
                            for listener in self.on_close_listeners:
                                await listener(message)
                        elif message.type == WSMsgType.ERROR:
                            for listener in self.on_error_listeners:
                                await listener(message)
                        elif message.type == WSMsgType.CLOSED:
                            await asyncio.sleep(self.ping_interval)
                            continue
                        elif message.type == WSMsgType.PING:
                            await session.pong(message.data)
                            continue
                        elif message.type == WSMsgType.PONG:
                            if message.data is not None:
                                str_message_data = message.data.decode("utf-8")
                                elements = str_message_data.split(":")
                                if (len(elements) == 2
                                        and elements[0] == "sdk-ping-pong"):
                                    try:
                                        self.last_ping_pong_time = float(
                                            elements[1])
                                    except Exception as e:
                                        self.logger.warning(
                                            f"Failed to parse the last_ping_pong_time value from {str_message_data}"
                                            f" - error : {e}, session: {session_id}"
                                        )
                            continue

                    consecutive_error_count = 0

                except Exception as e:
                    consecutive_error_count += 1
                    self.logger.error(
                        f"Failed to receive or enqueue a message: {type(e).__name__}, {e} ({session_id})"
                    )
                    if isinstance(e, ClientConnectionError):
                        await asyncio.sleep(self.ping_interval)
                    else:
                        await asyncio.sleep(consecutive_error_count)
        except asyncio.CancelledError:
            if self.logger.level <= logging.DEBUG:
                self.logger.debug(
                    f"The running receive_messages task for {session_id} is now cancelled"
                )
            raise

    async def is_ping_pong_failing(self) -> bool:
        if self.last_ping_pong_time is None:
            return False
        disconnected_seconds = int(time.time() - self.last_ping_pong_time)
        return disconnected_seconds >= (self.ping_interval * 4)

    async def is_connected(self) -> bool:
        connected: bool = (not self.closed and not self.stale
                           and self.current_session is not None
                           and not self.current_session.closed
                           and not await self.is_ping_pong_failing())
        if self.logger.level <= logging.DEBUG and connected is False:
            # Prints more detailed information about the inactive connection
            is_ping_pong_failing = await self.is_ping_pong_failing()
            session_id = await self.session_id()
            self.logger.debug(
                "Inactive connection detected ("
                f"session_id: {session_id}, "
                f"closed: {self.closed}, "
                f"stale: {self.stale}, "
                f"current_session.closed: {self.current_session.closed}, "
                f"is_ping_pong_failing: {is_ping_pong_failing}"
                ")")
        return connected

    async def session_id(self) -> str:
        return self.build_session_id(self.current_session)

    async def connect(self):
        old_session: Optional[ClientWebSocketResponse] = (
            None if self.current_session is None else self.current_session)
        if self.wss_uri is None:
            # If the underlying WSS URL does not exist,
            # acquiring a new active WSS URL from the server-side first
            self.wss_uri = await self.issue_new_wss_url()

        self.current_session = await self.aiohttp_client_session.ws_connect(
            self.wss_uri,
            autoping=False,
            heartbeat=self.ping_interval,
            proxy=self.proxy,
        )
        session_id: str = await self.session_id()
        self.auto_reconnect_enabled = self.default_auto_reconnect_enabled
        self.stale = False
        self.logger.info(f"A new session ({session_id}) has been established")

        # The first ping from the new connection
        if self.logger.level <= logging.DEBUG:
            self.logger.debug(
                f"Sending a ping message with the newly established connection ({session_id})..."
            )
        t = time.time()
        await self.current_session.ping(f"sdk-ping-pong:{t}")

        if self.current_session_monitor is not None:
            self.current_session_monitor.cancel()

        self.current_session_monitor = asyncio.ensure_future(
            self.monitor_current_session())
        if self.logger.level <= logging.DEBUG:
            self.logger.debug(
                f"A new monitor_current_session() executor has been recreated for {session_id}"
            )

        if self.message_receiver is not None:
            self.message_receiver.cancel()

        self.message_receiver = asyncio.ensure_future(self.receive_messages())
        if self.logger.level <= logging.DEBUG:
            self.logger.debug(
                f"A new receive_messages() executor has been recreated for {session_id}"
            )

        if old_session is not None:
            await old_session.close()
            old_session_id = self.build_session_id(old_session)
            self.logger.info(
                f"The old session ({old_session_id}) has been abandoned")

    async def disconnect(self):
        if self.current_session is not None:
            await self.current_session.close()
        session_id = await self.session_id()
        self.logger.info(
            f"The current session ({session_id}) has been abandoned by disconnect() method call"
        )

    async def send_message(self, message: str):
        session_id = await self.session_id()
        if self.logger.level <= logging.DEBUG:
            self.logger.debug(
                f"Sending a message: {message} from session: {session_id}")
        try:
            await self.current_session.send_str(message)
        except ConnectionError as e:
            # We rarely get this exception while replacing the underlying WebSocket connections.
            # We can do one more try here as the self.current_session should be ready now.
            if self.logger.level <= logging.DEBUG:
                self.logger.debug(
                    f"Failed to send a message (error: {e}, message: {message}, session: {session_id})"
                    " as the underlying connection was replaced. Retrying the same request only one time..."
                )
            # Although acquiring self.connect_operation_lock also for the first method call is the safest way,
            # we avoid synchronizing a lot for better performance. That's why we are doing a retry here.
            try:
                await self.connect_operation_lock.acquire()
                if await self.is_connected():
                    await self.current_session.send_str(message)
                else:
                    self.logger.warning(
                        f"The current session ({session_id}) is no longer active. "
                        "Failed to send a message")
                    raise e
            finally:
                if self.connect_operation_lock.locked() is True:
                    self.connect_operation_lock.release()

    async def close(self):
        self.closed = True
        self.auto_reconnect_enabled = False
        await self.disconnect()
        if self.message_processor is not None:
            self.message_processor.cancel()
        if self.current_session_monitor is not None:
            self.current_session_monitor.cancel()
        if self.message_receiver is not None:
            self.message_receiver.cancel()
        if self.aiohttp_client_session is not None:
            await self.aiohttp_client_session.close()

    @classmethod
    def build_session_id(cls, session: ClientWebSocketResponse) -> str:
        if session is None:
            return ""
        return "s_" + str(hash(session))
Beispiel #23
0
class FakeAioRedisConnection:
    def __init__(self, max_age_seconds=120, loop=None):
        assert max_age_seconds >= 1
        self.store = {}
        self.max_age = max_age_seconds
        self.lock = Lock()
        self.closed = False
        self.loop = loop or get_event_loop()
        self.clean_keys()

    async def pexpire(self, key, pexpire):
        self.expire(key, pexpire / 1000)

    async def expire(self, key, expire):
        with (await self.lock):
            item = self.store.get(key, None)
            if not item:
                return None
            self.store[key] = (item[0], time.time() + expire)

    async def persist(self, key):
        with (await self.lock):
            item = self.store.get(key, None)
            if not item:
                return None
            self.store[key] = (item[0], float('inf'))

    async def set(self, key, value, pexpire=None):
        with (await self.lock):
            if not pexpire:
                self.store[key] = (
                        str(value).encode(), float('inf'))
            else:
                self.store[key] = (
                        str(value).encode(), time.time() + pexpire / 1000)

    async def get(self, key):
        with (await self.lock):
            item = self.store.get(key, None)
            if not item:
                return None
            return item[0]

    async def delete(self, key, default=None):
        with (await self.lock):
            item = self.store.get(key, None)
            if not item:
                return None
            del self.store[key]
            return item[0]

    async def lpush(self, key, value):
        with (await self.lock):
            li, timeout = self.store.get(key, (None, None))
            if not li:
                self.store[key] = (deque([value]), float('inf'))
            else:
                li.appendleft(value)
                self.store[key] = (li, timeout)
            return

    async def lpushx(self, key, value):
        with (await self.lock):
            li, timeout = self.store.get(key, (None, None))
            if not li:
                return None
            li.appendleft(value)
            self.store[key] = (li, timeout)
            return

    async def rpush(self, key, value):
        with (await self.lock):
            li, timeout = self.store.get(key, (None, None))
            if not li:
                self.store[key] = (deque([value]), float('inf'))
            else:
                li.append(value)
                self.store[key] = (li, timeout)
            return

    async def rpushx(self, key, value):
        with (await self.lock):
            li, timeout = self.store.get(key, (None, None))
            if not li:
                return None
            li.append(value)
            self.store[key] = (li, timeout)
            return

    async def lpop(self, key):
        with (await self.lock):
            li, _ = self.store.get(key, (None, None))
            if not li:
                return None
            elem = li.popleft()
            return elem

    async def rpop(self, key):
        with (await self.lock):
            li, _ = self.store.get(key, (None, None))
            if not li:
                return None
            elem = li.pop()
            return elem

    async def lrange(self, key, start, end):
        with (await self.lock):
            li, _ = self.store.get(key, (None, None))
            if not li:
                return []
            if start == 0 and end == -1:
                return li
            return li[start, end]

    async def llen(self, key):
        with (await self.lock):
            li, _ = self.store.get(key, (None, None))
            if not li:
                return None
            return len(li)

    async def incr(self, key):
        with (await self.lock):
            value, timeout = self.store.get(key, (0, float('inf')))
            value += 1
            self.store[key] = (value, timeout)

    async def incrby(self, key, increment):
        with (await self.lock):
            value, timeout = self.store.get(key, (0, float('inf')))
            value += increment
            self.store[key] = (value, timeout)

    async def quit(self):
        with (await self.lock):
            self.closed = True
            self.store = {}
        return

    def clean_keys(self):
        if self.closed:
            return
        for k in list(self.store.keys()):
            if not self.lock.locked():
                item = self.store[k]
                if time.time() > item[1]:
                    del self.store[k]
        self.loop.call_soon_threadsafe(
            self.loop.call_later,
            self.max_age,
            self.clean_keys)